'commit'
This commit is contained in:
@@ -147,5 +147,52 @@ class Txt2Img(MjCommon):
|
||||
except Exception as e:
|
||||
log.error(f"程序执行异常: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
def query_task_status(task_id):
|
||||
"""
|
||||
查询任务状态
|
||||
|
||||
:param task_id: 任务ID
|
||||
:return: 任务状态信息
|
||||
:raises Exception: 异常信息
|
||||
"""
|
||||
try:
|
||||
# 创建请求
|
||||
url = f"{MjCommon.BASE_URL}/mj/task/status"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {MjCommon.ak}"
|
||||
}
|
||||
params = {
|
||||
"taskId": task_id
|
||||
}
|
||||
|
||||
# 发送请求并获取响应
|
||||
log.info(f"查询Midjourney任务状态: {task_id}")
|
||||
response = requests.get(url, headers=headers, params=params, timeout=30)
|
||||
|
||||
# 检查响应状态
|
||||
if response.status_code != 200:
|
||||
error_msg = f"Midjourney API请求失败,状态码: {response.status_code}"
|
||||
log.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
|
||||
# 解析响应
|
||||
response_body = response.text
|
||||
log.info(f"Midjourney任务状态响应: {response_body}")
|
||||
|
||||
response_json = json.loads(response_body)
|
||||
|
||||
# 检查响应状态 - code=1 表示成功
|
||||
if response_json.get("code") != 1:
|
||||
error_msg = f"Midjourney任务状态查询失败: {response_json.get('description', '未知错误')}"
|
||||
log.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
|
||||
# 返回任务状态信息
|
||||
return response_json.get("result", {})
|
||||
except Exception as e:
|
||||
log.error(f"查询任务状态异常: {str(e)}")
|
||||
raise
|
||||
if __name__ == "__main__":
|
||||
Txt2Img.main()
|
100
dsLightRag/Routes/MjRoute.py
Normal file
100
dsLightRag/Routes/MjRoute.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import logging
|
||||
import json
|
||||
import time
|
||||
import datetime
|
||||
import requests
|
||||
import uuid
|
||||
import os
|
||||
from typing import Optional, Dict, Any
|
||||
from fastapi import APIRouter, HTTPException, Query, BackgroundTasks
|
||||
from pydantic import BaseModel
|
||||
from Midjourney.Txt2Img import Txt2Img
|
||||
from Config import Config
|
||||
import asyncio
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
# 创建路由路由器
|
||||
router = APIRouter(prefix="/api/mj", tags=["文生图"])
|
||||
|
||||
# 配置日志
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 任务状态存储
|
||||
TASK_STATUS: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
|
||||
class ImagineRequest(BaseModel):
|
||||
prompt: str
|
||||
base64_array: Optional[list] = None
|
||||
notify_hook: Optional[str] = None
|
||||
|
||||
|
||||
class TaskStatusResponse(BaseModel):
|
||||
task_id: str
|
||||
status: str
|
||||
image_url: Optional[str] = None
|
||||
progress: Optional[int] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
@router.post("/imagine", response_model=Dict[str, str])
|
||||
async def submit_imagine(request: ImagineRequest, background_tasks: BackgroundTasks):
|
||||
"""
|
||||
提交文生图请求
|
||||
|
||||
Args:
|
||||
request: 包含提示词、垫图和回调地址的请求体
|
||||
background_tasks: 用于异步处理任务状态查询的后台任务
|
||||
|
||||
Returns:
|
||||
包含任务ID的字典
|
||||
"""
|
||||
try:
|
||||
# 生成唯一任务ID
|
||||
task_id = str(uuid.uuid4())
|
||||
logger.info(f"收到文生图请求,任务ID: {task_id}, 提示词: {request.prompt}")
|
||||
|
||||
# 初始化任务状态
|
||||
TASK_STATUS[task_id] = {
|
||||
"status": "pending",
|
||||
"image_url": None,
|
||||
"progress": 0,
|
||||
"error": None
|
||||
}
|
||||
|
||||
# 提交到Midjourney
|
||||
midjourney_task_id = Txt2Img.submit_imagine(
|
||||
prompt=request.prompt,
|
||||
base64_array=request.base64_array,
|
||||
notify_hook=request.notify_hook
|
||||
)
|
||||
|
||||
# 存储Midjourney任务ID
|
||||
TASK_STATUS[task_id]["midjourney_task_id"] = midjourney_task_id
|
||||
return {"task_id": task_id}
|
||||
except Exception as e:
|
||||
logger.error(f"提交文生图请求失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"提交文生图请求失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/task_status", response_model=TaskStatusResponse)
|
||||
async def get_task_status(task_id: str = Query(..., description="任务ID")):
|
||||
"""
|
||||
查询文生图任务状态
|
||||
|
||||
Args:
|
||||
task_id: 任务ID
|
||||
|
||||
Returns:
|
||||
任务状态信息
|
||||
"""
|
||||
if task_id not in TASK_STATUS:
|
||||
raise HTTPException(status_code=404, detail="任务ID不存在")
|
||||
|
||||
return TaskStatusResponse(
|
||||
task_id=task_id,
|
||||
status=TASK_STATUS[task_id]["status"],
|
||||
image_url=TASK_STATUS[task_id]["image_url"],
|
||||
progress=TASK_STATUS[task_id]["progress"],
|
||||
error=TASK_STATUS[task_id]["error"]
|
||||
)
|
@@ -22,6 +22,7 @@ from Routes.QA import router as qa_router
|
||||
from Routes.JiMengRoute import router as jimeng_router
|
||||
from Routes.SunoRoute import router as suno_router
|
||||
from Routes.XueBanRoute import router as xueban_router
|
||||
from Routes.MjRoute import router as mj_router
|
||||
from Util.LightRagUtil import *
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@@ -62,7 +63,8 @@ app.include_router(llm_router) # 大模型路由
|
||||
app.include_router(qa_router) # 答疑路由
|
||||
app.include_router(jimeng_router) # 即梦路由
|
||||
app.include_router(suno_router) # Suno路由
|
||||
app.include_router(xueban_router) # 学伴路由
|
||||
app.include_router(xueban_router) # 学伴路由
|
||||
app.include_router(mj_router) # Midjourney路由
|
||||
|
||||
# Teaching Model 相关路由
|
||||
# 登录相关(不用登录)
|
||||
|
Reference in New Issue
Block a user