Files
dsProject/dsLightRag/Test/TTS/T2_StreamingVolanoTTS.py
2025-08-31 09:37:21 +08:00

156 lines
5.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import json
import uuid
import websockets
import os
from queue import Queue
from Config import Config
from Util.TTS_Protocols import full_client_request, receive_message, MsgType, EventType
class StreamingVolcanoTTS:
def __init__(self, voice_type='zh_female_wanwanxiaohe_moon_bigtts', encoding='wav', max_concurrency=2):
self.voice_type = voice_type
self.encoding = encoding
self.app_key = Config.HS_APP_ID
self.access_token = Config.HS_ACCESS_TOKEN
self.endpoint = "wss://openspeech.bytedance.com/api/v3/tts/unidirectional/stream"
self.audio_queue = Queue()
self.max_concurrency = max_concurrency # 最大并发数
self.semaphore = asyncio.Semaphore(max_concurrency) # 并发控制信号量
@staticmethod
def get_resource_id(voice: str) -> str:
if voice.startswith("S_"):
return "volc.megatts.default"
return "volc.service_type.10029"
async def synthesize_stream(self, text_stream, audio_callback):
"""
流式合成语音
Args:
text_stream: 文本流生成器
audio_callback: 音频数据回调函数,接收音频片段
"""
# 为每个文本片段创建一个WebSocket连接但限制并发数
tasks = []
for text in text_stream:
if text.strip(): # 忽略空文本
task = asyncio.create_task(self._synthesize_single_with_semaphore(text, audio_callback))
tasks.append(task)
# 等待所有任务完成
await asyncio.gather(*tasks)
async def _synthesize_single_with_semaphore(self, text, audio_callback):
"""使用信号量控制并发数的单个文本合成"""
async with self.semaphore: # 获取信号量,限制并发数
await self._synthesize_single(text, audio_callback)
async def _synthesize_single(self, text, audio_callback):
"""合成单个文本片段"""
headers = {
"X-Api-App-Key": self.app_key,
"X-Api-Access-Key": self.access_token,
"X-Api-Resource-Id": self.get_resource_id(self.voice_type),
"X-Api-Connect-Id": str(uuid.uuid4()),
}
websocket = await websockets.connect(
self.endpoint, additional_headers=headers, max_size=10 * 1024 * 1024
)
try:
request = {
"user": {
"uid": str(uuid.uuid4()),
},
"req_params": {
"speaker": self.voice_type,
"audio_params": {
"format": self.encoding,
"sample_rate": 24000,
"enable_timestamp": True,
},
"text": text,
"additions": json.dumps({"disable_markdown_filter": False}),
},
}
# 发送请求
await full_client_request(websocket, json.dumps(request).encode())
# 接收音频数据
audio_data = bytearray()
while True:
msg = await receive_message(websocket)
if msg.type == MsgType.FullServerResponse:
if msg.event == EventType.SessionFinished:
break
elif msg.type == MsgType.AudioOnlyServer:
audio_data.extend(msg.payload)
else:
raise RuntimeError(f"TTS conversion failed: {msg}")
# 通过回调函数返回音频数据
if audio_data:
audio_callback(audio_data)
finally:
await websocket.close()
def audio_callback(audio_data):
"""
音频数据回调函数,将音频数据保存到文件
"""
# 创建输出目录
output_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "output")
os.makedirs(output_dir, exist_ok=True)
# 生成文件名
filename = f"streaming_tts_{uuid.uuid4().hex[:8]}.wav"
filepath = os.path.join(output_dir, filename)
# 保存音频文件
with open(filepath, "wb") as f:
f.write(audio_data)
print(f"音频片段已保存到: {filepath}")
async def test_streaming_tts():
"""
测试流式TTS功能
"""
# 创建TTS实例
tts = StreamingVolcanoTTS()
# 准备测试文本流
test_texts = [
"你好,我是火山引擎的语音合成服务。",
"这是一个流式语音合成的测试。",
"我们将文本分成多个片段进行合成。",
"这样可以减少等待时间,提高用户体验。"
]
print("开始测试流式TTS...")
print(f"测试文本: {test_texts}")
# 调用流式合成
await tts.synthesize_stream(test_texts, audio_callback)
print("流式TTS测试完成")
def main():
"""
主函数,运行测试
"""
asyncio.run(test_streaming_tts())
if __name__ == "__main__":
main()