You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

238 lines
7.9 KiB

import logging
import threading
import time
from contextlib import asynccontextmanager
from logging.handlers import RotatingFileHandler
from pathlib import Path
import asyncio
import uvicorn
from fastapi import FastAPI, UploadFile, File, HTTPException
from pymysql.cursors import DictCursor
from Dao.KbDao import KbDao
from Model.KbModel import KbModel, KbFileModel
from Util.MySQLUtil import init_mysql_pool
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
handler = RotatingFileHandler('Logs/document_processor.log', maxBytes=1024*1024, backupCount=5)
handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
logger.addHandler(handler)
@asynccontextmanager
async def lifespan(app: FastAPI):
# 启动时初始化数据库连接池
app.state.kb_dao = KbDao(await init_mysql_pool())
# 启动文档处理线程
async def document_processor():
while True:
try:
# 获取未处理文档
# 处理文档
# 保存到ES
await asyncio.sleep(10)
except Exception as e:
logger.error(f"文档处理出错: {e}")
await asyncio.sleep(10)
time.sleep(10) # 每10秒检查一次
def run_async_in_thread():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(document_processor())
finally:
loop.close()
processor_thread = threading.Thread(
target=run_async_in_thread,
daemon=True
)
processor_thread.start()
# 启动文档处理任务
task = asyncio.create_task(document_processor())
yield
# 关闭时取消任务
task.cancel()
# 关闭时清理资源
await app.state.kb_dao.mysql_pool.close()
app = FastAPI(lifespan=lifespan)
@app.post("/kb")
async def create_kb(kb: KbModel):
return await app.state.kb_dao.create_kb(kb)
@app.get("/kb/{kb_id}")
async def read_kb(kb_id: int):
"""获取知识库详情
Args:
kb_id: 知识库ID
Returns:
返回知识库详细信息
"""
async with app.state.mysql_pool.acquire() as conn:
async with conn.cursor() as cur:
await cur.execute("SELECT * FROM t_ai_kb WHERE id = %s", (kb_id,))
result = await cur.fetchone()
if not result:
raise HTTPException(status_code=404, detail="知识库不存在")
return result
@app.post("/kb/update/{kb_id}")
async def update_kb(kb_id: int, kb: KbModel):
"""更新知识库信息
Args:
kb_id: 知识库ID
kb: 包含更新后知识库信息的对象
Returns:
返回更新结果
"""
async with app.state.mysql_pool.acquire() as conn:
async with conn.cursor() as cur:
await cur.execute(
"""UPDATE t_ai_kb
SET kb_name=%s, short_name=%s, is_delete=%s
WHERE id=%s""",
(kb.kb_name, kb.short_name, kb.is_delete, kb_id)
)
await conn.commit()
return {"status": "success", "affected_rows": cur.rowcount}
@app.delete("/kb/{kb_id}")
async def delete_kb(kb_id: int):
async with app.state.mysql_pool.acquire() as conn:
async with conn.cursor() as cur:
await cur.execute("DELETE FROM t_ai_kb WHERE id = %s", (kb_id,))
await conn.commit()
return {"message": "Knowledge base deleted"}
# 知识库文件CRUD接口
@app.post("/kb_file")
async def create_kb_file(file: KbFileModel):
"""创建知识库文件
Args:
file: 包含文件名(file_name)、扩展名(ext_name)、关联知识库ID(kb_id)等信息的对象
Returns:
返回创建成功的文件ID
"""
async with app.state.mysql_pool.acquire() as conn:
async with conn.cursor() as cur:
await cur.execute(
"""INSERT INTO t_ai_kb_files
(file_name, ext_name, kb_id, is_delete, state)
VALUES (%s, %s, %s, %s, %s)""",
(file.file_name, file.ext_name, file.kb_id, file.is_delete, file.state)
)
await conn.commit()
return {"id": cur.lastrowid}
@app.get("/kb_files/{file_id}")
async def read_kb_file(file_id: int):
async with app.state.mysql_pool.acquire() as conn:
async with conn.cursor(DictCursor) as cur:
await cur.execute("SELECT * FROM t_ai_kb_files WHERE id = %s", (file_id,))
result = await cur.fetchone()
if not result:
raise HTTPException(status_code=404, detail="File not found")
return result
# 知识库文件更新接口修改
@app.post("/kb_files/update/{file_id}")
async def update_kb_file(file_id: int, file: KbFileModel):
"""更新知识库文件信息
Args:
file_id: 文件ID
file: 包含更新后文件信息的对象
Returns:
返回更新结果
"""
async with app.state.mysql_pool.acquire() as conn:
async with conn.cursor() as cur:
await cur.execute(
"""UPDATE t_ai_kb_files
SET file_name=%s, ext_name=%s, kb_id=%s,
is_delete=%s, state=%s
WHERE id=%s""",
(file.file_name, file.ext_name, file.kb_id,
file.is_delete, file.state, file_id)
)
await conn.commit()
return {"status": "success", "affected_rows": cur.rowcount}
@app.delete("/kb_files/{file_id}")
async def delete_kb_file(file_id: int):
async with app.state.mysql_pool.acquire() as conn:
async with conn.cursor() as cur:
await cur.execute("DELETE FROM t_ai_kb_files WHERE id = %s", (file_id,))
await conn.commit()
return {"message": "File deleted"}
# 允许的文件类型
ALLOWED_EXTENSIONS = {'txt', 'pdf', 'docx', 'ppt', 'pptx'}
MAX_FILE_SIZE = 100 * 1024 * 1024 # 100MB
@app.post("/upload")
async def upload_file(
kb_id: int,
file: UploadFile = File(...)
):
"""
上传文件接口
Args:
kb_id: 关联的知识库ID
file: 上传的文件对象
Returns:
返回文件保存信息和数据库记录ID
"""
# 检查文件类型
file_ext = Path(file.filename).suffix.lower()[1:]
if file_ext not in ALLOWED_EXTENSIONS:
raise HTTPException(
status_code=400,
detail=f"不支持的文件类型: {file_ext}, 仅支持: {', '.join(ALLOWED_EXTENSIONS)}"
)
# 检查文件大小
file_size = 0
for chunk in file.file:
file_size += len(chunk)
if file_size > MAX_FILE_SIZE:
raise HTTPException(
status_code=400,
detail=f"文件大小超过限制: {MAX_FILE_SIZE/1024/1024}MB"
)
file.file.seek(0)
# 创建Upload目录
upload_dir = Path("Upload")
upload_dir.mkdir(exist_ok=True)
# 保存文件
file_path = upload_dir / file.filename
with open(file_path, "wb") as f:
f.write(await file.read())
# 记录到数据库
async with app.state.mysql_pool.acquire() as conn:
async with conn.cursor() as cur:
await cur.execute(
"""INSERT INTO t_ai_kb_files
(file_name, ext_name, kb_id, file_path, file_size)
VALUES (%s, %s, %s, %s, %s)""",
(file.filename, file_ext, kb_id, str(file_path), file_size)
)
await conn.commit()
return {
"id": cur.lastrowid,
"filename": file.filename,
"size": file_size,
"path": str(file_path)
}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)