This commit is contained in:
2025-09-07 07:37:09 +08:00
parent 34388b5bd4
commit f27f6e0baa
2 changed files with 99 additions and 100 deletions

View File

@@ -11,7 +11,105 @@ from volcengine.base.Service import Service
from volcengine.ServiceInfo import ServiceInfo from volcengine.ServiceInfo import ServiceInfo
from volcengine.auth.SignerV4 import SignerV4 from volcengine.auth.SignerV4 import SignerV4
from volcengine.base.Request import Request from volcengine.base.Request import Request
import os
import time
from dotenv import load_dotenv
from volcenginesdkarkruntime import Ark
def initialize_services(ak=None, sk=None, ark_api_key=None):
"""初始化记忆数据库服务和LLM客户端"""
load_dotenv()
# 如果参数未提供,尝试从环境变量获取
if not ak:
ak = os.getenv("VOLC_ACCESSKEY")
if not sk:
sk = os.getenv("VOLC_SECRETKEY")
if not ark_api_key:
ark_api_key = os.getenv("VOLC_API_KEY")
if not all([ak, sk, ark_api_key]):
raise ValueError("必须提供 VOLC_ACCESSKEY, VOLC_SECRETKEY, 和 VOLC_API_KEY。")
memory_service = VikingDBMemoryService(ak=ak, sk=sk)
llm_client = Ark(
base_url="https://ark.cn-beijing.volces.com/api/v3",
api_key=ark_api_key,
)
return memory_service, llm_client
def ensure_collection_exists(memory_service, collection_name, description="",
builtin_event_types=["sys_event_v1", "sys_profile_collect_v1"],
builtin_entity_types=["sys_profile_v1"]):
"""检查记忆集合是否存在,如果不存在则创建。"""
try:
memory_service.get_collection(collection_name)
print(f"记忆集合 '{collection_name}' 已存在。")
except Exception as e:
error_message = str(e)
if "collection not exist" in error_message:
print(f"记忆集合 '{collection_name}' 未找到,正在创建...")
try:
memory_service.create_collection(
collection_name=collection_name,
description=description,
builtin_event_types=builtin_event_types,
builtin_entity_types=builtin_entity_types
)
print(f"记忆集合 '{collection_name}' 创建成功。")
print("等待集合准备就绪...")
except Exception as create_e:
print(f"创建集合失败: {create_e}")
raise
else:
print(f"检查集合时出错: {e}")
raise
def search_relevant_memories(memory_service, collection_name, user_id, query, limit=3):
"""搜索与用户查询相关的记忆,并在索引构建中时重试。"""
print(f"正在搜索与 '{query}' 相关的记忆...")
retry_attempt = 0
while True:
try:
filter_params = {
"user_id": [user_id],
"memory_type": ["sys_event_v1", "sys_profile_v1"]
}
response = memory_service.search_memory(
collection_name=collection_name,
query=query,
filter=filter_params,
limit=limit
)
memories = []
if response.get('data', {}).get('count', 0) > 0:
for result in response['data']['result_list']:
if 'memory_info' in result and result['memory_info']:
memories.append({
'memory_info': result['memory_info'],
'score': result['score']
})
if memories:
if retry_attempt > 0:
print("重试后搜索成功。")
print(f"找到 {len(memories)} 条相关记忆:")
for i, memory in enumerate(memories, 1):
print(
f" {i}. (相关度: {memory['score']:.3f}): {json.dumps(memory['memory_info'], ensure_ascii=False, indent=2)}")
else:
print("未找到相关记忆。")
return memories
except Exception as e:
error_message = str(e)
if "1000023" in error_message:
retry_attempt += 1
print(f"记忆索引正在构建中。将在60秒后重试... (尝试次数 {retry_attempt})")
time.sleep(60)
else:
print(f"搜索记忆时出错 (不可重试): {e}")
return []
class VikingDBMemoryException(Exception): class VikingDBMemoryException(Exception):
def __init__(self, code, request_id, message=None): def __init__(self, code, request_id, message=None):

View File

