main
HuangHai 4 months ago
parent 8ab96bb9e6
commit 26896588e1

@ -1,5 +1,6 @@
import asyncpg
# 删除数据
async def delete_question(db: asyncpg.Connection, question_id: str):
delete_sql = """
@ -7,6 +8,7 @@ async def delete_question(db: asyncpg.Connection, question_id: str):
"""
await db.execute(delete_sql, question_id)
# 插入数据
async def insert_question(db: asyncpg.Connection, question_id: str, question: str):
insert_sql = """
@ -15,13 +17,14 @@ async def insert_question(db: asyncpg.Connection, question_id: str, question: st
"""
await db.execute(insert_sql, question_id, question, 0, 0, 0)
# 修改数据
async def update_question_by_id(db: asyncpg.Connection, question_id: str, **kwargs):
update_fields = {k: v for k, v in kwargs.items() if v is not None}
if not update_fields:
return False
set_clause = ", ".join([f"{field} = ${i+1}" for i, field in enumerate(update_fields.keys())])
set_clause = ", ".join([f"{field} = ${i + 1}" for i, field in enumerate(update_fields.keys())])
sql = f"""
UPDATE t_bi_question
SET {set_clause}
@ -36,6 +39,7 @@ async def update_question_by_id(db: asyncpg.Connection, question_id: str, **kwar
print(f"更新失败: {e}")
return False
# 根据问题 ID 查询 SQL
async def get_question_by_id(db: asyncpg.Connection, question_id: str):
select_sql = """
@ -44,6 +48,7 @@ async def get_question_by_id(db: asyncpg.Connection, question_id: str):
_data = await db.fetch(select_sql, question_id)
return _data
# 保存系统推荐
async def set_system_recommend_questions(db: asyncpg.Connection, question_id: str, flag: str):
sql = """
@ -57,6 +62,7 @@ async def set_system_recommend_questions(db: asyncpg.Connection, question_id: st
print(f"更新失败: {e}")
return False
# 设置用户收藏
async def set_user_collect_questions(db: asyncpg.Connection, question_id: str, flag: str):
sql = """
@ -70,6 +76,7 @@ async def set_user_collect_questions(db: asyncpg.Connection, question_id: str, f
print(f"更新失败: {e}")
return False
# 查询系统推荐问题
async def get_system_recommend_questions(db: asyncpg.Connection):
sql = """
@ -78,10 +85,26 @@ async def get_system_recommend_questions(db: asyncpg.Connection):
_data = await db.fetch(sql)
return _data
# 查询用户收藏问题
async def get_user_collect_questions(db: asyncpg.Connection):
sql = """
SELECT * FROM t_bi_question WHERE is_collect = 1
"""
_data = await db.fetch(sql)
return _data
return _data
# 获取数据集的字段名称
async def get_column_names(db: asyncpg.Connection, sql: str):
# 执行查询(添加 LIMIT 1
sql = sql.replace(";", "")
sql = sql + ' limit 1'
result = await db.fetchrow(sql)
# 获取列名
# 获取列名
if result:
column_names = list(result.keys())
return column_names
else:
return []

@ -19,7 +19,7 @@ def save_to_excel(data, filename):
df.to_excel(writer, index=False, sheet_name='Sheet1')
# 获取工作表对象
workbook = writer.book
#workbook = writer.book
worksheet = writer.sheets['Sheet1']
# 定义边框样式

@ -257,6 +257,17 @@ async def get_user_collect(db: asyncpg.Connection = Depends(get_db)): # 添加
return {"success": True, "data": user_collect_questions}
# 获取数据库字段名
@app.get("/questions/get_data_scheme_by_id")
async def get_data_scheme_by_id(
question_id: str = Query(None, description="问题IDGET请求"),
db: asyncpg.Connection = Depends(get_db)
):
sql = (await get_question_by_id(db, question_id))[0]['sql']
column_names = (await get_column_names(db, sql))
return {"success": True, "data": column_names}
# 启动 FastAPI
if __name__ == "__main__":
uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True, workers=4)
uvicorn.run("app:app", host="0.0.0.0", port=8000, workers=4)

Loading…
Cancel
Save