This commit is contained in:
2025-08-31 13:15:58 +08:00
parent 1bdf7a3af4
commit 6af46bec96
7 changed files with 11 additions and 255 deletions

View File

@@ -84,15 +84,6 @@ async def streaming_chat(websocket: WebSocket):
logger.error(f"发送ASR结果失败: {str(e)}")
return
# 获取学伴响应内容(包含题目信息)
logger.info("获取学伴响应内容...")
llm_chunks = []
async for chunk in get_xueban_response_async(asr_result['text'], stream=True):
llm_chunks.append(chunk)
full_llm_response = ''.join(llm_chunks)
logger.info(f"学伴响应内容: {full_llm_response}")
# 定义音频回调函数,将音频块发送给前端
async def audio_callback(audio_chunk):
logger.info(f"发送音频块,大小: {len(audio_chunk)}")
@@ -103,12 +94,14 @@ async def streaming_chat(websocket: WebSocket):
logger.error(f"发送音频块失败: {str(e)}")
raise
# 修改streaming_chat函数中的相关部分
# 实时获取LLM流式输出并处理
logger.info("开始LLM流式处理和TTS合成...")
try:
# 直接使用stream_and_split_text获取LLM流式响应并断句
text_stream = stream_and_split_text(query_text=asr_result['text'])
# 获取LLM流式响应
llm_stream = get_xueban_response_async(asr_result['text'], stream=True)
# 使用stream_and_split_text处理流式响应并断句
text_stream = stream_and_split_text(llm_stream=llm_stream)
# 初始化TTS处理器
tts = StreamingVolcanoTTS(max_concurrency=1)

View File

@@ -1,235 +0,0 @@
import asyncio
import json
import os
import re
import uuid
from queue import Queue
import websockets
from Config import Config
from Util.TTS_Protocols import full_client_request, receive_message, MsgType, EventType
# 添加必要的导入
from Util.XueBanUtil import get_xueban_response_async
async def stream_and_split_text(query_text=None, llm_stream=None):
"""
流式获取LLM输出并按句子分割
@param query_text: 查询文本(如果直接提供查询文本)
@param llm_stream: LLM流式响应生成器如果已有流式响应
@return: 异步生成器,每次产生一个完整句子
"""
buffer = ""
if llm_stream is None and query_text is not None:
# 如果没有提供llm_stream但有query_text则使用get_xueban_response_async获取流式响应
llm_stream = get_xueban_response_async(query_text, stream=True)
elif llm_stream is None:
raise ValueError("必须提供query_text或llm_stream参数")
# 直接处理LLM流式输出
async for content in llm_stream:
buffer += content
# 使用正则表达式检测句子结束
sentences = re.split(r'([。!?.!?])', buffer)
if len(sentences) > 1:
# 提取完整句子
for i in range(0, len(sentences)-1, 2):
if i+1 < len(sentences):
sentence = sentences[i] + sentences[i+1]
yield sentence
# 保留不完整的部分
buffer = sentences[-1]
# 处理最后剩余的部分
if buffer:
yield buffer
# 修改streaming_tts_pipeline函数
async def streaming_tts_pipeline(query_text, audio_callback):
"""
流式TTS管道获取LLM流式输出并断句然后使用TTS合成语音
Args:
query_text: 查询文本
audio_callback: 音频数据回调函数
"""
# 1. 获取LLM流式输出并断句
text_stream = stream_and_split_text(query_text=query_text)
# 2. 初始化TTS处理器
tts = StreamingVolcanoTTS()
# 3. 流式处理文本并生成音频
await tts.synthesize_stream(text_stream, audio_callback)
class StreamingVolcanoTTS:
def __init__(self, voice_type='zh_female_wanwanxiaohe_moon_bigtts', encoding='wav', max_concurrency=2):
self.voice_type = voice_type
self.encoding = encoding
self.app_key = Config.HS_APP_ID
self.access_token = Config.HS_ACCESS_TOKEN
self.endpoint = "wss://openspeech.bytedance.com/api/v3/tts/unidirectional/stream"
self.audio_queue = Queue()
self.max_concurrency = max_concurrency # 最大并发数
self.semaphore = asyncio.Semaphore(max_concurrency) # 并发控制信号量
@staticmethod
def get_resource_id(voice: str) -> str:
if voice.startswith("S_"):
return "volc.megatts.default"
return "volc.service_type.10029"
async def synthesize_stream(self, text_stream, audio_callback):
"""
流式合成语音
Args:
text_stream: 文本流生成器
audio_callback: 音频数据回调函数,接收音频片段
"""
# 实时处理每个文本片段删除任务列表和gather
async for text in text_stream:
if text.strip():
await self._synthesize_single_with_semaphore(text, audio_callback)
async def _synthesize_single_with_semaphore(self, text, audio_callback):
"""使用信号量控制并发数的单个文本合成"""
async with self.semaphore: # 获取信号量,限制并发数
await self._synthesize_single(text, audio_callback)
async def _synthesize_single(self, text, audio_callback):
"""合成单个文本片段"""
headers = {
"X-Api-App-Key": self.app_key,
"X-Api-Access-Key": self.access_token,
"X-Api-Resource-Id": self.get_resource_id(self.voice_type),
"X-Api-Connect-Id": str(uuid.uuid4()),
}
websocket = await websockets.connect(
self.endpoint, additional_headers=headers, max_size=10 * 1024 * 1024
)
try:
request = {
"user": {
"uid": str(uuid.uuid4()),
},
"req_params": {
"speaker": self.voice_type,
"audio_params": {
"format": self.encoding,
"sample_rate": 24000,
"enable_timestamp": True,
},
"text": text,
"additions": json.dumps({"disable_markdown_filter": False}),
},
}
# 发送请求
await full_client_request(websocket, json.dumps(request).encode())
# 接收音频数据
audio_data = bytearray()
while True:
msg = await receive_message(websocket)
if msg.type == MsgType.FullServerResponse:
if msg.event == EventType.SessionFinished:
break
elif msg.type == MsgType.AudioOnlyServer:
audio_data.extend(msg.payload)
else:
raise RuntimeError(f"TTS conversion failed: {msg}")
# 通过回调函数返回音频数据
if audio_data:
await audio_callback(audio_data)
finally:
await websocket.close()
async def streaming_tts_pipeline(prompt, audio_callback):
"""
流式TTS管道获取LLM流式输出并断句然后使用TTS合成语音
Args:
prompt: 提示文本
audio_callback: 音频数据回调函数
"""
# 1. 获取LLM流式输出并断句
text_stream = stream_and_split_text(prompt)
# 2. 初始化TTS处理器
tts = StreamingVolcanoTTS()
# 3. 流式处理文本并生成音频
await tts.synthesize_stream(text_stream, audio_callback)
def save_audio_callback(output_dir=None):
"""
创建一个音频回调函数,用于保存音频数据到文件
Args:
output_dir: 输出目录默认为当前文件所在目录下的output文件夹
Returns:
音频回调函数
"""
if output_dir is None:
output_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "output")
# 确保输出目录存在
os.makedirs(output_dir, exist_ok=True)
def callback(audio_data):
# 生成文件名
filename = f"pipeline_tts_{uuid.uuid4().hex[:8]}.wav"
filepath = os.path.join(output_dir, filename)
# 保存音频文件
with open(filepath, "wb") as f:
f.write(audio_data)
print(f"音频片段已保存到: {filepath} ({len(audio_data)} 字节)")
return callback
async def test_pipeline():
"""
测试流式TTS管道
"""
# 创建音频回调函数
audio_handler = save_audio_callback()
# 测试提示
prompt = "请详细解释一下量子力学的基本原理,包括波粒二象性、不确定性原理和薛定谔方程。"
print("开始测试流式TTS管道...")
print(f"测试提示: {prompt}")
print("等待LLM生成文本并转换为语音...")
# 运行管道
await streaming_tts_pipeline(prompt, audio_handler)
print("流式TTS管道测试完成")
def main():
"""
主函数,运行测试
"""
asyncio.run(test_pipeline())
if __name__ == "__main__":
main()

