You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

110 lines
3.6 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import re
import uvicorn # 导入 uvicorn
from fastapi import FastAPI, HTTPException, Depends
from sqlalchemy.orm import Session
from starlette.responses import StreamingResponse
from Text2Sql.Util.PostgreSQLUtil import PostgreSQLUtil
from Text2Sql.Util.SaveToExcel import save_to_excel_stream
from Text2Sql.Util.VannaUtil import VannaUtil
from Model.biModel import *
# 初始化 FastAPI
app = FastAPI()
@app.get("/")
def read_root():
return {"message": "Hello, World!"}
# 获取数据库会话
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
# 创建记录
@app.post("/questions/", response_model=TBIQuestionCreate)
def create_question(question: TBIQuestionCreate, db: Session = Depends(get_db)):
db_question = TBIQuestion(**question.dict())
db.add(db_question)
db.commit()
db.refresh(db_question)
return db_question
# 读取记录
@app.get("/questions/{question_id}", response_model=TBIQuestionCreate)
def read_question(question_id: int, db: Session = Depends(get_db)):
db_question = db.query(TBIQuestion).filter(TBIQuestion.id == question_id).first()
if db_question is None:
raise HTTPException(status_code=404, detail="Question not found")
return db_question
# 更新记录
@app.put("/questions/{question_id}", response_model=TBIQuestionCreate)
def update_question(question_id: int, question: TBIQuestionUpdate, db: Session = Depends(get_db)):
db_question = db.query(TBIQuestion).filter(TBIQuestion.id == question_id).first()
if db_question is None:
raise HTTPException(status_code=404, detail="Question not found")
for key, value in question.dict().items():
if value is not None:
setattr(db_question, key, value)
db.commit()
db.refresh(db_question)
return db_question
# 删除记录
@app.delete("/questions/{question_id}")
def delete_question(question_id: int, db: Session = Depends(get_db)):
db_question = db.query(TBIQuestion).filter(TBIQuestion.id == question_id).first()
if db_question is None:
raise HTTPException(status_code=404, detail="Question not found")
db.delete(db_question)
db.commit()
return {"message": "Question deleted successfully"}
# 通过语义生成Excel
@app.post("/questions/get_excel")
def get_excel(question: str):
# 指定学段
# question = '''
# 查询:
# 1、发布时间是2024年度
# 2、每个学段每个科目上传课程数量按由多到少排序
# 3、字段名: 学段,科目,排名,课程数量
# '''
common_prompt = '''
返回的信息要求:
1、行政区划为NULL 或者是空字符的不参加统计
2、目标数据库是Postgresql 16
'''
question = question + common_prompt
# 获取完整 SQL
sql = vn.generate_sql(question)
print("生成的查询 SQL:\n", sql)
# 执行SQL查询
with PostgreSQLUtil() as db:
_data = db.execute_query(sql)
# 将数据保存为Excel文件流
excel_stream = save_to_excel_stream(_data)
# 返回Excel文件流
return StreamingResponse(
excel_stream,
media_type="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
headers={"Content-Disposition": "attachment; filename=导出信息.xlsx"}
)
# 确保直接运行脚本时启动 FastAPI 应用
if __name__ == "__main__":
vn = VannaUtil()
uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True)