main
黄海 5 months ago
parent ea738b77e2
commit 498971b9f3

@ -0,0 +1,84 @@
import logging
from connectors.mysql_connector import MySQLConnector
from connectors.clickhouse_connector import ClickHouseConnector
from mappers.data_mapper import DataMapper
from services.sync_service import SyncService, SyncType
from utils.logger import configure_logger
from apscheduler.schedulers.background import BackgroundScheduler
import time
from config.db_config import MYSQL_CONFIG, CH_CONFIG, print_config
from concurrent.futures import ThreadPoolExecutor
# 要同步的表
tables = ['t_station', 't_equipment', 't_equipment_charge_order']
# 线程数量
MAX_WORKERS =8
# 配置日志记录
logger = configure_logger()
# 添加文件处理器
file_handler = logging.FileHandler('sync.log')
file_handler.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
# 同步函数
def sync(table, sync_type):
logger.info(f"开始同步表: {table}, 同步类型: {sync_type}")
# 初始化组件
mysql_conn = MySQLConnector(MYSQL_CONFIG)
ch_conn = ClickHouseConnector(CH_CONFIG)
if sync_type == SyncType.FULL:
# 清空目标表
ch_conn.connect().execute(f"TRUNCATE TABLE {table}")
logger.info(f"已清空目标表: {table}")
ch_conn.disconnect()
# 创建数据映射器
mapper = DataMapper(mysql_conn, table)
# 创建同步服务
service = SyncService(mysql_conn=mysql_conn, ch_conn=ch_conn, mapper=mapper)
try:
# 同步数据
service.sync_data(batch_size=5000, table=table, sync_type=sync_type) # 确保这里传递的是 SyncType 枚举
logger.info(f"同步成功: {table}")
except Exception as e:
logger.error(f"同步失败: {str(e)}", exc_info=True)
finally:
ch_conn.disconnect()
# 全量同步
def full_sync():
with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor: # 设置最大线程数
executor.map(lambda table: sync(table, SyncType.FULL), tables)
# 增量同步
def incremental_sync():
with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor: # 设置最大线程数
executor.map(lambda table: sync(table, SyncType.INCREMENTAL), tables)
if __name__ == "__main__":
# 输出配置信息
print_config()
# 先执行一次增量同步
incremental_sync()
# 创建调度器
scheduler = BackgroundScheduler()
# 每天2点执行全量同步
scheduler.add_job(full_sync, 'cron', hour=2, minute=0)
# 每小时执行增量同步排除2点
scheduler.add_job(incremental_sync, 'cron', hour='*', minute=0, id='incremental_sync_job', replace_existing=True)
# 启动调度器
scheduler.start()
try:
# 保持主线程运行
while True:
time.sleep(1)
except (KeyboardInterrupt, SystemExit):
scheduler.shutdown()

@ -0,0 +1,110 @@
import os
import requests
from vanna.base import VannaBase
from typing import List, Dict, Any
from Config import *
class QwenChat(VannaBase):
def __init__(self):
super().__init__()
self.api_key = MODEL_API_KEY
self.model = "qwen-max"
self.training_data = []
self.chat_history = [] # 新增:用于存储对话历史
# ---------- 必须实现的抽象方法 ----------
def add_ddl(self, ddl: str, **kwargs) -> None:
self.training_data.append({"type": "ddl", "content": ddl})
def add_documentation(self, doc: str, **kwargs) -> None:
self.training_data.append({"type": "documentation", "content": doc})
def add_question_sql(self, question: str, sql: str, **kwargs) -> None:
self.training_data.append({"type": "qa", "question": question, "sql": sql})
def get_related_ddl(self, question: str, **kwargs) -> str:
return "\n".join([item["content"] for item in self.training_data if item["type"] == "ddl"])
def get_related_documentation(self, question: str, **kwargs) -> str:
return "\n".join([item["content"] for item in self.training_data if item["type"] == "documentation"])
def get_training_data(self, **kwargs) -> List[Dict[str, Any]]:
return self.training_data
# ---------- 新增的必需方法 ----------
def system_message(self, message: str) -> None:
"""系统消息(用于初始化对话)"""
self.chat_history.append({"role": "system", "content": message})
def user_message(self, message: str) -> None:
"""用户消息"""
self.chat_history.append({"role": "user", "content": message})
def assistant_message(self, message: str) -> None:
"""助手消息"""
self.chat_history.append({"role": "assistant", "content": message})
def submit_prompt(self, prompt: str, **kwargs) -> str:
"""提交提示词(这里直接调用 Qwen API"""
return self.generate_sql(question=prompt)
# ---------- 其他需要实现的方法 ----------
def remove_training_data(self, id: str, **kwargs) -> bool:
return True
def generate_embedding(self, text: str, **kwargs) -> List[float]:
return []
def get_similar_question_sql(self, question: str, **kwargs) -> List[Dict[str, Any]]:
return []
# ---------- 原有生成 SQL 的方法 ----------
def generate_sql(self, question: str, **kwargs) -> str:
url = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
prompt = f"基于以下 MySQL 表结构,生成 SQL 查询。只需返回 SQL 代码,不要解释。\n\n表结构:\n{self.get_related_ddl(question)}\n\n问题:{question}"
data = {
"model": self.model,
"input": {"messages": [{"role": "user", "content": prompt}]},
"parameters": {}
}
response = requests.post(url, headers=headers, json=data)
if response.status_code == 200:
return response.json()["output"]["text"]
else:
raise Exception(f"Qwen API 错误: {response.text}")
# ----------------------------
# 初始化并使用实例
# ----------------------------
vn = QwenChat()
# 添加训练数据
vn.train(
ddl="""
CREATE TABLE employees (
id INT PRIMARY KEY,
name VARCHAR(100),
department VARCHAR(50),
salary DECIMAL(10, 2),
hire_date DATE
);
"""
)
vn.train(documentation="部门 'Sales' 的员工工资超过 50000")
# 生成 SQL
question = "显示 2020 年后入职的 Sales 部门员工姓名和工资,且工资大于 50000"
try:
sql = vn.generate_sql(question=question)
print("生成的 SQL:\n", sql)
except Exception as e:
print("错误:", str(e))
Loading…
Cancel
Save