|
|
|
@ -0,0 +1,188 @@
|
|
|
|
|
import re
|
|
|
|
|
|
|
|
|
|
from fastapi import FastAPI, HTTPException, Depends
|
|
|
|
|
from pydantic import BaseModel
|
|
|
|
|
from sqlalchemy import create_engine, Column, Integer, String, SmallInteger, Date
|
|
|
|
|
from sqlalchemy.orm import declarative_base # 更新导入路径
|
|
|
|
|
from sqlalchemy.orm import sessionmaker, Session
|
|
|
|
|
import uvicorn # 导入 uvicorn
|
|
|
|
|
from starlette.staticfiles import StaticFiles
|
|
|
|
|
|
|
|
|
|
from Text2Sql.Util.VannaUtil import VannaUtil
|
|
|
|
|
|
|
|
|
|
# 数据库连接配置
|
|
|
|
|
DATABASE_URL = "postgresql+psycopg2://postgres:DsideaL147258369@10.10.14.71:5432/szjz_db"
|
|
|
|
|
|
|
|
|
|
# 创建数据库引擎
|
|
|
|
|
engine = create_engine(DATABASE_URL)
|
|
|
|
|
|
|
|
|
|
# 创建 SessionLocal 类
|
|
|
|
|
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
|
|
|
|
|
|
|
|
|
# 创建 Base 类
|
|
|
|
|
Base = declarative_base()
|
|
|
|
|
|
|
|
|
|
# 定义 t_bi_question 表模型
|
|
|
|
|
class TBIQuestion(Base):
|
|
|
|
|
__tablename__ = "t_bi_question"
|
|
|
|
|
|
|
|
|
|
id = Column(Integer, primary_key=True, index=True)
|
|
|
|
|
question = Column(String(255), nullable=False)
|
|
|
|
|
state_id = Column(SmallInteger, nullable=False, default=0)
|
|
|
|
|
sql = Column(String(2048))
|
|
|
|
|
bar_chart_x_columns = Column(String(255))
|
|
|
|
|
bar_chart_y_columns = Column(String(255))
|
|
|
|
|
pie_chart_category_columns = Column(String(255))
|
|
|
|
|
pie_chart_value_column = Column(String(255))
|
|
|
|
|
session_id = Column(String(255), nullable=False)
|
|
|
|
|
excel_file_name = Column(String(255))
|
|
|
|
|
bar_chart_file_name = Column(String(255))
|
|
|
|
|
pie_chart_file_name = Column(String(255))
|
|
|
|
|
report_file_name = Column(String(255))
|
|
|
|
|
create_time = Column(Date, nullable=False, default="CURRENT_TIMESTAMP")
|
|
|
|
|
is_system = Column(SmallInteger, nullable=False, default=0)
|
|
|
|
|
is_collect = Column(SmallInteger, nullable=False, default=0)
|
|
|
|
|
|
|
|
|
|
# 创建 Pydantic 模型
|
|
|
|
|
class TBIQuestionCreate(BaseModel):
|
|
|
|
|
question: str
|
|
|
|
|
state_id: int = 0
|
|
|
|
|
sql: str = None
|
|
|
|
|
bar_chart_x_columns: str = None
|
|
|
|
|
bar_chart_y_columns: str = None
|
|
|
|
|
pie_chart_category_columns: str = None
|
|
|
|
|
pie_chart_value_column: str = None
|
|
|
|
|
session_id: str
|
|
|
|
|
excel_file_name: str = None
|
|
|
|
|
bar_chart_file_name: str = None
|
|
|
|
|
pie_chart_file_name: str = None
|
|
|
|
|
report_file_name: str = None
|
|
|
|
|
is_system: int = 0
|
|
|
|
|
is_collect: int = 0
|
|
|
|
|
|
|
|
|
|
class TBIQuestionUpdate(BaseModel):
|
|
|
|
|
question: str = None
|
|
|
|
|
state_id: int = None
|
|
|
|
|
sql: str = None
|
|
|
|
|
bar_chart_x_columns: str = None
|
|
|
|
|
bar_chart_y_columns: str = None
|
|
|
|
|
pie_chart_category_columns: str = None
|
|
|
|
|
pie_chart_value_column: str = None
|
|
|
|
|
session_id: str = None
|
|
|
|
|
excel_file_name: str = None
|
|
|
|
|
bar_chart_file_name: str = None
|
|
|
|
|
pie_chart_file_name: str = None
|
|
|
|
|
report_file_name: str = None
|
|
|
|
|
is_system: int = None
|
|
|
|
|
is_collect: int = None
|
|
|
|
|
|
|
|
|
|
# 初始化 FastAPI
|
|
|
|
|
app = FastAPI()
|
|
|
|
|
|
|
|
|
|
@app.get("/")
|
|
|
|
|
def read_root():
|
|
|
|
|
return {"message": "Hello, World!"}
|
|
|
|
|
|
|
|
|
|
# 获取数据库会话
|
|
|
|
|
def get_db():
|
|
|
|
|
db = SessionLocal()
|
|
|
|
|
try:
|
|
|
|
|
yield db
|
|
|
|
|
finally:
|
|
|
|
|
db.close()
|
|
|
|
|
|
|
|
|
|
# 创建记录
|
|
|
|
|
@app.post("/questions/", response_model=TBIQuestionCreate)
|
|
|
|
|
def create_question(question: TBIQuestionCreate, db: Session = Depends(get_db)):
|
|
|
|
|
db_question = TBIQuestion(**question.dict())
|
|
|
|
|
db.add(db_question)
|
|
|
|
|
db.commit()
|
|
|
|
|
db.refresh(db_question)
|
|
|
|
|
return db_question
|
|
|
|
|
|
|
|
|
|
# 读取记录
|
|
|
|
|
@app.get("/questions/{question_id}", response_model=TBIQuestionCreate)
|
|
|
|
|
def read_question(question_id: int, db: Session = Depends(get_db)):
|
|
|
|
|
db_question = db.query(TBIQuestion).filter(TBIQuestion.id == question_id).first()
|
|
|
|
|
if db_question is None:
|
|
|
|
|
raise HTTPException(status_code=404, detail="Question not found")
|
|
|
|
|
return db_question
|
|
|
|
|
|
|
|
|
|
# 更新记录
|
|
|
|
|
@app.put("/questions/{question_id}", response_model=TBIQuestionCreate)
|
|
|
|
|
def update_question(question_id: int, question: TBIQuestionUpdate, db: Session = Depends(get_db)):
|
|
|
|
|
db_question = db.query(TBIQuestion).filter(TBIQuestion.id == question_id).first()
|
|
|
|
|
if db_question is None:
|
|
|
|
|
raise HTTPException(status_code=404, detail="Question not found")
|
|
|
|
|
for key, value in question.dict().items():
|
|
|
|
|
if value is not None:
|
|
|
|
|
setattr(db_question, key, value)
|
|
|
|
|
db.commit()
|
|
|
|
|
db.refresh(db_question)
|
|
|
|
|
return db_question
|
|
|
|
|
|
|
|
|
|
# 删除记录
|
|
|
|
|
@app.delete("/questions/{question_id}")
|
|
|
|
|
def delete_question(question_id: int, db: Session = Depends(get_db)):
|
|
|
|
|
db_question = db.query(TBIQuestion).filter(TBIQuestion.id == question_id).first()
|
|
|
|
|
if db_question is None:
|
|
|
|
|
raise HTTPException(status_code=404, detail="Question not found")
|
|
|
|
|
db.delete(db_question)
|
|
|
|
|
db.commit()
|
|
|
|
|
return {"message": "Question deleted successfully"}
|
|
|
|
|
|
|
|
|
|
@app.post("/questions/generate_sql")
|
|
|
|
|
def generate_sql(question: str):
|
|
|
|
|
# 指定学段
|
|
|
|
|
question = '''
|
|
|
|
|
查询:
|
|
|
|
|
1、发布时间是2024年度
|
|
|
|
|
2、每个学段,每个科目,上传课程数量,按由多到少排序
|
|
|
|
|
3、字段名: 学段,科目,排名,课程数量
|
|
|
|
|
'''
|
|
|
|
|
common_prompt = '''
|
|
|
|
|
返回的信息要求:
|
|
|
|
|
1、行政区划为NULL 或者是空字符的不参加统计
|
|
|
|
|
2、目标数据库是Postgresql 16
|
|
|
|
|
'''
|
|
|
|
|
question = question + common_prompt
|
|
|
|
|
# 开始查询
|
|
|
|
|
print("开始查询...")
|
|
|
|
|
# 获取完整 SQL
|
|
|
|
|
sql = vn.generate_sql(question)
|
|
|
|
|
print("生成的查询 SQL:\n", sql)
|
|
|
|
|
|
|
|
|
|
# 添加 main 函数
|
|
|
|
|
def main():
|
|
|
|
|
# 开始训练
|
|
|
|
|
print("开始训练...")
|
|
|
|
|
# 打开AreaSchoolLesson.sql文件内容
|
|
|
|
|
with open("Sql/AreaSchoolLessonDDL.sql", "r", encoding="utf-8") as file:
|
|
|
|
|
ddl = file.read()
|
|
|
|
|
# 训练数据
|
|
|
|
|
vn.train(
|
|
|
|
|
ddl=ddl
|
|
|
|
|
)
|
|
|
|
|
# 添加有关业务术语或定义的文档
|
|
|
|
|
# vn.train(documentation="Sql/AreaSchoolLesson.md")
|
|
|
|
|
|
|
|
|
|
# 使用 SQL 进行训练
|
|
|
|
|
with open('Sql/AreaSchoolLessonGenerate.sql', 'r', encoding='utf-8') as file:
|
|
|
|
|
sql_content = file.read()
|
|
|
|
|
# 使用正则表达式提取注释和 SQL 语句
|
|
|
|
|
sql_pattern = r'/\*(.*?)\*/(.*?);'
|
|
|
|
|
sql_snippets = re.findall(sql_pattern, sql_content, re.DOTALL)
|
|
|
|
|
|
|
|
|
|
# 打印提取的注释和 SQL 语句
|
|
|
|
|
for i, (comment, sql) in enumerate(sql_snippets, 1):
|
|
|
|
|
vn.train(sql=comment.strip() + '\n' + sql.strip() + '\n')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 确保直接运行脚本时启动 FastAPI 应用
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
vn = VannaUtil()
|
|
|
|
|
main()
|