main
HuangHai 4 months ago
parent 0f31b5597a
commit fd10f458a8

@ -26,27 +26,6 @@ def extract_keywords(text, topK=3):
keywords = jieba.analyse.extract_tags(text, topK=topK)
return keywords
# 构建查询条件
def build_query_expr(session_id, keywords):
"""
构建查询条件
:param session_id: 用户会话 ID
:param keywords: 关键词列表
:return: 查询条件表达式
"""
# 基础条件:过滤 session_id
expr = f"session_id == '{session_id}'"
# 添加关键词条件
if keywords:
keyword_conditions = []
for keyword in keywords:
if len(keyword) > 1: # 过滤过短的关键词
# 使用前缀匹配
keyword_conditions.append(f"user_input like '{keyword}%'")
if keyword_conditions:
expr += " and (" + " or ".join(keyword_conditions) + ")"
return expr
# 初始化 Word2Vec 模型
model_path = MS_MODEL_PATH
model = KeyedVectors.load_word2vec_format(model_path, binary=False, limit=MS_MODEL_LIMIT)
@ -93,9 +72,6 @@ client = OpenAI(
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
)
# 设置相似度阈值
SIMILARITY_THRESHOLD = 0.5 # 距离小于 0.5 的结果被认为是高相似度
@app.post("/reply")
async def reply(session_id: str = Form(...), prompt: str = Form(...)):
"""
@ -111,15 +87,7 @@ async def reply(session_id: str = Form(...), prompt: str = Form(...)):
# 将用户输入转换为嵌入向量
current_embedding = text_to_embedding(prompt)
# 提取用户输入的关键词
keywords = extract_keywords(prompt)
print(f"提取的关键词: {keywords}")
# 构建查询条件
expr = build_query_expr(session_id, keywords)
print(f"查询条件: {expr}")
# 查询与当前对话最相关的五条历史交互
# 查询与当前对话最相关的历史交互
search_params = {
"metric_type": "L2", # 使用 L2 距离度量方式
"params": {"nprobe": MS_NPROBE} # 设置 IVF_FLAT 的 nprobe 参数
@ -128,7 +96,6 @@ async def reply(session_id: str = Form(...), prompt: str = Form(...)):
results = collection_manager.search(
data=current_embedding, # 输入向量
search_params=search_params, # 搜索参数
expr=expr, # 查询条件
limit=5 # 返回 5 条结果
)
end_time = time.time()
@ -139,18 +106,12 @@ async def reply(session_id: str = Form(...), prompt: str = Form(...)):
for hits in results:
for hit in hits:
try:
# 过滤低相似度的结果
if hit.distance > SIMILARITY_THRESHOLD:
print(f"跳过低相似度结果,距离: {hit.distance}")
continue
# 查询非向量字段
record = collection_manager.query_by_id(hit.id)
if record:
print(f"查询到的记录: {record}")
# 只添加与当前问题高度相关的历史交互
if any(keyword in record['user_input'] for keyword in keywords):
history_prompt += f"用户: {record['user_input']}\n大模型: {record['model_response']}\n"
# 添加历史交互
history_prompt += f"用户: {record['user_input']}\n大模型: {record['model_response']}\n"
except Exception as e:
print(f"查询失败: {e}")

Loading…
Cancel
Save