This commit is contained in:
2025-08-25 14:22:53 +08:00
parent 456be765ab
commit 71a67388fe
5 changed files with 150 additions and 262 deletions

View File

@@ -1,111 +0,0 @@
import os
import logging
import requests
import json
import time
from pathlib import Path
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger('MjCommon')
class MjCommon:
# 获取项目根目录路径
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", "..", "..", ".."))
# 拼接相对路径
base_path = os.path.join(project_root, "src", "main", "python", "com", "dsideal", "aiSupport", "Util", "Midjourney", "Example")
ak = "sk-amQHwiEzPIZIB2KuF5A10dC23a0e4b02B48a7a2b6aFa0662" # 填写access key
BASE_URL = "https://goapi.gptnb.ai"
@staticmethod
def download_file(file_url, save_file_path):
"""从URL下载文件到指定路径
Args:
file_url (str): 文件URL
save_file_path (str): 保存路径
Raises:
Exception: 下载过程中的异常
"""
try:
# 确保目录存在
os.makedirs(os.path.dirname(save_file_path), exist_ok=True)
# 下载文件
logger.info(f"开始下载文件: {file_url}")
response = requests.get(file_url, stream=True, timeout=30)
response.raise_for_status()
# 保存文件
with open(save_file_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
file_size = os.path.getsize(save_file_path)
logger.info(f"文件下载成功,保存路径: {save_file_path}, 文件大小: {file_size}字节")
except Exception as e:
logger.error(f"文件下载失败: {str(e)}")
raise Exception(f"文件下载失败: {str(e)}")
@staticmethod
def query_task_status(task_id):
"""查询任务状态
Args:
task_id (str): 任务ID
Returns:
dict: 任务结果
Raises:
Exception: 异常信息
"""
# 创建请求URL
url = f"{MjCommon.BASE_URL}/mj/task/{task_id}/fetch"
# 设置请求头
headers = {
"Authorization": f"Bearer {MjCommon.ak}"
}
try:
logger.info(f"查询Midjourney任务状态: {task_id}")
response = requests.get(url, headers=headers, timeout=30)
# 检查响应状态
if not response.ok:
error_msg = f"Midjourney API请求失败状态码: {response.status_code}"
logger.error(error_msg)
raise Exception(error_msg)
# 解析响应
response_body = response.text
logger.info(f"查询Midjourney任务状态响应: {response_body}")
return json.loads(response_body)
except Exception as e:
logger.error(f"查询任务状态失败: {str(e)}")
raise Exception(f"查询任务状态失败: {str(e)}")
# 测试代码
if __name__ == "__main__":
# 测试下载文件
# try:
# test_url = "https://example.com/test.jpg"
# test_save_path = os.path.join(MjCommon.base_path, "test.jpg")
# MjCommon.download_file(test_url, test_save_path)
# except Exception as e:
# print(f"测试下载失败: {e}")
# 测试查询任务状态
# try:
# test_task_id = "your_task_id_here"
# result = MjCommon.query_task_status(test_task_id)
# print(f"任务状态查询结果: {result}")
# except Exception as e:
# print(f"测试查询任务状态失败: {e}")
pass

View File

@@ -1,150 +0,0 @@
import os
import time
import json
import logging
import requests
from com.dsideal.Ai.Util.Midjourney.Kit.MjCommon import MjCommon
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
log = logging.getLogger(__name__)
class Txt2Img(MjCommon):
"""
Midjourney API 工具类
用于调用 Midjourney 的 imagine 接口生成图片
"""
@staticmethod
def submit_imagine(prompt, base64_array=None, notify_hook=None):
"""
提交 imagine 请求
:param prompt: 提示词
:param base64_array: 垫图base64数组
:param notify_hook: 回调地址
:return: 任务ID
:raises Exception: 异常信息
"""
try:
# 构建请求体
request_body = {
"prompt": prompt,
"state": f"midjourney_task_{int(time.time() * 1000)}"
}
# 如果提供了垫图base64数组则添加到请求体中
if base64_array and len(base64_array) > 0:
request_body["base64Array"] = base64_array
# 如果提供了回调地址,则添加到请求体中
if notify_hook and notify_hook.strip():
request_body["notifyHook"] = notify_hook
# 创建请求
url = f"{MjCommon.BASE_URL}/mj/submit/imagine"
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {MjCommon.ak}"
}
# 发送请求并获取响应
log.info(f"提交Midjourney imagine请求: {json.dumps(request_body)}")
response = requests.post(url, headers=headers, json=request_body, timeout=30)
# 检查响应状态
if response.status_code != 200:
error_msg = f"Midjourney API请求失败状态码: {response.status_code}"
log.error(error_msg)
raise Exception(error_msg)
# 解析响应
response_body = response.text
log.info(f"Midjourney imagine响应: {response_body}")
response_json = json.loads(response_body)
# 检查响应状态 - code=1 表示成功
if response_json.get("code") != 1:
error_msg = f"Midjourney imagine失败: {response_json.get('description', '未知错误')}"
log.error(error_msg)
raise Exception(error_msg)
# 获取任务ID - 从result字段获取
task_id = response_json.get("result")
log.info(f"Midjourney imagine任务ID: {task_id}")
return task_id
except Exception as e:
log.error(f"提交imagine请求异常: {str(e)}")
raise
@staticmethod
def main():
"""测试方法"""
try:
# 提示词
prompt = "A cute cat playing with a ball of yarn, digital art style"
# 提交imagine请求
task_id = Txt2Img.submit_imagine(prompt)
# 轮询查询任务状态
max_retries = 1000
retry_count = 0
retry_interval = 5 # 5秒
image_url = None
while retry_count < max_retries:
result = MjCommon.query_task_status(task_id)
# 直接使用响应中的字段
status = result.get("status")
log.info(f"任务状态: {status}")
# 检查进度
progress = result.get("progress")
log.info(f"任务进度: {progress}")
# 任务状态可能为空字符串需要检查progress或其他字段来判断任务是否完成
if status and status.strip():
if status == "SUCCESS":
# 任务成功获取图片URL
image_url = result.get("imageUrl")
log.info(f"生成的图片URL: {image_url}")
break
elif status == "FAILED":
# 任务失败
log.error(f"任务失败: {result.get('failReason', '未知原因')}")
break
else:
# 检查description字段
description = result.get("description")
if description and "成功" in description and progress != "0%":
# 如果描述包含"成功"且进度不为0%,可能任务已完成
image_url = result.get("imageUrl")
if image_url and image_url.strip():
log.info(f"生成的图片URL: {image_url}")
break
# 任务仍在进行中,等待后重试
log.info(f"任务进行中,等待{retry_interval}秒后重试...")
time.sleep(retry_interval)
retry_count += 1
if retry_count >= max_retries:
log.error(f"查询任务状态超时,已达到最大重试次数: {max_retries}")
# 如果获取到了图片URL则下载保存
if image_url and image_url.strip():
# 生成文件名使用时间戳和任务ID
file_name = f"mj_{int(time.time() * 1000)}_{task_id}.png"
# 完整保存路径
save_path = os.path.join(MjCommon.base_path, file_name)
# 下载图片
MjCommon.download_file(image_url, save_path)
except Exception as e:
log.error(f"程序执行异常: {str(e)}")
if __name__ == "__main__":
Txt2Img.main()

