You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

71 lines
2.6 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# pip install gensim
import time
from WxMini.Milvus.Utils.MilvusCollectionManager import MilvusCollectionManager
from WxMini.Milvus.Utils.MilvusConnectionPool import *
from WxMini.Milvus.Config.MulvusConfig import *
from gensim.models import KeyedVectors
# 加载预训练的 Word2Vec 模型
model_path = "D:/Tencent_AILab_ChineseEmbedding/Tencent_AILab_ChineseEmbedding.txt" # 替换为你的 Word2Vec 模型路径
# 参考文档使用gensim之KeyedVectors操作词向量模型
# https://www.cnblogs.com/bill-h/p/14655224.html
# 读取词向量模型限定前10000个词
model = KeyedVectors.load_word2vec_format(model_path, binary=False, limit=10000)
# 将文本转换为嵌入向量
def text_to_embedding(text):
words = text.split()
embeddings = [model[word] for word in words if word in model]
if embeddings:
return sum(embeddings) / len(embeddings) # 取词向量的平均值
else:
return [0.0] * model.vector_size # 如果文本中没有有效词,返回零向量
# 1. 使用连接池管理 Milvus 连接
milvus_pool = MilvusConnectionPool(host=MS_HOST, port=MS_PORT, max_connections=MS_MAX_CONNECTIONS)
# 2. 从连接池中获取一个连接
connection = milvus_pool.get_connection()
# 3. 初始化集合管理器
collection_name = MS_COLLECTION_NAME
collection_manager = MilvusCollectionManager(collection_name)
# 4. 加载集合到内存
collection_manager.load_collection()
# 5. 输入一句话
input_text = input("请输入一句话:") # 例如:“我今天心情不太好”
# 6. 将文本转换为嵌入向量
current_embedding = text_to_embedding(input_text)
# 7. 查询与当前对话最相关的历史对话
search_params = {
"metric_type": "L2", # 使用 L2 距离度量方式
"params": {"nprobe": 100} # 设置 IVF_FLAT 的 nprobe 参数
}
start_time = time.time()
results = collection_manager.search(current_embedding, search_params, limit=5) # 返回 5 条结果
end_time = time.time()
# 8. 输出查询结果
print("最相关的历史对话:")
if results:
for hits in results:
for hit in hits:
try:
text = collection_manager.query_text_by_id(hit.id)
print(f"- {text} (距离: {hit.distance})")
except Exception as e:
print(f"查询失败: {e}")
else:
print("未找到相关历史对话,请检查查询参数或数据。")
# 9. 输出查询耗时
print(f"查询耗时: {end_time - start_time:.4f}")
# 10. 释放连接
milvus_pool.release_connection(connection)
# 11. 关闭连接池
milvus_pool.close()