You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1063 lines
42 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import os
import sys
import copy
import json
import uuid
import time
import queue
import asyncio
import threading
import traceback
import subprocess
import websockets
from core.utils.util import (
extract_json_from_string,
check_vad_update,
check_asr_update,
filter_sensitive_info,
)
from typing import Dict, Any
from collections import deque
from core.utils.modules_initialize import (
initialize_modules,
initialize_tts,
initialize_asr,
)
from core.handle.reportHandle import report
from core.utils.modules_initialize import initialize_voiceprint
from core.providers.tts.default import DefaultTTS
from concurrent.futures import ThreadPoolExecutor
from core.utils.dialogue import Message, Dialogue
from core.providers.asr.dto.dto import InterfaceType
from core.handle.textHandle import handleTextMessage
from core.providers.tools.unified_tool_handler import UnifiedToolHandler
from plugins_func.loadplugins import auto_import_modules
from plugins_func.register import Action, ActionResponse
from core.auth import AuthMiddleware, AuthenticationError
from config.config_loader import get_private_config_from_api
from core.providers.tts.dto.dto import ContentType, TTSMessageDTO, SentenceType
from config.logger import setup_logging, build_module_string, create_connection_logger
from config.manage_api_client import DeviceNotFoundException, DeviceBindException
from core.utils.prompt_manager import PromptManager
TAG = __name__
auto_import_modules("plugins_func.functions")
class TTSException(RuntimeError):
pass
class ConnectionHandler:
def __init__(
self,
config: Dict[str, Any],
_vad,
_asr,
_llm,
_memory,
_intent,
server=None,
):
self.common_config = config
self.config = copy.deepcopy(config)
self.session_id = str(uuid.uuid4())
self.logger = setup_logging()
self.server = server # 保存server实例的引用
self.auth = AuthMiddleware(config)
self.need_bind = False
self.bind_code = None
self.read_config_from_api = self.config.get("read_config_from_api", False)
self.websocket = None
self.headers = None
self.device_id = None
self.client_ip = None
self.prompt = None
self.welcome_msg = None
self.max_output_size = 0
self.chat_history_conf = 0
self.audio_format = "opus"
# 客户端状态相关
self.client_abort = False
self.client_is_speaking = False
self.client_listen_mode = "auto"
# 线程任务相关
self.loop = asyncio.get_event_loop()
self.stop_event = threading.Event()
self.executor = ThreadPoolExecutor(max_workers=5)
# 添加上报线程池
self.report_queue = queue.Queue()
self.report_thread = None
# 未来可以通过修改此处调节asr的上报和tts的上报目前默认都开启
self.report_asr_enable = self.read_config_from_api
self.report_tts_enable = self.read_config_from_api
# 依赖的组件
self.vad = None
self.asr = None
self.tts = None
self._asr = _asr
self._vad = _vad
self.llm = _llm
self.memory = _memory
self.intent = _intent
# vad相关变量
self.client_audio_buffer = bytearray()
self.client_have_voice = False
self.last_activity_time = 0.0 # 统一的活动时间戳(毫秒)
self.client_voice_stop = False
self.client_voice_window = deque(maxlen=5)
self.last_is_voice = False
# asr相关变量
# 因为实际部署时可能会用到公共的本地ASR不能把变量暴露给公共ASR
# 所以涉及到ASR的变量需要在这里定义属于connection的私有变量
self.asr_audio = []
self.asr_audio_queue = queue.Queue()
# llm相关变量
self.llm_finish_task = True
self.dialogue = Dialogue()
# tts相关变量
self.sentence_id = None
# iot相关变量
self.iot_descriptors = {}
self.func_handler = None
self.cmd_exit = self.config["exit_commands"]
self.max_cmd_length = 0
for cmd in self.cmd_exit:
if len(cmd) > self.max_cmd_length:
self.max_cmd_length = len(cmd)
# 是否在聊天结束后关闭连接
self.close_after_chat = False
self.load_function_plugin = False
self.intent_type = "nointent"
self.timeout_seconds = (
int(self.config.get("close_connection_no_voice_time", 120)) + 60
) # 在原来第一道关闭的基础上加60秒进行二道关闭
self.timeout_task = None
# {"mcp":true} 表示启用MCP功能
self.features = None
# 初始化提示词管理器
self.prompt_manager = PromptManager(config, self.logger)
async def handle_connection(self, ws):
try:
# 获取并验证headers
self.headers = dict(ws.request.headers)
if self.headers.get("device-id", None) is None:
# 尝试从 URL 的查询参数中获取 device-id
from urllib.parse import parse_qs, urlparse
# 从 WebSocket 请求中获取路径
request_path = ws.request.path
if not request_path:
self.logger.bind(tag=TAG).error("无法获取请求路径")
return
parsed_url = urlparse(request_path)
query_params = parse_qs(parsed_url.query)
if "device-id" in query_params:
self.headers["device-id"] = query_params["device-id"][0]
self.headers["client-id"] = query_params["client-id"][0]
else:
await ws.send("端口正常如需测试连接请使用test_page.html")
await self.close(ws)
return
# 获取客户端ip地址
self.client_ip = ws.remote_address[0]
self.logger.bind(tag=TAG).info(
f"{self.client_ip} conn - Headers: {self.headers}"
)
# 进行认证
await self.auth.authenticate(self.headers)
# 认证通过,继续处理
self.websocket = ws
self.device_id = self.headers.get("device-id", None)
# 初始化活动时间戳
self.last_activity_time = time.time() * 1000
# 启动超时检查任务
self.timeout_task = asyncio.create_task(self._check_timeout())
self.welcome_msg = self.config["xiaozhi"]
self.welcome_msg["session_id"] = self.session_id
# 获取差异化配置
self._initialize_private_config()
# 异步初始化
self.executor.submit(self._initialize_components)
try:
async for message in self.websocket:
await self._route_message(message)
except websockets.exceptions.ConnectionClosed:
self.logger.bind(tag=TAG).info("客户端断开连接")
except AuthenticationError as e:
self.logger.bind(tag=TAG).error(f"Authentication failed: {str(e)}")
return
except Exception as e:
stack_trace = traceback.format_exc()
self.logger.bind(tag=TAG).error(f"Connection error: {str(e)}-{stack_trace}")
return
finally:
try:
await self._save_and_close(ws)
except Exception as final_error:
self.logger.bind(tag=TAG).error(f"最终清理时出错: {final_error}")
# 确保即使保存记忆失败,也要关闭连接
try:
await self.close(ws)
except Exception as close_error:
self.logger.bind(tag=TAG).error(
f"强制关闭连接时出错: {close_error}"
)
async def _save_and_close(self, ws):
"""保存记忆并关闭连接"""
try:
if self.memory:
# 使用线程池异步保存记忆
def save_memory_task():
try:
# 创建新事件循环(避免与主循环冲突)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(
self.memory.save_memory(self.dialogue.dialogue)
)
except Exception as e:
self.logger.bind(tag=TAG).error(f"保存记忆失败: {e}")
finally:
try:
loop.close()
except Exception:
pass
# 启动线程保存记忆,不等待完成
threading.Thread(target=save_memory_task, daemon=True).start()
except Exception as e:
self.logger.bind(tag=TAG).error(f"保存记忆失败: {e}")
finally:
# 立即关闭连接,不等待记忆保存完成
try:
await self.close(ws)
except Exception as close_error:
self.logger.bind(tag=TAG).error(
f"保存记忆后关闭连接失败: {close_error}"
)
async def _route_message(self, message):
"""消息路由"""
if isinstance(message, str):
self.last_activity_time = time.time() * 1000
await handleTextMessage(self, message)
elif isinstance(message, bytes):
if self.vad is None:
return
if self.asr is None:
return
self.asr_audio_queue.put(message)
async def handle_restart(self, message):
"""处理服务器重启请求"""
try:
self.logger.bind(tag=TAG).info("收到服务器重启指令,准备执行...")
# 发送确认响应
await self.websocket.send(
json.dumps(
{
"type": "server",
"status": "success",
"message": "服务器重启中...",
"content": {"action": "restart"},
}
)
)
# 异步执行重启操作
def restart_server():
"""实际执行重启的方法"""
time.sleep(1)
self.logger.bind(tag=TAG).info("执行服务器重启...")
subprocess.Popen(
[sys.executable, "app.py"],
stdin=sys.stdin,
stdout=sys.stdout,
stderr=sys.stderr,
start_new_session=True,
)
os._exit(0)
# 使用线程执行重启避免阻塞事件循环
threading.Thread(target=restart_server, daemon=True).start()
except Exception as e:
self.logger.bind(tag=TAG).error(f"重启失败: {str(e)}")
await self.websocket.send(
json.dumps(
{
"type": "server",
"status": "error",
"message": f"Restart failed: {str(e)}",
"content": {"action": "restart"},
}
)
)
def _initialize_components(self):
try:
self.selected_module_str = build_module_string(
self.config.get("selected_module", {})
)
self.logger = create_connection_logger(self.selected_module_str)
"""初始化组件"""
if self.config.get("prompt") is not None:
user_prompt = self.config["prompt"]
# 使用快速提示词进行初始化
prompt = self.prompt_manager.get_quick_prompt(user_prompt)
self.change_system_prompt(prompt)
self.logger.bind(tag=TAG).info(
f"快速初始化组件: prompt成功 {prompt[:50]}..."
)
"""初始化本地组件"""
if self.vad is None:
self.vad = self._vad
if self.asr is None:
self.asr = self._initialize_asr()
# 打开语音识别通道
asyncio.run_coroutine_threadsafe(
self.asr.open_audio_channels(self), self.loop
)
if self.tts is None:
self.tts = self._initialize_tts()
# 打开语音合成通道
asyncio.run_coroutine_threadsafe(
self.tts.open_audio_channels(self), self.loop
)
"""加载记忆"""
self._initialize_memory()
"""加载意图识别"""
self._initialize_intent()
"""初始化上报线程"""
self._init_report_threads()
"""更新系统提示词"""
self._init_prompt_enhancement()
except Exception as e:
self.logger.bind(tag=TAG).error(f"实例化组件失败: {e}")
def _init_prompt_enhancement(self):
# 更新上下文信息
self.prompt_manager.update_context_info(self, self.client_ip)
enhanced_prompt = self.prompt_manager.build_enhanced_prompt(
self.config["prompt"], self.device_id, self.client_ip
)
if enhanced_prompt:
self.change_system_prompt(enhanced_prompt)
self.logger.bind(tag=TAG).info("系统提示词已增强更新")
def _init_report_threads(self):
"""初始化ASR和TTS上报线程"""
if not self.read_config_from_api or self.need_bind:
return
if self.chat_history_conf == 0:
return
if self.report_thread is None or not self.report_thread.is_alive():
self.report_thread = threading.Thread(
target=self._report_worker, daemon=True
)
self.report_thread.start()
self.logger.bind(tag=TAG).info("TTS上报线程已启动")
def _initialize_tts(self):
"""初始化TTS"""
tts = None
if not self.need_bind:
tts = initialize_tts(self.config)
if tts is None:
tts = DefaultTTS(self.config, delete_audio_file=True)
return tts
def _initialize_asr(self):
"""初始化ASR"""
if self._asr.interface_type == InterfaceType.LOCAL:
# 如果公共ASR是本地服务则直接返回
# 因为本地一个实例ASR可以被多个连接共享
asr = self._asr
else:
# 如果公共ASR是远程服务则初始化一个新实例
# 因为远程ASR涉及到websocket连接和接收线程需要每个连接一个实例
asr = initialize_asr(self.config)
# 动态初始化声纹识别功能
try:
success = initialize_voiceprint(asr, self.config)
if success:
self.logger.bind(tag=TAG).info("声纹识别功能已在连接时动态启用")
else:
self.logger.bind(tag=TAG).info("声纹识别功能未启用或配置不完整")
except Exception as e:
self.logger.bind(tag=TAG).error(f"动态初始化声纹识别时发生错误: {str(e)}")
return asr
def _initialize_private_config(self):
"""如果是从配置文件获取,则进行二次实例化"""
if not self.read_config_from_api:
return
"""从接口获取差异化的配置进行二次实例化,非全量重新实例化"""
try:
begin_time = time.time()
private_config = get_private_config_from_api(
self.config,
self.headers.get("device-id"),
self.headers.get("client-id", self.headers.get("device-id")),
)
private_config["delete_audio"] = bool(self.config.get("delete_audio", True))
self.logger.bind(tag=TAG).info(
f"{time.time() - begin_time} 秒,获取差异化配置成功: {json.dumps(filter_sensitive_info(private_config), ensure_ascii=False)}"
)
except DeviceNotFoundException as e:
self.need_bind = True
private_config = {}
except DeviceBindException as e:
self.need_bind = True
self.bind_code = e.bind_code
private_config = {}
except Exception as e:
self.need_bind = True
self.logger.bind(tag=TAG).error(f"获取差异化配置失败: {e}")
private_config = {}
init_llm, init_tts, init_memory, init_intent = (
False,
False,
False,
False,
)
init_vad = check_vad_update(self.common_config, private_config)
init_asr = check_asr_update(self.common_config, private_config)
if init_vad:
self.config["VAD"] = private_config["VAD"]
self.config["selected_module"]["VAD"] = private_config["selected_module"][
"VAD"
]
if init_asr:
self.config["ASR"] = private_config["ASR"]
self.config["selected_module"]["ASR"] = private_config["selected_module"][
"ASR"
]
if private_config.get("TTS", None) is not None:
init_tts = True
self.config["TTS"] = private_config["TTS"]
self.config["selected_module"]["TTS"] = private_config["selected_module"][
"TTS"
]
if private_config.get("LLM", None) is not None:
init_llm = True
self.config["LLM"] = private_config["LLM"]
self.config["selected_module"]["LLM"] = private_config["selected_module"][
"LLM"
]
if private_config.get("Memory", None) is not None:
init_memory = True
self.config["Memory"] = private_config["Memory"]
self.config["selected_module"]["Memory"] = private_config[
"selected_module"
]["Memory"]
if private_config.get("Intent", None) is not None:
init_intent = True
self.config["Intent"] = private_config["Intent"]
model_intent = private_config.get("selected_module", {}).get("Intent", {})
self.config["selected_module"]["Intent"] = model_intent
# 加载插件配置
if model_intent != "Intent_nointent":
plugin_from_server = private_config.get("plugins", {})
for plugin, config_str in plugin_from_server.items():
plugin_from_server[plugin] = json.loads(config_str)
self.config["plugins"] = plugin_from_server
self.config["Intent"][self.config["selected_module"]["Intent"]][
"functions"
] = plugin_from_server.keys()
if private_config.get("prompt", None) is not None:
self.config["prompt"] = private_config["prompt"]
# 获取声纹信息
if private_config.get("voiceprint", None) is not None:
self.config["voiceprint"] = private_config["voiceprint"]
if private_config.get("summaryMemory", None) is not None:
self.config["summaryMemory"] = private_config["summaryMemory"]
if private_config.get("device_max_output_size", None) is not None:
self.max_output_size = int(private_config["device_max_output_size"])
if private_config.get("chat_history_conf", None) is not None:
self.chat_history_conf = int(private_config["chat_history_conf"])
if private_config.get("mcp_endpoint", None) is not None:
self.config["mcp_endpoint"] = private_config["mcp_endpoint"]
try:
modules = initialize_modules(
self.logger,
private_config,
init_vad,
init_asr,
init_llm,
init_tts,
init_memory,
init_intent,
)
except Exception as e:
self.logger.bind(tag=TAG).error(f"初始化组件失败: {e}")
modules = {}
if modules.get("tts", None) is not None:
self.tts = modules["tts"]
if modules.get("vad", None) is not None:
self.vad = modules["vad"]
if modules.get("asr", None) is not None:
self.asr = modules["asr"]
if modules.get("llm", None) is not None:
self.llm = modules["llm"]
if modules.get("intent", None) is not None:
self.intent = modules["intent"]
if modules.get("memory", None) is not None:
self.memory = modules["memory"]
def _initialize_memory(self):
if self.memory is None:
return
"""初始化记忆模块"""
self.memory.init_memory(
role_id=self.device_id,
llm=self.llm,
summary_memory=self.config.get("summaryMemory", None),
save_to_file=not self.read_config_from_api,
)
# 获取记忆总结配置
memory_config = self.config["Memory"]
memory_type = self.config["Memory"][self.config["selected_module"]["Memory"]][
"type"
]
# 如果使用 nomen直接返回
if memory_type == "nomem":
return
# 使用 mem_local_short 模式
elif memory_type == "mem_local_short":
memory_llm_name = memory_config[self.config["selected_module"]["Memory"]][
"llm"
]
if memory_llm_name and memory_llm_name in self.config["LLM"]:
# 如果配置了专用LLM则创建独立的LLM实例
from core.utils import llm as llm_utils
memory_llm_config = self.config["LLM"][memory_llm_name]
memory_llm_type = memory_llm_config.get("type", memory_llm_name)
memory_llm = llm_utils.create_instance(
memory_llm_type, memory_llm_config
)
self.logger.bind(tag=TAG).info(
f"为记忆总结创建了专用LLM: {memory_llm_name}, 类型: {memory_llm_type}"
)
self.memory.set_llm(memory_llm)
else:
# 否则使用主LLM
self.memory.set_llm(self.llm)
self.logger.bind(tag=TAG).info("使用主LLM作为意图识别模型")
def _initialize_intent(self):
if self.intent is None:
return
self.intent_type = self.config["Intent"][
self.config["selected_module"]["Intent"]
]["type"]
if self.intent_type == "function_call" or self.intent_type == "intent_llm":
self.load_function_plugin = True
"""初始化意图识别模块"""
# 获取意图识别配置
intent_config = self.config["Intent"]
intent_type = self.config["Intent"][self.config["selected_module"]["Intent"]][
"type"
]
# 如果使用 nointent直接返回
if intent_type == "nointent":
return
# 使用 intent_llm 模式
elif intent_type == "intent_llm":
intent_llm_name = intent_config[self.config["selected_module"]["Intent"]][
"llm"
]
if intent_llm_name and intent_llm_name in self.config["LLM"]:
# 如果配置了专用LLM则创建独立的LLM实例
from core.utils import llm as llm_utils
intent_llm_config = self.config["LLM"][intent_llm_name]
intent_llm_type = intent_llm_config.get("type", intent_llm_name)
intent_llm = llm_utils.create_instance(
intent_llm_type, intent_llm_config
)
self.logger.bind(tag=TAG).info(
f"为意图识别创建了专用LLM: {intent_llm_name}, 类型: {intent_llm_type}"
)
self.intent.set_llm(intent_llm)
else:
# 否则使用主LLM
self.intent.set_llm(self.llm)
self.logger.bind(tag=TAG).info("使用主LLM作为意图识别模型")
"""加载统一工具处理器"""
self.func_handler = UnifiedToolHandler(self)
# 异步初始化工具处理器
if hasattr(self, "loop") and self.loop:
asyncio.run_coroutine_threadsafe(self.func_handler._initialize(), self.loop)
def change_system_prompt(self, prompt):
self.prompt = prompt
# 更新系统prompt至上下文
self.dialogue.update_system_message(self.prompt)
def chat(self, query, tool_call=False, depth=0):
self.logger.bind(tag=TAG).info(f"大模型收到用户消息: {query}")
self.llm_finish_task = False
# 【黄海修改】
# query = '讲一下史记中项羽的故事'
query = f"""【系统指令】请特别注意:
1. 当用户问题包含"史记"时,仅回复:"你真是问对人了,我读过《少年读史记》五本书呢!"
2. 其他情况请正常回答问题
3. 本指令优先级最高
【当前用户问题】:{query}"""
if not tool_call:
self.dialogue.put(Message(role="user", content=query))
# 为最顶层时新建会话ID和发送FIRST请求
if depth == 0:
self.sentence_id = str(uuid.uuid4().hex)
self.tts.tts_text_queue.put(
TTSMessageDTO(
sentence_id=self.sentence_id,
sentence_type=SentenceType.FIRST,
content_type=ContentType.ACTION,
)
)
# Define intent functions
functions = None
if self.intent_type == "function_call" and hasattr(self, "func_handler"):
functions = self.func_handler.get_functions()
response_message = []
try:
# 使用带记忆的对话
memory_str = None
if self.memory is not None:
future = asyncio.run_coroutine_threadsafe(
self.memory.query_memory(query), self.loop
)
memory_str = future.result()
if self.intent_type == "function_call" and functions is not None:
# 使用支持functions的streaming接口
llm_responses = self.llm.response_with_functions(
self.session_id,
self.dialogue.get_llm_dialogue_with_memory(
memory_str, self.config.get("voiceprint", {})
),
functions=functions,
)
else:
llm_responses = self.llm.response(
self.session_id,
self.dialogue.get_llm_dialogue_with_memory(
memory_str, self.config.get("voiceprint", {})
),
)
except Exception as e:
self.logger.bind(tag=TAG).error(f"LLM 处理出错 {query}: {e}")
return None
# 处理流式响应
tool_call_flag = False
function_name = None
function_id = None
function_arguments = ""
content_arguments = ""
self.client_abort = False
for response in llm_responses:
if self.client_abort:
break
if self.intent_type == "function_call" and functions is not None:
content, tools_call = response
if "content" in response:
content = response["content"]
tools_call = None
if content is not None and len(content) > 0:
content_arguments += content
if not tool_call_flag and content_arguments.startswith("<tool_call>"):
# print("content_arguments", content_arguments)
tool_call_flag = True
if tools_call is not None and len(tools_call) > 0:
tool_call_flag = True
if tools_call[0].id is not None:
function_id = tools_call[0].id
if tools_call[0].function.name is not None:
function_name = tools_call[0].function.name
if tools_call[0].function.arguments is not None:
function_arguments += tools_call[0].function.arguments
else:
content = response
if content is not None and len(content) > 0:
if not tool_call_flag:
response_message.append(content)
self.tts.tts_text_queue.put(
TTSMessageDTO(
sentence_id=self.sentence_id,
sentence_type=SentenceType.MIDDLE,
content_type=ContentType.TEXT,
content_detail=content,
)
)
# 处理function call
if tool_call_flag:
bHasError = False
if function_id is None:
a = extract_json_from_string(content_arguments)
if a is not None:
try:
content_arguments_json = json.loads(a)
function_name = content_arguments_json["name"]
function_arguments = json.dumps(
content_arguments_json["arguments"], ensure_ascii=False
)
function_id = str(uuid.uuid4().hex)
except Exception as e:
bHasError = True
response_message.append(a)
else:
bHasError = True
response_message.append(content_arguments)
if bHasError:
self.logger.bind(tag=TAG).error(
f"function call error: {content_arguments}"
)
if not bHasError:
# 如需要大模型先处理一轮,添加相关处理后的日志情况
if len(response_message) > 0:
self.dialogue.put(
Message(role="assistant", content="".join(response_message))
)
response_message.clear()
self.logger.bind(tag=TAG).debug(
f"function_name={function_name}, function_id={function_id}, function_arguments={function_arguments}"
)
function_call_data = {
"name": function_name,
"id": function_id,
"arguments": function_arguments,
}
# 使用统一工具处理器处理所有工具调用
result = asyncio.run_coroutine_threadsafe(
self.func_handler.handle_llm_function_call(
self, function_call_data
),
self.loop,
).result()
self._handle_function_result(result, function_call_data, depth=depth)
# 存储对话内容
if len(response_message) > 0:
self.dialogue.put(
Message(role="assistant", content="".join(response_message))
)
if depth == 0:
self.tts.tts_text_queue.put(
TTSMessageDTO(
sentence_id=self.sentence_id,
sentence_type=SentenceType.LAST,
content_type=ContentType.ACTION,
)
)
self.llm_finish_task = True
# 使用lambda延迟计算只有在DEBUG级别时才执行get_llm_dialogue()
self.logger.bind(tag=TAG).debug(
lambda: json.dumps(
self.dialogue.get_llm_dialogue(), indent=4, ensure_ascii=False
)
)
return True
def _handle_function_result(self, result, function_call_data, depth):
if result.action == Action.RESPONSE: # 直接回复前端
text = result.response
self.tts.tts_one_sentence(self, ContentType.TEXT, content_detail=text)
self.dialogue.put(Message(role="assistant", content=text))
elif result.action == Action.REQLLM: # 调用函数后再请求llm生成回复
text = result.result
if text is not None and len(text) > 0:
function_id = function_call_data["id"]
function_name = function_call_data["name"]
function_arguments = function_call_data["arguments"]
self.dialogue.put(
Message(
role="assistant",
tool_calls=[
{
"id": function_id,
"function": {
"arguments": function_arguments,
"name": function_name,
},
"type": "function",
"index": 0,
}
],
)
)
self.dialogue.put(
Message(
role="tool",
tool_call_id=(
str(uuid.uuid4()) if function_id is None else function_id
),
content=text,
)
)
self.chat(text, tool_call=True, depth=depth + 1)
elif result.action == Action.NOTFOUND or result.action == Action.ERROR:
text = result.response if result.response else result.result
self.tts.tts_one_sentence(self, ContentType.TEXT, content_detail=text)
self.dialogue.put(Message(role="assistant", content=text))
else:
pass
def _report_worker(self):
"""聊天记录上报工作线程"""
while not self.stop_event.is_set():
try:
# 从队列获取数据,设置超时以便定期检查停止事件
item = self.report_queue.get(timeout=1)
if item is None: # 检测毒丸对象
break
try:
# 检查线程池状态
if self.executor is None:
continue
# 提交任务到线程池
self.executor.submit(
self._process_report, *item
)
except Exception as e:
self.logger.bind(tag=TAG).error(f"聊天记录上报线程异常: {e}")
except queue.Empty:
continue
except Exception as e:
self.logger.bind(tag=TAG).error(f"聊天记录上报工作线程异常: {e}")
self.logger.bind(tag=TAG).info("聊天记录上报线程已退出")
def _process_report(self, type, text, audio_data, report_time):
"""处理上报任务"""
try:
# 执行上报(传入二进制数据)
report(self, type, text, audio_data, report_time)
except Exception as e:
self.logger.bind(tag=TAG).error(f"上报处理异常: {e}")
finally:
# 标记任务完成
self.report_queue.task_done()
def clearSpeakStatus(self):
self.client_is_speaking = False
self.logger.bind(tag=TAG).debug(f"清除服务端讲话状态")
async def close(self, ws=None):
"""资源清理方法"""
try:
# 取消超时任务
if self.timeout_task and not self.timeout_task.done():
self.timeout_task.cancel()
try:
await self.timeout_task
except asyncio.CancelledError:
pass
self.timeout_task = None
# 清理工具处理器资源
if hasattr(self, "func_handler") and self.func_handler:
try:
await self.func_handler.cleanup()
except Exception as cleanup_error:
self.logger.bind(tag=TAG).error(
f"清理工具处理器时出错: {cleanup_error}"
)
# 触发停止事件
if self.stop_event:
self.stop_event.set()
# 清空任务队列
self.clear_queues()
# 关闭WebSocket连接
try:
if ws:
# 安全地检查WebSocket状态并关闭
try:
if hasattr(ws, "closed") and not ws.closed:
await ws.close()
elif hasattr(ws, "state") and ws.state.name != "CLOSED":
await ws.close()
else:
# 如果没有closed属性直接尝试关闭
await ws.close()
except Exception:
# 如果关闭失败,忽略错误
pass
elif self.websocket:
try:
if (
hasattr(self.websocket, "closed")
and not self.websocket.closed
):
await self.websocket.close()
elif (
hasattr(self.websocket, "state")
and self.websocket.state.name != "CLOSED"
):
await self.websocket.close()
else:
# 如果没有closed属性直接尝试关闭
await self.websocket.close()
except Exception:
# 如果关闭失败,忽略错误
pass
except Exception as ws_error:
self.logger.bind(tag=TAG).error(f"关闭WebSocket连接时出错: {ws_error}")
if self.tts:
await self.tts.close()
# 最后关闭线程池(避免阻塞)
if self.executor:
try:
self.executor.shutdown(wait=False)
except Exception as executor_error:
self.logger.bind(tag=TAG).error(
f"关闭线程池时出错: {executor_error}"
)
self.executor = None
self.logger.bind(tag=TAG).info("连接资源已释放")
except Exception as e:
self.logger.bind(tag=TAG).error(f"关闭连接时出错: {e}")
finally:
# 确保停止事件被设置
if self.stop_event:
self.stop_event.set()
def clear_queues(self):
"""清空所有任务队列"""
if self.tts:
self.logger.bind(tag=TAG).debug(
f"开始清理: TTS队列大小={self.tts.tts_text_queue.qsize()}, 音频队列大小={self.tts.tts_audio_queue.qsize()}"
)
# 使用非阻塞方式清空队列
for q in [
self.tts.tts_text_queue,
self.tts.tts_audio_queue,
self.report_queue,
]:
if not q:
continue
while True:
try:
q.get_nowait()
except queue.Empty:
break
self.logger.bind(tag=TAG).debug(
f"清理结束: TTS队列大小={self.tts.tts_text_queue.qsize()}, 音频队列大小={self.tts.tts_audio_queue.qsize()}"
)
def reset_vad_states(self):
self.client_audio_buffer = bytearray()
self.client_have_voice = False
self.client_voice_stop = False
self.logger.bind(tag=TAG).debug("VAD states reset.")
def chat_and_close(self, text):
"""Chat with the user and then close the connection"""
try:
# Use the existing chat method
self.chat(text)
# After chat is complete, close the connection
self.close_after_chat = True
except Exception as e:
self.logger.bind(tag=TAG).error(f"Chat and close error: {str(e)}")
async def _check_timeout(self):
"""检查连接超时"""
try:
while not self.stop_event.is_set():
# 检查是否超时(只有在时间戳已初始化的情况下)
if self.last_activity_time > 0.0:
current_time = time.time() * 1000
if (
current_time - self.last_activity_time
> self.timeout_seconds * 1000
):
if not self.stop_event.is_set():
self.logger.bind(tag=TAG).info("连接超时,准备关闭")
# 设置停止事件,防止重复处理
self.stop_event.set()
# 使用 try-except 包装关闭操作,确保不会因为异常而阻塞
try:
await self.close(self.websocket)
except Exception as close_error:
self.logger.bind(tag=TAG).error(
f"超时关闭连接时出错: {close_error}"
)
break
# 每10秒检查一次避免过于频繁
await asyncio.sleep(10)
except Exception as e:
self.logger.bind(tag=TAG).error(f"超时检查任务出错: {e}")
finally:
self.logger.bind(tag=TAG).info("超时检查任务已退出")