|
|
|
@ -1,27 +1,26 @@
|
|
|
|
|
import asyncio
|
|
|
|
|
import urllib.parse
|
|
|
|
|
from contextlib import asynccontextmanager
|
|
|
|
|
from io import BytesIO
|
|
|
|
|
from logging.handlers import RotatingFileHandler
|
|
|
|
|
|
|
|
|
|
import html2text
|
|
|
|
|
import jieba # 导入 jieba 分词库
|
|
|
|
|
import uvicorn
|
|
|
|
|
from docx import Document
|
|
|
|
|
from fastapi import FastAPI, Request, HTTPException
|
|
|
|
|
from pydantic import BaseModel, Field, ValidationError
|
|
|
|
|
from fastapi.staticfiles import StaticFiles
|
|
|
|
|
from openai import OpenAI
|
|
|
|
|
from sse_starlette.sse import EventSourceResponse
|
|
|
|
|
from gensim.models import KeyedVectors
|
|
|
|
|
from openai import OpenAI
|
|
|
|
|
from pydantic import BaseModel, Field, ValidationError
|
|
|
|
|
from starlette.responses import StreamingResponse
|
|
|
|
|
|
|
|
|
|
from Config import Config
|
|
|
|
|
from Config.Config import MS_MODEL_PATH, MS_MODEL_LIMIT, MS_HOST, MS_PORT, MS_MAX_CONNECTIONS, MS_NPROBE, DEEPSEEK_API_KEY, DEEPSEEK_URL, MS_COLLECTION_NAME
|
|
|
|
|
from Config.Config import MS_MODEL_PATH, MS_MODEL_LIMIT, MS_HOST, MS_PORT, MS_MAX_CONNECTIONS, MS_NPROBE, \
|
|
|
|
|
MS_COLLECTION_NAME
|
|
|
|
|
from Milvus.Utils.MilvusCollectionManager import MilvusCollectionManager
|
|
|
|
|
from Milvus.Utils.MilvusConnectionPool import *
|
|
|
|
|
from Milvus.Utils.MilvusConnectionPool import MilvusConnectionPool
|
|
|
|
|
from docx import Document
|
|
|
|
|
from docx.shared import Inches
|
|
|
|
|
from io import BytesIO
|
|
|
|
|
import html2text
|
|
|
|
|
import urllib.parse
|
|
|
|
|
|
|
|
|
|
# 初始化日志
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
logger.setLevel(logging.INFO)
|
|
|
|
@ -93,7 +92,7 @@ async def generate_stream(client, milvus_pool, collection_manager, query):
|
|
|
|
|
results = collection_manager.search(current_embedding, search_params, limit=5) # 返回 2 条结果
|
|
|
|
|
|
|
|
|
|
# 3. 处理搜索结果
|
|
|
|
|
logger.info("最相关的历史对话:")
|
|
|
|
|
logger.info("最相关的知识库内容:")
|
|
|
|
|
context = ""
|
|
|
|
|
if results:
|
|
|
|
|
for hits in results:
|
|
|
|
@ -101,17 +100,18 @@ async def generate_stream(client, milvus_pool, collection_manager, query):
|
|
|
|
|
try:
|
|
|
|
|
# 查询非向量字段
|
|
|
|
|
record = collection_manager.query_by_id(hit.id)
|
|
|
|
|
logger.info(f"ID: {hit.id}")
|
|
|
|
|
logger.info(f"标签: {record['tags']}")
|
|
|
|
|
logger.info(f"用户问题: {record['user_input']}")
|
|
|
|
|
|
|
|
|
|
# 获取完整内容
|
|
|
|
|
full_content = record['tags'].get('full_content', record['user_input'])
|
|
|
|
|
context = context + full_content
|
|
|
|
|
|
|
|
|
|
logger.info(f"时间: {record['timestamp']}")
|
|
|
|
|
logger.info(f"距离: {hit.distance}")
|
|
|
|
|
logger.info("-" * 40) # 分隔线
|
|
|
|
|
if hit.distance < 0.88: # 设置距离阈值
|
|
|
|
|
logger.info(f"ID: {hit.id}")
|
|
|
|
|
logger.info(f"标签: {record['tags']}")
|
|
|
|
|
logger.info(f"用户问题: {record['user_input']}")
|
|
|
|
|
|
|
|
|
|
logger.info(f"时间: {record['timestamp']}")
|
|
|
|
|
logger.info(f"距离: {hit.distance}")
|
|
|
|
|
logger.info("-" * 40) # 分隔线
|
|
|
|
|
# 获取完整内容
|
|
|
|
|
full_content = record['tags'].get('full_content', record['user_input'])
|
|
|
|
|
context = context + full_content
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"查询失败: {e}")
|
|
|
|
|
else:
|
|
|
|
|