import os import yaml from collections.abc import Mapping from config.manage_api_client import init_service, get_server_config, get_agent_models def get_project_dir(): """获取项目根目录""" return os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + "/" def read_config(config_path): with open(config_path, "r", encoding="utf-8") as file: config = yaml.safe_load(file) return config def load_config(): """加载配置文件""" from core.utils.cache.manager import cache_manager, CacheType # 检查缓存 cached_config = cache_manager.get(CacheType.CONFIG, "main_config") if cached_config is not None: return cached_config default_config_path = get_project_dir() + "config.yaml" custom_config_path = get_project_dir() + "data/.config.yaml" # 加载默认配置 default_config = read_config(default_config_path) custom_config = read_config(custom_config_path) if custom_config.get("manager-api", {}).get("url"): config = get_config_from_api(custom_config) else: # 合并配置 config = merge_configs(default_config, custom_config) # 初始化目录 ensure_directories(config) # 缓存配置 cache_manager.set(CacheType.CONFIG, "main_config", config) return config def get_config_from_api(config): """从Java API获取配置""" # 初始化API客户端 init_service(config) # 获取服务器配置 config_data = get_server_config() if config_data is None: raise Exception("Failed to fetch server config from API") config_data["read_config_from_api"] = True config_data["manager-api"] = { "url": config["manager-api"].get("url", ""), "secret": config["manager-api"].get("secret", ""), } # server的配置以本地为准 if config.get("server"): config_data["server"] = { "ip": config["server"].get("ip", ""), "port": config["server"].get("port", ""), "http_port": config["server"].get("http_port", ""), "vision_explain": config["server"].get("vision_explain", ""), "auth_key": config["server"].get("auth_key", ""), } return config_data def get_private_config_from_api(config, device_id, client_id): """从Java API获取私有配置""" return get_agent_models(device_id, client_id, config["selected_module"]) def ensure_directories(config): """确保所有配置路径存在""" dirs_to_create = set() project_dir = get_project_dir() # 获取项目根目录 # 日志文件目录 log_dir = config.get("log", {}).get("log_dir", "tmp") dirs_to_create.add(os.path.join(project_dir, log_dir)) # ASR/TTS模块输出目录 for module in ["ASR", "TTS"]: if config.get(module) is None: continue for provider in config.get(module, {}).values(): output_dir = provider.get("output_dir", "") if output_dir: dirs_to_create.add(output_dir) # 根据selected_module创建模型目录 selected_modules = config.get("selected_module", {}) for module_type in ["ASR", "LLM", "TTS"]: selected_provider = selected_modules.get(module_type) if not selected_provider: continue if config.get(module) is None: continue if config.get(selected_provider) is None: continue provider_config = config.get(module_type, {}).get(selected_provider, {}) output_dir = provider_config.get("output_dir") if output_dir: full_model_dir = os.path.join(project_dir, output_dir) dirs_to_create.add(full_model_dir) # 统一创建目录(保留原data目录创建) for dir_path in dirs_to_create: try: os.makedirs(dir_path, exist_ok=True) except PermissionError: print(f"警告:无法创建目录 {dir_path},请检查写入权限") def merge_configs(default_config, custom_config): """ 递归合并配置,custom_config优先级更高 Args: default_config: 默认配置 custom_config: 用户自定义配置 Returns: 合并后的配置 """ if not isinstance(default_config, Mapping) or not isinstance( custom_config, Mapping ): return custom_config merged = dict(default_config) for key, value in custom_config.items(): if ( key in merged and isinstance(merged[key], Mapping) and isinstance(value, Mapping) ): merged[key] = merge_configs(merged[key], value) else: merged[key] = value return merged