Files
QingLong/XiaoZhi/xiaozhi-esp32-server/main/xiaozhi-server/core/connection.py
2025-08-15 09:13:13 +08:00

939 lines
36 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import copy
import json
import uuid
import time
import queue
import asyncio
import traceback
import threading
import websockets
from typing import Dict, Any
from plugins_func.loadplugins import auto_import_modules
from config.logger import setup_logging
from core.utils.dialogue import Message, Dialogue
from core.handle.textHandle import handleTextMessage
from core.utils.util import (
get_string_no_punctuation_or_emoji,
extract_json_from_string,
get_ip_info,
initialize_modules,
)
from concurrent.futures import ThreadPoolExecutor, TimeoutError
from core.handle.sendAudioHandle import sendAudioMessage
from core.handle.receiveAudioHandle import handleAudioMessage
from core.handle.functionHandler import FunctionHandler
from plugins_func.register import Action, ActionResponse
from core.auth import AuthMiddleware, AuthenticationError
from core.mcp.manager import MCPManager
from config.config_loader import get_private_config_from_api
from config.manage_api_client import DeviceNotFoundException, DeviceBindException
from core.utils.output_counter import add_device_output
TAG = __name__
auto_import_modules("plugins_func.functions")
class TTSException(RuntimeError):
pass
class ConnectionHandler:
def __init__(
self, config: Dict[str, Any], _vad, _asr, _llm, _tts, _memory, _intent
):
self.config = copy.deepcopy(config)
self.logger = setup_logging()
self.auth = AuthMiddleware(config)
self.need_bind = False
self.bind_code = None
self.websocket = None
self.headers = None
self.client_ip = None
self.client_ip_info = {}
self.session_id = None
self.prompt = None
self.welcome_msg = None
self.max_output_size = 0
# 客户端状态相关
self.client_abort = False
self.client_listen_mode = "auto"
# 线程任务相关
self.loop = asyncio.get_event_loop()
self.stop_event = threading.Event()
self.tts_queue = queue.Queue()
self.audio_play_queue = queue.Queue()
self.executor = ThreadPoolExecutor(max_workers=10)
# 依赖的组件
self.vad = _vad
self.asr = _asr
self.llm = _llm
self.tts = _tts
self.memory = _memory
self.intent = _intent
# vad相关变量
self.client_audio_buffer = bytearray()
self.client_have_voice = False
self.client_have_voice_last_time = 0.0
self.client_no_voice_last_time = 0.0
self.client_voice_stop = False
# asr相关变量
self.asr_audio = []
self.asr_server_receive = True
# llm相关变量
self.llm_finish_task = False
self.dialogue = Dialogue()
# tts相关变量
self.tts_first_text_index = -1
self.tts_last_text_index = -1
# iot相关变量
self.iot_descriptors = {}
self.func_handler = None
self.cmd_exit = self.config["exit_commands"]
self.max_cmd_length = 0
for cmd in self.cmd_exit:
if len(cmd) > self.max_cmd_length:
self.max_cmd_length = len(cmd)
self.close_after_chat = False # 是否在聊天结束后关闭连接
self.use_function_call_mode = False
self.timeout_task = None
self.timeout_seconds = (
int(self.config.get("close_connection_no_voice_time", 120)) + 60
) # 在原来第一道关闭的基础上加60秒进行二道关闭
async def handle_connection(self, ws):
try:
# 获取并验证headers
self.headers = dict(ws.request.headers)
if self.headers.get("device-id", None) is None:
# 尝试从 URL 的查询参数中获取 device-id
from urllib.parse import parse_qs, urlparse
# 从 WebSocket 请求中获取路径
request_path = ws.request.path
if not request_path:
self.logger.bind(tag=TAG).error("无法获取请求路径")
return
parsed_url = urlparse(request_path)
query_params = parse_qs(parsed_url.query)
if "device-id" in query_params:
self.headers["device-id"] = query_params["device-id"][0]
self.headers["client-id"] = query_params["client-id"][0]
else:
self.logger.bind(tag=TAG).error(
"无法从请求头和URL查询参数中获取device-id"
)
return
# 获取客户端ip地址
self.client_ip = ws.remote_address[0]
self.logger.bind(tag=TAG).info(
f"{self.client_ip} conn - Headers: {self.headers}"
)
# 进行认证
await self.auth.authenticate(self.headers)
# 认证通过,继续处理
self.websocket = ws
self.session_id = str(uuid.uuid4())
# 启动超时检查任务
self.timeout_task = asyncio.create_task(self._check_timeout())
self.welcome_msg = self.config["xiaozhi"]
self.welcome_msg["session_id"] = self.session_id
await self.websocket.send(json.dumps(self.welcome_msg))
# 获取差异化配置
private_config = self._initialize_private_config()
# 异步初始化
self.executor.submit(self._initialize_components, private_config)
# tts 消化线程
self.tts_priority_thread = threading.Thread(
target=self._tts_priority_thread, daemon=True
)
self.tts_priority_thread.start()
# 音频播放 消化线程
self.audio_play_priority_thread = threading.Thread(
target=self._audio_play_priority_thread, daemon=True
)
self.audio_play_priority_thread.start()
try:
async for message in self.websocket:
await self._route_message(message)
except websockets.exceptions.ConnectionClosed:
self.logger.bind(tag=TAG).info("客户端断开连接")
except AuthenticationError as e:
self.logger.bind(tag=TAG).error(f"Authentication failed: {str(e)}")
return
except Exception as e:
stack_trace = traceback.format_exc()
self.logger.bind(tag=TAG).error(f"Connection error: {str(e)}-{stack_trace}")
return
finally:
await self._save_and_close(ws)
async def _save_and_close(self, ws):
"""保存记忆并关闭连接"""
try:
await self.memory.save_memory(self.dialogue.dialogue)
except Exception as e:
self.logger.bind(tag=TAG).error(f"保存记忆失败: {e}")
finally:
await self.close(ws)
async def _route_message(self, message):
"""消息路由"""
# 重置超时计时器
if self.timeout_task:
self.timeout_task.cancel()
self.timeout_task = asyncio.create_task(self._check_timeout())
if isinstance(message, str):
await handleTextMessage(self, message)
elif isinstance(message, bytes):
await handleAudioMessage(self, message)
def _initialize_components(self, private_config):
"""初始化组件"""
if private_config is not None:
self._initialize_models(private_config)
else:
self.prompt = self.config["prompt"]
self.change_system_prompt(self.prompt)
"""加载记忆"""
self._initialize_memory()
"""加载意图识别"""
self._initialize_intent()
"""加载位置信息"""
self.client_ip_info = get_ip_info(self.client_ip, self.logger)
if self.client_ip_info is not None and "city" in self.client_ip_info:
self.logger.bind(tag=TAG).info(f"Client ip info: {self.client_ip_info}")
self.prompt = self.prompt + f"\nuser location:{self.client_ip_info}"
self.dialogue.update_system_message(self.prompt)
def _initialize_private_config(self):
read_config_from_api = self.config.get("read_config_from_api", False)
"""如果是从配置文件获取,则进行二次实例化"""
if not read_config_from_api:
return
"""从接口获取差异化的配置进行二次实例化,非全量重新实例化"""
try:
begin_time = time.time()
private_config = get_private_config_from_api(
self.config,
self.headers.get("device-id"),
self.headers.get("client-id", self.headers.get("device-id")),
)
private_config["delete_audio"] = bool(self.config.get("delete_audio", True))
self.logger.bind(tag=TAG).info(
f"{time.time() - begin_time} 秒,获取差异化配置成功: {private_config}"
)
except DeviceNotFoundException as e:
self.need_bind = True
private_config = {}
except DeviceBindException as e:
self.need_bind = True
self.bind_code = e.bind_code
private_config = {}
except Exception as e:
self.need_bind = True
self.logger.bind(tag=TAG).error(f"获取差异化配置失败: {e}")
private_config = {}
init_tts = False
if private_config.get("TTS", None) is not None:
init_tts = True
self.config["TTS"] = private_config["TTS"]
self.config["selected_module"]["TTS"] = private_config["selected_module"][
"TTS"
]
try:
modules = initialize_modules(
self.logger,
private_config,
False,
False,
False,
init_tts,
False,
False,
)
except Exception as e:
self.logger.bind(tag=TAG).error(f"初始化组件失败: {e}")
modules = {}
if modules.get("tts", None) is not None:
self.tts = modules["tts"]
if modules.get("prompt", None) is not None:
self.change_system_prompt(modules["prompt"])
private_config["prompt"] = None
return private_config
def _initialize_models(self, private_config):
init_vad, init_asr, init_llm, init_memory, init_intent = (
False,
False,
False,
False,
False,
)
if private_config.get("VAD", None) is not None:
init_vad = True
self.config["VAD"] = private_config["VAD"]
self.config["selected_module"]["VAD"] = private_config["selected_module"][
"VAD"
]
if private_config.get("ASR", None) is not None:
init_asr = True
self.config["ASR"] = private_config["ASR"]
self.config["selected_module"]["ASR"] = private_config["selected_module"][
"ASR"
]
if private_config.get("LLM", None) is not None:
init_llm = True
self.config["LLM"] = private_config["LLM"]
self.config["selected_module"]["LLM"] = private_config["selected_module"][
"LLM"
]
if private_config.get("Memory", None) is not None:
init_memory = True
self.config["Memory"] = private_config["Memory"]
self.config["selected_module"]["Memory"] = private_config[
"selected_module"
]["Memory"]
if private_config.get("Intent", None) is not None:
init_intent = True
self.config["Intent"] = private_config["Intent"]
self.config["selected_module"]["Intent"] = private_config[
"selected_module"
]["Intent"]
if private_config.get("device_max_output_size", None) is not None:
self.max_output_size = int(private_config["device_max_output_size"])
try:
modules = initialize_modules(
self.logger,
private_config,
init_vad,
init_asr,
init_llm,
False,
init_memory,
init_intent,
)
except Exception as e:
self.logger.bind(tag=TAG).error(f"初始化组件失败: {e}")
modules = {}
if modules.get("vad", None) is not None:
self.vad = modules["vad"]
if modules.get("asr", None) is not None:
self.asr = modules["asr"]
if modules.get("llm", None) is not None:
self.llm = modules["llm"]
if modules.get("intent", None) is not None:
self.intent = modules["intent"]
if modules.get("memory", None) is not None:
self.memory = modules["memory"]
def _initialize_memory(self):
"""初始化记忆模块"""
device_id = self.headers.get("device-id", None)
self.memory.init_memory(device_id, self.llm)
def _initialize_intent(self):
if (
self.config["Intent"][self.config["selected_module"]["Intent"]]["type"]
== "function_call"
):
self.use_function_call_mode = True
"""初始化意图识别模块"""
# 获取意图识别配置
intent_config = self.config["Intent"]
intent_type = self.config["Intent"][self.config["selected_module"]["Intent"]][
"type"
]
# 如果使用 nointent直接返回
if intent_type == "nointent":
return
# 使用 intent_llm 模式
elif intent_type == "intent_llm":
intent_llm_name = intent_config[self.config["selected_module"]["Intent"]][
"llm"
]
if intent_llm_name and intent_llm_name in self.config["LLM"]:
# 如果配置了专用LLM则创建独立的LLM实例
from core.utils import llm as llm_utils
intent_llm_config = self.config["LLM"][intent_llm_name]
intent_llm_type = intent_llm_config.get("type", intent_llm_name)
intent_llm = llm_utils.create_instance(
intent_llm_type, intent_llm_config
)
self.logger.bind(tag=TAG).info(
f"为意图识别创建了专用LLM: {intent_llm_name}, 类型: {intent_llm_type}"
)
self.intent.set_llm(intent_llm)
else:
# 否则使用主LLM
self.intent.set_llm(self.llm)
self.logger.bind(tag=TAG).info("使用主LLM作为意图识别模型")
"""加载插件"""
self.func_handler = FunctionHandler(self)
self.mcp_manager = MCPManager(self)
"""加载MCP工具"""
asyncio.run_coroutine_threadsafe(
self.mcp_manager.initialize_servers(), self.loop
)
def change_system_prompt(self, prompt):
self.prompt = prompt
# 更新系统prompt至上下文
self.dialogue.update_system_message(self.prompt)
def chat(self, query):
self.dialogue.put(Message(role="user", content=query))
response_message = []
processed_chars = 0 # 跟踪已处理的字符位置
try:
# 使用带记忆的对话
future = asyncio.run_coroutine_threadsafe(
self.memory.query_memory(query), self.loop
)
memory_str = future.result()
self.logger.bind(tag=TAG).debug(f"记忆内容: {memory_str}")
llm_responses = self.llm.response(
self.session_id, self.dialogue.get_llm_dialogue_with_memory(memory_str)
)
except Exception as e:
self.logger.bind(tag=TAG).error(f"LLM 处理出错 {query}: {e}")
return None
self.llm_finish_task = False
text_index = 0
for content in llm_responses:
response_message.append(content)
if self.client_abort:
break
# 合并当前全部文本并处理未分割部分
full_text = "".join(response_message)
current_text = full_text[processed_chars:] # 从未处理的位置开始
# 查找最后一个有效标点
punctuations = ("", "", "", "", "")
last_punct_pos = -1
for punct in punctuations:
pos = current_text.rfind(punct)
if pos > last_punct_pos:
last_punct_pos = pos
# 找到分割点则处理
if last_punct_pos != -1:
segment_text_raw = current_text[: last_punct_pos + 1]
segment_text = get_string_no_punctuation_or_emoji(segment_text_raw)
if segment_text:
# 强制设置空字符测试TTS出错返回语音的健壮性
# if text_index % 2 == 0:
# segment_text = " "
text_index += 1
self.recode_first_last_text(segment_text, text_index)
future = self.executor.submit(
self.speak_and_play, segment_text, text_index
)
self.tts_queue.put(future)
processed_chars += len(segment_text_raw) # 更新已处理字符位置
# 处理最后剩余的文本
full_text = "".join(response_message)
remaining_text = full_text[processed_chars:]
if remaining_text:
segment_text = get_string_no_punctuation_or_emoji(remaining_text)
if segment_text:
text_index += 1
self.recode_first_last_text(segment_text, text_index)
future = self.executor.submit(
self.speak_and_play, segment_text, text_index
)
self.tts_queue.put(future)
self.llm_finish_task = True
self.dialogue.put(Message(role="assistant", content="".join(response_message)))
self.logger.bind(tag=TAG).debug(
json.dumps(self.dialogue.get_llm_dialogue(), indent=4, ensure_ascii=False)
)
return True
def chat_with_function_calling(self, query, tool_call=False):
self.logger.bind(tag=TAG).debug(f"Chat with function calling start: {query}")
"""Chat with function calling for intent detection using streaming"""
if not tool_call:
self.dialogue.put(Message(role="user", content=query))
# Define intent functions
functions = None
if hasattr(self, "func_handler"):
functions = self.func_handler.get_functions()
response_message = []
processed_chars = 0 # 跟踪已处理的字符位置
try:
start_time = time.time()
# 使用带记忆的对话
future = asyncio.run_coroutine_threadsafe(
self.memory.query_memory(query), self.loop
)
memory_str = future.result()
# self.logger.bind(tag=TAG).info(f"对话记录: {self.dialogue.get_llm_dialogue_with_memory(memory_str)}")
# 使用支持functions的streaming接口
llm_responses = self.llm.response_with_functions(
self.session_id,
self.dialogue.get_llm_dialogue_with_memory(memory_str),
functions=functions,
)
except Exception as e:
self.logger.bind(tag=TAG).error(f"LLM 处理出错 {query}: {e}")
return None
self.llm_finish_task = False
text_index = 0
# 处理流式响应
tool_call_flag = False
function_name = None
function_id = None
function_arguments = ""
content_arguments = ""
for response in llm_responses:
content, tools_call = response
if "content" in response:
content = response["content"]
tools_call = None
if content is not None and len(content) > 0:
content_arguments += content
if not tool_call_flag and content_arguments.startswith("<tool_call>"):
# print("content_arguments", content_arguments)
tool_call_flag = True
if tools_call is not None:
tool_call_flag = True
if tools_call[0].id is not None:
function_id = tools_call[0].id
if tools_call[0].function.name is not None:
function_name = tools_call[0].function.name
if tools_call[0].function.arguments is not None:
function_arguments += tools_call[0].function.arguments
if content is not None and len(content) > 0:
if not tool_call_flag:
response_message.append(content)
if self.client_abort:
break
end_time = time.time()
# self.logger.bind(tag=TAG).debug(f"大模型返回时间: {end_time - start_time} 秒, 生成token={content}")
# 处理文本分段和TTS逻辑
# 合并当前全部文本并处理未分割部分
full_text = "".join(response_message)
current_text = full_text[processed_chars:] # 从未处理的位置开始
# 查找最后一个有效标点
punctuations = ("", "", "", "", "")
last_punct_pos = -1
for punct in punctuations:
pos = current_text.rfind(punct)
if pos > last_punct_pos:
last_punct_pos = pos
# 找到分割点则处理
if last_punct_pos != -1:
segment_text_raw = current_text[: last_punct_pos + 1]
segment_text = get_string_no_punctuation_or_emoji(
segment_text_raw
)
if segment_text:
text_index += 1
self.recode_first_last_text(segment_text, text_index)
future = self.executor.submit(
self.speak_and_play, segment_text, text_index
)
self.tts_queue.put(future)
# 更新已处理字符位置
processed_chars += len(segment_text_raw)
# 处理function call
if tool_call_flag:
bHasError = False
if function_id is None:
a = extract_json_from_string(content_arguments)
if a is not None:
try:
content_arguments_json = json.loads(a)
function_name = content_arguments_json["name"]
function_arguments = json.dumps(
content_arguments_json["arguments"], ensure_ascii=False
)
function_id = str(uuid.uuid4().hex)
except Exception as e:
bHasError = True
response_message.append(a)
else:
bHasError = True
response_message.append(content_arguments)
if bHasError:
self.logger.bind(tag=TAG).error(
f"function call error: {content_arguments}"
)
if not bHasError:
response_message.clear()
self.logger.bind(tag=TAG).debug(
f"function_name={function_name}, function_id={function_id}, function_arguments={function_arguments}"
)
function_call_data = {
"name": function_name,
"id": function_id,
"arguments": function_arguments,
}
# 处理MCP工具调用
if self.mcp_manager.is_mcp_tool(function_name):
result = self._handle_mcp_tool_call(function_call_data)
else:
# 处理系统函数
result = self.func_handler.handle_llm_function_call(
self, function_call_data
)
self._handle_function_result(result, function_call_data, text_index + 1)
# 处理最后剩余的文本
full_text = "".join(response_message)
remaining_text = full_text[processed_chars:]
if remaining_text:
segment_text = get_string_no_punctuation_or_emoji(remaining_text)
if segment_text:
text_index += 1
self.recode_first_last_text(segment_text, text_index)
future = self.executor.submit(
self.speak_and_play, segment_text, text_index
)
self.tts_queue.put(future)
# 存储对话内容
if len(response_message) > 0:
self.dialogue.put(
Message(role="assistant", content="".join(response_message))
)
self.llm_finish_task = True
self.logger.bind(tag=TAG).debug(
json.dumps(self.dialogue.get_llm_dialogue(), indent=4, ensure_ascii=False)
)
return True
def _handle_mcp_tool_call(self, function_call_data):
function_arguments = function_call_data["arguments"]
function_name = function_call_data["name"]
try:
args_dict = function_arguments
if isinstance(function_arguments, str):
try:
args_dict = json.loads(function_arguments)
except json.JSONDecodeError:
self.logger.bind(tag=TAG).error(
f"无法解析 function_arguments: {function_arguments}"
)
return ActionResponse(
action=Action.REQLLM, result="参数解析失败", response=""
)
tool_result = asyncio.run_coroutine_threadsafe(
self.mcp_manager.execute_tool(function_name, args_dict), self.loop
).result()
# meta=None content=[TextContent(type='text', text='北京当前天气:\n温度: 21°C\n天气: 晴\n湿度: 6%\n风向: 西北 风\n风力等级: 5级', annotations=None)] isError=False
content_text = ""
if tool_result is not None and tool_result.content is not None:
for content in tool_result.content:
content_type = content.type
if content_type == "text":
content_text = content.text
elif content_type == "image":
pass
if len(content_text) > 0:
return ActionResponse(
action=Action.REQLLM, result=content_text, response=""
)
except Exception as e:
self.logger.bind(tag=TAG).error(f"MCP工具调用错误: {e}")
return ActionResponse(
action=Action.REQLLM, result="工具调用出错", response=""
)
return ActionResponse(action=Action.REQLLM, result="工具调用出错", response="")
def _handle_function_result(self, result, function_call_data, text_index):
if result.action == Action.RESPONSE: # 直接回复前端
text = result.response
self.recode_first_last_text(text, text_index)
future = self.executor.submit(self.speak_and_play, text, text_index)
self.tts_queue.put(future)
self.dialogue.put(Message(role="assistant", content=text))
elif result.action == Action.REQLLM: # 调用函数后再请求llm生成回复
text = result.result
if text is not None and len(text) > 0:
function_id = function_call_data["id"]
function_name = function_call_data["name"]
function_arguments = function_call_data["arguments"]
self.dialogue.put(
Message(
role="assistant",
tool_calls=[
{
"id": function_id,
"function": {
"arguments": function_arguments,
"name": function_name,
},
"type": "function",
"index": 0,
}
],
)
)
self.dialogue.put(
Message(role="tool", tool_call_id=function_id, content=text)
)
self.chat_with_function_calling(text, tool_call=True)
elif result.action == Action.NOTFOUND or result.action == Action.ERROR:
text = result.result
self.recode_first_last_text(text, text_index)
future = self.executor.submit(self.speak_and_play, text, text_index)
self.tts_queue.put(future)
self.dialogue.put(Message(role="assistant", content=text))
else:
pass
def _tts_priority_thread(self):
while not self.stop_event.is_set():
text = None
try:
try:
future = self.tts_queue.get(timeout=1)
except queue.Empty:
if self.stop_event.is_set():
break
continue
if future is None:
continue
text = None
opus_datas, text_index, tts_file = [], 0, None
try:
self.logger.bind(tag=TAG).debug("正在处理TTS任务...")
tts_timeout = int(self.config.get("tts_timeout", 10))
tts_file, text, text_index = future.result(timeout=tts_timeout)
if text is None or len(text) <= 0:
self.logger.bind(tag=TAG).error(
f"TTS出错{text_index}: tts text is empty"
)
elif tts_file is None:
self.logger.bind(tag=TAG).error(
f"TTS出错 file is empty: {text_index}: {text}"
)
else:
self.logger.bind(tag=TAG).debug(
f"TTS生成文件路径: {tts_file}"
)
if os.path.exists(tts_file):
opus_datas, duration = self.tts.audio_to_opus_data(tts_file)
else:
self.logger.bind(tag=TAG).error(
f"TTS出错文件不存在{tts_file}"
)
except TimeoutError:
self.logger.bind(tag=TAG).error("TTS超时")
except Exception as e:
self.logger.bind(tag=TAG).error(f"TTS出错: {e}")
if not self.client_abort:
# 如果没有中途打断就发送语音
self.audio_play_queue.put((opus_datas, text, text_index))
if (
self.tts.delete_audio_file
and tts_file is not None
and os.path.exists(tts_file)
):
os.remove(tts_file)
except Exception as e:
self.logger.bind(tag=TAG).error(f"TTS任务处理错误: {e}")
self.clearSpeakStatus()
asyncio.run_coroutine_threadsafe(
self.websocket.send(
json.dumps(
{
"type": "tts",
"state": "stop",
"session_id": self.session_id,
}
)
),
self.loop,
)
self.logger.bind(tag=TAG).error(
f"tts_priority priority_thread: {text} {e}"
)
def _audio_play_priority_thread(self):
while not self.stop_event.is_set():
text = None
try:
try:
opus_datas, text, text_index = self.audio_play_queue.get(timeout=1)
except queue.Empty:
if self.stop_event.is_set():
break
continue
future = asyncio.run_coroutine_threadsafe(
sendAudioMessage(self, opus_datas, text, text_index), self.loop
)
future.result()
except Exception as e:
self.logger.bind(tag=TAG).error(
f"audio_play_priority priority_thread: {text} {e}"
)
def speak_and_play(self, text, text_index=0):
if text is None or len(text) <= 0:
self.logger.bind(tag=TAG).info(f"无需tts转换query为空{text}")
return None, text, text_index
tts_file = self.tts.to_tts(text)
if tts_file is None:
self.logger.bind(tag=TAG).error(f"tts转换失败{text}")
return None, text, text_index
self.logger.bind(tag=TAG).debug(f"TTS 文件生成完毕: {tts_file}")
if self.max_output_size > 0:
add_device_output(self.headers.get("device-id"), len(text))
return tts_file, text, text_index
def clearSpeakStatus(self):
self.logger.bind(tag=TAG).debug(f"清除服务端讲话状态")
self.asr_server_receive = True
self.tts_last_text_index = -1
self.tts_first_text_index = -1
def recode_first_last_text(self, text, text_index=0):
if self.tts_first_text_index == -1:
self.logger.bind(tag=TAG).info(f"大模型说出第一句话: {text}")
self.tts_first_text_index = text_index
self.tts_last_text_index = text_index
async def close(self, ws=None):
"""资源清理方法"""
# 取消超时任务
if self.timeout_task:
self.timeout_task.cancel()
self.timeout_task = None
# 清理MCP资源
if hasattr(self, "mcp_manager") and self.mcp_manager:
await self.mcp_manager.cleanup_all()
# 触发停止事件并清理资源
if self.stop_event:
self.stop_event.set()
# 立即关闭线程池
if self.executor:
self.executor.shutdown(wait=False, cancel_futures=True)
self.executor = None
# 清空任务队列
self._clear_queues()
if ws:
await ws.close()
elif self.websocket:
await self.websocket.close()
self.logger.bind(tag=TAG).info("连接资源已释放")
def _clear_queues(self):
# 清空所有任务队列
for q in [self.tts_queue, self.audio_play_queue]:
if not q:
continue
while not q.empty():
try:
q.get_nowait()
except queue.Empty:
continue
q.queue.clear()
# 添加毒丸信号到队列,确保线程退出
# q.queue.put(None)
def reset_vad_states(self):
self.client_audio_buffer = bytearray()
self.client_have_voice = False
self.client_have_voice_last_time = 0
self.client_voice_stop = False
self.logger.bind(tag=TAG).debug("VAD states reset.")
def chat_and_close(self, text):
"""Chat with the user and then close the connection"""
try:
# Use the existing chat method
self.chat(text)
# After chat is complete, close the connection
self.close_after_chat = True
except Exception as e:
self.logger.bind(tag=TAG).error(f"Chat and close error: {str(e)}")
async def _check_timeout(self):
"""检查连接超时"""
try:
while not self.stop_event.is_set():
await asyncio.sleep(self.timeout_seconds)
if not self.stop_event.is_set():
self.logger.bind(tag=TAG).info("连接超时,准备关闭")
await self.close(self.websocket)
break
except Exception as e:
self.logger.bind(tag=TAG).error(f"超时检查任务出错: {e}")