main
HuangHai 4 weeks ago
parent 1a13554c2b
commit 3c6cf034eb

@ -1,13 +1,104 @@
from typing import List, Optional
from Model.KbModel import KbModel, KbFileModel
from Util.MySQLUtil import init_mysql_pool
import logging
from typing import Optional, List, Dict
from aiomysql import DictCursor
class KbDao:
def __init__(self, mysql_pool):
self.mysql_pool = mysql_pool
async def create_kb(self, kb: KbModel) -> int:
# 实现创建知识库的数据库操作
self.logger = logging.getLogger(__name__)
async def create_kb(self, kb: Dict) -> int:
"""创建知识库"""
async with self.mysql_pool.acquire() as conn:
async with conn.cursor(DictCursor) as cur:
await cur.execute(
"INSERT INTO t_ai_kb(name, description) VALUES(%s, %s)",
(kb['name'], kb['description']))
await conn.commit()
return cur.lastrowid
async def get_kb(self, kb_id: int) -> Optional[Dict]:
"""获取知识库详情"""
async with self.mysql_pool.acquire() as conn:
async with conn.cursor(DictCursor) as cur:
await cur.execute(
"SELECT * FROM t_ai_kb WHERE id = %s", (kb_id,))
return await cur.fetchone()
async def update_kb(self, kb_id: int, kb: Dict) -> bool:
"""更新知识库信息"""
async with self.mysql_pool.acquire() as conn:
async with conn.cursor() as cur:
await cur.execute(
"UPDATE t_ai_kb SET name = %s, description = %s WHERE id = %s",
(kb['name'], kb['description'], kb_id))
await conn.commit()
return cur.rowcount > 0
async def delete_kb(self, kb_id: int) -> bool:
"""删除知识库"""
async with self.mysql_pool.acquire() as conn:
async with conn.cursor() as cur:
await cur.execute(
"DELETE FROM t_ai_kb WHERE id = %s", (kb_id,))
await conn.commit()
return cur.rowcount > 0
async def create_kb_file(self, file: Dict) -> int:
"""创建知识库文件记录"""
async with self.mysql_pool.acquire() as conn:
async with conn.cursor(DictCursor) as cur:
await cur.execute(
"""INSERT INTO t_ai_kb_files
(kb_id, file_name, file_path, file_size, file_type, state)
VALUES(%s, %s, %s, %s, %s, %s)""",
(file['kb_id'], file['file_name'], file['file_path'],
file['file_size'], file['file_type'], file['state']))
await conn.commit()
return cur.lastrowid
async def get_kb_file(self, file_id: int) -> Optional[Dict]:
"""获取文件详情"""
async with self.mysql_pool.acquire() as conn:
async with conn.cursor(DictCursor) as cur:
await cur.execute(
"SELECT * FROM t_ai_kb_files WHERE id = %s", (file_id,))
return await cur.fetchone()
async def update_kb_file(self, file_id: int, file: Dict) -> bool:
"""更新文件信息"""
async with self.mysql_pool.acquire() as conn:
async with conn.cursor() as cur:
await cur.execute(
"""UPDATE t_ai_kb_files SET
kb_id = %s, file_name = %s, file_path = %s,
file_size = %s, file_type = %s, state = %s
WHERE id = %s""",
(file['kb_id'], file['file_name'], file['file_path'],
file['file_size'], file['file_type'], file['state'], file_id))
await conn.commit()
return cur.rowcount > 0
async def delete_kb_file(self, file_id: int) -> bool:
"""删除文件记录"""
async with self.mysql_pool.acquire() as conn:
async with conn.cursor() as cur:
await cur.execute(
"DELETE FROM t_ai_kb_files WHERE id = %s", (file_id,))
await conn.commit()
return cur.rowcount > 0
async def handle_upload(self, kb_id: int, file) -> Dict:
"""处理文件上传"""
# 文件保存逻辑
# 数据库记录创建
# 返回文件信息
pass
# 其他CRUD方法实现...
async def get_unprocessed_files(self) -> List[Dict]:
"""获取未处理文件列表"""
async with self.mysql_pool.acquire() as conn:
async with conn.cursor(DictCursor) as cur:
await cur.execute(
"SELECT * FROM t_ai_kb_files WHERE state = 0")
return await cur.fetchall()

