Files
dsProject/dsLightRag/Routes/SunoRoute.py

302 lines
12 KiB
Python
Raw Normal View History

2025-08-21 14:16:13 +08:00
import logging
import json
import time
import datetime
import requests
2025-08-21 14:50:28 +08:00
import uuid # 新增导入uuid模块
import os # 新增导入os模块
2025-08-21 14:16:13 +08:00
from typing import Optional
import fastapi
from fastapi import APIRouter, HTTPException, Query
from pydantic import BaseModel
2025-08-21 14:24:35 +08:00
from Suno.sunoUtil import SunoMusicGenerator
2025-08-21 14:16:13 +08:00
from Config import Config
# 创建路由路由器
router = APIRouter(prefix="/api/suno", tags=["音乐"])
# 配置日志
logger = logging.getLogger(__name__)
# 获取API密钥
AK = Config.GPTNB_API_KEY
# 请求模型
class MusicGenerateRequest(BaseModel):
prompt: str
make_instrumental: Optional[bool] = True
# 初始化音乐生成器
music_generator = SunoMusicGenerator(AK)
# 任务状态存储(实际应用中可使用数据库)
music_tasks = {}
@router.post("/prompt_input")
async def prompt_input(request: MusicGenerateRequest):
"""
生成音乐任务接口
:param request: 包含提示词和是否纯音乐的请求体
:return: 任务ID和状态信息
"""
prompt = request.prompt
make_instrumental = request.make_instrumental
if not prompt:
raise HTTPException(status_code=400, detail="缺少提示词参数")
try:
logger.info(f"开始生成音乐任务,提示词: {prompt}")
# 调用音乐生成器生成音乐
# 注意:我们只执行生成请求,不等待结果,因为这是一个异步过程
# 构建JSON请求体
request_json = {
"gpt_description_prompt": prompt,
"mv": "chirp-v3-5",
"prompt": "",
"make_instrumental": make_instrumental
}
# 设置请求头
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {AK}"
}
# 执行生成请求
response = requests.post(
SunoMusicGenerator.GENERATE_URL,
headers=headers,
json=request_json,
timeout=30
)
response.raise_for_status()
# 解析响应
generate_json = response.json()
logger.info(f"音乐生成响应: {generate_json}")
# 提取任务ID
task_id = None
if "id" in generate_json:
task_id = generate_json["id"]
elif "task_id" in generate_json:
task_id = generate_json["task_id"]
elif "clip_id" in generate_json:
task_id = generate_json["clip_id"]
if task_id is None:
raise HTTPException(status_code=500, detail="无法从响应中提取任务ID")
# 存储任务信息
music_tasks[task_id] = {
"status": "processing",
"prompt": prompt,
"create_time": datetime.datetime.now().isoformat(),
"response": generate_json
}
logger.info(f"音乐生成任务已提交任务ID: {task_id}")
return {
"code": 200,
"message": "音乐生成任务已提交",
"data": {
"task_id": task_id,
"status": "processing"
}
}
except json.JSONDecodeError as e:
logger.error(f"解析音乐生成响应失败: {e}")
raise HTTPException(status_code=500, detail=f"解析音乐生成响应失败: {str(e)}")
except Exception as e:
logger.error(f"音乐生成过程中发生错误: {e}")
raise HTTPException(status_code=500, detail=f"音乐生成过程中发生错误: {str(e)}")
@router.get("/check_task_status")
async def check_task_status(task_id: str = Query(..., description="音乐生成任务ID")):
"""
检查音乐生成任务状态接口
:param task_id: 音乐生成任务ID
:return: 任务状态信息
"""
if not task_id:
raise HTTPException(status_code=400, detail="缺少任务ID参数")
# 检查任务是否存在
if task_id not in music_tasks:
raise HTTPException(status_code=404, detail=f"任务ID不存在: {task_id}")
try:
logger.info(f"检查任务状态任务ID: {task_id}")
# 获取任务信息
task_info = music_tasks[task_id]
generate_json = task_info["response"]
# 构建查询URL
url_builder = [SunoMusicGenerator.FEED_URL, "?"]
# 尝试从生成响应中获取clips的ID
clip_ids = []
if "clips" in generate_json:
clips_array = generate_json["clips"]
for clip in clips_array:
if "id" in clip:
clip_ids.append(clip["id"])
# 添加ids参数
if clip_ids:
ids_param = ",".join(clip_ids)
url_builder.append(f"ids={ids_param}")
logger.info(f"使用clips ID查询: {ids_param}")
else:
url_builder.append(f"ids={task_id}")
logger.info(f"使用任务ID查询: {task_id}")
url = "".join(url_builder)
logger.info(f"查询URL: {url}")
# 设置请求头
headers = {
"Authorization": f"Bearer {AK}",
"Accept": "application/json"
}
# 执行查询请求
response = requests.get(url, headers=headers, timeout=30)
response.raise_for_status()
# 解析查询响应
json_response = response.json()
clips = json_response.get("clips", [])
if clips:
# 遍历所有返回的音乐片段
for clip in clips:
clip_id = clip.get("id")
status = clip.get("status")
title = clip.get("title")
audio_url = clip.get("audio_url")
logger.info(f"查询结果:")
logger.info(f"ID: {clip_id}")
logger.info(f"标题: {title}")
logger.info(f"状态: {status}")
# 更新任务状态
task_info["status"] = status
task_info["last_check_time"] = datetime.datetime.now().isoformat()
task_info["title"] = title
if status == "complete":
# 确保audio_url字段存在
if audio_url:
# 移除URL中可能存在的反引号
audio_url = audio_url.replace("`", "").strip()
task_info["audio_url"] = audio_url
logger.info("音乐生成已完成!")
logger.info(f"音频URL: {audio_url}")
2025-08-21 15:06:32 +08:00
# 新增检查是否已经上传到OBS
if "obs_url" not in task_info or not task_info["obs_url"]:
# 新增下载音频文件并上传到OBS
try:
# 使用UUID生成唯一文件名
unique_id = uuid.uuid4()
object_key = f"HuangHai/JiMeng/{unique_id}.mp3"
# 临时文件保存路径
temp_file_path = os.path.join(os.path.dirname(__file__), f"{unique_id}.mp3")
# 下载URL内容到临时文件
logger.info(f"开始下载URL内容: {audio_url}")
with requests.get(audio_url, stream=True, timeout=3600) as r:
r.raise_for_status()
with open(temp_file_path, 'wb') as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
logger.info(f"URL内容下载完成保存至: {temp_file_path}")
# 上传文件到OBS
from Util.ObsUtil import ObsUploader # 导入ObsUploader
obs_uploader = ObsUploader()
logger.info(f"开始上传文件到OBS: {temp_file_path}\n对象键: {object_key}")
# 正确处理元组返回值
success, response_info = obs_uploader.upload_file(
file_path=temp_file_path,
object_key=object_key
)
if success:
# 构造OBS URL
from Config.Config import OBS_SERVER, OBS_BUCKET
obs_url = f"https://{OBS_BUCKET}.{OBS_SERVER}/{object_key}"
logger.info(f"文件上传成功OBS URL: {obs_url}")
task_info["obs_url"] = obs_url
task_info["object_key"] = object_key
else:
error_msg = f"上传失败: {str(response_info)}"
logger.error(error_msg)
finally:
# 清理临时文件
if os.path.exists(temp_file_path):
os.remove(temp_file_path)
logger.info(f"临时文件已删除: {temp_file_path}")
2025-08-21 14:50:28 +08:00
# 新增结束
2025-08-21 14:16:13 +08:00
else:
logger.warning("音乐生成已完成但未找到音频URL!")
elif status == "failed":
logger.error("音乐生成失败!")
task_info["error_message"] = clip.get("error_message", "未知错误")
# 更新任务存储
music_tasks[task_id] = task_info
return {
"code": 200,
"message": "查询成功",
"data": {
"task_id": task_id,
"status": status,
"title": title,
"audio_url": audio_url if status == "complete" else None,
2025-08-21 14:50:28 +08:00
"obs_url": task_info.get("obs_url") if status == "complete" else None, # 新增返回OBS URL
2025-08-21 14:16:13 +08:00
"error_message": task_info.get("error_message") if status == "failed" else None
}
}
else:
logger.info("未找到音乐片段")
return {
"code": 200,
"message": "查询成功",
"data": {
"task_id": task_id,
"status": "processing",
"title": None,
"audio_url": None,
"local_path": None,
"error_message": None
}
}
except fastapi.Requests.exceptions.RequestException as e:
logger.error(f"查询任务状态失败: {e}")
raise HTTPException(status_code=500, detail=f"查询任务状态失败: {str(e)}")
except json.JSONDecodeError as e:
logger.error(f"解析查询响应失败: {e}")
raise HTTPException(status_code=500, detail=f"解析查询响应失败: {str(e)}")
except Exception as e:
logger.error(f"查询任务状态过程中发生错误: {e}")
raise HTTPException(status_code=500, detail=f"查询任务状态过程中发生错误: {str(e)}")