From 3c6cf034eb7737708bdbe4952f903bf67dc6e301 Mon Sep 17 00:00:00 2001 From: HuangHai <10402852@qq.com> Date: Tue, 24 Jun 2025 20:20:41 +0800 Subject: [PATCH] 'commit' --- dsRag/Dao/KbDao.py | 105 +++++++++- dsRag/Dao/__pycache__/KbDao.cpython-310.pyc | Bin 756 -> 4881 bytes dsRag/Model/KbModel.py | 41 ++++ dsRag/Start.py | 212 ++++---------------- 4 files changed, 173 insertions(+), 185 deletions(-) diff --git a/dsRag/Dao/KbDao.py b/dsRag/Dao/KbDao.py index 45c73177..68b0c2a4 100644 --- a/dsRag/Dao/KbDao.py +++ b/dsRag/Dao/KbDao.py @@ -1,13 +1,104 @@ -from typing import List, Optional -from Model.KbModel import KbModel, KbFileModel -from Util.MySQLUtil import init_mysql_pool +import logging +from typing import Optional, List, Dict +from aiomysql import DictCursor class KbDao: def __init__(self, mysql_pool): self.mysql_pool = mysql_pool - - async def create_kb(self, kb: KbModel) -> int: - # 实现创建知识库的数据库操作 + self.logger = logging.getLogger(__name__) + + async def create_kb(self, kb: Dict) -> int: + """创建知识库""" + async with self.mysql_pool.acquire() as conn: + async with conn.cursor(DictCursor) as cur: + await cur.execute( + "INSERT INTO t_ai_kb(name, description) VALUES(%s, %s)", + (kb['name'], kb['description'])) + await conn.commit() + return cur.lastrowid + + async def get_kb(self, kb_id: int) -> Optional[Dict]: + """获取知识库详情""" + async with self.mysql_pool.acquire() as conn: + async with conn.cursor(DictCursor) as cur: + await cur.execute( + "SELECT * FROM t_ai_kb WHERE id = %s", (kb_id,)) + return await cur.fetchone() + + async def update_kb(self, kb_id: int, kb: Dict) -> bool: + """更新知识库信息""" + async with self.mysql_pool.acquire() as conn: + async with conn.cursor() as cur: + await cur.execute( + "UPDATE t_ai_kb SET name = %s, description = %s WHERE id = %s", + (kb['name'], kb['description'], kb_id)) + await conn.commit() + return cur.rowcount > 0 + + async def delete_kb(self, kb_id: int) -> bool: + """删除知识库""" + async with self.mysql_pool.acquire() as conn: + async with conn.cursor() as cur: + await cur.execute( + "DELETE FROM t_ai_kb WHERE id = %s", (kb_id,)) + await conn.commit() + return cur.rowcount > 0 + + async def create_kb_file(self, file: Dict) -> int: + """创建知识库文件记录""" + async with self.mysql_pool.acquire() as conn: + async with conn.cursor(DictCursor) as cur: + await cur.execute( + """INSERT INTO t_ai_kb_files + (kb_id, file_name, file_path, file_size, file_type, state) + VALUES(%s, %s, %s, %s, %s, %s)""", + (file['kb_id'], file['file_name'], file['file_path'], + file['file_size'], file['file_type'], file['state'])) + await conn.commit() + return cur.lastrowid + + async def get_kb_file(self, file_id: int) -> Optional[Dict]: + """获取文件详情""" + async with self.mysql_pool.acquire() as conn: + async with conn.cursor(DictCursor) as cur: + await cur.execute( + "SELECT * FROM t_ai_kb_files WHERE id = %s", (file_id,)) + return await cur.fetchone() + + async def update_kb_file(self, file_id: int, file: Dict) -> bool: + """更新文件信息""" + async with self.mysql_pool.acquire() as conn: + async with conn.cursor() as cur: + await cur.execute( + """UPDATE t_ai_kb_files SET + kb_id = %s, file_name = %s, file_path = %s, + file_size = %s, file_type = %s, state = %s + WHERE id = %s""", + (file['kb_id'], file['file_name'], file['file_path'], + file['file_size'], file['file_type'], file['state'], file_id)) + await conn.commit() + return cur.rowcount > 0 + + async def delete_kb_file(self, file_id: int) -> bool: + """删除文件记录""" + async with self.mysql_pool.acquire() as conn: + async with conn.cursor() as cur: + await cur.execute( + "DELETE FROM t_ai_kb_files WHERE id = %s", (file_id,)) + await conn.commit() + return cur.rowcount > 0 + + async def handle_upload(self, kb_id: int, file) -> Dict: + """处理文件上传""" + # 文件保存逻辑 + # 数据库记录创建 + # 返回文件信息 pass - # 其他CRUD方法实现... \ No newline at end of file + async def get_unprocessed_files(self) -> List[Dict]: + """获取未处理文件列表""" + async with self.mysql_pool.acquire() as conn: + async with conn.cursor(DictCursor) as cur: + await cur.execute( + "SELECT * FROM t_ai_kb_files WHERE state = 0") + return await cur.fetchall() \ No newline at end of file diff --git a/dsRag/Dao/__pycache__/KbDao.cpython-310.pyc b/dsRag/Dao/__pycache__/KbDao.cpython-310.pyc index 5b3956185250945c7d698a6a304d615efd47431d..575a5666550910cb267de9aa8821cda8d5ac5b1e 100644 GIT binary patch literal 4881 zcmb_gUu+yl8Q+!;#@9n(zI};O`2^&XrR1AX_YE=NlLgxw%dp$VmaTeowe_7 zo!LE=oE)i9LWy_)Bt$AGRc)yt!UO38Dkv>Nf;S#`M*9e{>q>~1Jb-8+eBbOJ@A}r9 zRI#V~ZuZ-6W@qL%zwbA5?&xS)f$Q(Do0tB!OHuxcH|a}-o0IVMj)UO}CvC;VU)m-e zYSNBssvTxBfK;o?G%Ax7%z7?5yMY z0h@4IzCnUA{+{Xy&lN3PCSMAL0i z7-f8689_J&UvC`DGFetE#V0G&S68vjxB`_)!_b0q=Zbf^Z7sQOJ1Df>`FW>1AC%^8 z|16Yk5fm+}+w9nu73lcX76uKNXSe4>0b0JTh*9|Y*TIS77rA%A6^rn5UbqW(%ZJib zbN(Xq?xO7D(WNWcn$vZB%R+cO21X}iq+baG;)elUoDFng`#sUkj0`K_;`;;E{v()c zFq#Erg{&&iC}&`)omPI_xJs94IMT#-^s2h7t}srQ30L>Qd~k%BuF|j$QJ{vl8npm- z>Vkejq0k;vm~vf>qw;*y%CV@Pcls4V6#RS@z83~+w5b>l=pe@c?4&N5-`*)$e9JF7^LHpn)+H)j@vAT(I9+G*!6k;dShz&va z<#en>SMBYA0c?Et*K6Oqm7u$E=V$BJzt!kJK0SGM^3+WIsrn03=e`o7T)%L7a%!^f z@cOd=>1|!YC$JfTq?WjehoEE@=WM@q$?e*=Bd8w1=ljqg3|o=Xf!zR84IbUcmSUK6 zQhqQ(YyXZQ8VMl^gIkzT;tN+yLc>47O%5>UB?vS{%wt#CGFwrBgjAA{A_G<&j{mzwxW}Tkj-b-TU3@`n5Za{)v~)Pdq;} znV>X1Ia5cnNt80Al%6b#q6)JmcB3I+20=wU#a=LG5y;nadtKkqLjFlmk%ZeLTk0+G zkb@YUaIh*lI3_t*9N=J5XM6Jf9$L%6_{ETXDU1(G)DdHnp&ngHL#Fa%UcICq#y`A9N=sz|A^aDhdlg|VtZc)GWTi?F--n$#Wd}r;qKWy~h%uzxMC(o;w>p9mU zsVGP47$WyDQE{p1Uy6#J(+3@u0sfUGD0)87$H;Flqp+y5>8~M>aX~40SyaTAMMZpB zRK%ABntYkSG#pFplThBC-IYWw<*+{o!QQ99#QB?&s#+!VLHXl~4pRHiXvNjGNUf|K z6l72ThvJGyv14ivGghcu5gd_d8M&er0g{m}ft`${=_r)9AL5QsZDcDaF!C`l{atA# zlFtU@DB;UGRLJ4W#Q`~bHdl_QWDDiU7Rr$=lmlBJofT4ifxYAwnJRsvbe7pC4B0-8ve6N+P%)sQ(KRSJ-+|B+=* zw(>x#j(0-h324?}5|p)@s4?06^<<9BM95^vpNvo-#kY!0C6NwUc?v^NcKQ*|vyU@G z!95%{E`5xb{zK057Wo^!)a-J|vAw0X z+vHz^cIC;AzRI-z?vT2fh!*xCzOYa640`XvO%8b{u9D^G@R}Uyo{$0vN4ixKAJ5W< zxV-&go-eg$d@fC}A!N*mot#5e3RsjVfAWj9)MT#Re0^i}=Z*g1;Tba~%G1Ws?S8Y} zj*;1Rr@sf&X|KB^+?MTmHV>EEEX=a}Utm)>G8_N>o9T9-Zo}9u5U!wXSsj=60Cp=F zv8>CzW;^_!K;SV};7L7DovtrX*9ApWM<1Joe;UN&cz*!RXVE;3<`|mKqxk|FTo~dR zG@Hv!mhiMGaP{Ib#KX;2Q7Mut8!wL+tL18Gd=!2w@LP~kMo|-rdx@K}yvFs1a|# zL-aX#3tu_$3Y?g62t;C}nVp~YeBX>+_xoKy`Sw+O_{0EyP;e*#24|?vDH;We87P@x zB^R7g%uP^+A}k{j`8Y6fnTUkJD^Q_|wxFU1E>aawAxnN@56HL+FN&sh>FZTn)KzXU zlZ@w=wbI5Njpr{5qkWJucU)9OE6YuDYvih~&4M}{N{GQ3YIA}{gJ22-S4{COh(HB8 zRH2Tpk%@}8p@>z2JW>~p=jVC7;CN;)>ebJIE6~63t?@Nv^ z=PtIoUEAtEmUvC@oz7B&FF5cNY>KX@k#S4p?O>nqS=+e8m z_B8OhN!D?3yIB?0jio(UdTg%m?|$^>LHPcbPP@UowlJg1&GgL$d6vfaeGc$EydO4w TzGHn#>uP8}Lm09id%}MMaEYP; diff --git a/dsRag/Model/KbModel.py b/dsRag/Model/KbModel.py index df15cf51..bb735d76 100644 --- a/dsRag/Model/KbModel.py +++ b/dsRag/Model/KbModel.py @@ -1,5 +1,46 @@ from pydantic import BaseModel from typing import Optional +from datetime import datetime + +class KbBase(BaseModel): + name: str + description: Optional[str] = None + +class KbCreate(KbBase): + pass + +class KbUpdate(KbBase): + pass + +class Kb(KbBase): + id: int + create_time: datetime + update_time: datetime + + class Config: + orm_mode = True + +class KbFileBase(BaseModel): + kb_id: int + file_name: str + file_path: str + file_size: int + file_type: str + state: int = 0 + +class KbFileCreate(KbFileBase): + pass + +class KbFileUpdate(KbFileBase): + pass + +class KbFile(KbFileBase): + id: int + create_time: datetime + update_time: datetime + + class Config: + orm_mode = True class KbModel(BaseModel): kb_name: str diff --git a/dsRag/Start.py b/dsRag/Start.py index b62ef342..5100a2a2 100644 --- a/dsRag/Start.py +++ b/dsRag/Start.py @@ -1,28 +1,27 @@ +import asyncio import logging -import threading -import time from contextlib import asynccontextmanager from logging.handlers import RotatingFileHandler -from pathlib import Path -import asyncio + import uvicorn -from fastapi import FastAPI, UploadFile, File, HTTPException -from pymysql.cursors import DictCursor +from fastapi import FastAPI, UploadFile, File + from Dao.KbDao import KbDao -from Model.KbModel import KbModel, KbFileModel from Util.MySQLUtil import init_mysql_pool + +# 初始化日志 logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) -handler = RotatingFileHandler('Logs/document_processor.log', maxBytes=1024*1024, backupCount=5) +handler = RotatingFileHandler('Logs/start.log', maxBytes=1024*1024, backupCount=5) handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')) logger.addHandler(handler) @asynccontextmanager async def lifespan(app: FastAPI): - # 启动时初始化数据库连接池 + # 初始化数据库连接池 app.state.kb_dao = KbDao(await init_mysql_pool()) - # 启动文档处理线程 + # 启动文档处理任务 async def document_processor(): while True: try: @@ -33,206 +32,63 @@ async def lifespan(app: FastAPI): except Exception as e: logger.error(f"文档处理出错: {e}") await asyncio.sleep(10) - - time.sleep(10) # 每10秒检查一次 - - - def run_async_in_thread(): - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - loop.run_until_complete(document_processor()) - finally: - loop.close() - - processor_thread = threading.Thread( - target=run_async_in_thread, - daemon=True - ) - processor_thread.start() - # 启动文档处理任务 task = asyncio.create_task(document_processor()) yield # 关闭时取消任务 task.cancel() - - # 关闭时清理资源 + # 关闭数据库连接池 await app.state.kb_dao.mysql_pool.close() app = FastAPI(lifespan=lifespan) - +# 知识库CRUD接口 @app.post("/kb") -async def create_kb(kb: KbModel): +async def create_kb(kb: dict): + """创建知识库""" return await app.state.kb_dao.create_kb(kb) @app.get("/kb/{kb_id}") async def read_kb(kb_id: int): - """获取知识库详情 - Args: - kb_id: 知识库ID - Returns: - 返回知识库详细信息 - """ - async with app.state.mysql_pool.acquire() as conn: - async with conn.cursor() as cur: - await cur.execute("SELECT * FROM t_ai_kb WHERE id = %s", (kb_id,)) - result = await cur.fetchone() - if not result: - raise HTTPException(status_code=404, detail="知识库不存在") - return result + """获取知识库详情""" + return await app.state.kb_dao.get_kb(kb_id) @app.post("/kb/update/{kb_id}") -async def update_kb(kb_id: int, kb: KbModel): - """更新知识库信息 - Args: - kb_id: 知识库ID - kb: 包含更新后知识库信息的对象 - Returns: - 返回更新结果 - """ - async with app.state.mysql_pool.acquire() as conn: - async with conn.cursor() as cur: - await cur.execute( - """UPDATE t_ai_kb - SET kb_name=%s, short_name=%s, is_delete=%s - WHERE id=%s""", - (kb.kb_name, kb.short_name, kb.is_delete, kb_id) - ) - await conn.commit() - return {"status": "success", "affected_rows": cur.rowcount} +async def update_kb(kb_id: int, kb: dict): + """更新知识库信息""" + return await app.state.kb_dao.update_kb(kb_id, kb) @app.delete("/kb/{kb_id}") async def delete_kb(kb_id: int): - async with app.state.mysql_pool.acquire() as conn: - async with conn.cursor() as cur: - await cur.execute("DELETE FROM t_ai_kb WHERE id = %s", (kb_id,)) - await conn.commit() - return {"message": "Knowledge base deleted"} + """删除知识库""" + return await app.state.kb_dao.delete_kb(kb_id) # 知识库文件CRUD接口 @app.post("/kb_file") -async def create_kb_file(file: KbFileModel): - """创建知识库文件 - Args: - file: 包含文件名(file_name)、扩展名(ext_name)、关联知识库ID(kb_id)等信息的对象 - Returns: - 返回创建成功的文件ID - """ - async with app.state.mysql_pool.acquire() as conn: - async with conn.cursor() as cur: - await cur.execute( - """INSERT INTO t_ai_kb_files - (file_name, ext_name, kb_id, is_delete, state) - VALUES (%s, %s, %s, %s, %s)""", - (file.file_name, file.ext_name, file.kb_id, file.is_delete, file.state) - ) - await conn.commit() - return {"id": cur.lastrowid} +async def create_kb_file(file: dict): + """创建知识库文件记录""" + return await app.state.kb_dao.create_kb_file(file) @app.get("/kb_files/{file_id}") async def read_kb_file(file_id: int): - async with app.state.mysql_pool.acquire() as conn: - async with conn.cursor(DictCursor) as cur: - await cur.execute("SELECT * FROM t_ai_kb_files WHERE id = %s", (file_id,)) - result = await cur.fetchone() - if not result: - raise HTTPException(status_code=404, detail="File not found") - return result + """获取文件详情""" + return await app.state.kb_dao.get_kb_file(file_id) -# 知识库文件更新接口修改 @app.post("/kb_files/update/{file_id}") -async def update_kb_file(file_id: int, file: KbFileModel): - """更新知识库文件信息 - Args: - file_id: 文件ID - file: 包含更新后文件信息的对象 - Returns: - 返回更新结果 - """ - async with app.state.mysql_pool.acquire() as conn: - async with conn.cursor() as cur: - await cur.execute( - """UPDATE t_ai_kb_files - SET file_name=%s, ext_name=%s, kb_id=%s, - is_delete=%s, state=%s - WHERE id=%s""", - (file.file_name, file.ext_name, file.kb_id, - file.is_delete, file.state, file_id) - ) - await conn.commit() - return {"status": "success", "affected_rows": cur.rowcount} +async def update_kb_file(file_id: int, file: dict): + """更新文件信息""" + return await app.state.kb_dao.update_kb_file(file_id, file) @app.delete("/kb_files/{file_id}") async def delete_kb_file(file_id: int): - async with app.state.mysql_pool.acquire() as conn: - async with conn.cursor() as cur: - await cur.execute("DELETE FROM t_ai_kb_files WHERE id = %s", (file_id,)) - await conn.commit() - return {"message": "File deleted"} - -# 允许的文件类型 -ALLOWED_EXTENSIONS = {'txt', 'pdf', 'docx', 'ppt', 'pptx'} -MAX_FILE_SIZE = 100 * 1024 * 1024 # 100MB + """删除文件记录""" + return await app.state.kb_dao.delete_kb_file(file_id) +# 文件上传接口 @app.post("/upload") -async def upload_file( - kb_id: int, - file: UploadFile = File(...) -): - """ - 上传文件接口 - Args: - kb_id: 关联的知识库ID - file: 上传的文件对象 - Returns: - 返回文件保存信息和数据库记录ID - """ - # 检查文件类型 - file_ext = Path(file.filename).suffix.lower()[1:] - if file_ext not in ALLOWED_EXTENSIONS: - raise HTTPException( - status_code=400, - detail=f"不支持的文件类型: {file_ext}, 仅支持: {', '.join(ALLOWED_EXTENSIONS)}" - ) - - # 检查文件大小 - file_size = 0 - for chunk in file.file: - file_size += len(chunk) - if file_size > MAX_FILE_SIZE: - raise HTTPException( - status_code=400, - detail=f"文件大小超过限制: {MAX_FILE_SIZE/1024/1024}MB" - ) - file.file.seek(0) - - # 创建Upload目录 - upload_dir = Path("Upload") - upload_dir.mkdir(exist_ok=True) - - # 保存文件 - file_path = upload_dir / file.filename - with open(file_path, "wb") as f: - f.write(await file.read()) - - # 记录到数据库 - async with app.state.mysql_pool.acquire() as conn: - async with conn.cursor() as cur: - await cur.execute( - """INSERT INTO t_ai_kb_files - (file_name, ext_name, kb_id, file_path, file_size) - VALUES (%s, %s, %s, %s, %s)""", - (file.filename, file_ext, kb_id, str(file_path), file_size) - ) - await conn.commit() - return { - "id": cur.lastrowid, - "filename": file.filename, - "size": file_size, - "path": str(file_path) - } +async def upload_file(kb_id: int, file: UploadFile = File(...)): + """文件上传接口""" + return await app.state.kb_dao.handle_upload(kb_id, file) if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=8000) \ No newline at end of file