You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

104 lines
4.5 KiB

import logging
from typing import Optional, List, Dict
from aiomysql import DictCursor
class KbDao:
def __init__(self, mysql_pool):
self.mysql_pool = mysql_pool
self.logger = logging.getLogger(__name__)
async def create_kb(self, kb: Dict) -> int:
"""创建知识库"""
async with self.mysql_pool.acquire() as conn:
async with conn.cursor(DictCursor) as cur:
await cur.execute(
"INSERT INTO t_ai_kb(name, description) VALUES(%s, %s)",
(kb['name'], kb['description']))
await conn.commit()
return cur.lastrowid
async def get_kb(self, kb_id: int) -> Optional[Dict]:
"""获取知识库详情"""
async with self.mysql_pool.acquire() as conn:
async with conn.cursor(DictCursor) as cur:
await cur.execute(
"SELECT * FROM t_ai_kb WHERE id = %s", (kb_id,))
return await cur.fetchone()
async def update_kb(self, kb_id: int, kb: Dict) -> bool:
"""更新知识库信息"""
async with self.mysql_pool.acquire() as conn:
async with conn.cursor() as cur:
await cur.execute(
"UPDATE t_ai_kb SET name = %s, description = %s WHERE id = %s",
(kb['name'], kb['description'], kb_id))
await conn.commit()
return cur.rowcount > 0
async def delete_kb(self, kb_id: int) -> bool:
"""删除知识库"""
async with self.mysql_pool.acquire() as conn:
async with conn.cursor() as cur:
await cur.execute(
"DELETE FROM t_ai_kb WHERE id = %s", (kb_id,))
await conn.commit()
return cur.rowcount > 0
async def create_kb_file(self, file: Dict) -> int:
"""创建知识库文件记录"""
async with self.mysql_pool.acquire() as conn:
async with conn.cursor(DictCursor) as cur:
await cur.execute(
"""INSERT INTO t_ai_kb_files
(kb_id, file_name, file_path, file_size, file_type, state)
VALUES(%s, %s, %s, %s, %s, %s)""",
(file['kb_id'], file['file_name'], file['file_path'],
file['file_size'], file['file_type'], file['state']))
await conn.commit()
return cur.lastrowid
async def get_kb_file(self, file_id: int) -> Optional[Dict]:
"""获取文件详情"""
async with self.mysql_pool.acquire() as conn:
async with conn.cursor(DictCursor) as cur:
await cur.execute(
"SELECT * FROM t_ai_kb_files WHERE id = %s", (file_id,))
return await cur.fetchone()
async def update_kb_file(self, file_id: int, file: Dict) -> bool:
"""更新文件信息"""
async with self.mysql_pool.acquire() as conn:
async with conn.cursor() as cur:
await cur.execute(
"""UPDATE t_ai_kb_files SET
kb_id = %s, file_name = %s, file_path = %s,
file_size = %s, file_type = %s, state = %s
WHERE id = %s""",
(file['kb_id'], file['file_name'], file['file_path'],
file['file_size'], file['file_type'], file['state'], file_id))
await conn.commit()
return cur.rowcount > 0
async def delete_kb_file(self, file_id: int) -> bool:
"""删除文件记录"""
async with self.mysql_pool.acquire() as conn:
async with conn.cursor() as cur:
await cur.execute(
"DELETE FROM t_ai_kb_files WHERE id = %s", (file_id,))
await conn.commit()
return cur.rowcount > 0
async def handle_upload(self, kb_id: int, file) -> Dict:
"""处理文件上传"""
# 文件保存逻辑
# 数据库记录创建
# 返回文件信息
pass
async def get_unprocessed_files(self) -> List[Dict]:
"""获取未处理文件列表"""
async with self.mysql_pool.acquire() as conn:
async with conn.cursor(DictCursor) as cur:
await cur.execute(
"SELECT * FROM t_ai_kb_files WHERE state = 0")
return await cur.fetchall()