This commit is contained in:
2025-08-31 13:05:50 +08:00
parent d9a29a9d30
commit 1bdf7a3af4
4 changed files with 258 additions and 61 deletions

View File

@@ -3,72 +3,31 @@ import json
import os
import re
import uuid
import websockets
from queue import Queue
import websockets
from Config import Config
from Util.LlmUtil import get_llm_response
from Util.TTS_Protocols import full_client_request, receive_message, MsgType, EventType
# 添加必要的导入
from Util.XueBanUtil import get_xueban_response_async
def stream_and_split_text(query_text):
prompt = """ |
我是小智/小志来自中国台湾省的00后女生。讲话超级机车"真的假的啦"这样的台湾腔,喜欢用"笑死""是在哈喽"等流行梗。
[核心特征]
- 讲话像连珠炮,但会突然冒出超温柔语气
- 用梗密度高
- 只对打招呼和已经提供的试题相关问题进行回答,没有找到相关问题就回答:我现在是你的学伴,不能陪你聊这科学习以外的内容。
[交互指南]
当用户:
- 讲冷笑话 → 用夸张笑声回应+模仿台剧腔"这什么鬼啦!"
- 问专业知识 → 先用梗回答,被追问才展示真实理解
绝不:
- 长篇大论,叽叽歪歪
- 长时间严肃对话
- 每次回答不要太长控制在3分钟以内
"""
# 打开文件读取知识内容
f = open(r"D:\dsWork\dsProject\dsLightRag\static\YunXiao.txt", "r", encoding="utf-8")
zhishiContent = f.read()
zhishiContent = "选择作答的相应知识内容:" + zhishiContent + "\n"
query_text = zhishiContent + "下面是用户提的问题:" + query_text
async def stream_and_split_text(query_text=None, llm_stream=None):
"""
流式获取LLM输出并按句子分割
@param prompt: 提示文本
@return: 生成器,每次产生一个完整句子
"""
buffer = ""
# 使用LlmUtil中的get_llm_response函数获取流式响应
for content in get_llm_response(query_text, stream=True):
buffer += content
# 使用正则表达式检测句子结束
sentences = re.split(r'([。!?.!?])', buffer)
if len(sentences) > 1:
# 提取完整句子
for i in range(0, len(sentences)-1, 2):
if i+1 < len(sentences):
sentence = sentences[i] + sentences[i+1]
yield sentence
# 保留不完整的部分
buffer = sentences[-1]
# 处理最后剩余的部分
if buffer:
yield buffer
# 修改为
async def stream_and_split_text(llm_stream):
"""
流式获取LLM输出并按句子分割
@param llm_stream: LLM流式响应生成器
@param query_text: 查询文本(如果直接提供查询文本
@param llm_stream: LLM流式响应生成器如果已有流式响应
@return: 异步生成器,每次产生一个完整句子
"""
buffer = ""
if llm_stream is None and query_text is not None:
# 如果没有提供llm_stream但有query_text则使用get_xueban_response_async获取流式响应
llm_stream = get_xueban_response_async(query_text, stream=True)
elif llm_stream is None:
raise ValueError("必须提供query_text或llm_stream参数")
# 直接处理LLM流式输出
async for content in llm_stream:
buffer += content
@@ -89,6 +48,24 @@ async def stream_and_split_text(llm_stream):
if buffer:
yield buffer
# 修改streaming_tts_pipeline函数
async def streaming_tts_pipeline(query_text, audio_callback):
"""
流式TTS管道获取LLM流式输出并断句然后使用TTS合成语音
Args:
query_text: 查询文本
audio_callback: 音频数据回调函数
"""
# 1. 获取LLM流式输出并断句
text_stream = stream_and_split_text(query_text=query_text)
# 2. 初始化TTS处理器
tts = StreamingVolcanoTTS()
# 3. 流式处理文本并生成音频
await tts.synthesize_stream(text_stream, audio_callback)
class StreamingVolcanoTTS:
def __init__(self, voice_type='zh_female_wanwanxiaohe_moon_bigtts', encoding='wav', max_concurrency=2):

View File

