diff --git a/dsRag/Start.py b/dsRag/Start.py index 32e0a190..ea7fcbf5 100644 --- a/dsRag/Start.py +++ b/dsRag/Start.py @@ -5,6 +5,7 @@ import uuid from contextlib import asynccontextmanager from io import BytesIO from logging.handlers import RotatingFileHandler +from typing import List import jieba # 导入 jieba 分词库 import uvicorn @@ -83,7 +84,7 @@ def text_to_embedding(text): return [0.0] * model.vector_size -async def generate_stream(client, milvus_pool, collection_manager, query): +async def generate_stream(client, milvus_pool, collection_manager, query, documents): # 从连接池获取连接 connection = milvus_pool.get_connection() try: @@ -95,8 +96,17 @@ async def generate_stream(client, milvus_pool, collection_manager, query): "metric_type": "L2", # 使用 L2 距离度量方式 "params": {"nprobe": MS_NPROBE} # 设置 IVF_FLAT 的 nprobe 参数 } + # 动态生成expr表达式 + if documents: + conditions = [f"array_contains(tags['tags'], '{doc}')" for doc in documents] + expr = " OR ".join(conditions) + else: + expr = "" # 如果没有选择文档,返回空字符串 # 7. 将文本转换为嵌入向量 - results = collection_manager.search(current_embedding, search_params, limit=5) # 返回 5 条结果 + results = collection_manager.search(current_embedding, + search_params, + expr=expr, # 使用in操作符 + limit=5) # 返回 5 条结果 # 3. 处理搜索结果 logger.info("最相关的知识库内容:") @@ -176,6 +186,7 @@ http://10.10.21.22:8000/static/ai.html class QueryRequest(BaseModel): query: str = Field(..., description="用户查询的问题") + documents: List[str] = Field(..., description="用户上传的文档") class SaveWordRequest(BaseModel): @@ -249,7 +260,8 @@ async def rag_stream(request: Request): request.app.state.deepseek_client, request.app.state.milvus_pool, request.app.state.collection_manager, - query_request.query + query_request.query, + query_request.documents ): return chunk diff --git a/dsRag/static/ai.html b/dsRag/static/ai.html index a7736596..6b1ff779 100644 --- a/dsRag/static/ai.html +++ b/dsRag/static/ai.html @@ -182,10 +182,9 @@
请在下方输入您的问题,答案将在此处显示
您也可以点击"示例问题"快速体验