168 lines
5.7 KiB
Python
168 lines
5.7 KiB
Python
import logging
|
||
import json
|
||
import time
|
||
import datetime
|
||
import requests
|
||
import uuid
|
||
import os
|
||
from typing import Optional, Dict, Any
|
||
from fastapi import APIRouter, HTTPException, Query, BackgroundTasks
|
||
from pydantic import BaseModel
|
||
from Midjourney.Txt2Img import Txt2Img
|
||
from Config import Config
|
||
import asyncio
|
||
import threading
|
||
from fastapi.responses import StreamingResponse
|
||
|
||
# 创建路由路由器
|
||
router = APIRouter(prefix="/api/mj", tags=["文生图"])
|
||
|
||
# 配置日志
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# 任务状态存储
|
||
TASK_STATUS: Dict[str, Dict[str, Any]] = {}
|
||
|
||
|
||
class ImagineRequest(BaseModel):
|
||
prompt: str
|
||
base64_array: Optional[list] = None
|
||
notify_hook: Optional[str] = None
|
||
|
||
|
||
class TaskStatusResponse(BaseModel):
|
||
task_id: str
|
||
status: str
|
||
image_url: Optional[str] = None
|
||
progress: Optional[int] = None
|
||
error: Optional[str] = None
|
||
|
||
|
||
@router.post("/imagine", response_model=Dict[str, str])
|
||
async def submit_imagine(request: ImagineRequest, background_tasks: BackgroundTasks):
|
||
"""
|
||
提交文生图请求
|
||
|
||
Args:
|
||
request: 包含提示词、垫图和回调地址的请求体
|
||
background_tasks: 用于异步处理任务状态查询的后台任务
|
||
|
||
Returns:
|
||
包含任务ID的字典
|
||
"""
|
||
try:
|
||
# 生成唯一任务ID
|
||
task_id = str(uuid.uuid4())
|
||
logger.info(f"收到文生图请求,任务ID: {task_id}, 提示词: {request.prompt}")
|
||
|
||
# 初始化任务状态
|
||
TASK_STATUS[task_id] = {
|
||
"status": "pending",
|
||
"image_url": None,
|
||
"progress": 0,
|
||
"error": None
|
||
}
|
||
|
||
# 提交到Midjourney
|
||
midjourney_task_id = Txt2Img.submit_imagine(
|
||
prompt=request.prompt,
|
||
base64_array=request.base64_array,
|
||
notify_hook=request.notify_hook
|
||
)
|
||
|
||
# 存储Midjourney任务ID
|
||
TASK_STATUS[task_id]["midjourney_task_id"] = midjourney_task_id
|
||
|
||
# 添加后台任务轮询状态
|
||
def poll_task_status_background(task_id, midjourney_task_id):
|
||
max_retries = 1000
|
||
retry_count = 0
|
||
retry_interval = 5 # 5秒
|
||
|
||
while retry_count < max_retries:
|
||
try:
|
||
# 查询任务状态
|
||
result = Txt2Img.query_task_status(midjourney_task_id)
|
||
|
||
# 更新任务状态
|
||
if result.get("status") == "SUCCESS":
|
||
TASK_STATUS[task_id] = {
|
||
"status": "completed",
|
||
"image_url": result.get("imageUrl"),
|
||
"progress": 100,
|
||
"error": None,
|
||
"midjourney_task_id": midjourney_task_id
|
||
}
|
||
logger.info(f"任务 {task_id} 完成,图片URL: {result.get('imageUrl')}")
|
||
break
|
||
elif result.get("status") == "FAILED":
|
||
TASK_STATUS[task_id] = {
|
||
"status": "failed",
|
||
"image_url": None,
|
||
"progress": 0,
|
||
"error": result.get("errorMsg", "未知错误"),
|
||
"midjourney_task_id": midjourney_task_id
|
||
}
|
||
logger.error(f"任务 {task_id} 失败: {result.get('errorMsg', '未知错误')}")
|
||
break
|
||
else:
|
||
# 更新进度
|
||
progress = result.get("progress", 0)
|
||
TASK_STATUS[task_id]["progress"] = progress
|
||
TASK_STATUS[task_id]["status"] = "processing"
|
||
logger.info(f"任务 {task_id} 处理中,进度: {progress}%")
|
||
|
||
# 增加重试计数
|
||
retry_count += 1
|
||
|
||
# 等待重试间隔
|
||
time.sleep(retry_interval)
|
||
|
||
except Exception as e:
|
||
logger.error(f"轮询任务 {task_id} 状态失败: {str(e)}")
|
||
TASK_STATUS[task_id]["error"] = str(e)
|
||
time.sleep(retry_interval)
|
||
|
||
if retry_count >= max_retries:
|
||
logger.error(f"任务 {task_id} 超时")
|
||
TASK_STATUS[task_id] = {
|
||
"status": "failed",
|
||
"image_url": None,
|
||
"progress": 0,
|
||
"error": "任务处理超时",
|
||
"midjourney_task_id": midjourney_task_id
|
||
}
|
||
|
||
# 使用线程运行后台任务
|
||
thread = threading.Thread(target=poll_task_status_background, args=(task_id, midjourney_task_id))
|
||
thread.daemon = True
|
||
thread.start()
|
||
|
||
return {"task_id": task_id}
|
||
except Exception as e:
|
||
logger.error(f"提交文生图请求失败: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"提交文生图请求失败: {str(e)}")
|
||
|
||
|
||
@router.get("/task_status", response_model=TaskStatusResponse)
|
||
async def get_task_status(task_id: str = Query(..., description="任务ID")):
|
||
"""
|
||
查询文生图任务状态
|
||
|
||
Args:
|
||
task_id: 任务ID
|
||
|
||
Returns:
|
||
任务状态信息
|
||
"""
|
||
if task_id not in TASK_STATUS:
|
||
raise HTTPException(status_code=404, detail="任务ID不存在")
|
||
|
||
return TaskStatusResponse(
|
||
task_id=task_id,
|
||
status=TASK_STATUS[task_id]["status"],
|
||
image_url=TASK_STATUS[task_id]["image_url"],
|
||
progress=TASK_STATUS[task_id]["progress"],
|
||
error=TASK_STATUS[task_id]["error"]
|
||
)
|