You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

139 lines
5.5 KiB

6 days ago
import asyncio
import websockets
from config.logger import setup_logging
from core.connection import ConnectionHandler
from config.config_loader import get_config_from_api
from core.utils.modules_initialize import initialize_modules
from core.utils.util import check_vad_update, check_asr_update
TAG = __name__
class WebSocketServer:
def __init__(self, config: dict):
self.config = config
self.logger = setup_logging()
self.config_lock = asyncio.Lock()
modules = initialize_modules(
self.logger,
self.config,
"VAD" in self.config["selected_module"],
"ASR" in self.config["selected_module"],
"LLM" in self.config["selected_module"],
False,
"Memory" in self.config["selected_module"],
"Intent" in self.config["selected_module"],
)
self._vad = modules["vad"] if "vad" in modules else None
self._asr = modules["asr"] if "asr" in modules else None
self._llm = modules["llm"] if "llm" in modules else None
self._intent = modules["intent"] if "intent" in modules else None
self._memory = modules["memory"] if "memory" in modules else None
self.active_connections = set()
async def start(self):
server_config = self.config["server"]
host = server_config.get("ip", "0.0.0.0")
port = int(server_config.get("port", 8000))
async with websockets.serve(
self._handle_connection, host, port, process_request=self._http_response
):
await asyncio.Future()
async def _handle_connection(self, websocket):
"""处理新连接每次创建独立的ConnectionHandler"""
# 创建ConnectionHandler时传入当前server实例
handler = ConnectionHandler(
self.config,
self._vad,
self._asr,
self._llm,
self._memory,
self._intent,
self, # 传入server实例
)
self.active_connections.add(handler)
try:
await handler.handle_connection(websocket)
except Exception as e:
self.logger.bind(tag=TAG).error(f"处理连接时出错: {e}")
finally:
# 确保从活动连接集合中移除
self.active_connections.discard(handler)
# 强制关闭连接(如果还没有关闭的话)
try:
# 安全地检查WebSocket状态并关闭
if hasattr(websocket, "closed") and not websocket.closed:
await websocket.close()
elif hasattr(websocket, "state") and websocket.state.name != "CLOSED":
await websocket.close()
else:
# 如果没有closed属性直接尝试关闭
await websocket.close()
except Exception as close_error:
self.logger.bind(tag=TAG).error(
f"服务器端强制关闭连接时出错: {close_error}"
)
async def _http_response(self, websocket, request_headers):
# 检查是否为 WebSocket 升级请求
if request_headers.headers.get("connection", "").lower() == "upgrade":
# 如果是 WebSocket 请求,返回 None 允许握手继续
return None
else:
# 如果是普通 HTTP 请求,返回 "server is running"
return websocket.respond(200, "Server is running\n")
async def update_config(self) -> bool:
"""更新服务器配置并重新初始化组件
Returns:
bool: 更新是否成功
"""
try:
async with self.config_lock:
# 重新获取配置
new_config = get_config_from_api(self.config)
if new_config is None:
self.logger.bind(tag=TAG).error("获取新配置失败")
return False
self.logger.bind(tag=TAG).info(f"获取新配置成功")
# 检查 VAD 和 ASR 类型是否需要更新
update_vad = check_vad_update(self.config, new_config)
update_asr = check_asr_update(self.config, new_config)
self.logger.bind(tag=TAG).info(
f"检查VAD和ASR类型是否需要更新: {update_vad} {update_asr}"
)
# 更新配置
self.config = new_config
# 重新初始化组件
modules = initialize_modules(
self.logger,
new_config,
update_vad,
update_asr,
"LLM" in new_config["selected_module"],
False,
"Memory" in new_config["selected_module"],
"Intent" in new_config["selected_module"],
)
# 更新组件实例
if "vad" in modules:
self._vad = modules["vad"]
if "asr" in modules:
self._asr = modules["asr"]
if "llm" in modules:
self._llm = modules["llm"]
if "intent" in modules:
self._intent = modules["intent"]
if "memory" in modules:
self._memory = modules["memory"]
self.logger.bind(tag=TAG).info(f"更新配置任务执行完毕")
return True
except Exception as e:
self.logger.bind(tag=TAG).error(f"更新服务器配置失败: {str(e)}")
return False