@ -1,5 +1,46 @@
from pydantic import BaseModel
from typing import Optional
from datetime import datetime
class KbBase(BaseModel):
name: str
description: Optional[str] = None
class KbCreate(KbBase):
pass
class KbUpdate(KbBase):
pass
class Kb(KbBase):
id: int
create_time: datetime
update_time: datetime
class Config:
orm_mode = True
class KbFileBase(BaseModel):
kb_id: int
file_name: str
file_path: str
file_size: int
file_type: str
state: int = 0
class KbFileCreate(KbFileBase):
pass
class KbFileUpdate(KbFileBase):
pass
class KbFile(KbFileBase):
id: int
create_time: datetime
update_time: datetime
class Config:
orm_mode = True
class KbModel(BaseModel):
kb_name: str

@ -1,28 +1,27 @@
import asyncio
import logging
import threading
import time
from contextlib import asynccontextmanager
from logging.handlers import RotatingFileHandler
from pathlib import Path
import asyncio
import uvicorn
from fastapi import FastAPI, UploadFile, File, HTTPException
from pymysql.cursors import DictCursor
from fastapi import FastAPI, UploadFile, File
from Dao.KbDao import KbDao
from Model.KbModel import KbModel, KbFileModel
from Util.MySQLUtil import init_mysql_pool
# 初始化日志
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
handler = RotatingFileHandler('Logs/document_processor.log', maxBytes=1024*1024, backupCount=5)
handler = RotatingFileHandler('Logs/start.log', maxBytes=1024*1024, backupCount=5)
handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
logger.addHandler(handler)
@asynccontextmanager
async def lifespan(app: FastAPI):
# 启动时初始化数据库连接池
# 初始化数据库连接池
app.state.kb_dao = KbDao(await init_mysql_pool())
# 启动文档处理线程
# 启动文档处理任务
async def document_processor():
while True:
try:
@ -33,206 +32,63 @@ async def lifespan(app: FastAPI):
except Exception as e:
logger.error(f"文档处理出错: {e}")
await asyncio.sleep(10)
time.sleep(10) # 每10秒检查一次
def run_async_in_thread():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(document_processor())
finally:
loop.close()
processor_thread = threading.Thread(
target=run_async_in_thread,
daemon=True
)
processor_thread.start()
# 启动文档处理任务
task = asyncio.create_task(document_processor())
yield
# 关闭时取消任务
task.cancel()
# 关闭时清理资源
# 关闭数据库连接池
await app.state.kb_dao.mysql_pool.close()
app = FastAPI(lifespan=lifespan)
# 知识库CRUD接口
@app.post("/kb")
async def create_kb(kb: KbModel):
async def create_kb(kb: dict):
"""创建知识库"""
return await app.state.kb_dao.create_kb(kb)
@app.get("/kb/{kb_id}")
async def read_kb(kb_id: int):
"""获取知识库详情
Args:
kb_id: 知识库ID
Returns:
返回知识库详细信息
"""
async with app.state.mysql_pool.acquire() as conn:
async with conn.cursor() as cur:
await cur.execute("SELECT * FROM t_ai_kb WHERE id = %s", (kb_id,))
result = await cur.fetchone()
if not result:
raise HTTPException(status_code=404, detail="知识库不存在")
return result
"""获取知识库详情"""
return await app.state.kb_dao.get_kb(kb_id)
@app.post("/kb/update/{kb_id}")
async def update_kb(kb_id: int, kb: KbModel):
"""更新知识库信息
Args:
kb_id: 知识库ID
kb: 包含更新后知识库信息的对象
Returns:
返回更新结果
"""
async with app.state.mysql_pool.acquire() as conn:
async with conn.cursor() as cur:
await cur.execute(
"""UPDATE t_ai_kb
SET kb_name=%s, short_name=%s, is_delete=%s
WHERE id=%s""",
(kb.kb_name, kb.short_name, kb.is_delete, kb_id)
)
await conn.commit()
return {"status": "success", "affected_rows": cur.rowcount}
async def update_kb(kb_id: int, kb: dict):
"""更新知识库信息"""
return await app.state.kb_dao.update_kb(kb_id, kb)
@app.delete("/kb/{kb_id}")
async def delete_kb(kb_id: int):
async with app.state.mysql_pool.acquire() as conn:
async with conn.cursor() as cur:
await cur.execute("DELETE FROM t_ai_kb WHERE id = %s", (kb_id,))
await conn.commit()
return {"message": "Knowledge base deleted"}
"""删除知识库"""
return await app.state.kb_dao.delete_kb(kb_id)
# 知识库文件CRUD接口
@app.post("/kb_file")
async def create_kb_file(file: KbFileModel):
"""创建知识库文件
Args:
file: 包含文件名(file_name)扩展名(ext_name)关联知识库ID(kb_id)等信息的对象
Returns:
返回创建成功的文件ID
"""
async with app.state.mysql_pool.acquire() as conn:
async with conn.cursor() as cur:
await cur.execute(
"""INSERT INTO t_ai_kb_files
(file_name, ext_name, kb_id, is_delete, state)
VALUES (%s, %s, %s, %s, %s)""",
(file.file_name, file.ext_name, file.kb_id, file.is_delete, file.state)
)
await conn.commit()
return {"id": cur.lastrowid}
async def create_kb_file(file: dict):
"""创建知识库文件记录"""
return await app.state.kb_dao.create_kb_file(file)
@app.get("/kb_files/{file_id}")
async def read_kb_file(file_id: int):
async with app.state.mysql_pool.acquire() as conn:
async with conn.cursor(DictCursor) as cur:
await cur.execute("SELECT * FROM t_ai_kb_files WHERE id = %s", (file_id,))
result = await cur.fetchone()
if not result:
raise HTTPException(status_code=404, detail="File not found")
return result
"""获取文件详情"""
return await app.state.kb_dao.get_kb_file(file_id)
# 知识库文件更新接口修改
@app.post("/kb_files/update/{file_id}")
async def update_kb_file(file_id: int, file: KbFileModel):
"""更新知识库文件信息
Args:
file_id: 文件ID
file: 包含更新后文件信息的对象
Returns:
返回更新结果
"""
async with app.state.mysql_pool.acquire() as conn:
async with conn.cursor() as cur:
await cur.execute(
"""UPDATE t_ai_kb_files
SET file_name=%s, ext_name=%s, kb_id=%s,
is_delete=%s, state=%s
WHERE id=%s""",
(file.file_name, file.ext_name, file.kb_id,
file.is_delete, file.state, file_id)
)
await conn.commit()
return {"status": "success", "affected_rows": cur.rowcount}
async def update_kb_file(file_id: int, file: dict):
"""更新文件信息"""
return await app.state.kb_dao.update_kb_file(file_id, file)
@app.delete("/kb_files/{file_id}")
async def delete_kb_file(file_id: int):
async with app.state.mysql_pool.acquire() as conn:
async with conn.cursor() as cur:
await cur.execute("DELETE FROM t_ai_kb_files WHERE id = %s", (file_id,))
await conn.commit()
return {"message": "File deleted"}
# 允许的文件类型
ALLOWED_EXTENSIONS = {'txt', 'pdf', 'docx', 'ppt', 'pptx'}
MAX_FILE_SIZE = 100 * 1024 * 1024 # 100MB
"""删除文件记录"""
return await app.state.kb_dao.delete_kb_file(file_id)
# 文件上传接口
@app.post("/upload")
async def upload_file(
kb_id: int,
file: UploadFile = File(...)
):
"""
上传文件接口
Args:
kb_id: 关联的知识库ID
file: 上传的文件对象
Returns:
返回文件保存信息和数据库记录ID
"""
# 检查文件类型
file_ext = Path(file.filename).suffix.lower()[1:]
if file_ext not in ALLOWED_EXTENSIONS:
raise HTTPException(
status_code=400,
detail=f"不支持的文件类型: {file_ext}, 仅支持: {', '.join(ALLOWED_EXTENSIONS)}"
)
# 检查文件大小
file_size = 0
for chunk in file.file:
file_size += len(chunk)
if file_size > MAX_FILE_SIZE:
raise HTTPException(
status_code=400,
detail=f"文件大小超过限制: {MAX_FILE_SIZE/1024/1024}MB"
)
file.file.seek(0)
# 创建Upload目录
upload_dir = Path("Upload")
upload_dir.mkdir(exist_ok=True)
# 保存文件
file_path = upload_dir / file.filename
with open(file_path, "wb") as f:
f.write(await file.read())
# 记录到数据库
async with app.state.mysql_pool.acquire() as conn:
async with conn.cursor() as cur:
await cur.execute(
"""INSERT INTO t_ai_kb_files
(file_name, ext_name, kb_id, file_path, file_size)
VALUES (%s, %s, %s, %s, %s)""",
(file.filename, file_ext, kb_id, str(file_path), file_size)
)
await conn.commit()
return {
"id": cur.lastrowid,
"filename": file.filename,
"size": file_size,
"path": str(file_path)
}
async def upload_file(kb_id: int, file: UploadFile = File(...)):
"""文件上传接口"""
return await app.state.kb_dao.handle_upload(kb_id, file)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)
Loading…
Cancel
Save