You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

203 lines
6.1 KiB

import asyncio
import logging
from contextlib import asynccontextmanager
from logging.handlers import RotatingFileHandler
import uvicorn
from fastapi import FastAPI, UploadFile, File, Request
from sse_starlette.sse import EventSourceResponse
from elasticsearch import Elasticsearch
from openai import OpenAI
from Dao.KbDao import KbDao
from Util.MySQLUtil import init_mysql_pool
from Config import Config
from fastapi.staticfiles import StaticFiles
import urllib3
# 初始化日志
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
handler = RotatingFileHandler('Logs/start.log', maxBytes=1024*1024, backupCount=5)
handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
logger.addHandler(handler)
@asynccontextmanager
async def lifespan(app: FastAPI):
# 初始化数据库连接池
app.state.kb_dao = KbDao(await init_mysql_pool())
# 初始化ES连接
import urllib3
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
# 初始化ES连接时添加verify_certs=False
app.state.es = Elasticsearch(
hosts=Config.ES_CONFIG['hosts'],
basic_auth=Config.ES_CONFIG['basic_auth'],
verify_certs=False # 禁用证书验证
)
# 初始化DeepSeek客户端
app.state.deepseek_client = OpenAI(
api_key=Config.DEEPSEEK_API_KEY,
base_url=Config.DEEPSEEK_URL
)
# 启动文档处理任务
async def document_processor():
while True:
try:
# 获取未处理文档
# 处理文档
# 保存到ES
await asyncio.sleep(10)
except Exception as e:
logger.error(f"文档处理出错: {e}")
await asyncio.sleep(10)
task = asyncio.create_task(document_processor())
yield
# 关闭时取消任务
task.cancel()
# 关闭数据库连接池
await app.state.kb_dao.mysql_pool.close()
app = FastAPI(lifespan=lifespan)
# 知识库CRUD接口
@app.get("/kb")
async def list_kbs():
"""获取所有知识库列表"""
return await app.state.kb_dao.list_kbs()
@app.post("/kb")
async def create_kb(kb: dict):
"""创建知识库"""
return await app.state.kb_dao.create_kb(kb)
@app.get("/kb/{kb_id}")
async def read_kb(kb_id: int):
"""获取知识库详情"""
return await app.state.kb_dao.get_kb(kb_id)
@app.post("/kb/update/{kb_id}")
async def update_kb(kb_id: int, kb: dict):
"""更新知识库信息"""
return await app.state.kb_dao.update_kb(kb_id, kb)
@app.delete("/kb/{kb_id}")
async def delete_kb(kb_id: int):
"""删除知识库"""
return await app.state.kb_dao.delete_kb(kb_id)
# 知识库文件CRUD接口
@app.post("/kb_file")
async def create_kb_file(file: dict):
"""创建知识库文件记录"""
return await app.state.kb_dao.create_kb_file(file)
@app.get("/kb_files/{file_id}")
async def read_kb_file(file_id: int):
"""获取文件详情"""
return await app.state.kb_dao.get_kb_file(file_id)
@app.post("/kb_files/update/{file_id}")
async def update_kb_file(file_id: int, file: dict):
"""更新文件信息"""
return await app.state.kb_dao.update_kb_file(file_id, file)
@app.delete("/kb_files/{file_id}")
async def delete_kb_file(file_id: int):
"""删除文件记录"""
return await app.state.kb_dao.delete_kb_file(file_id)
# 文件上传接口
@app.post("/upload")
async def upload_file(kb_id: int, file: UploadFile = File(...)):
"""文件上传接口"""
return await app.state.kb_dao.handle_upload(kb_id, file)
def search_related_data(es, query):
"""搜索与查询相关的数据"""
# 向量搜索
vector_results = es.search(
index=Config.ES_CONFIG['default_index'],
body={
"query": {
"match": {
"content": {
"query": query,
"analyzer": "ik_smart"
}
}
},
"size": 5
}
)
# 文本精确搜索
text_results = es.search(
index="raw_texts",
body={
"query": {
"match": {
"text.keyword": query
}
},
"size": 5
}
)
# 合并结果
context = ""
for hit in vector_results['hits']['hits']:
context += f"向量相似度结果(score={hit['_score']}):\n{hit['_source']['text']}\n\n"
for hit in text_results['hits']['hits']:
context += f"文本精确匹配结果(score={hit['_score']}):\n{hit['_source']['text']}\n\n"
return context
async def generate_stream(client, es, query):
"""生成SSE流"""
context = search_related_data(es, query)
prompt = f"""根据以下关于'{query}'的相关信息,整理一份结构化的报告:
要求:
1. 分章节组织内容
2. 包含关键数据和事实
3. 语言简洁专业
相关信息:
{context}"""
try:
response = client.chat.completions.create(
model="deepseek-chat",
messages=[
{"role": "system", "content": "你是一个专业的文档整理助手"},
{"role": "user", "content": prompt}
],
temperature=0.3,
stream=True
)
for chunk in response:
if chunk.choices[0].delta.content:
yield {"data": chunk.choices[0].delta.content}
await asyncio.sleep(0.01)
except Exception as e:
yield {"data": f"生成报告时出错: {str(e)}"}
@app.get("/api/rag")
async def rag_stream(query: str, request: Request):
"""RAG+DeepSeek流式接口"""
return EventSourceResponse(
generate_stream(request.app.state.deepseek_client, request.app.state.es, query)
)
app.mount("/static", StaticFiles(directory="Static"), name="static")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)