You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

141 lines
3.9 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import asyncpg
# 删除数据
async def delete_question(db: asyncpg.Connection, question_id: str):
delete_sql = """
DELETE FROM t_bi_question WHERE id = $1
"""
await db.execute(delete_sql, question_id)
# 插入数据
async def insert_question(db: asyncpg.Connection, question_id: str, question: str):
insert_sql = """
INSERT INTO t_bi_question (id, question, state_id, is_system, is_collect)
VALUES ($1, $2, $3, $4, $5)
"""
await db.execute(insert_sql, question_id, question, 0, 0, 0)
# 修改数据
async def update_question_by_id(db: asyncpg.Connection, question_id: str, **kwargs):
update_fields = {k: v for k, v in kwargs.items() if v is not None}
if not update_fields:
return False
set_clause = ", ".join([f"{field} = ${i + 1}" for i, field in enumerate(update_fields.keys())])
sql = f"""
UPDATE t_bi_question
SET {set_clause}
WHERE id = ${len(update_fields) + 1}
"""
params = list(update_fields.values()) + [question_id]
try:
await db.execute(sql, *params)
return True
except Exception as e:
print(f"更新失败: {e}")
return False
# 根据问题 ID 查询 SQL
async def get_question_by_id(db: asyncpg.Connection, question_id: str):
select_sql = """
SELECT * FROM t_bi_question WHERE id = $1
"""
_data = await db.fetch(select_sql, question_id)
return _data
# 根据 SQL 查询数据
async def get_data_by_sql(db: asyncpg.Connection, sql: str):
_data = await db.fetch(sql)
return _data
# 保存系统推荐
async def set_system_recommend_questions(db: asyncpg.Connection, question_id: str, flag: str):
sql = """
UPDATE t_bi_question
SET is_system = $1 WHERE id = $2
"""
try:
await db.execute(sql, int(flag), question_id)
return True
except Exception as e:
print(f"更新失败: {e}")
return False
# 设置用户收藏
async def set_user_collect_questions(db: asyncpg.Connection, question_id: str, flag: str):
sql = """
UPDATE t_bi_question
SET is_collect = $1 WHERE id = $2
"""
try:
await db.execute(sql, int(flag), question_id)
return True
except Exception as e:
print(f"更新失败: {e}")
return False
# 查询系统推荐问题
async def get_system_recommend_questions(db: asyncpg.Connection, offset: int, limit: int):
query = """
SELECT *
FROM t_bi_question where is_system=1 ORDER BY id DESC LIMIT $1 OFFSET $2;
"""
return await db.fetch(query, limit, offset)
async def get_system_recommend_questions_count(db: asyncpg.Connection):
query = """
SELECT COUNT(*)
FROM t_bi_question where is_system=1;
"""
return await db.fetchval(query)
async def get_user_publish_questions(db: asyncpg.Connection, type_id: int, offset: int, limit: int):
# 基础查询
query = """
SELECT *
FROM t_bi_question
"""
# 根据 type_id 动态添加 WHERE 条件
if type_id == 1:
query += " WHERE is_collect = 1"
# 添加排序和分页
query += " ORDER BY id DESC LIMIT $1 OFFSET $2;"
# 执行查询
return await db.fetch(query, limit, offset)
async def get_user_publish_questions_count(db: asyncpg.Connection):
query = """
SELECT COUNT(*) FROM t_bi_question;
"""
return await db.fetchval(query)
# 获取数据集的字段名称
async def get_column_names(db: asyncpg.Connection, sql: str):
# 执行查询(添加 LIMIT 1
sql = sql.replace(";", "")
sql = sql + ' limit 1'
result = await db.fetchrow(sql)
# 获取列名
# 获取列名
if result:
column_names = list(result.keys())
return column_names
else:
return []