Files
dsProject/dsLightRag/Routes/SunoRoute.py
2025-08-21 15:06:32 +08:00

302 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import logging
import json
import time
import datetime
import requests
import uuid # 新增导入uuid模块
import os # 新增导入os模块
from typing import Optional
import fastapi
from fastapi import APIRouter, HTTPException, Query
from pydantic import BaseModel
from Suno.sunoUtil import SunoMusicGenerator
from Config import Config
# 创建路由路由器
router = APIRouter(prefix="/api/suno", tags=["音乐"])
# 配置日志
logger = logging.getLogger(__name__)
# 获取API密钥
AK = Config.GPTNB_API_KEY
# 请求模型
class MusicGenerateRequest(BaseModel):
prompt: str
make_instrumental: Optional[bool] = True
# 初始化音乐生成器
music_generator = SunoMusicGenerator(AK)
# 任务状态存储(实际应用中可使用数据库)
music_tasks = {}
@router.post("/prompt_input")
async def prompt_input(request: MusicGenerateRequest):
"""
生成音乐任务接口
:param request: 包含提示词和是否纯音乐的请求体
:return: 任务ID和状态信息
"""
prompt = request.prompt
make_instrumental = request.make_instrumental
if not prompt:
raise HTTPException(status_code=400, detail="缺少提示词参数")
try:
logger.info(f"开始生成音乐任务,提示词: {prompt}")
# 调用音乐生成器生成音乐
# 注意:我们只执行生成请求,不等待结果,因为这是一个异步过程
# 构建JSON请求体
request_json = {
"gpt_description_prompt": prompt,
"mv": "chirp-v3-5",
"prompt": "",
"make_instrumental": make_instrumental
}
# 设置请求头
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {AK}"
}
# 执行生成请求
response = requests.post(
SunoMusicGenerator.GENERATE_URL,
headers=headers,
json=request_json,
timeout=30
)
response.raise_for_status()
# 解析响应
generate_json = response.json()
logger.info(f"音乐生成响应: {generate_json}")
# 提取任务ID
task_id = None
if "id" in generate_json:
task_id = generate_json["id"]
elif "task_id" in generate_json:
task_id = generate_json["task_id"]
elif "clip_id" in generate_json:
task_id = generate_json["clip_id"]
if task_id is None:
raise HTTPException(status_code=500, detail="无法从响应中提取任务ID")
# 存储任务信息
music_tasks[task_id] = {
"status": "processing",
"prompt": prompt,
"create_time": datetime.datetime.now().isoformat(),
"response": generate_json
}
logger.info(f"音乐生成任务已提交任务ID: {task_id}")
return {
"code": 200,
"message": "音乐生成任务已提交",
"data": {
"task_id": task_id,
"status": "processing"
}
}
except json.JSONDecodeError as e:
logger.error(f"解析音乐生成响应失败: {e}")
raise HTTPException(status_code=500, detail=f"解析音乐生成响应失败: {str(e)}")
except Exception as e:
logger.error(f"音乐生成过程中发生错误: {e}")
raise HTTPException(status_code=500, detail=f"音乐生成过程中发生错误: {str(e)}")
@router.get("/check_task_status")
async def check_task_status(task_id: str = Query(..., description="音乐生成任务ID")):
"""
检查音乐生成任务状态接口
:param task_id: 音乐生成任务ID
:return: 任务状态信息
"""
if not task_id:
raise HTTPException(status_code=400, detail="缺少任务ID参数")
# 检查任务是否存在
if task_id not in music_tasks:
raise HTTPException(status_code=404, detail=f"任务ID不存在: {task_id}")
try:
logger.info(f"检查任务状态任务ID: {task_id}")
# 获取任务信息
task_info = music_tasks[task_id]
generate_json = task_info["response"]
# 构建查询URL
url_builder = [SunoMusicGenerator.FEED_URL, "?"]
# 尝试从生成响应中获取clips的ID
clip_ids = []
if "clips" in generate_json:
clips_array = generate_json["clips"]
for clip in clips_array:
if "id" in clip:
clip_ids.append(clip["id"])
# 添加ids参数
if clip_ids:
ids_param = ",".join(clip_ids)
url_builder.append(f"ids={ids_param}")
logger.info(f"使用clips ID查询: {ids_param}")
else:
url_builder.append(f"ids={task_id}")
logger.info(f"使用任务ID查询: {task_id}")
url = "".join(url_builder)
logger.info(f"查询URL: {url}")
# 设置请求头
headers = {
"Authorization": f"Bearer {AK}",
"Accept": "application/json"
}
# 执行查询请求
response = requests.get(url, headers=headers, timeout=30)
response.raise_for_status()
# 解析查询响应
json_response = response.json()
clips = json_response.get("clips", [])
if clips:
# 遍历所有返回的音乐片段
for clip in clips:
clip_id = clip.get("id")
status = clip.get("status")
title = clip.get("title")
audio_url = clip.get("audio_url")
logger.info(f"查询结果:")
logger.info(f"ID: {clip_id}")
logger.info(f"标题: {title}")
logger.info(f"状态: {status}")
# 更新任务状态
task_info["status"] = status
task_info["last_check_time"] = datetime.datetime.now().isoformat()
task_info["title"] = title
if status == "complete":
# 确保audio_url字段存在
if audio_url:
# 移除URL中可能存在的反引号
audio_url = audio_url.replace("`", "").strip()
task_info["audio_url"] = audio_url
logger.info("音乐生成已完成!")
logger.info(f"音频URL: {audio_url}")
# 新增检查是否已经上传到OBS
if "obs_url" not in task_info or not task_info["obs_url"]:
# 新增下载音频文件并上传到OBS
try:
# 使用UUID生成唯一文件名
unique_id = uuid.uuid4()
object_key = f"HuangHai/JiMeng/{unique_id}.mp3"
# 临时文件保存路径
temp_file_path = os.path.join(os.path.dirname(__file__), f"{unique_id}.mp3")
# 下载URL内容到临时文件
logger.info(f"开始下载URL内容: {audio_url}")
with requests.get(audio_url, stream=True, timeout=3600) as r:
r.raise_for_status()
with open(temp_file_path, 'wb') as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
logger.info(f"URL内容下载完成保存至: {temp_file_path}")
# 上传文件到OBS
from Util.ObsUtil import ObsUploader # 导入ObsUploader
obs_uploader = ObsUploader()
logger.info(f"开始上传文件到OBS: {temp_file_path}\n对象键: {object_key}")
# 正确处理元组返回值
success, response_info = obs_uploader.upload_file(
file_path=temp_file_path,
object_key=object_key
)
if success:
# 构造OBS URL
from Config.Config import OBS_SERVER, OBS_BUCKET
obs_url = f"https://{OBS_BUCKET}.{OBS_SERVER}/{object_key}"
logger.info(f"文件上传成功OBS URL: {obs_url}")
task_info["obs_url"] = obs_url
task_info["object_key"] = object_key
else:
error_msg = f"上传失败: {str(response_info)}"
logger.error(error_msg)
finally:
# 清理临时文件
if os.path.exists(temp_file_path):
os.remove(temp_file_path)
logger.info(f"临时文件已删除: {temp_file_path}")
# 新增结束
else:
logger.warning("音乐生成已完成但未找到音频URL!")
elif status == "failed":
logger.error("音乐生成失败!")
task_info["error_message"] = clip.get("error_message", "未知错误")
# 更新任务存储
music_tasks[task_id] = task_info
return {
"code": 200,
"message": "查询成功",
"data": {
"task_id": task_id,
"status": status,
"title": title,
"audio_url": audio_url if status == "complete" else None,
"obs_url": task_info.get("obs_url") if status == "complete" else None, # 新增返回OBS URL
"error_message": task_info.get("error_message") if status == "failed" else None
}
}
else:
logger.info("未找到音乐片段")
return {
"code": 200,
"message": "查询成功",
"data": {
"task_id": task_id,
"status": "processing",
"title": None,
"audio_url": None,
"local_path": None,
"error_message": None
}
}
except fastapi.Requests.exceptions.RequestException as e:
logger.error(f"查询任务状态失败: {e}")
raise HTTPException(status_code=500, detail=f"查询任务状态失败: {str(e)}")
except json.JSONDecodeError as e:
logger.error(f"解析查询响应失败: {e}")
raise HTTPException(status_code=500, detail=f"解析查询响应失败: {str(e)}")
except Exception as e:
logger.error(f"查询任务状态过程中发生错误: {e}")
raise HTTPException(status_code=500, detail=f"查询任务状态过程中发生错误: {str(e)}")