From 71a67388fea54b4d95987ed2f1ef022b8c827d9e Mon Sep 17 00:00:00 2001 From: HuangHai <10402852@qq.com> Date: Mon, 25 Aug 2025 14:22:53 +0800 Subject: [PATCH] 'commit' --- .../Ai/Util/Midjourney/Kit/MjCommon.py | 111 ------------- .../com/dsideal/Ai/Util/Midjourney/Txt2Img.py | 150 ------------------ dsLightRag/Midjourney/Txt2Img.py | 47 ++++++ dsLightRag/Routes/MjRoute.py | 100 ++++++++++++ dsLightRag/Start.py | 4 +- 5 files changed, 150 insertions(+), 262 deletions(-) delete mode 100644 dsAi/src/main/python/com/dsideal/Ai/Util/Midjourney/Kit/MjCommon.py delete mode 100644 dsAi/src/main/python/com/dsideal/Ai/Util/Midjourney/Txt2Img.py create mode 100644 dsLightRag/Routes/MjRoute.py diff --git a/dsAi/src/main/python/com/dsideal/Ai/Util/Midjourney/Kit/MjCommon.py b/dsAi/src/main/python/com/dsideal/Ai/Util/Midjourney/Kit/MjCommon.py deleted file mode 100644 index f85f463e..00000000 --- a/dsAi/src/main/python/com/dsideal/Ai/Util/Midjourney/Kit/MjCommon.py +++ /dev/null @@ -1,111 +0,0 @@ -import os -import logging -import requests -import json -import time -from pathlib import Path - -# 配置日志 -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') -logger = logging.getLogger('MjCommon') - - -class MjCommon: - # 获取项目根目录路径 - project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", "..", "..", "..")) - # 拼接相对路径 - base_path = os.path.join(project_root, "src", "main", "python", "com", "dsideal", "aiSupport", "Util", "Midjourney", "Example") - - ak = "sk-amQHwiEzPIZIB2KuF5A10dC23a0e4b02B48a7a2b6aFa0662" # 填写access key - BASE_URL = "https://goapi.gptnb.ai" - - @staticmethod - def download_file(file_url, save_file_path): - """从URL下载文件到指定路径 - - Args: - file_url (str): 文件URL - save_file_path (str): 保存路径 - - Raises: - Exception: 下载过程中的异常 - """ - try: - # 确保目录存在 - os.makedirs(os.path.dirname(save_file_path), exist_ok=True) - - # 下载文件 - logger.info(f"开始下载文件: {file_url}") - response = requests.get(file_url, stream=True, timeout=30) - response.raise_for_status() - - # 保存文件 - with open(save_file_path, 'wb') as f: - for chunk in response.iter_content(chunk_size=8192): - f.write(chunk) - - file_size = os.path.getsize(save_file_path) - logger.info(f"文件下载成功,保存路径: {save_file_path}, 文件大小: {file_size}字节") - except Exception as e: - logger.error(f"文件下载失败: {str(e)}") - raise Exception(f"文件下载失败: {str(e)}") - - @staticmethod - def query_task_status(task_id): - """查询任务状态 - - Args: - task_id (str): 任务ID - - Returns: - dict: 任务结果 - - Raises: - Exception: 异常信息 - """ - # 创建请求URL - url = f"{MjCommon.BASE_URL}/mj/task/{task_id}/fetch" - - # 设置请求头 - headers = { - "Authorization": f"Bearer {MjCommon.ak}" - } - - try: - logger.info(f"查询Midjourney任务状态: {task_id}") - response = requests.get(url, headers=headers, timeout=30) - - # 检查响应状态 - if not response.ok: - error_msg = f"Midjourney API请求失败,状态码: {response.status_code}" - logger.error(error_msg) - raise Exception(error_msg) - - # 解析响应 - response_body = response.text - logger.info(f"查询Midjourney任务状态响应: {response_body}") - - return json.loads(response_body) - except Exception as e: - logger.error(f"查询任务状态失败: {str(e)}") - raise Exception(f"查询任务状态失败: {str(e)}") - - -# 测试代码 -if __name__ == "__main__": - # 测试下载文件 - # try: - # test_url = "https://example.com/test.jpg" - # test_save_path = os.path.join(MjCommon.base_path, "test.jpg") - # MjCommon.download_file(test_url, test_save_path) - # except Exception as e: - # print(f"测试下载失败: {e}") - - # 测试查询任务状态 - # try: - # test_task_id = "your_task_id_here" - # result = MjCommon.query_task_status(test_task_id) - # print(f"任务状态查询结果: {result}") - # except Exception as e: - # print(f"测试查询任务状态失败: {e}") - pass \ No newline at end of file diff --git a/dsAi/src/main/python/com/dsideal/Ai/Util/Midjourney/Txt2Img.py b/dsAi/src/main/python/com/dsideal/Ai/Util/Midjourney/Txt2Img.py deleted file mode 100644 index 98ed9c3f..00000000 --- a/dsAi/src/main/python/com/dsideal/Ai/Util/Midjourney/Txt2Img.py +++ /dev/null @@ -1,150 +0,0 @@ -import os -import time -import json -import logging -import requests -from com.dsideal.Ai.Util.Midjourney.Kit.MjCommon import MjCommon - -# 配置日志 -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') -log = logging.getLogger(__name__) - -class Txt2Img(MjCommon): - """ - Midjourney API 工具类 - 用于调用 Midjourney 的 imagine 接口生成图片 - """ - - @staticmethod - def submit_imagine(prompt, base64_array=None, notify_hook=None): - """ - 提交 imagine 请求 - - :param prompt: 提示词 - :param base64_array: 垫图base64数组 - :param notify_hook: 回调地址 - :return: 任务ID - :raises Exception: 异常信息 - """ - try: - # 构建请求体 - request_body = { - "prompt": prompt, - "state": f"midjourney_task_{int(time.time() * 1000)}" - } - - # 如果提供了垫图base64数组,则添加到请求体中 - if base64_array and len(base64_array) > 0: - request_body["base64Array"] = base64_array - - # 如果提供了回调地址,则添加到请求体中 - if notify_hook and notify_hook.strip(): - request_body["notifyHook"] = notify_hook - - # 创建请求 - url = f"{MjCommon.BASE_URL}/mj/submit/imagine" - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {MjCommon.ak}" - } - - # 发送请求并获取响应 - log.info(f"提交Midjourney imagine请求: {json.dumps(request_body)}") - response = requests.post(url, headers=headers, json=request_body, timeout=30) - - # 检查响应状态 - if response.status_code != 200: - error_msg = f"Midjourney API请求失败,状态码: {response.status_code}" - log.error(error_msg) - raise Exception(error_msg) - - # 解析响应 - response_body = response.text - log.info(f"Midjourney imagine响应: {response_body}") - - response_json = json.loads(response_body) - - # 检查响应状态 - code=1 表示成功 - if response_json.get("code") != 1: - error_msg = f"Midjourney imagine失败: {response_json.get('description', '未知错误')}" - log.error(error_msg) - raise Exception(error_msg) - - # 获取任务ID - 从result字段获取 - task_id = response_json.get("result") - log.info(f"Midjourney imagine任务ID: {task_id}") - - return task_id - except Exception as e: - log.error(f"提交imagine请求异常: {str(e)}") - raise - - @staticmethod - def main(): - """测试方法""" - try: - # 提示词 - prompt = "A cute cat playing with a ball of yarn, digital art style" - - # 提交imagine请求 - task_id = Txt2Img.submit_imagine(prompt) - - # 轮询查询任务状态 - max_retries = 1000 - retry_count = 0 - retry_interval = 5 # 5秒 - image_url = None - - while retry_count < max_retries: - result = MjCommon.query_task_status(task_id) - - # 直接使用响应中的字段 - status = result.get("status") - log.info(f"任务状态: {status}") - - # 检查进度 - progress = result.get("progress") - log.info(f"任务进度: {progress}") - - # 任务状态可能为空字符串,需要检查progress或其他字段来判断任务是否完成 - if status and status.strip(): - if status == "SUCCESS": - # 任务成功,获取图片URL - image_url = result.get("imageUrl") - log.info(f"生成的图片URL: {image_url}") - break - elif status == "FAILED": - # 任务失败 - log.error(f"任务失败: {result.get('failReason', '未知原因')}") - break - else: - # 检查description字段 - description = result.get("description") - if description and "成功" in description and progress != "0%": - # 如果描述包含"成功"且进度不为0%,可能任务已完成 - image_url = result.get("imageUrl") - if image_url and image_url.strip(): - log.info(f"生成的图片URL: {image_url}") - break - - # 任务仍在进行中,等待后重试 - log.info(f"任务进行中,等待{retry_interval}秒后重试...") - time.sleep(retry_interval) - retry_count += 1 - - if retry_count >= max_retries: - log.error(f"查询任务状态超时,已达到最大重试次数: {max_retries}") - - # 如果获取到了图片URL,则下载保存 - if image_url and image_url.strip(): - # 生成文件名(使用时间戳和任务ID) - file_name = f"mj_{int(time.time() * 1000)}_{task_id}.png" - # 完整保存路径 - save_path = os.path.join(MjCommon.base_path, file_name) - # 下载图片 - MjCommon.download_file(image_url, save_path) - except Exception as e: - log.error(f"程序执行异常: {str(e)}") - -if __name__ == "__main__": - Txt2Img.main() \ No newline at end of file diff --git a/dsLightRag/Midjourney/Txt2Img.py b/dsLightRag/Midjourney/Txt2Img.py index c2d6e04d..4629279a 100644 --- a/dsLightRag/Midjourney/Txt2Img.py +++ b/dsLightRag/Midjourney/Txt2Img.py @@ -147,5 +147,52 @@ class Txt2Img(MjCommon): except Exception as e: log.error(f"程序执行异常: {str(e)}") +@staticmethod +def query_task_status(task_id): + """ + 查询任务状态 + + :param task_id: 任务ID + :return: 任务状态信息 + :raises Exception: 异常信息 + """ + try: + # 创建请求 + url = f"{MjCommon.BASE_URL}/mj/task/status" + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {MjCommon.ak}" + } + params = { + "taskId": task_id + } + + # 发送请求并获取响应 + log.info(f"查询Midjourney任务状态: {task_id}") + response = requests.get(url, headers=headers, params=params, timeout=30) + + # 检查响应状态 + if response.status_code != 200: + error_msg = f"Midjourney API请求失败,状态码: {response.status_code}" + log.error(error_msg) + raise Exception(error_msg) + + # 解析响应 + response_body = response.text + log.info(f"Midjourney任务状态响应: {response_body}") + + response_json = json.loads(response_body) + + # 检查响应状态 - code=1 表示成功 + if response_json.get("code") != 1: + error_msg = f"Midjourney任务状态查询失败: {response_json.get('description', '未知错误')}" + log.error(error_msg) + raise Exception(error_msg) + + # 返回任务状态信息 + return response_json.get("result", {}) + except Exception as e: + log.error(f"查询任务状态异常: {str(e)}") + raise if __name__ == "__main__": Txt2Img.main() \ No newline at end of file diff --git a/dsLightRag/Routes/MjRoute.py b/dsLightRag/Routes/MjRoute.py new file mode 100644 index 00000000..2dfd3897 --- /dev/null +++ b/dsLightRag/Routes/MjRoute.py @@ -0,0 +1,100 @@ +import logging +import json +import time +import datetime +import requests +import uuid +import os +from typing import Optional, Dict, Any +from fastapi import APIRouter, HTTPException, Query, BackgroundTasks +from pydantic import BaseModel +from Midjourney.Txt2Img import Txt2Img +from Config import Config +import asyncio +from fastapi.responses import StreamingResponse + +# 创建路由路由器 +router = APIRouter(prefix="/api/mj", tags=["文生图"]) + +# 配置日志 +logger = logging.getLogger(__name__) + +# 任务状态存储 +TASK_STATUS: Dict[str, Dict[str, Any]] = {} + + +class ImagineRequest(BaseModel): + prompt: str + base64_array: Optional[list] = None + notify_hook: Optional[str] = None + + +class TaskStatusResponse(BaseModel): + task_id: str + status: str + image_url: Optional[str] = None + progress: Optional[int] = None + error: Optional[str] = None + + +@router.post("/imagine", response_model=Dict[str, str]) +async def submit_imagine(request: ImagineRequest, background_tasks: BackgroundTasks): + """ + 提交文生图请求 + + Args: + request: 包含提示词、垫图和回调地址的请求体 + background_tasks: 用于异步处理任务状态查询的后台任务 + + Returns: + 包含任务ID的字典 + """ + try: + # 生成唯一任务ID + task_id = str(uuid.uuid4()) + logger.info(f"收到文生图请求,任务ID: {task_id}, 提示词: {request.prompt}") + + # 初始化任务状态 + TASK_STATUS[task_id] = { + "status": "pending", + "image_url": None, + "progress": 0, + "error": None + } + + # 提交到Midjourney + midjourney_task_id = Txt2Img.submit_imagine( + prompt=request.prompt, + base64_array=request.base64_array, + notify_hook=request.notify_hook + ) + + # 存储Midjourney任务ID + TASK_STATUS[task_id]["midjourney_task_id"] = midjourney_task_id + return {"task_id": task_id} + except Exception as e: + logger.error(f"提交文生图请求失败: {str(e)}") + raise HTTPException(status_code=500, detail=f"提交文生图请求失败: {str(e)}") + + +@router.get("/task_status", response_model=TaskStatusResponse) +async def get_task_status(task_id: str = Query(..., description="任务ID")): + """ + 查询文生图任务状态 + + Args: + task_id: 任务ID + + Returns: + 任务状态信息 + """ + if task_id not in TASK_STATUS: + raise HTTPException(status_code=404, detail="任务ID不存在") + + return TaskStatusResponse( + task_id=task_id, + status=TASK_STATUS[task_id]["status"], + image_url=TASK_STATUS[task_id]["image_url"], + progress=TASK_STATUS[task_id]["progress"], + error=TASK_STATUS[task_id]["error"] + ) diff --git a/dsLightRag/Start.py b/dsLightRag/Start.py index 045df23b..6307577f 100644 --- a/dsLightRag/Start.py +++ b/dsLightRag/Start.py @@ -22,6 +22,7 @@ from Routes.QA import router as qa_router from Routes.JiMengRoute import router as jimeng_router from Routes.SunoRoute import router as suno_router from Routes.XueBanRoute import router as xueban_router +from Routes.MjRoute import router as mj_router from Util.LightRagUtil import * from contextlib import asynccontextmanager @@ -62,7 +63,8 @@ app.include_router(llm_router) # 大模型路由 app.include_router(qa_router) # 答疑路由 app.include_router(jimeng_router) # 即梦路由 app.include_router(suno_router) # Suno路由 -app.include_router(xueban_router) # 学伴路由 +app.include_router(xueban_router) # 学伴路由 +app.include_router(mj_router) # Midjourney路由 # Teaching Model 相关路由 # 登录相关(不用登录)