import asyncio import logging from contextlib import asynccontextmanager from logging.handlers import RotatingFileHandler import uvicorn from fastapi import FastAPI, UploadFile, File, Request from sse_starlette.sse import EventSourceResponse from elasticsearch import Elasticsearch from openai import OpenAI from Dao.KbDao import KbDao from Util.MySQLUtil import init_mysql_pool from Config import Config from fastapi.staticfiles import StaticFiles # 初始化日志 logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) handler = RotatingFileHandler('Logs/start.log', maxBytes=1024*1024, backupCount=5) handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')) logger.addHandler(handler) @asynccontextmanager async def lifespan(app: FastAPI): # 初始化数据库连接池 app.state.kb_dao = KbDao(await init_mysql_pool()) # 初始化ES连接 import urllib3 urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) # 初始化ES连接时添加verify_certs=False app.state.es = Elasticsearch( hosts=Config.ES_CONFIG['hosts'], basic_auth=Config.ES_CONFIG['basic_auth'], verify_certs=False # 禁用证书验证 ) # 初始化DeepSeek客户端 app.state.deepseek_client = OpenAI( api_key=Config.DEEPSEEK_API_KEY, base_url=Config.DEEPSEEK_URL ) yield # 关闭数据库连接池 await app.state.kb_dao.mysql_pool.close() app = FastAPI(lifespan=lifespan) # 知识库CRUD接口 @app.get("/kb") async def list_kbs(): """获取所有知识库列表""" return await app.state.kb_dao.list_kbs() @app.post("/kb") async def create_kb(kb: dict): """创建知识库""" return await app.state.kb_dao.create_kb(kb) @app.get("/kb/{kb_id}") async def read_kb(kb_id: int): """获取知识库详情""" return await app.state.kb_dao.get_kb(kb_id) @app.post("/kb/update/{kb_id}") async def update_kb(kb_id: int, kb: dict): """更新知识库信息""" return await app.state.kb_dao.update_kb(kb_id, kb) @app.delete("/kb/{kb_id}") async def delete_kb(kb_id: int): """删除知识库""" return await app.state.kb_dao.delete_kb(kb_id) # 知识库文件CRUD接口 @app.post("/kb_file") async def create_kb_file(file: dict): """创建知识库文件记录""" return await app.state.kb_dao.create_kb_file(file) @app.get("/kb_files/{file_id}") async def read_kb_file(file_id: int): """获取文件详情""" return await app.state.kb_dao.get_kb_file(file_id) @app.post("/kb_files/update/{file_id}") async def update_kb_file(file_id: int, file: dict): """更新文件信息""" return await app.state.kb_dao.update_kb_file(file_id, file) @app.delete("/kb_files/{file_id}") async def delete_kb_file(file_id: int): """删除文件记录""" return await app.state.kb_dao.delete_kb_file(file_id) # 文件上传接口 @app.post("/upload") async def upload_file(kb_id: int, file: UploadFile = File(...)): """文件上传接口""" return await app.state.kb_dao.handle_upload(kb_id, file) def search_related_data(es, query): """搜索与查询相关的数据""" # 向量搜索 vector_results = es.search( index=Config.ES_CONFIG['default_index'], body={ "query": { "match": { "content": { "query": query, "analyzer": "ik_smart" } } }, "size": 5 } ) # 文本精确搜索 text_results = es.search( index="raw_texts", body={ "query": { "match": { "text.keyword": query } }, "size": 5 } ) # 合并结果 context = "" for hit in vector_results['hits']['hits']: context += f"向量相似度结果(score={hit['_score']}):\n{hit['_source']['text']}\n\n" for hit in text_results['hits']['hits']: context += f"文本精确匹配结果(score={hit['_score']}):\n{hit['_source']['text']}\n\n" return context async def generate_stream(client, es, query): """生成SSE流""" context = search_related_data(es, query) prompt = f"""根据以下关于'{query}'的相关信息,整理一份结构化的报告: 要求: 1. 分章节组织内容 2. 包含关键数据和事实 3. 语言简洁专业 相关信息: {context}""" try: response = client.chat.completions.create( model="deepseek-chat", messages=[ {"role": "system", "content": "你是一个专业的文档整理助手"}, {"role": "user", "content": prompt} ], temperature=0.3, stream=True ) for chunk in response: if chunk.choices[0].delta.content: yield {"data": chunk.choices[0].delta.content} await asyncio.sleep(0.01) except Exception as e: yield {"data": f"生成报告时出错: {str(e)}"} @app.get("/api/rag") async def rag_stream(query: str, request: Request): """RAG+DeepSeek流式接口""" return EventSourceResponse( generate_stream(request.app.state.deepseek_client, request.app.state.es, query) ) app.mount("/static", StaticFiles(directory="Static"), name="static") if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=8000)