You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

182 lines
5.8 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import os
import argparse
import requests
import yaml
import time
# 添加全局配置缓存
_config_cache = None
def get_project_dir():
"""获取项目根目录"""
return os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + "/"
def read_config(config_path):
with open(config_path, "r", encoding="utf-8") as file:
config = yaml.safe_load(file)
return config
def load_config():
"""加载配置文件"""
global _config_cache
if _config_cache is not None:
return _config_cache
parser = argparse.ArgumentParser(description="Server configuration")
config_file = get_config_file()
parser.add_argument("--config_path", type=str, default=config_file)
args = parser.parse_args()
config = read_config(args.config_path)
if config.get("manager-api", {}).get("url"):
config = get_config_from_api(config)
# 初始化目录
ensure_directories(config)
_config_cache = config
return config
def get_config_file():
"""获取配置文件路径,优先使用私有配置文件(若存在)。
Returns:
str: 配置文件路径(相对路径或默认路径)
"""
default_config_file = "config.yaml"
config_file = default_config_file
if os.path.exists(get_project_dir() + "data/." + default_config_file):
config_file = "data/." + default_config_file
return config_file
def _make_api_request(api_url, secret, endpoint, json_data=None):
"""执行API请求的通用函数
Args:
api_url: API的基础URL
secret: API密钥
endpoint: API端点
json_data: 请求的JSON数据
Returns:
dict: API返回的数据
Raises:
Exception: 当请求失败时抛出异常
"""
if not api_url or not secret:
raise Exception("manager-api的url或secret配置错误")
if "" in secret:
raise Exception("请先配置manager-api的secret")
max_retries = 10
retry_delay = 2 # 秒
for attempt in range(max_retries):
try:
response = requests.post(f"{api_url}{endpoint}", json=json_data)
if response.status_code == 200:
result = response.json()
if result.get("code") != 0:
raise Exception(f"API返回错误: {result.get('msg', '未知错误')}")
return result.get("data")
error_msg = f"manager-api请求失败状态码: {response.status_code}"
try:
error_data = response.json()
if "msg" in error_data:
error_msg = f"{error_msg}, 错误信息: {error_data['msg']}"
except:
error_msg = f"{error_msg}, 响应内容: {response.text}"
if attempt < max_retries - 1:
print(f"请求manager-api失败正在重试 ({attempt + 1}/{max_retries})...")
time.sleep(retry_delay)
else:
raise Exception(error_msg)
except requests.exceptions.RequestException as e:
if attempt < max_retries - 1:
print(f"请求manager-api异常正在重试 ({attempt + 1}/{max_retries})...")
time.sleep(retry_delay)
else:
raise Exception(f"manager-api请求异常: {str(e)}")
def get_config_from_api(config):
"""从Java API获取配置"""
# 如果配置中指定了manager-api则从Java API获取配置
# D:\dsWork\QingLong\XiaoZhi\xiaozhi-esp32-server\main\xiaozhi-server\config_from_api.yaml
# 据说需要从这个config_from_api.yaml拷贝过去才行
api_url = config["manager-api"].get("url", "")
secret = config["manager-api"].get("secret", "")
config_data = _make_api_request(
api_url, secret, "/config/server-base", {"secret": secret}
)
config_data["read_config_from_api"] = True
config_data["manager-api"] = {
"url": api_url,
"secret": secret,
}
return config_data
def get_private_config_from_api(config, device_id, client_id):
"""从Java API获取私有配置"""
api_url = config["manager-api"].get("url", "")
secret = config["manager-api"].get("secret", "")
# 调用JAVA接口获取私有配置
return _make_api_request(
api_url,
secret,
"/config/agent-models",
{
"secret": secret,
"macAddress": device_id,
"clientId": client_id,
"selectedModule": config["selected_module"],
},
)
def ensure_directories(config):
"""确保所有配置路径存在"""
dirs_to_create = set()
project_dir = get_project_dir() # 获取项目根目录
# 日志文件目录
log_dir = config.get("log", {}).get("log_dir", "tmp")
dirs_to_create.add(os.path.join(project_dir, log_dir))
# ASR/TTS模块输出目录
for module in ["ASR", "TTS"]:
for provider in config.get(module, {}).values():
output_dir = provider.get("output_dir", "")
if output_dir:
dirs_to_create.add(output_dir)
# 根据selected_module创建模型目录
selected_modules = config.get("selected_module", {})
for module_type in ["ASR", "LLM", "TTS"]:
selected_provider = selected_modules.get(module_type)
if not selected_provider:
continue
provider_config = config.get(module_type, {}).get(selected_provider, {})
output_dir = provider_config.get("output_dir")
if output_dir:
full_model_dir = os.path.join(project_dir, output_dir)
dirs_to_create.add(full_model_dir)
# 统一创建目录保留原data目录创建
for dir_path in dirs_to_create:
try:
os.makedirs(dir_path, exist_ok=True)
except PermissionError:
print(f"警告:无法创建目录 {dir_path},请检查写入权限")