'commit'
This commit is contained in:
@@ -1,12 +1,137 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import re
|
||||
import uuid
|
||||
import websockets
|
||||
from queue import Queue
|
||||
|
||||
# 添加路径以导入其他模块
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
from T1_LLM import stream_and_split_text
|
||||
from T2_StreamingVolanoTTS import StreamingVolcanoTTS
|
||||
from Config import Config
|
||||
from Util.LlmUtil import get_llm_response
|
||||
from Util.TTS_Protocols import full_client_request, receive_message, MsgType, EventType
|
||||
|
||||
|
||||
def stream_and_split_text(prompt):
|
||||
"""
|
||||
流式获取LLM输出并按句子分割
|
||||
@param prompt: 提示文本
|
||||
@return: 生成器,每次产生一个完整句子
|
||||
"""
|
||||
buffer = ""
|
||||
|
||||
# 使用LlmUtil中的get_llm_response函数获取流式响应
|
||||
for content in get_llm_response(prompt, stream=True):
|
||||
buffer += content
|
||||
|
||||
# 使用正则表达式检测句子结束
|
||||
sentences = re.split(r'([。!?.!?])', buffer)
|
||||
if len(sentences) > 1:
|
||||
# 提取完整句子
|
||||
for i in range(0, len(sentences)-1, 2):
|
||||
if i+1 < len(sentences):
|
||||
sentence = sentences[i] + sentences[i+1]
|
||||
yield sentence
|
||||
|
||||
# 保留不完整的部分
|
||||
buffer = sentences[-1]
|
||||
|
||||
# 处理最后剩余的部分
|
||||
if buffer:
|
||||
yield buffer
|
||||
|
||||
|
||||
class StreamingVolcanoTTS:
|
||||
def __init__(self, voice_type='zh_female_wanwanxiaohe_moon_bigtts', encoding='wav', max_concurrency=2):
|
||||
self.voice_type = voice_type
|
||||
self.encoding = encoding
|
||||
self.app_key = Config.HS_APP_ID
|
||||
self.access_token = Config.HS_ACCESS_TOKEN
|
||||
self.endpoint = "wss://openspeech.bytedance.com/api/v3/tts/unidirectional/stream"
|
||||
self.audio_queue = Queue()
|
||||
self.max_concurrency = max_concurrency # 最大并发数
|
||||
self.semaphore = asyncio.Semaphore(max_concurrency) # 并发控制信号量
|
||||
|
||||
@staticmethod
|
||||
def get_resource_id(voice: str) -> str:
|
||||
if voice.startswith("S_"):
|
||||
return "volc.megatts.default"
|
||||
return "volc.service_type.10029"
|
||||
|
||||
async def synthesize_stream(self, text_stream, audio_callback):
|
||||
"""
|
||||
流式合成语音
|
||||
|
||||
Args:
|
||||
text_stream: 文本流生成器
|
||||
audio_callback: 音频数据回调函数,接收音频片段
|
||||
"""
|
||||
# 为每个文本片段创建一个WebSocket连接,但限制并发数
|
||||
tasks = []
|
||||
for text in text_stream:
|
||||
if text.strip(): # 忽略空文本
|
||||
task = asyncio.create_task(self._synthesize_single_with_semaphore(text, audio_callback))
|
||||
tasks.append(task)
|
||||
|
||||
# 等待所有任务完成
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
async def _synthesize_single_with_semaphore(self, text, audio_callback):
|
||||
"""使用信号量控制并发数的单个文本合成"""
|
||||
async with self.semaphore: # 获取信号量,限制并发数
|
||||
await self._synthesize_single(text, audio_callback)
|
||||
|
||||
async def _synthesize_single(self, text, audio_callback):
|
||||
"""合成单个文本片段"""
|
||||
headers = {
|
||||
"X-Api-App-Key": self.app_key,
|
||||
"X-Api-Access-Key": self.access_token,
|
||||
"X-Api-Resource-Id": self.get_resource_id(self.voice_type),
|
||||
"X-Api-Connect-Id": str(uuid.uuid4()),
|
||||
}
|
||||
|
||||
websocket = await websockets.connect(
|
||||
self.endpoint, additional_headers=headers, max_size=10 * 1024 * 1024
|
||||
)
|
||||
|
||||
try:
|
||||
request = {
|
||||
"user": {
|
||||
"uid": str(uuid.uuid4()),
|
||||
},
|
||||
"req_params": {
|
||||
"speaker": self.voice_type,
|
||||
"audio_params": {
|
||||
"format": self.encoding,
|
||||
"sample_rate": 24000,
|
||||
"enable_timestamp": True,
|
||||
},
|
||||
"text": text,
|
||||
"additions": json.dumps({"disable_markdown_filter": False}),
|
||||
},
|
||||
}
|
||||
|
||||
# 发送请求
|
||||
await full_client_request(websocket, json.dumps(request).encode())
|
||||
|
||||
# 接收音频数据
|
||||
audio_data = bytearray()
|
||||
while True:
|
||||
msg = await receive_message(websocket)
|
||||
|
||||
if msg.type == MsgType.FullServerResponse:
|
||||
if msg.event == EventType.SessionFinished:
|
||||
break
|
||||
elif msg.type == MsgType.AudioOnlyServer:
|
||||
audio_data.extend(msg.payload)
|
||||
else:
|
||||
raise RuntimeError(f"TTS conversion failed: {msg}")
|
||||
|
||||
# 通过回调函数返回音频数据
|
||||
if audio_data:
|
||||
audio_callback(audio_data)
|
||||
|
||||
finally:
|
||||
await websocket.close()
|
||||
|
||||
|
||||
async def streaming_tts_pipeline(prompt, audio_callback):
|
||||
|
Reference in New Issue
Block a user