You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

101 lines
3.6 KiB

from pymilvus import Collection, utility, CollectionSchema
class MilvusCollectionManager:
def __init__(self, collection_name):
"""
初始化集合管理器
:param collection_name: 集合名称
"""
self.collection_name = collection_name
self.collection = None
self._load_collection_if_exists()
def _load_collection_if_exists(self):
"""
如果集合存在,则加载集合
"""
if utility.has_collection(self.collection_name):
self.collection = Collection(name=self.collection_name)
print(f"集合 '{self.collection_name}' 已加载。")
else:
print(f"集合 '{self.collection_name}' 不存在。")
def create_collection(self, fields, schema_description):
"""
创建集合
:param fields: 字段列表
:param schema_description: 集合描述
"""
if utility.has_collection(self.collection_name):
utility.drop_collection(self.collection_name) # 如果集合已存在,先删除
schema = CollectionSchema(fields, description=schema_description)
self.collection = Collection(name=self.collection_name, schema=schema)
print(f"集合 '{self.collection_name}' 创建成功。")
def create_index(self, field_name, index_params):
"""
创建索引
:param field_name: 字段名称
:param index_params: 索引参数
"""
if self.collection is None:
raise Exception("集合未加载,请检查集合是否存在。")
self.collection.create_index(field_name, index_params)
print("索引创建成功。")
def insert_data(self, entities):
"""
插入数据
:param entities: 数据实体,格式为 [texts, embeddings]
"""
if self.collection is None:
raise Exception("集合未加载,请检查集合是否存在。")
self.collection.insert(entities)
print("数据插入成功。")
def load_collection(self):
"""
加载集合到内存
"""
if self.collection is None:
raise Exception("集合未加载,请检查集合是否存在。")
self.collection.load()
print("集合已加载到内存。")
def search(self, query_embedding, search_params, limit=2):
"""
查询数据
:param query_embedding: 查询向量
:param search_params: 查询参数
:param limit: 返回结果数量
:return: 查询结果
"""
if self.collection is None:
raise Exception("集合未加载,请检查集合是否存在。")
return self.collection.search(
data=[query_embedding],
anns_field="embedding",
param=search_params,
limit=limit
)
def query_text_by_id(self, id):
"""
根据 ID 查询对话文本
:param id: 数据 ID
:return: 对话文本
"""
if self.collection is None:
raise Exception("集合未加载,请检查集合是否存在。")
# 检查集合的字段定义
schema = self.collection.schema
field_names = [field.name for field in schema.fields]
if "text" not in field_names:
raise Exception(f"集合 '{self.collection_name}' 中不存在 'text' 字段,请检查集合定义。")
result = self.collection.query(expr=f"id == {id}", output_fields=["text"])
if result:
return result[0]["text"]
else:
return None