This commit is contained in:
2025-09-07 08:56:35 +08:00
parent bb94553638
commit 25eb08b1db
7 changed files with 626 additions and 20 deletions

View File

@@ -0,0 +1,53 @@
import json
import requests
from volcengine.base.Request import Request
from volcengine.Credentials import Credentials
from volcengine.auth.SignerV4 import SignerV4
from Config.Config import VOLC_ACCESSKEY, VOLC_SECRETKEY
from Volcengine.VikingDBMemoryService import MEMORY_COLLECTION_NAME
AK = VOLC_ACCESSKEY
SK = VOLC_SECRETKEY
Domain = "api-knowledgebase.mlp.cn-beijing.volces.com"
def prepare_request(method, path, ak, sk, data=None):
r = Request()
r.set_shema("http") # 注意:这里用 http因为 SignerV4 内部会拼 host
r.set_method(method)
r.set_host(Domain)
r.set_path(path)
if data is not None:
r.set_body(json.dumps(data))
# 使用 air 服务和 cn-north-1 区域
credentials = Credentials(ak, sk, 'air', 'cn-north-1')
SignerV4.sign(r, credentials)
return r
def internal_request(method, api, payload, params=None):
req = prepare_request(
method=method,
path=api,
ak=AK,
sk=SK,
data=payload
)
r = requests.request(
method=req.method,
url="{}://{}{}".format(req.schema, req.host, req.path),
headers=req.headers,
data=req.body,
params=params,
)
return r
# 查询记忆库信息
path = '/api/memory/collection/info'
payload = {
"CollectionName": MEMORY_COLLECTION_NAME
}
rsp = internal_request("POST", path, payload)
print(rsp.json())

View File

@@ -7,9 +7,12 @@ from Config.Config import VOLC_ACCESSKEY, VOLC_SECRETKEY
# 控制日志输出 # 控制日志输出
logger = logging.getLogger('CollectionMemory') logger = logging.getLogger('CollectionMemory')
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')) # 只添加一次处理器,避免重复日志
logger.addHandler(handler) if not logger.handlers:
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
logger.addHandler(handler)
def drop_existing_collection(collection_name): def drop_existing_collection(collection_name):
# 初始化记忆库服务 # 初始化记忆库服务

View File

