156 lines
5.0 KiB
Python
156 lines
5.0 KiB
Python
|
import asyncio
|
|||
|
import json
|
|||
|
import uuid
|
|||
|
import websockets
|
|||
|
import os
|
|||
|
from queue import Queue
|
|||
|
from Config import Config
|
|||
|
from Util.TTS_Protocols import full_client_request, receive_message, MsgType, EventType
|
|||
|
|
|||
|
|
|||
|
class StreamingVolcanoTTS:
|
|||
|
def __init__(self, voice_type='zh_female_wanwanxiaohe_moon_bigtts', encoding='wav', max_concurrency=2):
|
|||
|
self.voice_type = voice_type
|
|||
|
self.encoding = encoding
|
|||
|
self.app_key = Config.HS_APP_ID
|
|||
|
self.access_token = Config.HS_ACCESS_TOKEN
|
|||
|
self.endpoint = "wss://openspeech.bytedance.com/api/v3/tts/unidirectional/stream"
|
|||
|
self.audio_queue = Queue()
|
|||
|
self.max_concurrency = max_concurrency # 最大并发数
|
|||
|
self.semaphore = asyncio.Semaphore(max_concurrency) # 并发控制信号量
|
|||
|
|
|||
|
@staticmethod
|
|||
|
def get_resource_id(voice: str) -> str:
|
|||
|
if voice.startswith("S_"):
|
|||
|
return "volc.megatts.default"
|
|||
|
return "volc.service_type.10029"
|
|||
|
|
|||
|
async def synthesize_stream(self, text_stream, audio_callback):
|
|||
|
"""
|
|||
|
流式合成语音
|
|||
|
|
|||
|
Args:
|
|||
|
text_stream: 文本流生成器
|
|||
|
audio_callback: 音频数据回调函数,接收音频片段
|
|||
|
"""
|
|||
|
# 为每个文本片段创建一个WebSocket连接,但限制并发数
|
|||
|
tasks = []
|
|||
|
for text in text_stream:
|
|||
|
if text.strip(): # 忽略空文本
|
|||
|
task = asyncio.create_task(self._synthesize_single_with_semaphore(text, audio_callback))
|
|||
|
tasks.append(task)
|
|||
|
|
|||
|
# 等待所有任务完成
|
|||
|
await asyncio.gather(*tasks)
|
|||
|
|
|||
|
async def _synthesize_single_with_semaphore(self, text, audio_callback):
|
|||
|
"""使用信号量控制并发数的单个文本合成"""
|
|||
|
async with self.semaphore: # 获取信号量,限制并发数
|
|||
|
await self._synthesize_single(text, audio_callback)
|
|||
|
|
|||
|
async def _synthesize_single(self, text, audio_callback):
|
|||
|
"""合成单个文本片段"""
|
|||
|
headers = {
|
|||
|
"X-Api-App-Key": self.app_key,
|
|||
|
"X-Api-Access-Key": self.access_token,
|
|||
|
"X-Api-Resource-Id": self.get_resource_id(self.voice_type),
|
|||
|
"X-Api-Connect-Id": str(uuid.uuid4()),
|
|||
|
}
|
|||
|
|
|||
|
websocket = await websockets.connect(
|
|||
|
self.endpoint, additional_headers=headers, max_size=10 * 1024 * 1024
|
|||
|
)
|
|||
|
|
|||
|
try:
|
|||
|
request = {
|
|||
|
"user": {
|
|||
|
"uid": str(uuid.uuid4()),
|
|||
|
},
|
|||
|
"req_params": {
|
|||
|
"speaker": self.voice_type,
|
|||
|
"audio_params": {
|
|||
|
"format": self.encoding,
|
|||
|
"sample_rate": 24000,
|
|||
|
"enable_timestamp": True,
|
|||
|
},
|
|||
|
"text": text,
|
|||
|
"additions": json.dumps({"disable_markdown_filter": False}),
|
|||
|
},
|
|||
|
}
|
|||
|
|
|||
|
# 发送请求
|
|||
|
await full_client_request(websocket, json.dumps(request).encode())
|
|||
|
|
|||
|
# 接收音频数据
|
|||
|
audio_data = bytearray()
|
|||
|
while True:
|
|||
|
msg = await receive_message(websocket)
|
|||
|
|
|||
|
if msg.type == MsgType.FullServerResponse:
|
|||
|
if msg.event == EventType.SessionFinished:
|
|||
|
break
|
|||
|
elif msg.type == MsgType.AudioOnlyServer:
|
|||
|
audio_data.extend(msg.payload)
|
|||
|
else:
|
|||
|
raise RuntimeError(f"TTS conversion failed: {msg}")
|
|||
|
|
|||
|
# 通过回调函数返回音频数据
|
|||
|
if audio_data:
|
|||
|
audio_callback(audio_data)
|
|||
|
|
|||
|
finally:
|
|||
|
await websocket.close()
|
|||
|
|
|||
|
|
|||
|
def audio_callback(audio_data):
|
|||
|
"""
|
|||
|
音频数据回调函数,将音频数据保存到文件
|
|||
|
"""
|
|||
|
# 创建输出目录
|
|||
|
output_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "output")
|
|||
|
os.makedirs(output_dir, exist_ok=True)
|
|||
|
|
|||
|
# 生成文件名
|
|||
|
filename = f"streaming_tts_{uuid.uuid4().hex[:8]}.wav"
|
|||
|
filepath = os.path.join(output_dir, filename)
|
|||
|
|
|||
|
# 保存音频文件
|
|||
|
with open(filepath, "wb") as f:
|
|||
|
f.write(audio_data)
|
|||
|
|
|||
|
print(f"音频片段已保存到: {filepath}")
|
|||
|
|
|||
|
|
|||
|
async def test_streaming_tts():
|
|||
|
"""
|
|||
|
测试流式TTS功能
|
|||
|
"""
|
|||
|
# 创建TTS实例
|
|||
|
tts = StreamingVolcanoTTS()
|
|||
|
|
|||
|
# 准备测试文本流
|
|||
|
test_texts = [
|
|||
|
"你好,我是火山引擎的语音合成服务。",
|
|||
|
"这是一个流式语音合成的测试。",
|
|||
|
"我们将文本分成多个片段进行合成。",
|
|||
|
"这样可以减少等待时间,提高用户体验。"
|
|||
|
]
|
|||
|
|
|||
|
print("开始测试流式TTS...")
|
|||
|
print(f"测试文本: {test_texts}")
|
|||
|
|
|||
|
# 调用流式合成
|
|||
|
await tts.synthesize_stream(test_texts, audio_callback)
|
|||
|
|
|||
|
print("流式TTS测试完成!")
|
|||
|
|
|||
|
|
|||
|
def main():
|
|||
|
"""
|
|||
|
主函数,运行测试
|
|||
|
"""
|
|||
|
asyncio.run(test_streaming_tts())
|
|||
|
|
|||
|
|
|||
|
if __name__ == "__main__":
|
|||
|
main()
|