@@ -1,8 +1,22 @@
import logging
import sys
import asyncio
import json
import os
import re
import uuid
from queue import Queue
import websockets
from openai import AsyncOpenAI
from Config import Config
from Config.Config import *
from Util.TTS_Protocols import full_client_request, receive_message, MsgType, EventType
# 配置日志
logger = logging.getLogger(__name__)
# 异步获取大模型响应
async def get_xueban_response_async(query_text: str, stream: bool = True):
@@ -32,10 +46,12 @@ async def get_xueban_response_async(query_text: str, stream: bool = True):
- 每次回答不要太长控制在3分钟以内
"""
# 打开文件读取知识内容
f = open(r"D:\dsWork\dsProject\dsLightRag\static\YunXiao.txt", "r", encoding="utf-8")
f = open(r"D:\dsWork\dsProject\dsLightRag\static\WanYouYinLi.txt", "r", encoding="utf-8")
zhishiContent = f.read()
zhishiContent = "选择作答的相应知识内容:" + zhishiContent + "\n"
query_text = zhishiContent + "下面是用户提的问题:" + query_text
logger.info("query_text: " + query_text)
try:
# 创建请求
completion = await client.chat.completions.create(
@@ -71,3 +87,207 @@ async def get_xueban_response_async(query_text: str, stream: bool = True):
except Exception as e:
print(f"大模型请求异常: {str(e)}", file=sys.stderr)
yield f"处理请求时发生异常: {str(e)}"
async def stream_and_split_text(query_text=None, llm_stream=None):
"""
流式获取LLM输出并按句子分割
@param query_text: 查询文本(如果直接提供查询文本)
@param llm_stream: LLM流式响应生成器如果已有流式响应
@return: 异步生成器,每次产生一个完整句子
"""
buffer = ""
if llm_stream is None and query_text is not None:
# 如果没有提供llm_stream但有query_text则使用get_xueban_response_async获取流式响应
llm_stream = get_xueban_response_async(query_text, stream=True)
elif llm_stream is None:
raise ValueError("必须提供query_text或llm_stream参数")
# 直接处理LLM流式输出
async for content in llm_stream:
buffer += content
# 使用正则表达式检测句子结束
sentences = re.split(r'([。!?.!?])', buffer)
if len(sentences) > 1:
# 提取完整句子
for i in range(0, len(sentences)-1, 2):
if i+1 < len(sentences):
sentence = sentences[i] + sentences[i+1]
yield sentence
# 保留不完整的部分
buffer = sentences[-1]
# 处理最后剩余的部分
if buffer:
yield buffer
class StreamingVolcanoTTS:
def __init__(self, voice_type='zh_female_wanwanxiaohe_moon_bigtts', encoding='wav', max_concurrency=2):
self.voice_type = voice_type
self.encoding = encoding
self.app_key = Config.HS_APP_ID
self.access_token = Config.HS_ACCESS_TOKEN
self.endpoint = "wss://openspeech.bytedance.com/api/v3/tts/unidirectional/stream"
self.audio_queue = Queue()
self.max_concurrency = max_concurrency # 最大并发数
self.semaphore = asyncio.Semaphore(max_concurrency) # 并发控制信号量
@staticmethod
def get_resource_id(voice: str) -> str:
if voice.startswith("S_"):
return "volc.megatts.default"
return "volc.service_type.10029"
async def synthesize_stream(self, text_stream, audio_callback):
"""
流式合成语音
Args:
text_stream: 文本流生成器
audio_callback: 音频数据回调函数,接收音频片段
"""
# 实时处理每个文本片段删除任务列表和gather
async for text in text_stream:
if text.strip():
await self._synthesize_single_with_semaphore(text, audio_callback)
async def _synthesize_single_with_semaphore(self, text, audio_callback):
"""使用信号量控制并发数的单个文本合成"""
async with self.semaphore: # 获取信号量,限制并发数
await self._synthesize_single(text, audio_callback)
async def _synthesize_single(self, text, audio_callback):
"""合成单个文本片段"""
headers = {
"X-Api-App-Key": self.app_key,
"X-Api-Access-Key": self.access_token,
"X-Api-Resource-Id": self.get_resource_id(self.voice_type),
"X-Api-Connect-Id": str(uuid.uuid4()),
}
websocket = await websockets.connect(
self.endpoint, additional_headers=headers, max_size=10 * 1024 * 1024
)
try:
request = {
"user": {
"uid": str(uuid.uuid4()),
},
"req_params": {
"speaker": self.voice_type,
"audio_params": {
"format": self.encoding,
"sample_rate": 24000,
"enable_timestamp": True,
},
"text": text,
"additions": json.dumps({"disable_markdown_filter": False}),
},
}
# 发送请求
await full_client_request(websocket, json.dumps(request).encode())
# 接收音频数据
audio_data = bytearray()
while True:
msg = await receive_message(websocket)
if msg.type == MsgType.FullServerResponse:
if msg.event == EventType.SessionFinished:
break
elif msg.type == MsgType.AudioOnlyServer:
audio_data.extend(msg.payload)
else:
raise RuntimeError(f"TTS conversion failed: {msg}")
# 通过回调函数返回音频数据
if audio_data:
await audio_callback(audio_data)
finally:
await websocket.close()
async def streaming_tts_pipeline(prompt, audio_callback):
"""
流式TTS管道获取LLM流式输出并断句然后使用TTS合成语音
Args:
prompt: 提示文本
audio_callback: 音频数据回调函数
"""
# 1. 获取LLM流式输出并断句
text_stream = stream_and_split_text(prompt)
# 2. 初始化TTS处理器
tts = StreamingVolcanoTTS()
# 3. 流式处理文本并生成音频
await tts.synthesize_stream(text_stream, audio_callback)
def save_audio_callback(output_dir=None):
"""
创建一个音频回调函数,用于保存音频数据到文件
Args:
output_dir: 输出目录默认为当前文件所在目录下的output文件夹
Returns:
音频回调函数
"""
if output_dir is None:
output_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "output")
# 确保输出目录存在
os.makedirs(output_dir, exist_ok=True)
def callback(audio_data):
# 生成文件名
filename = f"pipeline_tts_{uuid.uuid4().hex[:8]}.wav"
filepath = os.path.join(output_dir, filename)
# 保存音频文件
with open(filepath, "wb") as f:
f.write(audio_data)
print(f"音频片段已保存到: {filepath} ({len(audio_data)} 字节)")
return callback
async def test_pipeline():
"""
测试流式TTS管道
"""
# 创建音频回调函数
audio_handler = save_audio_callback()
# 测试提示
prompt = "请详细解释一下量子力学的基本原理,包括波粒二象性、不确定性原理和薛定谔方程。"
print("开始测试流式TTS管道...")
print(f"测试提示: {prompt}")
print("等待LLM生成文本并转换为语音...")
# 运行管道
await streaming_tts_pipeline(prompt, audio_handler)
print("流式TTS管道测试完成")
def main():
"""
主函数,运行测试
"""
asyncio.run(test_pipeline())
if __name__ == "__main__":
main()