main
HuangHai 4 weeks ago
parent 93f3ab1ed9
commit fe4955f466

@ -0,0 +1,13 @@
from typing import List, Optional
from Model.KbModel import KbModel, KbFileModel
from Util.MySQLUtil import init_mysql_pool
class KbDao:
def __init__(self, mysql_pool):
self.mysql_pool = mysql_pool
async def create_kb(self, kb: KbModel) -> int:
# 实现创建知识库的数据库操作
pass
# 其他CRUD方法实现...

@ -0,0 +1,14 @@
from pydantic import BaseModel
from typing import Optional
class KbModel(BaseModel):
kb_name: str
short_name: str
is_delete: Optional[int] = 0
class KbFileModel(BaseModel):
file_name: str
ext_name: str
kb_id: int
is_delete: Optional[int] = 0
state: Optional[int] = 0

@ -1,82 +1,64 @@
""" from pathlib import Path
pip install fastapi uvicorn aiomysql
"""
import uvicorn import uvicorn
from fastapi import FastAPI, HTTPException from fastapi import FastAPI, UploadFile, File, HTTPException
from pydantic import BaseModel from pymysql.cursors import DictCursor
from Util.MySQLUtil import * from Dao.KbDao import KbDao
from Model.KbModel import KbModel, KbFileModel
""" from Util.MySQLUtil import init_mysql_pool
API文档访问 http://localhost:8000/docs from contextlib import asynccontextmanager
该实现包含以下功能
@asynccontextmanager
- 知识库(t_ai_kb)的增删改查接口 async def lifespan(app: FastAPI):
- 知识库文件(t_ai_kb_files)的增删改查接口 # 启动时初始化数据库连接池
- 使用MySQLUtil.py中的连接池管理 app.state.kb_dao = KbDao(await init_mysql_pool())
- 自动生成的Swagger文档 yield
""" # 关闭时清理资源
await app.state.kb_dao.mysql_pool.close()
app = FastAPI()
app = FastAPI(lifespan=lifespan)
# 知识库模型
class KbModel(BaseModel):
kb_name: str
short_name: str
is_delete: Optional[int] = 0
# 知识库文件模型
class KbFileModel(BaseModel):
file_name: str
ext_name: str
kb_id: int
is_delete: Optional[int] = 0
state: Optional[int] = 0
@app.on_event("startup")
async def startup_event():
app.state.mysql_pool = await init_mysql_pool()
@app.on_event("shutdown")
async def shutdown_event():
app.state.mysql_pool.close()
await app.state.mysql_pool.wait_closed()
# 知识库CRUD接口
@app.post("/kb") @app.post("/kb")
async def create_kb(kb: KbModel): async def create_kb(kb: KbModel):
async with app.state.mysql_pool.acquire() as conn: return await app.state.kb_dao.create_kb(kb)
async with conn.cursor() as cur:
await cur.execute(
"""INSERT INTO t_ai_kb (kb_name, short_name, is_delete)
VALUES (%s, %s, %s)""",
(kb.kb_name, kb.short_name, kb.is_delete)
)
await conn.commit()
return {"id": cur.lastrowid}
@app.get("/kb/{kb_id}") @app.get("/kb/{kb_id}")
async def read_kb(kb_id: int): async def read_kb(kb_id: int):
"""获取知识库详情
Args:
kb_id: 知识库ID
Returns:
返回知识库详细信息
"""
async with app.state.mysql_pool.acquire() as conn: async with app.state.mysql_pool.acquire() as conn:
async with conn.cursor(DictCursor) as cur: async with conn.cursor() as cur:
await cur.execute("SELECT * FROM t_ai_kb WHERE id = %s", (kb_id,)) await cur.execute("SELECT * FROM t_ai_kb WHERE id = %s", (kb_id,))
result = await cur.fetchone() result = await cur.fetchone()
if not result: if not result:
raise HTTPException(status_code=404, detail="Knowledge base not found") raise HTTPException(status_code=404, detail="知识库不存在")
return result return result
@app.put("/kb/{kb_id}") @app.post("/kb/update/{kb_id}")
async def update_kb(kb_id: int, kb: KbModel): async def update_kb(kb_id: int, kb: KbModel):
"""更新知识库信息
Args:
kb_id: 知识库ID
kb: 包含更新后知识库信息的对象
Returns:
返回更新结果
"""
async with app.state.mysql_pool.acquire() as conn: async with app.state.mysql_pool.acquire() as conn:
async with conn.cursor() as cur: async with conn.cursor() as cur:
await cur.execute( await cur.execute(
"""UPDATE t_ai_kb """UPDATE t_ai_kb
SET kb_name = %s, short_name = %s, is_delete = %s SET kb_name=%s, short_name=%s, is_delete=%s
WHERE id = %s""", WHERE id=%s""",
(kb.kb_name, kb.short_name, kb.is_delete, kb_id) (kb.kb_name, kb.short_name, kb.is_delete, kb_id)
) )
await conn.commit() await conn.commit()
return {"message": "Knowledge base updated"} return {"status": "success", "affected_rows": cur.rowcount}
@app.delete("/kb/{kb_id}") @app.delete("/kb/{kb_id}")
async def delete_kb(kb_id: int): async def delete_kb(kb_id: int):
@ -87,8 +69,14 @@ async def delete_kb(kb_id: int):
return {"message": "Knowledge base deleted"} return {"message": "Knowledge base deleted"}
# 知识库文件CRUD接口 # 知识库文件CRUD接口
@app.post("/kb_files") @app.post("/kb_file")
async def create_kb_file(file: KbFileModel): async def create_kb_file(file: KbFileModel):
"""创建知识库文件
Args:
file: 包含文件名(file_name)扩展名(ext_name)关联知识库ID(kb_id)等信息的对象
Returns:
返回创建成功的文件ID
"""
async with app.state.mysql_pool.acquire() as conn: async with app.state.mysql_pool.acquire() as conn:
async with conn.cursor() as cur: async with conn.cursor() as cur:
await cur.execute( await cur.execute(
@ -110,20 +98,28 @@ async def read_kb_file(file_id: int):
raise HTTPException(status_code=404, detail="File not found") raise HTTPException(status_code=404, detail="File not found")
return result return result
@app.put("/kb_files/{file_id}") # 知识库文件更新接口修改
@app.post("/kb_files/update/{file_id}")
async def update_kb_file(file_id: int, file: KbFileModel): async def update_kb_file(file_id: int, file: KbFileModel):
"""更新知识库文件信息
Args:
file_id: 文件ID
file: 包含更新后文件信息的对象
Returns:
返回更新结果
"""
async with app.state.mysql_pool.acquire() as conn: async with app.state.mysql_pool.acquire() as conn:
async with conn.cursor() as cur: async with conn.cursor() as cur:
await cur.execute( await cur.execute(
"""UPDATE t_ai_kb_files """UPDATE t_ai_kb_files
SET file_name = %s, ext_name = %s, kb_id = %s, SET file_name=%s, ext_name=%s, kb_id=%s,
is_delete = %s, state = %s is_delete=%s, state=%s
WHERE id = %s""", WHERE id=%s""",
(file.file_name, file.ext_name, file.kb_id, (file.file_name, file.ext_name, file.kb_id,
file.is_delete, file.state, file_id) file.is_delete, file.state, file_id)
) )
await conn.commit() await conn.commit()
return {"message": "File updated"} return {"status": "success", "affected_rows": cur.rowcount}
@app.delete("/kb_files/{file_id}") @app.delete("/kb_files/{file_id}")
async def delete_kb_file(file_id: int): async def delete_kb_file(file_id: int):
@ -133,5 +129,67 @@ async def delete_kb_file(file_id: int):
await conn.commit() await conn.commit()
return {"message": "File deleted"} return {"message": "File deleted"}
# 允许的文件类型
ALLOWED_EXTENSIONS = {'txt', 'pdf', 'docx', 'ppt', 'pptx'}
MAX_FILE_SIZE = 100 * 1024 * 1024 # 100MB
@app.post("/upload")
async def upload_file(
kb_id: int,
file: UploadFile = File(...)
):
"""
上传文件接口
Args:
kb_id: 关联的知识库ID
file: 上传的文件对象
Returns:
返回文件保存信息和数据库记录ID
"""
# 检查文件类型
file_ext = Path(file.filename).suffix.lower()[1:]
if file_ext not in ALLOWED_EXTENSIONS:
raise HTTPException(
status_code=400,
detail=f"不支持的文件类型: {file_ext}, 仅支持: {', '.join(ALLOWED_EXTENSIONS)}"
)
# 检查文件大小
file_size = 0
for chunk in file.file:
file_size += len(chunk)
if file_size > MAX_FILE_SIZE:
raise HTTPException(
status_code=400,
detail=f"文件大小超过限制: {MAX_FILE_SIZE/1024/1024}MB"
)
file.file.seek(0)
# 创建Upload目录
upload_dir = Path("Upload")
upload_dir.mkdir(exist_ok=True)
# 保存文件
file_path = upload_dir / file.filename
with open(file_path, "wb") as f:
f.write(await file.read())
# 记录到数据库
async with app.state.mysql_pool.acquire() as conn:
async with conn.cursor() as cur:
await cur.execute(
"""INSERT INTO t_ai_kb_files
(file_name, ext_name, kb_id, file_path, file_size)
VALUES (%s, %s, %s, %s, %s)""",
(file.filename, file_ext, kb_id, str(file_path), file_size)
)
await conn.commit()
return {
"id": cur.lastrowid,
"filename": file.filename,
"size": file_size,
"path": str(file_path)
}
if __name__ == "__main__": if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000) uvicorn.run(app, host="0.0.0.0", port=8000)
Loading…
Cancel
Save