from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility import numpy as np import time # 1. 连接 Milvus connections.connect("default", host="10.10.14.101", port="19530") # 2. 定义集合的字段和模式 fields = [ FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True), FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=500), # 存储对话文本 FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=128) # 假设嵌入向量维度为 128 ] schema = CollectionSchema(fields, description="Conversation history collection") # 3. 创建集合 collection_name = "conversation_history" if utility.has_collection(collection_name): utility.drop_collection(collection_name) # 如果集合已存在,先删除 collection = Collection(name=collection_name, schema=schema) # 4. 创建索引 index_params = { "index_type": "IVF_FLAT", # 索引类型 "metric_type": "L2", # 距离度量方式 "params": {"nlist": 256} # 增加 nlist 的值 } collection.create_index("embedding", index_params) # 5. 加载集合 collection.load() # 6. 模拟多次对话 # 假设历史对话的嵌入向量已经生成 history = [ {"text": "我今天心情不太好,因为工作压力很大。", "embedding": np.random.random(128).tolist()}, {"text": "我最近在学习 Python,感觉很有趣。", "embedding": np.random.random(128).tolist()}, {"text": "我打算周末去爬山,放松一下。", "embedding": np.random.random(128).tolist()}, {"text": "我昨天看了一部很棒的电影,推荐给你。", "embedding": np.random.random(128).tolist()} ] # 将历史对话插入 Milvus for item in history: insert_result = collection.insert([[item["text"]], [item["embedding"]]]) print(f"插入结果: {insert_result}") # 刷新集合 collection.flush() # 调试:查询集合中的所有数据 print("集合中的数据:") result = collection.query(expr="", output_fields=["id", "text"], limit=10) if result: for item in result: print(f"ID: {item['id']}, Text: {item['text']}") else: print("集合中没有数据,请检查数据插入步骤。") # 7. 模拟当前对话 # 假设当前对话的嵌入向量已经生成 current_text = "我最近工作压力很大,想找个方式放松一下。" current_embedding = np.random.random(128).tolist() # 8. 查询与当前对话最相关的历史对话 search_params = { "metric_type": "L2", "params": {"nprobe": 100} # 增加 nprobe 的值 } start_time = time.time() results = collection.search( data=[current_embedding], # 查询向量 anns_field="embedding", # 查询字段 param=search_params, limit=2 # 返回最相似的 2 个结果 ) end_time = time.time() # 9. 输出查询结果 print("当前对话:", current_text) print("最相关的历史对话:") if results: for hits in results: for hit in hits: text = collection.query(expr=f"id == {hit.id}", output_fields=["text"])[0]["text"] print(f"- {text} (距离: {hit.distance})") else: print("未找到相关历史对话,请检查查询参数或数据。") # 输出查询耗时 print(f"查询耗时: {end_time - start_time:.4f} 秒") # 10. 关闭连接 connections.disconnect("default")