Files
dsProject/dsLightRag/Midjourney/Txt2Img.py
2025-08-25 14:03:04 +08:00

151 lines
5.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import time
import json
import logging
import requests
from Midjourney.Kit.MjCommon import MjCommon
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
log = logging.getLogger(__name__)
class Txt2Img(MjCommon):
"""
Midjourney API 工具类
用于调用 Midjourney 的 imagine 接口生成图片
"""
@staticmethod
def submit_imagine(prompt, base64_array=None, notify_hook=None):
"""
提交 imagine 请求
:param prompt: 提示词
:param base64_array: 垫图base64数组
:param notify_hook: 回调地址
:return: 任务ID
:raises Exception: 异常信息
"""
try:
# 构建请求体
request_body = {
"prompt": prompt,
"state": f"midjourney_task_{int(time.time() * 1000)}"
}
# 如果提供了垫图base64数组则添加到请求体中
if base64_array and len(base64_array) > 0:
request_body["base64Array"] = base64_array
# 如果提供了回调地址,则添加到请求体中
if notify_hook and notify_hook.strip():
request_body["notifyHook"] = notify_hook
# 创建请求
url = f"{MjCommon.BASE_URL}/mj/submit/imagine"
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {MjCommon.ak}"
}
# 发送请求并获取响应
log.info(f"提交Midjourney imagine请求: {json.dumps(request_body)}")
response = requests.post(url, headers=headers, json=request_body, timeout=30)
# 检查响应状态
if response.status_code != 200:
error_msg = f"Midjourney API请求失败状态码: {response.status_code}"
log.error(error_msg)
raise Exception(error_msg)
# 解析响应
response_body = response.text
log.info(f"Midjourney imagine响应: {response_body}")
response_json = json.loads(response_body)
# 检查响应状态 - code=1 表示成功
if response_json.get("code") != 1:
error_msg = f"Midjourney imagine失败: {response_json.get('description', '未知错误')}"
log.error(error_msg)
raise Exception(error_msg)
# 获取任务ID - 从result字段获取
task_id = response_json.get("result")
log.info(f"Midjourney imagine任务ID: {task_id}")
return task_id
except Exception as e:
log.error(f"提交imagine请求异常: {str(e)}")
raise
@staticmethod
def main():
"""测试方法"""
try:
# 提示词
prompt = "A cute cat playing with a ball of yarn, digital art style"
# 提交imagine请求
task_id = Txt2Img.submit_imagine(prompt)
# 轮询查询任务状态
max_retries = 1000
retry_count = 0
retry_interval = 5 # 5秒
image_url = None
while retry_count < max_retries:
result = MjCommon.query_task_status(task_id)
# 直接使用响应中的字段
status = result.get("status")
log.info(f"任务状态: {status}")
# 检查进度
progress = result.get("progress")
log.info(f"任务进度: {progress}")
# 任务状态可能为空字符串需要检查progress或其他字段来判断任务是否完成
if status and status.strip():
if status == "SUCCESS":
# 任务成功获取图片URL
image_url = result.get("imageUrl")
log.info(f"生成的图片URL: {image_url}")
break
elif status == "FAILED":
# 任务失败
log.error(f"任务失败: {result.get('failReason', '未知原因')}")
break
else:
# 检查description字段
description = result.get("description")
if description and "成功" in description and progress != "0%":
# 如果描述包含"成功"且进度不为0%,可能任务已完成
image_url = result.get("imageUrl")
if image_url and image_url.strip():
log.info(f"生成的图片URL: {image_url}")
break
# 任务仍在进行中,等待后重试
log.info(f"任务进行中,等待{retry_interval}秒后重试...")
time.sleep(retry_interval)
retry_count += 1
if retry_count >= max_retries:
log.error(f"查询任务状态超时,已达到最大重试次数: {max_retries}")
# 如果获取到了图片URL则下载保存
if image_url and image_url.strip():
# 生成文件名使用时间戳和任务ID
file_name = f"mj_{int(time.time() * 1000)}_{task_id}.png"
# 完整保存路径
save_path = os.path.join("d:/", file_name)
# 下载图片
MjCommon.download_file(image_url, save_path)
except Exception as e:
log.error(f"程序执行异常: {str(e)}")
if __name__ == "__main__":
Txt2Img.main()