This commit is contained in:
2025-09-07 13:13:21 +08:00
parent 26377b0baa
commit 8ac14dbdc9
6 changed files with 208 additions and 201 deletions

View File

@@ -1,5 +1,10 @@
import json
import logging
import sys
import os
# 添加当前目录到系统路径
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from VikingDBMemoryService import VikingDBMemoryService, MEMORY_COLLECTION_NAME
from Config.Config import VOLC_ACCESSKEY, VOLC_SECRETKEY

View File

@@ -1,7 +1,12 @@
import logging
import sys
import os
import time
import json
from Volcengine.VikingDBMemoryService import MEMORY_COLLECTION_NAME, initialize_services, ensure_collection_exists
from Volcengine.chat import handle_conversation_turn, archive_conversation
from volcenginesdkarkruntime import Ark
from Config.Config import VOLC_ACCESSKEY, VOLC_SECRETKEY, VOLC_API_KEY
from VikingDBMemoryService import VikingDBMemoryService, MEMORY_COLLECTION_NAME
# 控制日志输出
logger = logging.getLogger('CollectionMemory')
@@ -14,48 +19,97 @@ if not logger.handlers:
logger.addHandler(handler)
def main():
logger.info("开始端到端记忆测试...")
logger.info("开始创建索引...")
# 初始化记忆库服务
memory_service = VikingDBMemoryService(
ak=VOLC_ACCESSKEY,
sk=VOLC_SECRETKEY,
host="api-knowledgebase.mlp.cn-beijing.volces.com",
region="cn-beijing"
)
# 初始化LLM客户端
llm_client = Ark(
base_url="https://ark.cn-beijing.volces.com/api/v3",
api_key=VOLC_API_KEY
)
try:
memory_service, llm_client = initialize_services()
collection_name = MEMORY_COLLECTION_NAME
user_id = "system"
assistant_id = "assistant"
ensure_collection_exists(memory_service, collection_name)
# 确保集合存在
logger.info("检查/创建集合...")
memory_service.ensure_collection_exists(collection_name)
# 添加测试数据以触发索引构建
logger.info("添加测试数据...")
test_messages = [
{"role": "user", "content": "你好,我是测试用户"},
{"role": "assistant", "content": "你好,我是测试助手"}
]
test_metadata = {
"default_user_id": user_id,
"default_assistant_id": assistant_id,
"time": int(time.time() * 1000)
}
session_id = f"test_session_{int(time.time())}"
memory_service.add_session(
collection_name=collection_name,
session_id=session_id,
messages=test_messages,
metadata=test_metadata
)
logger.info("测试数据添加成功,等待索引构建...")
# 使用与chat.py.backup相同的等待索引就绪的逻辑
max_retries = 30
retry_interval = 60 # 秒与chat.py.backup一致
for retry in range(max_retries):
try:
# 尝试搜索以验证索引是否就绪
filter_params = {
"user_id": user_id,
"memory_type": ["sys_event_v1", "sys_profile_v1"]
}
response = memory_service.search_memory(
collection_name=collection_name,
query="测试查询",
filter=filter_params,
limit=1
)
# 如果搜索成功,说明索引已就绪
logger.info(f"索引已就绪,找到 {response.get('data', {}).get('count', 0)} 条记录")
break
except Exception as e:
logger.info(f"初始化失败: {e}")
return
error_message = str(e)
if "1000023" in error_message: # 与chat.py.backup中的错误码一致
retry_attempt = retry + 1
logger.info(f"记忆索引正在构建中。将在{retry_interval}秒后重试... (尝试次数 {retry_attempt})")
time.sleep(retry_interval)
else:
logger.error(f"搜索时发生错误: {error_message}")
raise
else:
# 如果循环正常结束未break说明超时
logger.error(f"索引构建超时,已尝试 {max_retries}")
sys.exit(1)
logger.info("\n--- 阶段 1: 初始对话 ---")
initial_conversation_history = []
handle_conversation_turn(
memory_service, llm_client, collection_name, user_id,
"你好我是小明今年18岁但压力好大。",
initial_conversation_history
)
handle_conversation_turn(
memory_service, llm_client, collection_name, user_id,
"马上就要高考了,家里人的期待好高。",
initial_conversation_history
)
logger.info("索引创建和测试完成!")
logger.info("\n--- 阶段 2: 归档记忆 ---")
archive_conversation(
memory_service, collection_name, user_id, assistant_id,
initial_conversation_history, "study_stress_discussion"
)
logger.info("\n--- 阶段 3: 验证记忆 ---")
verification_conversation_history = []
handle_conversation_turn(
memory_service, llm_client, collection_name, user_id,
"我最近很焦虑,不知道该怎么办。",
verification_conversation_history
)
logger.info("\n端到端记忆测试完成!")
except Exception as e:
logger.error(f"操作失败: {e}")
sys.exit(1)
if __name__ == "__main__":

