You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1273 lines
49 KiB

import asyncio
import os
from dataclasses import dataclass, field
from typing import Any, Union, final
import time
import numpy as np
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
from ..base import BaseGraphStorage, BaseKVStorage, BaseVectorStorage
from ..namespace import NameSpace, is_namespace
from ..utils import logger
import pipmaster as pm
import configparser
if not pm.is_installed("pymysql"):
pm.install("pymysql")
if not pm.is_installed("sqlalchemy"):
pm.install("sqlalchemy")
from sqlalchemy import create_engine, text # type: ignore
def sanitize_sensitive_info(data: dict) -> dict:
sanitized_data = data.copy()
sensitive_fields = [
"password",
"user",
"host",
"database",
"port",
"ssl_verify_cert",
"ssl_verify_identity",
]
for field_name in sensitive_fields:
if field_name in sanitized_data:
sanitized_data[field_name] = "***"
return sanitized_data
class TiDB:
def __init__(self, config, **kwargs):
self.host = config.get("host", None)
self.port = config.get("port", None)
self.user = config.get("user", None)
self.password = config.get("password", None)
self.database = config.get("database", None)
self.workspace = config.get("workspace", None)
connection_string = (
f"mysql+pymysql://{self.user}:{self.password}@{self.host}:{self.port}/{self.database}"
f"?ssl_verify_cert=true&ssl_verify_identity=true"
)
try:
self.engine = create_engine(connection_string)
logger.info("Connected to TiDB database")
except Exception as e:
logger.error("Failed to connect to TiDB database")
logger.error(f"TiDB database error: {e}")
raise
async def _migrate_timestamp_columns(self):
"""Migrate timestamp columns in tables to timezone-aware types, assuming original data is in UTC"""
# Not implemented yet
pass
async def check_tables(self):
# First create all tables
for k, v in TABLES.items():
try:
await self.query(f"SELECT 1 FROM {k}".format(k=k))
except Exception as e:
logger.error("Failed to check table in TiDB database")
logger.error(f"TiDB database error: {e}")
try:
await self.execute(v["ddl"])
logger.info("Created table in TiDB database")
except Exception as e:
logger.error("Failed to create table in TiDB database")
logger.error(f"TiDB database error: {e}")
# After all tables are created, try to migrate timestamp fields
try:
await self._migrate_timestamp_columns()
except Exception as e:
logger.error(f"TiDB, Failed to migrate timestamp columns: {e}")
# Don't raise exceptions, allow initialization process to continue
async def query(
self, sql: str, params: dict = None, multirows: bool = False
) -> Union[dict, None]:
if params is None:
params = {"workspace": self.workspace}
else:
params.update({"workspace": self.workspace})
with self.engine.connect() as conn, conn.begin():
try:
result = conn.execute(text(sql), params)
except Exception as e:
sanitized_params = sanitize_sensitive_info(params)
sanitized_error = sanitize_sensitive_info({"error": str(e)})
logger.error(
f"Tidb database,\nsql:{sql},\nparams:{sanitized_params},\nerror:{sanitized_error}"
)
raise
if multirows:
rows = result.all()
if rows:
data = [dict(zip(result.keys(), row)) for row in rows]
else:
data = []
else:
row = result.first()
if row:
data = dict(zip(result.keys(), row))
else:
data = None
return data
async def execute(self, sql: str, data: list | dict = None):
# logger.info("go into TiDBDB execute method")
try:
with self.engine.connect() as conn, conn.begin():
if data is None:
conn.execute(text(sql))
else:
conn.execute(text(sql), parameters=data)
except Exception as e:
sanitized_data = sanitize_sensitive_info(data) if data else None
sanitized_error = sanitize_sensitive_info({"error": str(e)})
logger.error(
f"Tidb database,\nsql:{sql},\ndata:{sanitized_data},\nerror:{sanitized_error}"
)
raise
class ClientManager:
_instances: dict[str, Any] = {"db": None, "ref_count": 0}
_lock = asyncio.Lock()
@staticmethod
def get_config() -> dict[str, Any]:
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
return {
"host": os.environ.get(
"TIDB_HOST",
config.get("tidb", "host", fallback="localhost"),
),
"port": os.environ.get(
"TIDB_PORT", config.get("tidb", "port", fallback=4000)
),
"user": os.environ.get(
"TIDB_USER",
config.get("tidb", "user", fallback=None),
),
"password": os.environ.get(
"TIDB_PASSWORD",
config.get("tidb", "password", fallback=None),
),
"database": os.environ.get(
"TIDB_DATABASE",
config.get("tidb", "database", fallback=None),
),
"workspace": os.environ.get(
"TIDB_WORKSPACE",
config.get("tidb", "workspace", fallback="default"),
),
}
@classmethod
async def get_client(cls) -> TiDB:
async with cls._lock:
if cls._instances["db"] is None:
config = ClientManager.get_config()
db = TiDB(config)
await db.check_tables()
cls._instances["db"] = db
cls._instances["ref_count"] = 0
cls._instances["ref_count"] += 1
return cls._instances["db"]
@classmethod
async def release_client(cls, db: TiDB):
async with cls._lock:
if db is not None:
if db is cls._instances["db"]:
cls._instances["ref_count"] -= 1
if cls._instances["ref_count"] == 0:
cls._instances["db"] = None
@final
@dataclass
class TiDBKVStorage(BaseKVStorage):
db: TiDB = field(default=None)
def __post_init__(self):
self._data = {}
self._max_batch_size = self.global_config["embedding_batch_num"]
async def initialize(self):
if self.db is None:
self.db = await ClientManager.get_client()
async def finalize(self):
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
################ QUERY METHODS ################
async def get_all(self) -> dict[str, Any]:
"""Get all data from storage
Returns:
Dictionary containing all stored data
"""
async with self._storage_lock:
return dict(self._data)
async def get_by_id(self, id: str) -> dict[str, Any] | None:
"""Fetch doc_full data by id."""
SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
params = {"id": id}
response = await self.db.query(SQL, params)
return response if response else None
# Query by id
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
"""Fetch doc_chunks data by id"""
SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
ids=",".join([f"'{id}'" for id in ids])
)
return await self.db.query(SQL, multirows=True)
async def filter_keys(self, keys: set[str]) -> set[str]:
SQL = SQL_TEMPLATES["filter_keys"].format(
table_name=namespace_to_table_name(self.namespace),
id_field=namespace_to_id(self.namespace),
ids=",".join([f"'{id}'" for id in keys]),
)
try:
await self.db.query(SQL)
except Exception as e:
logger.error(f"Tidb database,\nsql:{SQL},\nkeys:{keys},\nerror:{e}")
res = await self.db.query(SQL, multirows=True)
if res:
exist_keys = [key["id"] for key in res]
data = set([s for s in keys if s not in exist_keys])
else:
exist_keys = []
data = set([s for s in keys if s not in exist_keys])
return data
################ INSERT full_doc AND chunks ################
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} to {self.namespace}")
if not data:
return
left_data = {k: v for k, v in data.items() if k not in self._data}
self._data.update(left_data)
if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
list_data = [
{
"__id__": k,
**{k1: v1 for k1, v1 in v.items()},
}
for k, v in data.items()
]
contents = [v["content"] for v in data.values()]
batches = [
contents[i : i + self._max_batch_size]
for i in range(0, len(contents), self._max_batch_size)
]
embeddings_list = await asyncio.gather(
*[self.embedding_func(batch) for batch in batches]
)
embeddings = np.concatenate(embeddings_list)
for i, d in enumerate(list_data):
d["__vector__"] = embeddings[i]
# Get current time as UNIX timestamp
current_time = int(time.time())
merge_sql = SQL_TEMPLATES["upsert_chunk"]
data = []
for item in list_data:
data.append(
{
"id": item["__id__"],
"content": item["content"],
"tokens": item["tokens"],
"chunk_order_index": item["chunk_order_index"],
"full_doc_id": item["full_doc_id"],
"content_vector": f"{item['__vector__'].tolist()}",
"workspace": self.db.workspace,
"timestamp": current_time,
}
)
await self.db.execute(merge_sql, data)
if is_namespace(self.namespace, NameSpace.KV_STORE_FULL_DOCS):
merge_sql = SQL_TEMPLATES["upsert_doc_full"]
data = []
for k, v in self._data.items():
data.append(
{
"id": k,
"content": v["content"],
"workspace": self.db.workspace,
}
)
await self.db.execute(merge_sql, data)
return left_data
async def index_done_callback(self) -> None:
# Ti handles persistence automatically
pass
async def delete(self, ids: list[str]) -> None:
"""Delete records with specified IDs from the storage.
Args:
ids: List of record IDs to be deleted
"""
if not ids:
return
try:
table_name = namespace_to_table_name(self.namespace)
id_field = namespace_to_id(self.namespace)
if not table_name or not id_field:
logger.error(f"Unknown namespace for deletion: {self.namespace}")
return
ids_list = ",".join([f"'{id}'" for id in ids])
delete_sql = f"DELETE FROM {table_name} WHERE workspace = :workspace AND {id_field} IN ({ids_list})"
await self.db.execute(delete_sql, {"workspace": self.db.workspace})
logger.info(
f"Successfully deleted {len(ids)} records from {self.namespace}"
)
except Exception as e:
logger.error(f"Error deleting records from {self.namespace}: {e}")
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
"""Delete specific records from storage by cache mode
Args:
modes (list[str]): List of cache modes to be dropped from storage
Returns:
bool: True if successful, False otherwise
"""
if not modes:
return False
try:
table_name = namespace_to_table_name(self.namespace)
if not table_name:
return False
if table_name != "LIGHTRAG_LLM_CACHE":
return False
# Build MySQL style IN query
modes_list = ", ".join([f"'{mode}'" for mode in modes])
sql = f"""
DELETE FROM {table_name}
WHERE workspace = :workspace
AND mode IN ({modes_list})
"""
logger.info(f"Deleting cache by modes: {modes}")
await self.db.execute(sql, {"workspace": self.db.workspace})
return True
except Exception as e:
logger.error(f"Error deleting cache by modes {modes}: {e}")
return False
async def drop(self) -> dict[str, str]:
"""Drop the storage"""
try:
table_name = namespace_to_table_name(self.namespace)
if not table_name:
return {
"status": "error",
"message": f"Unknown namespace: {self.namespace}",
}
drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format(
table_name=table_name
)
await self.db.execute(drop_sql, {"workspace": self.db.workspace})
return {"status": "success", "message": "data dropped"}
except Exception as e:
return {"status": "error", "message": str(e)}
@final
@dataclass
class TiDBVectorDBStorage(BaseVectorStorage):
db: TiDB | None = field(default=None)
def __post_init__(self):
self._client_file_name = os.path.join(
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
)
self._max_batch_size = self.global_config["embedding_batch_num"]
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
cosine_threshold = config.get("cosine_better_than_threshold")
if cosine_threshold is None:
raise ValueError(
"cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
)
self.cosine_better_than_threshold = cosine_threshold
async def initialize(self):
if self.db is None:
self.db = await ClientManager.get_client()
async def finalize(self):
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
async def query(
self, query: str, top_k: int, ids: list[str] | None = None
) -> list[dict[str, Any]]:
"""Search from tidb vector"""
embeddings = await self.embedding_func(
[query], _priority=5
) # higher priority for query
embedding = embeddings[0]
embedding_string = "[" + ", ".join(map(str, embedding.tolist())) + "]"
params = {
"embedding_string": embedding_string,
"top_k": top_k,
"better_than_threshold": self.cosine_better_than_threshold,
}
results = await self.db.query(
SQL_TEMPLATES[self.namespace], params=params, multirows=True
)
if not results:
return []
return results
###### INSERT entities And relationships ######
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
logger.info(f"Inserting {len(data)} to {self.namespace}")
if not data:
return
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
# Get current time as UNIX timestamp
import time
current_time = int(time.time())
list_data = [
{
"id": k,
"timestamp": current_time,
**{k1: v1 for k1, v1 in v.items()},
}
for k, v in data.items()
]
contents = [v["content"] for v in data.values()]
batches = [
contents[i : i + self._max_batch_size]
for i in range(0, len(contents), self._max_batch_size)
]
embedding_tasks = [self.embedding_func(batch) for batch in batches]
embeddings_list = await asyncio.gather(*embedding_tasks)
embeddings = np.concatenate(embeddings_list)
for i, d in enumerate(list_data):
d["content_vector"] = embeddings[i]
if is_namespace(self.namespace, NameSpace.VECTOR_STORE_CHUNKS):
for item in list_data:
param = {
"id": item["id"],
"content": item["content"],
"tokens": item.get("tokens", 0),
"chunk_order_index": item.get("chunk_order_index", 0),
"full_doc_id": item.get("full_doc_id", ""),
"content_vector": f"{item['content_vector'].tolist()}",
"workspace": self.db.workspace,
"timestamp": item["timestamp"],
}
await self.db.execute(SQL_TEMPLATES["upsert_chunk"], param)
elif is_namespace(self.namespace, NameSpace.VECTOR_STORE_ENTITIES):
for item in list_data:
param = {
"id": item["id"],
"name": item["entity_name"],
"content": item["content"],
"content_vector": f"{item['content_vector'].tolist()}",
"workspace": self.db.workspace,
"timestamp": item["timestamp"],
}
await self.db.execute(SQL_TEMPLATES["upsert_entity"], param)
elif is_namespace(self.namespace, NameSpace.VECTOR_STORE_RELATIONSHIPS):
for item in list_data:
param = {
"id": item["id"],
"source_name": item["src_id"],
"target_name": item["tgt_id"],
"content": item["content"],
"content_vector": f"{item['content_vector'].tolist()}",
"workspace": self.db.workspace,
"timestamp": item["timestamp"],
}
await self.db.execute(SQL_TEMPLATES["upsert_relationship"], param)
async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
SQL = SQL_TEMPLATES["get_by_status_" + self.namespace]
params = {"workspace": self.db.workspace, "status": status}
return await self.db.query(SQL, params, multirows=True)
async def delete(self, ids: list[str]) -> None:
"""Delete vectors with specified IDs from the storage.
Args:
ids: List of vector IDs to be deleted
"""
if not ids:
return
table_name = namespace_to_table_name(self.namespace)
id_field = namespace_to_id(self.namespace)
if not table_name or not id_field:
logger.error(f"Unknown namespace for vector deletion: {self.namespace}")
return
ids_list = ",".join([f"'{id}'" for id in ids])
delete_sql = f"DELETE FROM {table_name} WHERE workspace = :workspace AND {id_field} IN ({ids_list})"
try:
await self.db.execute(delete_sql, {"workspace": self.db.workspace})
logger.debug(
f"Successfully deleted {len(ids)} vectors from {self.namespace}"
)
except Exception as e:
logger.error(f"Error while deleting vectors from {self.namespace}: {e}")
async def delete_entity(self, entity_name: str) -> None:
"""Delete an entity by its name from the vector storage.
Args:
entity_name: The name of the entity to delete
"""
try:
# Construct SQL to delete the entity
delete_sql = """DELETE FROM LIGHTRAG_GRAPH_NODES
WHERE workspace = :workspace AND name = :entity_name"""
await self.db.execute(
delete_sql, {"workspace": self.db.workspace, "entity_name": entity_name}
)
logger.debug(f"Successfully deleted entity {entity_name}")
except Exception as e:
logger.error(f"Error deleting entity {entity_name}: {e}")
async def delete_entity_relation(self, entity_name: str) -> None:
"""Delete all relations associated with an entity.
Args:
entity_name: The name of the entity whose relations should be deleted
"""
try:
# Delete relations where the entity is either the source or target
delete_sql = """DELETE FROM LIGHTRAG_GRAPH_EDGES
WHERE workspace = :workspace AND (source_name = :entity_name OR target_name = :entity_name)"""
await self.db.execute(
delete_sql, {"workspace": self.db.workspace, "entity_name": entity_name}
)
logger.debug(f"Successfully deleted relations for entity {entity_name}")
except Exception as e:
logger.error(f"Error deleting relations for entity {entity_name}: {e}")
async def index_done_callback(self) -> None:
# Ti handles persistence automatically
pass
async def drop(self) -> dict[str, str]:
"""Drop the storage"""
try:
table_name = namespace_to_table_name(self.namespace)
if not table_name:
return {
"status": "error",
"message": f"Unknown namespace: {self.namespace}",
}
drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format(
table_name=table_name
)
await self.db.execute(drop_sql, {"workspace": self.db.workspace})
return {"status": "success", "message": "data dropped"}
except Exception as e:
return {"status": "error", "message": str(e)}
async def get_by_id(self, id: str) -> dict[str, Any] | None:
"""Get vector data by its ID
Args:
id: The unique identifier of the vector
Returns:
The vector data if found, or None if not found
"""
try:
# Determine which table to query based on namespace
if self.namespace == NameSpace.VECTOR_STORE_ENTITIES:
sql_template = """
SELECT entity_id as id, name as entity_name, entity_type, description, content,
UNIX_TIMESTAMP(createtime) as created_at
FROM LIGHTRAG_GRAPH_NODES
WHERE entity_id = :entity_id AND workspace = :workspace
"""
params = {"entity_id": id, "workspace": self.db.workspace}
elif self.namespace == NameSpace.VECTOR_STORE_RELATIONSHIPS:
sql_template = """
SELECT relation_id as id, source_name as src_id, target_name as tgt_id,
keywords, description, content, UNIX_TIMESTAMP(createtime) as created_at
FROM LIGHTRAG_GRAPH_EDGES
WHERE relation_id = :relation_id AND workspace = :workspace
"""
params = {"relation_id": id, "workspace": self.db.workspace}
elif self.namespace == NameSpace.VECTOR_STORE_CHUNKS:
sql_template = """
SELECT chunk_id as id, content, tokens, chunk_order_index, full_doc_id,
UNIX_TIMESTAMP(createtime) as created_at
FROM LIGHTRAG_DOC_CHUNKS
WHERE chunk_id = :chunk_id AND workspace = :workspace
"""
params = {"chunk_id": id, "workspace": self.db.workspace}
else:
logger.warning(
f"Namespace {self.namespace} not supported for get_by_id"
)
return None
result = await self.db.query(sql_template, params=params)
return result
except Exception as e:
logger.error(f"Error retrieving vector data for ID {id}: {e}")
return None
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
"""Get multiple vector data by their IDs
Args:
ids: List of unique identifiers
Returns:
List of vector data objects that were found
"""
if not ids:
return []
try:
# Format IDs for SQL IN clause
ids_str = ", ".join([f"'{id}'" for id in ids])
# Determine which table to query based on namespace
if self.namespace == NameSpace.VECTOR_STORE_ENTITIES:
sql_template = f"""
SELECT entity_id as id, name as entity_name, entity_type, description, content,
UNIX_TIMESTAMP(createtime) as created_at
FROM LIGHTRAG_GRAPH_NODES
WHERE entity_id IN ({ids_str}) AND workspace = :workspace
"""
elif self.namespace == NameSpace.VECTOR_STORE_RELATIONSHIPS:
sql_template = f"""
SELECT relation_id as id, source_name as src_id, target_name as tgt_id,
keywords, description, content, UNIX_TIMESTAMP(createtime) as created_at
FROM LIGHTRAG_GRAPH_EDGES
WHERE relation_id IN ({ids_str}) AND workspace = :workspace
"""
elif self.namespace == NameSpace.VECTOR_STORE_CHUNKS:
sql_template = f"""
SELECT chunk_id as id, content, tokens, chunk_order_index, full_doc_id,
UNIX_TIMESTAMP(createtime) as created_at
FROM LIGHTRAG_DOC_CHUNKS
WHERE chunk_id IN ({ids_str}) AND workspace = :workspace
"""
else:
logger.warning(
f"Namespace {self.namespace} not supported for get_by_ids"
)
return []
params = {"workspace": self.db.workspace}
results = await self.db.query(sql_template, params=params, multirows=True)
return results if results else []
except Exception as e:
logger.error(f"Error retrieving vector data for IDs {ids}: {e}")
return []
@final
@dataclass
class TiDBGraphStorage(BaseGraphStorage):
db: TiDB = field(default=None)
def __post_init__(self):
self._max_batch_size = self.global_config["embedding_batch_num"]
async def initialize(self):
if self.db is None:
self.db = await ClientManager.get_client()
async def finalize(self):
if self.db is not None:
await ClientManager.release_client(self.db)
self.db = None
#################### upsert method ################
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
entity_name = node_id
entity_type = node_data["entity_type"]
description = node_data["description"]
source_id = node_data["source_id"]
logger.debug(f"entity_name:{entity_name}, entity_type:{entity_type}")
content = entity_name + description
contents = [content]
batches = [
contents[i : i + self._max_batch_size]
for i in range(0, len(contents), self._max_batch_size)
]
embeddings_list = await asyncio.gather(
*[self.embedding_func(batch) for batch in batches]
)
embeddings = np.concatenate(embeddings_list)
content_vector = embeddings[0]
sql = SQL_TEMPLATES["upsert_node"]
data = {
"workspace": self.db.workspace,
"name": entity_name,
"entity_type": entity_type,
"description": description,
"source_chunk_id": source_id,
"content": content,
"content_vector": f"{content_vector.tolist()}",
}
await self.db.execute(sql, data)
async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
) -> None:
source_name = source_node_id
target_name = target_node_id
weight = edge_data["weight"]
keywords = edge_data["keywords"]
description = edge_data["description"]
source_chunk_id = edge_data["source_id"]
logger.debug(
f"source_name:{source_name}, target_name:{target_name}, keywords: {keywords}"
)
content = keywords + source_name + target_name + description
contents = [content]
batches = [
contents[i : i + self._max_batch_size]
for i in range(0, len(contents), self._max_batch_size)
]
embeddings_list = await asyncio.gather(
*[self.embedding_func(batch) for batch in batches]
)
embeddings = np.concatenate(embeddings_list)
content_vector = embeddings[0]
merge_sql = SQL_TEMPLATES["upsert_edge"]
data = {
"workspace": self.db.workspace,
"source_name": source_name,
"target_name": target_name,
"weight": weight,
"keywords": keywords,
"description": description,
"source_chunk_id": source_chunk_id,
"content": content,
"content_vector": f"{content_vector.tolist()}",
}
await self.db.execute(merge_sql, data)
# Query
async def has_node(self, node_id: str) -> bool:
sql = SQL_TEMPLATES["has_entity"]
param = {"name": node_id, "workspace": self.db.workspace}
has = await self.db.query(sql, param)
return has["cnt"] != 0
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
sql = SQL_TEMPLATES["has_relationship"]
param = {
"source_name": source_node_id,
"target_name": target_node_id,
"workspace": self.db.workspace,
}
has = await self.db.query(sql, param)
return has["cnt"] != 0
async def node_degree(self, node_id: str) -> int:
sql = SQL_TEMPLATES["node_degree"]
param = {"name": node_id, "workspace": self.db.workspace}
result = await self.db.query(sql, param)
return result["cnt"]
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
degree = await self.node_degree(src_id) + await self.node_degree(tgt_id)
return degree
async def get_node(self, node_id: str) -> dict[str, str] | None:
sql = SQL_TEMPLATES["get_node"]
param = {"name": node_id, "workspace": self.db.workspace}
return await self.db.query(sql, param)
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> dict[str, str] | None:
sql = SQL_TEMPLATES["get_edge"]
param = {
"source_name": source_node_id,
"target_name": target_node_id,
"workspace": self.db.workspace,
}
return await self.db.query(sql, param)
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
sql = SQL_TEMPLATES["get_node_edges"]
param = {"source_name": source_node_id, "workspace": self.db.workspace}
res = await self.db.query(sql, param, multirows=True)
if res:
data = [(i["source_name"], i["target_name"]) for i in res]
return data
else:
return []
async def index_done_callback(self) -> None:
# Ti handles persistence automatically
pass
async def drop(self) -> dict[str, str]:
"""Drop the storage"""
try:
drop_sql = """
DELETE FROM LIGHTRAG_GRAPH_EDGES WHERE workspace = :workspace;
DELETE FROM LIGHTRAG_GRAPH_NODES WHERE workspace = :workspace;
"""
await self.db.execute(drop_sql, {"workspace": self.db.workspace})
return {"status": "success", "message": "graph data dropped"}
except Exception as e:
return {"status": "error", "message": str(e)}
async def delete_node(self, node_id: str) -> None:
"""Delete a node and all its related edges
Args:
node_id: The ID of the node to delete
"""
# First delete all edges related to this node
await self.db.execute(
SQL_TEMPLATES["delete_node_edges"],
{"name": node_id, "workspace": self.db.workspace},
)
# Then delete the node itself
await self.db.execute(
SQL_TEMPLATES["delete_node"],
{"name": node_id, "workspace": self.db.workspace},
)
logger.debug(
f"Node {node_id} and its related edges have been deleted from the graph"
)
async def get_all_labels(self) -> list[str]:
"""Get all entity types (labels) in the database
Returns:
List of labels sorted alphabetically
"""
result = await self.db.query(
SQL_TEMPLATES["get_all_labels"],
{"workspace": self.db.workspace},
multirows=True,
)
if not result:
return []
# Extract all labels
return [item["label"] for item in result]
async def get_knowledge_graph(
self, node_label: str, max_depth: int = 5
) -> KnowledgeGraph:
"""
Get a connected subgraph of nodes matching the specified label
Maximum number of nodes is limited by MAX_GRAPH_NODES environment variable (default: 1000)
Args:
node_label: The node label to match
max_depth: Maximum depth of the subgraph
Returns:
KnowledgeGraph object containing nodes and edges
"""
result = KnowledgeGraph()
MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000))
# Get matching nodes
if node_label == "*":
# Handle special case, get all nodes
node_results = await self.db.query(
SQL_TEMPLATES["get_all_nodes"],
{"workspace": self.db.workspace, "max_nodes": MAX_GRAPH_NODES},
multirows=True,
)
else:
# Get nodes matching the label
label_pattern = f"%{node_label}%"
node_results = await self.db.query(
SQL_TEMPLATES["get_matching_nodes"],
{"workspace": self.db.workspace, "label_pattern": label_pattern},
multirows=True,
)
if not node_results:
logger.warning(f"No nodes found matching label {node_label}")
return result
# Limit the number of returned nodes
if len(node_results) > MAX_GRAPH_NODES:
node_results = node_results[:MAX_GRAPH_NODES]
# Extract node names for edge query
node_names = [node["name"] for node in node_results]
node_names_str = ",".join([f"'{name}'" for name in node_names])
# Add nodes to result
for node in node_results:
node_properties = {
k: v for k, v in node.items() if k not in ["id", "name", "entity_type"]
}
result.nodes.append(
KnowledgeGraphNode(
id=node["name"],
labels=[node["entity_type"]]
if node.get("entity_type")
else [node["name"]],
properties=node_properties,
)
)
# Get related edges
edge_results = await self.db.query(
SQL_TEMPLATES["get_related_edges"].format(node_names=node_names_str),
{"workspace": self.db.workspace},
multirows=True,
)
if edge_results:
# Add edges to result
for edge in edge_results:
# Only include edges related to selected nodes
if (
edge["source_name"] in node_names
and edge["target_name"] in node_names
):
edge_id = f"{edge['source_name']}-{edge['target_name']}"
edge_properties = {
k: v
for k, v in edge.items()
if k not in ["id", "source_name", "target_name"]
}
result.edges.append(
KnowledgeGraphEdge(
id=edge_id,
type="RELATED",
source=edge["source_name"],
target=edge["target_name"],
properties=edge_properties,
)
)
logger.info(
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
)
return result
async def remove_nodes(self, nodes: list[str]):
"""Delete multiple nodes
Args:
nodes: List of node IDs to delete
"""
for node_id in nodes:
await self.delete_node(node_id)
async def remove_edges(self, edges: list[tuple[str, str]]):
"""Delete multiple edges
Args:
edges: List of edges to delete, each edge is a (source, target) tuple
"""
for source, target in edges:
await self.db.execute(
SQL_TEMPLATES["remove_multiple_edges"],
{"source": source, "target": target, "workspace": self.db.workspace},
)
N_T = {
NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",
NameSpace.KV_STORE_TEXT_CHUNKS: "LIGHTRAG_DOC_CHUNKS",
NameSpace.VECTOR_STORE_CHUNKS: "LIGHTRAG_DOC_CHUNKS",
NameSpace.VECTOR_STORE_ENTITIES: "LIGHTRAG_GRAPH_NODES",
NameSpace.VECTOR_STORE_RELATIONSHIPS: "LIGHTRAG_GRAPH_EDGES",
}
N_ID = {
NameSpace.KV_STORE_FULL_DOCS: "doc_id",
NameSpace.KV_STORE_TEXT_CHUNKS: "chunk_id",
NameSpace.VECTOR_STORE_CHUNKS: "chunk_id",
NameSpace.VECTOR_STORE_ENTITIES: "entity_id",
NameSpace.VECTOR_STORE_RELATIONSHIPS: "relation_id",
}
def namespace_to_table_name(namespace: str) -> str:
for k, v in N_T.items():
if is_namespace(namespace, k):
return v
def namespace_to_id(namespace: str) -> str:
for k, v in N_ID.items():
if is_namespace(namespace, k):
return v
TABLES = {
"LIGHTRAG_DOC_FULL": {
"ddl": """
CREATE TABLE LIGHTRAG_DOC_FULL (
`id` BIGINT PRIMARY KEY AUTO_RANDOM,
`doc_id` VARCHAR(256) NOT NULL,
`workspace` varchar(1024),
`content` LONGTEXT,
`meta` JSON,
`createtime` TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
`updatetime` TIMESTAMP DEFAULT NULL,
UNIQUE KEY (`doc_id`)
);
"""
},
"LIGHTRAG_DOC_CHUNKS": {
"ddl": """
CREATE TABLE LIGHTRAG_DOC_CHUNKS (
`id` BIGINT PRIMARY KEY AUTO_RANDOM,
`chunk_id` VARCHAR(256) NOT NULL,
`full_doc_id` VARCHAR(256) NOT NULL,
`workspace` varchar(1024),
`chunk_order_index` INT,
`tokens` INT,
`content` LONGTEXT,
`content_vector` VECTOR,
`createtime` TIMESTAMP,
`updatetime` TIMESTAMP,
UNIQUE KEY (`chunk_id`)
);
"""
},
"LIGHTRAG_GRAPH_NODES": {
"ddl": """
CREATE TABLE LIGHTRAG_GRAPH_NODES (
`id` BIGINT PRIMARY KEY AUTO_RANDOM,
`entity_id` VARCHAR(256),
`workspace` varchar(1024),
`name` VARCHAR(2048),
`entity_type` VARCHAR(1024),
`description` LONGTEXT,
`source_chunk_id` VARCHAR(256),
`content` LONGTEXT,
`content_vector` VECTOR,
`createtime` TIMESTAMP,
`updatetime` TIMESTAMP,
KEY (`entity_id`)
);
"""
},
"LIGHTRAG_GRAPH_EDGES": {
"ddl": """
CREATE TABLE LIGHTRAG_GRAPH_EDGES (
`id` BIGINT PRIMARY KEY AUTO_RANDOM,
`relation_id` VARCHAR(256),
`workspace` varchar(1024),
`source_name` VARCHAR(2048),
`target_name` VARCHAR(2048),
`weight` DECIMAL,
`keywords` TEXT,
`description` LONGTEXT,
`source_chunk_id` varchar(256),
`content` LONGTEXT,
`content_vector` VECTOR,
`createtime` TIMESTAMP,
`updatetime` TIMESTAMP,
KEY (`relation_id`)
);
"""
},
"LIGHTRAG_LLM_CACHE": {
"ddl": """
CREATE TABLE LIGHTRAG_LLM_CACHE (
id BIGINT PRIMARY KEY AUTO_INCREMENT,
send TEXT,
return TEXT,
model VARCHAR(1024),
createtime DATETIME DEFAULT CURRENT_TIMESTAMP,
updatetime DATETIME DEFAULT NULL
);
"""
},
}
SQL_TEMPLATES = {
# SQL for KVStorage
"get_by_id_full_docs": "SELECT doc_id as id, IFNULL(content, '') AS content FROM LIGHTRAG_DOC_FULL WHERE doc_id = :id AND workspace = :workspace",
"get_by_id_text_chunks": "SELECT chunk_id as id, tokens, IFNULL(content, '') AS content, chunk_order_index, full_doc_id FROM LIGHTRAG_DOC_CHUNKS WHERE chunk_id = :id AND workspace = :workspace",
"get_by_ids_full_docs": "SELECT doc_id as id, IFNULL(content, '') AS content FROM LIGHTRAG_DOC_FULL WHERE doc_id IN ({ids}) AND workspace = :workspace",
"get_by_ids_text_chunks": "SELECT chunk_id as id, tokens, IFNULL(content, '') AS content, chunk_order_index, full_doc_id FROM LIGHTRAG_DOC_CHUNKS WHERE chunk_id IN ({ids}) AND workspace = :workspace",
"filter_keys": "SELECT {id_field} AS id FROM {table_name} WHERE {id_field} IN ({ids}) AND workspace = :workspace",
# SQL for Merge operations (TiDB version with INSERT ... ON DUPLICATE KEY UPDATE)
"upsert_doc_full": """
INSERT INTO LIGHTRAG_DOC_FULL (doc_id, content, workspace)
VALUES (:id, :content, :workspace)
ON DUPLICATE KEY UPDATE content = VALUES(content), workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP
""",
"upsert_chunk": """
INSERT INTO LIGHTRAG_DOC_CHUNKS(chunk_id, content, tokens, chunk_order_index, full_doc_id, content_vector, workspace, createtime, updatetime)
VALUES (:id, :content, :tokens, :chunk_order_index, :full_doc_id, :content_vector, :workspace, FROM_UNIXTIME(:timestamp), FROM_UNIXTIME(:timestamp))
ON DUPLICATE KEY UPDATE
content = VALUES(content), tokens = VALUES(tokens), chunk_order_index = VALUES(chunk_order_index),
full_doc_id = VALUES(full_doc_id), content_vector = VALUES(content_vector), workspace = VALUES(workspace), updatetime = FROM_UNIXTIME(:timestamp)
""",
# SQL for VectorStorage
"entities": """SELECT n.name as entity_name, UNIX_TIMESTAMP(n.createtime) as created_at FROM
(SELECT entity_id as id, name, createtime, VEC_COSINE_DISTANCE(content_vector,:embedding_string) as distance
FROM LIGHTRAG_GRAPH_NODES WHERE workspace = :workspace) n
WHERE n.distance>:better_than_threshold ORDER BY n.distance DESC LIMIT :top_k
""",
"relationships": """SELECT e.source_name as src_id, e.target_name as tgt_id, UNIX_TIMESTAMP(e.createtime) as created_at FROM
(SELECT source_name, target_name, createtime, VEC_COSINE_DISTANCE(content_vector, :embedding_string) as distance
FROM LIGHTRAG_GRAPH_EDGES WHERE workspace = :workspace) e
WHERE e.distance>:better_than_threshold ORDER BY e.distance DESC LIMIT :top_k
""",
"chunks": """SELECT c.id, UNIX_TIMESTAMP(c.createtime) as created_at FROM
(SELECT chunk_id as id, createtime, VEC_COSINE_DISTANCE(content_vector, :embedding_string) as distance
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace = :workspace) c
WHERE c.distance>:better_than_threshold ORDER BY c.distance DESC LIMIT :top_k
""",
"has_entity": """
SELECT COUNT(id) AS cnt FROM LIGHTRAG_GRAPH_NODES WHERE name = :name AND workspace = :workspace
""",
"has_relationship": """
SELECT COUNT(id) AS cnt FROM LIGHTRAG_GRAPH_EDGES WHERE source_name = :source_name AND target_name = :target_name AND workspace = :workspace
""",
"upsert_entity": """
INSERT INTO LIGHTRAG_GRAPH_NODES(entity_id, name, content, content_vector, workspace, createtime, updatetime)
VALUES(:id, :name, :content, :content_vector, :workspace, FROM_UNIXTIME(:timestamp), FROM_UNIXTIME(:timestamp))
ON DUPLICATE KEY UPDATE
content = VALUES(content),
content_vector = VALUES(content_vector),
updatetime = FROM_UNIXTIME(:timestamp)
""",
"upsert_relationship": """
INSERT INTO LIGHTRAG_GRAPH_EDGES(relation_id, source_name, target_name, content, content_vector, workspace, createtime, updatetime)
VALUES(:id, :source_name, :target_name, :content, :content_vector, :workspace, FROM_UNIXTIME(:timestamp), FROM_UNIXTIME(:timestamp))
ON DUPLICATE KEY UPDATE
content = VALUES(content),
content_vector = VALUES(content_vector),
updatetime = FROM_UNIXTIME(:timestamp)
""",
# SQL for GraphStorage
"get_node": """
SELECT entity_id AS id, workspace, name, entity_type, description, source_chunk_id AS source_id, content, content_vector
FROM LIGHTRAG_GRAPH_NODES WHERE name = :name AND workspace = :workspace
""",
"get_edge": """
SELECT relation_id AS id, workspace, source_name, target_name, weight, keywords, description, source_chunk_id AS source_id, content, content_vector
FROM LIGHTRAG_GRAPH_EDGES WHERE source_name = :source_name AND target_name = :target_name AND workspace = :workspace
""",
"get_node_edges": """
SELECT relation_id AS id, workspace, source_name, target_name, weight, keywords, description, source_chunk_id, content, content_vector
FROM LIGHTRAG_GRAPH_EDGES WHERE source_name = :source_name AND workspace = :workspace
""",
"node_degree": """
SELECT COUNT(id) AS cnt FROM LIGHTRAG_GRAPH_EDGES WHERE workspace = :workspace AND :name IN (source_name, target_name)
""",
"upsert_node": """
INSERT INTO LIGHTRAG_GRAPH_NODES(name, content, content_vector, workspace, source_chunk_id, entity_type, description)
VALUES(:name, :content, :content_vector, :workspace, :source_chunk_id, :entity_type, :description)
ON DUPLICATE KEY UPDATE
name = VALUES(name), content = VALUES(content), content_vector = VALUES(content_vector),
workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP,
source_chunk_id = VALUES(source_chunk_id), entity_type = VALUES(entity_type), description = VALUES(description)
""",
"upsert_edge": """
INSERT INTO LIGHTRAG_GRAPH_EDGES(source_name, target_name, content, content_vector,
workspace, weight, keywords, description, source_chunk_id)
VALUES(:source_name, :target_name, :content, :content_vector,
:workspace, :weight, :keywords, :description, :source_chunk_id)
ON DUPLICATE KEY UPDATE
source_name = VALUES(source_name), target_name = VALUES(target_name), content = VALUES(content),
content_vector = VALUES(content_vector), workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP,
weight = VALUES(weight), keywords = VALUES(keywords), description = VALUES(description),
source_chunk_id = VALUES(source_chunk_id)
""",
"delete_node": """
DELETE FROM LIGHTRAG_GRAPH_NODES
WHERE name = :name AND workspace = :workspace
""",
"delete_node_edges": """
DELETE FROM LIGHTRAG_GRAPH_EDGES
WHERE (source_name = :name OR target_name = :name) AND workspace = :workspace
""",
"get_all_labels": """
SELECT DISTINCT entity_type as label
FROM LIGHTRAG_GRAPH_NODES
WHERE workspace = :workspace
ORDER BY entity_type
""",
"get_matching_nodes": """
SELECT * FROM LIGHTRAG_GRAPH_NODES
WHERE name LIKE :label_pattern AND workspace = :workspace
ORDER BY name
""",
"get_all_nodes": """
SELECT * FROM LIGHTRAG_GRAPH_NODES
WHERE workspace = :workspace
ORDER BY name
LIMIT :max_nodes
""",
"get_related_edges": """
SELECT * FROM LIGHTRAG_GRAPH_EDGES
WHERE (source_name IN (:node_names) OR target_name IN (:node_names))
AND workspace = :workspace
""",
"remove_multiple_edges": """
DELETE FROM LIGHTRAG_GRAPH_EDGES
WHERE (source_name = :source AND target_name = :target)
AND workspace = :workspace
""",
# Drop tables
"drop_specifiy_table_workspace": "DELETE FROM {table_name} WHERE workspace = :workspace",
}