You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
128 lines
4.0 KiB
128 lines
4.0 KiB
# Database.py
|
|
import logging
|
|
import math
|
|
|
|
import aiomysql
|
|
import asyncio
|
|
from config.Config import *
|
|
|
|
# 创建一个全局的连接池
|
|
pool = None
|
|
|
|
async def create_pool(loop):
|
|
global pool
|
|
pool = await aiomysql.create_pool(
|
|
host=MYSQL_HOST,
|
|
port=MYSQL_PORT,
|
|
user=MYSQL_USER,
|
|
password=MYSQL_PASSWORD,
|
|
db=MYSQL_DB_NAME,
|
|
minsize=1, # 设置连接池最小连接数
|
|
maxsize=MYSQL_POOL_SIZE, # 设置连接池最大连接数
|
|
cursorclass=aiomysql.DictCursor # 指定游标为字典模式
|
|
)
|
|
|
|
async def get_connection():
|
|
if pool is None:
|
|
raise Exception("连接池未初始化")
|
|
async with pool.acquire() as conn:
|
|
return conn
|
|
|
|
async def close_pool():
|
|
if pool is not None:
|
|
pool.close()
|
|
await pool.wait_closed()
|
|
|
|
# 初始化连接池的函数
|
|
async def init_database():
|
|
loop = asyncio.get_event_loop()
|
|
await create_pool(loop)
|
|
|
|
# 关闭连接池的函数
|
|
async def shutdown_database():
|
|
await close_pool()
|
|
|
|
# 根据sql语句查询数据
|
|
async def find_by_sql(sql: str, params: tuple):
|
|
if pool is None:
|
|
logging.error("数据库连接池未创建")
|
|
return None
|
|
try:
|
|
async with pool.acquire() as conn:
|
|
async with conn.cursor() as cur:
|
|
await cur.execute(sql, params)
|
|
result = await cur.fetchall()
|
|
if result:
|
|
return result
|
|
else:
|
|
return None
|
|
except Exception as e:
|
|
logging.error(f"数据库查询错误: {e}")
|
|
return None
|
|
|
|
|
|
# 根据sql语句查询数据
|
|
async def find_one_by_sql(sql: str, params: tuple):
|
|
if pool is None:
|
|
logging.error("数据库连接池未创建")
|
|
return None
|
|
try:
|
|
async with pool.acquire() as conn:
|
|
async with conn.cursor() as cur:
|
|
await cur.execute(sql, params)
|
|
result = await cur.fetchone()
|
|
if result:
|
|
return result
|
|
else:
|
|
return None
|
|
except Exception as e:
|
|
logging.error(f"数据库查询错误: {e}")
|
|
return None
|
|
|
|
# 查询数据条数
|
|
async def get_total_data_count(total_data_sql):
|
|
total_data_count = 0
|
|
total_data_count_sql = "select count(1) as count from (" + total_data_sql + ") as temp_table"
|
|
result = await find_one_by_sql(total_data_count_sql, ())
|
|
if result:
|
|
total_data_count = result.get("count")
|
|
return total_data_count
|
|
|
|
|
|
def get_page_by_total_row(total_data_count, page_number, page_size):
|
|
total_page = (page_size != 0) and math.floor((total_data_count + page_size - 1) / page_size) or 0
|
|
if page_number <= 0:
|
|
page_number = 1
|
|
if 0 < total_page < page_number:
|
|
page_number = total_page
|
|
offset = page_size * page_number - page_size
|
|
limit = page_size
|
|
return total_data_count, total_page, offset, limit
|
|
|
|
|
|
async def get_page_data_by_sql(total_data_sql: str, page_number: int, page_size: int):
|
|
if pool is None:
|
|
logging.error("数据库连接池未创建")
|
|
return None
|
|
total_row: int = 0
|
|
total_page: int = 0
|
|
total_data_sql = total_data_sql.replace(";", "")
|
|
total_data_sql = total_data_sql.replace(" FROM ", " from ")
|
|
|
|
# 查询总数
|
|
total_data_count = await get_total_data_count(total_data_sql)
|
|
if total_data_count == 0:
|
|
return {"page_number": page_number, "page_size": page_size, "total_row": 0, "total_page": 0, "list": []}
|
|
else:
|
|
total_row, total_page, offset, limit = get_page_by_total_row(total_data_count, page_number, page_size)
|
|
|
|
# 构造执行分页查询的sql语句
|
|
page_data_sql = total_data_sql + " LIMIT %d, %d " % (offset, limit)
|
|
print(page_data_sql)
|
|
# 执行分页查询
|
|
page_data = await find_by_sql(page_data_sql, ())
|
|
if page_data:
|
|
return {"page_number": page_number, "page_size": page_size, "total_row": total_row, "total_page": total_page, "list": page_data}
|
|
else:
|
|
return {"page_number": page_number, "page_size": page_size, "total_row": 0, "total_page": 0, "list": []}
|