'commit'
This commit is contained in:
@@ -4,7 +4,7 @@ import tempfile
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import APIRouter, Request, File, UploadFile
|
||||
from fastapi import APIRouter, Request, File, UploadFile, WebSocket, WebSocketDisconnect
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
# 创建路由路由器
|
||||
@@ -17,14 +17,14 @@ logger = logging.getLogger(__name__)
|
||||
from Util.XueBanUtil import get_xueban_response_async
|
||||
from Util.ASRClient import ASRClient
|
||||
from Util.ObsUtil import ObsUploader
|
||||
# 新增导入TTSService
|
||||
from Util.TTSService import TTSService
|
||||
|
||||
# 导入TTS管道
|
||||
from Util.TTS_Pipeline import stream_and_split_text, StreamingVolcanoTTS
|
||||
|
||||
# 保留原有的HTTP接口,用于向后兼容
|
||||
@router.post("/xueban/upload-audio")
|
||||
async def upload_audio(file: UploadFile = File(...)):
|
||||
"""
|
||||
上传音频文件并进行ASR处理
|
||||
上传音频文件并进行ASR处理 - 原有接口,用于向后兼容
|
||||
- 参数: file - 音频文件
|
||||
- 返回: JSON包含识别结果
|
||||
"""
|
||||
@@ -59,13 +59,32 @@ async def upload_audio(file: UploadFile = File(...)):
|
||||
feedback_text += chunk
|
||||
logger.info(f"大模型反馈生成完成: {feedback_text}")
|
||||
|
||||
# 使用TTS生成语音
|
||||
tts_service = TTSService()
|
||||
tts_temp_file = os.path.join(tempfile.gettempdir(), f"tts_{timestamp}.mp3")
|
||||
success = tts_service.synthesize(feedback_text, output_file=tts_temp_file)
|
||||
if not success:
|
||||
raise Exception("TTS语音合成失败")
|
||||
logger.info(f"TTS语音合成成功,文件保存至: {tts_temp_file}")
|
||||
# 使用流式TTS生成语音
|
||||
import io
|
||||
audio_chunks = []
|
||||
|
||||
# 定义音频回调函数,收集音频块
|
||||
def audio_callback(audio_chunk):
|
||||
audio_chunks.append(audio_chunk)
|
||||
|
||||
# 获取LLM流式输出并断句
|
||||
text_stream = stream_and_split_text(asr_result['text'])
|
||||
|
||||
# 初始化TTS处理器
|
||||
tts = StreamingVolcanoTTS(max_concurrency=2)
|
||||
|
||||
# 流式处理文本并生成音频
|
||||
await tts.synthesize_stream(text_stream, audio_callback)
|
||||
|
||||
# 合并所有音频块
|
||||
if audio_chunks:
|
||||
tts_temp_file = os.path.join(tempfile.gettempdir(), f"tts_{timestamp}.mp3")
|
||||
with open(tts_temp_file, "wb") as f:
|
||||
for chunk in audio_chunks:
|
||||
f.write(chunk)
|
||||
logger.info(f"TTS语音合成成功,文件保存至: {tts_temp_file}")
|
||||
else:
|
||||
raise Exception("TTS语音合成失败,未生成音频数据")
|
||||
|
||||
# 上传TTS音频文件到OBS
|
||||
tts_audio_url = upload_file_to_obs(tts_temp_file)
|
||||
@@ -90,7 +109,119 @@ async def upload_audio(file: UploadFile = File(...)):
|
||||
"message": f"音频处理失败: {str(e)}"
|
||||
}, status_code=500)
|
||||
|
||||
# 新增WebSocket接口,用于流式处理
|
||||
@router.websocket("/xueban/streaming-chat")
|
||||
async def streaming_chat(websocket: WebSocket):
|
||||
await websocket.accept()
|
||||
logger.info("WebSocket连接已接受")
|
||||
try:
|
||||
# 接收用户音频文件
|
||||
logger.info("等待接收音频数据...")
|
||||
data = await websocket.receive_json()
|
||||
logger.info(f"接收到数据类型: {type(data)}")
|
||||
logger.info(f"接收到数据内容: {data.keys() if isinstance(data, dict) else '非字典类型'}")
|
||||
|
||||
# 检查数据格式
|
||||
if not isinstance(data, dict):
|
||||
logger.error(f"接收到的数据不是字典类型,而是: {type(data)}")
|
||||
await websocket.send_json({"type": "error", "message": "数据格式错误"})
|
||||
return
|
||||
|
||||
audio_data = data.get("audio_data")
|
||||
logger.info(f"音频数据是否存在: {audio_data is not None}")
|
||||
logger.info(f"音频数据长度: {len(audio_data) if audio_data else 0}")
|
||||
|
||||
if not audio_data:
|
||||
logger.error("未收到音频数据")
|
||||
await websocket.send_json({"type": "error", "message": "未收到音频数据"})
|
||||
return
|
||||
|
||||
# 保存临时音频文件
|
||||
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
temp_file_path = os.path.join(tempfile.gettempdir(), f"temp_audio_{timestamp}.wav")
|
||||
logger.info(f"保存临时音频文件到: {temp_file_path}")
|
||||
|
||||
# 解码base64音频数据并保存
|
||||
import base64
|
||||
try:
|
||||
with open(temp_file_path, "wb") as f:
|
||||
f.write(base64.b64decode(audio_data))
|
||||
logger.info("音频文件保存完成")
|
||||
except Exception as e:
|
||||
logger.error(f"音频文件保存失败: {str(e)}")
|
||||
await websocket.send_json({"type": "error", "message": f"音频文件保存失败: {str(e)}"})
|
||||
return
|
||||
|
||||
# 处理ASR
|
||||
logger.info("开始ASR处理...")
|
||||
try:
|
||||
asr_result = await process_asr(temp_file_path)
|
||||
logger.info(f"ASR处理完成,结果: {asr_result['text']}")
|
||||
os.remove(temp_file_path) # 删除临时文件
|
||||
except Exception as e:
|
||||
logger.error(f"ASR处理失败: {str(e)}")
|
||||
await websocket.send_json({"type": "error", "message": f"ASR处理失败: {str(e)}"})
|
||||
if os.path.exists(temp_file_path):
|
||||
os.remove(temp_file_path) # 确保删除临时文件
|
||||
return
|
||||
|
||||
# 发送ASR结果给前端
|
||||
logger.info("发送ASR结果给前端")
|
||||
try:
|
||||
await websocket.send_json({
|
||||
"type": "asr_result",
|
||||
"text": asr_result['text']
|
||||
})
|
||||
logger.info("ASR结果发送成功")
|
||||
except Exception as e:
|
||||
logger.error(f"发送ASR结果失败: {str(e)}")
|
||||
return
|
||||
|
||||
# 定义音频回调函数,将音频块发送给前端
|
||||
async def audio_callback(audio_chunk):
|
||||
logger.info(f"发送音频块,大小: {len(audio_chunk)}")
|
||||
try:
|
||||
await websocket.send_bytes(audio_chunk)
|
||||
logger.info("音频块发送成功")
|
||||
except Exception as e:
|
||||
logger.error(f"发送音频块失败: {str(e)}")
|
||||
raise
|
||||
|
||||
# 获取LLM流式输出并断句
|
||||
logger.info("开始LLM处理和TTS合成...")
|
||||
try:
|
||||
text_stream = stream_and_split_text(asr_result['text'])
|
||||
|
||||
# 初始化TTS处理器
|
||||
tts = StreamingVolcanoTTS(max_concurrency=2)
|
||||
|
||||
# 流式处理文本并生成音频
|
||||
await tts.synthesize_stream(text_stream, audio_callback)
|
||||
logger.info("TTS合成完成")
|
||||
except Exception as e:
|
||||
logger.error(f"TTS合成失败: {str(e)}")
|
||||
await websocket.send_json({"type": "error", "message": f"TTS合成失败: {str(e)}"})
|
||||
return
|
||||
|
||||
# 发送结束信号
|
||||
logger.info("发送结束信号")
|
||||
try:
|
||||
await websocket.send_json({"type": "end"})
|
||||
logger.info("结束信号发送成功")
|
||||
except Exception as e:
|
||||
logger.error(f"发送结束信号失败: {str(e)}")
|
||||
return
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.info("客户端断开连接")
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket处理失败: {str(e)}")
|
||||
try:
|
||||
await websocket.send_json({"type": "error", "message": str(e)})
|
||||
except:
|
||||
logger.error("发送错误消息失败")
|
||||
|
||||
# 原有的辅助函数保持不变
|
||||
async def process_asr(audio_path: str) -> dict:
|
||||
"""
|
||||
调用ASR服务处理音频文件
|
||||
|
Reference in New Issue
Block a user