You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

94 lines
3.3 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility
import numpy as np
import time
# 1. 连接 Milvus
connections.connect("default", host="10.10.14.101", port="19530")
# 2. 定义集合的字段和模式
fields = [
FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=500), # 存储对话文本
FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=128) # 假设嵌入向量维度为 128
]
schema = CollectionSchema(fields, description="Conversation history collection")
# 3. 创建集合
collection_name = "conversation_history"
if utility.has_collection(collection_name):
utility.drop_collection(collection_name) # 如果集合已存在,先删除
collection = Collection(name=collection_name, schema=schema)
# 4. 创建索引
index_params = {
"index_type": "IVF_FLAT", # 索引类型
"metric_type": "L2", # 距离度量方式
"params": {"nlist": 256} # 增加 nlist 的值
}
collection.create_index("embedding", index_params)
# 5. 加载集合
collection.load()
# 6. 模拟多次对话
# 假设历史对话的嵌入向量已经生成
history = [
{"text": "我今天心情不太好,因为工作压力很大。", "embedding": np.random.random(128).tolist()},
{"text": "我最近在学习 Python感觉很有趣。", "embedding": np.random.random(128).tolist()},
{"text": "我打算周末去爬山,放松一下。", "embedding": np.random.random(128).tolist()},
{"text": "我昨天看了一部很棒的电影,推荐给你。", "embedding": np.random.random(128).tolist()}
]
# 将历史对话插入 Milvus
for item in history:
insert_result = collection.insert([[item["text"]], [item["embedding"]]])
print(f"插入结果: {insert_result}")
# 刷新集合
collection.flush()
# 调试:查询集合中的所有数据
print("集合中的数据:")
result = collection.query(expr="", output_fields=["id", "text"], limit=10)
if result:
for item in result:
print(f"ID: {item['id']}, Text: {item['text']}")
else:
print("集合中没有数据,请检查数据插入步骤。")
# 7. 模拟当前对话
# 假设当前对话的嵌入向量已经生成
current_text = "我最近工作压力很大,想找个方式放松一下。"
current_embedding = np.random.random(128).tolist()
# 8. 查询与当前对话最相关的历史对话
search_params = {
"metric_type": "L2",
"params": {"nprobe": 100} # 增加 nprobe 的值
}
start_time = time.time()
results = collection.search(
data=[current_embedding], # 查询向量
anns_field="embedding", # 查询字段
param=search_params,
limit=2 # 返回最相似的 2 个结果
)
end_time = time.time()
# 9. 输出查询结果
print("当前对话:", current_text)
print("最相关的历史对话:")
if results:
for hits in results:
for hit in hits:
text = collection.query(expr=f"id == {hit.id}", output_fields=["text"])[0]["text"]
print(f"- {text} (距离: {hit.distance})")
else:
print("未找到相关历史对话,请检查查询参数或数据。")
# 输出查询耗时
print(f"查询耗时: {end_time - start_time:.4f}")
# 10. 关闭连接
connections.disconnect("default")