You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

106 lines
3.7 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import uuid
import uvicorn # 导入 uvicorn
from fastapi import FastAPI, HTTPException, Depends, Form
from sqlalchemy.orm import Session
from starlette.responses import JSONResponse
from starlette.staticfiles import StaticFiles
from Model.biModel import *
from Text2Sql.Util.PostgreSQLUtil import PostgreSQLUtil
from Text2Sql.Util.SaveToExcel import save_to_excel
from Text2Sql.Util.VannaUtil import VannaUtil
# 初始化 FastAPI
app = FastAPI()
# 配置静态文件目录
app.mount("/static", StaticFiles(directory="static"), name="static")
@app.get("/")
def read_root():
return {"message": "Welcome to Vanna AI SQL !"}
# 创建记录
@app.post("/questions/", response_model=TBIQuestionCreate)
def create_question(question: TBIQuestionCreate, db: Session = Depends(get_db)):
db_question = TBIQuestion(**question.dict())
db.add(db_question)
db.commit()
db.refresh(db_question)
return db_question
# 读取记录
@app.get("/questions/{question_id}", response_model=TBIQuestionCreate)
def read_question(question_id: int, db: Session = Depends(get_db)):
db_question = db.query(TBIQuestion).filter(TBIQuestion.id == question_id).first()
if db_question is None:
raise HTTPException(status_code=404, detail="Question not found")
return db_question
# 更新记录
@app.put("/questions/{question_id}", response_model=TBIQuestionCreate)
def update_question(question_id: int, question: TBIQuestionUpdate, db: Session = Depends(get_db)):
db_question = db.query(TBIQuestion).filter(TBIQuestion.id == question_id).first()
if db_question is None:
raise HTTPException(status_code=404, detail="Question not found")
for key, value in question.dict().items():
if value is not None:
setattr(db_question, key, value)
db.commit()
db.refresh(db_question)
return db_question
# 删除记录
@app.delete("/questions/{question_id}")
def delete_question(question_id: int, db: Session = Depends(get_db)):
db_question = db.query(TBIQuestion).filter(TBIQuestion.id == question_id).first()
if db_question is None:
raise HTTPException(status_code=404, detail="Question not found")
db.delete(db_question)
db.commit()
return {"message": "Question deleted successfully"}
# 通过语义生成Excel
# http://10.10.21.20:8000/questions/get_excel
# 参数question
@app.post("/questions/get_excel")
def get_excel(question: str = Form(...)):
vn = VannaUtil()
# 指定学段
# question = '''
# 查询:
# 1、发布时间是2024年度
# 2、每个学段每个科目上传课程数量按由多到少排序
# 3、字段名: 学段,科目,排名,课程数量
# '''
common_prompt = '''
返回的信息要求:
1、行政区划为NULL 或者是空字符的不参加统计
2、目标数据库是Postgresql 16
'''
question = question + common_prompt
# 获取完整 SQL
sql = vn.generate_sql(question)
print("生成的查询 SQL:\n", sql)
# 执行SQL查询
with PostgreSQLUtil() as db:
_data = db.execute_query(sql)
print(_data)
# 在static目录下生成一个guid号的临时文件
uuidStr = str(uuid.uuid4())
filename = f"static/{uuidStr}.xlsx"
save_to_excel(_data, filename)
# 返回静态文件URL
return {"success": True, "message": "Excel文件生成成功", "download_url": f"/static/{uuidStr}.xlsx"}
# 确保直接运行脚本时启动 FastAPI 应用
if __name__ == "__main__":
uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True)