You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1344 lines
55 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import os
import re
from dataclasses import dataclass
from typing import final
import configparser
from tenacity import (
retry,
stop_after_attempt,
wait_exponential,
retry_if_exception_type,
)
import logging
from ..utils import logger
from ..base import BaseGraphStorage
from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
import pipmaster as pm
if not pm.is_installed("neo4j"):
pm.install("neo4j")
from neo4j import ( # type: ignore
AsyncGraphDatabase,
exceptions as neo4jExceptions,
AsyncDriver,
AsyncManagedTransaction,
)
from dotenv import load_dotenv
# use the .env that is inside the current folder
# allows to use different .env file for each lightrag instance
# the OS environment variables take precedence over the .env file
load_dotenv(dotenv_path=".env", override=False)
# Get maximum number of graph nodes from environment variable, default is 1000
MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000))
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
# Set neo4j logger level to ERROR to suppress warning logs
logging.getLogger("neo4j").setLevel(logging.ERROR)
@final
@dataclass
class Neo4JStorage(BaseGraphStorage):
def __init__(self, namespace, global_config, embedding_func):
super().__init__(
namespace=namespace,
global_config=global_config,
embedding_func=embedding_func,
)
self._driver = None
async def initialize(self):
URI = os.environ.get("NEO4J_URI", config.get("neo4j", "uri", fallback=None))
USERNAME = os.environ.get(
"NEO4J_USERNAME", config.get("neo4j", "username", fallback=None)
)
PASSWORD = os.environ.get(
"NEO4J_PASSWORD", config.get("neo4j", "password", fallback=None)
)
MAX_CONNECTION_POOL_SIZE = int(
os.environ.get(
"NEO4J_MAX_CONNECTION_POOL_SIZE",
config.get("neo4j", "connection_pool_size", fallback=50),
)
)
CONNECTION_TIMEOUT = float(
os.environ.get(
"NEO4J_CONNECTION_TIMEOUT",
config.get("neo4j", "connection_timeout", fallback=30.0),
),
)
CONNECTION_ACQUISITION_TIMEOUT = float(
os.environ.get(
"NEO4J_CONNECTION_ACQUISITION_TIMEOUT",
config.get("neo4j", "connection_acquisition_timeout", fallback=30.0),
),
)
MAX_TRANSACTION_RETRY_TIME = float(
os.environ.get(
"NEO4J_MAX_TRANSACTION_RETRY_TIME",
config.get("neo4j", "max_transaction_retry_time", fallback=30.0),
),
)
DATABASE = os.environ.get(
"NEO4J_DATABASE", re.sub(r"[^a-zA-Z0-9-]", "-", self.namespace)
)
self._driver: AsyncDriver = AsyncGraphDatabase.driver(
URI,
auth=(USERNAME, PASSWORD),
max_connection_pool_size=MAX_CONNECTION_POOL_SIZE,
connection_timeout=CONNECTION_TIMEOUT,
connection_acquisition_timeout=CONNECTION_ACQUISITION_TIMEOUT,
max_transaction_retry_time=MAX_TRANSACTION_RETRY_TIME,
)
# Try to connect to the database and create it if it doesn't exist
for database in (DATABASE, None):
self._DATABASE = database
connected = False
try:
async with self._driver.session(database=database) as session:
try:
result = await session.run("MATCH (n) RETURN n LIMIT 0")
await result.consume() # Ensure result is consumed
logger.info(f"Connected to {database} at {URI}")
connected = True
except neo4jExceptions.ServiceUnavailable as e:
logger.error(
f"{database} at {URI} is not available".capitalize()
)
raise e
except neo4jExceptions.AuthError as e:
logger.error(f"Authentication failed for {database} at {URI}")
raise e
except neo4jExceptions.ClientError as e:
if e.code == "Neo.ClientError.Database.DatabaseNotFound":
logger.info(
f"{database} at {URI} not found. Try to create specified database.".capitalize()
)
try:
async with self._driver.session() as session:
result = await session.run(
f"CREATE DATABASE `{database}` IF NOT EXISTS"
)
await result.consume() # Ensure result is consumed
logger.info(f"{database} at {URI} created".capitalize())
connected = True
except (
neo4jExceptions.ClientError,
neo4jExceptions.DatabaseError,
) as e:
if (
e.code
== "Neo.ClientError.Statement.UnsupportedAdministrationCommand"
) or (e.code == "Neo.DatabaseError.Statement.ExecutionFailed"):
if database is not None:
logger.warning(
"This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead. Fallback to use the default database."
)
if database is None:
logger.error(f"Failed to create {database} at {URI}")
raise e
if connected:
# Create index for base nodes on entity_id if it doesn't exist
try:
async with self._driver.session(database=database) as session:
# Check if index exists first
check_query = """
CALL db.indexes() YIELD name, labelsOrTypes, properties
WHERE labelsOrTypes = ['base'] AND properties = ['entity_id']
RETURN count(*) > 0 AS exists
"""
try:
check_result = await session.run(check_query)
record = await check_result.single()
await check_result.consume()
index_exists = record and record.get("exists", False)
if not index_exists:
# Create index only if it doesn't exist
result = await session.run(
"CREATE INDEX FOR (n:base) ON (n.entity_id)"
)
await result.consume()
logger.info(
f"Created index for base nodes on entity_id in {database}"
)
except Exception:
# Fallback if db.indexes() is not supported in this Neo4j version
result = await session.run(
"CREATE INDEX IF NOT EXISTS FOR (n:base) ON (n.entity_id)"
)
await result.consume()
except Exception as e:
logger.warning(f"Failed to create index: {str(e)}")
break
async def finalize(self):
"""Close the Neo4j driver and release all resources"""
if self._driver:
await self._driver.close()
self._driver = None
async def __aexit__(self, exc_type, exc, tb):
"""Ensure driver is closed when context manager exits"""
await self.finalize()
async def index_done_callback(self) -> None:
# Noe4J handles persistence automatically
pass
async def has_node(self, node_id: str) -> bool:
"""
Check if a node with the given label exists in the database
Args:
node_id: Label of the node to check
Returns:
bool: True if node exists, False otherwise
Raises:
ValueError: If node_id is invalid
Exception: If there is an error executing the query
"""
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
try:
query = "MATCH (n:base {entity_id: $entity_id}) RETURN count(n) > 0 AS node_exists"
result = await session.run(query, entity_id=node_id)
single_result = await result.single()
await result.consume() # Ensure result is fully consumed
return single_result["node_exists"]
except Exception as e:
logger.error(f"Error checking node existence for {node_id}: {str(e)}")
await result.consume() # Ensure results are consumed even on error
raise
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
"""
Check if an edge exists between two nodes
Args:
source_node_id: Label of the source node
target_node_id: Label of the target node
Returns:
bool: True if edge exists, False otherwise
Raises:
ValueError: If either node_id is invalid
Exception: If there is an error executing the query
"""
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
try:
query = (
"MATCH (a:base {entity_id: $source_entity_id})-[r]-(b:base {entity_id: $target_entity_id}) "
"RETURN COUNT(r) > 0 AS edgeExists"
)
result = await session.run(
query,
source_entity_id=source_node_id,
target_entity_id=target_node_id,
)
single_result = await result.single()
await result.consume() # Ensure result is fully consumed
return single_result["edgeExists"]
except Exception as e:
logger.error(
f"Error checking edge existence between {source_node_id} and {target_node_id}: {str(e)}"
)
await result.consume() # Ensure results are consumed even on error
raise
async def get_node(self, node_id: str) -> dict[str, str] | None:
"""Get node by its label identifier, return only node properties
Args:
node_id: The node label to look up
Returns:
dict: Node properties if found
None: If node not found
Raises:
ValueError: If node_id is invalid
Exception: If there is an error executing the query
"""
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
try:
query = "MATCH (n:base {entity_id: $entity_id}) RETURN n"
result = await session.run(query, entity_id=node_id)
try:
records = await result.fetch(
2
) # Get 2 records for duplication check
if len(records) > 1:
logger.warning(
f"Multiple nodes found with label '{node_id}'. Using first node."
)
if records:
node = records[0]["n"]
node_dict = dict(node)
# Remove base label from labels list if it exists
if "labels" in node_dict:
node_dict["labels"] = [
label
for label in node_dict["labels"]
if label != "base"
]
# logger.debug(f"Neo4j query node {query} return: {node_dict}")
return node_dict
return None
finally:
await result.consume() # Ensure result is fully consumed
except Exception as e:
logger.error(f"Error getting node for {node_id}: {str(e)}")
raise
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
"""
Retrieve multiple nodes in one query using UNWIND.
Args:
node_ids: List of node entity IDs to fetch.
Returns:
A dictionary mapping each node_id to its node data (or None if not found).
"""
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = """
UNWIND $node_ids AS id
MATCH (n:base {entity_id: id})
RETURN n.entity_id AS entity_id, n
"""
result = await session.run(query, node_ids=node_ids)
nodes = {}
async for record in result:
entity_id = record["entity_id"]
node = record["n"]
node_dict = dict(node)
# Remove the 'base' label if present in a 'labels' property
if "labels" in node_dict:
node_dict["labels"] = [
label for label in node_dict["labels"] if label != "base"
]
nodes[entity_id] = node_dict
await result.consume() # Make sure to consume the result fully
return nodes
async def node_degree(self, node_id: str) -> int:
"""Get the degree (number of relationships) of a node with the given label.
If multiple nodes have the same label, returns the degree of the first node.
If no node is found, returns 0.
Args:
node_id: The label of the node
Returns:
int: The number of relationships the node has, or 0 if no node found
Raises:
ValueError: If node_id is invalid
Exception: If there is an error executing the query
"""
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
try:
query = """
MATCH (n:base {entity_id: $entity_id})
OPTIONAL MATCH (n)-[r]-()
RETURN COUNT(r) AS degree
"""
result = await session.run(query, entity_id=node_id)
try:
record = await result.single()
if not record:
logger.warning(f"No node found with label '{node_id}'")
return 0
degree = record["degree"]
# logger.debug(
# f"Neo4j query node degree for {node_id} return: {degree}"
# )
return degree
finally:
await result.consume() # Ensure result is fully consumed
except Exception as e:
logger.error(f"Error getting node degree for {node_id}: {str(e)}")
raise
async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
"""
Retrieve the degree for multiple nodes in a single query using UNWIND.
Args:
node_ids: List of node labels (entity_id values) to look up.
Returns:
A dictionary mapping each node_id to its degree (number of relationships).
If a node is not found, its degree will be set to 0.
"""
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = """
UNWIND $node_ids AS id
MATCH (n:base {entity_id: id})
RETURN n.entity_id AS entity_id, count { (n)--() } AS degree;
"""
result = await session.run(query, node_ids=node_ids)
degrees = {}
async for record in result:
entity_id = record["entity_id"]
degrees[entity_id] = record["degree"]
await result.consume() # Ensure result is fully consumed
# For any node_id that did not return a record, set degree to 0.
for nid in node_ids:
if nid not in degrees:
logger.warning(f"No node found with label '{nid}'")
degrees[nid] = 0
# logger.debug(f"Neo4j batch node degree query returned: {degrees}")
return degrees
async def edge_degree(self, src_id: str, tgt_id: str) -> int:
"""Get the total degree (sum of relationships) of two nodes.
Args:
src_id: Label of the source node
tgt_id: Label of the target node
Returns:
int: Sum of the degrees of both nodes
"""
src_degree = await self.node_degree(src_id)
trg_degree = await self.node_degree(tgt_id)
# Convert None to 0 for addition
src_degree = 0 if src_degree is None else src_degree
trg_degree = 0 if trg_degree is None else trg_degree
degrees = int(src_degree) + int(trg_degree)
return degrees
async def edge_degrees_batch(
self, edge_pairs: list[tuple[str, str]]
) -> dict[tuple[str, str], int]:
"""
Calculate the combined degree for each edge (sum of the source and target node degrees)
in batch using the already implemented node_degrees_batch.
Args:
edge_pairs: List of (src, tgt) tuples.
Returns:
A dictionary mapping each (src, tgt) tuple to the sum of their degrees.
"""
# Collect unique node IDs from all edge pairs.
unique_node_ids = {src for src, _ in edge_pairs}
unique_node_ids.update({tgt for _, tgt in edge_pairs})
# Get degrees for all nodes in one go.
degrees = await self.node_degrees_batch(list(unique_node_ids))
# Sum up degrees for each edge pair.
edge_degrees = {}
for src, tgt in edge_pairs:
edge_degrees[(src, tgt)] = degrees.get(src, 0) + degrees.get(tgt, 0)
return edge_degrees
async def get_edge(
self, source_node_id: str, target_node_id: str
) -> dict[str, str] | None:
"""Get edge properties between two nodes.
Args:
source_node_id: Label of the source node
target_node_id: Label of the target node
Returns:
dict: Edge properties if found, default properties if not found or on error
Raises:
ValueError: If either node_id is invalid
Exception: If there is an error executing the query
"""
try:
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = """
MATCH (start:base {entity_id: $source_entity_id})-[r]-(end:base {entity_id: $target_entity_id})
RETURN properties(r) as edge_properties
"""
result = await session.run(
query,
source_entity_id=source_node_id,
target_entity_id=target_node_id,
)
try:
records = await result.fetch(2)
if len(records) > 1:
logger.warning(
f"Multiple edges found between '{source_node_id}' and '{target_node_id}'. Using first edge."
)
if records:
try:
edge_result = dict(records[0]["edge_properties"])
# logger.debug(f"Result: {edge_result}")
# Ensure required keys exist with defaults
required_keys = {
"weight": 0.0,
"source_id": None,
"description": None,
"keywords": None,
}
for key, default_value in required_keys.items():
if key not in edge_result:
edge_result[key] = default_value
logger.warning(
f"Edge between {source_node_id} and {target_node_id} "
f"missing {key}, using default: {default_value}"
)
# logger.debug(
# f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{edge_result}"
# )
return edge_result
except (KeyError, TypeError, ValueError) as e:
logger.error(
f"Error processing edge properties between {source_node_id} "
f"and {target_node_id}: {str(e)}"
)
# Return default edge properties on error
return {
"weight": 0.0,
"source_id": None,
"description": None,
"keywords": None,
}
# logger.debug(
# f"{inspect.currentframe().f_code.co_name}: No edge found between {source_node_id} and {target_node_id}"
# )
# Return None when no edge found
return None
finally:
await result.consume() # Ensure result is fully consumed
except Exception as e:
logger.error(
f"Error in get_edge between {source_node_id} and {target_node_id}: {str(e)}"
)
raise
async def get_edges_batch(
self, pairs: list[dict[str, str]]
) -> dict[tuple[str, str], dict]:
"""
Retrieve edge properties for multiple (src, tgt) pairs in one query.
Args:
pairs: List of dictionaries, e.g. [{"src": "node1", "tgt": "node2"}, ...]
Returns:
A dictionary mapping (src, tgt) tuples to their edge properties.
"""
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = """
UNWIND $pairs AS pair
MATCH (start:base {entity_id: pair.src})-[r:DIRECTED]-(end:base {entity_id: pair.tgt})
RETURN pair.src AS src_id, pair.tgt AS tgt_id, collect(properties(r)) AS edges
"""
result = await session.run(query, pairs=pairs)
edges_dict = {}
async for record in result:
src = record["src_id"]
tgt = record["tgt_id"]
edges = record["edges"]
if edges and len(edges) > 0:
edge_props = edges[0] # choose the first if multiple exist
# Ensure required keys exist with defaults
for key, default in {
"weight": 0.0,
"source_id": None,
"description": None,
"keywords": None,
}.items():
if key not in edge_props:
edge_props[key] = default
edges_dict[(src, tgt)] = edge_props
else:
# No edge found set default edge properties
edges_dict[(src, tgt)] = {
"weight": 0.0,
"source_id": None,
"description": None,
"keywords": None,
}
await result.consume()
return edges_dict
async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
"""Retrieves all edges (relationships) for a particular node identified by its label.
Args:
source_node_id: Label of the node to get edges for
Returns:
list[tuple[str, str]]: List of (source_label, target_label) tuples representing edges
None: If no edges found
Raises:
ValueError: If source_node_id is invalid
Exception: If there is an error executing the query
"""
try:
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
try:
query = """MATCH (n:base {entity_id: $entity_id})
OPTIONAL MATCH (n)-[r]-(connected:base)
WHERE connected.entity_id IS NOT NULL
RETURN n, r, connected"""
results = await session.run(query, entity_id=source_node_id)
edges = []
async for record in results:
source_node = record["n"]
connected_node = record["connected"]
# Skip if either node is None
if not source_node or not connected_node:
continue
source_label = (
source_node.get("entity_id")
if source_node.get("entity_id")
else None
)
target_label = (
connected_node.get("entity_id")
if connected_node.get("entity_id")
else None
)
if source_label and target_label:
edges.append((source_label, target_label))
await results.consume() # Ensure results are consumed
return edges
except Exception as e:
logger.error(
f"Error getting edges for node {source_node_id}: {str(e)}"
)
await results.consume() # Ensure results are consumed even on error
raise
except Exception as e:
logger.error(f"Error in get_node_edges for {source_node_id}: {str(e)}")
raise
async def get_nodes_edges_batch(
self, node_ids: list[str]
) -> dict[str, list[tuple[str, str]]]:
"""
Batch retrieve edges for multiple nodes in one query using UNWIND.
For each node, returns both outgoing and incoming edges to properly represent
the undirected graph nature.
Args:
node_ids: List of node IDs (entity_id) for which to retrieve edges.
Returns:
A dictionary mapping each node ID to its list of edge tuples (source, target).
For each node, the list includes both:
- Outgoing edges: (queried_node, connected_node)
- Incoming edges: (connected_node, queried_node)
"""
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
# Query to get both outgoing and incoming edges
query = """
UNWIND $node_ids AS id
MATCH (n:base {entity_id: id})
OPTIONAL MATCH (n)-[r]-(connected:base)
RETURN id AS queried_id, n.entity_id AS node_entity_id,
connected.entity_id AS connected_entity_id,
startNode(r).entity_id AS start_entity_id
"""
result = await session.run(query, node_ids=node_ids)
# Initialize the dictionary with empty lists for each node ID
edges_dict = {node_id: [] for node_id in node_ids}
# Process results to include both outgoing and incoming edges
async for record in result:
queried_id = record["queried_id"]
node_entity_id = record["node_entity_id"]
connected_entity_id = record["connected_entity_id"]
start_entity_id = record["start_entity_id"]
# Skip if either node is None
if not node_entity_id or not connected_entity_id:
continue
# Determine the actual direction of the edge
# If the start node is the queried node, it's an outgoing edge
# Otherwise, it's an incoming edge
if start_entity_id == node_entity_id:
# Outgoing edge: (queried_node -> connected_node)
edges_dict[queried_id].append((node_entity_id, connected_entity_id))
else:
# Incoming edge: (connected_node -> queried_node)
edges_dict[queried_id].append((connected_entity_id, node_entity_id))
await result.consume() # Ensure results are fully consumed
return edges_dict
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.WriteServiceUnavailable,
neo4jExceptions.ClientError,
)
),
)
async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
"""
Upsert a node in the Neo4j database.
Args:
node_id: The unique identifier for the node (used as label)
node_data: Dictionary of node properties
"""
properties = node_data
entity_type = properties["entity_type"]
if "entity_id" not in properties:
raise ValueError("Neo4j: node properties must contain an 'entity_id' field")
try:
async with self._driver.session(database=self._DATABASE) as session:
async def execute_upsert(tx: AsyncManagedTransaction):
query = (
"""
MERGE (n:base {entity_id: $entity_id})
SET n += $properties
SET n:`%s`
"""
% entity_type
)
result = await tx.run(
query, entity_id=node_id, properties=properties
)
await result.consume() # Ensure result is fully consumed
await session.execute_write(execute_upsert)
except Exception as e:
logger.error(f"Error during upsert: {str(e)}")
raise
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.WriteServiceUnavailable,
neo4jExceptions.ClientError,
)
),
)
async def upsert_edge(
self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
) -> None:
"""
Upsert an edge and its properties between two nodes identified by their labels.
Ensures both source and target nodes exist and are unique before creating the edge.
Uses entity_id property to uniquely identify nodes.
Args:
source_node_id (str): Label of the source node (used as identifier)
target_node_id (str): Label of the target node (used as identifier)
edge_data (dict): Dictionary of properties to set on the edge
Raises:
ValueError: If either source or target node does not exist or is not unique
"""
try:
edge_properties = edge_data
async with self._driver.session(database=self._DATABASE) as session:
async def execute_upsert(tx: AsyncManagedTransaction):
query = """
MATCH (source:base {entity_id: $source_entity_id})
WITH source
MATCH (target:base {entity_id: $target_entity_id})
MERGE (source)-[r:DIRECTED]-(target)
SET r += $properties
RETURN r, source, target
"""
result = await tx.run(
query,
source_entity_id=source_node_id,
target_entity_id=target_node_id,
properties=edge_properties,
)
try:
await result.fetch(2)
finally:
await result.consume() # Ensure result is consumed
await session.execute_write(execute_upsert)
except Exception as e:
logger.error(f"Error during edge upsert: {str(e)}")
raise
async def get_knowledge_graph(
self,
node_label: str,
max_depth: int = 3,
max_nodes: int = MAX_GRAPH_NODES,
) -> KnowledgeGraph:
"""
Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
Args:
node_label: Label of the starting node, * means all nodes
max_depth: Maximum depth of the subgraph, Defaults to 3
max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000
Returns:
KnowledgeGraph object containing nodes and edges, with an is_truncated flag
indicating whether the graph was truncated due to max_nodes limit
"""
result = KnowledgeGraph()
seen_nodes = set()
seen_edges = set()
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
try:
if node_label == "*":
# First check total node count to determine if graph is truncated
count_query = "MATCH (n) RETURN count(n) as total"
count_result = None
try:
count_result = await session.run(count_query)
count_record = await count_result.single()
if count_record and count_record["total"] > max_nodes:
result.is_truncated = True
logger.info(
f"Graph truncated: {count_record['total']} nodes found, limited to {max_nodes}"
)
finally:
if count_result:
await count_result.consume()
# Run main query to get nodes with highest degree
main_query = """
MATCH (n)
OPTIONAL MATCH (n)-[r]-()
WITH n, COALESCE(count(r), 0) AS degree
ORDER BY degree DESC
LIMIT $max_nodes
WITH collect({node: n}) AS filtered_nodes
UNWIND filtered_nodes AS node_info
WITH collect(node_info.node) AS kept_nodes, filtered_nodes
OPTIONAL MATCH (a)-[r]-(b)
WHERE a IN kept_nodes AND b IN kept_nodes
RETURN filtered_nodes AS node_info,
collect(DISTINCT r) AS relationships
"""
result_set = None
try:
result_set = await session.run(
main_query,
{"max_nodes": max_nodes},
)
record = await result_set.single()
finally:
if result_set:
await result_set.consume()
else:
# return await self._robust_fallback(node_label, max_depth, max_nodes)
# First try without limit to check if we need to truncate
full_query = """
MATCH (start)
WHERE start.entity_id = $entity_id
WITH start
CALL apoc.path.subgraphAll(start, {
relationshipFilter: '',
minLevel: 0,
maxLevel: $max_depth,
bfs: true
})
YIELD nodes, relationships
WITH nodes, relationships, size(nodes) AS total_nodes
UNWIND nodes AS node
WITH collect({node: node}) AS node_info, relationships, total_nodes
RETURN node_info, relationships, total_nodes
"""
# Try to get full result
full_result = None
try:
full_result = await session.run(
full_query,
{
"entity_id": node_label,
"max_depth": max_depth,
},
)
full_record = await full_result.single()
# If no record found, return empty KnowledgeGraph
if not full_record:
logger.debug(f"No nodes found for entity_id: {node_label}")
return result
# If record found, check node count
total_nodes = full_record["total_nodes"]
if total_nodes <= max_nodes:
# If node count is within limit, use full result directly
logger.debug(
f"Using full result with {total_nodes} nodes (no truncation needed)"
)
record = full_record
else:
# If node count exceeds limit, set truncated flag and run limited query
result.is_truncated = True
logger.info(
f"Graph truncated: {total_nodes} nodes found, breadth-first search limited to {max_nodes}"
)
# Run limited query
limited_query = """
MATCH (start)
WHERE start.entity_id = $entity_id
WITH start
CALL apoc.path.subgraphAll(start, {
relationshipFilter: '',
minLevel: 0,
maxLevel: $max_depth,
limit: $max_nodes,
bfs: true
})
YIELD nodes, relationships
UNWIND nodes AS node
WITH collect({node: node}) AS node_info, relationships
RETURN node_info, relationships
"""
result_set = None
try:
result_set = await session.run(
limited_query,
{
"entity_id": node_label,
"max_depth": max_depth,
"max_nodes": max_nodes,
},
)
record = await result_set.single()
finally:
if result_set:
await result_set.consume()
finally:
if full_result:
await full_result.consume()
if record:
# Handle nodes (compatible with multi-label cases)
for node_info in record["node_info"]:
node = node_info["node"]
node_id = node.id
if node_id not in seen_nodes:
result.nodes.append(
KnowledgeGraphNode(
id=f"{node_id}",
labels=[node.get("entity_id")],
properties=dict(node),
)
)
seen_nodes.add(node_id)
# Handle relationships (including direction information)
for rel in record["relationships"]:
edge_id = rel.id
if edge_id not in seen_edges:
start = rel.start_node
end = rel.end_node
result.edges.append(
KnowledgeGraphEdge(
id=f"{edge_id}",
type=rel.type,
source=f"{start.id}",
target=f"{end.id}",
properties=dict(rel),
)
)
seen_edges.add(edge_id)
logger.info(
f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
)
except neo4jExceptions.ClientError as e:
logger.warning(f"APOC plugin error: {str(e)}")
if node_label != "*":
logger.warning(
"Neo4j: falling back to basic Cypher recursive search..."
)
return await self._robust_fallback(node_label, max_depth, max_nodes)
else:
logger.warning(
"Neo4j: APOC plugin error with wildcard query, returning empty result"
)
return result
async def _robust_fallback(
self, node_label: str, max_depth: int, max_nodes: int
) -> KnowledgeGraph:
"""
Fallback implementation when APOC plugin is not available or incompatible.
This method implements the same functionality as get_knowledge_graph but uses
only basic Cypher queries and true breadth-first traversal instead of APOC procedures.
"""
from collections import deque
result = KnowledgeGraph()
visited_nodes = set()
visited_edges = set()
visited_edge_pairs = set()
# Get the starting node's data
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = """
MATCH (n:base {entity_id: $entity_id})
RETURN id(n) as node_id, n
"""
node_result = await session.run(query, entity_id=node_label)
try:
node_record = await node_result.single()
if not node_record:
return result
# Create initial KnowledgeGraphNode
start_node = KnowledgeGraphNode(
id=f"{node_record['n'].get('entity_id')}",
labels=[node_record["n"].get("entity_id")],
properties=dict(node_record["n"]._properties),
)
finally:
await node_result.consume() # Ensure results are consumed
# Initialize queue for BFS with (node, edge, depth) tuples
# edge is None for the starting node
queue = deque([(start_node, None, 0)])
# True BFS implementation using a queue
while queue and len(visited_nodes) < max_nodes:
# Dequeue the next node to process
current_node, current_edge, current_depth = queue.popleft()
# Skip if already visited or exceeds max depth
if current_node.id in visited_nodes:
continue
if current_depth > max_depth:
logger.debug(
f"Skipping node at depth {current_depth} (max_depth: {max_depth})"
)
continue
# Add current node to result
result.nodes.append(current_node)
visited_nodes.add(current_node.id)
# Add edge to result if it exists and not already added
if current_edge and current_edge.id not in visited_edges:
result.edges.append(current_edge)
visited_edges.add(current_edge.id)
# Stop if we've reached the node limit
if len(visited_nodes) >= max_nodes:
result.is_truncated = True
logger.info(
f"Graph truncated: breadth-first search limited to: {max_nodes} nodes"
)
break
# Get all edges and target nodes for the current node (even at max_depth)
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
query = """
MATCH (a:base {entity_id: $entity_id})-[r]-(b)
WITH r, b, id(r) as edge_id, id(b) as target_id
RETURN r, b, edge_id, target_id
"""
results = await session.run(query, entity_id=current_node.id)
# Get all records and release database connection
records = await results.fetch(1000) # Max neighbor nodes we can handle
await results.consume() # Ensure results are consumed
# Process all neighbors - capture all edges but only queue unvisited nodes
for record in records:
rel = record["r"]
edge_id = str(record["edge_id"])
if edge_id not in visited_edges:
b_node = record["b"]
target_id = b_node.get("entity_id")
if target_id: # Only process if target node has entity_id
# Create KnowledgeGraphNode for target
target_node = KnowledgeGraphNode(
id=f"{target_id}",
labels=[target_id],
properties=dict(b_node._properties),
)
# Create KnowledgeGraphEdge
target_edge = KnowledgeGraphEdge(
id=f"{edge_id}",
type=rel.type,
source=f"{current_node.id}",
target=f"{target_id}",
properties=dict(rel),
)
# Sort source_id and target_id to ensure (A,B) and (B,A) are treated as the same edge
sorted_pair = tuple(sorted([current_node.id, target_id]))
# Check if the same edge already exists (considering undirectedness)
if sorted_pair not in visited_edge_pairs:
# Only add the edge if the target node is already in the result or will be added
if target_id in visited_nodes or (
target_id not in visited_nodes
and current_depth < max_depth
):
result.edges.append(target_edge)
visited_edges.add(edge_id)
visited_edge_pairs.add(sorted_pair)
# Only add unvisited nodes to the queue for further expansion
if target_id not in visited_nodes:
# Only add to queue if we're not at max depth yet
if current_depth < max_depth:
# Add node to queue with incremented depth
# Edge is already added to result, so we pass None as edge
queue.append((target_node, None, current_depth + 1))
else:
# At max depth, we've already added the edge but we don't add the node
# This prevents adding nodes beyond max_depth to the result
logger.debug(
f"Node {target_id} beyond max depth {max_depth}, edge added but node not included"
)
else:
# If target node already exists in result, we don't need to add it again
logger.debug(
f"Node {target_id} already visited, edge added but node not queued"
)
else:
logger.warning(
f"Skipping edge {edge_id} due to missing entity_id on target node"
)
logger.info(
f"BFS subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
)
return result
async def get_all_labels(self) -> list[str]:
"""
Get all existing node labels in the database
Returns:
["Person", "Company", ...] # Alphabetically sorted label list
"""
async with self._driver.session(
database=self._DATABASE, default_access_mode="READ"
) as session:
# Method 1: Direct metadata query (Available for Neo4j 4.3+)
# query = "CALL db.labels() YIELD label RETURN label"
# Method 2: Query compatible with older versions
query = """
MATCH (n:base)
WHERE n.entity_id IS NOT NULL
RETURN DISTINCT n.entity_id AS label
ORDER BY label
"""
result = await session.run(query)
labels = []
try:
async for record in result:
labels.append(record["label"])
finally:
await (
result.consume()
) # Ensure results are consumed even if processing fails
return labels
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.WriteServiceUnavailable,
neo4jExceptions.ClientError,
)
),
)
async def delete_node(self, node_id: str) -> None:
"""Delete a node with the specified label
Args:
node_id: The label of the node to delete
"""
async def _do_delete(tx: AsyncManagedTransaction):
query = """
MATCH (n:base {entity_id: $entity_id})
DETACH DELETE n
"""
result = await tx.run(query, entity_id=node_id)
logger.debug(f"Deleted node with label '{node_id}'")
await result.consume() # Ensure result is fully consumed
try:
async with self._driver.session(database=self._DATABASE) as session:
await session.execute_write(_do_delete)
except Exception as e:
logger.error(f"Error during node deletion: {str(e)}")
raise
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.WriteServiceUnavailable,
neo4jExceptions.ClientError,
)
),
)
async def remove_nodes(self, nodes: list[str]):
"""Delete multiple nodes
Args:
nodes: List of node labels to be deleted
"""
for node in nodes:
await self.delete_node(node)
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=4, max=10),
retry=retry_if_exception_type(
(
neo4jExceptions.ServiceUnavailable,
neo4jExceptions.TransientError,
neo4jExceptions.WriteServiceUnavailable,
neo4jExceptions.ClientError,
)
),
)
async def remove_edges(self, edges: list[tuple[str, str]]):
"""Delete multiple edges
Args:
edges: List of edges to be deleted, each edge is a (source, target) tuple
"""
for source, target in edges:
async def _do_delete_edge(tx: AsyncManagedTransaction):
query = """
MATCH (source:base {entity_id: $source_entity_id})-[r]-(target:base {entity_id: $target_entity_id})
DELETE r
"""
result = await tx.run(
query, source_entity_id=source, target_entity_id=target
)
logger.debug(f"Deleted edge from '{source}' to '{target}'")
await result.consume() # Ensure result is fully consumed
try:
async with self._driver.session(database=self._DATABASE) as session:
await session.execute_write(_do_delete_edge)
except Exception as e:
logger.error(f"Error during edge deletion: {str(e)}")
raise
async def drop(self) -> dict[str, str]:
"""Drop all data from storage and clean up resources
This method will delete all nodes and relationships in the Neo4j database.
Returns:
dict[str, str]: Operation status and message
- On success: {"status": "success", "message": "data dropped"}
- On failure: {"status": "error", "message": "<error details>"}
"""
try:
async with self._driver.session(database=self._DATABASE) as session:
# Delete all nodes and relationships
query = "MATCH (n) DETACH DELETE n"
result = await session.run(query)
await result.consume() # Ensure result is fully consumed
logger.info(
f"Process {os.getpid()} drop Neo4j database {self._DATABASE}"
)
return {"status": "success", "message": "data dropped"}
except Exception as e:
logger.error(f"Error dropping Neo4j database {self._DATABASE}: {e}")
return {"status": "error", "message": str(e)}