This commit is contained in:
2025-08-25 16:25:12 +08:00
parent b68431828c
commit 54f1365d24
3 changed files with 351 additions and 374 deletions

View File

@@ -14,22 +14,18 @@ import asyncio
import threading
from fastapi.responses import StreamingResponse
from fastapi import APIRouter, Request, HTTPException, BackgroundTasks
# 创建路由路由器
router = APIRouter(prefix="/api/mj", tags=["文生图"])
# 配置日志
logger = logging.getLogger(__name__)
# 任务状态存储
TASK_STATUS: Dict[str, Dict[str, Any]] = {}
class ImagineRequest(BaseModel):
prompt: str
base64_array: Optional[list] = None
notify_hook: Optional[str] = None
class TaskStatusResponse(BaseModel):
task_id: str
status: str
@@ -37,128 +33,74 @@ class TaskStatusResponse(BaseModel):
progress: Optional[int] = None
error: Optional[str] = None
@router.post("/imagine", response_model=Dict[str, str])
async def submit_imagine(request: ImagineRequest, background_tasks: BackgroundTasks):
async def submit_imagine(request: ImagineRequest):
"""
提交文生图请求
Args:
request: 包含提示词、垫图和回调地址的请求体
background_tasks: 用于异步处理任务状态查询的后台任务
Returns:
包含任务ID的字典
包含MJ服务器任务ID的字典
"""
try:
# 生成唯一任务ID
task_id = str(uuid.uuid4())
logger.info(f"收到文生图请求任务ID: {task_id}, 提示词: {request.prompt}")
# 初始化任务状态
TASK_STATUS[task_id] = {
"status": "pending",
"image_url": None,
"progress": 0,
"error": None
}
# 提交到Midjourney
# 直接提交到Midjourney并返回MJ任务ID
midjourney_task_id = Txt2Img.submit_imagine(
prompt=request.prompt,
base64_array=request.base64_array,
notify_hook=request.notify_hook
)
# 存储Midjourney任务ID
TASK_STATUS[task_id]["midjourney_task_id"] = midjourney_task_id
# 添加后台任务轮询状态
def poll_task_status_background(task_id, midjourney_task_id):
max_retries = 1000
retry_count = 0
retry_interval = 5 # 5秒
while retry_count < max_retries:
try:
# 查询任务状态
result = Txt2Img.query_task_status(midjourney_task_id)
# 更新任务状态
if result.get("status") == "SUCCESS":
TASK_STATUS[task_id] = {
"status": "completed",
"image_url": result.get("imageUrl"),
"progress": 100,
"error": None,
"midjourney_task_id": midjourney_task_id
}
logger.info(f"任务 {task_id} 完成图片URL: {result.get('imageUrl')}")
break
elif result.get("status") == "FAILED":
TASK_STATUS[task_id] = {
"status": "failed",
"image_url": None,
"progress": 0,
"error": result.get("errorMsg", "未知错误"),
"midjourney_task_id": midjourney_task_id
}
logger.error(f"任务 {task_id} 失败: {result.get('errorMsg', '未知错误')}")
break
else:
# 更新进度
progress = result.get("progress", 0)
TASK_STATUS[task_id]["progress"] = progress
TASK_STATUS[task_id]["status"] = "processing"
logger.info(f"任务 {task_id} 处理中,进度: {progress}%")
# 增加重试计数
retry_count += 1
# 等待重试间隔
time.sleep(retry_interval)
except Exception as e:
logger.error(f"轮询任务 {task_id} 状态失败: {str(e)}")
TASK_STATUS[task_id]["error"] = str(e)
time.sleep(retry_interval)
if retry_count >= max_retries:
logger.error(f"任务 {task_id} 超时")
TASK_STATUS[task_id] = {
"status": "failed",
"image_url": None,
"progress": 0,
"error": "任务处理超时",
"midjourney_task_id": midjourney_task_id
}
# 使用线程运行后台任务
thread = threading.Thread(target=poll_task_status_background, args=(task_id, midjourney_task_id))
thread.daemon = True
thread.start()
return {"task_id": task_id}
logger.info(f"提交文生图请求成功MJ任务ID: {midjourney_task_id}, 提示词: {request.prompt}")
return {"task_id": midjourney_task_id}
except Exception as e:
logger.error(f"提交文生图请求失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"提交文生图请求失败: {str(e)}")
# 在文件顶部确保已定义任务状态缓存
task_status_cache = {}
@router.get("/task_status")
async def get_task_status(task_id: str):
if not task_id:
raise HTTPException(status_code=400, detail="task_id 参数缺失")
"""
直接查询MJ服务器获取任务状态
# 将所有 task_status_store 替换为 task_status_cache
if task_id not in task_status_cache:
raise HTTPException(status_code=404, detail="任务ID不存在")
Args:
task_id: MJ服务器返回的任务ID
return {
"task_id": task_id,
"status": task_status_cache[task_id]["status"],
"progress": task_status_cache[task_id]["progress"] ,
"result": task_status_cache[task_id]["result"],
"error": task_status_cache[task_id]["error"]
}
Returns:
包含任务状态的字典
"""
try:
# 直接查询MJ服务器获取实时状态
result = Txt2Img.query_task_status(task_id)
# 处理查询结果
if result.get("status") == "SUCCESS":
return {
"task_id": task_id,
"status": "completed",
"image_url": result.get("imageUrl"),
"progress": 100,
"error": None
}
elif result.get("status") == "FAILED":
return {
"task_id": task_id,
"status": "failed",
"image_url": None,
"progress": 0,
"error": result.get("errorMsg", "未知错误")
}
else:
return {
"task_id": task_id,
"status": "processing",
"progress": result.get("progress", 0),
"error": None
}
except Exception as e:
logger.error(f"查询任务状态失败: {str(e)}")
return {
"task_id": task_id,
"status": "error",
"message": f"查询任务状态失败: {str(e)}"
}, 500