@@ -1,14 +1,20 @@
import json import json
import logging import logging
import time
from Config.Config import VOLC_ACCESSKEY, VOLC_SECRETKEY from Config.Config import VOLC_ACCESSKEY, VOLC_SECRETKEY
from VikingDBMemoryService import VikingDBMemoryService, MEMORY_COLLECTION_NAME from VikingDBMemoryService import VikingDBMemoryService, MEMORY_COLLECTION_NAME, VikingDBMemoryException
# 控制日志输出 # 控制日志输出
logger = logging.getLogger('CollectionMemory') logger = logging.getLogger('CollectionMemory')
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')) # 只添加一次处理器,避免重复日志
logger.addHandler(handler) if not logger.handlers:
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
logger.addHandler(handler)
def create_memory_collection(collection_name, description="情感陪伴记忆库"): def create_memory_collection(collection_name, description="情感陪伴记忆库"):
@@ -44,14 +50,43 @@ def create_memory_collection(collection_name, description="情感陪伴记忆库
logger.info(f"创建响应: {json.dumps(response, ensure_ascii=False, indent=2)}") logger.info(f"创建响应: {json.dumps(response, ensure_ascii=False, indent=2)}")
logger.info(f"集合 '{collection_name}' 创建成功") logger.info(f"集合 '{collection_name}' 创建成功")
# 等待集合就绪 # 等待集合就绪 - 修改为模拟chat.py的索引就绪检查机制
logger.info("等待集合初始化完成...") logger.info("等待集合初始化完成...")
# 将独立函数调用改为实例方法调用 max_retries = 30 # 最多重试30次
if memory_service.wait_for_collection_ready(): retry_interval = 10 # 每10秒重试一次
logger.info(f"集合 '{collection_name}' 已就绪,可以开始使用") retry_count = 0
# 增加初始延迟,避免创建后立即检查
logger.info(f"初始延迟30秒等待索引构建...")
time.sleep(30)
while retry_count < max_retries:
try:
# 尝试执行需要索引的操作模拟chat.py中的搜索逻辑
filter_params = {"memory_type": ["sys_event_v1"]}
memory_service.search_memory(
collection_name=collection_name,
query="test",
filter=filter_params,
limit=1
)
# 如果没有抛出索引错误,则认为就绪
logger.info(f"集合 '{collection_name}' 索引构建完成,已就绪")
return True return True
except VikingDBMemoryException as e:
error_msg = str(e)
# 检查是否是索引未就绪相关错误
if "index not exist" in error_msg or "need to add messages" in error_msg:
retry_count += 1
logger.info(f"索引尚未就绪,等待中... (重试 {retry_count}/{max_retries})")
time.sleep(retry_interval)
else: else:
logger.info(f"集合 '{collection_name}' 初始化超时") logger.error(f"检查索引就绪状态时发生错误: {str(e)}")
return False
except Exception as e:
logger.error(f"检查过程发生意外错误: {str(e)}")
return False
logger.error(f"集合 '{collection_name}' 索引构建超时")
return False return False
except Exception as e: except Exception as e:

View File

@@ -0,0 +1,107 @@
import json
import logging
import time
from Volcengine.VikingDBMemoryService import initialize_services, MEMORY_COLLECTION_NAME
from volcenginesdkarkruntime import Ark
from Config.Config import VOLC_API_KEY
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger('MemoryChatSimulator')
def simulate_chat_session(memory_service, llm_client, user_id, session_topic):
"""模拟完整对话会话:创建对话→存储记忆→搜索记忆"""
# 初始化对话历史
conversation_history = []
logger.info(f"\n=== 开始新对话会话: {session_topic} ===")
# 模拟两轮对话
user_messages = [
f"你好,我是{user_id}{session_topic}",
"我最近感觉压力很大,不知道该怎么办"
]
for message in user_messages:
# 处理对话轮次
assistant_reply = memory_service.handle_conversation_turn(
llm_client=llm_client,
user_id=user_id,
user_message=message,
conversation_history=conversation_history
)
time.sleep(1) # 模拟思考延迟
# 归档对话记忆
logger.info("\n=== 归档对话记忆 ===")
archive_success = memory_service.archive_conversation(
user_id=user_id,
assistant_id="simulated_assistant",
conversation_history=conversation_history,
topic_name=session_topic
)
# 搜索相关记忆
if archive_success:
logger.info("\n=== 搜索相关记忆 ===")
# 等待索引更新
logger.info("等待记忆库索引更新...")
time.sleep(5) # 实际环境可能需要更长时间
# 搜索与当前主题相关的记忆
search_query = f"{session_topic}相关的问题"
relevant_memories = memory_service.search_relevant_memories(
collection_name=MEMORY_COLLECTION_NAME,
user_id=user_id,
query=search_query,
limit=3
)
logger.info(f"\n搜索到 '{search_query}' 的相关记忆 ({len(relevant_memories)}条):")
for i, memory in enumerate(relevant_memories, 1):
logger.info(f"\n记忆 {i} (相关度: {memory['score']:.3f}):")
logger.info(f"内容: {memory['memory_info']}")
return archive_success
def main():
"""主函数:初始化服务并执行模拟测试"""
logger.info("===== 记忆库存储与搜索模拟测试 ====")
try:
# 初始化服务
logger.info("初始化记忆库服务和LLM客户端...")
memory_service, llm_client = initialize_services()
# 确保集合就绪
logger.info(f"检查集合 '{MEMORY_COLLECTION_NAME}' 是否就绪...")
# 先验证集合是否存在
try:
collection_info = memory_service.get_collection(MEMORY_COLLECTION_NAME)
logger.info(f"集合 '{MEMORY_COLLECTION_NAME}' 存在,详细信息: {json.dumps(collection_info, ensure_ascii=False)}...")
except Exception as e:
logger.error(f"集合 '{MEMORY_COLLECTION_NAME}' 不存在: {str(e)}")
logger.error("请先运行T2_CreateIndex.py创建集合")
return
if not memory_service.wait_for_collection_ready(timeout=120):
logger.error("集合未就绪,无法继续测试")
return
# 模拟两个不同用户的对话会话
user_sessions = [
("student_001", "学习压力问题"),
("teacher_001", "教学方法讨论")
]
for user_id, topic in user_sessions:
simulate_chat_session(memory_service, llm_client, user_id, topic)
time.sleep(2) # 会话间隔
logger.info("\n===== 模拟测试完成 ====")
except Exception as e:
logger.error(f"测试过程中发生错误: {str(e)}", exc_info=True)
if __name__ == "__main__":
main()

View File

@@ -377,17 +377,22 @@ class VikingDBMemoryService(Service):
start_time = time.time() start_time = time.time()
while time.time() - start_time < timeout: while time.time() - start_time < timeout:
try: try:
# 使用类中定义的集合名称常量
collection_info = self.get_collection(MEMORY_COLLECTION_NAME) collection_info = self.get_collection(MEMORY_COLLECTION_NAME)
status = collection_info.get("Status", "UNKNOWN") # 打印完整集合信息用于调试
logger.info(f"集合详细信息: {json.dumps(collection_info, ensure_ascii=False)}")
# 尝试多种可能的状态字段名
status = collection_info.get("Status") or collection_info.get("status") or "UNKNOWN"
logger.info(f"集合 '{MEMORY_COLLECTION_NAME}' 当前状态: {status}") logger.info(f"集合 '{MEMORY_COLLECTION_NAME}' 当前状态: {status}")
if status == "READY":
# 检查是否为就绪状态可能的值READY, RUNNING, ACTIVE等
if status in ["READY", "RUNNING", "ACTIVE"]:
return True return True
time.sleep(interval) time.sleep(interval)
except Exception as e: except Exception as e:
logger.info(f"检查集合状态失败: {e}") logger.error(f"检查集合状态失败: {str(e)}")
time.sleep(interval) time.sleep(interval)
logger.info(f"集合 '{MEMORY_COLLECTION_NAME}'{timeout}秒内未就绪") logger.error(f"集合 '{MEMORY_COLLECTION_NAME}'{timeout}秒内未就绪")
return False return False
def setup_memory_collection(self): def setup_memory_collection(self):

View File

@@ -0,0 +1,403 @@
import json
import threading
from volcengine.ApiInfo import ApiInfo
from volcengine.Credentials import Credentials
from volcengine.base.Service import Service
from volcengine.ServiceInfo import ServiceInfo
from volcengine.auth.SignerV4 import SignerV4
from volcengine.base.Request import Request
from Config.Config import VOLC_SECRETKEY, VOLC_API_KEY, VOLC_ACCESSKEY
class VikingDBMemoryException(Exception):
def __init__(self, code, request_id, message=None):
self.code = code
self.request_id = request_id
self.message = "{}, code:{}request_id:{}".format(message, self.code, self.request_id)
def __str__(self):
return self.message
class VikingDBMemoryService(Service):
_instance_lock = threading.Lock()
def __new__(cls, *args, **kwargs):
if not hasattr(VikingDBMemoryService, "_instance"):
with VikingDBMemoryService._instance_lock:
if not hasattr(VikingDBMemoryService, "_instance"):
VikingDBMemoryService._instance = object.__new__(cls)
return VikingDBMemoryService._instance
def __init__(self, host="api-knowledgebase.mlp.cn-beijing.volces.com", region="cn-beijing", ak="", sk="",
sts_token="", scheme='https',
connection_timeout=30, socket_timeout=30):
self.service_info = VikingDBMemoryService.get_service_info(host, region, scheme, connection_timeout,
socket_timeout)
self.api_info = VikingDBMemoryService.get_api_info()
super(VikingDBMemoryService, self).__init__(self.service_info, self.api_info)
if ak:
self.set_ak(ak)
if sk:
self.set_sk(sk)
if sts_token:
self.set_session_token(session_token=sts_token)
try:
self.get_body("Ping", {}, json.dumps({}))
except Exception as e:
raise VikingDBMemoryException(1000028, "missed", "host or region is incorrect".format(str(e))) from None
def setHeader(self, header):
api_info = VikingDBMemoryService.get_api_info()
for key in api_info:
for item in header:
api_info[key].header[item] = header[item]
self.api_info = api_info
@staticmethod
def get_service_info(host, region, scheme, connection_timeout, socket_timeout):
service_info = ServiceInfo(host, {"Host": host},
Credentials('', '', 'air', region), connection_timeout, socket_timeout,
scheme=scheme)
return service_info
@staticmethod
def get_api_info():
api_info = {
"CreateCollection": ApiInfo("POST", "/api/memory/collection/create", {}, {},
{'Accept': 'application/json', 'Content-Type': 'application/json'}),
"GetCollection": ApiInfo("POST", "/api/memory/collection/info", {}, {},
{'Accept': 'application/json', 'Content-Type': 'application/json'}),
"DropCollection": ApiInfo("POST", "/api/memory/collection/delete", {}, {},
{'Accept': 'application/json', 'Content-Type': 'application/json'}),
"UpdateCollection": ApiInfo("POST", "/api/memory/collection/update", {}, {},
{'Accept': 'application/json', 'Content-Type': 'application/json'}),
"SearchMemory": ApiInfo("POST", "/api/memory/search", {}, {},
{'Accept': 'application/json', 'Content-Type': 'application/json'}),
"AddSession": ApiInfo("POST", "/api/memory/session/add", {}, {},
{'Accept': 'application/json', 'Content-Type': 'application/json'}),
"Ping": ApiInfo("GET", "/api/memory/ping", {}, {},
{'Accept': 'application/json', 'Content-Type': 'application/json'}),
}
return api_info
def get_body(self, api, params, body):
if not (api in self.api_info):
raise Exception("no such api")
api_info = self.api_info[api]
r = self.prepare_request(api_info, params)
r.headers['Content-Type'] = 'application/json'
r.headers['Traffic-Source'] = 'SDK'
r.body = body
SignerV4.sign(r, self.service_info.credentials)
url = r.build()
resp = self.session.get(url, headers=r.headers, data=r.body,
timeout=(self.service_info.connection_timeout, self.service_info.socket_timeout))
if resp.status_code == 200:
return json.dumps(resp.json())
else:
raise Exception(resp.text.encode("utf-8"))
def get_body_exception(self, api, params, body):
try:
res = self.get_body(api, params, body)
except Exception as e:
try:
res_json = json.loads(e.args[0].decode("utf-8"))
except:
raise VikingDBMemoryException(1000028, "missed", "json load res error, res:{}".format(str(e))) from None
code = res_json.get("code", 1000028)
request_id = res_json.get("request_id", 1000028)
message = res_json.get("message", None)
raise VikingDBMemoryException(code, request_id, message)
if res == '':
raise VikingDBMemoryException(1000028, "missed",
"empty response due to unknown error, please contact customer service") from None
return res
def get_exception(self, api, params):
try:
res = self.get(api, params)
except Exception as e:
try:
res_json = json.loads(e.args[0].decode("utf-8"))
except:
raise VikingDBMemoryException(1000028, "missed", "json load res error, res:{}".format(str(e))) from None
code = res_json.get("code", 1000028)
request_id = res_json.get("request_id", 1000028)
message = res_json.get("message", None)
raise VikingDBMemoryException(code, request_id, message)
if res == '':
raise VikingDBMemoryException(1000028, "missed",
"empty response due to unknown error, please contact customer service") from None
return res
def create_collection(self, collection_name, description="", custom_event_type_schemas=[],
custom_entity_type_schemas=[], builtin_event_types=[], builtin_entity_types=[]):
params = {
"CollectionName": collection_name, "Description": description,
"CustomEventTypeSchemas": custom_event_type_schemas, "CustomEntityTypeSchemas": custom_entity_type_schemas,
"BuiltinEventTypes": builtin_event_types, "BuiltinEntityTypes": builtin_entity_types,
}
res = self.json("CreateCollection", {}, json.dumps(params))
return json.loads(res)
def get_collection(self, collection_name):
params = {"CollectionName": collection_name}
res = self.json("GetCollection", {}, json.dumps(params))
return json.loads(res)
def drop_collection(self, collection_name):
params = {"CollectionName": collection_name}
res = self.json("DropCollection", {}, json.dumps(params))
return json.loads(res)
def update_collection(self, collection_name, custom_event_type_schemas=[], custom_entity_type_schemas=[],
builtin_event_types=[], builtin_entity_types=[]):
params = {
"CollectionName": collection_name,
"CustomEventTypeSchemas": custom_event_type_schemas, "CustomEntityTypeSchemas": custom_entity_type_schemas,
"BuiltinEventTypes": builtin_event_types, "BuiltinEntityTypes": builtin_entity_types,
}
res = self.json("UpdateCollection", {}, json.dumps(params))
return json.loads(res)
def search_memory(self, collection_name, query, filter, limit=10):
params = {
"collection_name": collection_name,
"query": query,
"limit": limit,
"filter": filter,
}
res = self.json("SearchMemory", {}, json.dumps(params))
return json.loads(res)
def add_session(self, collection_name, session_id, messages, metadata, entities=None):
params = {
"collection_name": collection_name,
"session_id": session_id,
"messages": messages,
"metadata": metadata,
}
if entities is not None:
params["entities"] = entities
res = self.json("AddSession", {}, json.dumps(params))
return json.loads(res)
import json
import os
import time
from dotenv import load_dotenv
from volcenginesdkarkruntime import Ark
def initialize_services():
load_dotenv()
ak = VOLC_ACCESSKEY
sk = VOLC_SECRETKEY
ark_api_key = VOLC_API_KEY
if not all([ak, sk, ark_api_key]):
raise ValueError("必须在环境变量中设置 VOLC_ACCESSKEY, VOLC_SECRETKEY, 和 ARK_API_KEY。")
memory_service = VikingDBMemoryService(ak=ak, sk=sk)
llm_client = Ark(
base_url="https://ark.cn-beijing.volces.com/api/v3",
api_key=ark_api_key,
)
return memory_service, llm_client
def ensure_collection_exists(memory_service, collection_name):
"""检查记忆集合是否存在,如果不存在则创建。"""
try:
memory_service.get_collection(collection_name)
print(f"记忆集合 '{collection_name}' 已存在。")
except Exception as e:
error_message = str(e)
if "collection not exist" in error_message:
print(f"记忆集合 '{collection_name}' 未找到,正在创建...")
try:
memory_service.create_collection(
collection_name=collection_name,
description="中文情感陪伴场景测试",
builtin_event_types=["sys_event_v1", "sys_profile_collect_v1"],
builtin_entity_types=["sys_profile_v1"]
)
print(f"记忆集合 '{collection_name}' 创建成功。")
print("等待集合准备就绪...")
except Exception as create_e:
print(f"创建集合失败: {create_e}")
raise
else:
print(f"检查集合时出错: {e}")
raise
def search_relevant_memories(memory_service, collection_name, user_id, query):
"""搜索与用户查询相关的记忆,并在索引构建中时重试。"""
print(f"正在搜索与 '{query}' 相关的记忆...")
retry_attempt = 0
while True:
try:
filter_params = {
"user_id": [user_id],
"memory_type": ["sys_event_v1", "sys_profile_v1"]
}
response = memory_service.search_memory(
collection_name=collection_name,
query=query,
filter=filter_params,
limit=3
)
memories = []
if response.get('data', {}).get('count', 0) > 0:
for result in response['data']['result_list']:
if 'memory_info' in result and result['memory_info']:
memories.append({
'memory_info': result['memory_info'],
'score': result['score']
})
if memories:
if retry_attempt > 0:
print("重试后搜索成功。")
print(f"找到 {len(memories)} 条相关记忆:")
for i, memory in enumerate(memories, 1):
print(
f" {i}. (相关度: {memory['score']:.3f}): {json.dumps(memory['memory_info'], ensure_ascii=False, indent=2)}")
else:
print("未找到相关记忆。")
return memories
except Exception as e:
error_message = str(e)
if "1000023" in error_message:
retry_attempt += 1
print(f"记忆索引正在构建中。将在60秒后重试... (尝试次数 {retry_attempt})")
time.sleep(60)
else:
print(f"搜索记忆时出错 (不可重试): {e}")
return []
def handle_conversation_turn(memory_service, llm_client, collection_name, user_id, user_message, conversation_history):
"""处理一轮对话包括记忆搜索和LLM响应。"""
print("\n" + "=" * 60)
print(f"用户: {user_message}")
relevant_memories = search_relevant_memories(memory_service, collection_name, user_id, user_message)
system_prompt = "你是一个富有同情心、善于倾听的AI伙伴拥有长期记忆能力。你的目标是为用户提供情感支持和温暖的陪伴。"
if relevant_memories:
memory_context = "\n".join(
[f"- {json.dumps(mem['memory_info'], ensure_ascii=False)}" for mem in relevant_memories])
system_prompt += f"\n\n这是我们过去的一些对话记忆,请参考:\n{memory_context}\n\n请利用这些信息来更好地理解和回应用户。"
print("AI正在思考...")
try:
messages = [{"role": "system", "content": system_prompt}] + conversation_history + [
{"role": "user", "content": user_message}]
completion = llm_client.chat.completions.create(
model="doubao-seed-1-6-flash-250715",
messages=messages
)
assistant_reply = completion.choices[0].message.content
except Exception as e:
print(f"LLM调用失败: {e}")
assistant_reply = "抱歉,我现在有点混乱,无法回应。我们可以稍后再聊吗?"
print(f"伙伴: {assistant_reply}")
conversation_history.extend([
{"role": "user", "content": user_message},
{"role": "assistant", "content": assistant_reply}
])
return assistant_reply
def archive_conversation(memory_service, collection_name, user_id, assistant_id, conversation_history, topic_name):
"""将对话历史归档到记忆数据库。"""
if not conversation_history:
print("没有对话可以归档。")
return False
print(f"\n正在归档关于 '{topic_name}' 的对话...")
session_id = f"{topic_name}_{int(time.time())}"
metadata = {
"default_user_id": user_id,
"default_assistant_id": assistant_id,
"time": int(time.time() * 1000)
}
try:
memory_service.add_session(
collection_name=collection_name,
session_id=session_id,
messages=conversation_history,
metadata=metadata
)
print(f"对话已成功归档会话ID: {session_id}")
print("正在等待记忆索引更新...")
return True
except Exception as e:
print(f"归档对话失败: {e}")
return False
def main():
print("开始端到端记忆测试...")
try:
memory_service, llm_client = initialize_services()
collection_name = "emotional_support"
user_id = "xiaoming"
assistant_id = "assistant"
ensure_collection_exists(memory_service, collection_name)
except Exception as e:
print(f"初始化失败: {e}")
return
print("\n--- 阶段 1: 初始对话 ---")
initial_conversation_history = []
handle_conversation_turn(
memory_service, llm_client, collection_name, user_id,
"你好我是小明今年18岁但压力好大。",
initial_conversation_history
)
handle_conversation_turn(
memory_service, llm_client, collection_name, user_id,
"马上就要高考了,家里人的期待好高。",
initial_conversation_history
)
print("\n--- 阶段 2: 归档记忆 ---")
archive_conversation(
memory_service, collection_name, user_id, assistant_id,
initial_conversation_history, "study_stress_discussion"
)
print("\n--- 阶段 3: 验证记忆 ---")
verification_conversation_history = []
handle_conversation_turn(
memory_service, llm_client, collection_name, user_id,
"我最近很焦虑,不知道该怎么办。",
verification_conversation_history
)
print("\n端到端记忆测试完成!")
if __name__ == "__main__":
main()