main
HuangHai 4 months ago
parent 7932179e71
commit 99fa38f1fe

@ -0,0 +1,94 @@
from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility
import numpy as np
import time
# 1. 连接 Milvus
connections.connect("default", host="10.10.14.101", port="19530")
# 2. 定义集合的字段和模式
fields = [
FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=500), # 存储对话文本
FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=128) # 假设嵌入向量维度为 128
]
schema = CollectionSchema(fields, description="Conversation history collection")
# 3. 创建集合
collection_name = "conversation_history"
if utility.has_collection(collection_name):
utility.drop_collection(collection_name) # 如果集合已存在,先删除
collection = Collection(name=collection_name, schema=schema)
# 4. 创建索引
index_params = {
"index_type": "IVF_FLAT", # 索引类型
"metric_type": "L2", # 距离度量方式
"params": {"nlist": 256} # 增加 nlist 的值
}
collection.create_index("embedding", index_params)
# 5. 加载集合
collection.load()
# 6. 模拟多次对话
# 假设历史对话的嵌入向量已经生成
history = [
{"text": "我今天心情不太好,因为工作压力很大。", "embedding": np.random.random(128).tolist()},
{"text": "我最近在学习 Python感觉很有趣。", "embedding": np.random.random(128).tolist()},
{"text": "我打算周末去爬山,放松一下。", "embedding": np.random.random(128).tolist()},
{"text": "我昨天看了一部很棒的电影,推荐给你。", "embedding": np.random.random(128).tolist()}
]
# 将历史对话插入 Milvus
for item in history:
insert_result = collection.insert([[item["text"]], [item["embedding"]]])
print(f"插入结果: {insert_result}")
# 刷新集合
collection.flush()
# 调试:查询集合中的所有数据
print("集合中的数据:")
result = collection.query(expr="", output_fields=["id", "text"], limit=10)
if result:
for item in result:
print(f"ID: {item['id']}, Text: {item['text']}")
else:
print("集合中没有数据,请检查数据插入步骤。")
# 7. 模拟当前对话
# 假设当前对话的嵌入向量已经生成
current_text = "我最近工作压力很大,想找个方式放松一下。"
current_embedding = np.random.random(128).tolist()
# 8. 查询与当前对话最相关的历史对话
search_params = {
"metric_type": "L2",
"params": {"nprobe": 100} # 增加 nprobe 的值
}
start_time = time.time()
results = collection.search(
data=[current_embedding], # 查询向量
anns_field="embedding", # 查询字段
param=search_params,
limit=2 # 返回最相似的 2 个结果
)
end_time = time.time()
# 9. 输出查询结果
print("当前对话:", current_text)
print("最相关的历史对话:")
if results:
for hits in results:
for hit in hits:
text = collection.query(expr=f"id == {hit.id}", output_fields=["text"])[0]["text"]
print(f"- {text} (距离: {hit.distance})")
else:
print("未找到相关历史对话,请检查查询参数或数据。")
# 输出查询耗时
print(f"查询耗时: {end_time - start_time:.4f}")
# 10. 关闭连接
connections.disconnect("default")
Loading…
Cancel
Save