diff --git a/dsRag/Dao/KbDao.py b/dsRag/Dao/KbDao.py index 45c73177..68b0c2a4 100644 --- a/dsRag/Dao/KbDao.py +++ b/dsRag/Dao/KbDao.py @@ -1,13 +1,104 @@ -from typing import List, Optional -from Model.KbModel import KbModel, KbFileModel -from Util.MySQLUtil import init_mysql_pool +import logging +from typing import Optional, List, Dict +from aiomysql import DictCursor class KbDao: def __init__(self, mysql_pool): self.mysql_pool = mysql_pool - - async def create_kb(self, kb: KbModel) -> int: - # 实现创建知识库的数据库操作 + self.logger = logging.getLogger(__name__) + + async def create_kb(self, kb: Dict) -> int: + """创建知识库""" + async with self.mysql_pool.acquire() as conn: + async with conn.cursor(DictCursor) as cur: + await cur.execute( + "INSERT INTO t_ai_kb(name, description) VALUES(%s, %s)", + (kb['name'], kb['description'])) + await conn.commit() + return cur.lastrowid + + async def get_kb(self, kb_id: int) -> Optional[Dict]: + """获取知识库详情""" + async with self.mysql_pool.acquire() as conn: + async with conn.cursor(DictCursor) as cur: + await cur.execute( + "SELECT * FROM t_ai_kb WHERE id = %s", (kb_id,)) + return await cur.fetchone() + + async def update_kb(self, kb_id: int, kb: Dict) -> bool: + """更新知识库信息""" + async with self.mysql_pool.acquire() as conn: + async with conn.cursor() as cur: + await cur.execute( + "UPDATE t_ai_kb SET name = %s, description = %s WHERE id = %s", + (kb['name'], kb['description'], kb_id)) + await conn.commit() + return cur.rowcount > 0 + + async def delete_kb(self, kb_id: int) -> bool: + """删除知识库""" + async with self.mysql_pool.acquire() as conn: + async with conn.cursor() as cur: + await cur.execute( + "DELETE FROM t_ai_kb WHERE id = %s", (kb_id,)) + await conn.commit() + return cur.rowcount > 0 + + async def create_kb_file(self, file: Dict) -> int: + """创建知识库文件记录""" + async with self.mysql_pool.acquire() as conn: + async with conn.cursor(DictCursor) as cur: + await cur.execute( + """INSERT INTO t_ai_kb_files + (kb_id, file_name, file_path, file_size, file_type, state) + VALUES(%s, %s, %s, %s, %s, %s)""", + (file['kb_id'], file['file_name'], file['file_path'], + file['file_size'], file['file_type'], file['state'])) + await conn.commit() + return cur.lastrowid + + async def get_kb_file(self, file_id: int) -> Optional[Dict]: + """获取文件详情""" + async with self.mysql_pool.acquire() as conn: + async with conn.cursor(DictCursor) as cur: + await cur.execute( + "SELECT * FROM t_ai_kb_files WHERE id = %s", (file_id,)) + return await cur.fetchone() + + async def update_kb_file(self, file_id: int, file: Dict) -> bool: + """更新文件信息""" + async with self.mysql_pool.acquire() as conn: + async with conn.cursor() as cur: + await cur.execute( + """UPDATE t_ai_kb_files SET + kb_id = %s, file_name = %s, file_path = %s, + file_size = %s, file_type = %s, state = %s + WHERE id = %s""", + (file['kb_id'], file['file_name'], file['file_path'], + file['file_size'], file['file_type'], file['state'], file_id)) + await conn.commit() + return cur.rowcount > 0 + + async def delete_kb_file(self, file_id: int) -> bool: + """删除文件记录""" + async with self.mysql_pool.acquire() as conn: + async with conn.cursor() as cur: + await cur.execute( + "DELETE FROM t_ai_kb_files WHERE id = %s", (file_id,)) + await conn.commit() + return cur.rowcount > 0 + + async def handle_upload(self, kb_id: int, file) -> Dict: + """处理文件上传""" + # 文件保存逻辑 + # 数据库记录创建 + # 返回文件信息 pass - # 其他CRUD方法实现... \ No newline at end of file + async def get_unprocessed_files(self) -> List[Dict]: + """获取未处理文件列表""" + async with self.mysql_pool.acquire() as conn: + async with conn.cursor(DictCursor) as cur: + await cur.execute( + "SELECT * FROM t_ai_kb_files WHERE state = 0") + return await cur.fetchall() \ No newline at end of file diff --git a/dsRag/Dao/__pycache__/KbDao.cpython-310.pyc b/dsRag/Dao/__pycache__/KbDao.cpython-310.pyc index 5b395618..575a5666 100644 Binary files a/dsRag/Dao/__pycache__/KbDao.cpython-310.pyc and b/dsRag/Dao/__pycache__/KbDao.cpython-310.pyc differ diff --git a/dsRag/Model/KbModel.py b/dsRag/Model/KbModel.py index df15cf51..bb735d76 100644 --- a/dsRag/Model/KbModel.py +++ b/dsRag/Model/KbModel.py @@ -1,5 +1,46 @@ from pydantic import BaseModel from typing import Optional +from datetime import datetime + +class KbBase(BaseModel): + name: str + description: Optional[str] = None + +class KbCreate(KbBase): + pass + +class KbUpdate(KbBase): + pass + +class Kb(KbBase): + id: int + create_time: datetime + update_time: datetime + + class Config: + orm_mode = True + +class KbFileBase(BaseModel): + kb_id: int + file_name: str + file_path: str + file_size: int + file_type: str + state: int = 0 + +class KbFileCreate(KbFileBase): + pass + +class KbFileUpdate(KbFileBase): + pass + +class KbFile(KbFileBase): + id: int + create_time: datetime + update_time: datetime + + class Config: + orm_mode = True class KbModel(BaseModel): kb_name: str diff --git a/dsRag/Start.py b/dsRag/Start.py index b62ef342..5100a2a2 100644 --- a/dsRag/Start.py +++ b/dsRag/Start.py @@ -1,28 +1,27 @@ +import asyncio import logging -import threading -import time from contextlib import asynccontextmanager from logging.handlers import RotatingFileHandler -from pathlib import Path -import asyncio + import uvicorn -from fastapi import FastAPI, UploadFile, File, HTTPException -from pymysql.cursors import DictCursor +from fastapi import FastAPI, UploadFile, File + from Dao.KbDao import KbDao -from Model.KbModel import KbModel, KbFileModel from Util.MySQLUtil import init_mysql_pool + +# 初始化日志 logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) -handler = RotatingFileHandler('Logs/document_processor.log', maxBytes=1024*1024, backupCount=5) +handler = RotatingFileHandler('Logs/start.log', maxBytes=1024*1024, backupCount=5) handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')) logger.addHandler(handler) @asynccontextmanager async def lifespan(app: FastAPI): - # 启动时初始化数据库连接池 + # 初始化数据库连接池 app.state.kb_dao = KbDao(await init_mysql_pool()) - # 启动文档处理线程 + # 启动文档处理任务 async def document_processor(): while True: try: @@ -33,206 +32,63 @@ async def lifespan(app: FastAPI): except Exception as e: logger.error(f"文档处理出错: {e}") await asyncio.sleep(10) - - time.sleep(10) # 每10秒检查一次 - - - def run_async_in_thread(): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - loop.run_until_complete(document_processor()) - finally: - loop.close() - - processor_thread = threading.Thread( - target=run_async_in_thread, - daemon=True - ) - processor_thread.start() - # 启动文档处理任务 task = asyncio.create_task(document_processor()) yield # 关闭时取消任务 task.cancel() - - # 关闭时清理资源 + # 关闭数据库连接池 await app.state.kb_dao.mysql_pool.close() app = FastAPI(lifespan=lifespan) - +# 知识库CRUD接口 @app.post("/kb") -async def create_kb(kb: KbModel): +async def create_kb(kb: dict): + """创建知识库""" return await app.state.kb_dao.create_kb(kb) @app.get("/kb/{kb_id}") async def read_kb(kb_id: int): - """获取知识库详情 - Args: - kb_id: 知识库ID - Returns: - 返回知识库详细信息 - """ - async with app.state.mysql_pool.acquire() as conn: - async with conn.cursor() as cur: - await cur.execute("SELECT * FROM t_ai_kb WHERE id = %s", (kb_id,)) - result = await cur.fetchone() - if not result: - raise HTTPException(status_code=404, detail="知识库不存在") - return result + """获取知识库详情""" + return await app.state.kb_dao.get_kb(kb_id) @app.post("/kb/update/{kb_id}") -async def update_kb(kb_id: int, kb: KbModel): - """更新知识库信息 - Args: - kb_id: 知识库ID - kb: 包含更新后知识库信息的对象 - Returns: - 返回更新结果 - """ - async with app.state.mysql_pool.acquire() as conn: - async with conn.cursor() as cur: - await cur.execute( - """UPDATE t_ai_kb - SET kb_name=%s, short_name=%s, is_delete=%s - WHERE id=%s""", - (kb.kb_name, kb.short_name, kb.is_delete, kb_id) - ) - await conn.commit() - return {"status": "success", "affected_rows": cur.rowcount} +async def update_kb(kb_id: int, kb: dict): + """更新知识库信息""" + return await app.state.kb_dao.update_kb(kb_id, kb) @app.delete("/kb/{kb_id}") async def delete_kb(kb_id: int): - async with app.state.mysql_pool.acquire() as conn: - async with conn.cursor() as cur: - await cur.execute("DELETE FROM t_ai_kb WHERE id = %s", (kb_id,)) - await conn.commit() - return {"message": "Knowledge base deleted"} + """删除知识库""" + return await app.state.kb_dao.delete_kb(kb_id) # 知识库文件CRUD接口 @app.post("/kb_file") -async def create_kb_file(file: KbFileModel): - """创建知识库文件 - Args: - file: 包含文件名(file_name)、扩展名(ext_name)、关联知识库ID(kb_id)等信息的对象 - Returns: - 返回创建成功的文件ID - """ - async with app.state.mysql_pool.acquire() as conn: - async with conn.cursor() as cur: - await cur.execute( - """INSERT INTO t_ai_kb_files - (file_name, ext_name, kb_id, is_delete, state) - VALUES (%s, %s, %s, %s, %s)""", - (file.file_name, file.ext_name, file.kb_id, file.is_delete, file.state) - ) - await conn.commit() - return {"id": cur.lastrowid} +async def create_kb_file(file: dict): + """创建知识库文件记录""" + return await app.state.kb_dao.create_kb_file(file) @app.get("/kb_files/{file_id}") async def read_kb_file(file_id: int): - async with app.state.mysql_pool.acquire() as conn: - async with conn.cursor(DictCursor) as cur: - await cur.execute("SELECT * FROM t_ai_kb_files WHERE id = %s", (file_id,)) - result = await cur.fetchone() - if not result: - raise HTTPException(status_code=404, detail="File not found") - return result + """获取文件详情""" + return await app.state.kb_dao.get_kb_file(file_id) -# 知识库文件更新接口修改 @app.post("/kb_files/update/{file_id}") -async def update_kb_file(file_id: int, file: KbFileModel): - """更新知识库文件信息 - Args: - file_id: 文件ID - file: 包含更新后文件信息的对象 - Returns: - 返回更新结果 - """ - async with app.state.mysql_pool.acquire() as conn: - async with conn.cursor() as cur: - await cur.execute( - """UPDATE t_ai_kb_files - SET file_name=%s, ext_name=%s, kb_id=%s, - is_delete=%s, state=%s - WHERE id=%s""", - (file.file_name, file.ext_name, file.kb_id, - file.is_delete, file.state, file_id) - ) - await conn.commit() - return {"status": "success", "affected_rows": cur.rowcount} +async def update_kb_file(file_id: int, file: dict): + """更新文件信息""" + return await app.state.kb_dao.update_kb_file(file_id, file) @app.delete("/kb_files/{file_id}") async def delete_kb_file(file_id: int): - async with app.state.mysql_pool.acquire() as conn: - async with conn.cursor() as cur: - await cur.execute("DELETE FROM t_ai_kb_files WHERE id = %s", (file_id,)) - await conn.commit() - return {"message": "File deleted"} - -# 允许的文件类型 -ALLOWED_EXTENSIONS = {'txt', 'pdf', 'docx', 'ppt', 'pptx'} -MAX_FILE_SIZE = 100 * 1024 * 1024 # 100MB + """删除文件记录""" + return await app.state.kb_dao.delete_kb_file(file_id) +# 文件上传接口 @app.post("/upload") -async def upload_file( - kb_id: int, - file: UploadFile = File(...) -): - """ - 上传文件接口 - Args: - kb_id: 关联的知识库ID - file: 上传的文件对象 - Returns: - 返回文件保存信息和数据库记录ID - """ - # 检查文件类型 - file_ext = Path(file.filename).suffix.lower()[1:] - if file_ext not in ALLOWED_EXTENSIONS: - raise HTTPException( - status_code=400, - detail=f"不支持的文件类型: {file_ext}, 仅支持: {', '.join(ALLOWED_EXTENSIONS)}" - ) - - # 检查文件大小 - file_size = 0 - for chunk in file.file: - file_size += len(chunk) - if file_size > MAX_FILE_SIZE: - raise HTTPException( - status_code=400, - detail=f"文件大小超过限制: {MAX_FILE_SIZE/1024/1024}MB" - ) - file.file.seek(0) - - # 创建Upload目录 - upload_dir = Path("Upload") - upload_dir.mkdir(exist_ok=True) - - # 保存文件 - file_path = upload_dir / file.filename - with open(file_path, "wb") as f: - f.write(await file.read()) - - # 记录到数据库 - async with app.state.mysql_pool.acquire() as conn: - async with conn.cursor() as cur: - await cur.execute( - """INSERT INTO t_ai_kb_files - (file_name, ext_name, kb_id, file_path, file_size) - VALUES (%s, %s, %s, %s, %s)""", - (file.filename, file_ext, kb_id, str(file_path), file_size) - ) - await conn.commit() - return { - "id": cur.lastrowid, - "filename": file.filename, - "size": file_size, - "path": str(file_path) - } +async def upload_file(kb_id: int, file: UploadFile = File(...)): + """文件上传接口""" + return await app.state.kb_dao.handle_upload(kb_id, file) if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=8000) \ No newline at end of file