2025-08-22 08:32:39 +08:00
|
|
|
|
import logging
|
2025-08-22 08:35:57 +08:00
|
|
|
|
import os
|
|
|
|
|
import tempfile
|
2025-08-22 10:01:04 +08:00
|
|
|
|
import uuid
|
2025-08-22 08:35:57 +08:00
|
|
|
|
from datetime import datetime
|
2025-08-22 08:32:39 +08:00
|
|
|
|
|
2025-08-31 10:22:31 +08:00
|
|
|
|
from fastapi import APIRouter, Request, File, UploadFile, WebSocket, WebSocketDisconnect
|
2025-08-22 08:35:57 +08:00
|
|
|
|
from fastapi.responses import JSONResponse
|
2025-08-22 08:32:39 +08:00
|
|
|
|
|
|
|
|
|
# 创建路由路由器
|
|
|
|
|
router = APIRouter(prefix="/api", tags=["学伴"])
|
|
|
|
|
|
|
|
|
|
# 配置日志
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
2025-08-22 10:01:04 +08:00
|
|
|
|
# 导入学伴工具函数、ASR客户端和OBS上传工具
|
2025-08-22 09:36:29 +08:00
|
|
|
|
from Util.XueBanUtil import get_xueban_response_async
|
2025-08-22 10:01:04 +08:00
|
|
|
|
from Util.ASRClient import ASRClient
|
|
|
|
|
from Util.ObsUtil import ObsUploader
|
2025-08-31 10:22:31 +08:00
|
|
|
|
# 导入TTS管道
|
|
|
|
|
from Util.TTS_Pipeline import stream_and_split_text, StreamingVolcanoTTS
|
2025-08-22 08:35:57 +08:00
|
|
|
|
|
2025-08-31 10:22:31 +08:00
|
|
|
|
# 保留原有的HTTP接口,用于向后兼容
|
2025-08-22 08:35:57 +08:00
|
|
|
|
@router.post("/xueban/upload-audio")
|
|
|
|
|
async def upload_audio(file: UploadFile = File(...)):
|
|
|
|
|
"""
|
2025-08-31 10:22:31 +08:00
|
|
|
|
上传音频文件并进行ASR处理 - 原有接口,用于向后兼容
|
2025-08-22 08:35:57 +08:00
|
|
|
|
- 参数: file - 音频文件
|
|
|
|
|
- 返回: JSON包含识别结果
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
# 记录日志
|
|
|
|
|
logger.info(f"接收到音频文件: {file.filename}")
|
|
|
|
|
|
|
|
|
|
# 保存临时文件
|
|
|
|
|
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
|
|
|
|
|
file_ext = os.path.splitext(file.filename)[1]
|
|
|
|
|
temp_file_name = f"temp_audio_{timestamp}{file_ext}"
|
|
|
|
|
temp_file_path = os.path.join(tempfile.gettempdir(), temp_file_name)
|
|
|
|
|
|
|
|
|
|
with open(temp_file_path, "wb") as f:
|
|
|
|
|
content = await file.read()
|
|
|
|
|
f.write(content)
|
|
|
|
|
|
|
|
|
|
logger.info(f"音频文件已保存至临时目录: {temp_file_path}")
|
|
|
|
|
|
2025-08-22 10:01:04 +08:00
|
|
|
|
# 调用ASR服务进行处理
|
|
|
|
|
asr_result = await process_asr(temp_file_path)
|
2025-08-22 08:35:57 +08:00
|
|
|
|
|
|
|
|
|
# 删除临时文件
|
2025-08-22 09:07:25 +08:00
|
|
|
|
os.remove(temp_file_path)
|
|
|
|
|
logger.info(f"临时文件已删除: {temp_file_path}")
|
2025-08-22 08:35:57 +08:00
|
|
|
|
|
2025-08-22 10:22:27 +08:00
|
|
|
|
# 使用大模型生成反馈
|
|
|
|
|
logger.info(f"使用大模型生成反馈,输入文本: {asr_result['text']}")
|
|
|
|
|
response_generator = get_xueban_response_async(asr_result['text'], stream=False)
|
|
|
|
|
feedback_text = ""
|
|
|
|
|
async for chunk in response_generator:
|
|
|
|
|
feedback_text += chunk
|
|
|
|
|
logger.info(f"大模型反馈生成完成: {feedback_text}")
|
|
|
|
|
|
2025-08-31 10:22:31 +08:00
|
|
|
|
# 使用流式TTS生成语音
|
|
|
|
|
import io
|
|
|
|
|
audio_chunks = []
|
|
|
|
|
|
|
|
|
|
# 定义音频回调函数,收集音频块
|
|
|
|
|
def audio_callback(audio_chunk):
|
|
|
|
|
audio_chunks.append(audio_chunk)
|
|
|
|
|
|
|
|
|
|
# 获取LLM流式输出并断句
|
|
|
|
|
text_stream = stream_and_split_text(asr_result['text'])
|
|
|
|
|
|
|
|
|
|
# 初始化TTS处理器
|
|
|
|
|
tts = StreamingVolcanoTTS(max_concurrency=2)
|
|
|
|
|
|
|
|
|
|
# 流式处理文本并生成音频
|
|
|
|
|
await tts.synthesize_stream(text_stream, audio_callback)
|
|
|
|
|
|
|
|
|
|
# 合并所有音频块
|
|
|
|
|
if audio_chunks:
|
|
|
|
|
tts_temp_file = os.path.join(tempfile.gettempdir(), f"tts_{timestamp}.mp3")
|
|
|
|
|
with open(tts_temp_file, "wb") as f:
|
|
|
|
|
for chunk in audio_chunks:
|
|
|
|
|
f.write(chunk)
|
|
|
|
|
logger.info(f"TTS语音合成成功,文件保存至: {tts_temp_file}")
|
|
|
|
|
else:
|
|
|
|
|
raise Exception("TTS语音合成失败,未生成音频数据")
|
2025-08-22 10:22:27 +08:00
|
|
|
|
|
|
|
|
|
# 上传TTS音频文件到OBS
|
|
|
|
|
tts_audio_url = upload_file_to_obs(tts_temp_file)
|
|
|
|
|
os.remove(tts_temp_file) # 删除临时TTS文件
|
|
|
|
|
logger.info(f"TTS文件已上传至OBS: {tts_audio_url}")
|
|
|
|
|
|
|
|
|
|
# 返回结果,包含ASR文本和TTS音频URL
|
2025-08-22 08:35:57 +08:00
|
|
|
|
return JSONResponse(content={
|
|
|
|
|
"success": True,
|
2025-08-22 10:22:27 +08:00
|
|
|
|
"message": "音频处理和语音反馈生成成功",
|
|
|
|
|
"data": {
|
|
|
|
|
"asr_text": asr_result['text'],
|
|
|
|
|
"feedback_text": feedback_text,
|
|
|
|
|
"audio_url": tts_audio_url
|
|
|
|
|
}
|
2025-08-22 08:35:57 +08:00
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"音频处理失败: {str(e)}")
|
|
|
|
|
return JSONResponse(content={
|
|
|
|
|
"success": False,
|
|
|
|
|
"message": f"音频处理失败: {str(e)}"
|
|
|
|
|
}, status_code=500)
|
|
|
|
|
|
2025-08-31 10:22:31 +08:00
|
|
|
|
# 新增WebSocket接口,用于流式处理
|
|
|
|
|
@router.websocket("/xueban/streaming-chat")
|
|
|
|
|
async def streaming_chat(websocket: WebSocket):
|
|
|
|
|
await websocket.accept()
|
|
|
|
|
logger.info("WebSocket连接已接受")
|
|
|
|
|
try:
|
|
|
|
|
# 接收用户音频文件
|
|
|
|
|
logger.info("等待接收音频数据...")
|
|
|
|
|
data = await websocket.receive_json()
|
|
|
|
|
logger.info(f"接收到数据类型: {type(data)}")
|
|
|
|
|
logger.info(f"接收到数据内容: {data.keys() if isinstance(data, dict) else '非字典类型'}")
|
|
|
|
|
|
|
|
|
|
# 检查数据格式
|
|
|
|
|
if not isinstance(data, dict):
|
|
|
|
|
logger.error(f"接收到的数据不是字典类型,而是: {type(data)}")
|
|
|
|
|
await websocket.send_json({"type": "error", "message": "数据格式错误"})
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
audio_data = data.get("audio_data")
|
|
|
|
|
logger.info(f"音频数据是否存在: {audio_data is not None}")
|
|
|
|
|
logger.info(f"音频数据长度: {len(audio_data) if audio_data else 0}")
|
|
|
|
|
|
|
|
|
|
if not audio_data:
|
|
|
|
|
logger.error("未收到音频数据")
|
|
|
|
|
await websocket.send_json({"type": "error", "message": "未收到音频数据"})
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
# 保存临时音频文件
|
|
|
|
|
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
|
|
|
|
|
temp_file_path = os.path.join(tempfile.gettempdir(), f"temp_audio_{timestamp}.wav")
|
|
|
|
|
logger.info(f"保存临时音频文件到: {temp_file_path}")
|
|
|
|
|
|
|
|
|
|
# 解码base64音频数据并保存
|
|
|
|
|
import base64
|
|
|
|
|
try:
|
|
|
|
|
with open(temp_file_path, "wb") as f:
|
|
|
|
|
f.write(base64.b64decode(audio_data))
|
|
|
|
|
logger.info("音频文件保存完成")
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"音频文件保存失败: {str(e)}")
|
|
|
|
|
await websocket.send_json({"type": "error", "message": f"音频文件保存失败: {str(e)}"})
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
# 处理ASR
|
|
|
|
|
logger.info("开始ASR处理...")
|
|
|
|
|
try:
|
|
|
|
|
asr_result = await process_asr(temp_file_path)
|
|
|
|
|
logger.info(f"ASR处理完成,结果: {asr_result['text']}")
|
|
|
|
|
os.remove(temp_file_path) # 删除临时文件
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"ASR处理失败: {str(e)}")
|
|
|
|
|
await websocket.send_json({"type": "error", "message": f"ASR处理失败: {str(e)}"})
|
|
|
|
|
if os.path.exists(temp_file_path):
|
|
|
|
|
os.remove(temp_file_path) # 确保删除临时文件
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
# 发送ASR结果给前端
|
|
|
|
|
logger.info("发送ASR结果给前端")
|
|
|
|
|
try:
|
|
|
|
|
await websocket.send_json({
|
|
|
|
|
"type": "asr_result",
|
|
|
|
|
"text": asr_result['text']
|
|
|
|
|
})
|
|
|
|
|
logger.info("ASR结果发送成功")
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"发送ASR结果失败: {str(e)}")
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
# 定义音频回调函数,将音频块发送给前端
|
|
|
|
|
async def audio_callback(audio_chunk):
|
|
|
|
|
logger.info(f"发送音频块,大小: {len(audio_chunk)}")
|
|
|
|
|
try:
|
|
|
|
|
await websocket.send_bytes(audio_chunk)
|
|
|
|
|
logger.info("音频块发送成功")
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"发送音频块失败: {str(e)}")
|
|
|
|
|
raise
|
|
|
|
|
|
2025-08-31 12:41:05 +08:00
|
|
|
|
# 获取学伴响应内容(包含题目信息)
|
|
|
|
|
logger.info("获取学伴响应内容...")
|
|
|
|
|
llm_chunks = []
|
|
|
|
|
async for chunk in get_xueban_response_async(asr_result['text'], stream=True):
|
|
|
|
|
llm_chunks.append(chunk)
|
|
|
|
|
full_llm_response = ''.join(llm_chunks)
|
|
|
|
|
logger.info(f"学伴响应内容: {full_llm_response}")
|
|
|
|
|
|
|
|
|
|
# 定义音频回调函数,将音频块发送给前端
|
|
|
|
|
async def audio_callback(audio_chunk):
|
|
|
|
|
logger.info(f"发送音频块,大小: {len(audio_chunk)}")
|
|
|
|
|
try:
|
|
|
|
|
await websocket.send_bytes(audio_chunk)
|
|
|
|
|
logger.info("音频块发送成功")
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"发送音频块失败: {str(e)}")
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
# 实时获取LLM流式输出并处理
|
|
|
|
|
logger.info("开始LLM流式处理和TTS合成...")
|
2025-08-31 10:22:31 +08:00
|
|
|
|
try:
|
2025-08-31 12:41:05 +08:00
|
|
|
|
# 直接将LLM流式响应接入TTS
|
|
|
|
|
llm_stream = get_xueban_response_async(asr_result['text'], stream=True)
|
|
|
|
|
text_stream = stream_and_split_text(llm_stream) # 异步函数调用
|
2025-08-31 10:22:31 +08:00
|
|
|
|
|
|
|
|
|
# 初始化TTS处理器
|
2025-08-31 12:41:05 +08:00
|
|
|
|
tts = StreamingVolcanoTTS(max_concurrency=1)
|
2025-08-31 10:22:31 +08:00
|
|
|
|
|
2025-08-31 12:41:05 +08:00
|
|
|
|
# 异步迭代文本流
|
|
|
|
|
async for text_chunk in text_stream:
|
|
|
|
|
await tts._synthesize_single_with_semaphore(text_chunk, audio_callback)
|
2025-08-31 10:22:31 +08:00
|
|
|
|
logger.info("TTS合成完成")
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"TTS合成失败: {str(e)}")
|
|
|
|
|
await websocket.send_json({"type": "error", "message": f"TTS合成失败: {str(e)}"})
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
# 发送结束信号
|
|
|
|
|
logger.info("发送结束信号")
|
|
|
|
|
try:
|
|
|
|
|
await websocket.send_json({"type": "end"})
|
|
|
|
|
logger.info("结束信号发送成功")
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"发送结束信号失败: {str(e)}")
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
except WebSocketDisconnect:
|
|
|
|
|
logger.info("客户端断开连接")
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"WebSocket处理失败: {str(e)}")
|
|
|
|
|
try:
|
|
|
|
|
await websocket.send_json({"type": "error", "message": str(e)})
|
|
|
|
|
except:
|
|
|
|
|
logger.error("发送错误消息失败")
|
2025-08-22 10:01:04 +08:00
|
|
|
|
|
2025-08-31 10:22:31 +08:00
|
|
|
|
# 原有的辅助函数保持不变
|
2025-08-22 08:35:57 +08:00
|
|
|
|
async def process_asr(audio_path: str) -> dict:
|
|
|
|
|
"""
|
|
|
|
|
调用ASR服务处理音频文件
|
|
|
|
|
:param audio_path: 音频文件路径
|
|
|
|
|
:return: 识别结果字典
|
|
|
|
|
"""
|
2025-08-22 10:01:04 +08:00
|
|
|
|
try:
|
|
|
|
|
# 上传文件到华为云OBS
|
|
|
|
|
audio_url = upload_file_to_obs(audio_path)
|
|
|
|
|
|
|
|
|
|
# 创建ASR客户端实例
|
|
|
|
|
asr_client = ASRClient()
|
|
|
|
|
|
|
|
|
|
# 设置音频文件URL
|
|
|
|
|
asr_client.file_url = audio_url
|
|
|
|
|
|
|
|
|
|
# 处理ASR任务并获取文本结果
|
|
|
|
|
text_result = asr_client.process_task()
|
|
|
|
|
|
|
|
|
|
# 构建返回结果
|
|
|
|
|
return {
|
|
|
|
|
"text": text_result,
|
|
|
|
|
"confidence": 1.0, # 实际应用中,这里应该从ASR结果中获取置信度
|
|
|
|
|
"audio_duration": os.path.getsize(audio_path) / 1024 / 16 # 估算音频时长
|
|
|
|
|
}
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"ASR处理失败: {str(e)}")
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def upload_file_to_obs(file_path: str) -> str:
|
|
|
|
|
"""
|
|
|
|
|
将本地文件上传到华为云OBS并返回URL
|
|
|
|
|
:param file_path: 本地文件路径
|
|
|
|
|
:return: OBS上的文件URL
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
# 创建OBS上传器实例
|
|
|
|
|
obs_uploader = ObsUploader()
|
|
|
|
|
|
|
|
|
|
# 生成UUID文件名
|
|
|
|
|
file_uuid = str(uuid.uuid4())
|
|
|
|
|
file_ext = os.path.splitext(file_path)[1].lower()
|
|
|
|
|
# 确保文件扩展名为.wav
|
|
|
|
|
if file_ext != '.wav':
|
|
|
|
|
file_ext = '.wav'
|
|
|
|
|
|
|
|
|
|
# 构建对象键(前缀 + UUID + .wav)
|
|
|
|
|
object_key = f"HuangHai/XueBan/{file_uuid}{file_ext}"
|
|
|
|
|
|
|
|
|
|
# 上传文件
|
|
|
|
|
success, result = obs_uploader.upload_file(object_key, file_path)
|
|
|
|
|
|
|
|
|
|
if success:
|
|
|
|
|
logger.info(f"文件上传成功: {file_path} -> {object_key}")
|
|
|
|
|
# 构建文件URL(假设OBS桶是公开可读的)
|
|
|
|
|
# 实际应用中,URL格式可能需要根据华为云OBS的具体配置进行调整
|
|
|
|
|
from Config.Config import OBS_SERVER, OBS_BUCKET
|
|
|
|
|
file_url = f"https://{OBS_BUCKET}.{OBS_SERVER}/{object_key}"
|
|
|
|
|
return file_url
|
|
|
|
|
else:
|
|
|
|
|
error_msg = f"文件上传失败: {result}"
|
|
|
|
|
logger.error(error_msg)
|
|
|
|
|
raise Exception(error_msg)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"上传文件到OBS失败: {str(e)}")
|
|
|
|
|
raise
|
2025-08-22 09:36:29 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.post("/xueban/chat")
|
|
|
|
|
async def chat_with_xueban(request: Request):
|
|
|
|
|
"""
|
|
|
|
|
与学伴大模型聊天的接口
|
|
|
|
|
- 参数: request body 中的 query_text (用户查询文本)
|
|
|
|
|
- 返回: JSON包含聊天响应
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
# 获取请求体数据
|
|
|
|
|
data = await request.json()
|
|
|
|
|
query_text = data.get("query_text", "")
|
|
|
|
|
|
|
|
|
|
if not query_text.strip():
|
|
|
|
|
return JSONResponse(content={
|
|
|
|
|
"success": False,
|
|
|
|
|
"message": "查询文本不能为空"
|
|
|
|
|
}, status_code=400)
|
|
|
|
|
|
|
|
|
|
# 记录日志
|
|
|
|
|
logger.info(f"接收到学伴聊天请求: {query_text}")
|
|
|
|
|
|
|
|
|
|
# 调用异步接口获取学伴响应
|
|
|
|
|
response_content = []
|
|
|
|
|
async for chunk in get_xueban_response_async(query_text, stream=True):
|
|
|
|
|
response_content.append(chunk)
|
|
|
|
|
|
|
|
|
|
full_response = "".join(response_content)
|
|
|
|
|
|
|
|
|
|
# 返回响应
|
|
|
|
|
return JSONResponse(content={
|
|
|
|
|
"success": True,
|
|
|
|
|
"message": "聊天成功",
|
|
|
|
|
"data": {
|
|
|
|
|
"response": full_response
|
|
|
|
|
}
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"学伴聊天失败: {str(e)}")
|
|
|
|
|
return JSONResponse(content={
|
|
|
|
|
"success": False,
|
|
|
|
|
"message": f"聊天处理失败: {str(e)}"
|
|
|
|
|
}, status_code=500)
|