diff --git a/dsRag/Start.py b/dsRag/Start.py index 4092d5d9..d8807fbf 100644 --- a/dsRag/Start.py +++ b/dsRag/Start.py @@ -1,27 +1,26 @@ -import asyncio +import urllib.parse from contextlib import asynccontextmanager +from io import BytesIO from logging.handlers import RotatingFileHandler +import html2text import jieba # 导入 jieba 分词库 import uvicorn +from docx import Document from fastapi import FastAPI, Request, HTTPException -from pydantic import BaseModel, Field, ValidationError from fastapi.staticfiles import StaticFiles -from openai import OpenAI -from sse_starlette.sse import EventSourceResponse from gensim.models import KeyedVectors +from openai import OpenAI +from pydantic import BaseModel, Field, ValidationError from starlette.responses import StreamingResponse from Config import Config -from Config.Config import MS_MODEL_PATH, MS_MODEL_LIMIT, MS_HOST, MS_PORT, MS_MAX_CONNECTIONS, MS_NPROBE, DEEPSEEK_API_KEY, DEEPSEEK_URL, MS_COLLECTION_NAME +from Config.Config import MS_MODEL_PATH, MS_MODEL_LIMIT, MS_HOST, MS_PORT, MS_MAX_CONNECTIONS, MS_NPROBE, \ + MS_COLLECTION_NAME from Milvus.Utils.MilvusCollectionManager import MilvusCollectionManager from Milvus.Utils.MilvusConnectionPool import * from Milvus.Utils.MilvusConnectionPool import MilvusConnectionPool -from docx import Document -from docx.shared import Inches -from io import BytesIO -import html2text -import urllib.parse + # 初始化日志 logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @@ -93,7 +92,7 @@ async def generate_stream(client, milvus_pool, collection_manager, query): results = collection_manager.search(current_embedding, search_params, limit=5) # 返回 2 条结果 # 3. 处理搜索结果 - logger.info("最相关的历史对话:") + logger.info("最相关的知识库内容:") context = "" if results: for hits in results: @@ -101,17 +100,18 @@ async def generate_stream(client, milvus_pool, collection_manager, query): try: # 查询非向量字段 record = collection_manager.query_by_id(hit.id) - logger.info(f"ID: {hit.id}") - logger.info(f"标签: {record['tags']}") - logger.info(f"用户问题: {record['user_input']}") - - # 获取完整内容 - full_content = record['tags'].get('full_content', record['user_input']) - context = context + full_content - - logger.info(f"时间: {record['timestamp']}") - logger.info(f"距离: {hit.distance}") - logger.info("-" * 40) # 分隔线 + if hit.distance < 0.88: # 设置距离阈值 + logger.info(f"ID: {hit.id}") + logger.info(f"标签: {record['tags']}") + logger.info(f"用户问题: {record['user_input']}") + + logger.info(f"时间: {record['timestamp']}") + logger.info(f"距离: {hit.distance}") + logger.info("-" * 40) # 分隔线 + # 获取完整内容 + full_content = record['tags'].get('full_content', record['user_input']) + context = context + full_content + except Exception as e: logger.error(f"查询失败: {e}") else: