from pathlib import Path import uvicorn from fastapi import FastAPI, UploadFile, File, HTTPException from pymysql.cursors import DictCursor from Dao.KbDao import KbDao from Model.KbModel import KbModel, KbFileModel from Util.MySQLUtil import init_mysql_pool from contextlib import asynccontextmanager @asynccontextmanager async def lifespan(app: FastAPI): # 启动时初始化数据库连接池 app.state.kb_dao = KbDao(await init_mysql_pool()) yield # 关闭时清理资源 await app.state.kb_dao.mysql_pool.close() app = FastAPI(lifespan=lifespan) @app.post("/kb") async def create_kb(kb: KbModel): return await app.state.kb_dao.create_kb(kb) @app.get("/kb/{kb_id}") async def read_kb(kb_id: int): """获取知识库详情 Args: kb_id: 知识库ID Returns: 返回知识库详细信息 """ async with app.state.mysql_pool.acquire() as conn: async with conn.cursor() as cur: await cur.execute("SELECT * FROM t_ai_kb WHERE id = %s", (kb_id,)) result = await cur.fetchone() if not result: raise HTTPException(status_code=404, detail="知识库不存在") return result @app.post("/kb/update/{kb_id}") async def update_kb(kb_id: int, kb: KbModel): """更新知识库信息 Args: kb_id: 知识库ID kb: 包含更新后知识库信息的对象 Returns: 返回更新结果 """ async with app.state.mysql_pool.acquire() as conn: async with conn.cursor() as cur: await cur.execute( """UPDATE t_ai_kb SET kb_name=%s, short_name=%s, is_delete=%s WHERE id=%s""", (kb.kb_name, kb.short_name, kb.is_delete, kb_id) ) await conn.commit() return {"status": "success", "affected_rows": cur.rowcount} @app.delete("/kb/{kb_id}") async def delete_kb(kb_id: int): async with app.state.mysql_pool.acquire() as conn: async with conn.cursor() as cur: await cur.execute("DELETE FROM t_ai_kb WHERE id = %s", (kb_id,)) await conn.commit() return {"message": "Knowledge base deleted"} # 知识库文件CRUD接口 @app.post("/kb_file") async def create_kb_file(file: KbFileModel): """创建知识库文件 Args: file: 包含文件名(file_name)、扩展名(ext_name)、关联知识库ID(kb_id)等信息的对象 Returns: 返回创建成功的文件ID """ async with app.state.mysql_pool.acquire() as conn: async with conn.cursor() as cur: await cur.execute( """INSERT INTO t_ai_kb_files (file_name, ext_name, kb_id, is_delete, state) VALUES (%s, %s, %s, %s, %s)""", (file.file_name, file.ext_name, file.kb_id, file.is_delete, file.state) ) await conn.commit() return {"id": cur.lastrowid} @app.get("/kb_files/{file_id}") async def read_kb_file(file_id: int): async with app.state.mysql_pool.acquire() as conn: async with conn.cursor(DictCursor) as cur: await cur.execute("SELECT * FROM t_ai_kb_files WHERE id = %s", (file_id,)) result = await cur.fetchone() if not result: raise HTTPException(status_code=404, detail="File not found") return result # 知识库文件更新接口修改 @app.post("/kb_files/update/{file_id}") async def update_kb_file(file_id: int, file: KbFileModel): """更新知识库文件信息 Args: file_id: 文件ID file: 包含更新后文件信息的对象 Returns: 返回更新结果 """ async with app.state.mysql_pool.acquire() as conn: async with conn.cursor() as cur: await cur.execute( """UPDATE t_ai_kb_files SET file_name=%s, ext_name=%s, kb_id=%s, is_delete=%s, state=%s WHERE id=%s""", (file.file_name, file.ext_name, file.kb_id, file.is_delete, file.state, file_id) ) await conn.commit() return {"status": "success", "affected_rows": cur.rowcount} @app.delete("/kb_files/{file_id}") async def delete_kb_file(file_id: int): async with app.state.mysql_pool.acquire() as conn: async with conn.cursor() as cur: await cur.execute("DELETE FROM t_ai_kb_files WHERE id = %s", (file_id,)) await conn.commit() return {"message": "File deleted"} # 允许的文件类型 ALLOWED_EXTENSIONS = {'txt', 'pdf', 'docx', 'ppt', 'pptx'} MAX_FILE_SIZE = 100 * 1024 * 1024 # 100MB @app.post("/upload") async def upload_file( kb_id: int, file: UploadFile = File(...) ): """ 上传文件接口 Args: kb_id: 关联的知识库ID file: 上传的文件对象 Returns: 返回文件保存信息和数据库记录ID """ # 检查文件类型 file_ext = Path(file.filename).suffix.lower()[1:] if file_ext not in ALLOWED_EXTENSIONS: raise HTTPException( status_code=400, detail=f"不支持的文件类型: {file_ext}, 仅支持: {', '.join(ALLOWED_EXTENSIONS)}" ) # 检查文件大小 file_size = 0 for chunk in file.file: file_size += len(chunk) if file_size > MAX_FILE_SIZE: raise HTTPException( status_code=400, detail=f"文件大小超过限制: {MAX_FILE_SIZE/1024/1024}MB" ) file.file.seek(0) # 创建Upload目录 upload_dir = Path("Upload") upload_dir.mkdir(exist_ok=True) # 保存文件 file_path = upload_dir / file.filename with open(file_path, "wb") as f: f.write(await file.read()) # 记录到数据库 async with app.state.mysql_pool.acquire() as conn: async with conn.cursor() as cur: await cur.execute( """INSERT INTO t_ai_kb_files (file_name, ext_name, kb_id, file_path, file_size) VALUES (%s, %s, %s, %s, %s)""", (file.filename, file_ext, kb_id, str(file_path), file_size) ) await conn.commit() return { "id": cur.lastrowid, "filename": file.filename, "size": file_size, "path": str(file_path) } if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=8000)