main
HuangHai 4 months ago
parent 26ae208519
commit 8336aa2b7d

@ -30,4 +30,7 @@ NEO4J_AUTH = ("neo4j", "DsideaL4r5t6y7u")
# Dify # Dify
DIFY_API_KEY = "app-jdd8mRNx3IJqVTYhRhmRfxtl" DIFY_API_KEY = "app-jdd8mRNx3IJqVTYhRhmRfxtl"
DIFY_URL='http://10.10.14.207/v1' DIFY_URL='http://10.10.14.207/v1'
# Vanna Postgresql
VANNA_POSTGRESQL_URI = "postgresql://postgres:DsideaL147258369@10.10.14.71:5432/szjz_db"

@ -1,88 +1,69 @@
import re import re
from typing import List, Dict, Any from typing import List, Dict, Any
import requests import requests
import sqlite3 import psycopg2
from vanna.base import VannaBase from vanna.base import VannaBase
from Config import * from Config import *
class VannaUtil(VannaBase): class VannaUtil(VannaBase):
def __init__(self, db_type='sqlite', db_uri=None): def __init__(self, db_uri=None):
super().__init__() super().__init__()
self.api_key = MODEL_API_KEY self.api_key = MODEL_API_KEY
self.base_url = MODEL_GENERATION_TEXT_URL # 阿里云专用API地址 self.base_url = MODEL_GENERATION_TEXT_URL # 阿里云专用API地址
self.model = QWEN_MODEL_NAME # 根据实际模型名称调整 self.model = QWEN_MODEL_NAME # 根据实际模型名称调整
self.training_data = [] self.training_data = []
self.chat_history = [] self.chat_history = []
self.db_type = db_type self.db_uri = db_uri or VANNA_POSTGRESQL_URI # 默认 PostgreSQL 连接字符串
self.db_uri = db_uri or 'Db/vanna.db' # 默认使用 SQLite
self._init_db() self._init_db()
def _init_db(self): def _init_db(self):
"""初始化数据库连接""" """初始化 PostgreSQL 数据库连接"""
if self.db_type == 'sqlite': self.conn = psycopg2.connect(self.db_uri)
self.conn = sqlite3.connect(self.db_uri) self._create_tables()
self._create_tables()
elif self.db_type == 'postgres':
import psycopg2
self.conn = psycopg2.connect(self.db_uri)
self._create_tables()
else:
raise ValueError(f"Unsupported database type: {self.db_type}")
def _create_tables(self): def _create_tables(self):
"""创建训练数据表""" """创建训练数据表"""
cursor = self.conn.cursor() cursor = self.conn.cursor()
if self.db_type == 'sqlite': cursor.execute('''
cursor.execute(''' CREATE TABLE IF NOT EXISTS training_data (
CREATE TABLE IF NOT EXISTS training_data ( id SERIAL PRIMARY KEY,
id INTEGER PRIMARY KEY AUTOINCREMENT, type TEXT,
type TEXT, question TEXT,
question TEXT, sql TEXT,
sql TEXT, content TEXT
content TEXT )
) ''')
''')
elif self.db_type == 'postgres':
cursor.execute('''
CREATE TABLE IF NOT EXISTS training_data (
id SERIAL PRIMARY KEY,
type TEXT,
question TEXT,
sql TEXT,
content TEXT
)
''')
self.conn.commit() self.conn.commit()
def add_ddl(self, ddl: str, **kwargs) -> None: def add_ddl(self, ddl: str, **kwargs) -> None:
"""添加 DDL""" """添加 DDL"""
cursor = self.conn.cursor() cursor = self.conn.cursor()
cursor.execute('INSERT INTO training_data (type, content) VALUES (?, ?)', ('ddl', ddl)) cursor.execute('INSERT INTO training_data (type, content) VALUES (%s, %s)', ('ddl', ddl))
self.conn.commit() self.conn.commit()
def add_documentation(self, doc: str, **kwargs) -> None: def add_documentation(self, doc: str, **kwargs) -> None:
"""添加文档""" """添加文档"""
cursor = self.conn.cursor() cursor = self.conn.cursor()
cursor.execute('INSERT INTO training_data (type, content) VALUES (?, ?)', ('documentation', doc)) cursor.execute('INSERT INTO training_data (type, content) VALUES (%s, %s)', ('documentation', doc))
self.conn.commit() self.conn.commit()
def add_question_sql(self, question: str, sql: str, **kwargs) -> None: def add_question_sql(self, question: str, sql: str, **kwargs) -> None:
"""添加问答对""" """添加问答对"""
cursor = self.conn.cursor() cursor = self.conn.cursor()
cursor.execute('INSERT INTO training_data (type, question, sql) VALUES (?, ?, ?)', ('qa', question, sql)) cursor.execute('INSERT INTO training_data (type, question, sql) VALUES (%s, %s, %s)', ('qa', question, sql))
self.conn.commit() self.conn.commit()
def get_related_ddl(self, question: str, **kwargs) -> str: def get_related_ddl(self, question: str, **kwargs) -> str:
"""获取相关 DDL""" """获取相关 DDL"""
cursor = self.conn.cursor() cursor = self.conn.cursor()
cursor.execute('SELECT content FROM training_data WHERE type = ?', ('ddl',)) cursor.execute('SELECT content FROM training_data WHERE type = %s', ('ddl',))
return "\n".join(row[0] for row in cursor.fetchall()) return "\n".join(row[0] for row in cursor.fetchall())
def get_related_documentation(self, question: str, **kwargs) -> str: def get_related_documentation(self, question: str, **kwargs) -> str:
"""获取相关文档""" """获取相关文档"""
cursor = self.conn.cursor() cursor = self.conn.cursor()
cursor.execute('SELECT content FROM training_data WHERE type = ?', ('documentation',)) cursor.execute('SELECT content FROM training_data WHERE type = %s', ('documentation',))
return "\n".join(row[0] for row in cursor.fetchall()) return "\n".join(row[0] for row in cursor.fetchall())
def get_training_data(self, **kwargs) -> List[Dict[str, Any]]: def get_training_data(self, **kwargs) -> List[Dict[str, Any]]:
@ -95,7 +76,7 @@ class VannaUtil(VannaBase):
def remove_training_data(self, id: str, **kwargs) -> bool: def remove_training_data(self, id: str, **kwargs) -> bool:
"""删除训练数据""" """删除训练数据"""
cursor = self.conn.cursor() cursor = self.conn.cursor()
cursor.execute('DELETE FROM training_data WHERE id = ?', (id,)) cursor.execute('DELETE FROM training_data WHERE id = %s', (id,))
self.conn.commit() self.conn.commit()
return cursor.rowcount > 0 return cursor.rowcount > 0

