You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

98 lines
3.4 KiB

4 months ago
import uuid
4 months ago
import uvicorn # 导入 uvicorn
4 months ago
from fastapi import FastAPI, HTTPException, Depends, Form
4 months ago
from sqlalchemy.orm import Session
4 months ago
from starlette.responses import JSONResponse
from starlette.staticfiles import StaticFiles
from Model.biModel import *
4 months ago
from Text2Sql.Util.PostgreSQLUtil import PostgreSQLUtil
4 months ago
from Text2Sql.Util.SaveToExcel import save_to_excel
4 months ago
from Text2Sql.Util.VannaUtil import VannaUtil
# 初始化 FastAPI
app = FastAPI()
4 months ago
# 配置静态文件目录
app.mount("/static", StaticFiles(directory="static"), name="static")
4 months ago
vn = VannaUtil()
4 months ago
@app.get("/")
def read_root():
4 months ago
return {"message": "Welcome to Vanna AI SQL !"}
4 months ago
4 months ago
4 months ago
# 创建记录
@app.post("/questions/", response_model=TBIQuestionCreate)
def create_question(question: TBIQuestionCreate, db: Session = Depends(get_db)):
db_question = TBIQuestion(**question.dict())
db.add(db_question)
db.commit()
db.refresh(db_question)
return db_question
4 months ago
4 months ago
# 读取记录
@app.get("/questions/{question_id}", response_model=TBIQuestionCreate)
def read_question(question_id: int, db: Session = Depends(get_db)):
db_question = db.query(TBIQuestion).filter(TBIQuestion.id == question_id).first()
if db_question is None:
raise HTTPException(status_code=404, detail="Question not found")
return db_question
4 months ago
4 months ago
# 更新记录
@app.put("/questions/{question_id}", response_model=TBIQuestionCreate)
def update_question(question_id: int, question: TBIQuestionUpdate, db: Session = Depends(get_db)):
db_question = db.query(TBIQuestion).filter(TBIQuestion.id == question_id).first()
if db_question is None:
raise HTTPException(status_code=404, detail="Question not found")
for key, value in question.dict().items():
if value is not None:
setattr(db_question, key, value)
db.commit()
db.refresh(db_question)
return db_question
4 months ago
4 months ago
# 删除记录
@app.delete("/questions/{question_id}")
def delete_question(question_id: int, db: Session = Depends(get_db)):
db_question = db.query(TBIQuestion).filter(TBIQuestion.id == question_id).first()
if db_question is None:
raise HTTPException(status_code=404, detail="Question not found")
db.delete(db_question)
db.commit()
return {"message": "Question deleted successfully"}
4 months ago
4 months ago
# 通过语义生成Excel
4 months ago
# http://10.10.21.20:8000/questions/get_excel
# 参数question
4 months ago
@app.post("/questions/get_excel")
4 months ago
def get_excel(question: str = Form(...)):
4 months ago
4 months ago
common_prompt = '''
返回的信息要求
1行政区划为NULL 或者是空字符的不参加统计
2目标数据库是Postgresql 16
'''
question = question + common_prompt
# 获取完整 SQL
sql = vn.generate_sql(question)
print("生成的查询 SQL:\n", sql)
4 months ago
# 执行SQL查询
with PostgreSQLUtil() as db:
_data = db.execute_query(sql)
4 months ago
# 在static目录下生成一个guid号的临时文件
uuidStr = str(uuid.uuid4())
filename = f"static/{uuidStr}.xlsx"
save_to_excel(_data, filename)
# 返回静态文件URL
return {"success": True, "message": "Excel文件生成成功", "download_url": f"/static/{uuidStr}.xlsx"}
4 months ago
# 确保直接运行脚本时启动 FastAPI 应用
if __name__ == "__main__":
4 months ago
uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True)