Files
dsProject/dsLightRag/Routes/SunoRoute.py
2025-08-21 14:24:35 +08:00

257 lines
8.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import logging
import json
import time
import datetime
import requests
from typing import Optional
import fastapi
from fastapi import APIRouter, HTTPException, Query
from pydantic import BaseModel
from Suno.sunoUtil import SunoMusicGenerator
from Config import Config
# 创建路由路由器
router = APIRouter(prefix="/api/suno", tags=["音乐"])
# 配置日志
logger = logging.getLogger(__name__)
# 获取API密钥
AK = Config.GPTNB_API_KEY
# 请求模型
class MusicGenerateRequest(BaseModel):
prompt: str
make_instrumental: Optional[bool] = True
# 初始化音乐生成器
music_generator = SunoMusicGenerator(AK)
# 任务状态存储(实际应用中可使用数据库)
music_tasks = {}
@router.post("/prompt_input")
async def prompt_input(request: MusicGenerateRequest):
"""
生成音乐任务接口
:param request: 包含提示词和是否纯音乐的请求体
:return: 任务ID和状态信息
"""
prompt = request.prompt
make_instrumental = request.make_instrumental
if not prompt:
raise HTTPException(status_code=400, detail="缺少提示词参数")
try:
logger.info(f"开始生成音乐任务,提示词: {prompt}")
# 调用音乐生成器生成音乐
# 注意:我们只执行生成请求,不等待结果,因为这是一个异步过程
# 构建JSON请求体
request_json = {
"gpt_description_prompt": prompt,
"mv": "chirp-v3-5",
"prompt": "",
"make_instrumental": make_instrumental
}
# 设置请求头
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {AK}"
}
# 执行生成请求
response = requests.post(
SunoMusicGenerator.GENERATE_URL,
headers=headers,
json=request_json,
timeout=30
)
response.raise_for_status()
# 解析响应
generate_json = response.json()
logger.info(f"音乐生成响应: {generate_json}")
# 提取任务ID
task_id = None
if "id" in generate_json:
task_id = generate_json["id"]
elif "task_id" in generate_json:
task_id = generate_json["task_id"]
elif "clip_id" in generate_json:
task_id = generate_json["clip_id"]
if task_id is None:
raise HTTPException(status_code=500, detail="无法从响应中提取任务ID")
# 存储任务信息
music_tasks[task_id] = {
"status": "processing",
"prompt": prompt,
"create_time": datetime.datetime.now().isoformat(),
"response": generate_json
}
logger.info(f"音乐生成任务已提交任务ID: {task_id}")
return {
"code": 200,
"message": "音乐生成任务已提交",
"data": {
"task_id": task_id,
"status": "processing"
}
}
except json.JSONDecodeError as e:
logger.error(f"解析音乐生成响应失败: {e}")
raise HTTPException(status_code=500, detail=f"解析音乐生成响应失败: {str(e)}")
except Exception as e:
logger.error(f"音乐生成过程中发生错误: {e}")
raise HTTPException(status_code=500, detail=f"音乐生成过程中发生错误: {str(e)}")
@router.get("/check_task_status")
async def check_task_status(task_id: str = Query(..., description="音乐生成任务ID")):
"""
检查音乐生成任务状态接口
:param task_id: 音乐生成任务ID
:return: 任务状态信息
"""
if not task_id:
raise HTTPException(status_code=400, detail="缺少任务ID参数")
# 检查任务是否存在
if task_id not in music_tasks:
raise HTTPException(status_code=404, detail=f"任务ID不存在: {task_id}")
try:
logger.info(f"检查任务状态任务ID: {task_id}")
# 获取任务信息
task_info = music_tasks[task_id]
generate_json = task_info["response"]
# 构建查询URL
url_builder = [SunoMusicGenerator.FEED_URL, "?"]
# 尝试从生成响应中获取clips的ID
clip_ids = []
if "clips" in generate_json:
clips_array = generate_json["clips"]
for clip in clips_array:
if "id" in clip:
clip_ids.append(clip["id"])
# 添加ids参数
if clip_ids:
ids_param = ",".join(clip_ids)
url_builder.append(f"ids={ids_param}")
logger.info(f"使用clips ID查询: {ids_param}")
else:
url_builder.append(f"ids={task_id}")
logger.info(f"使用任务ID查询: {task_id}")
url = "".join(url_builder)
logger.info(f"查询URL: {url}")
# 设置请求头
headers = {
"Authorization": f"Bearer {AK}",
"Accept": "application/json"
}
# 执行查询请求
response = requests.get(url, headers=headers, timeout=30)
response.raise_for_status()
# 解析查询响应
json_response = response.json()
clips = json_response.get("clips", [])
if clips:
# 遍历所有返回的音乐片段
for clip in clips:
clip_id = clip.get("id")
status = clip.get("status")
title = clip.get("title")
audio_url = clip.get("audio_url")
logger.info(f"查询结果:")
logger.info(f"ID: {clip_id}")
logger.info(f"标题: {title}")
logger.info(f"状态: {status}")
# 更新任务状态
task_info["status"] = status
task_info["last_check_time"] = datetime.datetime.now().isoformat()
task_info["title"] = title
if status == "complete":
# 确保audio_url字段存在
if audio_url:
# 移除URL中可能存在的反引号
audio_url = audio_url.replace("`", "").strip()
task_info["audio_url"] = audio_url
logger.info("音乐生成已完成!")
logger.info(f"音频URL: {audio_url}")
# 下载音频文件
file_name = f"suno_music_{int(time.time())}.mp3"
save_path = music_generator.base_path / file_name
if music_generator.download_audio(audio_url, str(save_path)):
task_info["local_path"] = str(save_path)
else:
logger.warning("音乐生成已完成但未找到音频URL!")
elif status == "failed":
logger.error("音乐生成失败!")
task_info["error_message"] = clip.get("error_message", "未知错误")
# 更新任务存储
music_tasks[task_id] = task_info
return {
"code": 200,
"message": "查询成功",
"data": {
"task_id": task_id,
"status": status,
"title": title,
"audio_url": audio_url if status == "complete" else None,
"local_path": task_info.get("local_path") if status == "complete" else None,
"error_message": task_info.get("error_message") if status == "failed" else None
}
}
else:
logger.info("未找到音乐片段")
return {
"code": 200,
"message": "查询成功",
"data": {
"task_id": task_id,
"status": "processing",
"title": None,
"audio_url": None,
"local_path": None,
"error_message": None
}
}
except fastapi.Requests.exceptions.RequestException as e:
logger.error(f"查询任务状态失败: {e}")
raise HTTPException(status_code=500, detail=f"查询任务状态失败: {str(e)}")
except json.JSONDecodeError as e:
logger.error(f"解析查询响应失败: {e}")
raise HTTPException(status_code=500, detail=f"解析查询响应失败: {str(e)}")
except Exception as e:
logger.error(f"查询任务状态过程中发生错误: {e}")
raise HTTPException(status_code=500, detail=f"查询任务状态过程中发生错误: {str(e)}")