main
HuangHai 4 months ago
parent 26ae208519
commit 8336aa2b7d

@ -30,4 +30,7 @@ NEO4J_AUTH = ("neo4j", "DsideaL4r5t6y7u")
# Dify
DIFY_API_KEY = "app-jdd8mRNx3IJqVTYhRhmRfxtl"
DIFY_URL='http://10.10.14.207/v1'
DIFY_URL='http://10.10.14.207/v1'
# Vanna Postgresql
VANNA_POSTGRESQL_URI = "postgresql://postgres:DsideaL147258369@10.10.14.71:5432/szjz_db"

@ -1,88 +1,69 @@
import re
from typing import List, Dict, Any
import requests
import sqlite3
import psycopg2
from vanna.base import VannaBase
from Config import *
class VannaUtil(VannaBase):
def __init__(self, db_type='sqlite', db_uri=None):
def __init__(self, db_uri=None):
super().__init__()
self.api_key = MODEL_API_KEY
self.base_url = MODEL_GENERATION_TEXT_URL # 阿里云专用API地址
self.model = QWEN_MODEL_NAME # 根据实际模型名称调整
self.training_data = []
self.chat_history = []
self.db_type = db_type
self.db_uri = db_uri or 'Db/vanna.db' # 默认使用 SQLite
self.db_uri = db_uri or VANNA_POSTGRESQL_URI # 默认 PostgreSQL 连接字符串
self._init_db()
def _init_db(self):
"""初始化数据库连接"""
if self.db_type == 'sqlite':
self.conn = sqlite3.connect(self.db_uri)
self._create_tables()
elif self.db_type == 'postgres':
import psycopg2
self.conn = psycopg2.connect(self.db_uri)
self._create_tables()
else:
raise ValueError(f"Unsupported database type: {self.db_type}")
"""初始化 PostgreSQL 数据库连接"""
self.conn = psycopg2.connect(self.db_uri)
self._create_tables()
def _create_tables(self):
"""创建训练数据表"""
cursor = self.conn.cursor()
if self.db_type == 'sqlite':
cursor.execute('''
CREATE TABLE IF NOT EXISTS training_data (
id INTEGER PRIMARY KEY AUTOINCREMENT,
type TEXT,
question TEXT,
sql TEXT,
content TEXT
)
''')
elif self.db_type == 'postgres':
cursor.execute('''
CREATE TABLE IF NOT EXISTS training_data (
id SERIAL PRIMARY KEY,
type TEXT,
question TEXT,
sql TEXT,
content TEXT
)
''')
cursor.execute('''
CREATE TABLE IF NOT EXISTS training_data (
id SERIAL PRIMARY KEY,
type TEXT,
question TEXT,
sql TEXT,
content TEXT
)
''')
self.conn.commit()
def add_ddl(self, ddl: str, **kwargs) -> None:
"""添加 DDL"""
cursor = self.conn.cursor()
cursor.execute('INSERT INTO training_data (type, content) VALUES (?, ?)', ('ddl', ddl))
cursor.execute('INSERT INTO training_data (type, content) VALUES (%s, %s)', ('ddl', ddl))
self.conn.commit()
def add_documentation(self, doc: str, **kwargs) -> None:
"""添加文档"""
cursor = self.conn.cursor()
cursor.execute('INSERT INTO training_data (type, content) VALUES (?, ?)', ('documentation', doc))
cursor.execute('INSERT INTO training_data (type, content) VALUES (%s, %s)', ('documentation', doc))
self.conn.commit()
def add_question_sql(self, question: str, sql: str, **kwargs) -> None:
"""添加问答对"""
cursor = self.conn.cursor()
cursor.execute('INSERT INTO training_data (type, question, sql) VALUES (?, ?, ?)', ('qa', question, sql))
cursor.execute('INSERT INTO training_data (type, question, sql) VALUES (%s, %s, %s)', ('qa', question, sql))
self.conn.commit()
def get_related_ddl(self, question: str, **kwargs) -> str:
"""获取相关 DDL"""
cursor = self.conn.cursor()
cursor.execute('SELECT content FROM training_data WHERE type = ?', ('ddl',))
cursor.execute('SELECT content FROM training_data WHERE type = %s', ('ddl',))
return "\n".join(row[0] for row in cursor.fetchall())
def get_related_documentation(self, question: str, **kwargs) -> str:
"""获取相关文档"""
cursor = self.conn.cursor()
cursor.execute('SELECT content FROM training_data WHERE type = ?', ('documentation',))
cursor.execute('SELECT content FROM training_data WHERE type = %s', ('documentation',))
return "\n".join(row[0] for row in cursor.fetchall())
def get_training_data(self, **kwargs) -> List[Dict[str, Any]]:
@ -95,7 +76,7 @@ class VannaUtil(VannaBase):
def remove_training_data(self, id: str, **kwargs) -> bool:
"""删除训练数据"""
cursor = self.conn.cursor()
cursor.execute('DELETE FROM training_data WHERE id = ?', (id,))
cursor.execute('DELETE FROM training_data WHERE id = %s', (id,))
self.conn.commit()
return cursor.rowcount > 0

@ -15,7 +15,7 @@ from Text2Sql.Util.VannaUtil import VannaUtil
app = FastAPI()
# 配置静态文件目录
app.mount("/static", StaticFiles(directory="static"), name="static")
vn = VannaUtil()
@app.get("/")
def read_root():
return {"message": "Welcome to Vanna AI SQL !"}
@ -70,14 +70,7 @@ def delete_question(question_id: int, db: Session = Depends(get_db)):
# 参数question
@app.post("/questions/get_excel")
def get_excel(question: str = Form(...)):
vn = VannaUtil()
# 指定学段
# question = '''
# 查询:
# 1、发布时间是2024年度
# 2、每个学段每个科目上传课程数量按由多到少排序
# 3、字段名: 学段,科目,排名,课程数量
# '''
common_prompt = '''
返回的信息要求
1行政区划为NULL 或者是空字符的不参加统计
@ -90,7 +83,6 @@ def get_excel(question: str = Form(...)):
# 执行SQL查询
with PostgreSQLUtil() as db:
_data = db.execute_query(sql)
print(_data)
# 在static目录下生成一个guid号的临时文件
uuidStr = str(uuid.uuid4())
filename = f"static/{uuidStr}.xlsx"

Loading…
Cancel
Save