You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
82 lines
2.9 KiB
82 lines
2.9 KiB
# MilvusCollectionManager.py
|
|
from pymilvus import Collection, utility, CollectionSchema
|
|
|
|
|
|
class MilvusCollectionManager:
|
|
def __init__(self, collection_name):
|
|
"""
|
|
初始化集合管理器
|
|
:param collection_name: 集合名称
|
|
"""
|
|
self.collection_name = collection_name
|
|
self.collection = None
|
|
self._load_collection_if_exists()
|
|
|
|
def _load_collection_if_exists(self):
|
|
"""
|
|
如果集合存在,则加载集合
|
|
"""
|
|
if utility.has_collection(self.collection_name):
|
|
self.collection = Collection(name=self.collection_name)
|
|
print(f"集合 '{self.collection_name}' 已加载。")
|
|
else:
|
|
print(f"集合 '{self.collection_name}' 不存在。")
|
|
|
|
def create_collection(self, fields, schema_description):
|
|
"""
|
|
创建集合
|
|
:param fields: 字段列表
|
|
:param schema_description: 集合描述
|
|
"""
|
|
if utility.has_collection(self.collection_name):
|
|
utility.drop_collection(self.collection_name) # 如果集合已存在,先删除
|
|
schema = CollectionSchema(fields, description=schema_description)
|
|
self.collection = Collection(name=self.collection_name, schema=schema)
|
|
print(f"集合 '{self.collection_name}' 创建成功。")
|
|
|
|
def insert_data(self, entities):
|
|
"""
|
|
插入数据
|
|
:param entities: 要插入的数据
|
|
"""
|
|
if self.collection is None:
|
|
raise Exception("集合未创建,请先调用 create_collection 方法")
|
|
self.collection.insert(entities)
|
|
print("数据插入成功。")
|
|
|
|
def create_index(self, field_name, index_params):
|
|
"""
|
|
创建索引
|
|
:param field_name: 字段名称
|
|
:param index_params: 索引参数
|
|
"""
|
|
if self.collection is None:
|
|
raise Exception("集合未创建,请先调用 create_collection 方法")
|
|
self.collection.create_index(field_name, index_params)
|
|
print("索引创建成功。")
|
|
|
|
def load_collection(self):
|
|
"""
|
|
加载集合到内存
|
|
"""
|
|
if self.collection is None:
|
|
raise Exception("集合未创建,请先调用 create_collection 方法")
|
|
self.collection.load()
|
|
print("集合已加载到内存。")
|
|
|
|
def search(self, query_vector, search_params, limit=2):
|
|
"""
|
|
查询数据
|
|
:param query_vector: 查询向量
|
|
:param search_params: 查询参数
|
|
:param limit: 返回结果数量
|
|
:return: 查询结果
|
|
"""
|
|
if self.collection is None:
|
|
raise Exception("集合未创建,请先调用 create_collection 方法")
|
|
return self.collection.search(
|
|
data=[query_vector],
|
|
anns_field="embedding",
|
|
param=search_params,
|
|
limit=limit
|
|
) |