View File

@@ -147,5 +147,52 @@ class Txt2Img(MjCommon):
except Exception as e:
log.error(f"程序执行异常: {str(e)}")
@staticmethod
def query_task_status(task_id):
"""
查询任务状态
:param task_id: 任务ID
:return: 任务状态信息
:raises Exception: 异常信息
"""
try:
# 创建请求
url = f"{MjCommon.BASE_URL}/mj/task/status"
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {MjCommon.ak}"
}
params = {
"taskId": task_id
}
# 发送请求并获取响应
log.info(f"查询Midjourney任务状态: {task_id}")
response = requests.get(url, headers=headers, params=params, timeout=30)
# 检查响应状态
if response.status_code != 200:
error_msg = f"Midjourney API请求失败状态码: {response.status_code}"
log.error(error_msg)
raise Exception(error_msg)
# 解析响应
response_body = response.text
log.info(f"Midjourney任务状态响应: {response_body}")
response_json = json.loads(response_body)
# 检查响应状态 - code=1 表示成功
if response_json.get("code") != 1:
error_msg = f"Midjourney任务状态查询失败: {response_json.get('description', '未知错误')}"
log.error(error_msg)
raise Exception(error_msg)
# 返回任务状态信息
return response_json.get("result", {})
except Exception as e:
log.error(f"查询任务状态异常: {str(e)}")
raise
if __name__ == "__main__":
Txt2Img.main()

