import urllib.parse from contextlib import asynccontextmanager from io import BytesIO from logging.handlers import RotatingFileHandler import html2text import jieba # 导入 jieba 分词库 import uvicorn from docx import Document from fastapi import FastAPI, Request, HTTPException from fastapi.staticfiles import StaticFiles from gensim.models import KeyedVectors from openai import OpenAI from pydantic import BaseModel, Field, ValidationError from starlette.responses import StreamingResponse from Config import Config from Config.Config import MS_MODEL_PATH, MS_MODEL_LIMIT, MS_HOST, MS_PORT, MS_MAX_CONNECTIONS, MS_NPROBE, \ MS_COLLECTION_NAME from Milvus.Utils.MilvusCollectionManager import MilvusCollectionManager from Milvus.Utils.MilvusConnectionPool import * from Milvus.Utils.MilvusConnectionPool import MilvusConnectionPool # 初始化日志 logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) handler = RotatingFileHandler('Logs/start.log', maxBytes=1024 * 1024, backupCount=5) handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')) logger.addHandler(handler) # 1. 加载预训练的 Word2Vec 模型 model = KeyedVectors.load_word2vec_format(MS_MODEL_PATH, binary=False, limit=MS_MODEL_LIMIT) logger.info(f"模型加载成功,词向量维度: {model.vector_size}") @asynccontextmanager async def lifespan(app: FastAPI): # 初始化Milvus连接池 app.state.milvus_pool = MilvusConnectionPool(host=MS_HOST, port=MS_PORT, max_connections=MS_MAX_CONNECTIONS) # 初始化集合管理器 app.state.collection_manager = MilvusCollectionManager(MS_COLLECTION_NAME) app.state.collection_manager.load_collection() # 初始化DeepSeek客户端 app.state.deepseek_client = OpenAI( api_key=Config.DEEPSEEK_API_KEY, base_url=Config.DEEPSEEK_URL ) yield # 关闭Milvus连接池 app.state.milvus_pool.close() app = FastAPI(lifespan=lifespan) # 挂载静态文件目录 app.mount("/static", StaticFiles(directory="Static"), name="static") # 将文本转换为嵌入向量 def text_to_embedding(text): words = jieba.lcut(text) # 使用 jieba 分词 print(f"文本: {text}, 分词结果: {words}") embeddings = [model[word] for word in words if word in model] logger.info(f"有效词向量数量: {len(embeddings)}") if embeddings: avg_embedding = sum(embeddings) / len(embeddings) logger.info(f"生成的平均向量: {avg_embedding[:5]}...") # 打印前 5 维 return avg_embedding else: logger.warning("未找到有效词,返回零向量") return [0.0] * model.vector_size async def generate_stream(client, milvus_pool, collection_manager, query): """生成SSE流""" # 从连接池获取连接 connection = milvus_pool.get_connection() try: # 1. 将查询文本转换为向量 current_embedding = text_to_embedding(query) # 2. 搜索相关数据 search_params = { "metric_type": "L2", # 使用 L2 距离度量方式 "params": {"nprobe": MS_NPROBE} # 设置 IVF_FLAT 的 nprobe 参数 } # 7. 将文本转换为嵌入向量 results = collection_manager.search(current_embedding, search_params, limit=5) # 返回 2 条结果 # 3. 处理搜索结果 logger.info("最相关的知识库内容:") context = "" if results: for hits in results: for hit in hits: try: # 查询非向量字段 record = collection_manager.query_by_id(hit.id) if hit.distance < 0.88: # 设置距离阈值 logger.info(f"ID: {hit.id}") logger.info(f"标签: {record['tags']}") logger.info(f"用户问题: {record['user_input']}") logger.info(f"时间: {record['timestamp']}") logger.info(f"距离: {hit.distance}") logger.info("-" * 40) # 分隔线 # 获取完整内容 full_content = record['tags'].get('full_content', record['user_input']) context = context + full_content except Exception as e: logger.error(f"查询失败: {e}") else: logger.warning("未找到相关历史对话,请检查查询参数或数据。") prompt = f""" 信息检索与回答助手 根据以下关于'{query}'的相关信息: 基本信息 - 语言: 中文 - 描述: 根据提供的材料检索信息并回答问题 - 特点: 快速准确提取关键信息,清晰简洁地回答 相关信息 {context} 回答要求 1. 依托给定的资料,快速准确地回答问题,可以添加一些额外的信息,但请勿重复内容。 2. 使用HTML格式返回,包含适当的段落、列表和标题标签 3. 确保内容结构清晰,便于前端展示 """ response = client.chat.completions.create( model="deepseek-chat", messages=[ {"role": "system", "content": "你是一个专业的文档整理助手"}, {"role": "user", "content": prompt} ], temperature=0.3, stream=False ) yield {"data": response.choices[0].message.content} except Exception as e: yield {"data": f"生成报告时出错: {str(e)}"} finally: # 释放连接 milvus_pool.release_connection(connection) """ http://10.10.21.22:8000/static/ai.html 知识库中有的内容: 小学数学中有哪些模型? 帮我写一下 “如何理解点、线、面、体、角”的教学设计 知识库中没有的内容: 你知道黄海是谁吗? """ class QueryRequest(BaseModel): query: str = Field(..., description="用户查询的问题") class SaveWordRequest(BaseModel): html: str = Field(..., description="要保存为Word的HTML内容") @app.post("/api/save-word") async def save_to_word(request: Request): try: # Parse request data try: data = await request.json() html_content = data.get('html_content', '') if not html_content: raise ValueError("Empty HTML content") except Exception as e: logger.error(f"Request parsing failed: {str(e)}") raise HTTPException(status_code=400, detail=f"Invalid request: {str(e)}") # Convert HTML to text try: text_maker = html2text.HTML2Text() text_maker.ignore_links = True text_maker.ignore_images = True text_content = text_maker.handle(html_content) except Exception as e: logger.error(f"HTML conversion failed: {str(e)}") raise HTTPException(status_code=400, detail=f"HTML processing error: {str(e)}") # Create Word document try: doc = Document() doc.add_heading('小学数学问答', 0) for para in text_content.split('\n\n'): if para.strip(): doc.add_paragraph(para.strip()) except Exception as e: logger.error(f"Document creation failed: {str(e)}") raise HTTPException(status_code=500, detail=f"Document creation error: {str(e)}") # Save to stream try: stream = BytesIO() doc.save(stream) stream.seek(0) except Exception as e: logger.error(f"Document saving failed: {str(e)}") raise HTTPException(status_code=500, detail=f"Document saving error: {str(e)}") # Return response filename = "小学数学问答.docx" encoded_filename = urllib.parse.quote(filename) return StreamingResponse( stream, media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document", headers={"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}"} ) except HTTPException: raise except Exception as e: logger.error(f"Unexpected error: {str(e)}") raise HTTPException(status_code=500, detail="Internal server error") @app.post("/api/rag") async def rag_stream(request: Request): try: data = await request.json() query_request = QueryRequest(**data) except ValidationError as e: logger.error(f"请求体验证失败: {e.errors()}") raise HTTPException(status_code=422, detail=e.errors()) except Exception as e: logger.error(f"请求解析失败: {str(e)}") raise HTTPException(status_code=400, detail="无效的请求格式") """RAG+DeepSeek接口""" async for chunk in generate_stream( request.app.state.deepseek_client, request.app.state.milvus_pool, request.app.state.collection_manager, query_request.query ): return chunk if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=8000)