|
|
|
@ -17,8 +17,9 @@ from Text2Sql.Util.SaveToExcel import save_to_excel
|
|
|
|
|
from Text2Sql.Util.VannaUtil import VannaUtil
|
|
|
|
|
from fastapi.middleware.cors import CORSMiddleware # 导入跨域中间件
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vn = VannaUtil()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 初始化 FastAPI 应用
|
|
|
|
|
@asynccontextmanager
|
|
|
|
|
async def lifespan(app: FastAPI):
|
|
|
|
@ -36,6 +37,7 @@ async def lifespan(app: FastAPI):
|
|
|
|
|
# 关闭时释放连接池
|
|
|
|
|
await app.state.pool.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 初始化 FastAPI
|
|
|
|
|
app = FastAPI(lifespan=lifespan)
|
|
|
|
|
|
|
|
|
@ -51,11 +53,18 @@ app.add_middleware(
|
|
|
|
|
# 挂载静态文件目录
|
|
|
|
|
app.mount("/static", StaticFiles(directory="static"), name="static")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 依赖注入连接池
|
|
|
|
|
async def get_db():
|
|
|
|
|
async with app.state.pool.acquire() as connection:
|
|
|
|
|
yield connection
|
|
|
|
|
|
|
|
|
|
# 初始化 OpenAI 客户端
|
|
|
|
|
client = AsyncOpenAI(
|
|
|
|
|
api_key=MODEL_API_KEY,
|
|
|
|
|
base_url=MODEL_API_URL,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@app.post("/questions/get_excel")
|
|
|
|
|
async def get_excel(question_id: str = Form(...), question_str: str = Form(...),
|
|
|
|
|
db: asyncpg.Connection = Depends(get_db)):
|
|
|
|
@ -70,13 +79,18 @@ async def get_excel(question_id: str = Form(...), question_str: str = Form(...),
|
|
|
|
|
'''
|
|
|
|
|
question = question_str + common_prompt
|
|
|
|
|
|
|
|
|
|
# 先删除后插入,防止重复插入
|
|
|
|
|
await delete_question(db, question_id)
|
|
|
|
|
await insert_question(db, question_id, question)
|
|
|
|
|
|
|
|
|
|
# 获取完整 SQL
|
|
|
|
|
sql = vn.generate_sql(question)
|
|
|
|
|
print("生成的查询 SQL:\n", sql)
|
|
|
|
|
# 先删除后插入,防止重复插入
|
|
|
|
|
await delete_question(db, question_id)
|
|
|
|
|
await insert_question(db, question_id, question)
|
|
|
|
|
# 检查,如果sql为空,则返回错误信息
|
|
|
|
|
if not sql:
|
|
|
|
|
return {"success": False, "message": "无法生成相应的SQL语句!"}
|
|
|
|
|
# 检查,如果SQL无法正确执行,返回错误消息
|
|
|
|
|
if not await get_data_by_sql(db, sql):
|
|
|
|
|
return {"success": False, "message": "无法生成相应的SQL语句!"}
|
|
|
|
|
|
|
|
|
|
# 更新question_id
|
|
|
|
|
await update_question_by_id(db, question_id=question_id, sql=sql, state_id=1)
|
|
|
|
@ -94,12 +108,6 @@ async def get_excel(question_id: str = Form(...), question_str: str = Form(...),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# http://10.10.21.20:8000/questions/get_docx_stream?question_id_get=af15d834-e7f5-46b4-a0f6-15f1f888f443
|
|
|
|
|
# 初始化 OpenAI 客户端
|
|
|
|
|
client = AsyncOpenAI(
|
|
|
|
|
api_key=MODEL_API_KEY,
|
|
|
|
|
base_url=MODEL_API_URL,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.api_route("/questions/get_docx_stream", methods=["POST", "GET"])
|
|
|
|
|
async def get_docx_stream(
|
|
|
|
|