You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

128 lines
4.0 KiB

# Database.py
import logging
import math
import aiomysql
import asyncio
from config.Config import *
# 创建一个全局的连接池
pool = None
async def create_pool(loop):
global pool
pool = await aiomysql.create_pool(
host=MYSQL_HOST,
port=MYSQL_PORT,
user=MYSQL_USER,
password=MYSQL_PASSWORD,
db=MYSQL_DB_NAME,
minsize=1, # 设置连接池最小连接数
maxsize=MYSQL_POOL_SIZE, # 设置连接池最大连接数
cursorclass=aiomysql.DictCursor # 指定游标为字典模式
)
async def get_connection():
if pool is None:
raise Exception("连接池未初始化")
async with pool.acquire() as conn:
return conn
async def close_pool():
if pool is not None:
pool.close()
await pool.wait_closed()
# 初始化连接池的函数
async def init_database():
loop = asyncio.get_event_loop()
await create_pool(loop)
# 关闭连接池的函数
async def shutdown_database():
await close_pool()
# 根据sql语句查询数据
async def find_by_sql(sql: str, params: tuple):
if pool is None:
logging.error("数据库连接池未创建")
return None
try:
async with pool.acquire() as conn:
async with conn.cursor() as cur:
await cur.execute(sql, params)
result = await cur.fetchall()
if result:
return result
else:
return None
except Exception as e:
logging.error(f"数据库查询错误: {e}")
return None
# 根据sql语句查询数据
async def find_one_by_sql(sql: str, params: tuple):
if pool is None:
logging.error("数据库连接池未创建")
return None
try:
async with pool.acquire() as conn:
async with conn.cursor() as cur:
await cur.execute(sql, params)
result = await cur.fetchone()
if result:
return result
else:
return None
except Exception as e:
logging.error(f"数据库查询错误: {e}")
return None
# 查询数据条数
async def get_total_data_count(total_data_sql):
total_data_count = 0
total_data_count_sql = "select count(1) as count from (" + total_data_sql + ") as temp_table"
result = await find_one_by_sql(total_data_count_sql, ())
if result:
total_data_count = result.get("count")
return total_data_count
def get_page_by_total_row(total_data_count, page_number, page_size):
total_page = (page_size != 0) and math.floor((total_data_count + page_size - 1) / page_size) or 0
if page_number <= 0:
page_number = 1
if 0 < total_page < page_number:
page_number = total_page
offset = page_size * page_number - page_size
limit = page_size
return total_data_count, total_page, offset, limit
async def get_page_data_by_sql(total_data_sql: str, page_number: int, page_size: int):
if pool is None:
logging.error("数据库连接池未创建")
return None
total_row: int = 0
total_page: int = 0
total_data_sql = total_data_sql.replace(";", "")
total_data_sql = total_data_sql.replace(" FROM ", " from ")
# 查询总数
total_data_count = await get_total_data_count(total_data_sql)
if total_data_count == 0:
return {"page_number": page_number, "page_size": page_size, "total_row": 0, "total_page": 0, "list": []}
else:
total_row, total_page, offset, limit = get_page_by_total_row(total_data_count, page_number, page_size)
# 构造执行分页查询的sql语句
page_data_sql = total_data_sql + " LIMIT %d, %d " % (offset, limit)
print(page_data_sql)
# 执行分页查询
page_data = await find_by_sql(page_data_sql, ())
if page_data:
return {"page_number": page_number, "page_size": page_size, "total_row": total_row, "total_page": total_page, "list": page_data}
else:
return {"page_number": page_number, "page_size": page_size, "total_row": 0, "total_page": 0, "list": []}