113 lines
3.6 KiB
Python
113 lines
3.6 KiB
Python
|
import os
|
|||
|
import logging
|
|||
|
import requests
|
|||
|
import json
|
|||
|
import time
|
|||
|
from pathlib import Path
|
|||
|
|
|||
|
from Config import Config
|
|||
|
|
|||
|
# 配置日志
|
|||
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|||
|
logger = logging.getLogger('MjCommon')
|
|||
|
|
|||
|
|
|||
|
class MjCommon:
|
|||
|
# 获取项目根目录路径
|
|||
|
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", "..", "..", ".."))
|
|||
|
# 拼接相对路径
|
|||
|
base_path = os.path.join(project_root, "src", "main", "python", "com", "dsideal", "aiSupport", "Util", "Midjourney", "Example")
|
|||
|
|
|||
|
ak = Config.GPTNB_API_KEY
|
|||
|
BASE_URL = Config.GPTNB_BASE_URL
|
|||
|
|
|||
|
@staticmethod
|
|||
|
def download_file(file_url, save_file_path):
|
|||
|
"""从URL下载文件到指定路径
|
|||
|
|
|||
|
Args:
|
|||
|
file_url (str): 文件URL
|
|||
|
save_file_path (str): 保存路径
|
|||
|
|
|||
|
Raises:
|
|||
|
Exception: 下载过程中的异常
|
|||
|
"""
|
|||
|
try:
|
|||
|
# 确保目录存在
|
|||
|
os.makedirs(os.path.dirname(save_file_path), exist_ok=True)
|
|||
|
|
|||
|
# 下载文件
|
|||
|
logger.info(f"开始下载文件: {file_url}")
|
|||
|
response = requests.get(file_url, stream=True, timeout=30)
|
|||
|
response.raise_for_status()
|
|||
|
|
|||
|
# 保存文件
|
|||
|
with open(save_file_path, 'wb') as f:
|
|||
|
for chunk in response.iter_content(chunk_size=8192):
|
|||
|
f.write(chunk)
|
|||
|
|
|||
|
file_size = os.path.getsize(save_file_path)
|
|||
|
logger.info(f"文件下载成功,保存路径: {save_file_path}, 文件大小: {file_size}字节")
|
|||
|
except Exception as e:
|
|||
|
logger.error(f"文件下载失败: {str(e)}")
|
|||
|
raise Exception(f"文件下载失败: {str(e)}")
|
|||
|
|
|||
|
@staticmethod
|
|||
|
def query_task_status(task_id):
|
|||
|
"""查询任务状态
|
|||
|
|
|||
|
Args:
|
|||
|
task_id (str): 任务ID
|
|||
|
|
|||
|
Returns:
|
|||
|
dict: 任务结果
|
|||
|
|
|||
|
Raises:
|
|||
|
Exception: 异常信息
|
|||
|
"""
|
|||
|
# 创建请求URL
|
|||
|
url = f"{MjCommon.BASE_URL}/mj/task/{task_id}/fetch"
|
|||
|
|
|||
|
# 设置请求头
|
|||
|
headers = {
|
|||
|
"Authorization": f"Bearer {MjCommon.ak}"
|
|||
|
}
|
|||
|
|
|||
|
try:
|
|||
|
logger.info(f"查询Midjourney任务状态: {task_id}")
|
|||
|
response = requests.get(url, headers=headers, timeout=30)
|
|||
|
|
|||
|
# 检查响应状态
|
|||
|
if not response.ok:
|
|||
|
error_msg = f"Midjourney API请求失败,状态码: {response.status_code}"
|
|||
|
logger.error(error_msg)
|
|||
|
raise Exception(error_msg)
|
|||
|
|
|||
|
# 解析响应
|
|||
|
response_body = response.text
|
|||
|
logger.info(f"查询Midjourney任务状态响应: {response_body}")
|
|||
|
|
|||
|
return json.loads(response_body)
|
|||
|
except Exception as e:
|
|||
|
logger.error(f"查询任务状态失败: {str(e)}")
|
|||
|
raise Exception(f"查询任务状态失败: {str(e)}")
|
|||
|
|
|||
|
|
|||
|
# 测试代码
|
|||
|
if __name__ == "__main__":
|
|||
|
# 测试下载文件
|
|||
|
# try:
|
|||
|
# test_url = "https://example.com/test.jpg"
|
|||
|
# test_save_path = os.path.join(MjCommon.base_path, "test.jpg")
|
|||
|
# MjCommon.download_file(test_url, test_save_path)
|
|||
|
# except Exception as e:
|
|||
|
# print(f"测试下载失败: {e}")
|
|||
|
|
|||
|
# 测试查询任务状态
|
|||
|
# try:
|
|||
|
# test_task_id = "your_task_id_here"
|
|||
|
# result = MjCommon.query_task_status(test_task_id)
|
|||
|
# print(f"任务状态查询结果: {result}")
|
|||
|
# except Exception as e:
|
|||
|
# print(f"测试查询任务状态失败: {e}")
|
|||
|
pass
|