Files
dsProject/dsLightRag/Routes/MjRoute.py

107 lines
3.2 KiB
Python
Raw Normal View History

2025-08-25 14:22:53 +08:00
import logging
import json
import time
import datetime
import requests
import uuid
import os
from typing import Optional, Dict, Any
from fastapi import APIRouter, HTTPException, Query, BackgroundTasks
from pydantic import BaseModel
from Midjourney.Txt2Img import Txt2Img
from Config import Config
import asyncio
2025-08-25 14:40:05 +08:00
import threading
2025-08-25 14:22:53 +08:00
from fastapi.responses import StreamingResponse
2025-08-25 15:36:15 +08:00
from fastapi import APIRouter, Request, HTTPException, BackgroundTasks
2025-08-25 16:25:12 +08:00
2025-08-25 14:22:53 +08:00
# 创建路由路由器
router = APIRouter(prefix="/api/mj", tags=["文生图"])
# 配置日志
logger = logging.getLogger(__name__)
class ImagineRequest(BaseModel):
prompt: str
base64_array: Optional[list] = None
notify_hook: Optional[str] = None
class TaskStatusResponse(BaseModel):
task_id: str
status: str
image_url: Optional[str] = None
progress: Optional[int] = None
error: Optional[str] = None
@router.post("/imagine", response_model=Dict[str, str])
2025-08-25 16:25:12 +08:00
async def submit_imagine(request: ImagineRequest):
2025-08-25 14:22:53 +08:00
"""
提交文生图请求
Args:
request: 包含提示词垫图和回调地址的请求体
Returns:
2025-08-25 16:25:12 +08:00
包含MJ服务器任务ID的字典
2025-08-25 14:22:53 +08:00
"""
try:
2025-08-25 16:25:12 +08:00
# 直接提交到Midjourney并返回MJ任务ID
2025-08-25 14:22:53 +08:00
midjourney_task_id = Txt2Img.submit_imagine(
prompt=request.prompt,
base64_array=request.base64_array,
notify_hook=request.notify_hook
)
2025-08-25 16:25:12 +08:00
logger.info(f"提交文生图请求成功MJ任务ID: {midjourney_task_id}, 提示词: {request.prompt}")
return {"task_id": midjourney_task_id}
2025-08-25 14:22:53 +08:00
except Exception as e:
logger.error(f"提交文生图请求失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"提交文生图请求失败: {str(e)}")
2025-08-25 15:36:15 +08:00
@router.get("/task_status")
async def get_task_status(task_id: str):
2025-08-25 16:25:12 +08:00
"""
直接查询MJ服务器获取任务状态
2025-08-25 14:22:53 +08:00
2025-08-25 16:25:12 +08:00
Args:
task_id: MJ服务器返回的任务ID
2025-08-25 15:36:15 +08:00
2025-08-25 16:25:12 +08:00
Returns:
包含任务状态的字典
"""
try:
# 直接查询MJ服务器获取实时状态
result = Txt2Img.query_task_status(task_id)
# 处理查询结果
if result.get("status") == "SUCCESS":
return {
"task_id": task_id,
"status": "completed",
"image_url": result.get("imageUrl"),
"progress": 100,
"error": None
}
elif result.get("status") == "FAILED":
return {
"task_id": task_id,
"status": "failed",
"image_url": None,
"progress": 0,
"error": result.get("errorMsg", "未知错误")
}
else:
return {
"task_id": task_id,
"status": "processing",
"progress": result.get("progress", 0),
"error": None
}
except Exception as e:
logger.error(f"查询任务状态失败: {str(e)}")
return {
"task_id": task_id,
"status": "error",
"message": f"查询任务状态失败: {str(e)}"
}, 500