107 lines
3.9 KiB
Python
107 lines
3.9 KiB
Python
|
import json
|
|||
|
import logging
|
|||
|
import time
|
|||
|
from Volcengine.VikingDBMemoryService import initialize_services, MEMORY_COLLECTION_NAME
|
|||
|
from volcenginesdkarkruntime import Ark
|
|||
|
from Config.Config import VOLC_API_KEY
|
|||
|
|
|||
|
# 配置日志
|
|||
|
logging.basicConfig(level=logging.INFO)
|
|||
|
logger = logging.getLogger('MemoryChatSimulator')
|
|||
|
|
|||
|
def simulate_chat_session(memory_service, llm_client, user_id, session_topic):
|
|||
|
"""模拟完整对话会话:创建对话→存储记忆→搜索记忆"""
|
|||
|
# 初始化对话历史
|
|||
|
conversation_history = []
|
|||
|
logger.info(f"\n=== 开始新对话会话: {session_topic} ===")
|
|||
|
|
|||
|
# 模拟两轮对话
|
|||
|
user_messages = [
|
|||
|
f"你好,我是{user_id},{session_topic}",
|
|||
|
"我最近感觉压力很大,不知道该怎么办"
|
|||
|
]
|
|||
|
|
|||
|
for message in user_messages:
|
|||
|
# 处理对话轮次
|
|||
|
assistant_reply = memory_service.handle_conversation_turn(
|
|||
|
llm_client=llm_client,
|
|||
|
user_id=user_id,
|
|||
|
user_message=message,
|
|||
|
conversation_history=conversation_history
|
|||
|
)
|
|||
|
time.sleep(1) # 模拟思考延迟
|
|||
|
|
|||
|
# 归档对话记忆
|
|||
|
logger.info("\n=== 归档对话记忆 ===")
|
|||
|
archive_success = memory_service.archive_conversation(
|
|||
|
user_id=user_id,
|
|||
|
assistant_id="simulated_assistant",
|
|||
|
conversation_history=conversation_history,
|
|||
|
topic_name=session_topic
|
|||
|
)
|
|||
|
|
|||
|
# 搜索相关记忆
|
|||
|
if archive_success:
|
|||
|
logger.info("\n=== 搜索相关记忆 ===")
|
|||
|
# 等待索引更新
|
|||
|
logger.info("等待记忆库索引更新...")
|
|||
|
time.sleep(5) # 实际环境可能需要更长时间
|
|||
|
|
|||
|
# 搜索与当前主题相关的记忆
|
|||
|
search_query = f"{session_topic}相关的问题"
|
|||
|
relevant_memories = memory_service.search_relevant_memories(
|
|||
|
collection_name=MEMORY_COLLECTION_NAME,
|
|||
|
user_id=user_id,
|
|||
|
query=search_query,
|
|||
|
limit=3
|
|||
|
)
|
|||
|
|
|||
|
logger.info(f"\n搜索到 '{search_query}' 的相关记忆 ({len(relevant_memories)}条):")
|
|||
|
for i, memory in enumerate(relevant_memories, 1):
|
|||
|
logger.info(f"\n记忆 {i} (相关度: {memory['score']:.3f}):")
|
|||
|
logger.info(f"内容: {memory['memory_info']}")
|
|||
|
|
|||
|
return archive_success
|
|||
|
|
|||
|
def main():
|
|||
|
"""主函数:初始化服务并执行模拟测试"""
|
|||
|
logger.info("===== 记忆库存储与搜索模拟测试 ====")
|
|||
|
|
|||
|
try:
|
|||
|
# 初始化服务
|
|||
|
logger.info("初始化记忆库服务和LLM客户端...")
|
|||
|
memory_service, llm_client = initialize_services()
|
|||
|
|
|||
|
# 确保集合就绪
|
|||
|
logger.info(f"检查集合 '{MEMORY_COLLECTION_NAME}' 是否就绪...")
|
|||
|
|
|||
|
# 先验证集合是否存在
|
|||
|
try:
|
|||
|
collection_info = memory_service.get_collection(MEMORY_COLLECTION_NAME)
|
|||
|
logger.info(f"集合 '{MEMORY_COLLECTION_NAME}' 存在,详细信息: {json.dumps(collection_info, ensure_ascii=False)}...")
|
|||
|
except Exception as e:
|
|||
|
logger.error(f"集合 '{MEMORY_COLLECTION_NAME}' 不存在: {str(e)}")
|
|||
|
logger.error("请先运行T2_CreateIndex.py创建集合")
|
|||
|
return
|
|||
|
|
|||
|
if not memory_service.wait_for_collection_ready(timeout=120):
|
|||
|
logger.error("集合未就绪,无法继续测试")
|
|||
|
return
|
|||
|
|
|||
|
# 模拟两个不同用户的对话会话
|
|||
|
user_sessions = [
|
|||
|
("student_001", "学习压力问题"),
|
|||
|
("teacher_001", "教学方法讨论")
|
|||
|
]
|
|||
|
|
|||
|
for user_id, topic in user_sessions:
|
|||
|
simulate_chat_session(memory_service, llm_client, user_id, topic)
|
|||
|
time.sleep(2) # 会话间隔
|
|||
|
|
|||
|
logger.info("\n===== 模拟测试完成 ====")
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
logger.error(f"测试过程中发生错误: {str(e)}", exc_info=True)
|
|||
|
|
|||
|
if __name__ == "__main__":
|
|||
|
main()
|