import logging import json import time import datetime import requests import uuid import os from typing import Optional, Dict, Any from fastapi import APIRouter, HTTPException, Query, BackgroundTasks from pydantic import BaseModel from Midjourney.Txt2Img import Txt2Img from Config import Config import asyncio from fastapi.responses import StreamingResponse # 创建路由路由器 router = APIRouter(prefix="/api/mj", tags=["文生图"]) # 配置日志 logger = logging.getLogger(__name__) # 任务状态存储 TASK_STATUS: Dict[str, Dict[str, Any]] = {} class ImagineRequest(BaseModel): prompt: str base64_array: Optional[list] = None notify_hook: Optional[str] = None class TaskStatusResponse(BaseModel): task_id: str status: str image_url: Optional[str] = None progress: Optional[int] = None error: Optional[str] = None @router.post("/imagine", response_model=Dict[str, str]) async def submit_imagine(request: ImagineRequest, background_tasks: BackgroundTasks): """ 提交文生图请求 Args: request: 包含提示词、垫图和回调地址的请求体 background_tasks: 用于异步处理任务状态查询的后台任务 Returns: 包含任务ID的字典 """ try: # 生成唯一任务ID task_id = str(uuid.uuid4()) logger.info(f"收到文生图请求,任务ID: {task_id}, 提示词: {request.prompt}") # 初始化任务状态 TASK_STATUS[task_id] = { "status": "pending", "image_url": None, "progress": 0, "error": None } # 提交到Midjourney midjourney_task_id = Txt2Img.submit_imagine( prompt=request.prompt, base64_array=request.base64_array, notify_hook=request.notify_hook ) # 存储Midjourney任务ID TASK_STATUS[task_id]["midjourney_task_id"] = midjourney_task_id return {"task_id": task_id} except Exception as e: logger.error(f"提交文生图请求失败: {str(e)}") raise HTTPException(status_code=500, detail=f"提交文生图请求失败: {str(e)}") @router.get("/task_status", response_model=TaskStatusResponse) async def get_task_status(task_id: str = Query(..., description="任务ID")): """ 查询文生图任务状态 Args: task_id: 任务ID Returns: 任务状态信息 """ if task_id not in TASK_STATUS: raise HTTPException(status_code=404, detail="任务ID不存在") return TaskStatusResponse( task_id=task_id, status=TASK_STATUS[task_id]["status"], image_url=TASK_STATUS[task_id]["image_url"], progress=TASK_STATUS[task_id]["progress"], error=TASK_STATUS[task_id]["error"] )