main
HuangHai 4 weeks ago
parent 93f3ab1ed9
commit fe4955f466

@ -0,0 +1,13 @@
from typing import List, Optional
from Model.KbModel import KbModel, KbFileModel
from Util.MySQLUtil import init_mysql_pool
class KbDao:
def __init__(self, mysql_pool):
self.mysql_pool = mysql_pool
async def create_kb(self, kb: KbModel) -> int:
# 实现创建知识库的数据库操作
pass
# 其他CRUD方法实现...

@ -0,0 +1,14 @@
from pydantic import BaseModel
from typing import Optional
class KbModel(BaseModel):
kb_name: str
short_name: str
is_delete: Optional[int] = 0
class KbFileModel(BaseModel):
file_name: str
ext_name: str
kb_id: int
is_delete: Optional[int] = 0
state: Optional[int] = 0

@ -1,82 +1,64 @@
"""
pip install fastapi uvicorn aiomysql
"""
from pathlib import Path
import uvicorn
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from Util.MySQLUtil import *
"""
API文档访问 http://localhost:8000/docs
该实现包含以下功能
- 知识库(t_ai_kb)的增删改查接口
- 知识库文件(t_ai_kb_files)的增删改查接口
- 使用MySQLUtil.py中的连接池管理
- 自动生成的Swagger文档
"""
app = FastAPI()
# 知识库模型
class KbModel(BaseModel):
kb_name: str
short_name: str
is_delete: Optional[int] = 0
# 知识库文件模型
class KbFileModel(BaseModel):
file_name: str
ext_name: str
kb_id: int
is_delete: Optional[int] = 0
state: Optional[int] = 0
@app.on_event("startup")
async def startup_event():
app.state.mysql_pool = await init_mysql_pool()
@app.on_event("shutdown")
async def shutdown_event():
app.state.mysql_pool.close()
await app.state.mysql_pool.wait_closed()
# 知识库CRUD接口
from fastapi import FastAPI, UploadFile, File, HTTPException
from pymysql.cursors import DictCursor
from Dao.KbDao import KbDao
from Model.KbModel import KbModel, KbFileModel
from Util.MySQLUtil import init_mysql_pool
from contextlib import asynccontextmanager
@asynccontextmanager
async def lifespan(app: FastAPI):
# 启动时初始化数据库连接池
app.state.kb_dao = KbDao(await init_mysql_pool())
yield
# 关闭时清理资源
await app.state.kb_dao.mysql_pool.close()
app = FastAPI(lifespan=lifespan)
@app.post("/kb")
async def create_kb(kb: KbModel):
async with app.state.mysql_pool.acquire() as conn:
async with conn.cursor() as cur:
await cur.execute(
"""INSERT INTO t_ai_kb (kb_name, short_name, is_delete)
VALUES (%s, %s, %s)""",
(kb.kb_name, kb.short_name, kb.is_delete)
)
await conn.commit()
return {"id": cur.lastrowid}
return await app.state.kb_dao.create_kb(kb)
@app.get("/kb/{kb_id}")
async def read_kb(kb_id: int):
"""获取知识库详情
Args:
kb_id: 知识库ID
Returns:
返回知识库详细信息
"""
async with app.state.mysql_pool.acquire() as conn:
async with conn.cursor(DictCursor) as cur:
async with conn.cursor() as cur:
await cur.execute("SELECT * FROM t_ai_kb WHERE id = %s", (kb_id,))
result = await cur.fetchone()
if not result:
raise HTTPException(status_code=404, detail="Knowledge base not found")
raise HTTPException(status_code=404, detail="知识库不存在")
return result
@app.put("/kb/{kb_id}")
@app.post("/kb/update/{kb_id}")
async def update_kb(kb_id: int, kb: KbModel):
"""更新知识库信息
Args:
kb_id: 知识库ID
kb: 包含更新后知识库信息的对象
Returns:
返回更新结果
"""
async with app.state.mysql_pool.acquire() as conn:
async with conn.cursor() as cur:
await cur.execute(
"""UPDATE t_ai_kb
SET kb_name = %s, short_name = %s, is_delete = %s
WHERE id = %s""",
SET kb_name=%s, short_name=%s, is_delete=%s
WHERE id=%s""",
(kb.kb_name, kb.short_name, kb.is_delete, kb_id)
)
await conn.commit()
return {"message": "Knowledge base updated"}
return {"status": "success", "affected_rows": cur.rowcount}
@app.delete("/kb/{kb_id}")
async def delete_kb(kb_id: int):
@ -87,8 +69,14 @@ async def delete_kb(kb_id: int):
return {"message": "Knowledge base deleted"}
# 知识库文件CRUD接口
@app.post("/kb_files")
@app.post("/kb_file")
async def create_kb_file(file: KbFileModel):
"""创建知识库文件
Args:
file: 包含文件名(file_name)扩展名(ext_name)关联知识库ID(kb_id)等信息的对象
Returns:
返回创建成功的文件ID
"""
async with app.state.mysql_pool.acquire() as conn:
async with conn.cursor() as cur:
await cur.execute(
@ -110,20 +98,28 @@ async def read_kb_file(file_id: int):
raise HTTPException(status_code=404, detail="File not found")
return result
@app.put("/kb_files/{file_id}")
# 知识库文件更新接口修改
@app.post("/kb_files/update/{file_id}")
async def update_kb_file(file_id: int, file: KbFileModel):
"""更新知识库文件信息
Args:
file_id: 文件ID
file: 包含更新后文件信息的对象
Returns:
返回更新结果
"""
async with app.state.mysql_pool.acquire() as conn:
async with conn.cursor() as cur:
await cur.execute(
"""UPDATE t_ai_kb_files
SET file_name = %s, ext_name = %s, kb_id = %s,
is_delete = %s, state = %s
WHERE id = %s""",
SET file_name=%s, ext_name=%s, kb_id=%s,
is_delete=%s, state=%s
WHERE id=%s""",
(file.file_name, file.ext_name, file.kb_id,
file.is_delete, file.state, file_id)
)
await conn.commit()
return {"message": "File updated"}
return {"status": "success", "affected_rows": cur.rowcount}
@app.delete("/kb_files/{file_id}")
async def delete_kb_file(file_id: int):
@ -133,5 +129,67 @@ async def delete_kb_file(file_id: int):
await conn.commit()
return {"message": "File deleted"}
# 允许的文件类型
ALLOWED_EXTENSIONS = {'txt', 'pdf', 'docx', 'ppt', 'pptx'}
MAX_FILE_SIZE = 100 * 1024 * 1024 # 100MB
@app.post("/upload")
async def upload_file(
kb_id: int,
file: UploadFile = File(...)
):
"""
上传文件接口
Args:
kb_id: 关联的知识库ID
file: 上传的文件对象
Returns:
返回文件保存信息和数据库记录ID
"""
# 检查文件类型
file_ext = Path(file.filename).suffix.lower()[1:]
if file_ext not in ALLOWED_EXTENSIONS:
raise HTTPException(
status_code=400,
detail=f"不支持的文件类型: {file_ext}, 仅支持: {', '.join(ALLOWED_EXTENSIONS)}"
)
# 检查文件大小
file_size = 0
for chunk in file.file:
file_size += len(chunk)
if file_size > MAX_FILE_SIZE:
raise HTTPException(
status_code=400,
detail=f"文件大小超过限制: {MAX_FILE_SIZE/1024/1024}MB"
)
file.file.seek(0)
# 创建Upload目录
upload_dir = Path("Upload")
upload_dir.mkdir(exist_ok=True)
# 保存文件
file_path = upload_dir / file.filename
with open(file_path, "wb") as f:
f.write(await file.read())
# 记录到数据库
async with app.state.mysql_pool.acquire() as conn:
async with conn.cursor() as cur:
await cur.execute(
"""INSERT INTO t_ai_kb_files
(file_name, ext_name, kb_id, file_path, file_size)
VALUES (%s, %s, %s, %s, %s)""",
(file.filename, file_ext, kb_id, str(file_path), file_size)
)
await conn.commit()
return {
"id": cur.lastrowid,
"filename": file.filename,
"size": file_size,
"path": str(file_path)
}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)
Loading…
Cancel
Save