You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

188 lines
6.3 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import re
from fastapi import FastAPI, HTTPException, Depends
from pydantic import BaseModel
from sqlalchemy import create_engine, Column, Integer, String, SmallInteger, Date
from sqlalchemy.orm import declarative_base # 更新导入路径
from sqlalchemy.orm import sessionmaker, Session
import uvicorn # 导入 uvicorn
from starlette.staticfiles import StaticFiles
from Text2Sql.Util.VannaUtil import VannaUtil
# 数据库连接配置
DATABASE_URL = "postgresql+psycopg2://postgres:DsideaL147258369@10.10.14.71:5432/szjz_db"
# 创建数据库引擎
engine = create_engine(DATABASE_URL)
# 创建 SessionLocal 类
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# 创建 Base 类
Base = declarative_base()
# 定义 t_bi_question 表模型
class TBIQuestion(Base):
__tablename__ = "t_bi_question"
id = Column(Integer, primary_key=True, index=True)
question = Column(String(255), nullable=False)
state_id = Column(SmallInteger, nullable=False, default=0)
sql = Column(String(2048))
bar_chart_x_columns = Column(String(255))
bar_chart_y_columns = Column(String(255))
pie_chart_category_columns = Column(String(255))
pie_chart_value_column = Column(String(255))
session_id = Column(String(255), nullable=False)
excel_file_name = Column(String(255))
bar_chart_file_name = Column(String(255))
pie_chart_file_name = Column(String(255))
report_file_name = Column(String(255))
create_time = Column(Date, nullable=False, default="CURRENT_TIMESTAMP")
is_system = Column(SmallInteger, nullable=False, default=0)
is_collect = Column(SmallInteger, nullable=False, default=0)
# 创建 Pydantic 模型
class TBIQuestionCreate(BaseModel):
question: str
state_id: int = 0
sql: str = None
bar_chart_x_columns: str = None
bar_chart_y_columns: str = None
pie_chart_category_columns: str = None
pie_chart_value_column: str = None
session_id: str
excel_file_name: str = None
bar_chart_file_name: str = None
pie_chart_file_name: str = None
report_file_name: str = None
is_system: int = 0
is_collect: int = 0
class TBIQuestionUpdate(BaseModel):
question: str = None
state_id: int = None
sql: str = None
bar_chart_x_columns: str = None
bar_chart_y_columns: str = None
pie_chart_category_columns: str = None
pie_chart_value_column: str = None
session_id: str = None
excel_file_name: str = None
bar_chart_file_name: str = None
pie_chart_file_name: str = None
report_file_name: str = None
is_system: int = None
is_collect: int = None
# 初始化 FastAPI
app = FastAPI()
@app.get("/")
def read_root():
return {"message": "Hello, World!"}
# 获取数据库会话
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
# 创建记录
@app.post("/questions/", response_model=TBIQuestionCreate)
def create_question(question: TBIQuestionCreate, db: Session = Depends(get_db)):
db_question = TBIQuestion(**question.dict())
db.add(db_question)
db.commit()
db.refresh(db_question)
return db_question
# 读取记录
@app.get("/questions/{question_id}", response_model=TBIQuestionCreate)
def read_question(question_id: int, db: Session = Depends(get_db)):
db_question = db.query(TBIQuestion).filter(TBIQuestion.id == question_id).first()
if db_question is None:
raise HTTPException(status_code=404, detail="Question not found")
return db_question
# 更新记录
@app.put("/questions/{question_id}", response_model=TBIQuestionCreate)
def update_question(question_id: int, question: TBIQuestionUpdate, db: Session = Depends(get_db)):
db_question = db.query(TBIQuestion).filter(TBIQuestion.id == question_id).first()
if db_question is None:
raise HTTPException(status_code=404, detail="Question not found")
for key, value in question.dict().items():
if value is not None:
setattr(db_question, key, value)
db.commit()
db.refresh(db_question)
return db_question
# 删除记录
@app.delete("/questions/{question_id}")
def delete_question(question_id: int, db: Session = Depends(get_db)):
db_question = db.query(TBIQuestion).filter(TBIQuestion.id == question_id).first()
if db_question is None:
raise HTTPException(status_code=404, detail="Question not found")
db.delete(db_question)
db.commit()
return {"message": "Question deleted successfully"}
@app.post("/questions/generate_sql")
def generate_sql(question: str):
# 指定学段
question = '''
查询:
1、发布时间是2024年度
2、每个学段每个科目上传课程数量按由多到少排序
3、字段名: 学段,科目,排名,课程数量
'''
common_prompt = '''
返回的信息要求:
1、行政区划为NULL 或者是空字符的不参加统计
2、目标数据库是Postgresql 16
'''
question = question + common_prompt
# 开始查询
print("开始查询...")
# 获取完整 SQL
sql = vn.generate_sql(question)
print("生成的查询 SQL:\n", sql)
# 添加 main 函数
def main():
# 开始训练
print("开始训练...")
# 打开AreaSchoolLesson.sql文件内容
with open("Sql/AreaSchoolLessonDDL.sql", "r", encoding="utf-8") as file:
ddl = file.read()
# 训练数据
vn.train(
ddl=ddl
)
# 添加有关业务术语或定义的文档
# vn.train(documentation="Sql/AreaSchoolLesson.md")
# 使用 SQL 进行训练
with open('Sql/AreaSchoolLessonGenerate.sql', 'r', encoding='utf-8') as file:
sql_content = file.read()
# 使用正则表达式提取注释和 SQL 语句
sql_pattern = r'/\*(.*?)\*/(.*?);'
sql_snippets = re.findall(sql_pattern, sql_content, re.DOTALL)
# 打印提取的注释和 SQL 语句
for i, (comment, sql) in enumerate(sql_snippets, 1):
vn.train(sql=comment.strip() + '\n' + sql.strip() + '\n')
uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True)
# 确保直接运行脚本时启动 FastAPI 应用
if __name__ == "__main__":
vn = VannaUtil()
main()