@ -15,7 +15,7 @@ from Text2Sql.Util.VannaUtil import VannaUtil
app = FastAPI() app = FastAPI()
# 配置静态文件目录 # 配置静态文件目录
app.mount("/static", StaticFiles(directory="static"), name="static") app.mount("/static", StaticFiles(directory="static"), name="static")
vn = VannaUtil()
@app.get("/") @app.get("/")
def read_root(): def read_root():
return {"message": "Welcome to Vanna AI SQL !"} return {"message": "Welcome to Vanna AI SQL !"}
@ -70,14 +70,7 @@ def delete_question(question_id: int, db: Session = Depends(get_db)):
# 参数question # 参数question
@app.post("/questions/get_excel") @app.post("/questions/get_excel")
def get_excel(question: str = Form(...)): def get_excel(question: str = Form(...)):
vn = VannaUtil()
# 指定学段
# question = '''
# 查询:
# 1、发布时间是2024年度
# 2、每个学段每个科目上传课程数量按由多到少排序
# 3、字段名: 学段,科目,排名,课程数量
# '''
common_prompt = ''' common_prompt = '''
返回的信息要求 返回的信息要求
1行政区划为NULL 或者是空字符的不参加统计 1行政区划为NULL 或者是空字符的不参加统计
@ -90,7 +83,6 @@ def get_excel(question: str = Form(...)):
# 执行SQL查询 # 执行SQL查询
with PostgreSQLUtil() as db: with PostgreSQLUtil() as db:
_data = db.execute_query(sql) _data = db.execute_query(sql)
print(_data)
# 在static目录下生成一个guid号的临时文件 # 在static目录下生成一个guid号的临时文件
uuidStr = str(uuid.uuid4()) uuidStr = str(uuid.uuid4())
filename = f"static/{uuidStr}.xlsx" filename = f"static/{uuidStr}.xlsx"

Loading…
Cancel
Save