diff --git a/dsAiTeachingModel/main.py b/dsAiTeachingModel/main.py index d7b3cfdc..2bacc306 100644 --- a/dsAiTeachingModel/main.py +++ b/dsAiTeachingModel/main.py @@ -19,14 +19,14 @@ logging.basicConfig( ) async def lifespan(app: FastAPI): - # 创建数据库连接池 - await init_database() - + print("Starting up...") # 启动异步任务 asyncio.create_task(train_document_task()) - - yield - await shutdown_database() + try: + yield + finally: + print("Shutting down...") + # 如果有需要在关闭时执行的任务,可以在这里添加 app = FastAPI(lifespan=lifespan) diff --git a/dsAiTeachingModel/utils/Database.py b/dsAiTeachingModel/utils/Database.py index ebfce2d6..78799c35 100644 --- a/dsAiTeachingModel/utils/Database.py +++ b/dsAiTeachingModel/utils/Database.py @@ -1,51 +1,15 @@ # Database.py import datetime import logging -import asyncpg -from Config.Config import * - -# 创建一个全局的连接池 -pool = None - -async def create_pool(): - global pool - pool = await asyncpg.create_pool( - host=POSTGRES_HOST, - port=POSTGRES_PORT, - user=POSTGRES_USER, - password=POSTGRES_PASSWORD, - database=POSTGRES_DATABASE, - min_size=1, # 设置连接池最小连接数 - max_size=10 # 设置连接池最大连接数 - ) - -async def get_connection(): - if pool is None: - raise Exception("连接池未初始化") - async with pool.acquire() as conn: - return conn - -async def close_pool(): - if pool is not None: - await pool.close() - -# 初始化连接池的函数 -async def init_database(): - await create_pool() - -# 关闭连接池的函数 -async def shutdown_database(): - await close_pool() +from utils.PostgreSQLUtil import init_postgres_pool # 根据sql语句查询数据 async def find_by_sql(sql: str, params: tuple): - if pool is None: - logging.error("数据库连接池未创建") - return None try: - async with pool.acquire() as conn: + pg_pool = await init_postgres_pool() + async with pg_pool.acquire() as conn: result = await conn.fetch(sql, *params) # 将 asyncpg.Record 转换为字典 result_dict = [dict(record) for record in result] @@ -101,7 +65,8 @@ async def insert(tableName, param, onlyForParam=False): sql = f"INSERT INTO {tableName} ({column_names}) VALUES ({placeholder_names}) RETURNING id" try: - async with pool.acquire() as conn: + pg_pool = await init_postgres_pool() + async with pg_pool.acquire() as conn: result = await conn.fetchrow(sql, *values) if result: return result['id'] @@ -148,7 +113,8 @@ async def update(table_name, param, property_name, property_value, only_for_para values.append(property_value) try: - async with pool.acquire() as conn: + pg_pool = await init_postgres_pool() + async with pg_pool.acquire() as conn: result = await conn.fetchrow(sql, *values) if result: return result['id'] @@ -190,7 +156,8 @@ async def delete_by_id(table_name, property_name, property_value): logging.debug(sql) # 执行删除 try: - async with pool.acquire() as conn: + pg_pool = await init_postgres_pool() + async with pg_pool.acquire() as conn: result = await conn.execute(sql, property_value) if result: return True @@ -209,9 +176,9 @@ async def delete_by_id(table_name, property_name, property_value): # 执行一个SQL语句 async def execute_sql(sql, params): - logging.debug(sql) try: - async with pool.acquire() as conn: + pg_pool = await init_postgres_pool() + async with pg_pool.acquire() as conn: await conn.fetch(sql, *params) except Exception as e: logging.error(f"数据库查询错误: {e}")