'commit'
This commit is contained in:
@@ -1,111 +0,0 @@
|
||||
import os
|
||||
import logging
|
||||
import requests
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger('MjCommon')
|
||||
|
||||
|
||||
class MjCommon:
|
||||
# 获取项目根目录路径
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", "..", "..", ".."))
|
||||
# 拼接相对路径
|
||||
base_path = os.path.join(project_root, "src", "main", "python", "com", "dsideal", "aiSupport", "Util", "Midjourney", "Example")
|
||||
|
||||
ak = "sk-amQHwiEzPIZIB2KuF5A10dC23a0e4b02B48a7a2b6aFa0662" # 填写access key
|
||||
BASE_URL = "https://goapi.gptnb.ai"
|
||||
|
||||
@staticmethod
|
||||
def download_file(file_url, save_file_path):
|
||||
"""从URL下载文件到指定路径
|
||||
|
||||
Args:
|
||||
file_url (str): 文件URL
|
||||
save_file_path (str): 保存路径
|
||||
|
||||
Raises:
|
||||
Exception: 下载过程中的异常
|
||||
"""
|
||||
try:
|
||||
# 确保目录存在
|
||||
os.makedirs(os.path.dirname(save_file_path), exist_ok=True)
|
||||
|
||||
# 下载文件
|
||||
logger.info(f"开始下载文件: {file_url}")
|
||||
response = requests.get(file_url, stream=True, timeout=30)
|
||||
response.raise_for_status()
|
||||
|
||||
# 保存文件
|
||||
with open(save_file_path, 'wb') as f:
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
|
||||
file_size = os.path.getsize(save_file_path)
|
||||
logger.info(f"文件下载成功,保存路径: {save_file_path}, 文件大小: {file_size}字节")
|
||||
except Exception as e:
|
||||
logger.error(f"文件下载失败: {str(e)}")
|
||||
raise Exception(f"文件下载失败: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
def query_task_status(task_id):
|
||||
"""查询任务状态
|
||||
|
||||
Args:
|
||||
task_id (str): 任务ID
|
||||
|
||||
Returns:
|
||||
dict: 任务结果
|
||||
|
||||
Raises:
|
||||
Exception: 异常信息
|
||||
"""
|
||||
# 创建请求URL
|
||||
url = f"{MjCommon.BASE_URL}/mj/task/{task_id}/fetch"
|
||||
|
||||
# 设置请求头
|
||||
headers = {
|
||||
"Authorization": f"Bearer {MjCommon.ak}"
|
||||
}
|
||||
|
||||
try:
|
||||
logger.info(f"查询Midjourney任务状态: {task_id}")
|
||||
response = requests.get(url, headers=headers, timeout=30)
|
||||
|
||||
# 检查响应状态
|
||||
if not response.ok:
|
||||
error_msg = f"Midjourney API请求失败,状态码: {response.status_code}"
|
||||
logger.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
|
||||
# 解析响应
|
||||
response_body = response.text
|
||||
logger.info(f"查询Midjourney任务状态响应: {response_body}")
|
||||
|
||||
return json.loads(response_body)
|
||||
except Exception as e:
|
||||
logger.error(f"查询任务状态失败: {str(e)}")
|
||||
raise Exception(f"查询任务状态失败: {str(e)}")
|
||||
|
||||
|
||||
# 测试代码
|
||||
if __name__ == "__main__":
|
||||
# 测试下载文件
|
||||
# try:
|
||||
# test_url = "https://example.com/test.jpg"
|
||||
# test_save_path = os.path.join(MjCommon.base_path, "test.jpg")
|
||||
# MjCommon.download_file(test_url, test_save_path)
|
||||
# except Exception as e:
|
||||
# print(f"测试下载失败: {e}")
|
||||
|
||||
# 测试查询任务状态
|
||||
# try:
|
||||
# test_task_id = "your_task_id_here"
|
||||
# result = MjCommon.query_task_status(test_task_id)
|
||||
# print(f"任务状态查询结果: {result}")
|
||||
# except Exception as e:
|
||||
# print(f"测试查询任务状态失败: {e}")
|
||||
pass
|
@@ -1,150 +0,0 @@
|
||||
import os
|
||||
import time
|
||||
import json
|
||||
import logging
|
||||
import requests
|
||||
from com.dsideal.Ai.Util.Midjourney.Kit.MjCommon import MjCommon
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
class Txt2Img(MjCommon):
|
||||
"""
|
||||
Midjourney API 工具类
|
||||
用于调用 Midjourney 的 imagine 接口生成图片
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def submit_imagine(prompt, base64_array=None, notify_hook=None):
|
||||
"""
|
||||
提交 imagine 请求
|
||||
|
||||
:param prompt: 提示词
|
||||
:param base64_array: 垫图base64数组
|
||||
:param notify_hook: 回调地址
|
||||
:return: 任务ID
|
||||
:raises Exception: 异常信息
|
||||
"""
|
||||
try:
|
||||
# 构建请求体
|
||||
request_body = {
|
||||
"prompt": prompt,
|
||||
"state": f"midjourney_task_{int(time.time() * 1000)}"
|
||||
}
|
||||
|
||||
# 如果提供了垫图base64数组,则添加到请求体中
|
||||
if base64_array and len(base64_array) > 0:
|
||||
request_body["base64Array"] = base64_array
|
||||
|
||||
# 如果提供了回调地址,则添加到请求体中
|
||||
if notify_hook and notify_hook.strip():
|
||||
request_body["notifyHook"] = notify_hook
|
||||
|
||||
# 创建请求
|
||||
url = f"{MjCommon.BASE_URL}/mj/submit/imagine"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {MjCommon.ak}"
|
||||
}
|
||||
|
||||
# 发送请求并获取响应
|
||||
log.info(f"提交Midjourney imagine请求: {json.dumps(request_body)}")
|
||||
response = requests.post(url, headers=headers, json=request_body, timeout=30)
|
||||
|
||||
# 检查响应状态
|
||||
if response.status_code != 200:
|
||||
error_msg = f"Midjourney API请求失败,状态码: {response.status_code}"
|
||||
log.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
|
||||
# 解析响应
|
||||
response_body = response.text
|
||||
log.info(f"Midjourney imagine响应: {response_body}")
|
||||
|
||||
response_json = json.loads(response_body)
|
||||
|
||||
# 检查响应状态 - code=1 表示成功
|
||||
if response_json.get("code") != 1:
|
||||
error_msg = f"Midjourney imagine失败: {response_json.get('description', '未知错误')}"
|
||||
log.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
|
||||
# 获取任务ID - 从result字段获取
|
||||
task_id = response_json.get("result")
|
||||
log.info(f"Midjourney imagine任务ID: {task_id}")
|
||||
|
||||
return task_id
|
||||
except Exception as e:
|
||||
log.error(f"提交imagine请求异常: {str(e)}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def main():
|
||||
"""测试方法"""
|
||||
try:
|
||||
# 提示词
|
||||
prompt = "A cute cat playing with a ball of yarn, digital art style"
|
||||
|
||||
# 提交imagine请求
|
||||
task_id = Txt2Img.submit_imagine(prompt)
|
||||
|
||||
# 轮询查询任务状态
|
||||
max_retries = 1000
|
||||
retry_count = 0
|
||||
retry_interval = 5 # 5秒
|
||||
image_url = None
|
||||
|
||||
while retry_count < max_retries:
|
||||
result = MjCommon.query_task_status(task_id)
|
||||
|
||||
# 直接使用响应中的字段
|
||||
status = result.get("status")
|
||||
log.info(f"任务状态: {status}")
|
||||
|
||||
# 检查进度
|
||||
progress = result.get("progress")
|
||||
log.info(f"任务进度: {progress}")
|
||||
|
||||
# 任务状态可能为空字符串,需要检查progress或其他字段来判断任务是否完成
|
||||
if status and status.strip():
|
||||
if status == "SUCCESS":
|
||||
# 任务成功,获取图片URL
|
||||
image_url = result.get("imageUrl")
|
||||
log.info(f"生成的图片URL: {image_url}")
|
||||
break
|
||||
elif status == "FAILED":
|
||||
# 任务失败
|
||||
log.error(f"任务失败: {result.get('failReason', '未知原因')}")
|
||||
break
|
||||
else:
|
||||
# 检查description字段
|
||||
description = result.get("description")
|
||||
if description and "成功" in description and progress != "0%":
|
||||
# 如果描述包含"成功"且进度不为0%,可能任务已完成
|
||||
image_url = result.get("imageUrl")
|
||||
if image_url and image_url.strip():
|
||||
log.info(f"生成的图片URL: {image_url}")
|
||||
break
|
||||
|
||||
# 任务仍在进行中,等待后重试
|
||||
log.info(f"任务进行中,等待{retry_interval}秒后重试...")
|
||||
time.sleep(retry_interval)
|
||||
retry_count += 1
|
||||
|
||||
if retry_count >= max_retries:
|
||||
log.error(f"查询任务状态超时,已达到最大重试次数: {max_retries}")
|
||||
|
||||
# 如果获取到了图片URL,则下载保存
|
||||
if image_url and image_url.strip():
|
||||
# 生成文件名(使用时间戳和任务ID)
|
||||
file_name = f"mj_{int(time.time() * 1000)}_{task_id}.png"
|
||||
# 完整保存路径
|
||||
save_path = os.path.join(MjCommon.base_path, file_name)
|
||||
# 下载图片
|
||||
MjCommon.download_file(image_url, save_path)
|
||||
except Exception as e:
|
||||
log.error(f"程序执行异常: {str(e)}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
Txt2Img.main()
|
@@ -147,5 +147,52 @@ class Txt2Img(MjCommon):
|
||||
except Exception as e:
|
||||
log.error(f"程序执行异常: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
def query_task_status(task_id):
|
||||
"""
|
||||
查询任务状态
|
||||
|
||||
:param task_id: 任务ID
|
||||
:return: 任务状态信息
|
||||
:raises Exception: 异常信息
|
||||
"""
|
||||
try:
|
||||
# 创建请求
|
||||
url = f"{MjCommon.BASE_URL}/mj/task/status"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {MjCommon.ak}"
|
||||
}
|
||||
params = {
|
||||
"taskId": task_id
|
||||
}
|
||||
|
||||
# 发送请求并获取响应
|
||||
log.info(f"查询Midjourney任务状态: {task_id}")
|
||||
response = requests.get(url, headers=headers, params=params, timeout=30)
|
||||
|
||||
# 检查响应状态
|
||||
if response.status_code != 200:
|
||||
error_msg = f"Midjourney API请求失败,状态码: {response.status_code}"
|
||||
log.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
|
||||
# 解析响应
|
||||
response_body = response.text
|
||||
log.info(f"Midjourney任务状态响应: {response_body}")
|
||||
|
||||
response_json = json.loads(response_body)
|
||||
|
||||
# 检查响应状态 - code=1 表示成功
|
||||
if response_json.get("code") != 1:
|
||||
error_msg = f"Midjourney任务状态查询失败: {response_json.get('description', '未知错误')}"
|
||||
log.error(error_msg)
|
||||
raise Exception(error_msg)
|
||||
|
||||
# 返回任务状态信息
|
||||
return response_json.get("result", {})
|
||||
except Exception as e:
|
||||
log.error(f"查询任务状态异常: {str(e)}")
|
||||
raise
|
||||
if __name__ == "__main__":
|
||||
Txt2Img.main()
|
100
dsLightRag/Routes/MjRoute.py
Normal file
100
dsLightRag/Routes/MjRoute.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import logging
|
||||
import json
|
||||
import time
|
||||
import datetime
|
||||
import requests
|
||||
import uuid
|
||||
import os
|
||||
from typing import Optional, Dict, Any
|
||||
from fastapi import APIRouter, HTTPException, Query, BackgroundTasks
|
||||
from pydantic import BaseModel
|
||||
from Midjourney.Txt2Img import Txt2Img
|
||||
from Config import Config
|
||||
import asyncio
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
# 创建路由路由器
|
||||
router = APIRouter(prefix="/api/mj", tags=["文生图"])
|
||||
|
||||
# 配置日志
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 任务状态存储
|
||||
TASK_STATUS: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
|
||||
class ImagineRequest(BaseModel):
|
||||
prompt: str
|
||||
base64_array: Optional[list] = None
|
||||
notify_hook: Optional[str] = None
|
||||
|
||||
|
||||
class TaskStatusResponse(BaseModel):
|
||||
task_id: str
|
||||
status: str
|
||||
image_url: Optional[str] = None
|
||||
progress: Optional[int] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
@router.post("/imagine", response_model=Dict[str, str])
|
||||
async def submit_imagine(request: ImagineRequest, background_tasks: BackgroundTasks):
|
||||
"""
|
||||
提交文生图请求
|
||||
|
||||
Args:
|
||||
request: 包含提示词、垫图和回调地址的请求体
|
||||
background_tasks: 用于异步处理任务状态查询的后台任务
|
||||
|
||||
Returns:
|
||||
包含任务ID的字典
|
||||
"""
|
||||
try:
|
||||
# 生成唯一任务ID
|
||||
task_id = str(uuid.uuid4())
|
||||
logger.info(f"收到文生图请求,任务ID: {task_id}, 提示词: {request.prompt}")
|
||||
|
||||
# 初始化任务状态
|
||||
TASK_STATUS[task_id] = {
|
||||
"status": "pending",
|
||||
"image_url": None,
|
||||
"progress": 0,
|
||||
"error": None
|
||||
}
|
||||
|
||||
# 提交到Midjourney
|
||||
midjourney_task_id = Txt2Img.submit_imagine(
|
||||
prompt=request.prompt,
|
||||
base64_array=request.base64_array,
|
||||
notify_hook=request.notify_hook
|
||||
)
|
||||
|
||||
# 存储Midjourney任务ID
|
||||
TASK_STATUS[task_id]["midjourney_task_id"] = midjourney_task_id
|
||||
return {"task_id": task_id}
|
||||
except Exception as e:
|
||||
logger.error(f"提交文生图请求失败: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"提交文生图请求失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/task_status", response_model=TaskStatusResponse)
|
||||
async def get_task_status(task_id: str = Query(..., description="任务ID")):
|
||||
"""
|
||||
查询文生图任务状态
|
||||
|
||||
Args:
|
||||
task_id: 任务ID
|
||||
|
||||
Returns:
|
||||
任务状态信息
|
||||
"""
|
||||
if task_id not in TASK_STATUS:
|
||||
raise HTTPException(status_code=404, detail="任务ID不存在")
|
||||
|
||||
return TaskStatusResponse(
|
||||
task_id=task_id,
|
||||
status=TASK_STATUS[task_id]["status"],
|
||||
image_url=TASK_STATUS[task_id]["image_url"],
|
||||
progress=TASK_STATUS[task_id]["progress"],
|
||||
error=TASK_STATUS[task_id]["error"]
|
||||
)
|
@@ -22,6 +22,7 @@ from Routes.QA import router as qa_router
|
||||
from Routes.JiMengRoute import router as jimeng_router
|
||||
from Routes.SunoRoute import router as suno_router
|
||||
from Routes.XueBanRoute import router as xueban_router
|
||||
from Routes.MjRoute import router as mj_router
|
||||
from Util.LightRagUtil import *
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
@@ -63,6 +64,7 @@ app.include_router(qa_router) # 答疑路由
|
||||
app.include_router(jimeng_router) # 即梦路由
|
||||
app.include_router(suno_router) # Suno路由
|
||||
app.include_router(xueban_router) # 学伴路由
|
||||
app.include_router(mj_router) # Midjourney路由
|
||||
|
||||
# Teaching Model 相关路由
|
||||
# 登录相关(不用登录)
|
||||
|
Reference in New Issue
Block a user