diff --git a/dsLightRag/Routes/SunoRoute.py b/dsLightRag/Routes/SunoRoute.py new file mode 100644 index 00000000..3de2678f --- /dev/null +++ b/dsLightRag/Routes/SunoRoute.py @@ -0,0 +1,256 @@ +import logging +import json +import time +import datetime +import requests +from typing import Optional + +import fastapi +from fastapi import APIRouter, HTTPException, Query +from pydantic import BaseModel + +from Routes.suno_music_generator import SunoMusicGenerator +from Config import Config + +# 创建路由路由器 +router = APIRouter(prefix="/api/suno", tags=["音乐"]) + +# 配置日志 +logger = logging.getLogger(__name__) + +# 获取API密钥 +AK = Config.GPTNB_API_KEY + +# 请求模型 +class MusicGenerateRequest(BaseModel): + prompt: str + make_instrumental: Optional[bool] = True + +# 初始化音乐生成器 +music_generator = SunoMusicGenerator(AK) + +# 任务状态存储(实际应用中可使用数据库) +music_tasks = {} + + +@router.post("/prompt_input") +async def prompt_input(request: MusicGenerateRequest): + """ + 生成音乐任务接口 + :param request: 包含提示词和是否纯音乐的请求体 + :return: 任务ID和状态信息 + """ + prompt = request.prompt + make_instrumental = request.make_instrumental + + if not prompt: + raise HTTPException(status_code=400, detail="缺少提示词参数") + + try: + logger.info(f"开始生成音乐任务,提示词: {prompt}") + + # 调用音乐生成器生成音乐 + # 注意:我们只执行生成请求,不等待结果,因为这是一个异步过程 + # 构建JSON请求体 + request_json = { + "gpt_description_prompt": prompt, + "mv": "chirp-v3-5", + "prompt": "", + "make_instrumental": make_instrumental + } + + # 设置请求头 + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {AK}" + } + + # 执行生成请求 + response = requests.post( + SunoMusicGenerator.GENERATE_URL, + headers=headers, + json=request_json, + timeout=30 + ) + response.raise_for_status() + + # 解析响应 + generate_json = response.json() + logger.info(f"音乐生成响应: {generate_json}") + + # 提取任务ID + task_id = None + if "id" in generate_json: + task_id = generate_json["id"] + elif "task_id" in generate_json: + task_id = generate_json["task_id"] + elif "clip_id" in generate_json: + task_id = generate_json["clip_id"] + + if task_id is None: + raise HTTPException(status_code=500, detail="无法从响应中提取任务ID") + + # 存储任务信息 + music_tasks[task_id] = { + "status": "processing", + "prompt": prompt, + "create_time": datetime.datetime.now().isoformat(), + "response": generate_json + } + + logger.info(f"音乐生成任务已提交,任务ID: {task_id}") + + return { + "code": 200, + "message": "音乐生成任务已提交", + "data": { + "task_id": task_id, + "status": "processing" + } + } + + except json.JSONDecodeError as e: + logger.error(f"解析音乐生成响应失败: {e}") + raise HTTPException(status_code=500, detail=f"解析音乐生成响应失败: {str(e)}") + except Exception as e: + logger.error(f"音乐生成过程中发生错误: {e}") + raise HTTPException(status_code=500, detail=f"音乐生成过程中发生错误: {str(e)}") + + +@router.get("/check_task_status") +async def check_task_status(task_id: str = Query(..., description="音乐生成任务ID")): + """ + 检查音乐生成任务状态接口 + :param task_id: 音乐生成任务ID + :return: 任务状态信息 + """ + if not task_id: + raise HTTPException(status_code=400, detail="缺少任务ID参数") + + # 检查任务是否存在 + if task_id not in music_tasks: + raise HTTPException(status_code=404, detail=f"任务ID不存在: {task_id}") + + try: + logger.info(f"检查任务状态,任务ID: {task_id}") + + # 获取任务信息 + task_info = music_tasks[task_id] + generate_json = task_info["response"] + + # 构建查询URL + url_builder = [SunoMusicGenerator.FEED_URL, "?"] + + # 尝试从生成响应中获取clips的ID + clip_ids = [] + if "clips" in generate_json: + clips_array = generate_json["clips"] + for clip in clips_array: + if "id" in clip: + clip_ids.append(clip["id"]) + + # 添加ids参数 + if clip_ids: + ids_param = ",".join(clip_ids) + url_builder.append(f"ids={ids_param}") + logger.info(f"使用clips ID查询: {ids_param}") + else: + url_builder.append(f"ids={task_id}") + logger.info(f"使用任务ID查询: {task_id}") + + url = "".join(url_builder) + logger.info(f"查询URL: {url}") + + # 设置请求头 + headers = { + "Authorization": f"Bearer {AK}", + "Accept": "application/json" + } + + # 执行查询请求 + response = requests.get(url, headers=headers, timeout=30) + response.raise_for_status() + + # 解析查询响应 + json_response = response.json() + clips = json_response.get("clips", []) + + if clips: + # 遍历所有返回的音乐片段 + for clip in clips: + clip_id = clip.get("id") + status = clip.get("status") + title = clip.get("title") + audio_url = clip.get("audio_url") + + logger.info(f"查询结果:") + logger.info(f"ID: {clip_id}") + logger.info(f"标题: {title}") + logger.info(f"状态: {status}") + + # 更新任务状态 + task_info["status"] = status + task_info["last_check_time"] = datetime.datetime.now().isoformat() + task_info["title"] = title + + if status == "complete": + # 确保audio_url字段存在 + if audio_url: + # 移除URL中可能存在的反引号 + audio_url = audio_url.replace("`", "").strip() + task_info["audio_url"] = audio_url + logger.info("音乐生成已完成!") + logger.info(f"音频URL: {audio_url}") + + # 下载音频文件 + file_name = f"suno_music_{int(time.time())}.mp3" + save_path = music_generator.base_path / file_name + if music_generator.download_audio(audio_url, str(save_path)): + task_info["local_path"] = str(save_path) + else: + logger.warning("音乐生成已完成,但未找到音频URL!") + + elif status == "failed": + logger.error("音乐生成失败!") + task_info["error_message"] = clip.get("error_message", "未知错误") + + # 更新任务存储 + music_tasks[task_id] = task_info + + return { + "code": 200, + "message": "查询成功", + "data": { + "task_id": task_id, + "status": status, + "title": title, + "audio_url": audio_url if status == "complete" else None, + "local_path": task_info.get("local_path") if status == "complete" else None, + "error_message": task_info.get("error_message") if status == "failed" else None + } + } + else: + logger.info("未找到音乐片段") + return { + "code": 200, + "message": "查询成功", + "data": { + "task_id": task_id, + "status": "processing", + "title": None, + "audio_url": None, + "local_path": None, + "error_message": None + } + } + + except fastapi.Requests.exceptions.RequestException as e: + logger.error(f"查询任务状态失败: {e}") + raise HTTPException(status_code=500, detail=f"查询任务状态失败: {str(e)}") + except json.JSONDecodeError as e: + logger.error(f"解析查询响应失败: {e}") + raise HTTPException(status_code=500, detail=f"解析查询响应失败: {str(e)}") + except Exception as e: + logger.error(f"查询任务状态过程中发生错误: {e}") + raise HTTPException(status_code=500, detail=f"查询任务状态过程中发生错误: {str(e)}") + diff --git a/dsLightRag/Routes/__pycache__/SunoRoute.cpython-310.pyc b/dsLightRag/Routes/__pycache__/SunoRoute.cpython-310.pyc new file mode 100644 index 00000000..9ee09302 Binary files /dev/null and b/dsLightRag/Routes/__pycache__/SunoRoute.cpython-310.pyc differ diff --git a/dsLightRag/Routes/__pycache__/suno_music_generator.cpython-310.pyc b/dsLightRag/Routes/__pycache__/suno_music_generator.cpython-310.pyc new file mode 100644 index 00000000..e45569b3 Binary files /dev/null and b/dsLightRag/Routes/__pycache__/suno_music_generator.cpython-310.pyc differ diff --git a/dsLightRag/Suno/suno_music_generator.py b/dsLightRag/Routes/suno_music_generator.py similarity index 100% rename from dsLightRag/Suno/suno_music_generator.py rename to dsLightRag/Routes/suno_music_generator.py diff --git a/dsLightRag/Start.py b/dsLightRag/Start.py index 06a4f8ce..94c0c299 100644 --- a/dsLightRag/Start.py +++ b/dsLightRag/Start.py @@ -20,6 +20,7 @@ from Routes.TeachingModel.api.DocumentController import router as document_route from Routes.TeachingModel.api.TeachingModelController import router as teaching_model_router from Routes.QA import router as qa_router from Routes.JiMengRoute import router as jimeng_router +from Routes.SunoRoute import router as suno_router from Util.LightRagUtil import * from contextlib import asynccontextmanager @@ -33,17 +34,19 @@ logger.addHandler(handler) @asynccontextmanager -async def lifespan(_: FastAPI): +async def lifespan(_: FastAPI): pool = await init_postgres_pool() app.state.pool = pool asyncio.create_task(train_document_task()) - + try: - yield + yield finally: # 应用关闭时销毁连接池 await close_postgres_pool(pool) + + app = FastAPI(lifespan=lifespan) # 挂载静态文件目录 @@ -56,8 +59,9 @@ app.include_router(rag_router) # LightRAG路由 app.include_router(knowledge_router) # 知识图谱路由 app.include_router(oss_router) # 阿里云OSS路由 app.include_router(llm_router) # 大模型路由 -app.include_router(qa_router) # 答疑路由 -app.include_router(jimeng_router) # 即梦路由 +app.include_router(qa_router) # 答疑路由 +app.include_router(jimeng_router) # 即梦路由 +app.include_router(suno_router) # Suno路由 # Teaching Model 相关路由 # 登录相关(不用登录) @@ -75,6 +79,5 @@ app.include_router(teaching_model_router, prefix="/api/teaching/model", tags=["t # 教学答疑 app.include_router(teaching_model_router, prefix="/api/teaching/model", tags=["teacher_model"]) - if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=8100) diff --git a/dsLightRag/static/Suno/index.html b/dsLightRag/static/Suno/index.html index 77453110..75ae4f65 100644 --- a/dsLightRag/static/Suno/index.html +++ b/dsLightRag/static/Suno/index.html @@ -134,7 +134,7 @@ - +