Files
dsProject/dsLightRag/Volcengine/Kit/VikingDBMemoryService.py
2025-09-07 13:38:24 +08:00

554 lines
24 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
pip install volcengine
pip install --upgrade "volcengine-python-sdk[ark]"
"""
import json
import logging
import threading
import time
from dotenv import load_dotenv
from volcengine.ApiInfo import ApiInfo
from volcengine.Credentials import Credentials
from volcengine.ServiceInfo import ServiceInfo
from volcengine.auth.SignerV4 import SignerV4
from volcengine.base.Service import Service
from volcenginesdkarkruntime import Ark
from Config.Config import VOLC_SECRETKEY, VOLC_ACCESSKEY, VOLC_API_KEY
# 配置日志
logger = logging.getLogger('CollectionMemory')
logger.setLevel(logging.INFO)
# 只添加一次处理器,避免重复日志
if not logger.handlers:
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
logger.addHandler(handler)
# 记忆体集合名称
MEMORY_COLLECTION_NAME="dsideal_collection"
class VikingDBMemoryException(Exception):
def __init__(self, code, request_id, message=None):
self.code = code
self.request_id = request_id
self.message = "{}, code:{}request_id:{}".format(message, self.code, self.request_id)
def __str__(self):
return self.message
class VikingDBMemoryService(Service):
_instance_lock = threading.Lock()
def __new__(cls, *args, **kwargs):
if not hasattr(VikingDBMemoryService, "_instance"):
with VikingDBMemoryService._instance_lock:
if not hasattr(VikingDBMemoryService, "_instance"):
VikingDBMemoryService._instance = object.__new__(cls)
return VikingDBMemoryService._instance
def __init__(self, host="api-knowledgebase.mlp.cn-beijing.volces.com", region="cn-beijing", ak="", sk="",
sts_token="", scheme='https',
connection_timeout=30, socket_timeout=30):
self.service_info = VikingDBMemoryService.get_service_info(host, region, scheme, connection_timeout,
socket_timeout)
self.api_info = VikingDBMemoryService.get_api_info()
super(VikingDBMemoryService, self).__init__(self.service_info, self.api_info)
if ak:
self.set_ak(ak)
if sk:
self.set_sk(sk)
if sts_token:
self.set_session_token(session_token=sts_token)
try:
self.get_body("Ping", {}, json.dumps({}))
except Exception as e:
raise VikingDBMemoryException(1000028, "missed", "host or region is incorrect".format(str(e))) from None
def setHeader(self, header):
api_info = VikingDBMemoryService.get_api_info()
for key in api_info:
for item in header:
api_info[key].header[item] = header[item]
self.api_info = api_info
@staticmethod
def get_service_info(host, region, scheme, connection_timeout, socket_timeout):
service_info = ServiceInfo(host, {"Host": host},
Credentials('', '', 'air', region), connection_timeout, socket_timeout,
scheme=scheme)
return service_info
@staticmethod
def get_api_info():
api_info = {
"CreateCollection": ApiInfo("POST", "/api/memory/collection/create", {}, {},
{'Accept': 'application/json', 'Content-Type': 'application/json'}),
"GetCollection": ApiInfo("POST", "/api/memory/collection/info", {}, {},
{'Accept': 'application/json', 'Content-Type': 'application/json'}),
"DropCollection": ApiInfo("POST", "/api/memory/collection/delete", {}, {},
{'Accept': 'application/json', 'Content-Type': 'application/json'}),
"UpdateCollection": ApiInfo("POST", "/api/memory/collection/update", {}, {},
{'Accept': 'application/json', 'Content-Type': 'application/json'}),
"SearchMemory": ApiInfo("POST", "/api/memory/search", {}, {},
{'Accept': 'application/json', 'Content-Type': 'application/json'}),
"AddSession": ApiInfo("POST", "/api/memory/session/add", {}, {},
{'Accept': 'application/json', 'Content-Type': 'application/json'}),
"Ping": ApiInfo("GET", "/api/memory/ping", {}, {},
{'Accept': 'application/json', 'Content-Type': 'application/json'}),
}
return api_info
def get_body(self, api, params, body):
if not (api in self.api_info):
raise Exception("no such api")
api_info = self.api_info[api]
r = self.prepare_request(api_info, params)
r.headers['Content-Type'] = 'application/json'
r.headers['Traffic-Source'] = 'SDK'
r.body = body
SignerV4.sign(r, self.service_info.credentials)
url = r.build()
resp = self.session.get(url, headers=r.headers, data=r.body,
timeout=(self.service_info.connection_timeout, self.service_info.socket_timeout))
if resp.status_code == 200:
return json.dumps(resp.json())
else:
raise Exception(resp.text.encode("utf-8"))
def get_body_exception(self, api, params, body):
try:
res = self.get_body(api, params, body)
except Exception as e:
try:
res_json = json.loads(e.args[0].decode("utf-8"))
except:
raise VikingDBMemoryException(1000028, "missed", "json load res error, res:{}".format(str(e))) from None
code = res_json.get("code", 1000028)
request_id = res_json.get("request_id", 1000028)
message = res_json.get("message", None)
raise VikingDBMemoryException(code, request_id, message)
if res == '':
raise VikingDBMemoryException(1000028, "missed",
"empty response due to unknown error, please contact customer service") from None
return res
def get_exception(self, api, params):
try:
res = self.get(api, params)
except Exception as e:
try:
res_json = json.loads(e.args[0].decode("utf-8"))
except:
raise VikingDBMemoryException(1000028, "missed", "json load res error, res:{}".format(str(e))) from None
code = res_json.get("code", 1000028)
request_id = res_json.get("request_id", 1000028)
message = res_json.get("message", None)
raise VikingDBMemoryException(code, request_id, message)
if res == '':
raise VikingDBMemoryException(1000028, "missed",
"empty response due to unknown error, please contact customer service") from None
return res
def create_collection(self, collection_name, description="", custom_event_type_schemas=[],
custom_entity_type_schemas=[], builtin_event_types=[], builtin_entity_types=[]):
params = {
"CollectionName": collection_name, "Description": description,
"CustomEventTypeSchemas": custom_event_type_schemas, "CustomEntityTypeSchemas": custom_entity_type_schemas,
"BuiltinEventTypes": builtin_event_types, "BuiltinEntityTypes": builtin_entity_types,
}
res = self.json("CreateCollection", {}, json.dumps(params))
return json.loads(res)
def get_collection(self, collection_name):
params = {"CollectionName": collection_name}
res = self.json("GetCollection", {}, json.dumps(params))
return json.loads(res)
def drop_collection(self, collection_name):
params = {"CollectionName": collection_name}
res = self.json("DropCollection", {}, json.dumps(params))
return json.loads(res)
def update_collection(self, collection_name, custom_event_type_schemas=[], custom_entity_type_schemas=[],
builtin_event_types=[], builtin_entity_types=[]):
params = {
"CollectionName": collection_name,
"CustomEventTypeSchemas": custom_event_type_schemas, "CustomEntityTypeSchemas": custom_entity_type_schemas,
"BuiltinEventTypes": builtin_event_types, "BuiltinEntityTypes": builtin_entity_types,
}
res = self.json("UpdateCollection", {}, json.dumps(params))
return json.loads(res)
def search_memory(self, collection_name, query, filter, limit=10):
params = {
"collection_name": collection_name,
"query": query,
"limit": limit,
"filter": filter,
}
res = self.json("SearchMemory", {}, json.dumps(params))
return json.loads(res)
def add_session(self, collection_name, session_id, messages, metadata, entities=None):
params = {
"collection_name": collection_name,
"session_id": session_id,
"messages": messages,
"metadata": metadata,
}
if entities is not None:
params["entities"] = entities
res = self.json("AddSession", {}, json.dumps(params))
return json.loads(res)
def handle_conversation_turn(self, llm_client, user_id, user_message, conversation_history):
"""处理一轮对话包括记忆搜索和LLM响应。"""
logger.info("\n" + "=" * 60)
logger.info(f"用户: {user_message}")
# 修复调用正确的search_relevant_memories方法
relevant_memories = self.search_relevant_memories(MEMORY_COLLECTION_NAME, user_id, user_message)
system_prompt = "你是一个富有同情心、善于倾听的AI伙伴拥有长期记忆能力。你的目标是为用户提供情感支持和温暖的陪伴。"
if relevant_memories:
memory_context = "\n".join(
[f"- {json.dumps(mem['memory_info'], ensure_ascii=False)}" for mem in relevant_memories])
system_prompt += f"\n\n这是我们过去的一些对话记忆,请参考:\n{memory_context}\n\n请利用这些信息来更好地理解和回应用户。"
logger.info("AI正在思考...")
try:
messages = [{"role": "system", "content": system_prompt}] + conversation_history + [
{"role": "user", "content": user_message}]
completion = llm_client.chat.completions.create(
model="doubao-seed-1-6-flash-250715",
messages=messages
)
assistant_reply = completion.choices[0].message.content
except Exception as e:
logger.info(f"LLM调用失败: {e}")
assistant_reply = "抱歉,我现在有点混乱,无法回应。我们可以稍后再聊吗?"
logger.info(f"伙伴: {assistant_reply}")
conversation_history.extend([
{"role": "user", "content": user_message},
{"role": "assistant", "content": assistant_reply}
])
return assistant_reply
def archive_conversation(self, user_id, assistant_id, conversation_history, topic_name):
"""将对话历史归档到记忆数据库。"""
if not conversation_history:
logger.info("没有对话可以归档。")
return False
logger.info(f"\n正在归档关于 '{topic_name}' 的对话...")
session_id = f"{topic_name}_{int(time.time())}"
metadata = {
"default_user_id": user_id,
"default_assistant_id": assistant_id,
"time": int(time.time() * 1000)
}
try:
self.add_session(
collection_name=MEMORY_COLLECTION_NAME,
session_id=session_id,
messages=conversation_history,
metadata=metadata
)
logger.info(f"对话已成功归档会话ID: {session_id}")
logger.info("正在等待记忆索引更新...")
return True
except Exception as e:
logger.info(f"归档对话失败: {e}")
return False
def wait_for_collection_ready(self, timeout=300, interval=10):
"""
等待集合准备就绪
:param timeout: 超时时间(秒)
:param interval: 检查间隔(秒)
:return: True if ready, False if timeout
"""
start_time = time.time()
while time.time() - start_time < timeout:
try:
collection_info = self.get_collection(MEMORY_COLLECTION_NAME)
# 打印完整集合信息用于调试
logger.info(f"集合详细信息: {json.dumps(collection_info, ensure_ascii=False)}")
# 尝试多种可能的状态字段名
status = collection_info.get("Status") or collection_info.get("status") or "UNKNOWN"
logger.info(f"集合 '{MEMORY_COLLECTION_NAME}' 当前状态: {status}")
# 检查是否为就绪状态可能的值READY, RUNNING, ACTIVE等
if status in ["READY", "RUNNING", "ACTIVE"]:
return True
time.sleep(interval)
except Exception as e:
logger.error(f"检查集合状态失败: {str(e)}")
time.sleep(interval)
logger.error(f"集合 '{MEMORY_COLLECTION_NAME}'{timeout}秒内未就绪")
return False
def initialize_services(self, ak=None, sk=None, ark_api_key=None):
"""初始化记忆数据库服务和LLM客户端"""
load_dotenv()
# 如果参数未提供,尝试从环境变量获取
if not ak:
ak = VOLC_ACCESSKEY
if not sk:
sk = VOLC_SECRETKEY
if not ark_api_key:
ark_api_key = VOLC_API_KEY
if not all([ak, sk, ark_api_key]):
raise ValueError("必须提供 VOLC_ACCESSKEY, VOLC_SECRETKEY, 和 VOLC_API_KEY。")
memory_service = VikingDBMemoryService(ak=ak, sk=sk)
llm_client = Ark(
base_url="https://ark.cn-beijing.volces.com/api/v3",
api_key=ark_api_key,
)
return memory_service, llm_client
def ensure_collection_exists(self, collection_name, description="",
builtin_event_types=["sys_event_v1", "sys_profile_collect_v1"],
builtin_entity_types=["sys_profile_v1"]):
"""检查记忆集合是否存在,如果不存在则创建。"""
try:
self.get_collection(collection_name)
logger.info(f"记忆集合 '{collection_name}' 已存在。")
except Exception as e:
error_message = str(e)
if "collection not exist" in error_message:
logger.info(f"记忆集合 '{collection_name}' 未找到,正在创建...")
try:
self.create_collection(
collection_name=collection_name,
description=description,
builtin_event_types=builtin_event_types,
builtin_entity_types=builtin_entity_types
)
logger.info(f"记忆集合 '{collection_name}' 创建成功。")
logger.info("等待集合准备就绪...")
except Exception as create_e:
logger.info(f"创建集合失败: {create_e}")
raise
else:
logger.info(f"检查集合时出错: {e}")
raise
def search_relevant_memories(self, collection_name, user_id, query, limit=3):
"""搜索与用户查询相关的记忆,并在索引构建中时重试。"""
logger.info(f"正在搜索与 '{query}' 相关的记忆...")
retry_attempt = 0
while True:
try:
filter_params = {
"user_id": user_id, # 修正为字符串类型
"memory_type": ["sys_event_v1", "sys_profile_v1"]
}
response = self.search_memory(
collection_name=collection_name,
query=query,
filter=filter_params,
limit=limit
)
memories = []
if response.get('data', {}).get('count', 0) > 0:
for result in response['data']['result_list']:
if 'memory_info' in result and result['memory_info']:
memories.append({
'memory_info': result['memory_info'],
'score': result['score']
})
if memories:
if retry_attempt > 0:
logger.info("重试后搜索成功。")
logger.info(f"找到 {len(memories)} 条相关记忆:")
for i, memory in enumerate(memories, 1):
logger.info(
f" {i}. (相关度: {memory['score']:.3f}): {json.dumps(memory['memory_info'], ensure_ascii=False, indent=2)}")
else:
logger.info("未找到相关记忆。")
return memories
except Exception as e:
error_message = str(e)
if "1000023" in error_message:
retry_attempt += 1
logger.info(f"记忆索引正在构建中。将在60秒后重试... (尝试次数 {retry_attempt})")
time.sleep(60)
else:
logger.info(f"搜索记忆时出错 (不可重试): {e}")
return []
def setup_memory_collection(self):
"""独立封装记忆体创建逻辑返回memory_service供测试使用"""
try:
self.ensure_collection_exists(MEMORY_COLLECTION_NAME)
logger.info(f"记忆体 '{MEMORY_COLLECTION_NAME}' 创建/验证成功")
# 添加集合就绪等待
logger.info("等待集合准备就绪...")
if self.wait_for_collection_ready():
logger.info(f"集合 '{MEMORY_COLLECTION_NAME}' 已就绪")
return self
else:
logger.info(f"集合 '{MEMORY_COLLECTION_NAME}' 未能就绪")
return None
except Exception as e:
logger.info(f"记忆体创建失败: {e}")
return None
def initialize_services():
"""初始化服务和LLM客户端"""
ak = VOLC_ACCESSKEY
sk = VOLC_SECRETKEY
ark_api_key = VOLC_API_KEY
if not all([ak, sk, ark_api_key]):
raise ValueError("必须在环境变量中设置 VOLC_ACCESSKEY, VOLC_SECRETKEY, 和 ARK_API_KEY。")
memory_service = VikingDBMemoryService(
ak=ak,
sk=sk,
host="api-knowledgebase.mlp.cn-beijing.volces.com",
region="cn-beijing"
)
llm_client = Ark(
base_url="https://ark.cn-beijing.volces.com/api/v3",
api_key=ark_api_key,
)
return memory_service, llm_client
def search_relevant_memories(memory_service, collection_name, user_id, query):
"""搜索与用户查询相关的记忆,并在索引构建中时重试。"""
logger.info(f"正在搜索与 '{query}' 相关的记忆...")
retry_attempt = 0
while True:
try:
filter_params = {
"user_id": [user_id],
"memory_type": ["sys_event_v1", "sys_profile_v1"]
}
response = memory_service.search_memory(
collection_name=collection_name,
query=query,
filter=filter_params,
limit=3
)
memories = []
if response.get('data', {}).get('count', 0) > 0:
for result in response['data']['result_list']:
if 'memory_info' in result and result['memory_info']:
memories.append({
'memory_info': result['memory_info'],
'score': result['score']
})
if memories:
if retry_attempt > 0:
logger.info("重试后搜索成功。")
logger.info(f"找到 {len(memories)} 条相关记忆:")
for i, memory in enumerate(memories, 1):
logger.info(
f" {i}. (相关度: {memory['score']:.3f}): {json.dumps(memory['memory_info'], ensure_ascii=False, indent=2)}")
else:
logger.info("未找到相关记忆。")
return memories
except Exception as e:
error_message = str(e)
if "1000023" in error_message:
retry_attempt += 1
logger.info(f"记忆索引正在构建中。将在60秒后重试... (尝试次数 {retry_attempt})")
time.sleep(60)
else:
logger.info(f"搜索记忆时出错 (不可重试): {e}")
return []
def handle_conversation_turn(memory_service, llm_client, collection_name, user_id, user_message, conversation_history):
"""处理一轮对话包括记忆搜索和LLM响应。"""
logger.info("\n" + "=" * 60)
logger.info(f"用户: {user_message}")
relevant_memories = search_relevant_memories(memory_service, collection_name, user_id, user_message)
system_prompt = "你是一个富有同情心、善于倾听的AI伙伴拥有长期记忆能力。你的目标是为用户提供情感支持和温暖的陪伴。"
if relevant_memories:
memory_context = "\n".join(
[f"- {json.dumps(mem['memory_info'], ensure_ascii=False)}" for mem in relevant_memories])
system_prompt += f"\n\n这是我们过去的一些对话记忆,请参考:\n{memory_context}\n\n请利用这些信息来更好地理解和回应用户。"
logger.info("AI正在思考...")
try:
messages = [{"role": "system", "content": system_prompt}] + conversation_history + [
{"role": "user", "content": user_message}]
completion = llm_client.chat.completions.create(
model="doubao-seed-1-6-flash-250715",
messages=messages
)
assistant_reply = completion.choices[0].message.content
except Exception as e:
logger.info(f"LLM调用失败: {e}")
assistant_reply = "抱歉,我现在有点混乱,无法回应。我们可以稍后再聊吗?"
logger.info(f"伙伴: {assistant_reply}")
conversation_history.extend([
{"role": "user", "content": user_message},
{"role": "assistant", "content": assistant_reply}
])
return assistant_reply
def archive_conversation(memory_service, collection_name, user_id, assistant_id, conversation_history, topic_name):
"""将对话历史归档到记忆数据库。"""
if not conversation_history:
logger.info("没有对话可以归档。")
return False
logger.info(f"\n正在归档关于 '{topic_name}' 的对话...")
session_id = f"{topic_name}_{int(time.time())}"
metadata = {
"default_user_id": user_id,
"default_assistant_id": assistant_id,
"time": int(time.time() * 1000)
}
try:
memory_service.add_session(
collection_name=collection_name,
session_id=session_id,
messages=conversation_history,
metadata=metadata
)
logger.info(f"对话已成功归档会话ID: {session_id}")
logger.info("正在等待记忆索引更新...")
return True
except Exception as e:
logger.info(f"归档对话失败: {e}")
return False