|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
import uvicorn
|
|
|
|
from fastapi import FastAPI, UploadFile, File, HTTPException
|
|
|
|
from pymysql.cursors import DictCursor
|
|
|
|
|
|
|
|
from Dao.KbDao import KbDao
|
|
|
|
from Model.KbModel import KbModel, KbFileModel
|
|
|
|
from Util.MySQLUtil import init_mysql_pool
|
|
|
|
from contextlib import asynccontextmanager
|
|
|
|
|
|
|
|
@asynccontextmanager
|
|
|
|
async def lifespan(app: FastAPI):
|
|
|
|
# 启动时初始化数据库连接池
|
|
|
|
app.state.kb_dao = KbDao(await init_mysql_pool())
|
|
|
|
yield
|
|
|
|
# 关闭时清理资源
|
|
|
|
await app.state.kb_dao.mysql_pool.close()
|
|
|
|
|
|
|
|
app = FastAPI(lifespan=lifespan)
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/kb")
|
|
|
|
async def create_kb(kb: KbModel):
|
|
|
|
return await app.state.kb_dao.create_kb(kb)
|
|
|
|
|
|
|
|
@app.get("/kb/{kb_id}")
|
|
|
|
async def read_kb(kb_id: int):
|
|
|
|
"""获取知识库详情
|
|
|
|
Args:
|
|
|
|
kb_id: 知识库ID
|
|
|
|
Returns:
|
|
|
|
返回知识库详细信息
|
|
|
|
"""
|
|
|
|
async with app.state.mysql_pool.acquire() as conn:
|
|
|
|
async with conn.cursor() as cur:
|
|
|
|
await cur.execute("SELECT * FROM t_ai_kb WHERE id = %s", (kb_id,))
|
|
|
|
result = await cur.fetchone()
|
|
|
|
if not result:
|
|
|
|
raise HTTPException(status_code=404, detail="知识库不存在")
|
|
|
|
return result
|
|
|
|
|
|
|
|
@app.post("/kb/update/{kb_id}")
|
|
|
|
async def update_kb(kb_id: int, kb: KbModel):
|
|
|
|
"""更新知识库信息
|
|
|
|
Args:
|
|
|
|
kb_id: 知识库ID
|
|
|
|
kb: 包含更新后知识库信息的对象
|
|
|
|
Returns:
|
|
|
|
返回更新结果
|
|
|
|
"""
|
|
|
|
async with app.state.mysql_pool.acquire() as conn:
|
|
|
|
async with conn.cursor() as cur:
|
|
|
|
await cur.execute(
|
|
|
|
"""UPDATE t_ai_kb
|
|
|
|
SET kb_name=%s, short_name=%s, is_delete=%s
|
|
|
|
WHERE id=%s""",
|
|
|
|
(kb.kb_name, kb.short_name, kb.is_delete, kb_id)
|
|
|
|
)
|
|
|
|
await conn.commit()
|
|
|
|
return {"status": "success", "affected_rows": cur.rowcount}
|
|
|
|
|
|
|
|
@app.delete("/kb/{kb_id}")
|
|
|
|
async def delete_kb(kb_id: int):
|
|
|
|
async with app.state.mysql_pool.acquire() as conn:
|
|
|
|
async with conn.cursor() as cur:
|
|
|
|
await cur.execute("DELETE FROM t_ai_kb WHERE id = %s", (kb_id,))
|
|
|
|
await conn.commit()
|
|
|
|
return {"message": "Knowledge base deleted"}
|
|
|
|
|
|
|
|
# 知识库文件CRUD接口
|
|
|
|
@app.post("/kb_file")
|
|
|
|
async def create_kb_file(file: KbFileModel):
|
|
|
|
"""创建知识库文件
|
|
|
|
Args:
|
|
|
|
file: 包含文件名(file_name)、扩展名(ext_name)、关联知识库ID(kb_id)等信息的对象
|
|
|
|
Returns:
|
|
|
|
返回创建成功的文件ID
|
|
|
|
"""
|
|
|
|
async with app.state.mysql_pool.acquire() as conn:
|
|
|
|
async with conn.cursor() as cur:
|
|
|
|
await cur.execute(
|
|
|
|
"""INSERT INTO t_ai_kb_files
|
|
|
|
(file_name, ext_name, kb_id, is_delete, state)
|
|
|
|
VALUES (%s, %s, %s, %s, %s)""",
|
|
|
|
(file.file_name, file.ext_name, file.kb_id, file.is_delete, file.state)
|
|
|
|
)
|
|
|
|
await conn.commit()
|
|
|
|
return {"id": cur.lastrowid}
|
|
|
|
|
|
|
|
@app.get("/kb_files/{file_id}")
|
|
|
|
async def read_kb_file(file_id: int):
|
|
|
|
async with app.state.mysql_pool.acquire() as conn:
|
|
|
|
async with conn.cursor(DictCursor) as cur:
|
|
|
|
await cur.execute("SELECT * FROM t_ai_kb_files WHERE id = %s", (file_id,))
|
|
|
|
result = await cur.fetchone()
|
|
|
|
if not result:
|
|
|
|
raise HTTPException(status_code=404, detail="File not found")
|
|
|
|
return result
|
|
|
|
|
|
|
|
# 知识库文件更新接口修改
|
|
|
|
@app.post("/kb_files/update/{file_id}")
|
|
|
|
async def update_kb_file(file_id: int, file: KbFileModel):
|
|
|
|
"""更新知识库文件信息
|
|
|
|
Args:
|
|
|
|
file_id: 文件ID
|
|
|
|
file: 包含更新后文件信息的对象
|
|
|
|
Returns:
|
|
|
|
返回更新结果
|
|
|
|
"""
|
|
|
|
async with app.state.mysql_pool.acquire() as conn:
|
|
|
|
async with conn.cursor() as cur:
|
|
|
|
await cur.execute(
|
|
|
|
"""UPDATE t_ai_kb_files
|
|
|
|
SET file_name=%s, ext_name=%s, kb_id=%s,
|
|
|
|
is_delete=%s, state=%s
|
|
|
|
WHERE id=%s""",
|
|
|
|
(file.file_name, file.ext_name, file.kb_id,
|
|
|
|
file.is_delete, file.state, file_id)
|
|
|
|
)
|
|
|
|
await conn.commit()
|
|
|
|
return {"status": "success", "affected_rows": cur.rowcount}
|
|
|
|
|
|
|
|
@app.delete("/kb_files/{file_id}")
|
|
|
|
async def delete_kb_file(file_id: int):
|
|
|
|
async with app.state.mysql_pool.acquire() as conn:
|
|
|
|
async with conn.cursor() as cur:
|
|
|
|
await cur.execute("DELETE FROM t_ai_kb_files WHERE id = %s", (file_id,))
|
|
|
|
await conn.commit()
|
|
|
|
return {"message": "File deleted"}
|
|
|
|
|
|
|
|
# 允许的文件类型
|
|
|
|
ALLOWED_EXTENSIONS = {'txt', 'pdf', 'docx', 'ppt', 'pptx'}
|
|
|
|
MAX_FILE_SIZE = 100 * 1024 * 1024 # 100MB
|
|
|
|
|
|
|
|
@app.post("/upload")
|
|
|
|
async def upload_file(
|
|
|
|
kb_id: int,
|
|
|
|
file: UploadFile = File(...)
|
|
|
|
):
|
|
|
|
"""
|
|
|
|
上传文件接口
|
|
|
|
Args:
|
|
|
|
kb_id: 关联的知识库ID
|
|
|
|
file: 上传的文件对象
|
|
|
|
Returns:
|
|
|
|
返回文件保存信息和数据库记录ID
|
|
|
|
"""
|
|
|
|
# 检查文件类型
|
|
|
|
file_ext = Path(file.filename).suffix.lower()[1:]
|
|
|
|
if file_ext not in ALLOWED_EXTENSIONS:
|
|
|
|
raise HTTPException(
|
|
|
|
status_code=400,
|
|
|
|
detail=f"不支持的文件类型: {file_ext}, 仅支持: {', '.join(ALLOWED_EXTENSIONS)}"
|
|
|
|
)
|
|
|
|
|
|
|
|
# 检查文件大小
|
|
|
|
file_size = 0
|
|
|
|
for chunk in file.file:
|
|
|
|
file_size += len(chunk)
|
|
|
|
if file_size > MAX_FILE_SIZE:
|
|
|
|
raise HTTPException(
|
|
|
|
status_code=400,
|
|
|
|
detail=f"文件大小超过限制: {MAX_FILE_SIZE/1024/1024}MB"
|
|
|
|
)
|
|
|
|
file.file.seek(0)
|
|
|
|
|
|
|
|
# 创建Upload目录
|
|
|
|
upload_dir = Path("Upload")
|
|
|
|
upload_dir.mkdir(exist_ok=True)
|
|
|
|
|
|
|
|
# 保存文件
|
|
|
|
file_path = upload_dir / file.filename
|
|
|
|
with open(file_path, "wb") as f:
|
|
|
|
f.write(await file.read())
|
|
|
|
|
|
|
|
# 记录到数据库
|
|
|
|
async with app.state.mysql_pool.acquire() as conn:
|
|
|
|
async with conn.cursor() as cur:
|
|
|
|
await cur.execute(
|
|
|
|
"""INSERT INTO t_ai_kb_files
|
|
|
|
(file_name, ext_name, kb_id, file_path, file_size)
|
|
|
|
VALUES (%s, %s, %s, %s, %s)""",
|
|
|
|
(file.filename, file_ext, kb_id, str(file_path), file_size)
|
|
|
|
)
|
|
|
|
await conn.commit()
|
|
|
|
return {
|
|
|
|
"id": cur.lastrowid,
|
|
|
|
"filename": file.filename,
|
|
|
|
"size": file_size,
|
|
|
|
"path": str(file_path)
|
|
|
|
}
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|