|
|
|
@ -0,0 +1,94 @@
|
|
|
|
|
from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility
|
|
|
|
|
import numpy as np
|
|
|
|
|
import time
|
|
|
|
|
|
|
|
|
|
# 1. 连接 Milvus
|
|
|
|
|
connections.connect("default", host="10.10.14.101", port="19530")
|
|
|
|
|
|
|
|
|
|
# 2. 定义集合的字段和模式
|
|
|
|
|
fields = [
|
|
|
|
|
FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
|
|
|
|
|
FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=500), # 存储对话文本
|
|
|
|
|
FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=128) # 假设嵌入向量维度为 128
|
|
|
|
|
]
|
|
|
|
|
schema = CollectionSchema(fields, description="Conversation history collection")
|
|
|
|
|
|
|
|
|
|
# 3. 创建集合
|
|
|
|
|
collection_name = "conversation_history"
|
|
|
|
|
if utility.has_collection(collection_name):
|
|
|
|
|
utility.drop_collection(collection_name) # 如果集合已存在,先删除
|
|
|
|
|
|
|
|
|
|
collection = Collection(name=collection_name, schema=schema)
|
|
|
|
|
|
|
|
|
|
# 4. 创建索引
|
|
|
|
|
index_params = {
|
|
|
|
|
"index_type": "IVF_FLAT", # 索引类型
|
|
|
|
|
"metric_type": "L2", # 距离度量方式
|
|
|
|
|
"params": {"nlist": 256} # 增加 nlist 的值
|
|
|
|
|
}
|
|
|
|
|
collection.create_index("embedding", index_params)
|
|
|
|
|
|
|
|
|
|
# 5. 加载集合
|
|
|
|
|
collection.load()
|
|
|
|
|
|
|
|
|
|
# 6. 模拟多次对话
|
|
|
|
|
# 假设历史对话的嵌入向量已经生成
|
|
|
|
|
history = [
|
|
|
|
|
{"text": "我今天心情不太好,因为工作压力很大。", "embedding": np.random.random(128).tolist()},
|
|
|
|
|
{"text": "我最近在学习 Python,感觉很有趣。", "embedding": np.random.random(128).tolist()},
|
|
|
|
|
{"text": "我打算周末去爬山,放松一下。", "embedding": np.random.random(128).tolist()},
|
|
|
|
|
{"text": "我昨天看了一部很棒的电影,推荐给你。", "embedding": np.random.random(128).tolist()}
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
# 将历史对话插入 Milvus
|
|
|
|
|
for item in history:
|
|
|
|
|
insert_result = collection.insert([[item["text"]], [item["embedding"]]])
|
|
|
|
|
print(f"插入结果: {insert_result}")
|
|
|
|
|
|
|
|
|
|
# 刷新集合
|
|
|
|
|
collection.flush()
|
|
|
|
|
|
|
|
|
|
# 调试:查询集合中的所有数据
|
|
|
|
|
print("集合中的数据:")
|
|
|
|
|
result = collection.query(expr="", output_fields=["id", "text"], limit=10)
|
|
|
|
|
if result:
|
|
|
|
|
for item in result:
|
|
|
|
|
print(f"ID: {item['id']}, Text: {item['text']}")
|
|
|
|
|
else:
|
|
|
|
|
print("集合中没有数据,请检查数据插入步骤。")
|
|
|
|
|
|
|
|
|
|
# 7. 模拟当前对话
|
|
|
|
|
# 假设当前对话的嵌入向量已经生成
|
|
|
|
|
current_text = "我最近工作压力很大,想找个方式放松一下。"
|
|
|
|
|
current_embedding = np.random.random(128).tolist()
|
|
|
|
|
|
|
|
|
|
# 8. 查询与当前对话最相关的历史对话
|
|
|
|
|
search_params = {
|
|
|
|
|
"metric_type": "L2",
|
|
|
|
|
"params": {"nprobe": 100} # 增加 nprobe 的值
|
|
|
|
|
}
|
|
|
|
|
start_time = time.time()
|
|
|
|
|
results = collection.search(
|
|
|
|
|
data=[current_embedding], # 查询向量
|
|
|
|
|
anns_field="embedding", # 查询字段
|
|
|
|
|
param=search_params,
|
|
|
|
|
limit=2 # 返回最相似的 2 个结果
|
|
|
|
|
)
|
|
|
|
|
end_time = time.time()
|
|
|
|
|
|
|
|
|
|
# 9. 输出查询结果
|
|
|
|
|
print("当前对话:", current_text)
|
|
|
|
|
print("最相关的历史对话:")
|
|
|
|
|
if results:
|
|
|
|
|
for hits in results:
|
|
|
|
|
for hit in hits:
|
|
|
|
|
text = collection.query(expr=f"id == {hit.id}", output_fields=["text"])[0]["text"]
|
|
|
|
|
print(f"- {text} (距离: {hit.distance})")
|
|
|
|
|
else:
|
|
|
|
|
print("未找到相关历史对话,请检查查询参数或数据。")
|
|
|
|
|
|
|
|
|
|
# 输出查询耗时
|
|
|
|
|
print(f"查询耗时: {end_time - start_time:.4f} 秒")
|
|
|
|
|
|
|
|
|
|
# 10. 关闭连接
|
|
|
|
|
connections.disconnect("default")
|