import asyncio from contextlib import asynccontextmanager from logging.handlers import RotatingFileHandler import jieba # 导入 jieba 分词库 import uvicorn from fastapi import FastAPI, Request, HTTPException from pydantic import BaseModel, Field, ValidationError from fastapi.staticfiles import StaticFiles from openai import OpenAI from sse_starlette.sse import EventSourceResponse from gensim.models import KeyedVectors from starlette.responses import StreamingResponse from Config import Config from Config.Config import MS_MODEL_PATH, MS_MODEL_LIMIT, MS_HOST, MS_PORT, MS_MAX_CONNECTIONS, MS_NPROBE, DEEPSEEK_API_KEY, DEEPSEEK_URL, MS_COLLECTION_NAME from Milvus.Utils.MilvusCollectionManager import MilvusCollectionManager from Milvus.Utils.MilvusConnectionPool import * from Milvus.Utils.MilvusConnectionPool import MilvusConnectionPool from docx import Document from docx.shared import Inches from io import BytesIO import html2text import urllib.parse # 初始化日志 logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) handler = RotatingFileHandler('Logs/start.log', maxBytes=1024 * 1024, backupCount=5) handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')) logger.addHandler(handler) # 1. 加载预训练的 Word2Vec 模型 model = KeyedVectors.load_word2vec_format(MS_MODEL_PATH, binary=False, limit=MS_MODEL_LIMIT) logger.info(f"模型加载成功,词向量维度: {model.vector_size}") @asynccontextmanager async def lifespan(app: FastAPI): # 初始化Milvus连接池 app.state.milvus_pool = MilvusConnectionPool(host=MS_HOST, port=MS_PORT, max_connections=MS_MAX_CONNECTIONS) # 初始化集合管理器 app.state.collection_manager = MilvusCollectionManager(MS_COLLECTION_NAME) app.state.collection_manager.load_collection() # 初始化DeepSeek客户端 app.state.deepseek_client = OpenAI( api_key=Config.DEEPSEEK_API_KEY, base_url=Config.DEEPSEEK_URL ) yield # 关闭Milvus连接池 app.state.milvus_pool.close() app = FastAPI(lifespan=lifespan) # 挂载静态文件目录 app.mount("/static", StaticFiles(directory="Static"), name="static") # 将文本转换为嵌入向量 def text_to_embedding(text): words = jieba.lcut(text) # 使用 jieba 分词 print(f"文本: {text}, 分词结果: {words}") embeddings = [model[word] for word in words if word in model] logger.info(f"有效词向量数量: {len(embeddings)}") if embeddings: avg_embedding = sum(embeddings) / len(embeddings) logger.info(f"生成的平均向量: {avg_embedding[:5]}...") # 打印前 5 维 return avg_embedding else: logger.warning("未找到有效词,返回零向量") return [0.0] * model.vector_size async def generate_stream(client, milvus_pool, collection_manager, query): """生成SSE流""" # 从连接池获取连接 connection = milvus_pool.get_connection() try: # 1. 将查询文本转换为向量 current_embedding = text_to_embedding(query) # 2. 搜索相关数据 search_params = { "metric_type": "L2", # 使用 L2 距离度量方式 "params": {"nprobe": MS_NPROBE} # 设置 IVF_FLAT 的 nprobe 参数 } # 7. 将文本转换为嵌入向量 results = collection_manager.search(current_embedding, search_params, limit=5) # 返回 2 条结果 # 3. 处理搜索结果 logger.info("最相关的历史对话:") context = "" if results: for hits in results: for hit in hits: try: # 查询非向量字段 record = collection_manager.query_by_id(hit.id) logger.info(f"ID: {hit.id}") logger.info(f"标签: {record['tags']}") logger.info(f"用户问题: {record['user_input']}") # 获取完整内容 full_content = record['tags'].get('full_content', record['user_input']) context = context + full_content logger.info(f"时间: {record['timestamp']}") logger.info(f"距离: {hit.distance}") logger.info("-" * 40) # 分隔线 except Exception as e: logger.error(f"查询失败: {e}") else: logger.warning("未找到相关历史对话,请检查查询参数或数据。") prompt = f""" 信息检索与回答助手 根据以下关于'{query}'的相关信息: 基本信息 - 语言: 中文 - 描述: 根据提供的材料检索信息并回答问题 - 特点: 快速准确提取关键信息,清晰简洁地回答 相关信息 {context} 回答要求 1. 依托给定的资料,快速准确地回答问题,可以添加一些额外的信息,但请勿重复内容。 2. 使用HTML格式返回,包含适当的段落、列表和标题标签 3. 确保内容结构清晰,便于前端展示 """ response = client.chat.completions.create( model="deepseek-chat", messages=[ {"role": "system", "content": "你是一个专业的文档整理助手"}, {"role": "user", "content": prompt} ], temperature=0.3, stream=False ) yield {"data": response.choices[0].message.content} except Exception as e: yield {"data": f"生成报告时出错: {str(e)}"} finally: # 释放连接 milvus_pool.release_connection(connection) """ http://10.10.21.22:8000/static/ai.html 小学数学中有哪些模型? """ class QueryRequest(BaseModel): query: str = Field(..., description="用户查询的问题") class SaveWordRequest(BaseModel): html: str = Field(..., description="要保存为Word的HTML内容") @app.post("/api/save-word") async def save_to_word(request: Request): try: # Parse request data try: data = await request.json() html_content = data.get('html_content', '') if not html_content: raise ValueError("Empty HTML content") except Exception as e: logger.error(f"Request parsing failed: {str(e)}") raise HTTPException(status_code=400, detail=f"Invalid request: {str(e)}") # Convert HTML to text try: text_maker = html2text.HTML2Text() text_maker.ignore_links = True text_maker.ignore_images = True text_content = text_maker.handle(html_content) except Exception as e: logger.error(f"HTML conversion failed: {str(e)}") raise HTTPException(status_code=400, detail=f"HTML processing error: {str(e)}") # Create Word document try: doc = Document() doc.add_heading('小学数学问答', 0) for para in text_content.split('\n\n'): if para.strip(): doc.add_paragraph(para.strip()) except Exception as e: logger.error(f"Document creation failed: {str(e)}") raise HTTPException(status_code=500, detail=f"Document creation error: {str(e)}") # Save to stream try: stream = BytesIO() doc.save(stream) stream.seek(0) except Exception as e: logger.error(f"Document saving failed: {str(e)}") raise HTTPException(status_code=500, detail=f"Document saving error: {str(e)}") # Return response filename = "小学数学问答.docx" encoded_filename = urllib.parse.quote(filename) return StreamingResponse( stream, media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document", headers={"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}"} ) except HTTPException: raise except Exception as e: logger.error(f"Unexpected error: {str(e)}") raise HTTPException(status_code=500, detail="Internal server error") @app.post("/api/rag") async def rag_stream(request: Request): try: data = await request.json() query_request = QueryRequest(**data) except ValidationError as e: logger.error(f"请求体验证失败: {e.errors()}") raise HTTPException(status_code=422, detail=e.errors()) except Exception as e: logger.error(f"请求解析失败: {str(e)}") raise HTTPException(status_code=400, detail="无效的请求格式") """RAG+DeepSeek接口""" async for chunk in generate_stream( request.app.state.deepseek_client, request.app.state.milvus_pool, request.app.state.collection_manager, query_request.query ): return chunk if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=8000)