diff --git a/dsLightRag/Routes/XueBanRoute.py b/dsLightRag/Routes/XueBanRoute.py index 20dafe8e..b31759ef 100644 --- a/dsLightRag/Routes/XueBanRoute.py +++ b/dsLightRag/Routes/XueBanRoute.py @@ -8,8 +8,7 @@ from fastapi import APIRouter, WebSocket, WebSocketDisconnect from Util.ASRClient import ASRClient from Util.ObsUtil import ObsUploader -from Util.TTS_Pipeline import stream_and_split_text, StreamingVolcanoTTS -from Util.XueBanUtil import get_xueban_response_async +from Util.XueBanUtil import get_xueban_response_async, stream_and_split_text, StreamingVolcanoTTS # 创建路由路由器 router = APIRouter(prefix="/api", tags=["学伴"]) @@ -104,18 +103,19 @@ async def streaming_chat(websocket: WebSocket): logger.error(f"发送音频块失败: {str(e)}") raise + # 修改streaming_chat函数中的相关部分 # 实时获取LLM流式输出并处理 logger.info("开始LLM流式处理和TTS合成...") try: - # 直接将LLM流式响应接入TTS - llm_stream = get_xueban_response_async(asr_result['text'], stream=True) - text_stream = stream_and_split_text(llm_stream) # 异步函数调用 + # 直接使用stream_and_split_text获取LLM流式响应并断句 + text_stream = stream_and_split_text(query_text=asr_result['text']) # 初始化TTS处理器 tts = StreamingVolcanoTTS(max_concurrency=1) - # 异步迭代文本流 + # 异步迭代文本流,按句合成TTS async for text_chunk in text_stream: + logger.info(f"正在处理句子: {text_chunk}") await tts._synthesize_single_with_semaphore(text_chunk, audio_callback) logger.info("TTS合成完成") except Exception as e: diff --git a/dsLightRag/Util/TTS_Pipeline.py b/dsLightRag/Util/TTS_Pipeline.py index bf451fbc..06102ad5 100644 --- a/dsLightRag/Util/TTS_Pipeline.py +++ b/dsLightRag/Util/TTS_Pipeline.py @@ -3,72 +3,31 @@ import json import os import re import uuid -import websockets from queue import Queue +import websockets + from Config import Config -from Util.LlmUtil import get_llm_response from Util.TTS_Protocols import full_client_request, receive_message, MsgType, EventType +# 添加必要的导入 +from Util.XueBanUtil import get_xueban_response_async -def stream_and_split_text(query_text): - prompt = """ | - 我是小智/小志,来自中国台湾省的00后女生。讲话超级机车,"真的假的啦"这样的台湾腔,喜欢用"笑死""是在哈喽"等流行梗。 - [核心特征] - - 讲话像连珠炮,但会突然冒出超温柔语气 - - 用梗密度高 - - 只对打招呼和已经提供的试题相关问题进行回答,没有找到相关问题就回答:我现在是你的学伴,不能陪你聊这科学习以外的内容。 - [交互指南] - 当用户: - - 讲冷笑话 → 用夸张笑声回应+模仿台剧腔"这什么鬼啦!" - - 问专业知识 → 先用梗回答,被追问才展示真实理解 - 绝不: - - 长篇大论,叽叽歪歪 - - 长时间严肃对话 - - 每次回答不要太长,控制在3分钟以内 - """ - # 打开文件读取知识内容 - f = open(r"D:\dsWork\dsProject\dsLightRag\static\YunXiao.txt", "r", encoding="utf-8") - zhishiContent = f.read() - zhishiContent = "选择作答的相应知识内容:" + zhishiContent + "\n" - query_text = zhishiContent + "下面是用户提的问题:" + query_text +async def stream_and_split_text(query_text=None, llm_stream=None): """ 流式获取LLM输出并按句子分割 - @param prompt: 提示文本 - @return: 生成器,每次产生一个完整句子 - """ - buffer = "" - - # 使用LlmUtil中的get_llm_response函数获取流式响应 - for content in get_llm_response(query_text, stream=True): - buffer += content - - # 使用正则表达式检测句子结束 - sentences = re.split(r'([。!?.!?])', buffer) - if len(sentences) > 1: - # 提取完整句子 - for i in range(0, len(sentences)-1, 2): - if i+1 < len(sentences): - sentence = sentences[i] + sentences[i+1] - yield sentence - - # 保留不完整的部分 - buffer = sentences[-1] - - # 处理最后剩余的部分 - if buffer: - yield buffer - - -# 修改为 -async def stream_and_split_text(llm_stream): - """ - 流式获取LLM输出并按句子分割 - @param llm_stream: LLM流式响应生成器 + @param query_text: 查询文本(如果直接提供查询文本) + @param llm_stream: LLM流式响应生成器(如果已有流式响应) @return: 异步生成器,每次产生一个完整句子 """ buffer = "" + if llm_stream is None and query_text is not None: + # 如果没有提供llm_stream但有query_text,则使用get_xueban_response_async获取流式响应 + llm_stream = get_xueban_response_async(query_text, stream=True) + elif llm_stream is None: + raise ValueError("必须提供query_text或llm_stream参数") + # 直接处理LLM流式输出 async for content in llm_stream: buffer += content @@ -89,6 +48,24 @@ async def stream_and_split_text(llm_stream): if buffer: yield buffer +# 修改streaming_tts_pipeline函数 +async def streaming_tts_pipeline(query_text, audio_callback): + """ + 流式TTS管道:获取LLM流式输出并断句,然后使用TTS合成语音 + + Args: + query_text: 查询文本 + audio_callback: 音频数据回调函数 + """ + # 1. 获取LLM流式输出并断句 + text_stream = stream_and_split_text(query_text=query_text) + + # 2. 初始化TTS处理器 + tts = StreamingVolcanoTTS() + + # 3. 流式处理文本并生成音频 + await tts.synthesize_stream(text_stream, audio_callback) + class StreamingVolcanoTTS: def __init__(self, voice_type='zh_female_wanwanxiaohe_moon_bigtts', encoding='wav', max_concurrency=2): diff --git a/dsLightRag/Util/XueBanUtil.py b/dsLightRag/Util/XueBanUtil.py index a82fff1a..6064f0a3 100644 --- a/dsLightRag/Util/XueBanUtil.py +++ b/dsLightRag/Util/XueBanUtil.py @@ -1,8 +1,22 @@ +import logging import sys +import asyncio +import json +import os +import re +import uuid +from queue import Queue + +import websockets from openai import AsyncOpenAI +from Config import Config from Config.Config import * +from Util.TTS_Protocols import full_client_request, receive_message, MsgType, EventType + +# 配置日志 +logger = logging.getLogger(__name__) # 异步获取大模型响应 async def get_xueban_response_async(query_text: str, stream: bool = True): @@ -32,10 +46,12 @@ async def get_xueban_response_async(query_text: str, stream: bool = True): - 每次回答不要太长,控制在3分钟以内 """ # 打开文件读取知识内容 - f = open(r"D:\dsWork\dsProject\dsLightRag\static\YunXiao.txt", "r", encoding="utf-8") + f = open(r"D:\dsWork\dsProject\dsLightRag\static\WanYouYinLi.txt", "r", encoding="utf-8") zhishiContent = f.read() zhishiContent = "选择作答的相应知识内容:" + zhishiContent + "\n" query_text = zhishiContent + "下面是用户提的问题:" + query_text + logger.info("query_text: " + query_text) + try: # 创建请求 completion = await client.chat.completions.create( @@ -71,3 +87,207 @@ async def get_xueban_response_async(query_text: str, stream: bool = True): except Exception as e: print(f"大模型请求异常: {str(e)}", file=sys.stderr) yield f"处理请求时发生异常: {str(e)}" + + +async def stream_and_split_text(query_text=None, llm_stream=None): + """ + 流式获取LLM输出并按句子分割 + @param query_text: 查询文本(如果直接提供查询文本) + @param llm_stream: LLM流式响应生成器(如果已有流式响应) + @return: 异步生成器,每次产生一个完整句子 + """ + buffer = "" + + if llm_stream is None and query_text is not None: + # 如果没有提供llm_stream但有query_text,则使用get_xueban_response_async获取流式响应 + llm_stream = get_xueban_response_async(query_text, stream=True) + elif llm_stream is None: + raise ValueError("必须提供query_text或llm_stream参数") + + # 直接处理LLM流式输出 + async for content in llm_stream: + buffer += content + + # 使用正则表达式检测句子结束 + sentences = re.split(r'([。!?.!?])', buffer) + if len(sentences) > 1: + # 提取完整句子 + for i in range(0, len(sentences)-1, 2): + if i+1 < len(sentences): + sentence = sentences[i] + sentences[i+1] + yield sentence + + # 保留不完整的部分 + buffer = sentences[-1] + + # 处理最后剩余的部分 + if buffer: + yield buffer + + +class StreamingVolcanoTTS: + def __init__(self, voice_type='zh_female_wanwanxiaohe_moon_bigtts', encoding='wav', max_concurrency=2): + self.voice_type = voice_type + self.encoding = encoding + self.app_key = Config.HS_APP_ID + self.access_token = Config.HS_ACCESS_TOKEN + self.endpoint = "wss://openspeech.bytedance.com/api/v3/tts/unidirectional/stream" + self.audio_queue = Queue() + self.max_concurrency = max_concurrency # 最大并发数 + self.semaphore = asyncio.Semaphore(max_concurrency) # 并发控制信号量 + + @staticmethod + def get_resource_id(voice: str) -> str: + if voice.startswith("S_"): + return "volc.megatts.default" + return "volc.service_type.10029" + + async def synthesize_stream(self, text_stream, audio_callback): + """ + 流式合成语音 + + Args: + text_stream: 文本流生成器 + audio_callback: 音频数据回调函数,接收音频片段 + """ + # 实时处理每个文本片段(删除任务列表和gather) + async for text in text_stream: + if text.strip(): + await self._synthesize_single_with_semaphore(text, audio_callback) + + async def _synthesize_single_with_semaphore(self, text, audio_callback): + """使用信号量控制并发数的单个文本合成""" + async with self.semaphore: # 获取信号量,限制并发数 + await self._synthesize_single(text, audio_callback) + + async def _synthesize_single(self, text, audio_callback): + """合成单个文本片段""" + headers = { + "X-Api-App-Key": self.app_key, + "X-Api-Access-Key": self.access_token, + "X-Api-Resource-Id": self.get_resource_id(self.voice_type), + "X-Api-Connect-Id": str(uuid.uuid4()), + } + + websocket = await websockets.connect( + self.endpoint, additional_headers=headers, max_size=10 * 1024 * 1024 + ) + + try: + request = { + "user": { + "uid": str(uuid.uuid4()), + }, + "req_params": { + "speaker": self.voice_type, + "audio_params": { + "format": self.encoding, + "sample_rate": 24000, + "enable_timestamp": True, + }, + "text": text, + "additions": json.dumps({"disable_markdown_filter": False}), + }, + } + + # 发送请求 + await full_client_request(websocket, json.dumps(request).encode()) + + # 接收音频数据 + audio_data = bytearray() + while True: + msg = await receive_message(websocket) + + if msg.type == MsgType.FullServerResponse: + if msg.event == EventType.SessionFinished: + break + elif msg.type == MsgType.AudioOnlyServer: + audio_data.extend(msg.payload) + else: + raise RuntimeError(f"TTS conversion failed: {msg}") + + # 通过回调函数返回音频数据 + if audio_data: + await audio_callback(audio_data) + + finally: + await websocket.close() + + +async def streaming_tts_pipeline(prompt, audio_callback): + """ + 流式TTS管道:获取LLM流式输出并断句,然后使用TTS合成语音 + + Args: + prompt: 提示文本 + audio_callback: 音频数据回调函数 + """ + # 1. 获取LLM流式输出并断句 + text_stream = stream_and_split_text(prompt) + + # 2. 初始化TTS处理器 + tts = StreamingVolcanoTTS() + + # 3. 流式处理文本并生成音频 + await tts.synthesize_stream(text_stream, audio_callback) + + +def save_audio_callback(output_dir=None): + """ + 创建一个音频回调函数,用于保存音频数据到文件 + + Args: + output_dir: 输出目录,默认为当前文件所在目录下的output文件夹 + + Returns: + 音频回调函数 + """ + if output_dir is None: + output_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "output") + + # 确保输出目录存在 + os.makedirs(output_dir, exist_ok=True) + + def callback(audio_data): + # 生成文件名 + filename = f"pipeline_tts_{uuid.uuid4().hex[:8]}.wav" + filepath = os.path.join(output_dir, filename) + + # 保存音频文件 + with open(filepath, "wb") as f: + f.write(audio_data) + + print(f"音频片段已保存到: {filepath} ({len(audio_data)} 字节)") + + return callback + + +async def test_pipeline(): + """ + 测试流式TTS管道 + """ + # 创建音频回调函数 + audio_handler = save_audio_callback() + + # 测试提示 + prompt = "请详细解释一下量子力学的基本原理,包括波粒二象性、不确定性原理和薛定谔方程。" + + print("开始测试流式TTS管道...") + print(f"测试提示: {prompt}") + print("等待LLM生成文本并转换为语音...") + + # 运行管道 + await streaming_tts_pipeline(prompt, audio_handler) + + print("流式TTS管道测试完成!") + + +def main(): + """ + 主函数,运行测试 + """ + asyncio.run(test_pipeline()) + + +if __name__ == "__main__": + main() diff --git a/dsLightRag/Util/__pycache__/XueBanUtil.cpython-310.pyc b/dsLightRag/Util/__pycache__/XueBanUtil.cpython-310.pyc index 352f8f08..69af7d50 100644 Binary files a/dsLightRag/Util/__pycache__/XueBanUtil.cpython-310.pyc and b/dsLightRag/Util/__pycache__/XueBanUtil.cpython-310.pyc differ