diff --git a/dsLightRag/Routes/MjRoute.py b/dsLightRag/Routes/MjRoute.py index c6bd6142..a074af29 100644 --- a/dsLightRag/Routes/MjRoute.py +++ b/dsLightRag/Routes/MjRoute.py @@ -14,22 +14,18 @@ import asyncio import threading from fastapi.responses import StreamingResponse from fastapi import APIRouter, Request, HTTPException, BackgroundTasks + # 创建路由路由器 router = APIRouter(prefix="/api/mj", tags=["文生图"]) # 配置日志 logger = logging.getLogger(__name__) -# 任务状态存储 -TASK_STATUS: Dict[str, Dict[str, Any]] = {} - - class ImagineRequest(BaseModel): prompt: str base64_array: Optional[list] = None notify_hook: Optional[str] = None - class TaskStatusResponse(BaseModel): task_id: str status: str @@ -37,128 +33,74 @@ class TaskStatusResponse(BaseModel): progress: Optional[int] = None error: Optional[str] = None - @router.post("/imagine", response_model=Dict[str, str]) -async def submit_imagine(request: ImagineRequest, background_tasks: BackgroundTasks): +async def submit_imagine(request: ImagineRequest): """ 提交文生图请求 Args: request: 包含提示词、垫图和回调地址的请求体 - background_tasks: 用于异步处理任务状态查询的后台任务 Returns: - 包含任务ID的字典 + 包含MJ服务器任务ID的字典 """ try: - # 生成唯一任务ID - task_id = str(uuid.uuid4()) - logger.info(f"收到文生图请求,任务ID: {task_id}, 提示词: {request.prompt}") - - # 初始化任务状态 - TASK_STATUS[task_id] = { - "status": "pending", - "image_url": None, - "progress": 0, - "error": None - } - - # 提交到Midjourney + # 直接提交到Midjourney并返回MJ任务ID midjourney_task_id = Txt2Img.submit_imagine( prompt=request.prompt, base64_array=request.base64_array, notify_hook=request.notify_hook ) - - # 存储Midjourney任务ID - TASK_STATUS[task_id]["midjourney_task_id"] = midjourney_task_id - - # 添加后台任务轮询状态 - def poll_task_status_background(task_id, midjourney_task_id): - max_retries = 1000 - retry_count = 0 - retry_interval = 5 # 5秒 - - while retry_count < max_retries: - try: - # 查询任务状态 - result = Txt2Img.query_task_status(midjourney_task_id) - - # 更新任务状态 - if result.get("status") == "SUCCESS": - TASK_STATUS[task_id] = { - "status": "completed", - "image_url": result.get("imageUrl"), - "progress": 100, - "error": None, - "midjourney_task_id": midjourney_task_id - } - logger.info(f"任务 {task_id} 完成,图片URL: {result.get('imageUrl')}") - break - elif result.get("status") == "FAILED": - TASK_STATUS[task_id] = { - "status": "failed", - "image_url": None, - "progress": 0, - "error": result.get("errorMsg", "未知错误"), - "midjourney_task_id": midjourney_task_id - } - logger.error(f"任务 {task_id} 失败: {result.get('errorMsg', '未知错误')}") - break - else: - # 更新进度 - progress = result.get("progress", 0) - TASK_STATUS[task_id]["progress"] = progress - TASK_STATUS[task_id]["status"] = "processing" - logger.info(f"任务 {task_id} 处理中,进度: {progress}%") - - # 增加重试计数 - retry_count += 1 - - # 等待重试间隔 - time.sleep(retry_interval) - - except Exception as e: - logger.error(f"轮询任务 {task_id} 状态失败: {str(e)}") - TASK_STATUS[task_id]["error"] = str(e) - time.sleep(retry_interval) - - if retry_count >= max_retries: - logger.error(f"任务 {task_id} 超时") - TASK_STATUS[task_id] = { - "status": "failed", - "image_url": None, - "progress": 0, - "error": "任务处理超时", - "midjourney_task_id": midjourney_task_id - } - - # 使用线程运行后台任务 - thread = threading.Thread(target=poll_task_status_background, args=(task_id, midjourney_task_id)) - thread.daemon = True - thread.start() - - return {"task_id": task_id} + + logger.info(f"提交文生图请求成功,MJ任务ID: {midjourney_task_id}, 提示词: {request.prompt}") + return {"task_id": midjourney_task_id} except Exception as e: logger.error(f"提交文生图请求失败: {str(e)}") raise HTTPException(status_code=500, detail=f"提交文生图请求失败: {str(e)}") - -# 在文件顶部确保已定义任务状态缓存 -task_status_cache = {} @router.get("/task_status") async def get_task_status(task_id: str): - if not task_id: - raise HTTPException(status_code=400, detail="task_id 参数缺失") + """ + 直接查询MJ服务器获取任务状态 - # 将所有 task_status_store 替换为 task_status_cache - if task_id not in task_status_cache: - raise HTTPException(status_code=404, detail="任务ID不存在") + Args: + task_id: MJ服务器返回的任务ID - return { - "task_id": task_id, - "status": task_status_cache[task_id]["status"], - "progress": task_status_cache[task_id]["progress"] , - "result": task_status_cache[task_id]["result"], - "error": task_status_cache[task_id]["error"] - } + Returns: + 包含任务状态的字典 + """ + try: + # 直接查询MJ服务器获取实时状态 + result = Txt2Img.query_task_status(task_id) + + # 处理查询结果 + if result.get("status") == "SUCCESS": + return { + "task_id": task_id, + "status": "completed", + "image_url": result.get("imageUrl"), + "progress": 100, + "error": None + } + elif result.get("status") == "FAILED": + return { + "task_id": task_id, + "status": "failed", + "image_url": None, + "progress": 0, + "error": result.get("errorMsg", "未知错误") + } + else: + return { + "task_id": task_id, + "status": "processing", + "progress": result.get("progress", 0), + "error": None + } + except Exception as e: + logger.error(f"查询任务状态失败: {str(e)}") + return { + "task_id": task_id, + "status": "error", + "message": f"查询任务状态失败: {str(e)}" + }, 500 diff --git a/dsLightRag/Routes/__pycache__/MjRoute.cpython-310.pyc b/dsLightRag/Routes/__pycache__/MjRoute.cpython-310.pyc index 53050568..4ffa347a 100644 Binary files a/dsLightRag/Routes/__pycache__/MjRoute.cpython-310.pyc and b/dsLightRag/Routes/__pycache__/MjRoute.cpython-310.pyc differ diff --git a/dsLightRag/static/Midjourney/mj.html b/dsLightRag/static/Midjourney/mj.html index d984e6ed..c18b4e6d 100644 --- a/dsLightRag/static/Midjourney/mj.html +++ b/dsLightRag/static/Midjourney/mj.html @@ -57,6 +57,23 @@ +
@@ -275,36 +292,20 @@ 生成图像
- - - - +