View File

@@ -4,9 +4,10 @@ pip install --upgrade "volcengine-python-sdk[ark]"
"""
import json
import logging
import os
import threading
import time
import sys
import os
from dotenv import load_dotenv
from volcengine.ApiInfo import ApiInfo
@@ -21,6 +22,9 @@ from Config.Config import VOLC_SECRETKEY, VOLC_ACCESSKEY, VOLC_API_KEY
# 配置日志
logger = logging.getLogger('CollectionMemory')
logger.setLevel(logging.INFO)
# 只添加一次处理器,避免重复日志
if not logger.handlers:
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
logger.addHandler(handler)
@@ -28,100 +32,6 @@ logger.addHandler(handler)
# 记忆体集合名称
MEMORY_COLLECTION_NAME="dsideal_collection"
def initialize_services(ak=None, sk=None, ark_api_key=None):
"""初始化记忆数据库服务和LLM客户端"""
load_dotenv()
# 如果参数未提供,尝试从环境变量获取
if not ak:
ak = VOLC_ACCESSKEY
if not sk:
sk = VOLC_SECRETKEY
if not ark_api_key:
ark_api_key = VOLC_API_KEY
if not all([ak, sk, ark_api_key]):
raise ValueError("必须提供 VOLC_ACCESSKEY, VOLC_SECRETKEY, 和 VOLC_API_KEY。")
memory_service = VikingDBMemoryService(ak=ak, sk=sk)
llm_client = Ark(
base_url="https://ark.cn-beijing.volces.com/api/v3",
api_key=ark_api_key,
)
return memory_service, llm_client
def ensure_collection_exists(memory_service, collection_name, description="",
builtin_event_types=["sys_event_v1", "sys_profile_collect_v1"],
builtin_entity_types=["sys_profile_v1"]):
"""检查记忆集合是否存在,如果不存在则创建。"""
try:
memory_service.get_collection(collection_name)
print(f"记忆集合 '{collection_name}' 已存在。")
except Exception as e:
error_message = str(e)
if "collection not exist" in error_message:
print(f"记忆集合 '{collection_name}' 未找到,正在创建...")
try:
memory_service.create_collection(
collection_name=collection_name,
description=description,
builtin_event_types=builtin_event_types,
builtin_entity_types=builtin_entity_types
)
print(f"记忆集合 '{collection_name}' 创建成功。")
print("等待集合准备就绪...")
except Exception as create_e:
print(f"创建集合失败: {create_e}")
raise
else:
print(f"检查集合时出错: {e}")
raise
def search_relevant_memories(memory_service, collection_name, user_id, query, limit=3):
"""搜索与用户查询相关的记忆,并在索引构建中时重试。"""
print(f"正在搜索与 '{query}' 相关的记忆...")
retry_attempt = 0
while True:
try:
filter_params = {
"user_id": user_id, # 修正为字符串类型
"memory_type": ["sys_event_v1", "sys_profile_v1"]
}
response = memory_service.search_memory(
collection_name=collection_name,
query=query,
filter=filter_params,
limit=limit
)
memories = []
if response.get('data', {}).get('count', 0) > 0:
for result in response['data']['result_list']:
if 'memory_info' in result and result['memory_info']:
memories.append({
'memory_info': result['memory_info'],
'score': result['score']
})
if memories:
if retry_attempt > 0:
print("重试后搜索成功。")
print(f"找到 {len(memories)} 条相关记忆:")
for i, memory in enumerate(memories, 1):
print(
f" {i}. (相关度: {memory['score']:.3f}): {json.dumps(memory['memory_info'], ensure_ascii=False, indent=2)}")
else:
print("未找到相关记忆。")
return memories
except Exception as e:
error_message = str(e)
if "1000023" in error_message:
retry_attempt += 1
print(f"记忆索引正在构建中。将在60秒后重试... (尝试次数 {retry_attempt})")
time.sleep(60)
else:
print(f"搜索记忆时出错 (不可重试): {e}")
return []
class VikingDBMemoryException(Exception):
def __init__(self, code, request_id, message=None):
@@ -309,7 +219,8 @@ class VikingDBMemoryService(Service):
logger.info("\n" + "=" * 60)
logger.info(f"用户: {user_message}")
relevant_memories = self.search_relevant_memories(user_id, user_message)
# 修复调用正确的search_relevant_memories方法
relevant_memories = self.search_relevant_memories(MEMORY_COLLECTION_NAME, user_id, user_message)
system_prompt = "你是一个富有同情心、善于倾听的AI伙伴拥有长期记忆能力。你的目标是为用户提供情感支持和温暖的陪伴。"
if relevant_memories:
@@ -395,10 +306,105 @@ class VikingDBMemoryService(Service):
logger.error(f"集合 '{MEMORY_COLLECTION_NAME}'{timeout}秒内未就绪")
return False
def initialize_services(self, ak=None, sk=None, ark_api_key=None):
"""初始化记忆数据库服务和LLM客户端"""
load_dotenv()
# 如果参数未提供,尝试从环境变量获取
if not ak:
ak = VOLC_ACCESSKEY
if not sk:
sk = VOLC_SECRETKEY
if not ark_api_key:
ark_api_key = VOLC_API_KEY
if not all([ak, sk, ark_api_key]):
raise ValueError("必须提供 VOLC_ACCESSKEY, VOLC_SECRETKEY, 和 VOLC_API_KEY。")
memory_service = VikingDBMemoryService(ak=ak, sk=sk)
llm_client = Ark(
base_url="https://ark.cn-beijing.volces.com/api/v3",
api_key=ark_api_key,
)
return memory_service, llm_client
def ensure_collection_exists(self, collection_name, description="",
builtin_event_types=["sys_event_v1", "sys_profile_collect_v1"],
builtin_entity_types=["sys_profile_v1"]):
"""检查记忆集合是否存在,如果不存在则创建。"""
try:
self.get_collection(collection_name)
logger.info(f"记忆集合 '{collection_name}' 已存在。")
except Exception as e:
error_message = str(e)
if "collection not exist" in error_message:
logger.info(f"记忆集合 '{collection_name}' 未找到,正在创建...")
try:
self.create_collection(
collection_name=collection_name,
description=description,
builtin_event_types=builtin_event_types,
builtin_entity_types=builtin_entity_types
)
logger.info(f"记忆集合 '{collection_name}' 创建成功。")
logger.info("等待集合准备就绪...")
except Exception as create_e:
logger.info(f"创建集合失败: {create_e}")
raise
else:
logger.info(f"检查集合时出错: {e}")
raise
def search_relevant_memories(self, collection_name, user_id, query, limit=3):
"""搜索与用户查询相关的记忆,并在索引构建中时重试。"""
logger.info(f"正在搜索与 '{query}' 相关的记忆...")
retry_attempt = 0
while True:
try:
filter_params = {
"user_id": user_id, # 修正为字符串类型
"memory_type": ["sys_event_v1", "sys_profile_v1"]
}
response = self.search_memory(
collection_name=collection_name,
query=query,
filter=filter_params,
limit=limit
)
memories = []
if response.get('data', {}).get('count', 0) > 0:
for result in response['data']['result_list']:
if 'memory_info' in result and result['memory_info']:
memories.append({
'memory_info': result['memory_info'],
'score': result['score']
})
if memories:
if retry_attempt > 0:
logger.info("重试后搜索成功。")
logger.info(f"找到 {len(memories)} 条相关记忆:")
for i, memory in enumerate(memories, 1):
logger.info(
f" {i}. (相关度: {memory['score']:.3f}): {json.dumps(memory['memory_info'], ensure_ascii=False, indent=2)}")
else:
logger.info("未找到相关记忆。")
return memories
except Exception as e:
error_message = str(e)
if "1000023" in error_message:
retry_attempt += 1
logger.info(f"记忆索引正在构建中。将在60秒后重试... (尝试次数 {retry_attempt})")
time.sleep(60)
else:
logger.info(f"搜索记忆时出错 (不可重试): {e}")
return []
def setup_memory_collection(self):
"""独立封装记忆体创建逻辑返回memory_service供测试使用"""
try:
ensure_collection_exists(self, MEMORY_COLLECTION_NAME)
self.ensure_collection_exists(MEMORY_COLLECTION_NAME)
logger.info(f"记忆体 '{MEMORY_COLLECTION_NAME}' 创建/验证成功")
# 添加集合就绪等待
@@ -412,57 +418,3 @@ class VikingDBMemoryService(Service):
except Exception as e:
logger.info(f"记忆体创建失败: {e}")
return None
def run_end_to_end_test(self):
"""端到端记忆测试的主函数"""
logger.info("开始端到端记忆测试...")
try:
# 调用封装的记忆体创建函数
memory_service = self.setup_memory_collection()
if not memory_service:
return
llm_client = Ark(
base_url="https://ark.cn-beijing.volces.com/api/v3",
api_key=VOLC_API_KEY
)
user_id = "xiaoming" # 用户ID:小明
assistant_id = "assistant1" # 助手ID:助手1
except Exception as e:
logger.info(f"初始化失败: {e}")
return
logger.info("\n--- 阶段 1: 初始对话 ---")
initial_conversation_history = []
self.handle_conversation_turn(
llm_client, user_id,
"你好我是小明今年18岁但压力好大。",
initial_conversation_history
)
self.handle_conversation_turn(
llm_client, user_id,
"马上就要高考了,家里人的期待好高。",
initial_conversation_history
)
logger.info("\n--- 阶段 2: 归档记忆 ---")
self.archive_conversation(
user_id, assistant_id,
initial_conversation_history, "study_stress_discussion"
)
logger.info("\n--- 阶段 3: 验证记忆 ---")
verification_conversation_history = []
self.handle_conversation_turn(
llm_client, user_id,
"我最近很焦虑,不知道该怎么办。",
verification_conversation_history
)
logger.info("\n端到端记忆测试完成!")
if __name__ == "__main__":
# 初始化服务
memory_service, _ = initialize_services()
# 运行端到端测试
memory_service.run_end_to_end_test()

View File

@@ -1,12 +1,13 @@
import json
import threading
import time
from dotenv import load_dotenv
from volcengine.ApiInfo import ApiInfo
from volcengine.Credentials import Credentials
from volcengine.base.Service import Service
from volcengine.ServiceInfo import ServiceInfo
from volcengine.auth.SignerV4 import SignerV4
from volcengine.base.Request import Request
from volcengine.base.Service import Service
from volcenginesdkarkruntime import Ark
from Config.Config import VOLC_SECRETKEY, VOLC_API_KEY, VOLC_ACCESSKEY
@@ -193,11 +194,6 @@ class VikingDBMemoryService(Service):
return json.loads(res)
import json
import os
import time
from dotenv import load_dotenv
from volcenginesdkarkruntime import Ark
def initialize_services():