From 498971b9f3965779b2c8b31a1522b673ba719a1d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=BB=84=E6=B5=B7?= <10402852@qq.com> Date: Mon, 24 Feb 2025 12:00:12 +0800 Subject: [PATCH] 'commit' --- AI/SyncData/C3_ThreadPoolSyncData.py | 84 ++++++++++++++++++++ AI/Vanna/VannaQwen.py | 110 +++++++++++++++++++++++++++ 2 files changed, 194 insertions(+) create mode 100644 AI/SyncData/C3_ThreadPoolSyncData.py create mode 100644 AI/Vanna/VannaQwen.py diff --git a/AI/SyncData/C3_ThreadPoolSyncData.py b/AI/SyncData/C3_ThreadPoolSyncData.py new file mode 100644 index 00000000..e3bc8242 --- /dev/null +++ b/AI/SyncData/C3_ThreadPoolSyncData.py @@ -0,0 +1,84 @@ +import logging +from connectors.mysql_connector import MySQLConnector +from connectors.clickhouse_connector import ClickHouseConnector +from mappers.data_mapper import DataMapper +from services.sync_service import SyncService, SyncType +from utils.logger import configure_logger +from apscheduler.schedulers.background import BackgroundScheduler +import time +from config.db_config import MYSQL_CONFIG, CH_CONFIG, print_config +from concurrent.futures import ThreadPoolExecutor + +# 要同步的表 +tables = ['t_station', 't_equipment', 't_equipment_charge_order'] +# 线程数量 +MAX_WORKERS =8 + +# 配置日志记录 +logger = configure_logger() + +# 添加文件处理器 +file_handler = logging.FileHandler('sync.log') +file_handler.setLevel(logging.INFO) +formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') +file_handler.setFormatter(formatter) +logger.addHandler(file_handler) + +# 同步函数 +def sync(table, sync_type): + logger.info(f"开始同步表: {table}, 同步类型: {sync_type}") + + # 初始化组件 + mysql_conn = MySQLConnector(MYSQL_CONFIG) + ch_conn = ClickHouseConnector(CH_CONFIG) + + if sync_type == SyncType.FULL: + # 清空目标表 + ch_conn.connect().execute(f"TRUNCATE TABLE {table}") + logger.info(f"已清空目标表: {table}") + ch_conn.disconnect() + + # 创建数据映射器 + mapper = DataMapper(mysql_conn, table) + + # 创建同步服务 + service = SyncService(mysql_conn=mysql_conn, ch_conn=ch_conn, mapper=mapper) + + try: + # 同步数据 + service.sync_data(batch_size=5000, table=table, sync_type=sync_type) # 确保这里传递的是 SyncType 枚举 + logger.info(f"同步成功: {table}") + except Exception as e: + logger.error(f"同步失败: {str(e)}", exc_info=True) + finally: + ch_conn.disconnect() + +# 全量同步 +def full_sync(): + with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor: # 设置最大线程数 + executor.map(lambda table: sync(table, SyncType.FULL), tables) + +# 增量同步 +def incremental_sync(): + with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor: # 设置最大线程数 + executor.map(lambda table: sync(table, SyncType.INCREMENTAL), tables) + +if __name__ == "__main__": + # 输出配置信息 + print_config() + # 先执行一次增量同步 + incremental_sync() + # 创建调度器 + scheduler = BackgroundScheduler() + # 每天2点执行全量同步 + scheduler.add_job(full_sync, 'cron', hour=2, minute=0) + # 每小时执行增量同步,排除2点 + scheduler.add_job(incremental_sync, 'cron', hour='*', minute=0, id='incremental_sync_job', replace_existing=True) + # 启动调度器 + scheduler.start() + try: + # 保持主线程运行 + while True: + time.sleep(1) + except (KeyboardInterrupt, SystemExit): + scheduler.shutdown() \ No newline at end of file diff --git a/AI/Vanna/VannaQwen.py b/AI/Vanna/VannaQwen.py new file mode 100644 index 00000000..253d7bac --- /dev/null +++ b/AI/Vanna/VannaQwen.py @@ -0,0 +1,110 @@ +import os +import requests +from vanna.base import VannaBase +from typing import List, Dict, Any +from Config import * + + +class QwenChat(VannaBase): + def __init__(self): + super().__init__() + self.api_key = MODEL_API_KEY + self.model = "qwen-max" + self.training_data = [] + self.chat_history = [] # 新增:用于存储对话历史 + + # ---------- 必须实现的抽象方法 ---------- + def add_ddl(self, ddl: str, **kwargs) -> None: + self.training_data.append({"type": "ddl", "content": ddl}) + + def add_documentation(self, doc: str, **kwargs) -> None: + self.training_data.append({"type": "documentation", "content": doc}) + + def add_question_sql(self, question: str, sql: str, **kwargs) -> None: + self.training_data.append({"type": "qa", "question": question, "sql": sql}) + + def get_related_ddl(self, question: str, **kwargs) -> str: + return "\n".join([item["content"] for item in self.training_data if item["type"] == "ddl"]) + + def get_related_documentation(self, question: str, **kwargs) -> str: + return "\n".join([item["content"] for item in self.training_data if item["type"] == "documentation"]) + + def get_training_data(self, **kwargs) -> List[Dict[str, Any]]: + return self.training_data + + # ---------- 新增的必需方法 ---------- + def system_message(self, message: str) -> None: + """系统消息(用于初始化对话)""" + self.chat_history.append({"role": "system", "content": message}) + + def user_message(self, message: str) -> None: + """用户消息""" + self.chat_history.append({"role": "user", "content": message}) + + def assistant_message(self, message: str) -> None: + """助手消息""" + self.chat_history.append({"role": "assistant", "content": message}) + + def submit_prompt(self, prompt: str, **kwargs) -> str: + """提交提示词(这里直接调用 Qwen API)""" + return self.generate_sql(question=prompt) + + # ---------- 其他需要实现的方法 ---------- + def remove_training_data(self, id: str, **kwargs) -> bool: + return True + + def generate_embedding(self, text: str, **kwargs) -> List[float]: + return [] + + def get_similar_question_sql(self, question: str, **kwargs) -> List[Dict[str, Any]]: + return [] + + # ---------- 原有生成 SQL 的方法 ---------- + def generate_sql(self, question: str, **kwargs) -> str: + url = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation" + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } + prompt = f"基于以下 MySQL 表结构,生成 SQL 查询。只需返回 SQL 代码,不要解释。\n\n表结构:\n{self.get_related_ddl(question)}\n\n问题:{question}" + + data = { + "model": self.model, + "input": {"messages": [{"role": "user", "content": prompt}]}, + "parameters": {} + } + + response = requests.post(url, headers=headers, json=data) + if response.status_code == 200: + return response.json()["output"]["text"] + else: + raise Exception(f"Qwen API 错误: {response.text}") + + +# ---------------------------- +# 初始化并使用实例 +# ---------------------------- +vn = QwenChat() + +# 添加训练数据 +vn.train( + ddl=""" + CREATE TABLE employees ( + id INT PRIMARY KEY, + name VARCHAR(100), + department VARCHAR(50), + salary DECIMAL(10, 2), + hire_date DATE + ); + """ +) + +vn.train(documentation="部门 'Sales' 的员工工资超过 50000") + +# 生成 SQL +question = "显示 2020 年后入职的 Sales 部门员工姓名和工资,且工资大于 50000" +try: + sql = vn.generate_sql(question=question) + print("生成的 SQL:\n", sql) +except Exception as e: + print("错误:", str(e)) \ No newline at end of file