import os import re from dataclasses import dataclass from typing import final import configparser from tenacity import ( retry, stop_after_attempt, wait_exponential, retry_if_exception_type, ) import logging from ..utils import logger from ..base import BaseGraphStorage from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge import pipmaster as pm if not pm.is_installed("neo4j"): pm.install("neo4j") from neo4j import ( # type: ignore AsyncGraphDatabase, exceptions as neo4jExceptions, AsyncDriver, AsyncManagedTransaction, ) from dotenv import load_dotenv # use the .env that is inside the current folder # allows to use different .env file for each lightrag instance # the OS environment variables take precedence over the .env file load_dotenv(dotenv_path=".env", override=False) # Get maximum number of graph nodes from environment variable, default is 1000 MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000)) config = configparser.ConfigParser() config.read("config.ini", "utf-8") # Set neo4j logger level to ERROR to suppress warning logs logging.getLogger("neo4j").setLevel(logging.ERROR) @final @dataclass class Neo4JStorage(BaseGraphStorage): def __init__(self, namespace, global_config, embedding_func): super().__init__( namespace=namespace, global_config=global_config, embedding_func=embedding_func, ) self._driver = None async def initialize(self): URI = os.environ.get("NEO4J_URI", config.get("neo4j", "uri", fallback=None)) USERNAME = os.environ.get( "NEO4J_USERNAME", config.get("neo4j", "username", fallback=None) ) PASSWORD = os.environ.get( "NEO4J_PASSWORD", config.get("neo4j", "password", fallback=None) ) MAX_CONNECTION_POOL_SIZE = int( os.environ.get( "NEO4J_MAX_CONNECTION_POOL_SIZE", config.get("neo4j", "connection_pool_size", fallback=50), ) ) CONNECTION_TIMEOUT = float( os.environ.get( "NEO4J_CONNECTION_TIMEOUT", config.get("neo4j", "connection_timeout", fallback=30.0), ), ) CONNECTION_ACQUISITION_TIMEOUT = float( os.environ.get( "NEO4J_CONNECTION_ACQUISITION_TIMEOUT", config.get("neo4j", "connection_acquisition_timeout", fallback=30.0), ), ) MAX_TRANSACTION_RETRY_TIME = float( os.environ.get( "NEO4J_MAX_TRANSACTION_RETRY_TIME", config.get("neo4j", "max_transaction_retry_time", fallback=30.0), ), ) DATABASE = os.environ.get( "NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", self.namespace) ) self._driver: AsyncDriver = AsyncGraphDatabase.driver( URI, auth=(USERNAME, PASSWORD), max_connection_pool_size=MAX_CONNECTION_POOL_SIZE, connection_timeout=CONNECTION_TIMEOUT, connection_acquisition_timeout=CONNECTION_ACQUISITION_TIMEOUT, max_transaction_retry_time=MAX_TRANSACTION_RETRY_TIME, ) # Try to connect to the database and create it if it doesn't exist for database in (DATABASE, None): self._DATABASE = database connected = False try: async with self._driver.session(database=database) as session: try: result = await session.run("MATCH (n) RETURN n LIMIT 0") await result.consume() # Ensure result is consumed logger.info(f"Connected to {database} at {URI}") connected = True except neo4jExceptions.ServiceUnavailable as e: logger.error( f"{database} at {URI} is not available".capitalize() ) raise e except neo4jExceptions.AuthError as e: logger.error(f"Authentication failed for {database} at {URI}") raise e except neo4jExceptions.ClientError as e: if e.code == "Neo.ClientError.Database.DatabaseNotFound": logger.info( f"{database} at {URI} not found. Try to create specified database.".capitalize() ) try: async with self._driver.session() as session: result = await session.run( f"CREATE DATABASE `{database}` IF NOT EXISTS" ) await result.consume() # Ensure result is consumed logger.info(f"{database} at {URI} created".capitalize()) connected = True except ( neo4jExceptions.ClientError, neo4jExceptions.DatabaseError, ) as e: if ( e.code == "Neo.ClientError.Statement.UnsupportedAdministrationCommand" ) or (e.code == "Neo.DatabaseError.Statement.ExecutionFailed"): if database is not None: logger.warning( "This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead. Fallback to use the default database." ) if database is None: logger.error(f"Failed to create {database} at {URI}") raise e if connected: # Create index for base nodes on entity_id if it doesn't exist try: async with self._driver.session(database=database) as session: # Check if index exists first check_query = """ CALL db.indexes() YIELD name, labelsOrTypes, properties WHERE labelsOrTypes = ['base'] AND properties = ['entity_id'] RETURN count(*) > 0 AS exists """ try: check_result = await session.run(check_query) record = await check_result.single() await check_result.consume() index_exists = record and record.get("exists", False) if not index_exists: # Create index only if it doesn't exist result = await session.run( "CREATE INDEX FOR (n:base) ON (n.entity_id)" ) await result.consume() logger.info( f"Created index for base nodes on entity_id in {database}" ) except Exception: # Fallback if db.indexes() is not supported in this Neo4j version result = await session.run( "CREATE INDEX IF NOT EXISTS FOR (n:base) ON (n.entity_id)" ) await result.consume() except Exception as e: logger.warning(f"Failed to create index: {str(e)}") break async def finalize(self): """Close the Neo4j driver and release all resources""" if self._driver: await self._driver.close() self._driver = None async def __aexit__(self, exc_type, exc, tb): """Ensure driver is closed when context manager exits""" await self.finalize() async def index_done_callback(self) -> None: # Noe4J handles persistence automatically pass async def has_node(self, node_id: str) -> bool: """ Check if a node with the given label exists in the database Args: node_id: Label of the node to check Returns: bool: True if node exists, False otherwise Raises: ValueError: If node_id is invalid Exception: If there is an error executing the query """ async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: try: query = "MATCH (n:base {entity_id: $entity_id}) RETURN count(n) > 0 AS node_exists" result = await session.run(query, entity_id=node_id) single_result = await result.single() await result.consume() # Ensure result is fully consumed return single_result["node_exists"] except Exception as e: logger.error(f"Error checking node existence for {node_id}: {str(e)}") await result.consume() # Ensure results are consumed even on error raise async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: """ Check if an edge exists between two nodes Args: source_node_id: Label of the source node target_node_id: Label of the target node Returns: bool: True if edge exists, False otherwise Raises: ValueError: If either node_id is invalid Exception: If there is an error executing the query """ async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: try: query = ( "MATCH (a:base {entity_id: $source_entity_id})-[r]-(b:base {entity_id: $target_entity_id}) " "RETURN COUNT(r) > 0 AS edgeExists" ) result = await session.run( query, source_entity_id=source_node_id, target_entity_id=target_node_id, ) single_result = await result.single() await result.consume() # Ensure result is fully consumed return single_result["edgeExists"] except Exception as e: logger.error( f"Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}" ) await result.consume() # Ensure results are consumed even on error raise async def get_node(self, node_id: str) -> dict[str, str] | None: """Get node by its label identifier, return only node properties Args: node_id: The node label to look up Returns: dict: Node properties if found None: If node not found Raises: ValueError: If node_id is invalid Exception: If there is an error executing the query """ async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: try: query = "MATCH (n:base {entity_id: $entity_id}) RETURN n" result = await session.run(query, entity_id=node_id) try: records = await result.fetch( 2 ) # Get 2 records for duplication check if len(records) > 1: logger.warning( f"Multiple nodes found with label '{node_id}'. Using first node." ) if records: node = records[0]["n"] node_dict = dict(node) # Remove base label from labels list if it exists if "labels" in node_dict: node_dict["labels"] = [ label for label in node_dict["labels"] if label != "base" ] # logger.debug(f"Neo4j query node {query} return: {node_dict}") return node_dict return None finally: await result.consume() # Ensure result is fully consumed except Exception as e: logger.error(f"Error getting node for {node_id}: {str(e)}") raise async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]: """ Retrieve multiple nodes in one query using UNWIND. Args: node_ids: List of node entity IDs to fetch. Returns: A dictionary mapping each node_id to its node data (or None if not found). """ async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: query = """ UNWIND $node_ids AS id MATCH (n:base {entity_id: id}) RETURN n.entity_id AS entity_id, n """ result = await session.run(query, node_ids=node_ids) nodes = {} async for record in result: entity_id = record["entity_id"] node = record["n"] node_dict = dict(node) # Remove the 'base' label if present in a 'labels' property if "labels" in node_dict: node_dict["labels"] = [ label for label in node_dict["labels"] if label != "base" ] nodes[entity_id] = node_dict await result.consume() # Make sure to consume the result fully return nodes async def node_degree(self, node_id: str) -> int: """Get the degree (number of relationships) of a node with the given label. If multiple nodes have the same label, returns the degree of the first node. If no node is found, returns 0. Args: node_id: The label of the node Returns: int: The number of relationships the node has, or 0 if no node found Raises: ValueError: If node_id is invalid Exception: If there is an error executing the query """ async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: try: query = """ MATCH (n:base {entity_id: $entity_id}) OPTIONAL MATCH (n)-[r]-() RETURN COUNT(r) AS degree """ result = await session.run(query, entity_id=node_id) try: record = await result.single() if not record: logger.warning(f"No node found with label '{node_id}'") return 0 degree = record["degree"] # logger.debug( # f"Neo4j query node degree for {node_id} return: {degree}" # ) return degree finally: await result.consume() # Ensure result is fully consumed except Exception as e: logger.error(f"Error getting node degree for {node_id}: {str(e)}") raise async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]: """ Retrieve the degree for multiple nodes in a single query using UNWIND. Args: node_ids: List of node labels (entity_id values) to look up. Returns: A dictionary mapping each node_id to its degree (number of relationships). If a node is not found, its degree will be set to 0. """ async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: query = """ UNWIND $node_ids AS id MATCH (n:base {entity_id: id}) RETURN n.entity_id AS entity_id, count { (n)--() } AS degree; """ result = await session.run(query, node_ids=node_ids) degrees = {} async for record in result: entity_id = record["entity_id"] degrees[entity_id] = record["degree"] await result.consume() # Ensure result is fully consumed # For any node_id that did not return a record, set degree to 0. for nid in node_ids: if nid not in degrees: logger.warning(f"No node found with label '{nid}'") degrees[nid] = 0 # logger.debug(f"Neo4j batch node degree query returned: {degrees}") return degrees async def edge_degree(self, src_id: str, tgt_id: str) -> int: """Get the total degree (sum of relationships) of two nodes. Args: src_id: Label of the source node tgt_id: Label of the target node Returns: int: Sum of the degrees of both nodes """ src_degree = await self.node_degree(src_id) trg_degree = await self.node_degree(tgt_id) # Convert None to 0 for addition src_degree = 0 if src_degree is None else src_degree trg_degree = 0 if trg_degree is None else trg_degree degrees = int(src_degree) + int(trg_degree) return degrees async def edge_degrees_batch( self, edge_pairs: list[tuple[str, str]] ) -> dict[tuple[str, str], int]: """ Calculate the combined degree for each edge (sum of the source and target node degrees) in batch using the already implemented node_degrees_batch. Args: edge_pairs: List of (src, tgt) tuples. Returns: A dictionary mapping each (src, tgt) tuple to the sum of their degrees. """ # Collect unique node IDs from all edge pairs. unique_node_ids = {src for src, _ in edge_pairs} unique_node_ids.update({tgt for _, tgt in edge_pairs}) # Get degrees for all nodes in one go. degrees = await self.node_degrees_batch(list(unique_node_ids)) # Sum up degrees for each edge pair. edge_degrees = {} for src, tgt in edge_pairs: edge_degrees[(src, tgt)] = degrees.get(src, 0) + degrees.get(tgt, 0) return edge_degrees async def get_edge( self, source_node_id: str, target_node_id: str ) -> dict[str, str] | None: """Get edge properties between two nodes. Args: source_node_id: Label of the source node target_node_id: Label of the target node Returns: dict: Edge properties if found, default properties if not found or on error Raises: ValueError: If either node_id is invalid Exception: If there is an error executing the query """ try: async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: query = """ MATCH (start:base {entity_id: $source_entity_id})-[r]-(end:base {entity_id: $target_entity_id}) RETURN properties(r) as edge_properties """ result = await session.run( query, source_entity_id=source_node_id, target_entity_id=target_node_id, ) try: records = await result.fetch(2) if len(records) > 1: logger.warning( f"Multiple edges found between '{source_node_id}' and '{target_node_id}'. Using first edge." ) if records: try: edge_result = dict(records[0]["edge_properties"]) # logger.debug(f"Result: {edge_result}") # Ensure required keys exist with defaults required_keys = { "weight": 0.0, "source_id": None, "description": None, "keywords": None, } for key, default_value in required_keys.items(): if key not in edge_result: edge_result[key] = default_value logger.warning( f"Edge between {source_node_id} and {target_node_id} " f"missing {key}, using default: {default_value}" ) # logger.debug( # f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{edge_result}" # ) return edge_result except (KeyError, TypeError, ValueError) as e: logger.error( f"Error processing edge properties between {source_node_id} " f"and {target_node_id}: {str(e)}" ) # Return default edge properties on error return { "weight": 0.0, "source_id": None, "description": None, "keywords": None, } # logger.debug( # f"{inspect.currentframe().f_code.co_name}: No edge found between {source_node_id} and {target_node_id}" # ) # Return None when no edge found return None finally: await result.consume() # Ensure result is fully consumed except Exception as e: logger.error( f"Error in get_edge between {source_node_id} and {target_node_id}: {str(e)}" ) raise async def get_edges_batch( self, pairs: list[dict[str, str]] ) -> dict[tuple[str, str], dict]: """ Retrieve edge properties for multiple (src, tgt) pairs in one query. Args: pairs: List of dictionaries, e.g. [{"src": "node1", "tgt": "node2"}, ...] Returns: A dictionary mapping (src, tgt) tuples to their edge properties. """ async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: query = """ UNWIND $pairs AS pair MATCH (start:base {entity_id: pair.src})-[r:DIRECTED]-(end:base {entity_id: pair.tgt}) RETURN pair.src AS src_id, pair.tgt AS tgt_id, collect(properties(r)) AS edges """ result = await session.run(query, pairs=pairs) edges_dict = {} async for record in result: src = record["src_id"] tgt = record["tgt_id"] edges = record["edges"] if edges and len(edges) > 0: edge_props = edges[0] # choose the first if multiple exist # Ensure required keys exist with defaults for key, default in { "weight": 0.0, "source_id": None, "description": None, "keywords": None, }.items(): if key not in edge_props: edge_props[key] = default edges_dict[(src, tgt)] = edge_props else: # No edge found – set default edge properties edges_dict[(src, tgt)] = { "weight": 0.0, "source_id": None, "description": None, "keywords": None, } await result.consume() return edges_dict async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: """Retrieves all edges (relationships) for a particular node identified by its label. Args: source_node_id: Label of the node to get edges for Returns: list[tuple[str, str]]: List of (source_label, target_label) tuples representing edges None: If no edges found Raises: ValueError: If source_node_id is invalid Exception: If there is an error executing the query """ try: async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: try: query = """MATCH (n:base {entity_id: $entity_id}) OPTIONAL MATCH (n)-[r]-(connected:base) WHERE connected.entity_id IS NOT NULL RETURN n, r, connected""" results = await session.run(query, entity_id=source_node_id) edges = [] async for record in results: source_node = record["n"] connected_node = record["connected"] # Skip if either node is None if not source_node or not connected_node: continue source_label = ( source_node.get("entity_id") if source_node.get("entity_id") else None ) target_label = ( connected_node.get("entity_id") if connected_node.get("entity_id") else None ) if source_label and target_label: edges.append((source_label, target_label)) await results.consume() # Ensure results are consumed return edges except Exception as e: logger.error( f"Error getting edges for node {source_node_id}: {str(e)}" ) await results.consume() # Ensure results are consumed even on error raise except Exception as e: logger.error(f"Error in get_node_edges for {source_node_id}: {str(e)}") raise async def get_nodes_edges_batch( self, node_ids: list[str] ) -> dict[str, list[tuple[str, str]]]: """ Batch retrieve edges for multiple nodes in one query using UNWIND. For each node, returns both outgoing and incoming edges to properly represent the undirected graph nature. Args: node_ids: List of node IDs (entity_id) for which to retrieve edges. Returns: A dictionary mapping each node ID to its list of edge tuples (source, target). For each node, the list includes both: - Outgoing edges: (queried_node, connected_node) - Incoming edges: (connected_node, queried_node) """ async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: # Query to get both outgoing and incoming edges query = """ UNWIND $node_ids AS id MATCH (n:base {entity_id: id}) OPTIONAL MATCH (n)-[r]-(connected:base) RETURN id AS queried_id, n.entity_id AS node_entity_id, connected.entity_id AS connected_entity_id, startNode(r).entity_id AS start_entity_id """ result = await session.run(query, node_ids=node_ids) # Initialize the dictionary with empty lists for each node ID edges_dict = {node_id: [] for node_id in node_ids} # Process results to include both outgoing and incoming edges async for record in result: queried_id = record["queried_id"] node_entity_id = record["node_entity_id"] connected_entity_id = record["connected_entity_id"] start_entity_id = record["start_entity_id"] # Skip if either node is None if not node_entity_id or not connected_entity_id: continue # Determine the actual direction of the edge # If the start node is the queried node, it's an outgoing edge # Otherwise, it's an incoming edge if start_entity_id == node_entity_id: # Outgoing edge: (queried_node -> connected_node) edges_dict[queried_id].append((node_entity_id, connected_entity_id)) else: # Incoming edge: (connected_node -> queried_node) edges_dict[queried_id].append((connected_entity_id, node_entity_id)) await result.consume() # Ensure results are fully consumed return edges_dict @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10), retry=retry_if_exception_type( ( neo4jExceptions.ServiceUnavailable, neo4jExceptions.TransientError, neo4jExceptions.WriteServiceUnavailable, neo4jExceptions.ClientError, ) ), ) async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: """ Upsert a node in the Neo4j database. Args: node_id: The unique identifier for the node (used as label) node_data: Dictionary of node properties """ properties = node_data entity_type = properties["entity_type"] if "entity_id" not in properties: raise ValueError("Neo4j: node properties must contain an 'entity_id' field") try: async with self._driver.session(database=self._DATABASE) as session: async def execute_upsert(tx: AsyncManagedTransaction): query = ( """ MERGE (n:base {entity_id: $entity_id}) SET n += $properties SET n:`%s` """ % entity_type ) result = await tx.run( query, entity_id=node_id, properties=properties ) await result.consume() # Ensure result is fully consumed await session.execute_write(execute_upsert) except Exception as e: logger.error(f"Error during upsert: {str(e)}") raise @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10), retry=retry_if_exception_type( ( neo4jExceptions.ServiceUnavailable, neo4jExceptions.TransientError, neo4jExceptions.WriteServiceUnavailable, neo4jExceptions.ClientError, ) ), ) async def upsert_edge( self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] ) -> None: """ Upsert an edge and its properties between two nodes identified by their labels. Ensures both source and target nodes exist and are unique before creating the edge. Uses entity_id property to uniquely identify nodes. Args: source_node_id (str): Label of the source node (used as identifier) target_node_id (str): Label of the target node (used as identifier) edge_data (dict): Dictionary of properties to set on the edge Raises: ValueError: If either source or target node does not exist or is not unique """ try: edge_properties = edge_data async with self._driver.session(database=self._DATABASE) as session: async def execute_upsert(tx: AsyncManagedTransaction): query = """ MATCH (source:base {entity_id: $source_entity_id}) WITH source MATCH (target:base {entity_id: $target_entity_id}) MERGE (source)-[r:DIRECTED]-(target) SET r += $properties RETURN r, source, target """ result = await tx.run( query, source_entity_id=source_node_id, target_entity_id=target_node_id, properties=edge_properties, ) try: await result.fetch(2) finally: await result.consume() # Ensure result is consumed await session.execute_write(execute_upsert) except Exception as e: logger.error(f"Error during edge upsert: {str(e)}") raise async def get_knowledge_graph( self, node_label: str, max_depth: int = 3, max_nodes: int = MAX_GRAPH_NODES, ) -> KnowledgeGraph: """ Retrieve a connected subgraph of nodes where the label includes the specified `node_label`. Args: node_label: Label of the starting node, * means all nodes max_depth: Maximum depth of the subgraph, Defaults to 3 max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000 Returns: KnowledgeGraph object containing nodes and edges, with an is_truncated flag indicating whether the graph was truncated due to max_nodes limit """ result = KnowledgeGraph() seen_nodes = set() seen_edges = set() async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: try: if node_label == "*": # First check total node count to determine if graph is truncated count_query = "MATCH (n) RETURN count(n) as total" count_result = None try: count_result = await session.run(count_query) count_record = await count_result.single() if count_record and count_record["total"] > max_nodes: result.is_truncated = True logger.info( f"Graph truncated: {count_record['total']} nodes found, limited to {max_nodes}" ) finally: if count_result: await count_result.consume() # Run main query to get nodes with highest degree main_query = """ MATCH (n) OPTIONAL MATCH (n)-[r]-() WITH n, COALESCE(count(r), 0) AS degree ORDER BY degree DESC LIMIT $max_nodes WITH collect({node: n}) AS filtered_nodes UNWIND filtered_nodes AS node_info WITH collect(node_info.node) AS kept_nodes, filtered_nodes OPTIONAL MATCH (a)-[r]-(b) WHERE a IN kept_nodes AND b IN kept_nodes RETURN filtered_nodes AS node_info, collect(DISTINCT r) AS relationships """ result_set = None try: result_set = await session.run( main_query, {"max_nodes": max_nodes}, ) record = await result_set.single() finally: if result_set: await result_set.consume() else: # return await self._robust_fallback(node_label, max_depth, max_nodes) # First try without limit to check if we need to truncate full_query = """ MATCH (start) WHERE start.entity_id = $entity_id WITH start CALL apoc.path.subgraphAll(start, { relationshipFilter: '', minLevel: 0, maxLevel: $max_depth, bfs: true }) YIELD nodes, relationships WITH nodes, relationships, size(nodes) AS total_nodes UNWIND nodes AS node WITH collect({node: node}) AS node_info, relationships, total_nodes RETURN node_info, relationships, total_nodes """ # Try to get full result full_result = None try: full_result = await session.run( full_query, { "entity_id": node_label, "max_depth": max_depth, }, ) full_record = await full_result.single() # If no record found, return empty KnowledgeGraph if not full_record: logger.debug(f"No nodes found for entity_id: {node_label}") return result # If record found, check node count total_nodes = full_record["total_nodes"] if total_nodes <= max_nodes: # If node count is within limit, use full result directly logger.debug( f"Using full result with {total_nodes} nodes (no truncation needed)" ) record = full_record else: # If node count exceeds limit, set truncated flag and run limited query result.is_truncated = True logger.info( f"Graph truncated: {total_nodes} nodes found, breadth-first search limited to {max_nodes}" ) # Run limited query limited_query = """ MATCH (start) WHERE start.entity_id = $entity_id WITH start CALL apoc.path.subgraphAll(start, { relationshipFilter: '', minLevel: 0, maxLevel: $max_depth, limit: $max_nodes, bfs: true }) YIELD nodes, relationships UNWIND nodes AS node WITH collect({node: node}) AS node_info, relationships RETURN node_info, relationships """ result_set = None try: result_set = await session.run( limited_query, { "entity_id": node_label, "max_depth": max_depth, "max_nodes": max_nodes, }, ) record = await result_set.single() finally: if result_set: await result_set.consume() finally: if full_result: await full_result.consume() if record: # Handle nodes (compatible with multi-label cases) for node_info in record["node_info"]: node = node_info["node"] node_id = node.id if node_id not in seen_nodes: result.nodes.append( KnowledgeGraphNode( id=f"{node_id}", labels=[node.get("entity_id")], properties=dict(node), ) ) seen_nodes.add(node_id) # Handle relationships (including direction information) for rel in record["relationships"]: edge_id = rel.id if edge_id not in seen_edges: start = rel.start_node end = rel.end_node result.edges.append( KnowledgeGraphEdge( id=f"{edge_id}", type=rel.type, source=f"{start.id}", target=f"{end.id}", properties=dict(rel), ) ) seen_edges.add(edge_id) logger.info( f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" ) except neo4jExceptions.ClientError as e: logger.warning(f"APOC plugin error: {str(e)}") if node_label != "*": logger.warning( "Neo4j: falling back to basic Cypher recursive search..." ) return await self._robust_fallback(node_label, max_depth, max_nodes) else: logger.warning( "Neo4j: APOC plugin error with wildcard query, returning empty result" ) return result async def _robust_fallback( self, node_label: str, max_depth: int, max_nodes: int ) -> KnowledgeGraph: """ Fallback implementation when APOC plugin is not available or incompatible. This method implements the same functionality as get_knowledge_graph but uses only basic Cypher queries and true breadth-first traversal instead of APOC procedures. """ from collections import deque result = KnowledgeGraph() visited_nodes = set() visited_edges = set() visited_edge_pairs = set() # Get the starting node's data async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: query = """ MATCH (n:base {entity_id: $entity_id}) RETURN id(n) as node_id, n """ node_result = await session.run(query, entity_id=node_label) try: node_record = await node_result.single() if not node_record: return result # Create initial KnowledgeGraphNode start_node = KnowledgeGraphNode( id=f"{node_record['n'].get('entity_id')}", labels=[node_record["n"].get("entity_id")], properties=dict(node_record["n"]._properties), ) finally: await node_result.consume() # Ensure results are consumed # Initialize queue for BFS with (node, edge, depth) tuples # edge is None for the starting node queue = deque([(start_node, None, 0)]) # True BFS implementation using a queue while queue and len(visited_nodes) < max_nodes: # Dequeue the next node to process current_node, current_edge, current_depth = queue.popleft() # Skip if already visited or exceeds max depth if current_node.id in visited_nodes: continue if current_depth > max_depth: logger.debug( f"Skipping node at depth {current_depth} (max_depth: {max_depth})" ) continue # Add current node to result result.nodes.append(current_node) visited_nodes.add(current_node.id) # Add edge to result if it exists and not already added if current_edge and current_edge.id not in visited_edges: result.edges.append(current_edge) visited_edges.add(current_edge.id) # Stop if we've reached the node limit if len(visited_nodes) >= max_nodes: result.is_truncated = True logger.info( f"Graph truncated: breadth-first search limited to: {max_nodes} nodes" ) break # Get all edges and target nodes for the current node (even at max_depth) async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: query = """ MATCH (a:base {entity_id: $entity_id})-[r]-(b) WITH r, b, id(r) as edge_id, id(b) as target_id RETURN r, b, edge_id, target_id """ results = await session.run(query, entity_id=current_node.id) # Get all records and release database connection records = await results.fetch(1000) # Max neighbor nodes we can handle await results.consume() # Ensure results are consumed # Process all neighbors - capture all edges but only queue unvisited nodes for record in records: rel = record["r"] edge_id = str(record["edge_id"]) if edge_id not in visited_edges: b_node = record["b"] target_id = b_node.get("entity_id") if target_id: # Only process if target node has entity_id # Create KnowledgeGraphNode for target target_node = KnowledgeGraphNode( id=f"{target_id}", labels=[target_id], properties=dict(b_node._properties), ) # Create KnowledgeGraphEdge target_edge = KnowledgeGraphEdge( id=f"{edge_id}", type=rel.type, source=f"{current_node.id}", target=f"{target_id}", properties=dict(rel), ) # Sort source_id and target_id to ensure (A,B) and (B,A) are treated as the same edge sorted_pair = tuple(sorted([current_node.id, target_id])) # Check if the same edge already exists (considering undirectedness) if sorted_pair not in visited_edge_pairs: # Only add the edge if the target node is already in the result or will be added if target_id in visited_nodes or ( target_id not in visited_nodes and current_depth < max_depth ): result.edges.append(target_edge) visited_edges.add(edge_id) visited_edge_pairs.add(sorted_pair) # Only add unvisited nodes to the queue for further expansion if target_id not in visited_nodes: # Only add to queue if we're not at max depth yet if current_depth < max_depth: # Add node to queue with incremented depth # Edge is already added to result, so we pass None as edge queue.append((target_node, None, current_depth + 1)) else: # At max depth, we've already added the edge but we don't add the node # This prevents adding nodes beyond max_depth to the result logger.debug( f"Node {target_id} beyond max depth {max_depth}, edge added but node not included" ) else: # If target node already exists in result, we don't need to add it again logger.debug( f"Node {target_id} already visited, edge added but node not queued" ) else: logger.warning( f"Skipping edge {edge_id} due to missing entity_id on target node" ) logger.info( f"BFS subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}" ) return result async def get_all_labels(self) -> list[str]: """ Get all existing node labels in the database Returns: ["Person", "Company", ...] # Alphabetically sorted label list """ async with self._driver.session( database=self._DATABASE, default_access_mode="READ" ) as session: # Method 1: Direct metadata query (Available for Neo4j 4.3+) # query = "CALL db.labels() YIELD label RETURN label" # Method 2: Query compatible with older versions query = """ MATCH (n:base) WHERE n.entity_id IS NOT NULL RETURN DISTINCT n.entity_id AS label ORDER BY label """ result = await session.run(query) labels = [] try: async for record in result: labels.append(record["label"]) finally: await ( result.consume() ) # Ensure results are consumed even if processing fails return labels @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10), retry=retry_if_exception_type( ( neo4jExceptions.ServiceUnavailable, neo4jExceptions.TransientError, neo4jExceptions.WriteServiceUnavailable, neo4jExceptions.ClientError, ) ), ) async def delete_node(self, node_id: str) -> None: """Delete a node with the specified label Args: node_id: The label of the node to delete """ async def _do_delete(tx: AsyncManagedTransaction): query = """ MATCH (n:base {entity_id: $entity_id}) DETACH DELETE n """ result = await tx.run(query, entity_id=node_id) logger.debug(f"Deleted node with label '{node_id}'") await result.consume() # Ensure result is fully consumed try: async with self._driver.session(database=self._DATABASE) as session: await session.execute_write(_do_delete) except Exception as e: logger.error(f"Error during node deletion: {str(e)}") raise @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10), retry=retry_if_exception_type( ( neo4jExceptions.ServiceUnavailable, neo4jExceptions.TransientError, neo4jExceptions.WriteServiceUnavailable, neo4jExceptions.ClientError, ) ), ) async def remove_nodes(self, nodes: list[str]): """Delete multiple nodes Args: nodes: List of node labels to be deleted """ for node in nodes: await self.delete_node(node) @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10), retry=retry_if_exception_type( ( neo4jExceptions.ServiceUnavailable, neo4jExceptions.TransientError, neo4jExceptions.WriteServiceUnavailable, neo4jExceptions.ClientError, ) ), ) async def remove_edges(self, edges: list[tuple[str, str]]): """Delete multiple edges Args: edges: List of edges to be deleted, each edge is a (source, target) tuple """ for source, target in edges: async def _do_delete_edge(tx: AsyncManagedTransaction): query = """ MATCH (source:base {entity_id: $source_entity_id})-[r]-(target:base {entity_id: $target_entity_id}) DELETE r """ result = await tx.run( query, source_entity_id=source, target_entity_id=target ) logger.debug(f"Deleted edge from '{source}' to '{target}'") await result.consume() # Ensure result is fully consumed try: async with self._driver.session(database=self._DATABASE) as session: await session.execute_write(_do_delete_edge) except Exception as e: logger.error(f"Error during edge deletion: {str(e)}") raise async def drop(self) -> dict[str, str]: """Drop all data from storage and clean up resources This method will delete all nodes and relationships in the Neo4j database. Returns: dict[str, str]: Operation status and message - On success: {"status": "success", "message": "data dropped"} - On failure: {"status": "error", "message": ""} """ try: async with self._driver.session(database=self._DATABASE) as session: # Delete all nodes and relationships query = "MATCH (n) DETACH DELETE n" result = await session.run(query) await result.consume() # Ensure result is fully consumed logger.info( f"Process {os.getpid()} drop Neo4j database {self._DATABASE}" ) return {"status": "success", "message": "data dropped"} except Exception as e: logger.error(f"Error dropping Neo4j database {self._DATABASE}: {e}") return {"status": "error", "message": str(e)}