Files
dsProject/dsLightRag/Volcengine/T3_ChatWithMemory.py
2025-09-07 13:31:59 +08:00

248 lines
9.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import logging
import sys
import time
import json
from Config.Config import VOLC_ACCESSKEY, VOLC_SECRETKEY, VOLC_API_KEY
from Volcengine.Kit.VikingDBMemoryService import VikingDBMemoryService, MEMORY_COLLECTION_NAME
from volcenginesdkarkruntime import Ark
# 控制日志输出
logger = logging.getLogger('ChatWithMemory')
logger.setLevel(logging.INFO)
# 只添加一次处理器,避免重复日志
if not logger.handlers:
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
logger.addHandler(handler)
def initialize_services():
"""初始化服务和LLM客户端"""
ak = VOLC_ACCESSKEY
sk = VOLC_SECRETKEY
ark_api_key = VOLC_API_KEY
if not all([ak, sk, ark_api_key]):
raise ValueError("必须在环境变量中设置 VOLC_ACCESSKEY, VOLC_SECRETKEY, 和 ARK_API_KEY。")
memory_service = VikingDBMemoryService(
ak=ak,
sk=sk,
host="api-knowledgebase.mlp.cn-beijing.volces.com",
region="cn-beijing"
)
llm_client = Ark(
base_url="https://ark.cn-beijing.volces.com/api/v3",
api_key=ark_api_key,
)
return memory_service, llm_client
def search_relevant_memories(memory_service, collection_name, user_id, query):
"""搜索与用户查询相关的记忆,并在索引构建中时重试。"""
logger.info(f"正在搜索与 '{query}' 相关的记忆...")
retry_attempt = 0
while True:
try:
filter_params = {
"user_id": [user_id],
"memory_type": ["sys_event_v1", "sys_profile_v1"]
}
response = memory_service.search_memory(
collection_name=collection_name,
query=query,
filter=filter_params,
limit=3
)
memories = []
if response.get('data', {}).get('count', 0) > 0:
for result in response['data']['result_list']:
if 'memory_info' in result and result['memory_info']:
memories.append({
'memory_info': result['memory_info'],
'score': result['score']
})
if memories:
if retry_attempt > 0:
logger.info("重试后搜索成功。")
logger.info(f"找到 {len(memories)} 条相关记忆:")
for i, memory in enumerate(memories, 1):
logger.info(
f" {i}. (相关度: {memory['score']:.3f}): {json.dumps(memory['memory_info'], ensure_ascii=False, indent=2)}")
else:
logger.info("未找到相关记忆。")
return memories
except Exception as e:
error_message = str(e)
if "1000023" in error_message:
retry_attempt += 1
logger.info(f"记忆索引正在构建中。将在60秒后重试... (尝试次数 {retry_attempt})")
time.sleep(60)
else:
logger.info(f"搜索记忆时出错 (不可重试): {e}")
return []
def handle_conversation_turn(memory_service, llm_client, collection_name, user_id, user_message, conversation_history):
"""处理一轮对话包括记忆搜索和LLM响应。"""
logger.info("\n" + "=" * 60)
logger.info(f"用户: {user_message}")
relevant_memories = search_relevant_memories(memory_service, collection_name, user_id, user_message)
system_prompt = "你是一个富有同情心、善于倾听的AI伙伴拥有长期记忆能力。你的目标是为用户提供情感支持和温暖的陪伴。"
if relevant_memories:
memory_context = "\n".join(
[f"- {json.dumps(mem['memory_info'], ensure_ascii=False)}" for mem in relevant_memories])
system_prompt += f"\n\n这是我们过去的一些对话记忆,请参考:\n{memory_context}\n\n请利用这些信息来更好地理解和回应用户。"
logger.info("AI正在思考...")
try:
messages = [{"role": "system", "content": system_prompt}] + conversation_history + [
{"role": "user", "content": user_message}]
completion = llm_client.chat.completions.create(
model="doubao-seed-1-6-flash-250715",
messages=messages
)
assistant_reply = completion.choices[0].message.content
except Exception as e:
logger.info(f"LLM调用失败: {e}")
assistant_reply = "抱歉,我现在有点混乱,无法回应。我们可以稍后再聊吗?"
logger.info(f"伙伴: {assistant_reply}")
conversation_history.extend([
{"role": "user", "content": user_message},
{"role": "assistant", "content": assistant_reply}
])
return assistant_reply
def archive_conversation(memory_service, collection_name, user_id, assistant_id, conversation_history, topic_name):
"""将对话历史归档到记忆数据库。"""
if not conversation_history:
logger.info("没有对话可以归档。")
return False
logger.info(f"\n正在归档关于 '{topic_name}' 的对话...")
session_id = f"{topic_name}_{int(time.time())}"
metadata = {
"default_user_id": user_id,
"default_assistant_id": assistant_id,
"time": int(time.time() * 1000)
}
try:
memory_service.add_session(
collection_name=collection_name,
session_id=session_id,
messages=conversation_history,
metadata=metadata
)
logger.info(f"对话已成功归档会话ID: {session_id}")
logger.info("正在等待记忆索引更新...")
return True
except Exception as e:
logger.info(f"归档对话失败: {e}")
return False
def main():
logger.info("开始测试大模型记忆功能...")
try:
# 使用initialize_services函数初始化服务和LLM客户端
memory_service, llm_client = initialize_services()
collection_name = MEMORY_COLLECTION_NAME
user_id = "liming"
assistant_id = "assistant"
# 告知大模型用户信息
logger.info("告知大模型用户信息...")
user_info = "李明男生15岁家住长春"
# 记录信息到记忆体
logger.info("记录用户信息到记忆体...")
# 准备对话历史
conversation_history = []
# 使用正确的handle_conversation_turn方法参数
response = handle_conversation_turn(
memory_service=memory_service,
llm_client=llm_client,
collection_name=collection_name,
user_id=user_id,
user_message=f"请记住以下用户信息:{user_info}",
conversation_history=conversation_history
)
logger.info(f"模型回复: {response}")
# 归档对话
archive_conversation(
memory_service=memory_service,
collection_name=collection_name,
user_id=user_id,
assistant_id=assistant_id,
conversation_history=conversation_history,
topic_name="user_info"
)
# 等待一段时间确保索引更新
logger.info("等待索引更新...")
time.sleep(5)
# 验证大模型是否记住个人信息
logger.info("验证大模型是否记住个人信息...")
test_conversation_history = []
test_response = handle_conversation_turn(
memory_service=memory_service,
llm_client=llm_client,
collection_name=collection_name,
user_id=user_id,
user_message="请告诉我李明的个人信息",
conversation_history=test_conversation_history
)
logger.info(f"测试回复: {test_response}")
# 检查回复中是否包含关键信息
keywords = ["李明", "", "15", "长春"]
found_keywords = [kw for kw in keywords if kw in test_response]
if len(found_keywords) == len(keywords):
logger.info("✅ 大模型成功记住了用户信息!")
else:
logger.info(f"❌ 大模型可能没有完全记住用户信息。找到的关键词: {found_keywords}")
# 尝试直接搜索记忆
logger.info("尝试直接搜索记忆...")
filter_params = {
"user_id": [user_id],
"memory_type": ["sys_event_v1", "sys_profile_v1"]
}
search_result = memory_service.search_memory(
collection_name=collection_name,
query="李明 15岁 长春 男生",
filter=filter_params,
limit=5
)
logger.info(f"搜索结果: {search_result}")
except Exception as e:
logger.error(f"操作失败: {e}")
sys.exit(1)
if __name__ == "__main__":
main()