Files
dsProject/dsLightRag/Routes/XueBanRoute.py

200 lines
7.8 KiB
Python
Raw Normal View History

2025-08-22 08:32:39 +08:00
import logging
2025-08-22 08:35:57 +08:00
import os
import tempfile
2025-08-22 10:01:04 +08:00
import uuid
2025-08-22 08:35:57 +08:00
from datetime import datetime
2025-08-22 08:32:39 +08:00
2025-08-31 12:50:37 +08:00
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
2025-08-31 12:43:43 +08:00
from Util.ASRClient import ASRClient
from Util.ObsUtil import ObsUploader
2025-08-31 13:05:50 +08:00
from Util.XueBanUtil import get_xueban_response_async, stream_and_split_text, StreamingVolcanoTTS
2025-08-22 08:32:39 +08:00
# 创建路由路由器
router = APIRouter(prefix="/api", tags=["学伴"])
# 配置日志
logger = logging.getLogger(__name__)
2025-08-31 10:22:31 +08:00
# 新增WebSocket接口用于流式处理
@router.websocket("/xueban/streaming-chat")
async def streaming_chat(websocket: WebSocket):
await websocket.accept()
logger.info("WebSocket连接已接受")
try:
# 接收用户音频文件
logger.info("等待接收音频数据...")
data = await websocket.receive_json()
logger.info(f"接收到数据类型: {type(data)}")
logger.info(f"接收到数据内容: {data.keys() if isinstance(data, dict) else '非字典类型'}")
# 检查数据格式
if not isinstance(data, dict):
logger.error(f"接收到的数据不是字典类型,而是: {type(data)}")
await websocket.send_json({"type": "error", "message": "数据格式错误"})
return
audio_data = data.get("audio_data")
logger.info(f"音频数据是否存在: {audio_data is not None}")
logger.info(f"音频数据长度: {len(audio_data) if audio_data else 0}")
if not audio_data:
logger.error("未收到音频数据")
await websocket.send_json({"type": "error", "message": "未收到音频数据"})
return
# 保存临时音频文件
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
temp_file_path = os.path.join(tempfile.gettempdir(), f"temp_audio_{timestamp}.wav")
logger.info(f"保存临时音频文件到: {temp_file_path}")
# 解码base64音频数据并保存
import base64
try:
with open(temp_file_path, "wb") as f:
f.write(base64.b64decode(audio_data))
logger.info("音频文件保存完成")
except Exception as e:
logger.error(f"音频文件保存失败: {str(e)}")
await websocket.send_json({"type": "error", "message": f"音频文件保存失败: {str(e)}"})
return
# 处理ASR
logger.info("开始ASR处理...")
try:
asr_result = await process_asr(temp_file_path)
logger.info(f"ASR处理完成结果: {asr_result['text']}")
os.remove(temp_file_path) # 删除临时文件
except Exception as e:
logger.error(f"ASR处理失败: {str(e)}")
await websocket.send_json({"type": "error", "message": f"ASR处理失败: {str(e)}"})
if os.path.exists(temp_file_path):
os.remove(temp_file_path) # 确保删除临时文件
return
# 发送ASR结果给前端
logger.info("发送ASR结果给前端")
try:
await websocket.send_json({
"type": "asr_result",
"text": asr_result['text']
})
logger.info("ASR结果发送成功")
except Exception as e:
logger.error(f"发送ASR结果失败: {str(e)}")
return
2025-08-31 12:53:11 +08:00
2025-08-31 12:41:05 +08:00
# 定义音频回调函数,将音频块发送给前端
async def audio_callback(audio_chunk):
logger.info(f"发送音频块,大小: {len(audio_chunk)}")
try:
await websocket.send_bytes(audio_chunk)
logger.info("音频块发送成功")
except Exception as e:
logger.error(f"发送音频块失败: {str(e)}")
raise
# 实时获取LLM流式输出并处理
logger.info("开始LLM流式处理和TTS合成...")
2025-08-31 10:22:31 +08:00
try:
2025-08-31 13:15:58 +08:00
# 获取LLM流式响应
llm_stream = get_xueban_response_async(asr_result['text'], stream=True)
# 使用stream_and_split_text处理流式响应并断句
text_stream = stream_and_split_text(llm_stream=llm_stream)
2025-08-31 10:22:31 +08:00
# 初始化TTS处理器
2025-08-31 12:41:05 +08:00
tts = StreamingVolcanoTTS(max_concurrency=1)
2025-08-31 10:22:31 +08:00
2025-08-31 13:05:50 +08:00
# 异步迭代文本流按句合成TTS
2025-08-31 12:41:05 +08:00
async for text_chunk in text_stream:
2025-08-31 13:05:50 +08:00
logger.info(f"正在处理句子: {text_chunk}")
2025-08-31 12:41:05 +08:00
await tts._synthesize_single_with_semaphore(text_chunk, audio_callback)
2025-08-31 10:22:31 +08:00
logger.info("TTS合成完成")
except Exception as e:
logger.error(f"TTS合成失败: {str(e)}")
await websocket.send_json({"type": "error", "message": f"TTS合成失败: {str(e)}"})
return
# 发送结束信号
logger.info("发送结束信号")
try:
await websocket.send_json({"type": "end"})
logger.info("结束信号发送成功")
except Exception as e:
logger.error(f"发送结束信号失败: {str(e)}")
return
except WebSocketDisconnect:
logger.info("客户端断开连接")
except Exception as e:
logger.error(f"WebSocket处理失败: {str(e)}")
try:
await websocket.send_json({"type": "error", "message": str(e)})
except:
logger.error("发送错误消息失败")
2025-08-22 10:01:04 +08:00
2025-08-31 10:22:31 +08:00
# 原有的辅助函数保持不变
2025-08-22 08:35:57 +08:00
async def process_asr(audio_path: str) -> dict:
"""
调用ASR服务处理音频文件
:param audio_path: 音频文件路径
:return: 识别结果字典
"""
2025-08-22 10:01:04 +08:00
try:
# 上传文件到华为云OBS
audio_url = upload_file_to_obs(audio_path)
# 创建ASR客户端实例
asr_client = ASRClient()
# 设置音频文件URL
asr_client.file_url = audio_url
# 处理ASR任务并获取文本结果
text_result = asr_client.process_task()
# 构建返回结果
return {
"text": text_result,
"confidence": 1.0, # 实际应用中这里应该从ASR结果中获取置信度
"audio_duration": os.path.getsize(audio_path) / 1024 / 16 # 估算音频时长
}
except Exception as e:
logger.error(f"ASR处理失败: {str(e)}")
raise
def upload_file_to_obs(file_path: str) -> str:
"""
将本地文件上传到华为云OBS并返回URL
:param file_path: 本地文件路径
:return: OBS上的文件URL
"""
try:
# 创建OBS上传器实例
obs_uploader = ObsUploader()
# 生成UUID文件名
file_uuid = str(uuid.uuid4())
file_ext = os.path.splitext(file_path)[1].lower()
# 确保文件扩展名为.wav
if file_ext != '.wav':
file_ext = '.wav'
# 构建对象键(前缀 + UUID + .wav
object_key = f"HuangHai/XueBan/{file_uuid}{file_ext}"
# 上传文件
success, result = obs_uploader.upload_file(object_key, file_path)
if success:
logger.info(f"文件上传成功: {file_path} -> {object_key}")
# 构建文件URL假设OBS桶是公开可读的
# 实际应用中URL格式可能需要根据华为云OBS的具体配置进行调整
from Config.Config import OBS_SERVER, OBS_BUCKET
file_url = f"https://{OBS_BUCKET}.{OBS_SERVER}/{object_key}"
return file_url
else:
error_msg = f"文件上传失败: {result}"
logger.error(error_msg)
raise Exception(error_msg)
except Exception as e:
logger.error(f"上传文件到OBS失败: {str(e)}")
2025-08-31 12:50:37 +08:00
raise