You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

137 lines
4.7 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""
pip install fastapi uvicorn aiomysql
"""
import uvicorn
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from Util.MySQLUtil import *
"""
API文档访问 http://localhost:8000/docs
该实现包含以下功能:
- 知识库(t_ai_kb)的增删改查接口
- 知识库文件(t_ai_kb_files)的增删改查接口
- 使用MySQLUtil.py中的连接池管理
- 自动生成的Swagger文档
"""
app = FastAPI()
# 知识库模型
class KbModel(BaseModel):
kb_name: str
short_name: str
is_delete: Optional[int] = 0
# 知识库文件模型
class KbFileModel(BaseModel):
file_name: str
ext_name: str
kb_id: int
is_delete: Optional[int] = 0
state: Optional[int] = 0
@app.on_event("startup")
async def startup_event():
app.state.mysql_pool = await init_mysql_pool()
@app.on_event("shutdown")
async def shutdown_event():
app.state.mysql_pool.close()
await app.state.mysql_pool.wait_closed()
# 知识库CRUD接口
@app.post("/kb")
async def create_kb(kb: KbModel):
async with app.state.mysql_pool.acquire() as conn:
async with conn.cursor() as cur:
await cur.execute(
"""INSERT INTO t_ai_kb (kb_name, short_name, is_delete)
VALUES (%s, %s, %s)""",
(kb.kb_name, kb.short_name, kb.is_delete)
)
await conn.commit()
return {"id": cur.lastrowid}
@app.get("/kb/{kb_id}")
async def read_kb(kb_id: int):
async with app.state.mysql_pool.acquire() as conn:
async with conn.cursor(DictCursor) as cur:
await cur.execute("SELECT * FROM t_ai_kb WHERE id = %s", (kb_id,))
result = await cur.fetchone()
if not result:
raise HTTPException(status_code=404, detail="Knowledge base not found")
return result
@app.put("/kb/{kb_id}")
async def update_kb(kb_id: int, kb: KbModel):
async with app.state.mysql_pool.acquire() as conn:
async with conn.cursor() as cur:
await cur.execute(
"""UPDATE t_ai_kb
SET kb_name = %s, short_name = %s, is_delete = %s
WHERE id = %s""",
(kb.kb_name, kb.short_name, kb.is_delete, kb_id)
)
await conn.commit()
return {"message": "Knowledge base updated"}
@app.delete("/kb/{kb_id}")
async def delete_kb(kb_id: int):
async with app.state.mysql_pool.acquire() as conn:
async with conn.cursor() as cur:
await cur.execute("DELETE FROM t_ai_kb WHERE id = %s", (kb_id,))
await conn.commit()
return {"message": "Knowledge base deleted"}
# 知识库文件CRUD接口
@app.post("/kb_files")
async def create_kb_file(file: KbFileModel):
async with app.state.mysql_pool.acquire() as conn:
async with conn.cursor() as cur:
await cur.execute(
"""INSERT INTO t_ai_kb_files
(file_name, ext_name, kb_id, is_delete, state)
VALUES (%s, %s, %s, %s, %s)""",
(file.file_name, file.ext_name, file.kb_id, file.is_delete, file.state)
)
await conn.commit()
return {"id": cur.lastrowid}
@app.get("/kb_files/{file_id}")
async def read_kb_file(file_id: int):
async with app.state.mysql_pool.acquire() as conn:
async with conn.cursor(DictCursor) as cur:
await cur.execute("SELECT * FROM t_ai_kb_files WHERE id = %s", (file_id,))
result = await cur.fetchone()
if not result:
raise HTTPException(status_code=404, detail="File not found")
return result
@app.put("/kb_files/{file_id}")
async def update_kb_file(file_id: int, file: KbFileModel):
async with app.state.mysql_pool.acquire() as conn:
async with conn.cursor() as cur:
await cur.execute(
"""UPDATE t_ai_kb_files
SET file_name = %s, ext_name = %s, kb_id = %s,
is_delete = %s, state = %s
WHERE id = %s""",
(file.file_name, file.ext_name, file.kb_id,
file.is_delete, file.state, file_id)
)
await conn.commit()
return {"message": "File updated"}
@app.delete("/kb_files/{file_id}")
async def delete_kb_file(file_id: int):
async with app.state.mysql_pool.acquire() as conn:
async with conn.cursor() as cur:
await cur.execute("DELETE FROM t_ai_kb_files WHERE id = %s", (file_id,))
await conn.commit()
return {"message": "File deleted"}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)