View File

@@ -0,0 +1,100 @@
import logging
import json
import time
import datetime
import requests
import uuid
import os
from typing import Optional, Dict, Any
from fastapi import APIRouter, HTTPException, Query, BackgroundTasks
from pydantic import BaseModel
from Midjourney.Txt2Img import Txt2Img
from Config import Config
import asyncio
from fastapi.responses import StreamingResponse
# 创建路由路由器
router = APIRouter(prefix="/api/mj", tags=["文生图"])
# 配置日志
logger = logging.getLogger(__name__)
# 任务状态存储
TASK_STATUS: Dict[str, Dict[str, Any]] = {}
class ImagineRequest(BaseModel):
prompt: str
base64_array: Optional[list] = None
notify_hook: Optional[str] = None
class TaskStatusResponse(BaseModel):
task_id: str
status: str
image_url: Optional[str] = None
progress: Optional[int] = None
error: Optional[str] = None
@router.post("/imagine", response_model=Dict[str, str])
async def submit_imagine(request: ImagineRequest, background_tasks: BackgroundTasks):
"""
提交文生图请求
Args:
request: 包含提示词、垫图和回调地址的请求体
background_tasks: 用于异步处理任务状态查询的后台任务
Returns:
包含任务ID的字典
"""
try:
# 生成唯一任务ID
task_id = str(uuid.uuid4())
logger.info(f"收到文生图请求任务ID: {task_id}, 提示词: {request.prompt}")
# 初始化任务状态
TASK_STATUS[task_id] = {
"status": "pending",
"image_url": None,
"progress": 0,
"error": None
}
# 提交到Midjourney
midjourney_task_id = Txt2Img.submit_imagine(
prompt=request.prompt,
base64_array=request.base64_array,
notify_hook=request.notify_hook
)
# 存储Midjourney任务ID
TASK_STATUS[task_id]["midjourney_task_id"] = midjourney_task_id
return {"task_id": task_id}
except Exception as e:
logger.error(f"提交文生图请求失败: {str(e)}")
raise HTTPException(status_code=500, detail=f"提交文生图请求失败: {str(e)}")
@router.get("/task_status", response_model=TaskStatusResponse)
async def get_task_status(task_id: str = Query(..., description="任务ID")):
"""
查询文生图任务状态
Args:
task_id: 任务ID
Returns:
任务状态信息
"""
if task_id not in TASK_STATUS:
raise HTTPException(status_code=404, detail="任务ID不存在")
return TaskStatusResponse(
task_id=task_id,
status=TASK_STATUS[task_id]["status"],
image_url=TASK_STATUS[task_id]["image_url"],
progress=TASK_STATUS[task_id]["progress"],
error=TASK_STATUS[task_id]["error"]
)

View File

@@ -22,6 +22,7 @@ from Routes.QA import router as qa_router
from Routes.JiMengRoute import router as jimeng_router
from Routes.SunoRoute import router as suno_router
from Routes.XueBanRoute import router as xueban_router
from Routes.MjRoute import router as mj_router
from Util.LightRagUtil import *
from contextlib import asynccontextmanager
@@ -62,7 +63,8 @@ app.include_router(llm_router) # 大模型路由
app.include_router(qa_router) # 答疑路由
app.include_router(jimeng_router) # 即梦路由
app.include_router(suno_router) # Suno路由
app.include_router(xueban_router) # 学伴路由
app.include_router(xueban_router) # 学伴路由
app.include_router(mj_router) # Midjourney路由
# Teaching Model 相关路由
# 登录相关(不用登录)