|
|
|
@ -1,88 +1,69 @@
|
|
|
|
|
import re
|
|
|
|
|
from typing import List, Dict, Any
|
|
|
|
|
import requests
|
|
|
|
|
import sqlite3
|
|
|
|
|
import psycopg2
|
|
|
|
|
from vanna.base import VannaBase
|
|
|
|
|
from Config import *
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class VannaUtil(VannaBase):
|
|
|
|
|
def __init__(self, db_type='sqlite', db_uri=None):
|
|
|
|
|
def __init__(self, db_uri=None):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.api_key = MODEL_API_KEY
|
|
|
|
|
self.base_url = MODEL_GENERATION_TEXT_URL # 阿里云专用API地址
|
|
|
|
|
self.model = QWEN_MODEL_NAME # 根据实际模型名称调整
|
|
|
|
|
self.training_data = []
|
|
|
|
|
self.chat_history = []
|
|
|
|
|
self.db_type = db_type
|
|
|
|
|
self.db_uri = db_uri or 'Db/vanna.db' # 默认使用 SQLite
|
|
|
|
|
self.db_uri = db_uri or VANNA_POSTGRESQL_URI # 默认 PostgreSQL 连接字符串
|
|
|
|
|
self._init_db()
|
|
|
|
|
|
|
|
|
|
def _init_db(self):
|
|
|
|
|
"""初始化数据库连接"""
|
|
|
|
|
if self.db_type == 'sqlite':
|
|
|
|
|
self.conn = sqlite3.connect(self.db_uri)
|
|
|
|
|
self._create_tables()
|
|
|
|
|
elif self.db_type == 'postgres':
|
|
|
|
|
import psycopg2
|
|
|
|
|
self.conn = psycopg2.connect(self.db_uri)
|
|
|
|
|
self._create_tables()
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(f"Unsupported database type: {self.db_type}")
|
|
|
|
|
"""初始化 PostgreSQL 数据库连接"""
|
|
|
|
|
self.conn = psycopg2.connect(self.db_uri)
|
|
|
|
|
self._create_tables()
|
|
|
|
|
|
|
|
|
|
def _create_tables(self):
|
|
|
|
|
"""创建训练数据表"""
|
|
|
|
|
cursor = self.conn.cursor()
|
|
|
|
|
if self.db_type == 'sqlite':
|
|
|
|
|
cursor.execute('''
|
|
|
|
|
CREATE TABLE IF NOT EXISTS training_data (
|
|
|
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
|
|
|
type TEXT,
|
|
|
|
|
question TEXT,
|
|
|
|
|
sql TEXT,
|
|
|
|
|
content TEXT
|
|
|
|
|
)
|
|
|
|
|
''')
|
|
|
|
|
elif self.db_type == 'postgres':
|
|
|
|
|
cursor.execute('''
|
|
|
|
|
CREATE TABLE IF NOT EXISTS training_data (
|
|
|
|
|
id SERIAL PRIMARY KEY,
|
|
|
|
|
type TEXT,
|
|
|
|
|
question TEXT,
|
|
|
|
|
sql TEXT,
|
|
|
|
|
content TEXT
|
|
|
|
|
)
|
|
|
|
|
''')
|
|
|
|
|
cursor.execute('''
|
|
|
|
|
CREATE TABLE IF NOT EXISTS training_data (
|
|
|
|
|
id SERIAL PRIMARY KEY,
|
|
|
|
|
type TEXT,
|
|
|
|
|
question TEXT,
|
|
|
|
|
sql TEXT,
|
|
|
|
|
content TEXT
|
|
|
|
|
)
|
|
|
|
|
''')
|
|
|
|
|
self.conn.commit()
|
|
|
|
|
|
|
|
|
|
def add_ddl(self, ddl: str, **kwargs) -> None:
|
|
|
|
|
"""添加 DDL"""
|
|
|
|
|
cursor = self.conn.cursor()
|
|
|
|
|
cursor.execute('INSERT INTO training_data (type, content) VALUES (?, ?)', ('ddl', ddl))
|
|
|
|
|
cursor.execute('INSERT INTO training_data (type, content) VALUES (%s, %s)', ('ddl', ddl))
|
|
|
|
|
self.conn.commit()
|
|
|
|
|
|
|
|
|
|
def add_documentation(self, doc: str, **kwargs) -> None:
|
|
|
|
|
"""添加文档"""
|
|
|
|
|
cursor = self.conn.cursor()
|
|
|
|
|
cursor.execute('INSERT INTO training_data (type, content) VALUES (?, ?)', ('documentation', doc))
|
|
|
|
|
cursor.execute('INSERT INTO training_data (type, content) VALUES (%s, %s)', ('documentation', doc))
|
|
|
|
|
self.conn.commit()
|
|
|
|
|
|
|
|
|
|
def add_question_sql(self, question: str, sql: str, **kwargs) -> None:
|
|
|
|
|
"""添加问答对"""
|
|
|
|
|
cursor = self.conn.cursor()
|
|
|
|
|
cursor.execute('INSERT INTO training_data (type, question, sql) VALUES (?, ?, ?)', ('qa', question, sql))
|
|
|
|
|
cursor.execute('INSERT INTO training_data (type, question, sql) VALUES (%s, %s, %s)', ('qa', question, sql))
|
|
|
|
|
self.conn.commit()
|
|
|
|
|
|
|
|
|
|
def get_related_ddl(self, question: str, **kwargs) -> str:
|
|
|
|
|
"""获取相关 DDL"""
|
|
|
|
|
cursor = self.conn.cursor()
|
|
|
|
|
cursor.execute('SELECT content FROM training_data WHERE type = ?', ('ddl',))
|
|
|
|
|
cursor.execute('SELECT content FROM training_data WHERE type = %s', ('ddl',))
|
|
|
|
|
return "\n".join(row[0] for row in cursor.fetchall())
|
|
|
|
|
|
|
|
|
|
def get_related_documentation(self, question: str, **kwargs) -> str:
|
|
|
|
|
"""获取相关文档"""
|
|
|
|
|
cursor = self.conn.cursor()
|
|
|
|
|
cursor.execute('SELECT content FROM training_data WHERE type = ?', ('documentation',))
|
|
|
|
|
cursor.execute('SELECT content FROM training_data WHERE type = %s', ('documentation',))
|
|
|
|
|
return "\n".join(row[0] for row in cursor.fetchall())
|
|
|
|
|
|
|
|
|
|
def get_training_data(self, **kwargs) -> List[Dict[str, Any]]:
|
|
|
|
@ -95,7 +76,7 @@ class VannaUtil(VannaBase):
|
|
|
|
|
def remove_training_data(self, id: str, **kwargs) -> bool:
|
|
|
|
|
"""删除训练数据"""
|
|
|
|
|
cursor = self.conn.cursor()
|
|
|
|
|
cursor.execute('DELETE FROM training_data WHERE id = ?', (id,))
|
|
|
|
|
cursor.execute('DELETE FROM training_data WHERE id = %s', (id,))
|
|
|
|
|
self.conn.commit()
|
|
|
|
|
return cursor.rowcount > 0
|
|
|
|
|
|
|
|
|
|