You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

149 lines
4.6 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import os
import yaml
from collections.abc import Mapping
from config.manage_api_client import init_service, get_server_config, get_agent_models
def get_project_dir():
"""获取项目根目录"""
return os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + "/"
def read_config(config_path):
with open(config_path, "r", encoding="utf-8") as file:
config = yaml.safe_load(file)
return config
def load_config():
"""加载配置文件"""
from core.utils.cache.manager import cache_manager, CacheType
# 检查缓存
cached_config = cache_manager.get(CacheType.CONFIG, "main_config")
if cached_config is not None:
return cached_config
default_config_path = get_project_dir() + "config.yaml"
custom_config_path = get_project_dir() + "data/.config.yaml"
# 加载默认配置
default_config = read_config(default_config_path)
custom_config = read_config(custom_config_path)
if custom_config.get("manager-api", {}).get("url"):
config = get_config_from_api(custom_config)
else:
# 合并配置
config = merge_configs(default_config, custom_config)
# 初始化目录
ensure_directories(config)
# 缓存配置
cache_manager.set(CacheType.CONFIG, "main_config", config)
return config
def get_config_from_api(config):
"""从Java API获取配置"""
# 初始化API客户端
init_service(config)
# 获取服务器配置
config_data = get_server_config()
if config_data is None:
raise Exception("Failed to fetch server config from API")
config_data["read_config_from_api"] = True
config_data["manager-api"] = {
"url": config["manager-api"].get("url", ""),
"secret": config["manager-api"].get("secret", ""),
}
# server的配置以本地为准
if config.get("server"):
config_data["server"] = {
"ip": config["server"].get("ip", ""),
"port": config["server"].get("port", ""),
"http_port": config["server"].get("http_port", ""),
"vision_explain": config["server"].get("vision_explain", ""),
"auth_key": config["server"].get("auth_key", ""),
}
return config_data
def get_private_config_from_api(config, device_id, client_id):
"""从Java API获取私有配置"""
return get_agent_models(device_id, client_id, config["selected_module"])
def ensure_directories(config):
"""确保所有配置路径存在"""
dirs_to_create = set()
project_dir = get_project_dir() # 获取项目根目录
# 日志文件目录
log_dir = config.get("log", {}).get("log_dir", "tmp")
dirs_to_create.add(os.path.join(project_dir, log_dir))
# ASR/TTS模块输出目录
for module in ["ASR", "TTS"]:
if config.get(module) is None:
continue
for provider in config.get(module, {}).values():
output_dir = provider.get("output_dir", "")
if output_dir:
dirs_to_create.add(output_dir)
# 根据selected_module创建模型目录
selected_modules = config.get("selected_module", {})
for module_type in ["ASR", "LLM", "TTS"]:
selected_provider = selected_modules.get(module_type)
if not selected_provider:
continue
if config.get(module) is None:
continue
if config.get(selected_provider) is None:
continue
provider_config = config.get(module_type, {}).get(selected_provider, {})
output_dir = provider_config.get("output_dir")
if output_dir:
full_model_dir = os.path.join(project_dir, output_dir)
dirs_to_create.add(full_model_dir)
# 统一创建目录保留原data目录创建
for dir_path in dirs_to_create:
try:
os.makedirs(dir_path, exist_ok=True)
except PermissionError:
print(f"警告:无法创建目录 {dir_path},请检查写入权限")
def merge_configs(default_config, custom_config):
"""
递归合并配置custom_config优先级更高
Args:
default_config: 默认配置
custom_config: 用户自定义配置
Returns:
合并后的配置
"""
if not isinstance(default_config, Mapping) or not isinstance(
custom_config, Mapping
):
return custom_config
merged = dict(default_config)
for key, value in custom_config.items():
if (
key in merged
and isinstance(merged[key], Mapping)
and isinstance(value, Mapping)
):
merged[key] = merge_configs(merged[key], value)
else:
merged[key] = value
return merged