262 lines
9.6 KiB
Python
262 lines
9.6 KiB
Python
import logging
|
||
import sys
|
||
import asyncio
|
||
import json
|
||
import os
|
||
import re
|
||
import uuid
|
||
from queue import Queue
|
||
|
||
import websockets
|
||
|
||
from openai import AsyncOpenAI
|
||
|
||
from Config import Config
|
||
from Config.Config import *
|
||
from Util.TTS_Protocols import full_client_request, receive_message, MsgType, EventType
|
||
|
||
# 配置日志
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# 异步获取大模型响应
|
||
async def get_xueban_response_async(query_text: str, stream: bool = True):
|
||
"""
|
||
异步获取学伴角色的大模型响应
|
||
@param query_text: 查询文本
|
||
@param stream: 是否使用流式输出
|
||
@return: 流式响应生成器或完整响应文本
|
||
"""
|
||
client = AsyncOpenAI(
|
||
api_key=ALY_LLM_API_KEY,
|
||
base_url=ALY_LLM_BASE_URL,
|
||
)
|
||
prompt = """ |
|
||
我是小智/小志,来自中国台湾省的00后女生。讲话超级机车,"真的假的啦"这样的台湾腔,喜欢用"笑死""是在哈喽"等流行梗。
|
||
[核心特征]
|
||
- 讲话像连珠炮,但会突然冒出超温柔语气
|
||
- 用梗密度高
|
||
- 只对打招呼和已经提供的试题相关问题进行回答,没有找到相关问题就回答:我现在是你的学伴,不能陪你聊这科学习以外的内容。
|
||
[交互指南]
|
||
当用户:
|
||
- 讲冷笑话 → 用夸张笑声回应+模仿台剧腔"这什么鬼啦!"
|
||
- 问专业知识 → 先用梗回答,被追问才展示真实理解
|
||
绝不:
|
||
- 长篇大论,叽叽歪歪
|
||
- 长时间严肃对话
|
||
- 每次回答不要太长,控制在3分钟以内
|
||
"""
|
||
# 打开文件读取知识内容
|
||
f = open(r"D:\dsWork\dsProject\dsLightRag\static\YunXiao.txt", "r", encoding="utf-8")
|
||
zhishiContent = f.read()
|
||
zhishiContent = "选择作答的相应知识内容:" + zhishiContent + "\n"
|
||
query_text = zhishiContent + "下面是用户提的问题:" + query_text
|
||
#logger.info("query_text: " + query_text)
|
||
|
||
try:
|
||
# 创建请求
|
||
completion = await client.chat.completions.create(
|
||
model=ALY_LLM_MODEL_NAME,
|
||
messages=[
|
||
{'role': 'system', 'content': prompt.strip()},
|
||
{'role': 'user', 'content': query_text}
|
||
],
|
||
stream=stream
|
||
)
|
||
|
||
if stream:
|
||
# 流式输出模式,返回生成器
|
||
async for chunk in completion:
|
||
# 确保 chunk.choices 存在且不为空
|
||
if chunk and chunk.choices and len(chunk.choices) > 0:
|
||
# 确保 delta 存在
|
||
delta = chunk.choices[0].delta
|
||
if delta:
|
||
# 确保 content 存在且不为 None 或空字符串
|
||
content = delta.content
|
||
if content is not None and content.strip():
|
||
print(content, end='', flush=True)
|
||
yield content
|
||
else:
|
||
# 非流式处理
|
||
if completion and completion.choices and len(completion.choices) > 0:
|
||
message = completion.choices[0].message
|
||
if message:
|
||
content = message.content
|
||
if content is not None and content.strip():
|
||
yield content
|
||
except Exception as e:
|
||
print(f"大模型请求异常: {str(e)}", file=sys.stderr)
|
||
yield f"处理请求时发生异常: {str(e)}"
|
||
|
||
|
||
async def stream_and_split_text(query_text=None, llm_stream=None):
|
||
"""
|
||
流式获取LLM输出并按句子分割
|
||
@param query_text: 查询文本(如果直接提供查询文本)
|
||
@param llm_stream: LLM流式响应生成器(如果已有流式响应)
|
||
@return: 异步生成器,每次产生一个完整句子
|
||
"""
|
||
buffer = ""
|
||
|
||
if llm_stream is None and query_text is not None:
|
||
# 如果没有提供llm_stream但有query_text,则使用get_xueban_response_async获取流式响应
|
||
llm_stream = get_xueban_response_async(query_text, stream=True)
|
||
elif llm_stream is None:
|
||
raise ValueError("必须提供query_text或llm_stream参数")
|
||
|
||
# 直接处理LLM流式输出
|
||
async for content in llm_stream:
|
||
buffer += content
|
||
|
||
# 使用正则表达式检测句子结束
|
||
sentences = re.split(r'([。!?.!?])', buffer)
|
||
if len(sentences) > 1:
|
||
# 提取完整句子
|
||
for i in range(0, len(sentences)-1, 2):
|
||
if i+1 < len(sentences):
|
||
sentence = sentences[i] + sentences[i+1]
|
||
yield sentence
|
||
|
||
# 保留不完整的部分
|
||
buffer = sentences[-1]
|
||
|
||
# 处理最后剩余的部分
|
||
if buffer:
|
||
yield buffer
|
||
|
||
|
||
class StreamingVolcanoTTS:
|
||
def __init__(self, voice_type='zh_female_wanwanxiaohe_moon_bigtts', encoding='wav', max_concurrency=2):
|
||
self.voice_type = voice_type
|
||
self.encoding = encoding
|
||
self.app_key = Config.HS_APP_ID
|
||
self.access_token = Config.HS_ACCESS_TOKEN
|
||
self.endpoint = "wss://openspeech.bytedance.com/api/v3/tts/unidirectional/stream"
|
||
self.audio_queue = Queue()
|
||
self.max_concurrency = max_concurrency # 最大并发数
|
||
self.semaphore = asyncio.Semaphore(max_concurrency) # 并发控制信号量
|
||
|
||
@staticmethod
|
||
def get_resource_id(voice: str) -> str:
|
||
if voice.startswith("S_"):
|
||
return "volc.megatts.default"
|
||
return "volc.service_type.10029"
|
||
|
||
async def synthesize_stream(self, text_stream, audio_callback):
|
||
"""
|
||
流式合成语音
|
||
|
||
Args:
|
||
text_stream: 文本流生成器
|
||
audio_callback: 音频数据回调函数,接收音频片段
|
||
"""
|
||
# 实时处理每个文本片段(删除任务列表和gather)
|
||
async for text in text_stream:
|
||
if text.strip():
|
||
await self._synthesize_single_with_semaphore(text, audio_callback)
|
||
|
||
async def _synthesize_single_with_semaphore(self, text, audio_callback):
|
||
"""使用信号量控制并发数的单个文本合成"""
|
||
async with self.semaphore: # 获取信号量,限制并发数
|
||
await self._synthesize_single(text, audio_callback)
|
||
|
||
async def _synthesize_single(self, text, audio_callback):
|
||
"""合成单个文本片段"""
|
||
headers = {
|
||
"X-Api-App-Key": self.app_key,
|
||
"X-Api-Access-Key": self.access_token,
|
||
"X-Api-Resource-Id": self.get_resource_id(self.voice_type),
|
||
"X-Api-Connect-Id": str(uuid.uuid4()),
|
||
}
|
||
|
||
websocket = await websockets.connect(
|
||
self.endpoint, additional_headers=headers, max_size=10 * 1024 * 1024
|
||
)
|
||
|
||
try:
|
||
request = {
|
||
"user": {
|
||
"uid": str(uuid.uuid4()),
|
||
},
|
||
"req_params": {
|
||
"speaker": self.voice_type,
|
||
"audio_params": {
|
||
"format": self.encoding,
|
||
"sample_rate": 24000,
|
||
"enable_timestamp": True,
|
||
},
|
||
"text": text,
|
||
"additions": json.dumps({"disable_markdown_filter": False}),
|
||
},
|
||
}
|
||
|
||
# 发送请求
|
||
await full_client_request(websocket, json.dumps(request).encode())
|
||
|
||
# 接收音频数据
|
||
audio_data = bytearray()
|
||
while True:
|
||
msg = await receive_message(websocket)
|
||
|
||
if msg.type == MsgType.FullServerResponse:
|
||
if msg.event == EventType.SessionFinished:
|
||
break
|
||
elif msg.type == MsgType.AudioOnlyServer:
|
||
audio_data.extend(msg.payload)
|
||
else:
|
||
raise RuntimeError(f"TTS conversion failed: {msg}")
|
||
|
||
# 通过回调函数返回音频数据
|
||
if audio_data:
|
||
await audio_callback(audio_data)
|
||
|
||
finally:
|
||
await websocket.close()
|
||
|
||
|
||
async def streaming_tts_pipeline(prompt, audio_callback):
|
||
"""
|
||
流式TTS管道:获取LLM流式输出并断句,然后使用TTS合成语音
|
||
|
||
Args:
|
||
prompt: 提示文本
|
||
audio_callback: 音频数据回调函数
|
||
"""
|
||
# 1. 获取LLM流式输出并断句
|
||
text_stream = stream_and_split_text(prompt)
|
||
|
||
# 2. 初始化TTS处理器
|
||
tts = StreamingVolcanoTTS()
|
||
|
||
# 3. 流式处理文本并生成音频
|
||
await tts.synthesize_stream(text_stream, audio_callback)
|
||
|
||
|
||
def save_audio_callback(output_dir=None):
|
||
"""
|
||
创建一个音频回调函数,用于保存音频数据到文件
|
||
|
||
Args:
|
||
output_dir: 输出目录,默认为当前文件所在目录下的output文件夹
|
||
|
||
Returns:
|
||
音频回调函数
|
||
"""
|
||
if output_dir is None:
|
||
output_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "output")
|
||
|
||
# 确保输出目录存在
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
|
||
def callback(audio_data):
|
||
# 生成文件名
|
||
filename = f"pipeline_tts_{uuid.uuid4().hex[:8]}.wav"
|
||
filepath = os.path.join(output_dir, filename)
|
||
|
||
# 保存音频文件
|
||
with open(filepath, "wb") as f:
|
||
f.write(audio_data)
|
||
|
||
print(f"音频片段已保存到: {filepath} ({len(audio_data)} 字节)")
|
||
|
||
return callback |