Files
dsProject/dsLightRag/Routes/XueBanRoute.py
2025-08-31 12:43:43 +08:00

355 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import logging
import os
import tempfile
import uuid
from datetime import datetime
from fastapi import APIRouter, Request, File, UploadFile, WebSocket, WebSocketDisconnect
from fastapi.responses import JSONResponse
from Util.XueBanUtil import get_xueban_response_async
from Util.ASRClient import ASRClient
from Util.ObsUtil import ObsUploader
from Util.TTS_Pipeline import stream_and_split_text, StreamingVolcanoTTS
# 创建路由路由器
router = APIRouter(prefix="/api", tags=["学伴"])
# 配置日志
logger = logging.getLogger(__name__)
# 保留原有的HTTP接口用于向后兼容
@router.post("/xueban/upload-audio")
async def upload_audio(file: UploadFile = File(...)):
"""
上传音频文件并进行ASR处理 - 原有接口,用于向后兼容
- 参数: file - 音频文件
- 返回: JSON包含识别结果
"""
try:
# 记录日志
logger.info(f"接收到音频文件: {file.filename}")
# 保存临时文件
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
file_ext = os.path.splitext(file.filename)[1]
temp_file_name = f"temp_audio_{timestamp}{file_ext}"
temp_file_path = os.path.join(tempfile.gettempdir(), temp_file_name)
with open(temp_file_path, "wb") as f:
content = await file.read()
f.write(content)
logger.info(f"音频文件已保存至临时目录: {temp_file_path}")
# 调用ASR服务进行处理
asr_result = await process_asr(temp_file_path)
# 删除临时文件
os.remove(temp_file_path)
logger.info(f"临时文件已删除: {temp_file_path}")
# 使用大模型生成反馈
logger.info(f"使用大模型生成反馈,输入文本: {asr_result['text']}")
response_generator = get_xueban_response_async(asr_result['text'], stream=False)
feedback_text = ""
async for chunk in response_generator:
feedback_text += chunk
logger.info(f"大模型反馈生成完成: {feedback_text}")
# 使用流式TTS生成语音
import io
audio_chunks = []
# 定义音频回调函数,收集音频块
def audio_callback(audio_chunk):
audio_chunks.append(audio_chunk)
# 获取LLM流式输出并断句
text_stream = stream_and_split_text(asr_result['text'])
# 初始化TTS处理器
tts = StreamingVolcanoTTS(max_concurrency=2)
# 流式处理文本并生成音频
await tts.synthesize_stream(text_stream, audio_callback)
# 合并所有音频块
if audio_chunks:
tts_temp_file = os.path.join(tempfile.gettempdir(), f"tts_{timestamp}.mp3")
with open(tts_temp_file, "wb") as f:
for chunk in audio_chunks:
f.write(chunk)
logger.info(f"TTS语音合成成功文件保存至: {tts_temp_file}")
else:
raise Exception("TTS语音合成失败未生成音频数据")
# 上传TTS音频文件到OBS
tts_audio_url = upload_file_to_obs(tts_temp_file)
os.remove(tts_temp_file) # 删除临时TTS文件
logger.info(f"TTS文件已上传至OBS: {tts_audio_url}")
# 返回结果包含ASR文本和TTS音频URL
return JSONResponse(content={
"success": True,
"message": "音频处理和语音反馈生成成功",
"data": {
"asr_text": asr_result['text'],
"feedback_text": feedback_text,
"audio_url": tts_audio_url
}
})
except Exception as e:
logger.error(f"音频处理失败: {str(e)}")
return JSONResponse(content={
"success": False,
"message": f"音频处理失败: {str(e)}"
}, status_code=500)
# 新增WebSocket接口用于流式处理
@router.websocket("/xueban/streaming-chat")
async def streaming_chat(websocket: WebSocket):
await websocket.accept()
logger.info("WebSocket连接已接受")
try:
# 接收用户音频文件
logger.info("等待接收音频数据...")
data = await websocket.receive_json()
logger.info(f"接收到数据类型: {type(data)}")
logger.info(f"接收到数据内容: {data.keys() if isinstance(data, dict) else '非字典类型'}")
# 检查数据格式
if not isinstance(data, dict):
logger.error(f"接收到的数据不是字典类型,而是: {type(data)}")
await websocket.send_json({"type": "error", "message": "数据格式错误"})
return
audio_data = data.get("audio_data")
logger.info(f"音频数据是否存在: {audio_data is not None}")
logger.info(f"音频数据长度: {len(audio_data) if audio_data else 0}")
if not audio_data:
logger.error("未收到音频数据")
await websocket.send_json({"type": "error", "message": "未收到音频数据"})
return
# 保存临时音频文件
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
temp_file_path = os.path.join(tempfile.gettempdir(), f"temp_audio_{timestamp}.wav")
logger.info(f"保存临时音频文件到: {temp_file_path}")
# 解码base64音频数据并保存
import base64
try:
with open(temp_file_path, "wb") as f:
f.write(base64.b64decode(audio_data))
logger.info("音频文件保存完成")
except Exception as e:
logger.error(f"音频文件保存失败: {str(e)}")
await websocket.send_json({"type": "error", "message": f"音频文件保存失败: {str(e)}"})
return
# 处理ASR
logger.info("开始ASR处理...")
try:
asr_result = await process_asr(temp_file_path)
logger.info(f"ASR处理完成结果: {asr_result['text']}")
os.remove(temp_file_path) # 删除临时文件
except Exception as e:
logger.error(f"ASR处理失败: {str(e)}")
await websocket.send_json({"type": "error", "message": f"ASR处理失败: {str(e)}"})
if os.path.exists(temp_file_path):
os.remove(temp_file_path) # 确保删除临时文件
return
# 发送ASR结果给前端
logger.info("发送ASR结果给前端")
try:
await websocket.send_json({
"type": "asr_result",
"text": asr_result['text']
})
logger.info("ASR结果发送成功")
except Exception as e:
logger.error(f"发送ASR结果失败: {str(e)}")
return
# 定义音频回调函数,将音频块发送给前端
async def audio_callback(audio_chunk):
logger.info(f"发送音频块,大小: {len(audio_chunk)}")
try:
await websocket.send_bytes(audio_chunk)
logger.info("音频块发送成功")
except Exception as e:
logger.error(f"发送音频块失败: {str(e)}")
raise
# 获取学伴响应内容(包含题目信息)
logger.info("获取学伴响应内容...")
llm_chunks = []
async for chunk in get_xueban_response_async(asr_result['text'], stream=True):
llm_chunks.append(chunk)
full_llm_response = ''.join(llm_chunks)
logger.info(f"学伴响应内容: {full_llm_response}")
# 定义音频回调函数,将音频块发送给前端
async def audio_callback(audio_chunk):
logger.info(f"发送音频块,大小: {len(audio_chunk)}")
try:
await websocket.send_bytes(audio_chunk)
logger.info("音频块发送成功")
except Exception as e:
logger.error(f"发送音频块失败: {str(e)}")
raise
# 实时获取LLM流式输出并处理
logger.info("开始LLM流式处理和TTS合成...")
try:
# 直接将LLM流式响应接入TTS
llm_stream = get_xueban_response_async(asr_result['text'], stream=True)
text_stream = stream_and_split_text(llm_stream) # 异步函数调用
# 初始化TTS处理器
tts = StreamingVolcanoTTS(max_concurrency=1)
# 异步迭代文本流
async for text_chunk in text_stream:
await tts._synthesize_single_with_semaphore(text_chunk, audio_callback)
logger.info("TTS合成完成")
except Exception as e:
logger.error(f"TTS合成失败: {str(e)}")
await websocket.send_json({"type": "error", "message": f"TTS合成失败: {str(e)}"})
return
# 发送结束信号
logger.info("发送结束信号")
try:
await websocket.send_json({"type": "end"})
logger.info("结束信号发送成功")
except Exception as e:
logger.error(f"发送结束信号失败: {str(e)}")
return
except WebSocketDisconnect:
logger.info("客户端断开连接")
except Exception as e:
logger.error(f"WebSocket处理失败: {str(e)}")
try:
await websocket.send_json({"type": "error", "message": str(e)})
except:
logger.error("发送错误消息失败")
# 原有的辅助函数保持不变
async def process_asr(audio_path: str) -> dict:
"""
调用ASR服务处理音频文件
:param audio_path: 音频文件路径
:return: 识别结果字典
"""
try:
# 上传文件到华为云OBS
audio_url = upload_file_to_obs(audio_path)
# 创建ASR客户端实例
asr_client = ASRClient()
# 设置音频文件URL
asr_client.file_url = audio_url
# 处理ASR任务并获取文本结果
text_result = asr_client.process_task()
# 构建返回结果
return {
"text": text_result,
"confidence": 1.0, # 实际应用中这里应该从ASR结果中获取置信度
"audio_duration": os.path.getsize(audio_path) / 1024 / 16 # 估算音频时长
}
except Exception as e:
logger.error(f"ASR处理失败: {str(e)}")
raise
def upload_file_to_obs(file_path: str) -> str:
"""
将本地文件上传到华为云OBS并返回URL
:param file_path: 本地文件路径
:return: OBS上的文件URL
"""
try:
# 创建OBS上传器实例
obs_uploader = ObsUploader()
# 生成UUID文件名
file_uuid = str(uuid.uuid4())
file_ext = os.path.splitext(file_path)[1].lower()
# 确保文件扩展名为.wav
if file_ext != '.wav':
file_ext = '.wav'
# 构建对象键(前缀 + UUID + .wav
object_key = f"HuangHai/XueBan/{file_uuid}{file_ext}"
# 上传文件
success, result = obs_uploader.upload_file(object_key, file_path)
if success:
logger.info(f"文件上传成功: {file_path} -> {object_key}")
# 构建文件URL假设OBS桶是公开可读的
# 实际应用中URL格式可能需要根据华为云OBS的具体配置进行调整
from Config.Config import OBS_SERVER, OBS_BUCKET
file_url = f"https://{OBS_BUCKET}.{OBS_SERVER}/{object_key}"
return file_url
else:
error_msg = f"文件上传失败: {result}"
logger.error(error_msg)
raise Exception(error_msg)
except Exception as e:
logger.error(f"上传文件到OBS失败: {str(e)}")
raise
@router.post("/xueban/chat")
async def chat_with_xueban(request: Request):
"""
与学伴大模型聊天的接口
- 参数: request body 中的 query_text (用户查询文本)
- 返回: JSON包含聊天响应
"""
try:
# 获取请求体数据
data = await request.json()
query_text = data.get("query_text", "")
if not query_text.strip():
return JSONResponse(content={
"success": False,
"message": "查询文本不能为空"
}, status_code=400)
# 记录日志
logger.info(f"接收到学伴聊天请求: {query_text}")
# 调用异步接口获取学伴响应
response_content = []
async for chunk in get_xueban_response_async(query_text, stream=True):
response_content.append(chunk)
full_response = "".join(response_content)
# 返回响应
return JSONResponse(content={
"success": True,
"message": "聊天成功",
"data": {
"response": full_response
}
})
except Exception as e:
logger.error(f"学伴聊天失败: {str(e)}")
return JSONResponse(content={
"success": False,
"message": f"聊天处理失败: {str(e)}"
}, status_code=500)