import asyncio import os from dataclasses import dataclass, field from typing import Any, Union, final import time import numpy as np from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge from ..base import BaseGraphStorage, BaseKVStorage, BaseVectorStorage from ..namespace import NameSpace, is_namespace from ..utils import logger import pipmaster as pm import configparser if not pm.is_installed("pymysql"): pm.install("pymysql") if not pm.is_installed("sqlalchemy"): pm.install("sqlalchemy") from sqlalchemy import create_engine, text # type: ignore def sanitize_sensitive_info(data: dict) -> dict: sanitized_data = data.copy() sensitive_fields = [ "password", "user", "host", "database", "port", "ssl_verify_cert", "ssl_verify_identity", ] for field_name in sensitive_fields: if field_name in sanitized_data: sanitized_data[field_name] = "***" return sanitized_data class TiDB: def __init__(self, config, **kwargs): self.host = config.get("host", None) self.port = config.get("port", None) self.user = config.get("user", None) self.password = config.get("password", None) self.database = config.get("database", None) self.workspace = config.get("workspace", None) connection_string = ( f"mysql+pymysql://{self.user}:{self.password}@{self.host}:{self.port}/{self.database}" f"?ssl_verify_cert=true&ssl_verify_identity=true" ) try: self.engine = create_engine(connection_string) logger.info("Connected to TiDB database") except Exception as e: logger.error("Failed to connect to TiDB database") logger.error(f"TiDB database error: {e}") raise async def _migrate_timestamp_columns(self): """Migrate timestamp columns in tables to timezone-aware types, assuming original data is in UTC""" # Not implemented yet pass async def check_tables(self): # First create all tables for k, v in TABLES.items(): try: await self.query(f"SELECT 1 FROM {k}".format(k=k)) except Exception as e: logger.error("Failed to check table in TiDB database") logger.error(f"TiDB database error: {e}") try: await self.execute(v["ddl"]) logger.info("Created table in TiDB database") except Exception as e: logger.error("Failed to create table in TiDB database") logger.error(f"TiDB database error: {e}") # After all tables are created, try to migrate timestamp fields try: await self._migrate_timestamp_columns() except Exception as e: logger.error(f"TiDB, Failed to migrate timestamp columns: {e}") # Don't raise exceptions, allow initialization process to continue async def query( self, sql: str, params: dict = None, multirows: bool = False ) -> Union[dict, None]: if params is None: params = {"workspace": self.workspace} else: params.update({"workspace": self.workspace}) with self.engine.connect() as conn, conn.begin(): try: result = conn.execute(text(sql), params) except Exception as e: sanitized_params = sanitize_sensitive_info(params) sanitized_error = sanitize_sensitive_info({"error": str(e)}) logger.error( f"Tidb database,\nsql:{sql},\nparams:{sanitized_params},\nerror:{sanitized_error}" ) raise if multirows: rows = result.all() if rows: data = [dict(zip(result.keys(), row)) for row in rows] else: data = [] else: row = result.first() if row: data = dict(zip(result.keys(), row)) else: data = None return data async def execute(self, sql: str, data: list | dict = None): # logger.info("go into TiDBDB execute method") try: with self.engine.connect() as conn, conn.begin(): if data is None: conn.execute(text(sql)) else: conn.execute(text(sql), parameters=data) except Exception as e: sanitized_data = sanitize_sensitive_info(data) if data else None sanitized_error = sanitize_sensitive_info({"error": str(e)}) logger.error( f"Tidb database,\nsql:{sql},\ndata:{sanitized_data},\nerror:{sanitized_error}" ) raise class ClientManager: _instances: dict[str, Any] = {"db": None, "ref_count": 0} _lock = asyncio.Lock() @staticmethod def get_config() -> dict[str, Any]: config = configparser.ConfigParser() config.read("config.ini", "utf-8") return { "host": os.environ.get( "TIDB_HOST", config.get("tidb", "host", fallback="localhost"), ), "port": os.environ.get( "TIDB_PORT", config.get("tidb", "port", fallback=4000) ), "user": os.environ.get( "TIDB_USER", config.get("tidb", "user", fallback=None), ), "password": os.environ.get( "TIDB_PASSWORD", config.get("tidb", "password", fallback=None), ), "database": os.environ.get( "TIDB_DATABASE", config.get("tidb", "database", fallback=None), ), "workspace": os.environ.get( "TIDB_WORKSPACE", config.get("tidb", "workspace", fallback="default"), ), } @classmethod async def get_client(cls) -> TiDB: async with cls._lock: if cls._instances["db"] is None: config = ClientManager.get_config() db = TiDB(config) await db.check_tables() cls._instances["db"] = db cls._instances["ref_count"] = 0 cls._instances["ref_count"] += 1 return cls._instances["db"] @classmethod async def release_client(cls, db: TiDB): async with cls._lock: if db is not None: if db is cls._instances["db"]: cls._instances["ref_count"] -= 1 if cls._instances["ref_count"] == 0: cls._instances["db"] = None @final @dataclass class TiDBKVStorage(BaseKVStorage): db: TiDB = field(default=None) def __post_init__(self): self._data = {} self._max_batch_size = self.global_config["embedding_batch_num"] async def initialize(self): if self.db is None: self.db = await ClientManager.get_client() async def finalize(self): if self.db is not None: await ClientManager.release_client(self.db) self.db = None ################ QUERY METHODS ################ async def get_all(self) -> dict[str, Any]: """Get all data from storage Returns: Dictionary containing all stored data """ async with self._storage_lock: return dict(self._data) async def get_by_id(self, id: str) -> dict[str, Any] | None: """Fetch doc_full data by id.""" SQL = SQL_TEMPLATES["get_by_id_" + self.namespace] params = {"id": id} response = await self.db.query(SQL, params) return response if response else None # Query by id async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: """Fetch doc_chunks data by id""" SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format( ids=",".join([f"'{id}'" for id in ids]) ) return await self.db.query(SQL, multirows=True) async def filter_keys(self, keys: set[str]) -> set[str]: SQL = SQL_TEMPLATES["filter_keys"].format( table_name=namespace_to_table_name(self.namespace), id_field=namespace_to_id(self.namespace), ids=",".join([f"'{id}'" for id in keys]), ) try: await self.db.query(SQL) except Exception as e: logger.error(f"Tidb database,\nsql:{SQL},\nkeys:{keys},\nerror:{e}") res = await self.db.query(SQL, multirows=True) if res: exist_keys = [key["id"] for key in res] data = set([s for s in keys if s not in exist_keys]) else: exist_keys = [] data = set([s for s in keys if s not in exist_keys]) return data ################ INSERT full_doc AND chunks ################ async def upsert(self, data: dict[str, dict[str, Any]]) -> None: logger.info(f"Inserting {len(data)} to {self.namespace}") if not data: return left_data = {k: v for k, v in data.items() if k not in self._data} self._data.update(left_data) if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS): list_data = [ { "__id__": k, **{k1: v1 for k1, v1 in v.items()}, } for k, v in data.items() ] contents = [v["content"] for v in data.values()] batches = [ contents[i : i + self._max_batch_size] for i in range(0, len(contents), self._max_batch_size) ] embeddings_list = await asyncio.gather( *[self.embedding_func(batch) for batch in batches] ) embeddings = np.concatenate(embeddings_list) for i, d in enumerate(list_data): d["__vector__"] = embeddings[i] # Get current time as UNIX timestamp current_time = int(time.time()) merge_sql = SQL_TEMPLATES["upsert_chunk"] data = [] for item in list_data: data.append( { "id": item["__id__"], "content": item["content"], "tokens": item["tokens"], "chunk_order_index": item["chunk_order_index"], "full_doc_id": item["full_doc_id"], "content_vector": f"{item['__vector__'].tolist()}", "workspace": self.db.workspace, "timestamp": current_time, } ) await self.db.execute(merge_sql, data) if is_namespace(self.namespace, NameSpace.KV_STORE_FULL_DOCS): merge_sql = SQL_TEMPLATES["upsert_doc_full"] data = [] for k, v in self._data.items(): data.append( { "id": k, "content": v["content"], "workspace": self.db.workspace, } ) await self.db.execute(merge_sql, data) return left_data async def index_done_callback(self) -> None: # Ti handles persistence automatically pass async def delete(self, ids: list[str]) -> None: """Delete records with specified IDs from the storage. Args: ids: List of record IDs to be deleted """ if not ids: return try: table_name = namespace_to_table_name(self.namespace) id_field = namespace_to_id(self.namespace) if not table_name or not id_field: logger.error(f"Unknown namespace for deletion: {self.namespace}") return ids_list = ",".join([f"'{id}'" for id in ids]) delete_sql = f"DELETE FROM {table_name} WHERE workspace = :workspace AND {id_field} IN ({ids_list})" await self.db.execute(delete_sql, {"workspace": self.db.workspace}) logger.info( f"Successfully deleted {len(ids)} records from {self.namespace}" ) except Exception as e: logger.error(f"Error deleting records from {self.namespace}: {e}") async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool: """Delete specific records from storage by cache mode Args: modes (list[str]): List of cache modes to be dropped from storage Returns: bool: True if successful, False otherwise """ if not modes: return False try: table_name = namespace_to_table_name(self.namespace) if not table_name: return False if table_name != "LIGHTRAG_LLM_CACHE": return False # Build MySQL style IN query modes_list = ", ".join([f"'{mode}'" for mode in modes]) sql = f""" DELETE FROM {table_name} WHERE workspace = :workspace AND mode IN ({modes_list}) """ logger.info(f"Deleting cache by modes: {modes}") await self.db.execute(sql, {"workspace": self.db.workspace}) return True except Exception as e: logger.error(f"Error deleting cache by modes {modes}: {e}") return False async def drop(self) -> dict[str, str]: """Drop the storage""" try: table_name = namespace_to_table_name(self.namespace) if not table_name: return { "status": "error", "message": f"Unknown namespace: {self.namespace}", } drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format( table_name=table_name ) await self.db.execute(drop_sql, {"workspace": self.db.workspace}) return {"status": "success", "message": "data dropped"} except Exception as e: return {"status": "error", "message": str(e)} @final @dataclass class TiDBVectorDBStorage(BaseVectorStorage): db: TiDB | None = field(default=None) def __post_init__(self): self._client_file_name = os.path.join( self.global_config["working_dir"], f"vdb_{self.namespace}.json" ) self._max_batch_size = self.global_config["embedding_batch_num"] config = self.global_config.get("vector_db_storage_cls_kwargs", {}) cosine_threshold = config.get("cosine_better_than_threshold") if cosine_threshold is None: raise ValueError( "cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs" ) self.cosine_better_than_threshold = cosine_threshold async def initialize(self): if self.db is None: self.db = await ClientManager.get_client() async def finalize(self): if self.db is not None: await ClientManager.release_client(self.db) self.db = None async def query( self, query: str, top_k: int, ids: list[str] | None = None ) -> list[dict[str, Any]]: """Search from tidb vector""" embeddings = await self.embedding_func( [query], _priority=5 ) # higher priority for query embedding = embeddings[0] embedding_string = "[" + ", ".join(map(str, embedding.tolist())) + "]" params = { "embedding_string": embedding_string, "top_k": top_k, "better_than_threshold": self.cosine_better_than_threshold, } results = await self.db.query( SQL_TEMPLATES[self.namespace], params=params, multirows=True ) if not results: return [] return results ###### INSERT entities And relationships ###### async def upsert(self, data: dict[str, dict[str, Any]]) -> None: logger.info(f"Inserting {len(data)} to {self.namespace}") if not data: return logger.info(f"Inserting {len(data)} vectors to {self.namespace}") # Get current time as UNIX timestamp import time current_time = int(time.time()) list_data = [ { "id": k, "timestamp": current_time, **{k1: v1 for k1, v1 in v.items()}, } for k, v in data.items() ] contents = [v["content"] for v in data.values()] batches = [ contents[i : i + self._max_batch_size] for i in range(0, len(contents), self._max_batch_size) ] embedding_tasks = [self.embedding_func(batch) for batch in batches] embeddings_list = await asyncio.gather(*embedding_tasks) embeddings = np.concatenate(embeddings_list) for i, d in enumerate(list_data): d["content_vector"] = embeddings[i] if is_namespace(self.namespace, NameSpace.VECTOR_STORE_CHUNKS): for item in list_data: param = { "id": item["id"], "content": item["content"], "tokens": item.get("tokens", 0), "chunk_order_index": item.get("chunk_order_index", 0), "full_doc_id": item.get("full_doc_id", ""), "content_vector": f"{item['content_vector'].tolist()}", "workspace": self.db.workspace, "timestamp": item["timestamp"], } await self.db.execute(SQL_TEMPLATES["upsert_chunk"], param) elif is_namespace(self.namespace, NameSpace.VECTOR_STORE_ENTITIES): for item in list_data: param = { "id": item["id"], "name": item["entity_name"], "content": item["content"], "content_vector": f"{item['content_vector'].tolist()}", "workspace": self.db.workspace, "timestamp": item["timestamp"], } await self.db.execute(SQL_TEMPLATES["upsert_entity"], param) elif is_namespace(self.namespace, NameSpace.VECTOR_STORE_RELATIONSHIPS): for item in list_data: param = { "id": item["id"], "source_name": item["src_id"], "target_name": item["tgt_id"], "content": item["content"], "content_vector": f"{item['content_vector'].tolist()}", "workspace": self.db.workspace, "timestamp": item["timestamp"], } await self.db.execute(SQL_TEMPLATES["upsert_relationship"], param) async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]: SQL = SQL_TEMPLATES["get_by_status_" + self.namespace] params = {"workspace": self.db.workspace, "status": status} return await self.db.query(SQL, params, multirows=True) async def delete(self, ids: list[str]) -> None: """Delete vectors with specified IDs from the storage. Args: ids: List of vector IDs to be deleted """ if not ids: return table_name = namespace_to_table_name(self.namespace) id_field = namespace_to_id(self.namespace) if not table_name or not id_field: logger.error(f"Unknown namespace for vector deletion: {self.namespace}") return ids_list = ",".join([f"'{id}'" for id in ids]) delete_sql = f"DELETE FROM {table_name} WHERE workspace = :workspace AND {id_field} IN ({ids_list})" try: await self.db.execute(delete_sql, {"workspace": self.db.workspace}) logger.debug( f"Successfully deleted {len(ids)} vectors from {self.namespace}" ) except Exception as e: logger.error(f"Error while deleting vectors from {self.namespace}: {e}") async def delete_entity(self, entity_name: str) -> None: """Delete an entity by its name from the vector storage. Args: entity_name: The name of the entity to delete """ try: # Construct SQL to delete the entity delete_sql = """DELETE FROM LIGHTRAG_GRAPH_NODES WHERE workspace = :workspace AND name = :entity_name""" await self.db.execute( delete_sql, {"workspace": self.db.workspace, "entity_name": entity_name} ) logger.debug(f"Successfully deleted entity {entity_name}") except Exception as e: logger.error(f"Error deleting entity {entity_name}: {e}") async def delete_entity_relation(self, entity_name: str) -> None: """Delete all relations associated with an entity. Args: entity_name: The name of the entity whose relations should be deleted """ try: # Delete relations where the entity is either the source or target delete_sql = """DELETE FROM LIGHTRAG_GRAPH_EDGES WHERE workspace = :workspace AND (source_name = :entity_name OR target_name = :entity_name)""" await self.db.execute( delete_sql, {"workspace": self.db.workspace, "entity_name": entity_name} ) logger.debug(f"Successfully deleted relations for entity {entity_name}") except Exception as e: logger.error(f"Error deleting relations for entity {entity_name}: {e}") async def index_done_callback(self) -> None: # Ti handles persistence automatically pass async def drop(self) -> dict[str, str]: """Drop the storage""" try: table_name = namespace_to_table_name(self.namespace) if not table_name: return { "status": "error", "message": f"Unknown namespace: {self.namespace}", } drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format( table_name=table_name ) await self.db.execute(drop_sql, {"workspace": self.db.workspace}) return {"status": "success", "message": "data dropped"} except Exception as e: return {"status": "error", "message": str(e)} async def get_by_id(self, id: str) -> dict[str, Any] | None: """Get vector data by its ID Args: id: The unique identifier of the vector Returns: The vector data if found, or None if not found """ try: # Determine which table to query based on namespace if self.namespace == NameSpace.VECTOR_STORE_ENTITIES: sql_template = """ SELECT entity_id as id, name as entity_name, entity_type, description, content, UNIX_TIMESTAMP(createtime) as created_at FROM LIGHTRAG_GRAPH_NODES WHERE entity_id = :entity_id AND workspace = :workspace """ params = {"entity_id": id, "workspace": self.db.workspace} elif self.namespace == NameSpace.VECTOR_STORE_RELATIONSHIPS: sql_template = """ SELECT relation_id as id, source_name as src_id, target_name as tgt_id, keywords, description, content, UNIX_TIMESTAMP(createtime) as created_at FROM LIGHTRAG_GRAPH_EDGES WHERE relation_id = :relation_id AND workspace = :workspace """ params = {"relation_id": id, "workspace": self.db.workspace} elif self.namespace == NameSpace.VECTOR_STORE_CHUNKS: sql_template = """ SELECT chunk_id as id, content, tokens, chunk_order_index, full_doc_id, UNIX_TIMESTAMP(createtime) as created_at FROM LIGHTRAG_DOC_CHUNKS WHERE chunk_id = :chunk_id AND workspace = :workspace """ params = {"chunk_id": id, "workspace": self.db.workspace} else: logger.warning( f"Namespace {self.namespace} not supported for get_by_id" ) return None result = await self.db.query(sql_template, params=params) return result except Exception as e: logger.error(f"Error retrieving vector data for ID {id}: {e}") return None async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: """Get multiple vector data by their IDs Args: ids: List of unique identifiers Returns: List of vector data objects that were found """ if not ids: return [] try: # Format IDs for SQL IN clause ids_str = ", ".join([f"'{id}'" for id in ids]) # Determine which table to query based on namespace if self.namespace == NameSpace.VECTOR_STORE_ENTITIES: sql_template = f""" SELECT entity_id as id, name as entity_name, entity_type, description, content, UNIX_TIMESTAMP(createtime) as created_at FROM LIGHTRAG_GRAPH_NODES WHERE entity_id IN ({ids_str}) AND workspace = :workspace """ elif self.namespace == NameSpace.VECTOR_STORE_RELATIONSHIPS: sql_template = f""" SELECT relation_id as id, source_name as src_id, target_name as tgt_id, keywords, description, content, UNIX_TIMESTAMP(createtime) as created_at FROM LIGHTRAG_GRAPH_EDGES WHERE relation_id IN ({ids_str}) AND workspace = :workspace """ elif self.namespace == NameSpace.VECTOR_STORE_CHUNKS: sql_template = f""" SELECT chunk_id as id, content, tokens, chunk_order_index, full_doc_id, UNIX_TIMESTAMP(createtime) as created_at FROM LIGHTRAG_DOC_CHUNKS WHERE chunk_id IN ({ids_str}) AND workspace = :workspace """ else: logger.warning( f"Namespace {self.namespace} not supported for get_by_ids" ) return [] params = {"workspace": self.db.workspace} results = await self.db.query(sql_template, params=params, multirows=True) return results if results else [] except Exception as e: logger.error(f"Error retrieving vector data for IDs {ids}: {e}") return [] @final @dataclass class TiDBGraphStorage(BaseGraphStorage): db: TiDB = field(default=None) def __post_init__(self): self._max_batch_size = self.global_config["embedding_batch_num"] async def initialize(self): if self.db is None: self.db = await ClientManager.get_client() async def finalize(self): if self.db is not None: await ClientManager.release_client(self.db) self.db = None #################### upsert method ################ async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: entity_name = node_id entity_type = node_data["entity_type"] description = node_data["description"] source_id = node_data["source_id"] logger.debug(f"entity_name:{entity_name}, entity_type:{entity_type}") content = entity_name + description contents = [content] batches = [ contents[i : i + self._max_batch_size] for i in range(0, len(contents), self._max_batch_size) ] embeddings_list = await asyncio.gather( *[self.embedding_func(batch) for batch in batches] ) embeddings = np.concatenate(embeddings_list) content_vector = embeddings[0] sql = SQL_TEMPLATES["upsert_node"] data = { "workspace": self.db.workspace, "name": entity_name, "entity_type": entity_type, "description": description, "source_chunk_id": source_id, "content": content, "content_vector": f"{content_vector.tolist()}", } await self.db.execute(sql, data) async def upsert_edge( self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] ) -> None: source_name = source_node_id target_name = target_node_id weight = edge_data["weight"] keywords = edge_data["keywords"] description = edge_data["description"] source_chunk_id = edge_data["source_id"] logger.debug( f"source_name:{source_name}, target_name:{target_name}, keywords: {keywords}" ) content = keywords + source_name + target_name + description contents = [content] batches = [ contents[i : i + self._max_batch_size] for i in range(0, len(contents), self._max_batch_size) ] embeddings_list = await asyncio.gather( *[self.embedding_func(batch) for batch in batches] ) embeddings = np.concatenate(embeddings_list) content_vector = embeddings[0] merge_sql = SQL_TEMPLATES["upsert_edge"] data = { "workspace": self.db.workspace, "source_name": source_name, "target_name": target_name, "weight": weight, "keywords": keywords, "description": description, "source_chunk_id": source_chunk_id, "content": content, "content_vector": f"{content_vector.tolist()}", } await self.db.execute(merge_sql, data) # Query async def has_node(self, node_id: str) -> bool: sql = SQL_TEMPLATES["has_entity"] param = {"name": node_id, "workspace": self.db.workspace} has = await self.db.query(sql, param) return has["cnt"] != 0 async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: sql = SQL_TEMPLATES["has_relationship"] param = { "source_name": source_node_id, "target_name": target_node_id, "workspace": self.db.workspace, } has = await self.db.query(sql, param) return has["cnt"] != 0 async def node_degree(self, node_id: str) -> int: sql = SQL_TEMPLATES["node_degree"] param = {"name": node_id, "workspace": self.db.workspace} result = await self.db.query(sql, param) return result["cnt"] async def edge_degree(self, src_id: str, tgt_id: str) -> int: degree = await self.node_degree(src_id) + await self.node_degree(tgt_id) return degree async def get_node(self, node_id: str) -> dict[str, str] | None: sql = SQL_TEMPLATES["get_node"] param = {"name": node_id, "workspace": self.db.workspace} return await self.db.query(sql, param) async def get_edge( self, source_node_id: str, target_node_id: str ) -> dict[str, str] | None: sql = SQL_TEMPLATES["get_edge"] param = { "source_name": source_node_id, "target_name": target_node_id, "workspace": self.db.workspace, } return await self.db.query(sql, param) async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: sql = SQL_TEMPLATES["get_node_edges"] param = {"source_name": source_node_id, "workspace": self.db.workspace} res = await self.db.query(sql, param, multirows=True) if res: data = [(i["source_name"], i["target_name"]) for i in res] return data else: return [] async def index_done_callback(self) -> None: # Ti handles persistence automatically pass async def drop(self) -> dict[str, str]: """Drop the storage""" try: drop_sql = """ DELETE FROM LIGHTRAG_GRAPH_EDGES WHERE workspace = :workspace; DELETE FROM LIGHTRAG_GRAPH_NODES WHERE workspace = :workspace; """ await self.db.execute(drop_sql, {"workspace": self.db.workspace}) return {"status": "success", "message": "graph data dropped"} except Exception as e: return {"status": "error", "message": str(e)} async def delete_node(self, node_id: str) -> None: """Delete a node and all its related edges Args: node_id: The ID of the node to delete """ # First delete all edges related to this node await self.db.execute( SQL_TEMPLATES["delete_node_edges"], {"name": node_id, "workspace": self.db.workspace}, ) # Then delete the node itself await self.db.execute( SQL_TEMPLATES["delete_node"], {"name": node_id, "workspace": self.db.workspace}, ) logger.debug( f"Node {node_id} and its related edges have been deleted from the graph" ) async def get_all_labels(self) -> list[str]: """Get all entity types (labels) in the database Returns: List of labels sorted alphabetically """ result = await self.db.query( SQL_TEMPLATES["get_all_labels"], {"workspace": self.db.workspace}, multirows=True, ) if not result: return [] # Extract all labels return [item["label"] for item in result] async def get_knowledge_graph( self, node_label: str, max_depth: int = 5 ) -> KnowledgeGraph: """ Get a connected subgraph of nodes matching the specified label Maximum number of nodes is limited by MAX_GRAPH_NODES environment variable (default: 1000) Args: node_label: The node label to match max_depth: Maximum depth of the subgraph Returns: KnowledgeGraph object containing nodes and edges """ result = KnowledgeGraph() MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000)) # Get matching nodes if node_label == "*": # Handle special case, get all nodes node_results = await self.db.query( SQL_TEMPLATES["get_all_nodes"], {"workspace": self.db.workspace, "max_nodes": MAX_GRAPH_NODES}, multirows=True, ) else: # Get nodes matching the label label_pattern = f"%{node_label}%" node_results = await self.db.query( SQL_TEMPLATES["get_matching_nodes"], {"workspace": self.db.workspace, "label_pattern": label_pattern}, multirows=True, ) if not node_results: logger.warning(f"No nodes found matching label {node_label}") return result # Limit the number of returned nodes if len(node_results) > MAX_GRAPH_NODES: node_results = node_results[:MAX_GRAPH_NODES] # Extract node names for edge query node_names = [node["name"] for node in node_results] node_names_str = ",".join([f"'{name}'" for name in node_names]) # Add nodes to result for node in node_results: node_properties = { k: v for k, v in node.items() if k not in ["id", "name", "entity_type"] } result.nodes.append( KnowledgeGraphNode( id=node["name"], labels=[node["entity_type"]] if node.get("entity_type") else [node["name"]], properties=node_properties, ) ) # Get related edges edge_results = await self.db.query( SQL_TEMPLATES["get_related_edges"].format(node_names=node_names_str), {"workspace": self.db.workspace}, multirows=True, ) if edge_results: # Add edges to result for edge in edge_results: # Only include edges related to selected nodes if ( edge["source_name"] in node_names and edge["target_name"] in node_names ): edge_id = f"{edge['source_name']}-{edge['target_name']}" edge_properties = { k: v for k, v in edge.items() if k not in ["id", "source_name", "target_name"] } result.edges.append( KnowledgeGraphEdge( id=edge_id, type="RELATED", source=edge["source_name"], target=edge["target_name"], properties=edge_properties, ) ) logger.info( f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" ) return result async def remove_nodes(self, nodes: list[str]): """Delete multiple nodes Args: nodes: List of node IDs to delete """ for node_id in nodes: await self.delete_node(node_id) async def remove_edges(self, edges: list[tuple[str, str]]): """Delete multiple edges Args: edges: List of edges to delete, each edge is a (source, target) tuple """ for source, target in edges: await self.db.execute( SQL_TEMPLATES["remove_multiple_edges"], {"source": source, "target": target, "workspace": self.db.workspace}, ) N_T = { NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL", NameSpace.KV_STORE_TEXT_CHUNKS: "LIGHTRAG_DOC_CHUNKS", NameSpace.VECTOR_STORE_CHUNKS: "LIGHTRAG_DOC_CHUNKS", NameSpace.VECTOR_STORE_ENTITIES: "LIGHTRAG_GRAPH_NODES", NameSpace.VECTOR_STORE_RELATIONSHIPS: "LIGHTRAG_GRAPH_EDGES", } N_ID = { NameSpace.KV_STORE_FULL_DOCS: "doc_id", NameSpace.KV_STORE_TEXT_CHUNKS: "chunk_id", NameSpace.VECTOR_STORE_CHUNKS: "chunk_id", NameSpace.VECTOR_STORE_ENTITIES: "entity_id", NameSpace.VECTOR_STORE_RELATIONSHIPS: "relation_id", } def namespace_to_table_name(namespace: str) -> str: for k, v in N_T.items(): if is_namespace(namespace, k): return v def namespace_to_id(namespace: str) -> str: for k, v in N_ID.items(): if is_namespace(namespace, k): return v TABLES = { "LIGHTRAG_DOC_FULL": { "ddl": """ CREATE TABLE LIGHTRAG_DOC_FULL ( `id` BIGINT PRIMARY KEY AUTO_RANDOM, `doc_id` VARCHAR(256) NOT NULL, `workspace` varchar(1024), `content` LONGTEXT, `meta` JSON, `createtime` TIMESTAMP DEFAULT CURRENT_TIMESTAMP, `updatetime` TIMESTAMP DEFAULT NULL, UNIQUE KEY (`doc_id`) ); """ }, "LIGHTRAG_DOC_CHUNKS": { "ddl": """ CREATE TABLE LIGHTRAG_DOC_CHUNKS ( `id` BIGINT PRIMARY KEY AUTO_RANDOM, `chunk_id` VARCHAR(256) NOT NULL, `full_doc_id` VARCHAR(256) NOT NULL, `workspace` varchar(1024), `chunk_order_index` INT, `tokens` INT, `content` LONGTEXT, `content_vector` VECTOR, `createtime` TIMESTAMP, `updatetime` TIMESTAMP, UNIQUE KEY (`chunk_id`) ); """ }, "LIGHTRAG_GRAPH_NODES": { "ddl": """ CREATE TABLE LIGHTRAG_GRAPH_NODES ( `id` BIGINT PRIMARY KEY AUTO_RANDOM, `entity_id` VARCHAR(256), `workspace` varchar(1024), `name` VARCHAR(2048), `entity_type` VARCHAR(1024), `description` LONGTEXT, `source_chunk_id` VARCHAR(256), `content` LONGTEXT, `content_vector` VECTOR, `createtime` TIMESTAMP, `updatetime` TIMESTAMP, KEY (`entity_id`) ); """ }, "LIGHTRAG_GRAPH_EDGES": { "ddl": """ CREATE TABLE LIGHTRAG_GRAPH_EDGES ( `id` BIGINT PRIMARY KEY AUTO_RANDOM, `relation_id` VARCHAR(256), `workspace` varchar(1024), `source_name` VARCHAR(2048), `target_name` VARCHAR(2048), `weight` DECIMAL, `keywords` TEXT, `description` LONGTEXT, `source_chunk_id` varchar(256), `content` LONGTEXT, `content_vector` VECTOR, `createtime` TIMESTAMP, `updatetime` TIMESTAMP, KEY (`relation_id`) ); """ }, "LIGHTRAG_LLM_CACHE": { "ddl": """ CREATE TABLE LIGHTRAG_LLM_CACHE ( id BIGINT PRIMARY KEY AUTO_INCREMENT, send TEXT, return TEXT, model VARCHAR(1024), createtime DATETIME DEFAULT CURRENT_TIMESTAMP, updatetime DATETIME DEFAULT NULL ); """ }, } SQL_TEMPLATES = { # SQL for KVStorage "get_by_id_full_docs": "SELECT doc_id as id, IFNULL(content, '') AS content FROM LIGHTRAG_DOC_FULL WHERE doc_id = :id AND workspace = :workspace", "get_by_id_text_chunks": "SELECT chunk_id as id, tokens, IFNULL(content, '') AS content, chunk_order_index, full_doc_id FROM LIGHTRAG_DOC_CHUNKS WHERE chunk_id = :id AND workspace = :workspace", "get_by_ids_full_docs": "SELECT doc_id as id, IFNULL(content, '') AS content FROM LIGHTRAG_DOC_FULL WHERE doc_id IN ({ids}) AND workspace = :workspace", "get_by_ids_text_chunks": "SELECT chunk_id as id, tokens, IFNULL(content, '') AS content, chunk_order_index, full_doc_id FROM LIGHTRAG_DOC_CHUNKS WHERE chunk_id IN ({ids}) AND workspace = :workspace", "filter_keys": "SELECT {id_field} AS id FROM {table_name} WHERE {id_field} IN ({ids}) AND workspace = :workspace", # SQL for Merge operations (TiDB version with INSERT ... ON DUPLICATE KEY UPDATE) "upsert_doc_full": """ INSERT INTO LIGHTRAG_DOC_FULL (doc_id, content, workspace) VALUES (:id, :content, :workspace) ON DUPLICATE KEY UPDATE content = VALUES(content), workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP """, "upsert_chunk": """ INSERT INTO LIGHTRAG_DOC_CHUNKS(chunk_id, content, tokens, chunk_order_index, full_doc_id, content_vector, workspace, createtime, updatetime) VALUES (:id, :content, :tokens, :chunk_order_index, :full_doc_id, :content_vector, :workspace, FROM_UNIXTIME(:timestamp), FROM_UNIXTIME(:timestamp)) ON DUPLICATE KEY UPDATE content = VALUES(content), tokens = VALUES(tokens), chunk_order_index = VALUES(chunk_order_index), full_doc_id = VALUES(full_doc_id), content_vector = VALUES(content_vector), workspace = VALUES(workspace), updatetime = FROM_UNIXTIME(:timestamp) """, # SQL for VectorStorage "entities": """SELECT n.name as entity_name, UNIX_TIMESTAMP(n.createtime) as created_at FROM (SELECT entity_id as id, name, createtime, VEC_COSINE_DISTANCE(content_vector,:embedding_string) as distance FROM LIGHTRAG_GRAPH_NODES WHERE workspace = :workspace) n WHERE n.distance>:better_than_threshold ORDER BY n.distance DESC LIMIT :top_k """, "relationships": """SELECT e.source_name as src_id, e.target_name as tgt_id, UNIX_TIMESTAMP(e.createtime) as created_at FROM (SELECT source_name, target_name, createtime, VEC_COSINE_DISTANCE(content_vector, :embedding_string) as distance FROM LIGHTRAG_GRAPH_EDGES WHERE workspace = :workspace) e WHERE e.distance>:better_than_threshold ORDER BY e.distance DESC LIMIT :top_k """, "chunks": """SELECT c.id, UNIX_TIMESTAMP(c.createtime) as created_at FROM (SELECT chunk_id as id, createtime, VEC_COSINE_DISTANCE(content_vector, :embedding_string) as distance FROM LIGHTRAG_DOC_CHUNKS WHERE workspace = :workspace) c WHERE c.distance>:better_than_threshold ORDER BY c.distance DESC LIMIT :top_k """, "has_entity": """ SELECT COUNT(id) AS cnt FROM LIGHTRAG_GRAPH_NODES WHERE name = :name AND workspace = :workspace """, "has_relationship": """ SELECT COUNT(id) AS cnt FROM LIGHTRAG_GRAPH_EDGES WHERE source_name = :source_name AND target_name = :target_name AND workspace = :workspace """, "upsert_entity": """ INSERT INTO LIGHTRAG_GRAPH_NODES(entity_id, name, content, content_vector, workspace, createtime, updatetime) VALUES(:id, :name, :content, :content_vector, :workspace, FROM_UNIXTIME(:timestamp), FROM_UNIXTIME(:timestamp)) ON DUPLICATE KEY UPDATE content = VALUES(content), content_vector = VALUES(content_vector), updatetime = FROM_UNIXTIME(:timestamp) """, "upsert_relationship": """ INSERT INTO LIGHTRAG_GRAPH_EDGES(relation_id, source_name, target_name, content, content_vector, workspace, createtime, updatetime) VALUES(:id, :source_name, :target_name, :content, :content_vector, :workspace, FROM_UNIXTIME(:timestamp), FROM_UNIXTIME(:timestamp)) ON DUPLICATE KEY UPDATE content = VALUES(content), content_vector = VALUES(content_vector), updatetime = FROM_UNIXTIME(:timestamp) """, # SQL for GraphStorage "get_node": """ SELECT entity_id AS id, workspace, name, entity_type, description, source_chunk_id AS source_id, content, content_vector FROM LIGHTRAG_GRAPH_NODES WHERE name = :name AND workspace = :workspace """, "get_edge": """ SELECT relation_id AS id, workspace, source_name, target_name, weight, keywords, description, source_chunk_id AS source_id, content, content_vector FROM LIGHTRAG_GRAPH_EDGES WHERE source_name = :source_name AND target_name = :target_name AND workspace = :workspace """, "get_node_edges": """ SELECT relation_id AS id, workspace, source_name, target_name, weight, keywords, description, source_chunk_id, content, content_vector FROM LIGHTRAG_GRAPH_EDGES WHERE source_name = :source_name AND workspace = :workspace """, "node_degree": """ SELECT COUNT(id) AS cnt FROM LIGHTRAG_GRAPH_EDGES WHERE workspace = :workspace AND :name IN (source_name, target_name) """, "upsert_node": """ INSERT INTO LIGHTRAG_GRAPH_NODES(name, content, content_vector, workspace, source_chunk_id, entity_type, description) VALUES(:name, :content, :content_vector, :workspace, :source_chunk_id, :entity_type, :description) ON DUPLICATE KEY UPDATE name = VALUES(name), content = VALUES(content), content_vector = VALUES(content_vector), workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP, source_chunk_id = VALUES(source_chunk_id), entity_type = VALUES(entity_type), description = VALUES(description) """, "upsert_edge": """ INSERT INTO LIGHTRAG_GRAPH_EDGES(source_name, target_name, content, content_vector, workspace, weight, keywords, description, source_chunk_id) VALUES(:source_name, :target_name, :content, :content_vector, :workspace, :weight, :keywords, :description, :source_chunk_id) ON DUPLICATE KEY UPDATE source_name = VALUES(source_name), target_name = VALUES(target_name), content = VALUES(content), content_vector = VALUES(content_vector), workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP, weight = VALUES(weight), keywords = VALUES(keywords), description = VALUES(description), source_chunk_id = VALUES(source_chunk_id) """, "delete_node": """ DELETE FROM LIGHTRAG_GRAPH_NODES WHERE name = :name AND workspace = :workspace """, "delete_node_edges": """ DELETE FROM LIGHTRAG_GRAPH_EDGES WHERE (source_name = :name OR target_name = :name) AND workspace = :workspace """, "get_all_labels": """ SELECT DISTINCT entity_type as label FROM LIGHTRAG_GRAPH_NODES WHERE workspace = :workspace ORDER BY entity_type """, "get_matching_nodes": """ SELECT * FROM LIGHTRAG_GRAPH_NODES WHERE name LIKE :label_pattern AND workspace = :workspace ORDER BY name """, "get_all_nodes": """ SELECT * FROM LIGHTRAG_GRAPH_NODES WHERE workspace = :workspace ORDER BY name LIMIT :max_nodes """, "get_related_edges": """ SELECT * FROM LIGHTRAG_GRAPH_EDGES WHERE (source_name IN (:node_names) OR target_name IN (:node_names)) AND workspace = :workspace """, "remove_multiple_edges": """ DELETE FROM LIGHTRAG_GRAPH_EDGES WHERE (source_name = :source AND target_name = :target) AND workspace = :workspace """, # Drop tables "drop_specifiy_table_workspace": "DELETE FROM {table_name} WHERE workspace = :workspace", }