You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
127 lines
4.5 KiB
127 lines
4.5 KiB
from pymilvus import Collection, utility, CollectionSchema
|
|
|
|
|
|
class MilvusCollectionManager:
|
|
def __init__(self, collection_name):
|
|
"""
|
|
初始化集合管理器
|
|
:param collection_name: 集合名称
|
|
"""
|
|
self.collection_name = collection_name
|
|
self.collection = None
|
|
self._load_collection_if_exists()
|
|
|
|
def _load_collection_if_exists(self):
|
|
"""
|
|
如果集合存在,则加载集合
|
|
"""
|
|
if utility.has_collection(self.collection_name):
|
|
self.collection = Collection(name=self.collection_name)
|
|
# print(f"集合 '{self.collection_name}' 已加载。")
|
|
else:
|
|
print(f"集合 '{self.collection_name}' 不存在。")
|
|
|
|
def create_collection(self, fields, schema_description):
|
|
"""
|
|
创建集合
|
|
:param fields: 字段列表
|
|
:param schema_description: 集合描述
|
|
"""
|
|
if utility.has_collection(self.collection_name):
|
|
utility.drop_collection(self.collection_name) # 如果集合已存在,先删除
|
|
schema = CollectionSchema(fields, description=schema_description)
|
|
self.collection = Collection(name=self.collection_name, schema=schema)
|
|
print(f"集合 '{self.collection_name}' 创建成功。")
|
|
|
|
def create_index(self, field_name, index_params):
|
|
"""
|
|
创建索引
|
|
:param field_name: 字段名称
|
|
:param index_params: 索引参数
|
|
"""
|
|
if self.collection is None:
|
|
raise Exception("集合未加载,请检查集合是否存在。")
|
|
self.collection.create_index(field_name, index_params)
|
|
print("索引创建成功。")
|
|
|
|
def insert_data(self, entities):
|
|
"""
|
|
插入数据
|
|
:param entities: 数据实体,格式为 [texts, embeddings]
|
|
"""
|
|
if self.collection is None:
|
|
raise Exception("集合未加载,请检查集合是否存在。")
|
|
self.collection.insert(entities)
|
|
|
|
def load_collection(self):
|
|
"""
|
|
加载集合到内存
|
|
"""
|
|
if self.collection is None:
|
|
raise Exception("集合未加载,请检查集合是否存在。")
|
|
self.collection.load()
|
|
|
|
def query_by_id(self, id):
|
|
"""
|
|
根据 ID 查询非向量字段
|
|
:param id: 记录的 ID
|
|
:return: 包含非向量字段的字典
|
|
"""
|
|
try:
|
|
# 使用 Milvus 的 query 方法查询指定 ID 的记录
|
|
results = self.collection.query(
|
|
expr=f"id == {id}", # 查询条件
|
|
output_fields=["id", "tags", "user_input", "timestamp"] # 返回的字段
|
|
)
|
|
if results:
|
|
return results[0] # 返回第一条记录
|
|
else:
|
|
return None
|
|
except Exception as e:
|
|
print(f"查询失败: {e}")
|
|
return None
|
|
|
|
def search(self, data, search_params, expr=None, limit=5):
|
|
"""
|
|
在集合中搜索与输入向量最相似的数据
|
|
:param data: 输入向量
|
|
:param search_params: 搜索参数
|
|
:param expr: 过滤条件(可选)
|
|
:param limit: 返回结果的数量
|
|
:return: 搜索结果
|
|
"""
|
|
try:
|
|
# 构建搜索参数
|
|
search_result = self.collection.search(
|
|
data=[data], # 输入向量
|
|
anns_field="embedding", # 向量字段名称
|
|
param=search_params, # 搜索参数
|
|
limit=limit, # 返回结果的数量
|
|
expr=expr # 过滤条件
|
|
)
|
|
return search_result
|
|
except Exception as e:
|
|
print(f"搜索失败: {e}")
|
|
return None
|
|
|
|
def query_text_by_id(self, id):
|
|
"""
|
|
根据 ID 查询对话文本
|
|
:param id: 数据 ID
|
|
:return: 对话文本
|
|
"""
|
|
if self.collection is None:
|
|
raise Exception("集合未加载,请检查集合是否存在。")
|
|
|
|
# 检查集合的字段定义
|
|
schema = self.collection.schema
|
|
field_names = [field.name for field in schema.fields]
|
|
if "text" not in field_names:
|
|
raise Exception(f"集合 '{self.collection_name}' 中不存在 'text' 字段,请检查集合定义。")
|
|
|
|
result = self.collection.query(expr=f"id == {id}", output_fields=["text"])
|
|
if result:
|
|
return result[0]["text"]
|
|
else:
|
|
return None
|