View File

@@ -50,7 +50,7 @@ async def get_xueban_response_async(query_text: str, stream: bool = True):
zhishiContent = f.read()
zhishiContent = "选择作答的相应知识内容:" + zhishiContent + "\n"
query_text = zhishiContent + "下面是用户提的问题:" + query_text
logger.info("query_text: " + query_text)
#logger.info("query_text: " + query_text)
try:
# 创建请求

View File

@@ -216,7 +216,7 @@ const WebSocketManager = {
console.log('当前音频队列长度:', AudioState.playback.audioQueue.length);
// 如果尚未开始流式播放,则开始播放
if (!AudioState.playback.isStreamPlaying && !AudioState.playback.isPlaying) {
if (!AudioState.playback.isStreamPlaying) {
console.log('开始流式播放音频');
AudioState.playback.isStreamPlaying = true;
AudioPlayer.processAudioQueue();
@@ -393,17 +393,16 @@ const AudioPlayer = {
// 处理音频队列
processAudioQueue() {
// 如果正在播放或队列为空,则返回
if (AudioState.playback.isStreamPlaying || AudioState.playback.audioQueue.length === 0) {
// 如果队列为空,则返回
if (AudioState.playback.audioQueue.length === 0) {
AudioState.playback.isStreamPlaying = false;
console.log('音频队列为空,停止流式播放');
return;
}
// 设置播放状态
AudioState.playback.isStreamPlaying = true;
// 从队列中取出第一个音频块
const audioBlob = AudioState.playback.audioQueue.shift();
console.log('从队列取出音频块,剩余队列长度:', AudioState.playback.audioQueue.length);
// 创建音频URL
const audioUrl = URL.createObjectURL(audioBlob);
@@ -417,7 +416,6 @@ const AudioPlayer = {
.catch(error => {
console.error('播放音频块失败:', error);
// 播放失败,继续处理下一个
AudioState.playback.isStreamPlaying = false;
this.processAudioQueue();
})
.finally(() => {