diff --git a/dsAi/src/main/python/com/dsideal/Ai/Util/Midjourney/Kit/MjCommon.py b/dsAi/src/main/python/com/dsideal/Ai/Util/Midjourney/Kit/MjCommon.py new file mode 100644 index 00000000..f85f463e --- /dev/null +++ b/dsAi/src/main/python/com/dsideal/Ai/Util/Midjourney/Kit/MjCommon.py @@ -0,0 +1,111 @@ +import os +import logging +import requests +import json +import time +from pathlib import Path + +# 配置日志 +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') +logger = logging.getLogger('MjCommon') + + +class MjCommon: + # 获取项目根目录路径 + project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", "..", "..", "..")) + # 拼接相对路径 + base_path = os.path.join(project_root, "src", "main", "python", "com", "dsideal", "aiSupport", "Util", "Midjourney", "Example") + + ak = "sk-amQHwiEzPIZIB2KuF5A10dC23a0e4b02B48a7a2b6aFa0662" # 填写access key + BASE_URL = "https://goapi.gptnb.ai" + + @staticmethod + def download_file(file_url, save_file_path): + """从URL下载文件到指定路径 + + Args: + file_url (str): 文件URL + save_file_path (str): 保存路径 + + Raises: + Exception: 下载过程中的异常 + """ + try: + # 确保目录存在 + os.makedirs(os.path.dirname(save_file_path), exist_ok=True) + + # 下载文件 + logger.info(f"开始下载文件: {file_url}") + response = requests.get(file_url, stream=True, timeout=30) + response.raise_for_status() + + # 保存文件 + with open(save_file_path, 'wb') as f: + for chunk in response.iter_content(chunk_size=8192): + f.write(chunk) + + file_size = os.path.getsize(save_file_path) + logger.info(f"文件下载成功,保存路径: {save_file_path}, 文件大小: {file_size}字节") + except Exception as e: + logger.error(f"文件下载失败: {str(e)}") + raise Exception(f"文件下载失败: {str(e)}") + + @staticmethod + def query_task_status(task_id): + """查询任务状态 + + Args: + task_id (str): 任务ID + + Returns: + dict: 任务结果 + + Raises: + Exception: 异常信息 + """ + # 创建请求URL + url = f"{MjCommon.BASE_URL}/mj/task/{task_id}/fetch" + + # 设置请求头 + headers = { + "Authorization": f"Bearer {MjCommon.ak}" + } + + try: + logger.info(f"查询Midjourney任务状态: {task_id}") + response = requests.get(url, headers=headers, timeout=30) + + # 检查响应状态 + if not response.ok: + error_msg = f"Midjourney API请求失败,状态码: {response.status_code}" + logger.error(error_msg) + raise Exception(error_msg) + + # 解析响应 + response_body = response.text + logger.info(f"查询Midjourney任务状态响应: {response_body}") + + return json.loads(response_body) + except Exception as e: + logger.error(f"查询任务状态失败: {str(e)}") + raise Exception(f"查询任务状态失败: {str(e)}") + + +# 测试代码 +if __name__ == "__main__": + # 测试下载文件 + # try: + # test_url = "https://example.com/test.jpg" + # test_save_path = os.path.join(MjCommon.base_path, "test.jpg") + # MjCommon.download_file(test_url, test_save_path) + # except Exception as e: + # print(f"测试下载失败: {e}") + + # 测试查询任务状态 + # try: + # test_task_id = "your_task_id_here" + # result = MjCommon.query_task_status(test_task_id) + # print(f"任务状态查询结果: {result}") + # except Exception as e: + # print(f"测试查询任务状态失败: {e}") + pass \ No newline at end of file diff --git a/dsAi/src/main/python/com/dsideal/Ai/Util/Midjourney/Txt2Img.py b/dsAi/src/main/python/com/dsideal/Ai/Util/Midjourney/Txt2Img.py new file mode 100644 index 00000000..98ed9c3f --- /dev/null +++ b/dsAi/src/main/python/com/dsideal/Ai/Util/Midjourney/Txt2Img.py @@ -0,0 +1,150 @@ +import os +import time +import json +import logging +import requests +from com.dsideal.Ai.Util.Midjourney.Kit.MjCommon import MjCommon + +# 配置日志 +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') +log = logging.getLogger(__name__) + +class Txt2Img(MjCommon): + """ + Midjourney API 工具类 + 用于调用 Midjourney 的 imagine 接口生成图片 + """ + + @staticmethod + def submit_imagine(prompt, base64_array=None, notify_hook=None): + """ + 提交 imagine 请求 + + :param prompt: 提示词 + :param base64_array: 垫图base64数组 + :param notify_hook: 回调地址 + :return: 任务ID + :raises Exception: 异常信息 + """ + try: + # 构建请求体 + request_body = { + "prompt": prompt, + "state": f"midjourney_task_{int(time.time() * 1000)}" + } + + # 如果提供了垫图base64数组,则添加到请求体中 + if base64_array and len(base64_array) > 0: + request_body["base64Array"] = base64_array + + # 如果提供了回调地址,则添加到请求体中 + if notify_hook and notify_hook.strip(): + request_body["notifyHook"] = notify_hook + + # 创建请求 + url = f"{MjCommon.BASE_URL}/mj/submit/imagine" + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {MjCommon.ak}" + } + + # 发送请求并获取响应 + log.info(f"提交Midjourney imagine请求: {json.dumps(request_body)}") + response = requests.post(url, headers=headers, json=request_body, timeout=30) + + # 检查响应状态 + if response.status_code != 200: + error_msg = f"Midjourney API请求失败,状态码: {response.status_code}" + log.error(error_msg) + raise Exception(error_msg) + + # 解析响应 + response_body = response.text + log.info(f"Midjourney imagine响应: {response_body}") + + response_json = json.loads(response_body) + + # 检查响应状态 - code=1 表示成功 + if response_json.get("code") != 1: + error_msg = f"Midjourney imagine失败: {response_json.get('description', '未知错误')}" + log.error(error_msg) + raise Exception(error_msg) + + # 获取任务ID - 从result字段获取 + task_id = response_json.get("result") + log.info(f"Midjourney imagine任务ID: {task_id}") + + return task_id + except Exception as e: + log.error(f"提交imagine请求异常: {str(e)}") + raise + + @staticmethod + def main(): + """测试方法""" + try: + # 提示词 + prompt = "A cute cat playing with a ball of yarn, digital art style" + + # 提交imagine请求 + task_id = Txt2Img.submit_imagine(prompt) + + # 轮询查询任务状态 + max_retries = 1000 + retry_count = 0 + retry_interval = 5 # 5秒 + image_url = None + + while retry_count < max_retries: + result = MjCommon.query_task_status(task_id) + + # 直接使用响应中的字段 + status = result.get("status") + log.info(f"任务状态: {status}") + + # 检查进度 + progress = result.get("progress") + log.info(f"任务进度: {progress}") + + # 任务状态可能为空字符串,需要检查progress或其他字段来判断任务是否完成 + if status and status.strip(): + if status == "SUCCESS": + # 任务成功,获取图片URL + image_url = result.get("imageUrl") + log.info(f"生成的图片URL: {image_url}") + break + elif status == "FAILED": + # 任务失败 + log.error(f"任务失败: {result.get('failReason', '未知原因')}") + break + else: + # 检查description字段 + description = result.get("description") + if description and "成功" in description and progress != "0%": + # 如果描述包含"成功"且进度不为0%,可能任务已完成 + image_url = result.get("imageUrl") + if image_url and image_url.strip(): + log.info(f"生成的图片URL: {image_url}") + break + + # 任务仍在进行中,等待后重试 + log.info(f"任务进行中,等待{retry_interval}秒后重试...") + time.sleep(retry_interval) + retry_count += 1 + + if retry_count >= max_retries: + log.error(f"查询任务状态超时,已达到最大重试次数: {max_retries}") + + # 如果获取到了图片URL,则下载保存 + if image_url and image_url.strip(): + # 生成文件名(使用时间戳和任务ID) + file_name = f"mj_{int(time.time() * 1000)}_{task_id}.png" + # 完整保存路径 + save_path = os.path.join(MjCommon.base_path, file_name) + # 下载图片 + MjCommon.download_file(image_url, save_path) + except Exception as e: + log.error(f"程序执行异常: {str(e)}") + +if __name__ == "__main__": + Txt2Img.main() \ No newline at end of file diff --git a/dsLightRag/Config/Config.py b/dsLightRag/Config/Config.py index f2d9d377..68c321bd 100644 --- a/dsLightRag/Config/Config.py +++ b/dsLightRag/Config/Config.py @@ -70,6 +70,7 @@ ZHIPU_API_KEY = "78dc1dfe37e04f29bd4ca9a49858a969.gn7TIZTfzpY35nx9" # GPTNB的API KEY GPTNB_API_KEY = "sk-amQHwiEzPIZIB2KuF5A10dC23a0e4b02B48a7a2b6aFa0662" +GPTNB_BASE_URL="https://goapi.gptnb.ai" # JWT配置信息 JWT_SECRET_KEY = "ZXZnZWVr5b+r5LmQ5L2g55qE5Ye66KGM" diff --git a/dsLightRag/Midjourney/Kit/MjCommon.py b/dsLightRag/Midjourney/Kit/MjCommon.py new file mode 100644 index 00000000..625b276c --- /dev/null +++ b/dsLightRag/Midjourney/Kit/MjCommon.py @@ -0,0 +1,113 @@ +import os +import logging +import requests +import json +import time +from pathlib import Path + +from Config import Config + +# 配置日志 +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') +logger = logging.getLogger('MjCommon') + + +class MjCommon: + # 获取项目根目录路径 + project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", "..", "..", "..")) + # 拼接相对路径 + base_path = os.path.join(project_root, "src", "main", "python", "com", "dsideal", "aiSupport", "Util", "Midjourney", "Example") + + ak = Config.GPTNB_API_KEY + BASE_URL = Config.GPTNB_BASE_URL + + @staticmethod + def download_file(file_url, save_file_path): + """从URL下载文件到指定路径 + + Args: + file_url (str): 文件URL + save_file_path (str): 保存路径 + + Raises: + Exception: 下载过程中的异常 + """ + try: + # 确保目录存在 + os.makedirs(os.path.dirname(save_file_path), exist_ok=True) + + # 下载文件 + logger.info(f"开始下载文件: {file_url}") + response = requests.get(file_url, stream=True, timeout=30) + response.raise_for_status() + + # 保存文件 + with open(save_file_path, 'wb') as f: + for chunk in response.iter_content(chunk_size=8192): + f.write(chunk) + + file_size = os.path.getsize(save_file_path) + logger.info(f"文件下载成功,保存路径: {save_file_path}, 文件大小: {file_size}字节") + except Exception as e: + logger.error(f"文件下载失败: {str(e)}") + raise Exception(f"文件下载失败: {str(e)}") + + @staticmethod + def query_task_status(task_id): + """查询任务状态 + + Args: + task_id (str): 任务ID + + Returns: + dict: 任务结果 + + Raises: + Exception: 异常信息 + """ + # 创建请求URL + url = f"{MjCommon.BASE_URL}/mj/task/{task_id}/fetch" + + # 设置请求头 + headers = { + "Authorization": f"Bearer {MjCommon.ak}" + } + + try: + logger.info(f"查询Midjourney任务状态: {task_id}") + response = requests.get(url, headers=headers, timeout=30) + + # 检查响应状态 + if not response.ok: + error_msg = f"Midjourney API请求失败,状态码: {response.status_code}" + logger.error(error_msg) + raise Exception(error_msg) + + # 解析响应 + response_body = response.text + logger.info(f"查询Midjourney任务状态响应: {response_body}") + + return json.loads(response_body) + except Exception as e: + logger.error(f"查询任务状态失败: {str(e)}") + raise Exception(f"查询任务状态失败: {str(e)}") + + +# 测试代码 +if __name__ == "__main__": + # 测试下载文件 + # try: + # test_url = "https://example.com/test.jpg" + # test_save_path = os.path.join(MjCommon.base_path, "test.jpg") + # MjCommon.download_file(test_url, test_save_path) + # except Exception as e: + # print(f"测试下载失败: {e}") + + # 测试查询任务状态 + # try: + # test_task_id = "your_task_id_here" + # result = MjCommon.query_task_status(test_task_id) + # print(f"任务状态查询结果: {result}") + # except Exception as e: + # print(f"测试查询任务状态失败: {e}") + pass \ No newline at end of file diff --git a/dsLightRag/Midjourney/Txt2Img.py b/dsLightRag/Midjourney/Txt2Img.py new file mode 100644 index 00000000..426e0938 --- /dev/null +++ b/dsLightRag/Midjourney/Txt2Img.py @@ -0,0 +1,151 @@ +import os +import time +import json +import logging +import requests + +from Midjourney.Kit.MjCommon import MjCommon + +# 配置日志 +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') +log = logging.getLogger(__name__) + +class Txt2Img(MjCommon): + """ + Midjourney API 工具类 + 用于调用 Midjourney 的 imagine 接口生成图片 + """ + + @staticmethod + def submit_imagine(prompt, base64_array=None, notify_hook=None): + """ + 提交 imagine 请求 + + :param prompt: 提示词 + :param base64_array: 垫图base64数组 + :param notify_hook: 回调地址 + :return: 任务ID + :raises Exception: 异常信息 + """ + try: + # 构建请求体 + request_body = { + "prompt": prompt, + "state": f"midjourney_task_{int(time.time() * 1000)}" + } + + # 如果提供了垫图base64数组,则添加到请求体中 + if base64_array and len(base64_array) > 0: + request_body["base64Array"] = base64_array + + # 如果提供了回调地址,则添加到请求体中 + if notify_hook and notify_hook.strip(): + request_body["notifyHook"] = notify_hook + + # 创建请求 + url = f"{MjCommon.BASE_URL}/mj/submit/imagine" + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {MjCommon.ak}" + } + + # 发送请求并获取响应 + log.info(f"提交Midjourney imagine请求: {json.dumps(request_body)}") + response = requests.post(url, headers=headers, json=request_body, timeout=30) + + # 检查响应状态 + if response.status_code != 200: + error_msg = f"Midjourney API请求失败,状态码: {response.status_code}" + log.error(error_msg) + raise Exception(error_msg) + + # 解析响应 + response_body = response.text + log.info(f"Midjourney imagine响应: {response_body}") + + response_json = json.loads(response_body) + + # 检查响应状态 - code=1 表示成功 + if response_json.get("code") != 1: + error_msg = f"Midjourney imagine失败: {response_json.get('description', '未知错误')}" + log.error(error_msg) + raise Exception(error_msg) + + # 获取任务ID - 从result字段获取 + task_id = response_json.get("result") + log.info(f"Midjourney imagine任务ID: {task_id}") + + return task_id + except Exception as e: + log.error(f"提交imagine请求异常: {str(e)}") + raise + + @staticmethod + def main(): + """测试方法""" + try: + # 提示词 + prompt = "A cute cat playing with a ball of yarn, digital art style" + + # 提交imagine请求 + task_id = Txt2Img.submit_imagine(prompt) + + # 轮询查询任务状态 + max_retries = 1000 + retry_count = 0 + retry_interval = 5 # 5秒 + image_url = None + + while retry_count < max_retries: + result = MjCommon.query_task_status(task_id) + + # 直接使用响应中的字段 + status = result.get("status") + log.info(f"任务状态: {status}") + + # 检查进度 + progress = result.get("progress") + log.info(f"任务进度: {progress}") + + # 任务状态可能为空字符串,需要检查progress或其他字段来判断任务是否完成 + if status and status.strip(): + if status == "SUCCESS": + # 任务成功,获取图片URL + image_url = result.get("imageUrl") + log.info(f"生成的图片URL: {image_url}") + break + elif status == "FAILED": + # 任务失败 + log.error(f"任务失败: {result.get('failReason', '未知原因')}") + break + else: + # 检查description字段 + description = result.get("description") + if description and "成功" in description and progress != "0%": + # 如果描述包含"成功"且进度不为0%,可能任务已完成 + image_url = result.get("imageUrl") + if image_url and image_url.strip(): + log.info(f"生成的图片URL: {image_url}") + break + + # 任务仍在进行中,等待后重试 + log.info(f"任务进行中,等待{retry_interval}秒后重试...") + time.sleep(retry_interval) + retry_count += 1 + + if retry_count >= max_retries: + log.error(f"查询任务状态超时,已达到最大重试次数: {max_retries}") + + # 如果获取到了图片URL,则下载保存 + if image_url and image_url.strip(): + # 生成文件名(使用时间戳和任务ID) + file_name = f"mj_{int(time.time() * 1000)}_{task_id}.png" + # 完整保存路径 + save_path = os.path.join(MjCommon.base_path, file_name) + # 下载图片 + MjCommon.download_file(image_url, save_path) + except Exception as e: + log.error(f"程序执行异常: {str(e)}") + +if __name__ == "__main__": + Txt2Img.main() \ No newline at end of file diff --git a/dsLightRag/Midjourney/__init__.py b/dsLightRag/Midjourney/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/dsLightRag/Suno/sunoUtil.py b/dsLightRag/Suno/sunoUtil.py index 464260cf..23a005a5 100644 --- a/dsLightRag/Suno/sunoUtil.py +++ b/dsLightRag/Suno/sunoUtil.py @@ -13,8 +13,8 @@ log = logging.getLogger('SunoMusicGenerator') class SunoMusicGenerator: - GENERATE_URL = "https://goapi.gptnb.ai/suno/v2/generate" - FEED_URL = "https://goapi.gptnb.ai/suno/v2/feed" + GENERATE_URL =f"{Config.GPTNB_BASE_URL}/suno/v2/generate" + FEED_URL = f"{Config.GPTNB_BASE_URL}/suno/v2/feed" MAX_RETRIES = 30 # 最大重试次数 RETRY_INTERVAL = 5000 # 重试间隔(毫秒) diff --git a/dsLightRag/Util/GoApiUtil.py b/dsLightRag/Util/GoApiUtil.py index 5955e830..67054a6a 100644 --- a/dsLightRag/Util/GoApiUtil.py +++ b/dsLightRag/Util/GoApiUtil.py @@ -1,9 +1,11 @@ import json import requests + +from Config import Config from Config.Config import GPTNB_API_KEY class ModelInteractor: - def __init__(self, api_key=GPTNB_API_KEY, api_url="https://goapi.gptnb.ai/v1/chat/completions"): # 修复URL + def __init__(self, api_key=GPTNB_API_KEY, api_url=f"{Config.GPTNB_BASE_URL}/v1/chat/completions"): # 修复URL self.api_key = api_key self.api_url = api_url self.headers = {