This commit is contained in:
2025-08-25 14:01:48 +08:00
parent 3dd85cd80d
commit 5e531d565b
8 changed files with 531 additions and 3 deletions

View File

@@ -0,0 +1,111 @@
import os
import logging
import requests
import json
import time
from pathlib import Path
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger('MjCommon')
class MjCommon:
# 获取项目根目录路径
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", "..", "..", ".."))
# 拼接相对路径
base_path = os.path.join(project_root, "src", "main", "python", "com", "dsideal", "aiSupport", "Util", "Midjourney", "Example")
ak = "sk-amQHwiEzPIZIB2KuF5A10dC23a0e4b02B48a7a2b6aFa0662" # 填写access key
BASE_URL = "https://goapi.gptnb.ai"
@staticmethod
def download_file(file_url, save_file_path):
"""从URL下载文件到指定路径
Args:
file_url (str): 文件URL
save_file_path (str): 保存路径
Raises:
Exception: 下载过程中的异常
"""
try:
# 确保目录存在
os.makedirs(os.path.dirname(save_file_path), exist_ok=True)
# 下载文件
logger.info(f"开始下载文件: {file_url}")
response = requests.get(file_url, stream=True, timeout=30)
response.raise_for_status()
# 保存文件
with open(save_file_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
file_size = os.path.getsize(save_file_path)
logger.info(f"文件下载成功,保存路径: {save_file_path}, 文件大小: {file_size}字节")
except Exception as e:
logger.error(f"文件下载失败: {str(e)}")
raise Exception(f"文件下载失败: {str(e)}")
@staticmethod
def query_task_status(task_id):
"""查询任务状态
Args:
task_id (str): 任务ID
Returns:
dict: 任务结果
Raises:
Exception: 异常信息
"""
# 创建请求URL
url = f"{MjCommon.BASE_URL}/mj/task/{task_id}/fetch"
# 设置请求头
headers = {
"Authorization": f"Bearer {MjCommon.ak}"
}
try:
logger.info(f"查询Midjourney任务状态: {task_id}")
response = requests.get(url, headers=headers, timeout=30)
# 检查响应状态
if not response.ok:
error_msg = f"Midjourney API请求失败状态码: {response.status_code}"
logger.error(error_msg)
raise Exception(error_msg)
# 解析响应
response_body = response.text
logger.info(f"查询Midjourney任务状态响应: {response_body}")
return json.loads(response_body)
except Exception as e:
logger.error(f"查询任务状态失败: {str(e)}")
raise Exception(f"查询任务状态失败: {str(e)}")
# 测试代码
if __name__ == "__main__":
# 测试下载文件
# try:
# test_url = "https://example.com/test.jpg"
# test_save_path = os.path.join(MjCommon.base_path, "test.jpg")
# MjCommon.download_file(test_url, test_save_path)
# except Exception as e:
# print(f"测试下载失败: {e}")
# 测试查询任务状态
# try:
# test_task_id = "your_task_id_here"
# result = MjCommon.query_task_status(test_task_id)
# print(f"任务状态查询结果: {result}")
# except Exception as e:
# print(f"测试查询任务状态失败: {e}")
pass

View File

@@ -0,0 +1,150 @@
import os
import time
import json
import logging
import requests
from com.dsideal.Ai.Util.Midjourney.Kit.MjCommon import MjCommon
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
log = logging.getLogger(__name__)
class Txt2Img(MjCommon):
"""
Midjourney API 工具类
用于调用 Midjourney 的 imagine 接口生成图片
"""
@staticmethod
def submit_imagine(prompt, base64_array=None, notify_hook=None):
"""
提交 imagine 请求
:param prompt: 提示词
:param base64_array: 垫图base64数组
:param notify_hook: 回调地址
:return: 任务ID
:raises Exception: 异常信息
"""
try:
# 构建请求体
request_body = {
"prompt": prompt,
"state": f"midjourney_task_{int(time.time() * 1000)}"
}
# 如果提供了垫图base64数组则添加到请求体中
if base64_array and len(base64_array) > 0:
request_body["base64Array"] = base64_array
# 如果提供了回调地址,则添加到请求体中
if notify_hook and notify_hook.strip():
request_body["notifyHook"] = notify_hook
# 创建请求
url = f"{MjCommon.BASE_URL}/mj/submit/imagine"
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {MjCommon.ak}"
}
# 发送请求并获取响应
log.info(f"提交Midjourney imagine请求: {json.dumps(request_body)}")
response = requests.post(url, headers=headers, json=request_body, timeout=30)
# 检查响应状态
if response.status_code != 200:
error_msg = f"Midjourney API请求失败状态码: {response.status_code}"
log.error(error_msg)
raise Exception(error_msg)
# 解析响应
response_body = response.text
log.info(f"Midjourney imagine响应: {response_body}")
response_json = json.loads(response_body)
# 检查响应状态 - code=1 表示成功
if response_json.get("code") != 1:
error_msg = f"Midjourney imagine失败: {response_json.get('description', '未知错误')}"
log.error(error_msg)
raise Exception(error_msg)
# 获取任务ID - 从result字段获取
task_id = response_json.get("result")
log.info(f"Midjourney imagine任务ID: {task_id}")
return task_id
except Exception as e:
log.error(f"提交imagine请求异常: {str(e)}")
raise
@staticmethod
def main():
"""测试方法"""
try:
# 提示词
prompt = "A cute cat playing with a ball of yarn, digital art style"
# 提交imagine请求
task_id = Txt2Img.submit_imagine(prompt)
# 轮询查询任务状态
max_retries = 1000
retry_count = 0
retry_interval = 5 # 5秒
image_url = None
while retry_count < max_retries:
result = MjCommon.query_task_status(task_id)
# 直接使用响应中的字段
status = result.get("status")
log.info(f"任务状态: {status}")
# 检查进度
progress = result.get("progress")
log.info(f"任务进度: {progress}")
# 任务状态可能为空字符串需要检查progress或其他字段来判断任务是否完成
if status and status.strip():
if status == "SUCCESS":
# 任务成功获取图片URL
image_url = result.get("imageUrl")
log.info(f"生成的图片URL: {image_url}")
break
elif status == "FAILED":
# 任务失败
log.error(f"任务失败: {result.get('failReason', '未知原因')}")
break
else:
# 检查description字段
description = result.get("description")
if description and "成功" in description and progress != "0%":
# 如果描述包含"成功"且进度不为0%,可能任务已完成
image_url = result.get("imageUrl")
if image_url and image_url.strip():
log.info(f"生成的图片URL: {image_url}")
break
# 任务仍在进行中,等待后重试
log.info(f"任务进行中,等待{retry_interval}秒后重试...")
time.sleep(retry_interval)
retry_count += 1
if retry_count >= max_retries:
log.error(f"查询任务状态超时,已达到最大重试次数: {max_retries}")
# 如果获取到了图片URL则下载保存
if image_url and image_url.strip():
# 生成文件名使用时间戳和任务ID
file_name = f"mj_{int(time.time() * 1000)}_{task_id}.png"
# 完整保存路径
save_path = os.path.join(MjCommon.base_path, file_name)
# 下载图片
MjCommon.download_file(image_url, save_path)
except Exception as e:
log.error(f"程序执行异常: {str(e)}")
if __name__ == "__main__":
Txt2Img.main()