Files
QingLong/XiaoZhi/xiaozhi-esp32-server/main/xiaozhi-server/core/connection.py

939 lines
36 KiB
Python
Raw Normal View History

2025-08-15 09:13:13 +08:00
import os
import copy
import json
import uuid
import time
import queue
import asyncio
import traceback
import threading
import websockets
from typing import Dict, Any
from plugins_func.loadplugins import auto_import_modules
from config.logger import setup_logging
from core.utils.dialogue import Message, Dialogue
from core.handle.textHandle import handleTextMessage
from core.utils.util import (
get_string_no_punctuation_or_emoji,
extract_json_from_string,
get_ip_info,
initialize_modules,
)
from concurrent.futures import ThreadPoolExecutor, TimeoutError
from core.handle.sendAudioHandle import sendAudioMessage
from core.handle.receiveAudioHandle import handleAudioMessage
from core.handle.functionHandler import FunctionHandler
from plugins_func.register import Action, ActionResponse
from core.auth import AuthMiddleware, AuthenticationError
from core.mcp.manager import MCPManager
from config.config_loader import get_private_config_from_api
from config.manage_api_client import DeviceNotFoundException, DeviceBindException
from core.utils.output_counter import add_device_output
TAG = __name__
auto_import_modules("plugins_func.functions")
class TTSException(RuntimeError):
pass
class ConnectionHandler:
def __init__(
self, config: Dict[str, Any], _vad, _asr, _llm, _tts, _memory, _intent
):
self.config = copy.deepcopy(config)
self.logger = setup_logging()
self.auth = AuthMiddleware(config)
self.need_bind = False
self.bind_code = None
self.websocket = None
self.headers = None
self.client_ip = None
self.client_ip_info = {}
self.session_id = None
self.prompt = None
self.welcome_msg = None
self.max_output_size = 0
# 客户端状态相关
self.client_abort = False
self.client_listen_mode = "auto"
# 线程任务相关
self.loop = asyncio.get_event_loop()
self.stop_event = threading.Event()
self.tts_queue = queue.Queue()
self.audio_play_queue = queue.Queue()
self.executor = ThreadPoolExecutor(max_workers=10)
# 依赖的组件
self.vad = _vad
self.asr = _asr
self.llm = _llm
self.tts = _tts
self.memory = _memory
self.intent = _intent
# vad相关变量
self.client_audio_buffer = bytearray()
self.client_have_voice = False
self.client_have_voice_last_time = 0.0
self.client_no_voice_last_time = 0.0
self.client_voice_stop = False
# asr相关变量
self.asr_audio = []
self.asr_server_receive = True
# llm相关变量
self.llm_finish_task = False
self.dialogue = Dialogue()
# tts相关变量
self.tts_first_text_index = -1
self.tts_last_text_index = -1
# iot相关变量
self.iot_descriptors = {}
self.func_handler = None
self.cmd_exit = self.config["exit_commands"]
self.max_cmd_length = 0
for cmd in self.cmd_exit:
if len(cmd) > self.max_cmd_length:
self.max_cmd_length = len(cmd)
self.close_after_chat = False # 是否在聊天结束后关闭连接
self.use_function_call_mode = False
self.timeout_task = None
self.timeout_seconds = (
int(self.config.get("close_connection_no_voice_time", 120)) + 60
) # 在原来第一道关闭的基础上加60秒进行二道关闭
async def handle_connection(self, ws):
try:
# 获取并验证headers
self.headers = dict(ws.request.headers)
if self.headers.get("device-id", None) is None:
# 尝试从 URL 的查询参数中获取 device-id
from urllib.parse import parse_qs, urlparse
# 从 WebSocket 请求中获取路径
request_path = ws.request.path
if not request_path:
self.logger.bind(tag=TAG).error("无法获取请求路径")
return
parsed_url = urlparse(request_path)
query_params = parse_qs(parsed_url.query)
if "device-id" in query_params:
self.headers["device-id"] = query_params["device-id"][0]
self.headers["client-id"] = query_params["client-id"][0]
else:
self.logger.bind(tag=TAG).error(
"无法从请求头和URL查询参数中获取device-id"
)
return
# 获取客户端ip地址
self.client_ip = ws.remote_address[0]
self.logger.bind(tag=TAG).info(
f"{self.client_ip} conn - Headers: {self.headers}"
)
# 进行认证
await self.auth.authenticate(self.headers)
# 认证通过,继续处理
self.websocket = ws
self.session_id = str(uuid.uuid4())
# 启动超时检查任务
self.timeout_task = asyncio.create_task(self._check_timeout())
self.welcome_msg = self.config["xiaozhi"]
self.welcome_msg["session_id"] = self.session_id
await self.websocket.send(json.dumps(self.welcome_msg))
# 获取差异化配置
private_config = self._initialize_private_config()
# 异步初始化
self.executor.submit(self._initialize_components, private_config)
# tts 消化线程
self.tts_priority_thread = threading.Thread(
target=self._tts_priority_thread, daemon=True
)
self.tts_priority_thread.start()
# 音频播放 消化线程
self.audio_play_priority_thread = threading.Thread(
target=self._audio_play_priority_thread, daemon=True
)
self.audio_play_priority_thread.start()
try:
async for message in self.websocket:
await self._route_message(message)
except websockets.exceptions.ConnectionClosed:
self.logger.bind(tag=TAG).info("客户端断开连接")
except AuthenticationError as e:
self.logger.bind(tag=TAG).error(f"Authentication failed: {str(e)}")
return
except Exception as e:
stack_trace = traceback.format_exc()
self.logger.bind(tag=TAG).error(f"Connection error: {str(e)}-{stack_trace}")
return
finally:
await self._save_and_close(ws)
async def _save_and_close(self, ws):
"""保存记忆并关闭连接"""
try:
await self.memory.save_memory(self.dialogue.dialogue)
except Exception as e:
self.logger.bind(tag=TAG).error(f"保存记忆失败: {e}")
finally:
await self.close(ws)
async def _route_message(self, message):
"""消息路由"""
# 重置超时计时器
if self.timeout_task:
self.timeout_task.cancel()
self.timeout_task = asyncio.create_task(self._check_timeout())
if isinstance(message, str):
await handleTextMessage(self, message)
elif isinstance(message, bytes):
await handleAudioMessage(self, message)
def _initialize_components(self, private_config):
"""初始化组件"""
if private_config is not None:
self._initialize_models(private_config)
else:
self.prompt = self.config["prompt"]
self.change_system_prompt(self.prompt)
"""加载记忆"""
self._initialize_memory()
"""加载意图识别"""
self._initialize_intent()
"""加载位置信息"""
self.client_ip_info = get_ip_info(self.client_ip, self.logger)
if self.client_ip_info is not None and "city" in self.client_ip_info:
self.logger.bind(tag=TAG).info(f"Client ip info: {self.client_ip_info}")
self.prompt = self.prompt + f"\nuser location:{self.client_ip_info}"
self.dialogue.update_system_message(self.prompt)
def _initialize_private_config(self):
read_config_from_api = self.config.get("read_config_from_api", False)
"""如果是从配置文件获取,则进行二次实例化"""
if not read_config_from_api:
return
"""从接口获取差异化的配置进行二次实例化,非全量重新实例化"""
try:
begin_time = time.time()
private_config = get_private_config_from_api(
self.config,
self.headers.get("device-id"),
self.headers.get("client-id", self.headers.get("device-id")),
)
private_config["delete_audio"] = bool(self.config.get("delete_audio", True))
self.logger.bind(tag=TAG).info(
f"{time.time() - begin_time} 秒,获取差异化配置成功: {private_config}"
)
except DeviceNotFoundException as e:
self.need_bind = True
private_config = {}
except DeviceBindException as e:
self.need_bind = True
self.bind_code = e.bind_code
private_config = {}
except Exception as e:
self.need_bind = True
self.logger.bind(tag=TAG).error(f"获取差异化配置失败: {e}")
private_config = {}
init_tts = False
if private_config.get("TTS", None) is not None:
init_tts = True
self.config["TTS"] = private_config["TTS"]
self.config["selected_module"]["TTS"] = private_config["selected_module"][
"TTS"
]
try:
modules = initialize_modules(
self.logger,
private_config,
False,
False,
False,
init_tts,
False,
False,
)
except Exception as e:
self.logger.bind(tag=TAG).error(f"初始化组件失败: {e}")
modules = {}
if modules.get("tts", None) is not None:
self.tts = modules["tts"]
if modules.get("prompt", None) is not None:
self.change_system_prompt(modules["prompt"])
private_config["prompt"] = None
return private_config
def _initialize_models(self, private_config):
init_vad, init_asr, init_llm, init_memory, init_intent = (
False,
False,
False,
False,
False,
)
if private_config.get("VAD", None) is not None:
init_vad = True
self.config["VAD"] = private_config["VAD"]
self.config["selected_module"]["VAD"] = private_config["selected_module"][
"VAD"
]
if private_config.get("ASR", None) is not None:
init_asr = True
self.config["ASR"] = private_config["ASR"]
self.config["selected_module"]["ASR"] = private_config["selected_module"][
"ASR"
]
if private_config.get("LLM", None) is not None:
init_llm = True
self.config["LLM"] = private_config["LLM"]
self.config["selected_module"]["LLM"] = private_config["selected_module"][
"LLM"
]
if private_config.get("Memory", None) is not None:
init_memory = True
self.config["Memory"] = private_config["Memory"]
self.config["selected_module"]["Memory"] = private_config[
"selected_module"
]["Memory"]
if private_config.get("Intent", None) is not None:
init_intent = True
self.config["Intent"] = private_config["Intent"]
self.config["selected_module"]["Intent"] = private_config[
"selected_module"
]["Intent"]
if private_config.get("device_max_output_size", None) is not None:
self.max_output_size = int(private_config["device_max_output_size"])
try:
modules = initialize_modules(
self.logger,
private_config,
init_vad,
init_asr,
init_llm,
False,
init_memory,
init_intent,
)
except Exception as e:
self.logger.bind(tag=TAG).error(f"初始化组件失败: {e}")
modules = {}
if modules.get("vad", None) is not None:
self.vad = modules["vad"]
if modules.get("asr", None) is not None:
self.asr = modules["asr"]
if modules.get("llm", None) is not None:
self.llm = modules["llm"]
if modules.get("intent", None) is not None:
self.intent = modules["intent"]
if modules.get("memory", None) is not None:
self.memory = modules["memory"]
def _initialize_memory(self):
"""初始化记忆模块"""
device_id = self.headers.get("device-id", None)
self.memory.init_memory(device_id, self.llm)
def _initialize_intent(self):
if (
self.config["Intent"][self.config["selected_module"]["Intent"]]["type"]
== "function_call"
):
self.use_function_call_mode = True
"""初始化意图识别模块"""
# 获取意图识别配置
intent_config = self.config["Intent"]
intent_type = self.config["Intent"][self.config["selected_module"]["Intent"]][
"type"
]
# 如果使用 nointent直接返回
if intent_type == "nointent":
return
# 使用 intent_llm 模式
elif intent_type == "intent_llm":
intent_llm_name = intent_config[self.config["selected_module"]["Intent"]][
"llm"
]
if intent_llm_name and intent_llm_name in self.config["LLM"]:
# 如果配置了专用LLM则创建独立的LLM实例
from core.utils import llm as llm_utils
intent_llm_config = self.config["LLM"][intent_llm_name]
intent_llm_type = intent_llm_config.get("type", intent_llm_name)
intent_llm = llm_utils.create_instance(
intent_llm_type, intent_llm_config
)
self.logger.bind(tag=TAG).info(
f"为意图识别创建了专用LLM: {intent_llm_name}, 类型: {intent_llm_type}"
)
self.intent.set_llm(intent_llm)
else:
# 否则使用主LLM
self.intent.set_llm(self.llm)
self.logger.bind(tag=TAG).info("使用主LLM作为意图识别模型")
"""加载插件"""
self.func_handler = FunctionHandler(self)
self.mcp_manager = MCPManager(self)
"""加载MCP工具"""
asyncio.run_coroutine_threadsafe(
self.mcp_manager.initialize_servers(), self.loop
)
def change_system_prompt(self, prompt):
self.prompt = prompt
# 更新系统prompt至上下文
self.dialogue.update_system_message(self.prompt)
def chat(self, query):
self.dialogue.put(Message(role="user", content=query))
response_message = []
processed_chars = 0 # 跟踪已处理的字符位置
try:
# 使用带记忆的对话
future = asyncio.run_coroutine_threadsafe(
self.memory.query_memory(query), self.loop
)
memory_str = future.result()
self.logger.bind(tag=TAG).debug(f"记忆内容: {memory_str}")
llm_responses = self.llm.response(
self.session_id, self.dialogue.get_llm_dialogue_with_memory(memory_str)
)
except Exception as e:
self.logger.bind(tag=TAG).error(f"LLM 处理出错 {query}: {e}")
return None
self.llm_finish_task = False
text_index = 0
for content in llm_responses:
response_message.append(content)
if self.client_abort:
break
# 合并当前全部文本并处理未分割部分
full_text = "".join(response_message)
current_text = full_text[processed_chars:] # 从未处理的位置开始
# 查找最后一个有效标点
punctuations = ("", "", "", "", "")
last_punct_pos = -1
for punct in punctuations:
pos = current_text.rfind(punct)
if pos > last_punct_pos:
last_punct_pos = pos
# 找到分割点则处理
if last_punct_pos != -1:
segment_text_raw = current_text[: last_punct_pos + 1]
segment_text = get_string_no_punctuation_or_emoji(segment_text_raw)
if segment_text:
# 强制设置空字符测试TTS出错返回语音的健壮性
# if text_index % 2 == 0:
# segment_text = " "
text_index += 1
self.recode_first_last_text(segment_text, text_index)
future = self.executor.submit(
self.speak_and_play, segment_text, text_index
)
self.tts_queue.put(future)
processed_chars += len(segment_text_raw) # 更新已处理字符位置
# 处理最后剩余的文本
full_text = "".join(response_message)
remaining_text = full_text[processed_chars:]
if remaining_text:
segment_text = get_string_no_punctuation_or_emoji(remaining_text)
if segment_text:
text_index += 1
self.recode_first_last_text(segment_text, text_index)
future = self.executor.submit(
self.speak_and_play, segment_text, text_index
)
self.tts_queue.put(future)
self.llm_finish_task = True
self.dialogue.put(Message(role="assistant", content="".join(response_message)))
self.logger.bind(tag=TAG).debug(
json.dumps(self.dialogue.get_llm_dialogue(), indent=4, ensure_ascii=False)
)
return True
def chat_with_function_calling(self, query, tool_call=False):
self.logger.bind(tag=TAG).debug(f"Chat with function calling start: {query}")
"""Chat with function calling for intent detection using streaming"""
if not tool_call:
self.dialogue.put(Message(role="user", content=query))
# Define intent functions
functions = None
if hasattr(self, "func_handler"):
functions = self.func_handler.get_functions()
response_message = []
processed_chars = 0 # 跟踪已处理的字符位置
try:
start_time = time.time()
# 使用带记忆的对话
future = asyncio.run_coroutine_threadsafe(
self.memory.query_memory(query), self.loop
)
memory_str = future.result()
# self.logger.bind(tag=TAG).info(f"对话记录: {self.dialogue.get_llm_dialogue_with_memory(memory_str)}")
# 使用支持functions的streaming接口
llm_responses = self.llm.response_with_functions(
self.session_id,
self.dialogue.get_llm_dialogue_with_memory(memory_str),
functions=functions,
)
except Exception as e:
self.logger.bind(tag=TAG).error(f"LLM 处理出错 {query}: {e}")
return None
self.llm_finish_task = False
text_index = 0
# 处理流式响应
tool_call_flag = False
function_name = None
function_id = None
function_arguments = ""
content_arguments = ""
for response in llm_responses:
content, tools_call = response
if "content" in response:
content = response["content"]
tools_call = None
if content is not None and len(content) > 0:
content_arguments += content
if not tool_call_flag and content_arguments.startswith("<tool_call>"):
# print("content_arguments", content_arguments)
tool_call_flag = True
if tools_call is not None:
tool_call_flag = True
if tools_call[0].id is not None:
function_id = tools_call[0].id
if tools_call[0].function.name is not None:
function_name = tools_call[0].function.name
if tools_call[0].function.arguments is not None:
function_arguments += tools_call[0].function.arguments
if content is not None and len(content) > 0:
if not tool_call_flag:
response_message.append(content)
if self.client_abort:
break
end_time = time.time()
# self.logger.bind(tag=TAG).debug(f"大模型返回时间: {end_time - start_time} 秒, 生成token={content}")
# 处理文本分段和TTS逻辑
# 合并当前全部文本并处理未分割部分
full_text = "".join(response_message)
current_text = full_text[processed_chars:] # 从未处理的位置开始
# 查找最后一个有效标点
punctuations = ("", "", "", "", "")
last_punct_pos = -1
for punct in punctuations:
pos = current_text.rfind(punct)
if pos > last_punct_pos:
last_punct_pos = pos
# 找到分割点则处理
if last_punct_pos != -1:
segment_text_raw = current_text[: last_punct_pos + 1]
segment_text = get_string_no_punctuation_or_emoji(
segment_text_raw
)
if segment_text:
text_index += 1
self.recode_first_last_text(segment_text, text_index)
future = self.executor.submit(
self.speak_and_play, segment_text, text_index
)
self.tts_queue.put(future)
# 更新已处理字符位置
processed_chars += len(segment_text_raw)
# 处理function call
if tool_call_flag:
bHasError = False
if function_id is None:
a = extract_json_from_string(content_arguments)
if a is not None:
try:
content_arguments_json = json.loads(a)
function_name = content_arguments_json["name"]
function_arguments = json.dumps(
content_arguments_json["arguments"], ensure_ascii=False
)
function_id = str(uuid.uuid4().hex)
except Exception as e:
bHasError = True
response_message.append(a)
else:
bHasError = True
response_message.append(content_arguments)
if bHasError:
self.logger.bind(tag=TAG).error(
f"function call error: {content_arguments}"
)
if not bHasError:
response_message.clear()
self.logger.bind(tag=TAG).debug(
f"function_name={function_name}, function_id={function_id}, function_arguments={function_arguments}"
)
function_call_data = {
"name": function_name,
"id": function_id,
"arguments": function_arguments,
}
# 处理MCP工具调用
if self.mcp_manager.is_mcp_tool(function_name):
result = self._handle_mcp_tool_call(function_call_data)
else:
# 处理系统函数
result = self.func_handler.handle_llm_function_call(
self, function_call_data
)
self._handle_function_result(result, function_call_data, text_index + 1)
# 处理最后剩余的文本
full_text = "".join(response_message)
remaining_text = full_text[processed_chars:]
if remaining_text:
segment_text = get_string_no_punctuation_or_emoji(remaining_text)
if segment_text:
text_index += 1
self.recode_first_last_text(segment_text, text_index)
future = self.executor.submit(
self.speak_and_play, segment_text, text_index
)
self.tts_queue.put(future)
# 存储对话内容
if len(response_message) > 0:
self.dialogue.put(
Message(role="assistant", content="".join(response_message))
)
self.llm_finish_task = True
self.logger.bind(tag=TAG).debug(
json.dumps(self.dialogue.get_llm_dialogue(), indent=4, ensure_ascii=False)
)
return True
def _handle_mcp_tool_call(self, function_call_data):
function_arguments = function_call_data["arguments"]
function_name = function_call_data["name"]
try:
args_dict = function_arguments
if isinstance(function_arguments, str):
try:
args_dict = json.loads(function_arguments)
except json.JSONDecodeError:
self.logger.bind(tag=TAG).error(
f"无法解析 function_arguments: {function_arguments}"
)
return ActionResponse(
action=Action.REQLLM, result="参数解析失败", response=""
)
tool_result = asyncio.run_coroutine_threadsafe(
self.mcp_manager.execute_tool(function_name, args_dict), self.loop
).result()
# meta=None content=[TextContent(type='text', text='北京当前天气:\n温度: 21°C\n天气: 晴\n湿度: 6%\n风向: 西北 风\n风力等级: 5级', annotations=None)] isError=False
content_text = ""
if tool_result is not None and tool_result.content is not None:
for content in tool_result.content:
content_type = content.type
if content_type == "text":
content_text = content.text
elif content_type == "image":
pass
if len(content_text) > 0:
return ActionResponse(
action=Action.REQLLM, result=content_text, response=""
)
except Exception as e:
self.logger.bind(tag=TAG).error(f"MCP工具调用错误: {e}")
return ActionResponse(
action=Action.REQLLM, result="工具调用出错", response=""
)
return ActionResponse(action=Action.REQLLM, result="工具调用出错", response="")
def _handle_function_result(self, result, function_call_data, text_index):
if result.action == Action.RESPONSE: # 直接回复前端
text = result.response
self.recode_first_last_text(text, text_index)
future = self.executor.submit(self.speak_and_play, text, text_index)
self.tts_queue.put(future)
self.dialogue.put(Message(role="assistant", content=text))
elif result.action == Action.REQLLM: # 调用函数后再请求llm生成回复
text = result.result
if text is not None and len(text) > 0:
function_id = function_call_data["id"]
function_name = function_call_data["name"]
function_arguments = function_call_data["arguments"]
self.dialogue.put(
Message(
role="assistant",
tool_calls=[
{
"id": function_id,
"function": {
"arguments": function_arguments,
"name": function_name,
},
"type": "function",
"index": 0,
}
],
)
)
self.dialogue.put(
Message(role="tool", tool_call_id=function_id, content=text)
)
self.chat_with_function_calling(text, tool_call=True)
elif result.action == Action.NOTFOUND or result.action == Action.ERROR:
text = result.result
self.recode_first_last_text(text, text_index)
future = self.executor.submit(self.speak_and_play, text, text_index)
self.tts_queue.put(future)
self.dialogue.put(Message(role="assistant", content=text))
else:
pass
def _tts_priority_thread(self):
while not self.stop_event.is_set():
text = None
try:
try:
future = self.tts_queue.get(timeout=1)
except queue.Empty:
if self.stop_event.is_set():
break
continue
if future is None:
continue
text = None
opus_datas, text_index, tts_file = [], 0, None
try:
self.logger.bind(tag=TAG).debug("正在处理TTS任务...")
tts_timeout = int(self.config.get("tts_timeout", 10))
tts_file, text, text_index = future.result(timeout=tts_timeout)
if text is None or len(text) <= 0:
self.logger.bind(tag=TAG).error(
f"TTS出错{text_index}: tts text is empty"
)
elif tts_file is None:
self.logger.bind(tag=TAG).error(
f"TTS出错 file is empty: {text_index}: {text}"
)
else:
self.logger.bind(tag=TAG).debug(
f"TTS生成文件路径: {tts_file}"
)
if os.path.exists(tts_file):
opus_datas, duration = self.tts.audio_to_opus_data(tts_file)
else:
self.logger.bind(tag=TAG).error(
f"TTS出错文件不存在{tts_file}"
)
except TimeoutError:
self.logger.bind(tag=TAG).error("TTS超时")
except Exception as e:
self.logger.bind(tag=TAG).error(f"TTS出错: {e}")
if not self.client_abort:
# 如果没有中途打断就发送语音
self.audio_play_queue.put((opus_datas, text, text_index))
if (
self.tts.delete_audio_file
and tts_file is not None
and os.path.exists(tts_file)
):
os.remove(tts_file)
except Exception as e:
self.logger.bind(tag=TAG).error(f"TTS任务处理错误: {e}")
self.clearSpeakStatus()
asyncio.run_coroutine_threadsafe(
self.websocket.send(
json.dumps(
{
"type": "tts",
"state": "stop",
"session_id": self.session_id,
}
)
),
self.loop,
)
self.logger.bind(tag=TAG).error(
f"tts_priority priority_thread: {text} {e}"
)
def _audio_play_priority_thread(self):
while not self.stop_event.is_set():
text = None
try:
try:
opus_datas, text, text_index = self.audio_play_queue.get(timeout=1)
except queue.Empty:
if self.stop_event.is_set():
break
continue
future = asyncio.run_coroutine_threadsafe(
sendAudioMessage(self, opus_datas, text, text_index), self.loop
)
future.result()
except Exception as e:
self.logger.bind(tag=TAG).error(
f"audio_play_priority priority_thread: {text} {e}"
)
def speak_and_play(self, text, text_index=0):
if text is None or len(text) <= 0:
self.logger.bind(tag=TAG).info(f"无需tts转换query为空{text}")
return None, text, text_index
tts_file = self.tts.to_tts(text)
if tts_file is None:
self.logger.bind(tag=TAG).error(f"tts转换失败{text}")
return None, text, text_index
self.logger.bind(tag=TAG).debug(f"TTS 文件生成完毕: {tts_file}")
if self.max_output_size > 0:
add_device_output(self.headers.get("device-id"), len(text))
return tts_file, text, text_index
def clearSpeakStatus(self):
self.logger.bind(tag=TAG).debug(f"清除服务端讲话状态")
self.asr_server_receive = True
self.tts_last_text_index = -1
self.tts_first_text_index = -1
def recode_first_last_text(self, text, text_index=0):
if self.tts_first_text_index == -1:
self.logger.bind(tag=TAG).info(f"大模型说出第一句话: {text}")
self.tts_first_text_index = text_index
self.tts_last_text_index = text_index
async def close(self, ws=None):
"""资源清理方法"""
# 取消超时任务
if self.timeout_task:
self.timeout_task.cancel()
self.timeout_task = None
# 清理MCP资源
if hasattr(self, "mcp_manager") and self.mcp_manager:
await self.mcp_manager.cleanup_all()
# 触发停止事件并清理资源
if self.stop_event:
self.stop_event.set()
# 立即关闭线程池
if self.executor:
self.executor.shutdown(wait=False, cancel_futures=True)
self.executor = None
# 清空任务队列
self._clear_queues()
if ws:
await ws.close()
elif self.websocket:
await self.websocket.close()
self.logger.bind(tag=TAG).info("连接资源已释放")
def _clear_queues(self):
# 清空所有任务队列
for q in [self.tts_queue, self.audio_play_queue]:
if not q:
continue
while not q.empty():
try:
q.get_nowait()
except queue.Empty:
continue
q.queue.clear()
# 添加毒丸信号到队列,确保线程退出
# q.queue.put(None)
def reset_vad_states(self):
self.client_audio_buffer = bytearray()
self.client_have_voice = False
self.client_have_voice_last_time = 0
self.client_voice_stop = False
self.logger.bind(tag=TAG).debug("VAD states reset.")
def chat_and_close(self, text):
"""Chat with the user and then close the connection"""
try:
# Use the existing chat method
self.chat(text)
# After chat is complete, close the connection
self.close_after_chat = True
except Exception as e:
self.logger.bind(tag=TAG).error(f"Chat and close error: {str(e)}")
async def _check_timeout(self):
"""检查连接超时"""
try:
while not self.stop_event.is_set():
await asyncio.sleep(self.timeout_seconds)
if not self.stop_event.is_set():
self.logger.bind(tag=TAG).info("连接超时,准备关闭")
await self.close(self.websocket)
break
except Exception as e:
self.logger.bind(tag=TAG).error(f"超时检查任务出错: {e}")