This commit is contained in:
2025-08-25 16:25:12 +08:00
parent b68431828c
commit 54f1365d24
3 changed files with 351 additions and 374 deletions

View File

@@ -14,22 +14,18 @@ import asyncio
import threading
from fastapi.responses import StreamingResponse
from fastapi import APIRouter, Request, HTTPException, BackgroundTasks
# 创建路由路由器
router = APIRouter(prefix="/api/mj", tags=["文生图"])
# 配置日志
logger = logging.getLogger(__name__)
# 任务状态存储
TASK_STATUS: Dict[str, Dict[str, Any]] = {}
class ImagineRequest(BaseModel):
prompt: str
base64_array: Optional[list] = None
notify_hook: Optional[str] = None
class TaskStatusResponse(BaseModel):
task_id: str
status: str
@@ -37,128 +33,74 @@ class TaskStatusResponse(BaseModel):
progress: Optional[int] = None
error: Optional[str] = None
@router.post("/imagine", response_model=Dict[str, str])
async def submit_imagine(request: ImagineRequest, background_tasks: BackgroundTasks):
async def submit_imagine(request: ImagineRequest):
"""
提交文生图请求
Args:
request: 包含提示词、垫图和回调地址的请求体
background_tasks: 用于异步处理任务状态查询的后台任务
Returns:
包含任务ID的字典
包含MJ服务器任务ID的字典
"""
try:
# 生成唯一任务ID
task_id = str(uuid.uuid4())
logger.info(f"收到文生图请求任务ID: {task_id}, 提示词: {request.prompt}")
# 初始化任务状态
TASK_STATUS[task_id] = {
"status": "pending",
"image_url": None,
"progress": 0,
"error": None
}
# 提交到Midjourney
# 直接提交到Midjourney并返回MJ任务ID
midjourney_task_id = Txt2Img.submit_imagine(
prompt=request.prompt,
base64_array=request.base64_array,
notify_hook=request.notify_hook
)
# 存储Midjourney任务ID
TASK_STATUS[task_id]["midjourney_task_id"] = midjourney_task_id
# 添加后台任务轮询状态
def poll_task_status_background(task_id, midjourney_task_id):
max_retries = 1000
retry_count = 0
retry_interval = 5 # 5秒
while retry_count < max_retries:
try:
# 查询任务状态
result = Txt2Img.query_task_status(midjourney_task_id)
# 更新任务状态
if result.get("status") == "SUCCESS":
TASK_STATUS[task_id] = {
"status": "completed",
"image_url": result.get("imageUrl"),
"progress": 100,
"error": None,
"midjourney_task_id": midjourney_task_id
}
logger.info(f"任务 {task_id} 完成图片URL: {result.get('imageUrl')}")
break
elif result.get("status") == "FAILED":
TASK_STATUS[task_id] = {
"status": "failed",
"image_url": None,
"progress": 0,
"error": result.get("errorMsg", "未知错误"),
"midjourney_task_id": midjourney_task_id
}
logger.error(f"任务 {task_id} 失败: {result.get('errorMsg', '未知错误')}")
break
else:
# 更新进度
progress = result.get("progress", 0)
TASK_STATUS[task_id]["progress"] = progress
TASK_STATUS[task_id]["status"] = "processing"
logger.info(f"任务 {task_id} 处理中,进度: {progress}%")
# 增加重试计数
retry_count += 1
# 等待重试间隔
time.sleep(retry_interval)
except Exception as e:
logger.error(f"轮询任务 {task_id} 状态失败: {str(e)}")
TASK_STATUS[task_id]["error"] = str(e)
time.sleep(retry_interval)
if retry_count >= max_retries:
logger.error(f"任务 {task_id} 超时")
TASK_STATUS[task_id] = {
"status": "failed",
"image_url": None,
"progress": 0,
"error": "任务处理超时",
"midjourney_task_id": midjourney_task_id
}
# 使用线程运行后台任务
thread = threading.Thread(target=poll_task_status_background, args=(task_id, midjourney_task_id))
thread.daemon = True
thread.start()
return {"task_id": task_id}
logger.info(f"提交文生图请求成功MJ任务ID: {midjourney_task_id}, 提示词: {request.prompt}")
return {"task_id": midjourney_task_id}
except Exception as e:
logger.error(f"提交文生图请求失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"提交文生图请求失败: {str(e)}")
# 在文件顶部确保已定义任务状态缓存
task_status_cache = {}
@router.get("/task_status")
async def get_task_status(task_id: str):
if not task_id:
raise HTTPException(status_code=400, detail="task_id 参数缺失")
"""
直接查询MJ服务器获取任务状态
# 将所有 task_status_store 替换为 task_status_cache
if task_id not in task_status_cache:
raise HTTPException(status_code=404, detail="任务ID不存在")
Args:
task_id: MJ服务器返回的任务ID
Returns:
包含任务状态的字典
"""
try:
# 直接查询MJ服务器获取实时状态
result = Txt2Img.query_task_status(task_id)
# 处理查询结果
if result.get("status") == "SUCCESS":
return {
"task_id": task_id,
"status": task_status_cache[task_id]["status"],
"progress": task_status_cache[task_id]["progress"] ,
"result": task_status_cache[task_id]["result"],
"error": task_status_cache[task_id]["error"]
"status": "completed",
"image_url": result.get("imageUrl"),
"progress": 100,
"error": None
}
elif result.get("status") == "FAILED":
return {
"task_id": task_id,
"status": "failed",
"image_url": None,
"progress": 0,
"error": result.get("errorMsg", "未知错误")
}
else:
return {
"task_id": task_id,
"status": "processing",
"progress": result.get("progress", 0),
"error": None
}
except Exception as e:
logger.error(f"查询任务状态失败: {str(e)}")
return {
"task_id": task_id,
"status": "error",
"message": f"查询任务状态失败: {str(e)}"
}, 500

View File

@@ -57,6 +57,23 @@
</style>
</head>
<body class="min-h-screen bg-gradient-to-br from-light to-gray-100 dark:from-dark dark:to-gray-900 text-gray-800 dark:text-gray-100 transition-colors duration-300">
<script>
function showError(message) {
const errorElement = document.getElementById('error-message');
if (errorElement) {
errorElement.textContent = message;
errorElement.classList.remove('hidden');
// Show error section if it's hidden
const statusSection = document.getElementById('statusSection');
if (statusSection) statusSection.classList.remove('hidden');
// Auto-hide after 5 seconds
setTimeout(() => errorElement.classList.add('hidden'), 5000);
} else {
console.error('Error: showError - error-message element not found');
alert(message);
}
}
</script>
<!-- 导航栏 -->
<header class="sticky top-0 z-50 glass-effect dark:glass-effect-dark border-b border-gray-200 dark:border-gray-700 shadow-sm">
<div class="container mx-auto px-4 py-3 flex justify-between items-center">
@@ -275,36 +292,20 @@
<i class="fa fa-magic mr-2"></i>生成图像
</button>
</div>
</section>
<!-- 生成状态区域 -->
<section id="statusSection" class="mb-16 hidden">
<div class="bg-white dark:bg-gray-800 rounded-2xl shadow-lg p-6 md:p-8">
<div class="flex flex-col items-center text-center space-y-6">
<div id="thinkingIndicator" class="flex flex-col items-center space-y-4">
<div class="w-16 h-16 rounded-full bg-primary/10 flex items-center justify-center animate-pulse-slow">
<i class="fa fa-cog fa-spin text-2xl text-primary"></i>
</div>
<h3 class="text-xl font-semibold">AI 正在创作中...</h3>
<p class="text-gray-600 dark:text-gray-300 max-w-md">
请稍候,我们的 AI 正在努力为您生成精美的图像
</p>
</div>
<div id="progressBarContainer"
class="w-full max-w-2xl bg-gray-200 dark:bg-gray-700 rounded-full h-3 overflow-hidden">
<div id="progressBar"
class="bg-gradient-to-r from-primary to-secondary h-full w-0 transition-all duration-300"></div>
</div>
<div class="text-sm text-gray-500 dark:text-gray-400">
<span id="progressText">0%</span> 完成
</div>
<!-- 新增状态显示区域 -->
<div id="statusSection" class="hidden p-6 bg-gray-50 dark:bg-gray-900 rounded-b-lg">
<div class="progress-container mb-4">
<div id="progress-bar" class="progress-bar h-2 bg-primary rounded-full transition-all duration-300"
style="width: 0%"></div>
<div id="progress-text" class="text-right text-sm mt-1 text-gray-600 dark:text-gray-300">0%</div>
</div>
<div id="status-display" class="text-gray-700 dark:text-gray-300 mb-2"></div>
<div id="error-message" class="error-message text-red-500 hidden"></div>
</div>
</section>
<!-- 结果展示区域 -->
<!-- 生成结果区域 -->
<section id="resultSection" class="mb-16 hidden">
<div class="bg-white dark:bg-gray-800 rounded-2xl shadow-lg overflow-hidden">
<div class="p-6 md:p-8 border-b border-gray-200 dark:border-gray-700">
@@ -429,51 +430,6 @@
</div>
</footer>
<script>
// ==============================================
// 全局函数 - 必须放在所有事件监听和调用之前
// ==============================================
function checkTaskStatus(taskId) {
if (!taskId) {
console.error('任务ID不存在');
return;
}
fetch(`/api/mj/task_status?task_id=${encodeURIComponent(taskId)}`)
.then(response => {
if (!response.ok) throw new Error('网络响应不正常');
return response.json();
})
.then(data => {
const progressBar = document.getElementById('progress-bar');
const progressText = document.getElementById('progress-text');
const statusDisplay = document.getElementById('status-display');
if (progressBar && progressText) {
const progress = data.progress || 0;
progressBar.style.width = `${progress}%`;
progressText.textContent = `${progress}%`;
}
if (statusDisplay) {
statusDisplay.textContent = `状态: ${data.status || '未知'}`;
}
if (data.status === 'completed') {
stopPolling();
displayResult(data.result);
addToHistory(data);
} else if (data.status === 'failed') {
stopPolling();
showError(data.error || '生成失败,请重试');
}
})
.catch(error => {
console.error('获取任务状态时出错:', error);
stopPolling();
showError('获取状态失败: ' + error.message);
});
}
document.addEventListener('DOMContentLoaded', function () {
// 主题切换功能
const themeToggle = document.getElementById('themeToggle');
@@ -500,6 +456,7 @@
const textToImagePanel = document.getElementById('textToImagePanel');
const imageToImagePanel = document.getElementById('imageToImagePanel');
textToImageTab.addEventListener('click', function () {
textToImageTab.classList.add('text-primary', 'border-primary');
textToImageTab.classList.remove('text-gray-500', 'border-transparent');
@@ -517,7 +474,7 @@
imageToImagePanel.classList.remove('hidden');
textToImagePanel.classList.add('hidden');
});
});
// 图生图相关功能
const browseImage = document.getElementById('browseImage');
const imageUpload = document.getElementById('imageUpload');
@@ -655,53 +612,105 @@
return;
}
// 提交生成请求
submitGenerateRequest(prompt, base64Array);
// 获取表单元素并进行空值检查
const promptElement = document.getElementById('prompt');
const aspectRatioElement = document.getElementById('aspectRatio');
const qualityElement = document.getElementById('quality');
const styleElement = document.getElementById('style');
// 检查是否所有必要元素都存在
if (!promptElement || !aspectRatioElement || !qualityElement || !styleElement) {
showError('表单元素缺失,请检查页面配置');
return;
}
// 创建FormData并添加值
const formData = new FormData();
formData.append('prompt', promptElement.value);
formData.append('aspect_ratio', aspectRatioElement.value);
formData.append('quality', qualityElement.value);
formData.append('style', styleElement.value);
// 调用提交函数
submitGenerateRequest(formData);
}
// 提交生成请求函数
async function submitGenerateRequest(prompt, base64Array) {
async function submitGenerateRequest(formData) {
// 验证formData是有效的FormData对象
if (!(formData instanceof FormData)) {
showError('无效的表单数据格式');
return;
}
// 提取表单数据
const prompt = formData.get('prompt');
const aspect_ratio = formData.get('aspect_ratio');
const quality = formData.get('quality');
const style = formData.get('style');
// 验证必填字段
if (!prompt) {
showError('请输入提示词');
return;
}
// 从formData提取值确保表单元素有正确的name属性
const promptText = formData.get('prompt')?.trim();
const selectedAspectRatio = formData.get('aspect_ratio');
const selectedQuality = formData.get('quality');
const selectedStyle = formData.get('style');
// 验证必要参数
if (!promptText) {
showError('提示词不能为空');
return;
}
if (!selectedAspectRatio || !selectedQuality || !selectedStyle) {
showError('请选择完整的生成参数');
return;
}
try {
// 禁用生成按钮
generateBtn.disabled = true;
generateBtn.innerHTML = '<i class="fa fa-spinner fa-spin mr-2"></i>正在提交...';
// 准备请求数据
const requestData = {
prompt: prompt,
base64_array: base64Array
};
// 发送请求
const response = await fetch('/api/mj/imagine', {
method: 'POST',
headers: {
'Content-Type': 'application/json'
'Content-Type': 'application/json',
},
body: JSON.stringify(requestData)
body: JSON.stringify({
prompt: promptText,
aspect_ratio: selectedAspectRatio,
quality: selectedQuality,
style: selectedStyle
})
});
if (!response.ok) {
throw new Error(`提交请求失败: ${response.statusText}`);
const errorData = await response.json().catch(() => null);
throw new Error(`生成请求失败: ${errorData?.message || response.statusText}`);
}
const data = await response.json();
const taskId = data.task_id;
const result = await response.json();
console.log('Backend response:', result); // Add this line
// 显示状态区域
if (!result || !result.task_id) {
showError('生成失败: 未获取到任务ID');
return null;
}
// 显示状态区域并开始轮询
document.getElementById('statusSection').classList.remove('hidden');
// 开始轮询任务状态
pollTaskStatus(taskId, prompt);
pollTaskStatus(result.task_id, promptText);
return result.task_id;
} catch (error) {
console.error('生成图像失败:', error);
alert(`生成图像失败: ${error.message}`);
// 恢复生成按钮
generateBtn.disabled = false;
generateBtn.innerHTML = '<i class="fa fa-magic mr-2"></i>生成图像';
console.error('提交生成请求失败:', error);
showError('提交请求失败: ' + error.message);
return null;
}
}
// 轮询任务状态函数
// 轮询任务状态函数
function pollTaskStatus(taskId, prompt) {
// 轮询配置
@@ -711,25 +720,35 @@
const startTime = new Date().getTime(); // 开始时间
let timerId = null; // 用于存储setTimeout的ID
// 定义内部轮询函数
function checkStatus() {
// 检查是否超过最大查询次数
if (queryCount >= maxQueries) {
const elapsedTime = Math.round((new Date().getTime() - startTime) / 1000);
document.getElementById('statusMessage').innerHTML = `<i class="fa fa-exclamation-circle"></i> 图像生成超时(已等待${elapsedTime}秒),您可以稍后在历史记录中查看,或重新生成。`;
document.getElementById('progressText').textContent = '生成超时';
document.getElementById('status-display').innerHTML = `<i class="fa fa-exclamation-circle"></i> 图像生成超时(已等待${elapsedTime}秒),您可以稍后在历史记录中查看,或重新生成。`;
document.getElementById('progress-text').textContent = '生成超时';
// 恢复生成按钮
generateBtn.disabled = false;
generateBtn.innerHTML = '<i class="fa fa-magic mr-2"></i>生成图像';
return;
}
// 发起状态查询请求
fetch(`/api/mj/task_status?task_id=${taskId}`)
.then(response => response.json())
.then(response => {
if (!response.ok) {
throw new Error(`HTTP error! status: ${response.status}`);
}
return response.json();
})
.then(data => {
queryCount++;
// 更新状态信息
const statusMessage = document.getElementById('statusMessage');
const progressFill = document.getElementById('progressBar');
const progressText = document.getElementById('progressText');
// 获取DOM元素
const statusMessage = document.getElementById('status-display');
const progressFill = document.getElementById('progress-bar');
const progressText = document.getElementById('progress-text');
// 处理错误状态
if (data.error) {
statusMessage.innerHTML = `<i class="fa fa-exclamation-circle"></i> 生成失败: ${data.error}`;
statusMessage.style.color = '#ef4444';
@@ -739,13 +758,13 @@
return;
}
// 更新状态信息
statusMessage.innerHTML = `<i class="fa fa-spinner fa-spin"></i> AI 正在创作中...`;
statusMessage.style.color = ''; // 重置颜色
// 更新进度条
let progress = data.progress || 0;
if (progress > 100) progress = 100;
// 计算已等待时间
const elapsedTime = Math.round((new Date().getTime() - startTime) / 1000);
progressFill.style.width = `${progress}%`;
progressText.textContent = `${Math.round(progress)}%`;
@@ -757,9 +776,8 @@
statusMessage.style.color = '#10b981';
// 确保进度为100%
progress = 100;
progressFill.style.width = `${progress}%`;
progressText.textContent = `${Math.round(progress)}%`;
progressFill.style.width = '100%';
progressText.textContent = '100%';
// 显示结果
setTimeout(() => {
@@ -771,23 +789,29 @@
clearTimeout(timerId);
timerId = null;
}
} else if (data.status === 'failed') {
}
// 任务失败
else if (data.status === 'failed') {
statusMessage.innerHTML = `<i class="fa fa-exclamation-circle"></i> 生成失败: ${data.error || '未知错误'}`;
statusMessage.style.color = '#ef4444';
// 恢复生成按钮
generateBtn.disabled = false;
generateBtn.innerHTML = '<i class="fa fa-magic mr-2"></i>生成图像';
} else {
// 继续
timerId = setTimeout(checkTaskStatus, queryInterval);
}
// 继续
else {
timerId = setTimeout(checkStatus, queryInterval);
}
})
.catch(error => {
console.error('检查任务状态失败:', error);
// 只有在任务未完成且未超时时才继续查询
// 在错误情况下仍继续轮询直到达到最大次数
if (queryCount < maxQueries) {
timerId = setTimeout(checkTaskStatus, queryInterval);
queryCount++;
timerId = setTimeout(checkStatus, queryInterval);
} else {
document.getElementById('status-display').innerHTML = `<i class="fa fa-exclamation-circle"></i> 连接服务器失败,请重试`;
document.getElementById('status-display').style.color = '#ef4444';
// 恢复生成按钮
generateBtn.disabled = false;
generateBtn.innerHTML = '<i class="fa fa-magic mr-2"></i>生成图像';
@@ -795,9 +819,9 @@
});
}
// 立即检查一次
checkTaskStatus();
})
// 立即开始第一次检查
checkStatus();
}
// 显示结果函数
function showResult(imageUrl, taskId, prompt) {
@@ -895,10 +919,21 @@
// 初始化历史记录网格
updateHistoryGrid();
}
</script>
<!-- 添加随机示例功能的JavaScript -->
<script>
// 添加 showError function after checkTaskStatus
function showError(message) {
const errorElement = document.getElementById('error-message');
if (errorElement) {
errorElement.textContent = message;
errorElement.classList.remove('hidden');
// Auto-hide after 5 seconds
setTimeout(() => errorElement.classList.add('hidden'), 5000);
} else {
// Fallback if error element not found
alert(message);
}
}
// 文生图示例提示词库
const textPromptExamples = [
"一只穿着太空服的柯基犬在火星表面行走,背景是红色星球和远处的地球,科幻风格,高清细节",