|
|
import os
|
|
|
import copy
|
|
|
import json
|
|
|
import uuid
|
|
|
import time
|
|
|
import queue
|
|
|
import asyncio
|
|
|
import traceback
|
|
|
|
|
|
import threading
|
|
|
import websockets
|
|
|
from typing import Dict, Any
|
|
|
from plugins_func.loadplugins import auto_import_modules
|
|
|
from config.logger import setup_logging
|
|
|
from core.utils.dialogue import Message, Dialogue
|
|
|
from core.handle.textHandle import handleTextMessage
|
|
|
from core.utils.util import (
|
|
|
get_string_no_punctuation_or_emoji,
|
|
|
extract_json_from_string,
|
|
|
get_ip_info,
|
|
|
initialize_modules,
|
|
|
)
|
|
|
from concurrent.futures import ThreadPoolExecutor, TimeoutError
|
|
|
from core.handle.sendAudioHandle import sendAudioMessage
|
|
|
from core.handle.receiveAudioHandle import handleAudioMessage
|
|
|
from core.handle.functionHandler import FunctionHandler
|
|
|
from plugins_func.register import Action, ActionResponse
|
|
|
from core.auth import AuthMiddleware, AuthenticationError
|
|
|
from core.mcp.manager import MCPManager
|
|
|
from config.config_loader import get_private_config_from_api
|
|
|
|
|
|
TAG = __name__
|
|
|
|
|
|
auto_import_modules("plugins_func.functions")
|
|
|
|
|
|
|
|
|
class TTSException(RuntimeError):
|
|
|
pass
|
|
|
|
|
|
|
|
|
class ConnectionHandler:
|
|
|
def __init__(
|
|
|
self, config: Dict[str, Any], _vad, _asr, _llm, _tts, _memory, _intent
|
|
|
):
|
|
|
self.config = copy.deepcopy(config)
|
|
|
self.logger = setup_logging()
|
|
|
self.auth = AuthMiddleware(config)
|
|
|
|
|
|
self.websocket = None
|
|
|
self.headers = None
|
|
|
self.client_ip = None
|
|
|
self.client_ip_info = {}
|
|
|
self.session_id = None
|
|
|
self.prompt = None
|
|
|
self.welcome_msg = None
|
|
|
|
|
|
# 客户端状态相关
|
|
|
self.client_abort = False
|
|
|
self.client_listen_mode = "auto"
|
|
|
|
|
|
# 线程任务相关
|
|
|
self.loop = asyncio.get_event_loop()
|
|
|
self.stop_event = threading.Event()
|
|
|
self.tts_queue = queue.Queue()
|
|
|
self.audio_play_queue = queue.Queue()
|
|
|
self.executor = ThreadPoolExecutor(max_workers=10)
|
|
|
|
|
|
# 依赖的组件
|
|
|
self.vad = _vad
|
|
|
self.asr = _asr
|
|
|
self.llm = _llm
|
|
|
self.tts = _tts
|
|
|
self.memory = _memory
|
|
|
self.intent = _intent
|
|
|
|
|
|
# vad相关变量
|
|
|
self.client_audio_buffer = bytearray()
|
|
|
self.client_have_voice = False
|
|
|
self.client_have_voice_last_time = 0.0
|
|
|
self.client_no_voice_last_time = 0.0
|
|
|
self.client_voice_stop = False
|
|
|
|
|
|
# asr相关变量
|
|
|
self.asr_audio = []
|
|
|
self.asr_server_receive = True
|
|
|
|
|
|
# llm相关变量
|
|
|
self.llm_finish_task = False
|
|
|
self.dialogue = Dialogue()
|
|
|
|
|
|
# tts相关变量
|
|
|
self.tts_first_text_index = -1
|
|
|
self.tts_last_text_index = -1
|
|
|
|
|
|
# iot相关变量
|
|
|
self.iot_descriptors = {}
|
|
|
self.func_handler = None
|
|
|
|
|
|
self.cmd_exit = self.config["exit_commands"]
|
|
|
self.max_cmd_length = 0
|
|
|
for cmd in self.cmd_exit:
|
|
|
if len(cmd) > self.max_cmd_length:
|
|
|
self.max_cmd_length = len(cmd)
|
|
|
|
|
|
self.close_after_chat = False # 是否在聊天结束后关闭连接
|
|
|
self.use_function_call_mode = False
|
|
|
|
|
|
async def handle_connection(self, ws):
|
|
|
try:
|
|
|
# 获取并验证headers
|
|
|
self.headers = dict(ws.request.headers)
|
|
|
# 获取客户端ip地址
|
|
|
self.client_ip = ws.remote_address[0]
|
|
|
self.logger.bind(tag=TAG).info(
|
|
|
f"{self.client_ip} conn - Headers: {self.headers}"
|
|
|
)
|
|
|
|
|
|
# 进行认证
|
|
|
await self.auth.authenticate(self.headers)
|
|
|
|
|
|
# 认证通过,继续处理
|
|
|
self.websocket = ws
|
|
|
self.session_id = str(uuid.uuid4())
|
|
|
|
|
|
self.welcome_msg = self.config["xiaozhi"]
|
|
|
self.welcome_msg["session_id"] = self.session_id
|
|
|
await self.websocket.send(json.dumps(self.welcome_msg))
|
|
|
|
|
|
# 异步初始化
|
|
|
self.executor.submit(self._initialize_components)
|
|
|
# tts 消化线程
|
|
|
self.tts_priority_thread = threading.Thread(
|
|
|
target=self._tts_priority_thread, daemon=True
|
|
|
)
|
|
|
self.tts_priority_thread.start()
|
|
|
|
|
|
# 音频播放 消化线程
|
|
|
self.audio_play_priority_thread = threading.Thread(
|
|
|
target=self._audio_play_priority_thread, daemon=True
|
|
|
)
|
|
|
self.audio_play_priority_thread.start()
|
|
|
|
|
|
try:
|
|
|
async for message in self.websocket:
|
|
|
await self._route_message(message)
|
|
|
except websockets.exceptions.ConnectionClosed:
|
|
|
self.logger.bind(tag=TAG).info("客户端断开连接")
|
|
|
|
|
|
except AuthenticationError as e:
|
|
|
self.logger.bind(tag=TAG).error(f"Authentication failed: {str(e)}")
|
|
|
return
|
|
|
except Exception as e:
|
|
|
stack_trace = traceback.format_exc()
|
|
|
self.logger.bind(tag=TAG).error(f"Connection error: {str(e)}-{stack_trace}")
|
|
|
return
|
|
|
finally:
|
|
|
await self._save_and_close(ws)
|
|
|
|
|
|
async def _save_and_close(self, ws):
|
|
|
"""保存记忆并关闭连接"""
|
|
|
try:
|
|
|
await self.memory.save_memory(self.dialogue.dialogue)
|
|
|
except Exception as e:
|
|
|
self.logger.bind(tag=TAG).error(f"保存记忆失败: {e}")
|
|
|
finally:
|
|
|
await self.close(ws)
|
|
|
|
|
|
async def _route_message(self, message):
|
|
|
"""消息路由"""
|
|
|
if isinstance(message, str):
|
|
|
await handleTextMessage(self, message)
|
|
|
elif isinstance(message, bytes):
|
|
|
await handleAudioMessage(self, message)
|
|
|
|
|
|
def _initialize_components(self):
|
|
|
"""初始化组件"""
|
|
|
self._initialize_models()
|
|
|
|
|
|
"""加载提示词"""
|
|
|
self.prompt = self.config["prompt"]
|
|
|
self.dialogue.put(Message(role="system", content=self.prompt))
|
|
|
|
|
|
"""加载记忆"""
|
|
|
self._initialize_memory()
|
|
|
"""加载意图识别"""
|
|
|
self._initialize_intent()
|
|
|
"""加载位置信息"""
|
|
|
self.client_ip_info = get_ip_info(self.client_ip, self.logger)
|
|
|
if self.client_ip_info is not None and "city" in self.client_ip_info:
|
|
|
self.logger.bind(tag=TAG).info(f"Client ip info: {self.client_ip_info}")
|
|
|
self.prompt = self.prompt + f"\nuser location:{self.client_ip_info}"
|
|
|
|
|
|
self.dialogue.update_system_message(self.prompt)
|
|
|
|
|
|
def _initialize_models(self):
|
|
|
read_config_from_api = self.config.get("read_config_from_api", False)
|
|
|
"""如果是从配置文件获取,则进行二次实例化"""
|
|
|
if not read_config_from_api:
|
|
|
return
|
|
|
"""从接口获取差异化的配置进行二次实例化,非全量重新实例化"""
|
|
|
try:
|
|
|
private_config = get_private_config_from_api(
|
|
|
self.config,
|
|
|
self.headers.get("device-id", None),
|
|
|
self.headers.get("client-id", None),
|
|
|
)
|
|
|
private_config["delete_audio"] = self.config["delete_audio"]
|
|
|
self.logger.bind(tag=TAG).info(f"获取差异化配置成功: {private_config}")
|
|
|
except Exception as e:
|
|
|
self.logger.bind(tag=TAG).error(f"获取差异化配置失败: {e}")
|
|
|
private_config = {}
|
|
|
|
|
|
init_vad, init_asr, init_llm, init_tts, init_memory, init_intent = (
|
|
|
False,
|
|
|
False,
|
|
|
False,
|
|
|
False,
|
|
|
False,
|
|
|
False,
|
|
|
)
|
|
|
if private_config.get("VAD", None) is not None:
|
|
|
init_vad = True
|
|
|
self.config["VAD"] = private_config["VAD"]
|
|
|
self.config["selected_module"]["VAD"] = private_config["selected_module"][
|
|
|
"VAD"
|
|
|
]
|
|
|
if private_config.get("ASR", None) is not None:
|
|
|
init_asr = True
|
|
|
self.config["ASR"] = private_config["ASR"]
|
|
|
self.config["selected_module"]["ASR"] = private_config["selected_module"][
|
|
|
"ASR"
|
|
|
]
|
|
|
if private_config.get("LLM", None) is not None:
|
|
|
init_llm = True
|
|
|
self.config["LLM"] = private_config["LLM"]
|
|
|
self.config["selected_module"]["LLM"] = private_config["selected_module"][
|
|
|
"LLM"
|
|
|
]
|
|
|
if private_config.get("TTS", None) is not None:
|
|
|
init_tts = True
|
|
|
self.config["TTS"] = private_config["TTS"]
|
|
|
self.config["selected_module"]["TTS"] = private_config["selected_module"][
|
|
|
"TTS"
|
|
|
]
|
|
|
if private_config.get("Memory", None) is not None:
|
|
|
init_memory = True
|
|
|
self.config["Memory"] = private_config["Memory"]
|
|
|
self.config["selected_module"]["Memory"] = private_config[
|
|
|
"selected_module"
|
|
|
]["Memory"]
|
|
|
if private_config.get("Intent", None) is not None:
|
|
|
init_intent = True
|
|
|
self.config["Intent"] = private_config["Intent"]
|
|
|
self.config["selected_module"]["Intent"] = private_config[
|
|
|
"selected_module"
|
|
|
]["Intent"]
|
|
|
if private_config.get("prompt", None) is not None:
|
|
|
self.config["prompt"] = private_config["prompt"]
|
|
|
try:
|
|
|
modules = initialize_modules(
|
|
|
self.logger,
|
|
|
private_config,
|
|
|
init_vad,
|
|
|
init_asr,
|
|
|
init_llm,
|
|
|
init_tts,
|
|
|
init_memory,
|
|
|
init_intent,
|
|
|
)
|
|
|
except Exception as e:
|
|
|
self.logger.bind(tag=TAG).error(f"初始化组件失败: {e}")
|
|
|
modules = {}
|
|
|
if modules.get("tts", None) is not None:
|
|
|
self.tts = modules["tts"]
|
|
|
if modules.get("llm", None) is not None:
|
|
|
self.llm = modules["llm"]
|
|
|
if modules.get("intent", None) is not None:
|
|
|
self.intent = modules["intent"]
|
|
|
if modules.get("memory", None) is not None:
|
|
|
self.memory = modules["memory"]
|
|
|
|
|
|
def _initialize_memory(self):
|
|
|
"""初始化记忆模块"""
|
|
|
device_id = self.headers.get("device-id", None)
|
|
|
self.memory.init_memory(device_id, self.llm)
|
|
|
|
|
|
def _initialize_intent(self):
|
|
|
if (
|
|
|
self.config["Intent"][self.config["selected_module"]["Intent"]]["type"]
|
|
|
== "function_call"
|
|
|
):
|
|
|
self.use_function_call_mode = True
|
|
|
"""初始化意图识别模块"""
|
|
|
# 获取意图识别配置
|
|
|
intent_config = self.config["Intent"]
|
|
|
intent_type = self.config["Intent"][self.config["selected_module"]["Intent"]][
|
|
|
"type"
|
|
|
]
|
|
|
|
|
|
# 如果使用 nointent,直接返回
|
|
|
if intent_type == "nointent":
|
|
|
return
|
|
|
# 使用 intent_llm 模式
|
|
|
elif intent_type == "intent_llm":
|
|
|
intent_llm_name = intent_config[self.config["selected_module"]["Intent"]][
|
|
|
"llm"
|
|
|
]
|
|
|
|
|
|
if intent_llm_name and intent_llm_name in self.config["LLM"]:
|
|
|
# 如果配置了专用LLM,则创建独立的LLM实例
|
|
|
from core.utils import llm as llm_utils
|
|
|
|
|
|
intent_llm_config = self.config["LLM"][intent_llm_name]
|
|
|
intent_llm_type = intent_llm_config.get("type", intent_llm_name)
|
|
|
intent_llm = llm_utils.create_instance(
|
|
|
intent_llm_type, intent_llm_config
|
|
|
)
|
|
|
self.logger.bind(tag=TAG).info(
|
|
|
f"为意图识别创建了专用LLM: {intent_llm_name}, 类型: {intent_llm_type}"
|
|
|
)
|
|
|
self.intent.set_llm(intent_llm)
|
|
|
else:
|
|
|
# 否则使用主LLM
|
|
|
self.intent.set_llm(self.llm)
|
|
|
self.logger.bind(tag=TAG).info("使用主LLM作为意图识别模型")
|
|
|
|
|
|
"""加载插件"""
|
|
|
self.func_handler = FunctionHandler(self)
|
|
|
self.mcp_manager = MCPManager(self)
|
|
|
|
|
|
"""加载MCP工具"""
|
|
|
asyncio.run_coroutine_threadsafe(
|
|
|
self.mcp_manager.initialize_servers(), self.loop
|
|
|
)
|
|
|
|
|
|
def change_system_prompt(self, prompt):
|
|
|
self.prompt = prompt
|
|
|
# 找到原来的role==system,替换原来的系统提示
|
|
|
for m in self.dialogue.dialogue:
|
|
|
if m.role == "system":
|
|
|
m.content = prompt
|
|
|
|
|
|
def chat(self, query):
|
|
|
|
|
|
self.dialogue.put(Message(role="user", content=query))
|
|
|
|
|
|
response_message = []
|
|
|
processed_chars = 0 # 跟踪已处理的字符位置
|
|
|
try:
|
|
|
start_time = time.time()
|
|
|
# 使用带记忆的对话
|
|
|
future = asyncio.run_coroutine_threadsafe(
|
|
|
self.memory.query_memory(query), self.loop
|
|
|
)
|
|
|
memory_str = future.result()
|
|
|
|
|
|
self.logger.bind(tag=TAG).debug(f"记忆内容: {memory_str}")
|
|
|
llm_responses = self.llm.response(
|
|
|
self.session_id, self.dialogue.get_llm_dialogue_with_memory(memory_str)
|
|
|
)
|
|
|
except Exception as e:
|
|
|
self.logger.bind(tag=TAG).error(f"LLM 处理出错 {query}: {e}")
|
|
|
return None
|
|
|
|
|
|
self.llm_finish_task = False
|
|
|
text_index = 0
|
|
|
for content in llm_responses:
|
|
|
response_message.append(content)
|
|
|
if self.client_abort:
|
|
|
break
|
|
|
|
|
|
end_time = time.time()
|
|
|
self.logger.bind(tag=TAG).debug(f"大模型返回时间: {end_time - start_time} 秒, 生成token={content}")
|
|
|
|
|
|
# 合并当前全部文本并处理未分割部分
|
|
|
full_text = "".join(response_message)
|
|
|
current_text = full_text[processed_chars:] # 从未处理的位置开始
|
|
|
|
|
|
# 查找最后一个有效标点
|
|
|
punctuations = ("。", "?", "!", ";", ":")
|
|
|
last_punct_pos = -1
|
|
|
for punct in punctuations:
|
|
|
pos = current_text.rfind(punct)
|
|
|
if pos > last_punct_pos:
|
|
|
last_punct_pos = pos
|
|
|
|
|
|
# 找到分割点则处理
|
|
|
if last_punct_pos != -1:
|
|
|
segment_text_raw = current_text[: last_punct_pos + 1]
|
|
|
segment_text = get_string_no_punctuation_or_emoji(segment_text_raw)
|
|
|
if segment_text:
|
|
|
# 强制设置空字符,测试TTS出错返回语音的健壮性
|
|
|
# if text_index % 2 == 0:
|
|
|
# segment_text = " "
|
|
|
text_index += 1
|
|
|
self.recode_first_last_text(segment_text, text_index)
|
|
|
future = self.executor.submit(
|
|
|
self.speak_and_play, segment_text, text_index
|
|
|
)
|
|
|
self.tts_queue.put(future)
|
|
|
processed_chars += len(segment_text_raw) # 更新已处理字符位置
|
|
|
|
|
|
# 处理最后剩余的文本
|
|
|
full_text = "".join(response_message)
|
|
|
remaining_text = full_text[processed_chars:]
|
|
|
if remaining_text:
|
|
|
segment_text = get_string_no_punctuation_or_emoji(remaining_text)
|
|
|
if segment_text:
|
|
|
text_index += 1
|
|
|
self.recode_first_last_text(segment_text, text_index)
|
|
|
future = self.executor.submit(
|
|
|
self.speak_and_play, segment_text, text_index
|
|
|
)
|
|
|
self.tts_queue.put(future)
|
|
|
|
|
|
self.llm_finish_task = True
|
|
|
self.dialogue.put(Message(role="assistant", content="".join(response_message)))
|
|
|
self.logger.bind(tag=TAG).debug(
|
|
|
json.dumps(self.dialogue.get_llm_dialogue(), indent=4, ensure_ascii=False)
|
|
|
)
|
|
|
return True
|
|
|
|
|
|
def chat_with_function_calling(self, query, tool_call=False):
|
|
|
self.logger.bind(tag=TAG).debug(f"Chat with function calling start: {query}")
|
|
|
"""Chat with function calling for intent detection using streaming"""
|
|
|
|
|
|
if not tool_call:
|
|
|
self.dialogue.put(Message(role="user", content=query))
|
|
|
|
|
|
# Define intent functions
|
|
|
functions = None
|
|
|
if hasattr(self, "func_handler"):
|
|
|
functions = self.func_handler.get_functions()
|
|
|
response_message = []
|
|
|
processed_chars = 0 # 跟踪已处理的字符位置
|
|
|
|
|
|
try:
|
|
|
start_time = time.time()
|
|
|
|
|
|
# 使用带记忆的对话
|
|
|
future = asyncio.run_coroutine_threadsafe(
|
|
|
self.memory.query_memory(query), self.loop
|
|
|
)
|
|
|
memory_str = future.result()
|
|
|
|
|
|
self.logger.bind(tag=TAG).info(f"对话记录: {self.dialogue.get_llm_dialogue_with_memory(memory_str)}")
|
|
|
|
|
|
# 使用支持functions的streaming接口
|
|
|
llm_responses = self.llm.response_with_functions(
|
|
|
self.session_id,
|
|
|
self.dialogue.get_llm_dialogue_with_memory(memory_str),
|
|
|
functions=functions,
|
|
|
)
|
|
|
except Exception as e:
|
|
|
self.logger.bind(tag=TAG).error(f"LLM 处理出错 {query}: {e}")
|
|
|
return None
|
|
|
|
|
|
self.llm_finish_task = False
|
|
|
text_index = 0
|
|
|
|
|
|
# 处理流式响应
|
|
|
tool_call_flag = False
|
|
|
function_name = None
|
|
|
function_id = None
|
|
|
function_arguments = ""
|
|
|
content_arguments = ""
|
|
|
for response in llm_responses:
|
|
|
content, tools_call = response
|
|
|
if "content" in response:
|
|
|
content = response["content"]
|
|
|
tools_call = None
|
|
|
if content is not None and len(content) > 0:
|
|
|
if len(response_message) <= 0 and (
|
|
|
content == "```" or "<tool_call>" in content
|
|
|
):
|
|
|
tool_call_flag = True
|
|
|
|
|
|
if tools_call is not None:
|
|
|
tool_call_flag = True
|
|
|
if tools_call[0].id is not None:
|
|
|
function_id = tools_call[0].id
|
|
|
if tools_call[0].function.name is not None:
|
|
|
function_name = tools_call[0].function.name
|
|
|
if tools_call[0].function.arguments is not None:
|
|
|
function_arguments += tools_call[0].function.arguments
|
|
|
|
|
|
if content is not None and len(content) > 0:
|
|
|
if tool_call_flag:
|
|
|
content_arguments += content
|
|
|
else:
|
|
|
response_message.append(content)
|
|
|
|
|
|
if self.client_abort:
|
|
|
break
|
|
|
|
|
|
end_time = time.time()
|
|
|
# self.logger.bind(tag=TAG).debug(f"大模型返回时间: {end_time - start_time} 秒, 生成token={content}")
|
|
|
|
|
|
# 处理文本分段和TTS逻辑
|
|
|
# 合并当前全部文本并处理未分割部分
|
|
|
full_text = "".join(response_message)
|
|
|
current_text = full_text[processed_chars:] # 从未处理的位置开始
|
|
|
|
|
|
# 查找最后一个有效标点
|
|
|
punctuations = ("。", "?", "!", ";", ":")
|
|
|
last_punct_pos = -1
|
|
|
for punct in punctuations:
|
|
|
pos = current_text.rfind(punct)
|
|
|
if pos > last_punct_pos:
|
|
|
last_punct_pos = pos
|
|
|
|
|
|
# 找到分割点则处理
|
|
|
if last_punct_pos != -1:
|
|
|
segment_text_raw = current_text[: last_punct_pos + 1]
|
|
|
segment_text = get_string_no_punctuation_or_emoji(
|
|
|
segment_text_raw
|
|
|
)
|
|
|
if segment_text:
|
|
|
text_index += 1
|
|
|
self.recode_first_last_text(segment_text, text_index)
|
|
|
future = self.executor.submit(
|
|
|
self.speak_and_play, segment_text, text_index
|
|
|
)
|
|
|
self.tts_queue.put(future)
|
|
|
# 更新已处理字符位置
|
|
|
processed_chars += len(segment_text_raw)
|
|
|
|
|
|
# 处理function call
|
|
|
if tool_call_flag:
|
|
|
bHasError = False
|
|
|
if function_id is None:
|
|
|
a = extract_json_from_string(content_arguments)
|
|
|
if a is not None:
|
|
|
try:
|
|
|
content_arguments_json = json.loads(a)
|
|
|
function_name = content_arguments_json["name"]
|
|
|
function_arguments = json.dumps(
|
|
|
content_arguments_json["arguments"], ensure_ascii=False
|
|
|
)
|
|
|
function_id = str(uuid.uuid4().hex)
|
|
|
except Exception as e:
|
|
|
bHasError = True
|
|
|
response_message.append(a)
|
|
|
else:
|
|
|
bHasError = True
|
|
|
response_message.append(content_arguments)
|
|
|
if bHasError:
|
|
|
self.logger.bind(tag=TAG).error(
|
|
|
f"function call error: {content_arguments}"
|
|
|
)
|
|
|
else:
|
|
|
function_arguments = json.loads(function_arguments)
|
|
|
if not bHasError:
|
|
|
self.logger.bind(tag=TAG).info(
|
|
|
f"function_name={function_name}, function_id={function_id}, function_arguments={function_arguments}"
|
|
|
)
|
|
|
function_call_data = {
|
|
|
"name": function_name,
|
|
|
"id": function_id,
|
|
|
"arguments": function_arguments,
|
|
|
}
|
|
|
|
|
|
# 处理MCP工具调用
|
|
|
if self.mcp_manager.is_mcp_tool(function_name):
|
|
|
result = self._handle_mcp_tool_call(function_call_data)
|
|
|
else:
|
|
|
# 处理系统函数
|
|
|
result = self.func_handler.handle_llm_function_call(
|
|
|
self, function_call_data
|
|
|
)
|
|
|
self._handle_function_result(result, function_call_data, text_index + 1)
|
|
|
|
|
|
# 处理最后剩余的文本
|
|
|
full_text = "".join(response_message)
|
|
|
remaining_text = full_text[processed_chars:]
|
|
|
if remaining_text:
|
|
|
segment_text = get_string_no_punctuation_or_emoji(remaining_text)
|
|
|
if segment_text:
|
|
|
text_index += 1
|
|
|
self.recode_first_last_text(segment_text, text_index)
|
|
|
future = self.executor.submit(
|
|
|
self.speak_and_play, segment_text, text_index
|
|
|
)
|
|
|
self.tts_queue.put(future)
|
|
|
|
|
|
# 存储对话内容
|
|
|
if len(response_message) > 0:
|
|
|
self.dialogue.put(
|
|
|
Message(role="assistant", content="".join(response_message))
|
|
|
)
|
|
|
|
|
|
self.llm_finish_task = True
|
|
|
self.logger.bind(tag=TAG).debug(
|
|
|
json.dumps(self.dialogue.get_llm_dialogue(), indent=4, ensure_ascii=False)
|
|
|
)
|
|
|
|
|
|
return True
|
|
|
|
|
|
def _handle_mcp_tool_call(self, function_call_data):
|
|
|
function_arguments = function_call_data["arguments"]
|
|
|
function_name = function_call_data["name"]
|
|
|
try:
|
|
|
args_dict = function_arguments
|
|
|
if isinstance(function_arguments, str):
|
|
|
try:
|
|
|
args_dict = json.loads(function_arguments)
|
|
|
except json.JSONDecodeError:
|
|
|
self.logger.bind(tag=TAG).error(
|
|
|
f"无法解析 function_arguments: {function_arguments}"
|
|
|
)
|
|
|
return ActionResponse(
|
|
|
action=Action.REQLLM, result="参数解析失败", response=""
|
|
|
)
|
|
|
|
|
|
tool_result = asyncio.run_coroutine_threadsafe(
|
|
|
self.mcp_manager.execute_tool(function_name, args_dict), self.loop
|
|
|
).result()
|
|
|
# meta=None content=[TextContent(type='text', text='北京当前天气:\n温度: 21°C\n天气: 晴\n湿度: 6%\n风向: 西北 风\n风力等级: 5级', annotations=None)] isError=False
|
|
|
content_text = ""
|
|
|
if tool_result is not None and tool_result.content is not None:
|
|
|
for content in tool_result.content:
|
|
|
content_type = content.type
|
|
|
if content_type == "text":
|
|
|
content_text = content.text
|
|
|
elif content_type == "image":
|
|
|
pass
|
|
|
|
|
|
if len(content_text) > 0:
|
|
|
return ActionResponse(
|
|
|
action=Action.REQLLM, result=content_text, response=""
|
|
|
)
|
|
|
|
|
|
except Exception as e:
|
|
|
self.logger.bind(tag=TAG).error(f"MCP工具调用错误: {e}")
|
|
|
return ActionResponse(
|
|
|
action=Action.REQLLM, result="工具调用出错", response=""
|
|
|
)
|
|
|
|
|
|
return ActionResponse(action=Action.REQLLM, result="工具调用出错", response="")
|
|
|
|
|
|
def _handle_function_result(self, result, function_call_data, text_index):
|
|
|
if result.action == Action.RESPONSE: # 直接回复前端
|
|
|
text = result.response
|
|
|
self.recode_first_last_text(text, text_index)
|
|
|
future = self.executor.submit(self.speak_and_play, text, text_index)
|
|
|
self.tts_queue.put(future)
|
|
|
self.dialogue.put(Message(role="assistant", content=text))
|
|
|
elif result.action == Action.REQLLM: # 调用函数后再请求llm生成回复
|
|
|
|
|
|
text = result.result
|
|
|
if text is not None and len(text) > 0:
|
|
|
function_id = function_call_data["id"]
|
|
|
function_name = function_call_data["name"]
|
|
|
function_arguments = function_call_data["arguments"]
|
|
|
self.dialogue.put(
|
|
|
Message(
|
|
|
role="assistant",
|
|
|
tool_calls=[
|
|
|
{
|
|
|
"id": function_id,
|
|
|
"function": {
|
|
|
"arguments": function_arguments,
|
|
|
"name": function_name,
|
|
|
},
|
|
|
"type": "function",
|
|
|
"index": 0,
|
|
|
}
|
|
|
],
|
|
|
)
|
|
|
)
|
|
|
|
|
|
self.dialogue.put(
|
|
|
Message(role="tool", tool_call_id=function_id, content=text)
|
|
|
)
|
|
|
self.chat_with_function_calling(text, tool_call=True)
|
|
|
elif result.action == Action.NOTFOUND:
|
|
|
text = result.result
|
|
|
self.recode_first_last_text(text, text_index)
|
|
|
future = self.executor.submit(self.speak_and_play, text, text_index)
|
|
|
self.tts_queue.put(future)
|
|
|
self.dialogue.put(Message(role="assistant", content=text))
|
|
|
else:
|
|
|
text = result.result
|
|
|
self.recode_first_last_text(text, text_index)
|
|
|
future = self.executor.submit(self.speak_and_play, text, text_index)
|
|
|
self.tts_queue.put(future)
|
|
|
self.dialogue.put(Message(role="assistant", content=text))
|
|
|
|
|
|
def _tts_priority_thread(self):
|
|
|
while not self.stop_event.is_set():
|
|
|
text = None
|
|
|
try:
|
|
|
try:
|
|
|
future = self.tts_queue.get(timeout=1)
|
|
|
except queue.Empty:
|
|
|
if self.stop_event.is_set():
|
|
|
break
|
|
|
continue
|
|
|
if future is None:
|
|
|
continue
|
|
|
text = None
|
|
|
opus_datas, text_index, tts_file = [], 0, None
|
|
|
try:
|
|
|
self.logger.bind(tag=TAG).debug("正在处理TTS任务...")
|
|
|
tts_timeout = self.config.get("tts_timeout", 10)
|
|
|
tts_file, text, text_index = future.result(timeout=tts_timeout)
|
|
|
if text is None or len(text) <= 0:
|
|
|
self.logger.bind(tag=TAG).error(
|
|
|
f"TTS出错:{text_index}: tts text is empty"
|
|
|
)
|
|
|
elif tts_file is None:
|
|
|
self.logger.bind(tag=TAG).error(
|
|
|
f"TTS出错: file is empty: {text_index}: {text}"
|
|
|
)
|
|
|
else:
|
|
|
self.logger.bind(tag=TAG).debug(
|
|
|
f"TTS生成:文件路径: {tts_file}"
|
|
|
)
|
|
|
if os.path.exists(tts_file):
|
|
|
opus_datas, duration = self.tts.audio_to_opus_data(tts_file)
|
|
|
else:
|
|
|
self.logger.bind(tag=TAG).error(
|
|
|
f"TTS出错:文件不存在{tts_file}"
|
|
|
)
|
|
|
except TimeoutError:
|
|
|
self.logger.bind(tag=TAG).error("TTS超时")
|
|
|
except Exception as e:
|
|
|
self.logger.bind(tag=TAG).error(f"TTS出错: {e}")
|
|
|
if not self.client_abort:
|
|
|
# 如果没有中途打断就发送语音
|
|
|
self.audio_play_queue.put((opus_datas, text, text_index))
|
|
|
if (
|
|
|
self.tts.delete_audio_file
|
|
|
and tts_file is not None
|
|
|
and os.path.exists(tts_file)
|
|
|
):
|
|
|
os.remove(tts_file)
|
|
|
except Exception as e:
|
|
|
self.logger.bind(tag=TAG).error(f"TTS任务处理错误: {e}")
|
|
|
self.clearSpeakStatus()
|
|
|
asyncio.run_coroutine_threadsafe(
|
|
|
self.websocket.send(
|
|
|
json.dumps(
|
|
|
{
|
|
|
"type": "tts",
|
|
|
"state": "stop",
|
|
|
"session_id": self.session_id,
|
|
|
}
|
|
|
)
|
|
|
),
|
|
|
self.loop,
|
|
|
)
|
|
|
self.logger.bind(tag=TAG).error(
|
|
|
f"tts_priority priority_thread: {text} {e}"
|
|
|
)
|
|
|
|
|
|
def _audio_play_priority_thread(self):
|
|
|
while not self.stop_event.is_set():
|
|
|
text = None
|
|
|
try:
|
|
|
try:
|
|
|
opus_datas, text, text_index = self.audio_play_queue.get(timeout=1)
|
|
|
except queue.Empty:
|
|
|
if self.stop_event.is_set():
|
|
|
break
|
|
|
continue
|
|
|
future = asyncio.run_coroutine_threadsafe(
|
|
|
sendAudioMessage(self, opus_datas, text, text_index), self.loop
|
|
|
)
|
|
|
future.result()
|
|
|
except Exception as e:
|
|
|
self.logger.bind(tag=TAG).error(
|
|
|
f"audio_play_priority priority_thread: {text} {e}"
|
|
|
)
|
|
|
|
|
|
def speak_and_play(self, text, text_index=0):
|
|
|
if text is None or len(text) <= 0:
|
|
|
self.logger.bind(tag=TAG).info(f"无需tts转换,query为空,{text}")
|
|
|
return None, text, text_index
|
|
|
tts_file = self.tts.to_tts(text)
|
|
|
if tts_file is None:
|
|
|
self.logger.bind(tag=TAG).error(f"tts转换失败,{text}")
|
|
|
return None, text, text_index
|
|
|
self.logger.bind(tag=TAG).debug(f"TTS 文件生成完毕: {tts_file}")
|
|
|
return tts_file, text, text_index
|
|
|
|
|
|
def clearSpeakStatus(self):
|
|
|
self.logger.bind(tag=TAG).debug(f"清除服务端讲话状态")
|
|
|
self.asr_server_receive = True
|
|
|
self.tts_last_text_index = -1
|
|
|
self.tts_first_text_index = -1
|
|
|
|
|
|
def recode_first_last_text(self, text, text_index=0):
|
|
|
if self.tts_first_text_index == -1:
|
|
|
self.logger.bind(tag=TAG).info(f"大模型说出第一句话: {text}")
|
|
|
self.tts_first_text_index = text_index
|
|
|
self.tts_last_text_index = text_index
|
|
|
|
|
|
async def close(self, ws=None):
|
|
|
"""资源清理方法"""
|
|
|
# 清理MCP资源
|
|
|
if hasattr(self, "mcp_manager") and self.mcp_manager:
|
|
|
await self.mcp_manager.cleanup_all()
|
|
|
|
|
|
# 触发停止事件并清理资源
|
|
|
if self.stop_event:
|
|
|
self.stop_event.set()
|
|
|
|
|
|
# 立即关闭线程池
|
|
|
if self.executor:
|
|
|
self.executor.shutdown(wait=False, cancel_futures=True)
|
|
|
self.executor = None
|
|
|
|
|
|
# 清空任务队列
|
|
|
self._clear_queues()
|
|
|
|
|
|
if ws:
|
|
|
await ws.close()
|
|
|
elif self.websocket:
|
|
|
await self.websocket.close()
|
|
|
self.logger.bind(tag=TAG).info("连接资源已释放")
|
|
|
|
|
|
def _clear_queues(self):
|
|
|
# 清空所有任务队列
|
|
|
for q in [self.tts_queue, self.audio_play_queue]:
|
|
|
if not q:
|
|
|
continue
|
|
|
while not q.empty():
|
|
|
try:
|
|
|
q.get_nowait()
|
|
|
except queue.Empty:
|
|
|
continue
|
|
|
q.queue.clear()
|
|
|
# 添加毒丸信号到队列,确保线程退出
|
|
|
# q.queue.put(None)
|
|
|
|
|
|
def reset_vad_states(self):
|
|
|
self.client_audio_buffer = bytearray()
|
|
|
self.client_have_voice = False
|
|
|
self.client_have_voice_last_time = 0
|
|
|
self.client_voice_stop = False
|
|
|
self.logger.bind(tag=TAG).debug("VAD states reset.")
|
|
|
|
|
|
def chat_and_close(self, text):
|
|
|
"""Chat with the user and then close the connection"""
|
|
|
try:
|
|
|
# Use the existing chat method
|
|
|
self.chat(text)
|
|
|
|
|
|
# After chat is complete, close the connection
|
|
|
self.close_after_chat = True
|
|
|
except Exception as e:
|
|
|
self.logger.bind(tag=TAG).error(f"Chat and close error: {str(e)}")
|