Files
dsProject/dsLightRag/Midjourney/Kit/MjCommon.py
2025-08-25 14:03:04 +08:00

108 lines
3.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import logging
import os
import requests
from Config import Config
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger('MjCommon')
class MjCommon:
ak = Config.GPTNB_API_KEY
BASE_URL = Config.GPTNB_BASE_URL
@staticmethod
def download_file(file_url, save_file_path):
"""从URL下载文件到指定路径
Args:
file_url (str): 文件URL
save_file_path (str): 保存路径
Raises:
Exception: 下载过程中的异常
"""
try:
# 确保目录存在
os.makedirs(os.path.dirname(save_file_path), exist_ok=True)
# 下载文件
logger.info(f"开始下载文件: {file_url}")
response = requests.get(file_url, stream=True, timeout=30)
response.raise_for_status()
# 保存文件
with open(save_file_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
file_size = os.path.getsize(save_file_path)
logger.info(f"文件下载成功,保存路径: {save_file_path}, 文件大小: {file_size}字节")
except Exception as e:
logger.error(f"文件下载失败: {str(e)}")
raise Exception(f"文件下载失败: {str(e)}")
@staticmethod
def query_task_status(task_id):
"""查询任务状态
Args:
task_id (str): 任务ID
Returns:
dict: 任务结果
Raises:
Exception: 异常信息
"""
# 创建请求URL
url = f"{MjCommon.BASE_URL}/mj/task/{task_id}/fetch"
# 设置请求头
headers = {
"Authorization": f"Bearer {MjCommon.ak}"
}
try:
logger.info(f"查询Midjourney任务状态: {task_id}")
response = requests.get(url, headers=headers, timeout=30)
# 检查响应状态
if not response.ok:
error_msg = f"Midjourney API请求失败状态码: {response.status_code}"
logger.error(error_msg)
raise Exception(error_msg)
# 解析响应
response_body = response.text
logger.info(f"查询Midjourney任务状态响应: {response_body}")
return json.loads(response_body)
except Exception as e:
logger.error(f"查询任务状态失败: {str(e)}")
raise Exception(f"查询任务状态失败: {str(e)}")
# 测试代码
if __name__ == "__main__":
# 测试下载文件
# try:
# test_url = "https://example.com/test.jpg"
# test_save_path = os.path.join(MjCommon.base_path, "test.jpg")
# MjCommon.download_file(test_url, test_save_path)
# except Exception as e:
# print(f"测试下载失败: {e}")
# 测试查询任务状态
# try:
# test_task_id = "your_task_id_here"
# result = MjCommon.query_task_status(test_task_id)
# print(f"任务状态查询结果: {result}")
# except Exception as e:
# print(f"测试查询任务状态失败: {e}")
pass