import threading import time from contextlib import asynccontextmanager from pathlib import Path import uvicorn from fastapi import FastAPI, UploadFile, File, HTTPException # 在文件开头添加导入 from pymysql.cursors import DictCursor from Dao.KbDao import KbDao from Model.KbModel import KbModel, KbFileModel from Test.T9_TestReadPptx import extract_text_from_pptx from Util import PdfUtil, WordUtil from Util.MySQLUtil import init_mysql_pool import logging from logging.handlers import RotatingFileHandler # 确保logger已在文件开头正确初始化 logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) handler = RotatingFileHandler('Logs/document_processor.log', maxBytes=1024*1024, backupCount=5) handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')) logger.addHandler(handler) @asynccontextmanager async def lifespan(app: FastAPI): # 启动时初始化数据库连接池 app.state.kb_dao = KbDao(await init_mysql_pool()) # 启动文档处理线程 # 修改函数定义 async def document_processor(): while True: try: # 获取未处理文档 # 处理文档 # 保存到ES await asyncio.sleep(10) except Exception as e: logger.error(f"文档处理出错: {e}") await asyncio.sleep(10) time.sleep(10) # 每10秒检查一次 # 修改线程启动部分 # 修改线程启动方式 import asyncio def run_async_in_thread(): loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: loop.run_until_complete(document_processor()) finally: loop.close() processor_thread = threading.Thread( target=run_async_in_thread, daemon=True ) processor_thread.start() # 启动文档处理任务 task = asyncio.create_task(document_processor()) yield # 关闭时取消任务 task.cancel() # 关闭时清理资源 await app.state.kb_dao.mysql_pool.close() app = FastAPI(lifespan=lifespan) @app.post("/kb") async def create_kb(kb: KbModel): return await app.state.kb_dao.create_kb(kb) @app.get("/kb/{kb_id}") async def read_kb(kb_id: int): """获取知识库详情 Args: kb_id: 知识库ID Returns: 返回知识库详细信息 """ async with app.state.mysql_pool.acquire() as conn: async with conn.cursor() as cur: await cur.execute("SELECT * FROM t_ai_kb WHERE id = %s", (kb_id,)) result = await cur.fetchone() if not result: raise HTTPException(status_code=404, detail="知识库不存在") return result @app.post("/kb/update/{kb_id}") async def update_kb(kb_id: int, kb: KbModel): """更新知识库信息 Args: kb_id: 知识库ID kb: 包含更新后知识库信息的对象 Returns: 返回更新结果 """ async with app.state.mysql_pool.acquire() as conn: async with conn.cursor() as cur: await cur.execute( """UPDATE t_ai_kb SET kb_name=%s, short_name=%s, is_delete=%s WHERE id=%s""", (kb.kb_name, kb.short_name, kb.is_delete, kb_id) ) await conn.commit() return {"status": "success", "affected_rows": cur.rowcount} @app.delete("/kb/{kb_id}") async def delete_kb(kb_id: int): async with app.state.mysql_pool.acquire() as conn: async with conn.cursor() as cur: await cur.execute("DELETE FROM t_ai_kb WHERE id = %s", (kb_id,)) await conn.commit() return {"message": "Knowledge base deleted"} # 知识库文件CRUD接口 @app.post("/kb_file") async def create_kb_file(file: KbFileModel): """创建知识库文件 Args: file: 包含文件名(file_name)、扩展名(ext_name)、关联知识库ID(kb_id)等信息的对象 Returns: 返回创建成功的文件ID """ async with app.state.mysql_pool.acquire() as conn: async with conn.cursor() as cur: await cur.execute( """INSERT INTO t_ai_kb_files (file_name, ext_name, kb_id, is_delete, state) VALUES (%s, %s, %s, %s, %s)""", (file.file_name, file.ext_name, file.kb_id, file.is_delete, file.state) ) await conn.commit() return {"id": cur.lastrowid} @app.get("/kb_files/{file_id}") async def read_kb_file(file_id: int): async with app.state.mysql_pool.acquire() as conn: async with conn.cursor(DictCursor) as cur: await cur.execute("SELECT * FROM t_ai_kb_files WHERE id = %s", (file_id,)) result = await cur.fetchone() if not result: raise HTTPException(status_code=404, detail="File not found") return result # 知识库文件更新接口修改 @app.post("/kb_files/update/{file_id}") async def update_kb_file(file_id: int, file: KbFileModel): """更新知识库文件信息 Args: file_id: 文件ID file: 包含更新后文件信息的对象 Returns: 返回更新结果 """ async with app.state.mysql_pool.acquire() as conn: async with conn.cursor() as cur: await cur.execute( """UPDATE t_ai_kb_files SET file_name=%s, ext_name=%s, kb_id=%s, is_delete=%s, state=%s WHERE id=%s""", (file.file_name, file.ext_name, file.kb_id, file.is_delete, file.state, file_id) ) await conn.commit() return {"status": "success", "affected_rows": cur.rowcount} @app.delete("/kb_files/{file_id}") async def delete_kb_file(file_id: int): async with app.state.mysql_pool.acquire() as conn: async with conn.cursor() as cur: await cur.execute("DELETE FROM t_ai_kb_files WHERE id = %s", (file_id,)) await conn.commit() return {"message": "File deleted"} # 允许的文件类型 ALLOWED_EXTENSIONS = {'txt', 'pdf', 'docx', 'ppt', 'pptx'} MAX_FILE_SIZE = 100 * 1024 * 1024 # 100MB @app.post("/upload") async def upload_file( kb_id: int, file: UploadFile = File(...) ): """ 上传文件接口 Args: kb_id: 关联的知识库ID file: 上传的文件对象 Returns: 返回文件保存信息和数据库记录ID """ # 检查文件类型 file_ext = Path(file.filename).suffix.lower()[1:] if file_ext not in ALLOWED_EXTENSIONS: raise HTTPException( status_code=400, detail=f"不支持的文件类型: {file_ext}, 仅支持: {', '.join(ALLOWED_EXTENSIONS)}" ) # 检查文件大小 file_size = 0 for chunk in file.file: file_size += len(chunk) if file_size > MAX_FILE_SIZE: raise HTTPException( status_code=400, detail=f"文件大小超过限制: {MAX_FILE_SIZE/1024/1024}MB" ) file.file.seek(0) # 创建Upload目录 upload_dir = Path("Upload") upload_dir.mkdir(exist_ok=True) # 保存文件 file_path = upload_dir / file.filename with open(file_path, "wb") as f: f.write(await file.read()) # 记录到数据库 async with app.state.mysql_pool.acquire() as conn: async with conn.cursor() as cur: await cur.execute( """INSERT INTO t_ai_kb_files (file_name, ext_name, kb_id, file_path, file_size) VALUES (%s, %s, %s, %s, %s)""", (file.filename, file_ext, kb_id, str(file_path), file_size) ) await conn.commit() return { "id": cur.lastrowid, "filename": file.filename, "size": file_size, "path": str(file_path) } if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=8000) # 确保logger已在文件开头正确初始化 logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) handler = RotatingFileHandler('Logs/document_processor.log', maxBytes=1024*1024, backupCount=5) handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')) logger.addHandler(handler)