Files
dsProject/dsLightRag/Routes/MjRoute.py

168 lines
5.7 KiB
Python
Raw Normal View History

2025-08-25 14:22:53 +08:00
import logging
import json
import time
import datetime
import requests
import uuid
import os
from typing import Optional, Dict, Any
from fastapi import APIRouter, HTTPException, Query, BackgroundTasks
from pydantic import BaseModel
from Midjourney.Txt2Img import Txt2Img
from Config import Config
import asyncio
2025-08-25 14:40:05 +08:00
import threading
2025-08-25 14:22:53 +08:00
from fastapi.responses import StreamingResponse
# 创建路由路由器
router = APIRouter(prefix="/api/mj", tags=["文生图"])
# 配置日志
logger = logging.getLogger(__name__)
# 任务状态存储
TASK_STATUS: Dict[str, Dict[str, Any]] = {}
class ImagineRequest(BaseModel):
prompt: str
base64_array: Optional[list] = None
notify_hook: Optional[str] = None
class TaskStatusResponse(BaseModel):
task_id: str
status: str
image_url: Optional[str] = None
progress: Optional[int] = None
error: Optional[str] = None
@router.post("/imagine", response_model=Dict[str, str])
async def submit_imagine(request: ImagineRequest, background_tasks: BackgroundTasks):
"""
提交文生图请求
Args:
request: 包含提示词垫图和回调地址的请求体
background_tasks: 用于异步处理任务状态查询的后台任务
Returns:
包含任务ID的字典
"""
try:
# 生成唯一任务ID
task_id = str(uuid.uuid4())
logger.info(f"收到文生图请求任务ID: {task_id}, 提示词: {request.prompt}")
# 初始化任务状态
TASK_STATUS[task_id] = {
"status": "pending",
"image_url": None,
"progress": 0,
"error": None
}
# 提交到Midjourney
midjourney_task_id = Txt2Img.submit_imagine(
prompt=request.prompt,
base64_array=request.base64_array,
notify_hook=request.notify_hook
)
# 存储Midjourney任务ID
TASK_STATUS[task_id]["midjourney_task_id"] = midjourney_task_id
2025-08-25 14:40:05 +08:00
# 添加后台任务轮询状态
def poll_task_status_background(task_id, midjourney_task_id):
max_retries = 1000
retry_count = 0
retry_interval = 5 # 5秒
while retry_count < max_retries:
try:
# 查询任务状态
result = Txt2Img.query_task_status(midjourney_task_id)
# 更新任务状态
if result.get("status") == "SUCCESS":
TASK_STATUS[task_id] = {
"status": "completed",
"image_url": result.get("imageUrl"),
"progress": 100,
"error": None,
"midjourney_task_id": midjourney_task_id
}
logger.info(f"任务 {task_id} 完成图片URL: {result.get('imageUrl')}")
break
elif result.get("status") == "FAILED":
TASK_STATUS[task_id] = {
"status": "failed",
"image_url": None,
"progress": 0,
"error": result.get("errorMsg", "未知错误"),
"midjourney_task_id": midjourney_task_id
}
logger.error(f"任务 {task_id} 失败: {result.get('errorMsg', '未知错误')}")
break
else:
# 更新进度
progress = result.get("progress", 0)
TASK_STATUS[task_id]["progress"] = progress
TASK_STATUS[task_id]["status"] = "processing"
logger.info(f"任务 {task_id} 处理中,进度: {progress}%")
# 增加重试计数
retry_count += 1
# 等待重试间隔
time.sleep(retry_interval)
except Exception as e:
logger.error(f"轮询任务 {task_id} 状态失败: {str(e)}")
TASK_STATUS[task_id]["error"] = str(e)
time.sleep(retry_interval)
if retry_count >= max_retries:
logger.error(f"任务 {task_id} 超时")
TASK_STATUS[task_id] = {
"status": "failed",
"image_url": None,
"progress": 0,
"error": "任务处理超时",
"midjourney_task_id": midjourney_task_id
}
# 使用线程运行后台任务
thread = threading.Thread(target=poll_task_status_background, args=(task_id, midjourney_task_id))
thread.daemon = True
thread.start()
2025-08-25 14:22:53 +08:00
return {"task_id": task_id}
except Exception as e:
logger.error(f"提交文生图请求失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"提交文生图请求失败: {str(e)}")
@router.get("/task_status", response_model=TaskStatusResponse)
async def get_task_status(task_id: str = Query(..., description="任务ID")):
"""
查询文生图任务状态
Args:
task_id: 任务ID
Returns:
任务状态信息
"""
if task_id not in TASK_STATUS:
raise HTTPException(status_code=404, detail="任务ID不存在")
return TaskStatusResponse(
task_id=task_id,
status=TASK_STATUS[task_id]["status"],
image_url=TASK_STATUS[task_id]["image_url"],
progress=TASK_STATUS[task_id]["progress"],
error=TASK_STATUS[task_id]["error"]
)