|
|
|
@ -26,27 +26,6 @@ def extract_keywords(text, topK=3):
|
|
|
|
|
keywords = jieba.analyse.extract_tags(text, topK=topK)
|
|
|
|
|
return keywords
|
|
|
|
|
|
|
|
|
|
# 构建查询条件
|
|
|
|
|
def build_query_expr(session_id, keywords):
|
|
|
|
|
"""
|
|
|
|
|
构建查询条件
|
|
|
|
|
:param session_id: 用户会话 ID
|
|
|
|
|
:param keywords: 关键词列表
|
|
|
|
|
:return: 查询条件表达式
|
|
|
|
|
"""
|
|
|
|
|
# 基础条件:过滤 session_id
|
|
|
|
|
expr = f"session_id == '{session_id}'"
|
|
|
|
|
# 添加关键词条件
|
|
|
|
|
if keywords:
|
|
|
|
|
keyword_conditions = []
|
|
|
|
|
for keyword in keywords:
|
|
|
|
|
if len(keyword) > 1: # 过滤过短的关键词
|
|
|
|
|
# 使用前缀匹配
|
|
|
|
|
keyword_conditions.append(f"user_input like '{keyword}%'")
|
|
|
|
|
if keyword_conditions:
|
|
|
|
|
expr += " and (" + " or ".join(keyword_conditions) + ")"
|
|
|
|
|
return expr
|
|
|
|
|
|
|
|
|
|
# 初始化 Word2Vec 模型
|
|
|
|
|
model_path = MS_MODEL_PATH
|
|
|
|
|
model = KeyedVectors.load_word2vec_format(model_path, binary=False, limit=MS_MODEL_LIMIT)
|
|
|
|
@ -93,9 +72,6 @@ client = OpenAI(
|
|
|
|
|
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# 设置相似度阈值
|
|
|
|
|
SIMILARITY_THRESHOLD = 0.5 # 距离小于 0.5 的结果被认为是高相似度
|
|
|
|
|
|
|
|
|
|
@app.post("/reply")
|
|
|
|
|
async def reply(session_id: str = Form(...), prompt: str = Form(...)):
|
|
|
|
|
"""
|
|
|
|
@ -111,15 +87,7 @@ async def reply(session_id: str = Form(...), prompt: str = Form(...)):
|
|
|
|
|
# 将用户输入转换为嵌入向量
|
|
|
|
|
current_embedding = text_to_embedding(prompt)
|
|
|
|
|
|
|
|
|
|
# 提取用户输入的关键词
|
|
|
|
|
keywords = extract_keywords(prompt)
|
|
|
|
|
print(f"提取的关键词: {keywords}")
|
|
|
|
|
|
|
|
|
|
# 构建查询条件
|
|
|
|
|
expr = build_query_expr(session_id, keywords)
|
|
|
|
|
print(f"查询条件: {expr}")
|
|
|
|
|
|
|
|
|
|
# 查询与当前对话最相关的五条历史交互
|
|
|
|
|
# 查询与当前对话最相关的历史交互
|
|
|
|
|
search_params = {
|
|
|
|
|
"metric_type": "L2", # 使用 L2 距离度量方式
|
|
|
|
|
"params": {"nprobe": MS_NPROBE} # 设置 IVF_FLAT 的 nprobe 参数
|
|
|
|
@ -128,7 +96,6 @@ async def reply(session_id: str = Form(...), prompt: str = Form(...)):
|
|
|
|
|
results = collection_manager.search(
|
|
|
|
|
data=current_embedding, # 输入向量
|
|
|
|
|
search_params=search_params, # 搜索参数
|
|
|
|
|
expr=expr, # 查询条件
|
|
|
|
|
limit=5 # 返回 5 条结果
|
|
|
|
|
)
|
|
|
|
|
end_time = time.time()
|
|
|
|
@ -139,18 +106,12 @@ async def reply(session_id: str = Form(...), prompt: str = Form(...)):
|
|
|
|
|
for hits in results:
|
|
|
|
|
for hit in hits:
|
|
|
|
|
try:
|
|
|
|
|
# 过滤低相似度的结果
|
|
|
|
|
if hit.distance > SIMILARITY_THRESHOLD:
|
|
|
|
|
print(f"跳过低相似度结果,距离: {hit.distance}")
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
# 查询非向量字段
|
|
|
|
|
record = collection_manager.query_by_id(hit.id)
|
|
|
|
|
if record:
|
|
|
|
|
print(f"查询到的记录: {record}")
|
|
|
|
|
# 只添加与当前问题高度相关的历史交互
|
|
|
|
|
if any(keyword in record['user_input'] for keyword in keywords):
|
|
|
|
|
history_prompt += f"用户: {record['user_input']}\n大模型: {record['model_response']}\n"
|
|
|
|
|
# 添加历史交互
|
|
|
|
|
history_prompt += f"用户: {record['user_input']}\n大模型: {record['model_response']}\n"
|
|
|
|
|
except Exception as e:
|
|
|
|
|
print(f"查询失败: {e}")
|
|
|
|
|
|
|
|
|
|