|
|
import uuid
|
|
|
|
|
|
import uvicorn # 导入 uvicorn
|
|
|
from fastapi import FastAPI, HTTPException, Depends, Form
|
|
|
from sqlalchemy.orm import Session
|
|
|
from starlette.responses import JSONResponse
|
|
|
from starlette.staticfiles import StaticFiles
|
|
|
|
|
|
from Model.biModel import *
|
|
|
from Text2Sql.Util.PostgreSQLUtil import PostgreSQLUtil
|
|
|
from Text2Sql.Util.SaveToExcel import save_to_excel
|
|
|
from Text2Sql.Util.VannaUtil import VannaUtil
|
|
|
|
|
|
# 初始化 FastAPI
|
|
|
app = FastAPI()
|
|
|
# 配置静态文件目录
|
|
|
app.mount("/static", StaticFiles(directory="static"), name="static")
|
|
|
vn = VannaUtil()
|
|
|
@app.get("/")
|
|
|
def read_root():
|
|
|
return {"message": "Welcome to Vanna AI SQL !"}
|
|
|
|
|
|
|
|
|
# 创建记录
|
|
|
@app.post("/questions/", response_model=TBIQuestionCreate)
|
|
|
def create_question(question: TBIQuestionCreate, db: Session = Depends(get_db)):
|
|
|
db_question = TBIQuestion(**question.dict())
|
|
|
db.add(db_question)
|
|
|
db.commit()
|
|
|
db.refresh(db_question)
|
|
|
return db_question
|
|
|
|
|
|
|
|
|
# 读取记录
|
|
|
@app.get("/questions/{question_id}", response_model=TBIQuestionCreate)
|
|
|
def read_question(question_id: int, db: Session = Depends(get_db)):
|
|
|
db_question = db.query(TBIQuestion).filter(TBIQuestion.id == question_id).first()
|
|
|
if db_question is None:
|
|
|
raise HTTPException(status_code=404, detail="Question not found")
|
|
|
return db_question
|
|
|
|
|
|
|
|
|
# 更新记录
|
|
|
@app.put("/questions/{question_id}", response_model=TBIQuestionCreate)
|
|
|
def update_question(question_id: int, question: TBIQuestionUpdate, db: Session = Depends(get_db)):
|
|
|
db_question = db.query(TBIQuestion).filter(TBIQuestion.id == question_id).first()
|
|
|
if db_question is None:
|
|
|
raise HTTPException(status_code=404, detail="Question not found")
|
|
|
for key, value in question.dict().items():
|
|
|
if value is not None:
|
|
|
setattr(db_question, key, value)
|
|
|
db.commit()
|
|
|
db.refresh(db_question)
|
|
|
return db_question
|
|
|
|
|
|
|
|
|
# 删除记录
|
|
|
@app.delete("/questions/{question_id}")
|
|
|
def delete_question(question_id: int, db: Session = Depends(get_db)):
|
|
|
db_question = db.query(TBIQuestion).filter(TBIQuestion.id == question_id).first()
|
|
|
if db_question is None:
|
|
|
raise HTTPException(status_code=404, detail="Question not found")
|
|
|
db.delete(db_question)
|
|
|
db.commit()
|
|
|
return {"message": "Question deleted successfully"}
|
|
|
|
|
|
|
|
|
# 通过语义生成Excel
|
|
|
# http://10.10.21.20:8000/questions/get_excel
|
|
|
# 参数:question
|
|
|
@app.post("/questions/get_excel")
|
|
|
def get_excel(question: str = Form(...)):
|
|
|
|
|
|
common_prompt = '''
|
|
|
返回的信息要求:
|
|
|
1、行政区划为NULL 或者是空字符的不参加统计
|
|
|
2、目标数据库是Postgresql 16
|
|
|
'''
|
|
|
question = question + common_prompt
|
|
|
# 获取完整 SQL
|
|
|
sql = vn.generate_sql(question)
|
|
|
print("生成的查询 SQL:\n", sql)
|
|
|
# 执行SQL查询
|
|
|
with PostgreSQLUtil() as db:
|
|
|
_data = db.execute_query(sql)
|
|
|
# 在static目录下,生成一个guid号的临时文件
|
|
|
uuidStr = str(uuid.uuid4())
|
|
|
filename = f"static/{uuidStr}.xlsx"
|
|
|
save_to_excel(_data, filename)
|
|
|
# 返回静态文件URL
|
|
|
return {"success": True, "message": "Excel文件生成成功", "download_url": f"/static/{uuidStr}.xlsx"}
|
|
|
|
|
|
|
|
|
|
|
|
# 确保直接运行脚本时启动 FastAPI 应用
|
|
|
if __name__ == "__main__":
|
|
|
uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True)
|