import re from fastapi import FastAPI, HTTPException, Depends from pydantic import BaseModel from sqlalchemy import create_engine, Column, Integer, String, SmallInteger, Date from sqlalchemy.orm import declarative_base # 更新导入路径 from sqlalchemy.orm import sessionmaker, Session import uvicorn # 导入 uvicorn from starlette.staticfiles import StaticFiles from Text2Sql.Util.VannaUtil import VannaUtil # 数据库连接配置 DATABASE_URL = "postgresql+psycopg2://postgres:DsideaL147258369@10.10.14.71:5432/szjz_db" # 创建数据库引擎 engine = create_engine(DATABASE_URL) # 创建 SessionLocal 类 SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) # 创建 Base 类 Base = declarative_base() # 定义 t_bi_question 表模型 class TBIQuestion(Base): __tablename__ = "t_bi_question" id = Column(Integer, primary_key=True, index=True) question = Column(String(255), nullable=False) state_id = Column(SmallInteger, nullable=False, default=0) sql = Column(String(2048)) bar_chart_x_columns = Column(String(255)) bar_chart_y_columns = Column(String(255)) pie_chart_category_columns = Column(String(255)) pie_chart_value_column = Column(String(255)) session_id = Column(String(255), nullable=False) excel_file_name = Column(String(255)) bar_chart_file_name = Column(String(255)) pie_chart_file_name = Column(String(255)) report_file_name = Column(String(255)) create_time = Column(Date, nullable=False, default="CURRENT_TIMESTAMP") is_system = Column(SmallInteger, nullable=False, default=0) is_collect = Column(SmallInteger, nullable=False, default=0) # 创建 Pydantic 模型 class TBIQuestionCreate(BaseModel): question: str state_id: int = 0 sql: str = None bar_chart_x_columns: str = None bar_chart_y_columns: str = None pie_chart_category_columns: str = None pie_chart_value_column: str = None session_id: str excel_file_name: str = None bar_chart_file_name: str = None pie_chart_file_name: str = None report_file_name: str = None is_system: int = 0 is_collect: int = 0 class TBIQuestionUpdate(BaseModel): question: str = None state_id: int = None sql: str = None bar_chart_x_columns: str = None bar_chart_y_columns: str = None pie_chart_category_columns: str = None pie_chart_value_column: str = None session_id: str = None excel_file_name: str = None bar_chart_file_name: str = None pie_chart_file_name: str = None report_file_name: str = None is_system: int = None is_collect: int = None # 初始化 FastAPI app = FastAPI() @app.get("/") def read_root(): return {"message": "Hello, World!"} # 获取数据库会话 def get_db(): db = SessionLocal() try: yield db finally: db.close() # 创建记录 @app.post("/questions/", response_model=TBIQuestionCreate) def create_question(question: TBIQuestionCreate, db: Session = Depends(get_db)): db_question = TBIQuestion(**question.dict()) db.add(db_question) db.commit() db.refresh(db_question) return db_question # 读取记录 @app.get("/questions/{question_id}", response_model=TBIQuestionCreate) def read_question(question_id: int, db: Session = Depends(get_db)): db_question = db.query(TBIQuestion).filter(TBIQuestion.id == question_id).first() if db_question is None: raise HTTPException(status_code=404, detail="Question not found") return db_question # 更新记录 @app.put("/questions/{question_id}", response_model=TBIQuestionCreate) def update_question(question_id: int, question: TBIQuestionUpdate, db: Session = Depends(get_db)): db_question = db.query(TBIQuestion).filter(TBIQuestion.id == question_id).first() if db_question is None: raise HTTPException(status_code=404, detail="Question not found") for key, value in question.dict().items(): if value is not None: setattr(db_question, key, value) db.commit() db.refresh(db_question) return db_question # 删除记录 @app.delete("/questions/{question_id}") def delete_question(question_id: int, db: Session = Depends(get_db)): db_question = db.query(TBIQuestion).filter(TBIQuestion.id == question_id).first() if db_question is None: raise HTTPException(status_code=404, detail="Question not found") db.delete(db_question) db.commit() return {"message": "Question deleted successfully"} @app.post("/questions/generate_sql") def generate_sql(question: str): # 指定学段 question = ''' 查询: 1、发布时间是2024年度 2、每个学段,每个科目,上传课程数量,按由多到少排序 3、字段名: 学段,科目,排名,课程数量 ''' common_prompt = ''' 返回的信息要求: 1、行政区划为NULL 或者是空字符的不参加统计 2、目标数据库是Postgresql 16 ''' question = question + common_prompt # 开始查询 print("开始查询...") # 获取完整 SQL sql = vn.generate_sql(question) print("生成的查询 SQL:\n", sql) # 添加 main 函数 def main(): # 开始训练 print("开始训练...") # 打开AreaSchoolLesson.sql文件内容 with open("Sql/AreaSchoolLessonDDL.sql", "r", encoding="utf-8") as file: ddl = file.read() # 训练数据 vn.train( ddl=ddl ) # 添加有关业务术语或定义的文档 # vn.train(documentation="Sql/AreaSchoolLesson.md") # 使用 SQL 进行训练 with open('Sql/AreaSchoolLessonGenerate.sql', 'r', encoding='utf-8') as file: sql_content = file.read() # 使用正则表达式提取注释和 SQL 语句 sql_pattern = r'/\*(.*?)\*/(.*?);' sql_snippets = re.findall(sql_pattern, sql_content, re.DOTALL) # 打印提取的注释和 SQL 语句 for i, (comment, sql) in enumerate(sql_snippets, 1): vn.train(sql=comment.strip() + '\n' + sql.strip() + '\n') uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True) # 确保直接运行脚本时启动 FastAPI 应用 if __name__ == "__main__": vn = VannaUtil() main()