diff --git a/dsLightRag/Volcengine/T2_CreateIndex.py b/dsLightRag/Volcengine/T2_CreateIndex.py index e8f461f8..468cd7a7 100644 --- a/dsLightRag/Volcengine/T2_CreateIndex.py +++ b/dsLightRag/Volcengine/T2_CreateIndex.py @@ -14,7 +14,8 @@ from dotenv import load_dotenv from volcenginesdkarkruntime import Ark from Config.Config import VOLC_SECRETKEY, VOLC_API_KEY, VOLC_ACCESSKEY -from Volcengine.VikingDBMemoryService import MEMORY_COLLECTION_NAME +from Volcengine.VikingDBMemoryService import MEMORY_COLLECTION_NAME, VikingDBMemoryException, VikingDBMemoryService, \ + initialize_services, ensure_collection_exists # 控制日志输出 logger = logging.getLogger('CollectionMemory') @@ -26,231 +27,6 @@ if not logger.handlers: handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')) logger.addHandler(handler) -class VikingDBMemoryException(Exception): - def __init__(self, code, request_id, message=None): - self.code = code - self.request_id = request_id - self.message = "{}, code:{},request_id:{}".format(message, self.code, self.request_id) - - def __str__(self): - return self.message - - -class VikingDBMemoryService(Service): - _instance_lock = threading.Lock() - - def __new__(cls, *args, **kwargs): - if not hasattr(VikingDBMemoryService, "_instance"): - with VikingDBMemoryService._instance_lock: - if not hasattr(VikingDBMemoryService, "_instance"): - VikingDBMemoryService._instance = object.__new__(cls) - return VikingDBMemoryService._instance - - def __init__(self, host="api-knowledgebase.mlp.cn-beijing.volces.com", region="cn-beijing", ak="", sk="", - sts_token="", scheme='https', - connection_timeout=30, socket_timeout=30): - self.service_info = VikingDBMemoryService.get_service_info(host, region, scheme, connection_timeout, - socket_timeout) - self.api_info = VikingDBMemoryService.get_api_info() - super(VikingDBMemoryService, self).__init__(self.service_info, self.api_info) - if ak: - self.set_ak(ak) - if sk: - self.set_sk(sk) - if sts_token: - self.set_session_token(session_token=sts_token) - try: - self.get_body("Ping", {}, json.dumps({})) - except Exception as e: - raise VikingDBMemoryException(1000028, "missed", "host or region is incorrect".format(str(e))) from None - - def setHeader(self, header): - api_info = VikingDBMemoryService.get_api_info() - for key in api_info: - for item in header: - api_info[key].header[item] = header[item] - self.api_info = api_info - - @staticmethod - def get_service_info(host, region, scheme, connection_timeout, socket_timeout): - service_info = ServiceInfo(host, {"Host": host}, - Credentials('', '', 'air', region), connection_timeout, socket_timeout, - scheme=scheme) - return service_info - - @staticmethod - def get_api_info(): - api_info = { - "CreateCollection": ApiInfo("POST", "/api/memory/collection/create", {}, {}, - {'Accept': 'application/json', 'Content-Type': 'application/json'}), - "GetCollection": ApiInfo("POST", "/api/memory/collection/info", {}, {}, - {'Accept': 'application/json', 'Content-Type': 'application/json'}), - "DropCollection": ApiInfo("POST", "/api/memory/collection/delete", {}, {}, - {'Accept': 'application/json', 'Content-Type': 'application/json'}), - "UpdateCollection": ApiInfo("POST", "/api/memory/collection/update", {}, {}, - {'Accept': 'application/json', 'Content-Type': 'application/json'}), - - "SearchMemory": ApiInfo("POST", "/api/memory/search", {}, {}, - {'Accept': 'application/json', 'Content-Type': 'application/json'}), - "AddSession": ApiInfo("POST", "/api/memory/session/add", {}, {}, - {'Accept': 'application/json', 'Content-Type': 'application/json'}), - - "Ping": ApiInfo("GET", "/api/memory/ping", {}, {}, - {'Accept': 'application/json', 'Content-Type': 'application/json'}), - } - return api_info - - def get_body(self, api, params, body): - if not (api in self.api_info): - raise Exception("no such api") - api_info = self.api_info[api] - r = self.prepare_request(api_info, params) - r.headers['Content-Type'] = 'application/json' - r.headers['Traffic-Source'] = 'SDK' - r.body = body - - SignerV4.sign(r, self.service_info.credentials) - - url = r.build() - resp = self.session.get(url, headers=r.headers, data=r.body, - timeout=(self.service_info.connection_timeout, self.service_info.socket_timeout)) - if resp.status_code == 200: - return json.dumps(resp.json()) - else: - raise Exception(resp.text.encode("utf-8")) - - def get_body_exception(self, api, params, body): - try: - res = self.get_body(api, params, body) - except Exception as e: - try: - res_json = json.loads(e.args[0].decode("utf-8")) - except: - raise VikingDBMemoryException(1000028, "missed", "json load res error, res:{}".format(str(e))) from None - code = res_json.get("code", 1000028) - request_id = res_json.get("request_id", 1000028) - message = res_json.get("message", None) - - raise VikingDBMemoryException(code, request_id, message) - - if res == '': - raise VikingDBMemoryException(1000028, "missed", - "empty response due to unknown error, please contact customer service") from None - return res - - def get_exception(self, api, params): - try: - res = self.get(api, params) - except Exception as e: - try: - res_json = json.loads(e.args[0].decode("utf-8")) - except: - raise VikingDBMemoryException(1000028, "missed", "json load res error, res:{}".format(str(e))) from None - code = res_json.get("code", 1000028) - request_id = res_json.get("request_id", 1000028) - message = res_json.get("message", None) - raise VikingDBMemoryException(code, request_id, message) - if res == '': - raise VikingDBMemoryException(1000028, "missed", - "empty response due to unknown error, please contact customer service") from None - return res - - def create_collection(self, collection_name, description="", custom_event_type_schemas=[], - custom_entity_type_schemas=[], builtin_event_types=[], builtin_entity_types=[]): - params = { - "CollectionName": collection_name, "Description": description, - "CustomEventTypeSchemas": custom_event_type_schemas, "CustomEntityTypeSchemas": custom_entity_type_schemas, - "BuiltinEventTypes": builtin_event_types, "BuiltinEntityTypes": builtin_entity_types, - } - res = self.json("CreateCollection", {}, json.dumps(params)) - return json.loads(res) - - def get_collection(self, collection_name): - params = {"CollectionName": collection_name} - res = self.json("GetCollection", {}, json.dumps(params)) - return json.loads(res) - - def drop_collection(self, collection_name): - params = {"CollectionName": collection_name} - res = self.json("DropCollection", {}, json.dumps(params)) - return json.loads(res) - - def update_collection(self, collection_name, custom_event_type_schemas=[], custom_entity_type_schemas=[], - builtin_event_types=[], builtin_entity_types=[]): - params = { - "CollectionName": collection_name, - "CustomEventTypeSchemas": custom_event_type_schemas, "CustomEntityTypeSchemas": custom_entity_type_schemas, - "BuiltinEventTypes": builtin_event_types, "BuiltinEntityTypes": builtin_entity_types, - } - res = self.json("UpdateCollection", {}, json.dumps(params)) - return json.loads(res) - - def search_memory(self, collection_name, query, filter, limit=10): - params = { - "collection_name": collection_name, - "query": query, - "limit": limit, - "filter": filter, - } - res = self.json("SearchMemory", {}, json.dumps(params)) - return json.loads(res) - - def add_session(self, collection_name, session_id, messages, metadata, entities=None): - params = { - "collection_name": collection_name, - "session_id": session_id, - "messages": messages, - "metadata": metadata, - } - if entities is not None: - params["entities"] = entities - res = self.json("AddSession", {}, json.dumps(params)) - return json.loads(res) - - - - -def initialize_services(): - load_dotenv() - ak = VOLC_ACCESSKEY - sk = VOLC_SECRETKEY - ark_api_key = VOLC_API_KEY - - if not all([ak, sk, ark_api_key]): - raise ValueError("必须在环境变量中设置 VOLC_ACCESSKEY, VOLC_SECRETKEY, 和 ARK_API_KEY。") - - memory_service = VikingDBMemoryService(ak=ak, sk=sk) - llm_client = Ark( - base_url="https://ark.cn-beijing.volces.com/api/v3", - api_key=ark_api_key, - ) - return memory_service, llm_client - - -def ensure_collection_exists(memory_service, collection_name): - """检查记忆集合是否存在,如果不存在则创建。""" - try: - memory_service.get_collection(collection_name) - logger.info(f"记忆集合 '{collection_name}' 已存在。") - except Exception as e: - error_message = str(e) - if "collection not exist" in error_message: - logger.info(f"记忆集合 '{collection_name}' 未找到,正在创建...") - try: - memory_service.create_collection( - collection_name=collection_name, - description="中文情感陪伴场景测试", - builtin_event_types=["sys_event_v1", "sys_profile_collect_v1"], - builtin_entity_types=["sys_profile_v1"] - ) - logger.info(f"记忆集合 '{collection_name}' 创建成功。") - logger.info("等待集合准备就绪...") - except Exception as create_e: - logger.info(f"创建集合失败: {create_e}") - raise - else: - logger.info(f"检查集合时出错: {e}") - raise def search_relevant_memories(memory_service, collection_name, user_id, query): diff --git a/dsLightRag/Volcengine/T3_Chat.py b/dsLightRag/Volcengine/T3_Chat.py deleted file mode 100644 index b0bbee72..00000000 --- a/dsLightRag/Volcengine/T3_Chat.py +++ /dev/null @@ -1,107 +0,0 @@ -import json -import logging -import time -from Volcengine.VikingDBMemoryService import initialize_services, MEMORY_COLLECTION_NAME -from volcenginesdkarkruntime import Ark -from Config.Config import VOLC_API_KEY - -# 配置日志 -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger('MemoryChatSimulator') - -def simulate_chat_session(memory_service, llm_client, user_id, session_topic): - """模拟完整对话会话:创建对话→存储记忆→搜索记忆""" - # 初始化对话历史 - conversation_history = [] - logger.info(f"\n=== 开始新对话会话: {session_topic} ===") - - # 模拟两轮对话 - user_messages = [ - f"你好,我是{user_id},{session_topic}", - "我最近感觉压力很大,不知道该怎么办" - ] - - for message in user_messages: - # 处理对话轮次 - assistant_reply = memory_service.handle_conversation_turn( - llm_client=llm_client, - user_id=user_id, - user_message=message, - conversation_history=conversation_history - ) - time.sleep(1) # 模拟思考延迟 - - # 归档对话记忆 - logger.info("\n=== 归档对话记忆 ===") - archive_success = memory_service.archive_conversation( - user_id=user_id, - assistant_id="simulated_assistant", - conversation_history=conversation_history, - topic_name=session_topic - ) - - # 搜索相关记忆 - if archive_success: - logger.info("\n=== 搜索相关记忆 ===") - # 等待索引更新 - logger.info("等待记忆库索引更新...") - time.sleep(5) # 实际环境可能需要更长时间 - - # 搜索与当前主题相关的记忆 - search_query = f"{session_topic}相关的问题" - relevant_memories = memory_service.search_relevant_memories( - collection_name=MEMORY_COLLECTION_NAME, - user_id=user_id, - query=search_query, - limit=3 - ) - - logger.info(f"\n搜索到 '{search_query}' 的相关记忆 ({len(relevant_memories)}条):") - for i, memory in enumerate(relevant_memories, 1): - logger.info(f"\n记忆 {i} (相关度: {memory['score']:.3f}):") - logger.info(f"内容: {memory['memory_info']}") - - return archive_success - -def main(): - """主函数:初始化服务并执行模拟测试""" - logger.info("===== 记忆库存储与搜索模拟测试 ====") - - try: - # 初始化服务 - logger.info("初始化记忆库服务和LLM客户端...") - memory_service, llm_client = initialize_services() - - # 确保集合就绪 - logger.info(f"检查集合 '{MEMORY_COLLECTION_NAME}' 是否就绪...") - - # 先验证集合是否存在 - try: - collection_info = memory_service.get_collection(MEMORY_COLLECTION_NAME) - logger.info(f"集合 '{MEMORY_COLLECTION_NAME}' 存在,详细信息: {json.dumps(collection_info, ensure_ascii=False)}...") - except Exception as e: - logger.error(f"集合 '{MEMORY_COLLECTION_NAME}' 不存在: {str(e)}") - logger.error("请先运行T2_CreateIndex.py创建集合") - return - - if not memory_service.wait_for_collection_ready(timeout=120): - logger.error("集合未就绪,无法继续测试") - return - - # 模拟两个不同用户的对话会话 - user_sessions = [ - ("student_001", "学习压力问题"), - ("teacher_001", "教学方法讨论") - ] - - for user_id, topic in user_sessions: - simulate_chat_session(memory_service, llm_client, user_id, topic) - time.sleep(2) # 会话间隔 - - logger.info("\n===== 模拟测试完成 ====") - - except Exception as e: - logger.error(f"测试过程中发生错误: {str(e)}", exc_info=True) - -if __name__ == "__main__": - main() \ No newline at end of file