@@ -1,11 +1,5 @@
import json import json
import os from VikingDBMemoryService import VikingDBMemoryService, initialize_services, ensure_collection_exists, search_relevant_memories
import time
from dotenv import load_dotenv
from volcenginesdkarkruntime import Ark
from Config.Config import VOLC_ACCESSKEY, VOLC_SECRETKEY, VOLC_API_KEY
from VikingDBMemoryService import VikingDBMemoryService
""" """
在记忆库准备好后,我们先模拟一段包含两轮的完整对话。 在记忆库准备好后,我们先模拟一段包含两轮的完整对话。
@@ -13,96 +7,6 @@ from VikingDBMemoryService import VikingDBMemoryService
AI 就能用刚写入的记忆来回答。 AI 就能用刚写入的记忆来回答。
注意:首次写入需要 35 分钟建立索引,这段时间内检索会报错。 注意:首次写入需要 35 分钟建立索引,这段时间内检索会报错。
""" """
def initialize_services():
load_dotenv()
ak = VOLC_ACCESSKEY
sk = VOLC_SECRETKEY
ark_api_key = VOLC_API_KEY
if not all([ak, sk, ark_api_key]):
raise ValueError("必须在环境变量中设置 VOLC_ACCESSKEY, VOLC_SECRETKEY, 和 ARK_API_KEY。")
memory_service = VikingDBMemoryService(ak=ak, sk=sk)
llm_client = Ark(
base_url="https://ark.cn-beijing.volces.com/api/v3",
api_key=ark_api_key,
)
return memory_service, llm_client
def ensure_collection_exists(memory_service, collection_name):
"""检查记忆集合是否存在,如果不存在则创建。"""
try:
memory_service.get_collection(collection_name)
print(f"记忆集合 '{collection_name}' 已存在。")
except Exception as e:
error_message = str(e)
if "collection not exist" in error_message:
print(f"记忆集合 '{collection_name}' 未找到,正在创建...")
try:
memory_service.create_collection(
collection_name=collection_name,
description="中文情感陪伴场景测试",
builtin_event_types=["sys_event_v1", "sys_profile_collect_v1"],
builtin_entity_types=["sys_profile_v1"]
)
print(f"记忆集合 '{collection_name}' 创建成功。")
print("等待集合准备就绪...")
except Exception as create_e:
print(f"创建集合失败: {create_e}")
raise
else:
print(f"检查集合时出错: {e}")
raise
def search_relevant_memories(memory_service, collection_name, user_id, query):
"""搜索与用户查询相关的记忆,并在索引构建中时重试。"""
print(f"正在搜索与 '{query}' 相关的记忆...")
retry_attempt = 0
while True:
try:
filter_params = {
"user_id": [user_id],
"memory_type": ["sys_event_v1", "sys_profile_v1"]
}
response = memory_service.search_memory(
collection_name=collection_name,
query=query,
filter=filter_params,
limit=3
)
memories = []
if response.get('data', {}).get('count', 0) > 0:
for result in response['data']['result_list']:
if 'memory_info' in result and result['memory_info']:
memories.append({
'memory_info': result['memory_info'],
'score': result['score']
})
if memories:
if retry_attempt > 0:
print("重试后搜索成功。")
print(f"找到 {len(memories)} 条相关记忆:")
for i, memory in enumerate(memories, 1):
print(
f" {i}. (相关度: {memory['score']:.3f}): {json.dumps(memory['memory_info'], ensure_ascii=False, indent=2)}")
else:
print("未找到相关记忆。")
return memories
except Exception as e:
error_message = str(e)
if "1000023" in error_message:
retry_attempt += 1
print(f"记忆索引正在构建中。将在60秒后重试... (尝试次数 {retry_attempt})")
time.sleep(60)
else:
print(f"搜索记忆时出错 (不可重试): {e}")
return []
def handle_conversation_turn(memory_service, llm_client, collection_name, user_id, user_message, conversation_history): def handle_conversation_turn(memory_service, llm_client, collection_name, user_id, user_message, conversation_history):
"""处理一轮对话包括记忆搜索和LLM响应。""" """处理一轮对话包括记忆搜索和LLM响应。"""
@@ -139,7 +43,6 @@ def handle_conversation_turn(memory_service, llm_client, collection_name, user_i
]) ])
return assistant_reply return assistant_reply
def archive_conversation(memory_service, collection_name, user_id, assistant_id, conversation_history, topic_name): def archive_conversation(memory_service, collection_name, user_id, assistant_id, conversation_history, topic_name):
"""将对话历史归档到记忆数据库。""" """将对话历史归档到记忆数据库。"""
if not conversation_history: if not conversation_history:
@@ -168,7 +71,6 @@ def archive_conversation(memory_service, collection_name, user_id, assistant_id,
print(f"归档对话失败: {e}") print(f"归档对话失败: {e}")
return False return False
def main(): def main():
print("开始端到端记忆测试...") print("开始端到端记忆测试...")
@@ -211,6 +113,5 @@ def main():
print("\n端到端记忆测试完成!") print("\n端到端记忆测试完成!")
if __name__ == "__main__": if __name__ == "__main__":
main() main()