diff --git a/dsLightRag/Doc/Postgresql支持工作空间的代码修改.txt b/dsLightRag/Doc/Postgresql支持工作空间的代码修改.txt new file mode 100644 index 00000000..5ab3f8c2 --- /dev/null +++ b/dsLightRag/Doc/Postgresql支持工作空间的代码修改.txt @@ -0,0 +1,8 @@ +https://github.com/HKUDS/LightRAG/issues/1244 + +# 源码路径 +D:\anaconda3\envs\py310\Lib\site-packages\lightrag + +# 用VScode打开编辑,可以全局搜索 + +D:\anaconda3\envs\py310\Lib\site-packages\lightrag\kg\postgres_impl.py \ No newline at end of file diff --git a/dsLightRag/Doc/[Feature Suggestion]_ Change _ Pass Workspace Name in PgSQL Storages · Issue #1244 · HKUDS_LightRAG.pdf b/dsLightRag/Doc/[Feature Suggestion]_ Change _ Pass Workspace Name in PgSQL Storages · Issue #1244 · HKUDS_LightRAG.pdf new file mode 100644 index 00000000..c900a9c1 Binary files /dev/null and b/dsLightRag/Doc/[Feature Suggestion]_ Change _ Pass Workspace Name in PgSQL Storages · Issue #1244 · HKUDS_LightRAG.pdf differ diff --git a/dsLightRag/Doc/postgres_impl.py b/dsLightRag/Doc/postgres_impl.py new file mode 100644 index 00000000..f02ad79f --- /dev/null +++ b/dsLightRag/Doc/postgres_impl.py @@ -0,0 +1,2575 @@ +import asyncio +import json +import os +import datetime +from datetime import timezone +from dataclasses import dataclass, field +from typing import Any, Union, final +import numpy as np +import configparser + +from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge + +from tenacity import ( + retry, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) + +from ..base import ( + BaseGraphStorage, + BaseKVStorage, + BaseVectorStorage, + DocProcessingStatus, + DocStatus, + DocStatusStorage, +) +from ..namespace import NameSpace, is_namespace +from ..utils import logger + +import pipmaster as pm + +if not pm.is_installed("asyncpg"): + pm.install("asyncpg") + +import asyncpg # type: ignore +from asyncpg import Pool # type: ignore + +from dotenv import load_dotenv + +# use the .env that is inside the current folder +# allows to use different .env file for each lightrag instance +# the OS environment variables take precedence over the .env file +load_dotenv(dotenv_path=".env", override=False) + +# Get maximum number of graph nodes from environment variable, default is 1000 +MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000)) + + +class PostgreSQLDB: + def __init__(self, config: dict[str, Any], **kwargs: Any): + self.host = config["host"] + self.port = config["port"] + self.user = config["user"] + self.password = config["password"] + self.database = config["database"] + self.workspace = config["workspace"] + self.max = int(config["max_connections"]) + self.increment = 1 + self.pool: Pool | None = None + + if self.user is None or self.password is None or self.database is None: + raise ValueError("Missing database user, password, or database") + + async def initdb(self): + try: + self.pool = await asyncpg.create_pool( # type: ignore + user=self.user, + password=self.password, + database=self.database, + host=self.host, + port=self.port, + min_size=1, + max_size=self.max, + ) + + logger.info( + f"PostgreSQL, Connected to database at {self.host}:{self.port}/{self.database}" + ) + except Exception as e: + logger.error( + f"PostgreSQL, Failed to connect database at {self.host}:{self.port}/{self.database}, Got:{e}" + ) + raise + + @staticmethod + async def configure_age(connection: asyncpg.Connection, graph_name: str) -> None: + """Set the Apache AGE environment and creates a graph if it does not exist. + + This method: + - Sets the PostgreSQL `search_path` to include `ag_catalog`, ensuring that Apache AGE functions can be used without specifying the schema. + - Attempts to create a new graph with the provided `graph_name` if it does not already exist. + - Silently ignores errors related to the graph already existing. + + """ + try: + await connection.execute( # type: ignore + 'SET search_path = ag_catalog, "$user", public' + ) + await connection.execute( # type: ignore + f"select create_graph('{graph_name}')" + ) + except ( + asyncpg.exceptions.InvalidSchemaNameError, + asyncpg.exceptions.UniqueViolationError, + ): + pass + + async def _migrate_llm_cache_add_chunk_id(self): + """Add chunk_id column to LIGHTRAG_LLM_CACHE table if it doesn't exist""" + try: + # Check if chunk_id column exists + check_column_sql = """ + SELECT column_name + FROM information_schema.columns + WHERE table_name = 'lightrag_llm_cache' + AND column_name = 'chunk_id' + """ + + column_info = await self.query(check_column_sql) + if not column_info: + logger.info("Adding chunk_id column to LIGHTRAG_LLM_CACHE table") + add_column_sql = """ + ALTER TABLE LIGHTRAG_LLM_CACHE + ADD COLUMN chunk_id VARCHAR(255) NULL + """ + await self.execute(add_column_sql) + logger.info( + "Successfully added chunk_id column to LIGHTRAG_LLM_CACHE table" + ) + else: + logger.info( + "chunk_id column already exists in LIGHTRAG_LLM_CACHE table" + ) + except Exception as e: + logger.warning(f"Failed to add chunk_id column to LIGHTRAG_LLM_CACHE: {e}") + + async def _migrate_timestamp_columns(self): + """Migrate timestamp columns in tables to timezone-aware types, assuming original data is in UTC time""" + # Tables and columns that need migration + tables_to_migrate = { + "LIGHTRAG_VDB_ENTITY": ["create_time", "update_time"], + "LIGHTRAG_VDB_RELATION": ["create_time", "update_time"], + "LIGHTRAG_DOC_CHUNKS": ["create_time", "update_time"], + } + + for table_name, columns in tables_to_migrate.items(): + for column_name in columns: + try: + # Check if column exists + check_column_sql = f""" + SELECT column_name, data_type + FROM information_schema.columns + WHERE table_name = '{table_name.lower()}' + AND column_name = '{column_name}' + """ + + column_info = await self.query(check_column_sql) + if not column_info: + logger.warning( + f"Column {table_name}.{column_name} does not exist, skipping migration" + ) + continue + + # Check column type + data_type = column_info.get("data_type") + if data_type == "timestamp with time zone": + logger.info( + f"Column {table_name}.{column_name} is already timezone-aware, no migration needed" + ) + continue + + # Execute migration, explicitly specifying UTC timezone for interpreting original data + logger.info( + f"Migrating {table_name}.{column_name} to timezone-aware type" + ) + migration_sql = f""" + ALTER TABLE {table_name} + ALTER COLUMN {column_name} TYPE TIMESTAMP(0) WITH TIME ZONE + USING {column_name} AT TIME ZONE 'UTC' + """ + + await self.execute(migration_sql) + logger.info( + f"Successfully migrated {table_name}.{column_name} to timezone-aware type" + ) + except Exception as e: + # Log error but don't interrupt the process + logger.warning(f"Failed to migrate {table_name}.{column_name}: {e}") + + async def check_tables(self): + # First create all tables + for k, v in TABLES.items(): + try: + await self.query(f"SELECT 1 FROM {k} LIMIT 1") + except Exception: + try: + logger.info(f"PostgreSQL, Try Creating table {k} in database") + await self.execute(v["ddl"]) + logger.info( + f"PostgreSQL, Creation success table {k} in PostgreSQL database" + ) + except Exception as e: + logger.error( + f"PostgreSQL, Failed to create table {k} in database, Please verify the connection with PostgreSQL database, Got: {e}" + ) + raise e + + # Create index for id column in each table + try: + index_name = f"idx_{k.lower()}_id" + check_index_sql = f""" + SELECT 1 FROM pg_indexes + WHERE indexname = '{index_name}' + AND tablename = '{k.lower()}' + """ + index_exists = await self.query(check_index_sql) + + if not index_exists: + create_index_sql = f"CREATE INDEX {index_name} ON {k}(id)" + logger.info(f"PostgreSQL, Creating index {index_name} on table {k}") + await self.execute(create_index_sql) + except Exception as e: + logger.error( + f"PostgreSQL, Failed to create index on table {k}, Got: {e}" + ) + + # After all tables are created, attempt to migrate timestamp fields + try: + await self._migrate_timestamp_columns() + except Exception as e: + logger.error(f"PostgreSQL, Failed to migrate timestamp columns: {e}") + # Don't throw an exception, allow the initialization process to continue + + # Migrate LLM cache table to add chunk_id field if needed + try: + await self._migrate_llm_cache_add_chunk_id() + except Exception as e: + logger.error(f"PostgreSQL, Failed to migrate LLM cache chunk_id field: {e}") + # Don't throw an exception, allow the initialization process to continue + + async def query( + self, + sql: str, + params: dict[str, Any] | None = None, + multirows: bool = False, + with_age: bool = False, + graph_name: str | None = None, + ) -> dict[str, Any] | None | list[dict[str, Any]]: + # start_time = time.time() + # logger.info(f"PostgreSQL, Querying:\n{sql}") + + async with self.pool.acquire() as connection: # type: ignore + if with_age and graph_name: + await self.configure_age(connection, graph_name) # type: ignore + elif with_age and not graph_name: + raise ValueError("Graph name is required when with_age is True") + + try: + if params: + rows = await connection.fetch(sql, *params.values()) + else: + rows = await connection.fetch(sql) + + if multirows: + if rows: + columns = [col for col in rows[0].keys()] + data = [dict(zip(columns, row)) for row in rows] + else: + data = [] + else: + if rows: + columns = rows[0].keys() + data = dict(zip(columns, rows[0])) + else: + data = None + + # query_time = time.time() - start_time + # logger.info(f"PostgreSQL, Query result len: {len(data)}") + # logger.info(f"PostgreSQL, Query execution time: {query_time:.4f}s") + + return data + except Exception as e: + logger.error(f"PostgreSQL database, error:{e}") + raise + + async def execute( + self, + sql: str, + data: dict[str, Any] | None = None, + upsert: bool = False, + ignore_if_exists: bool = False, + with_age: bool = False, + graph_name: str | None = None, + ): + try: + async with self.pool.acquire() as connection: # type: ignore + if with_age and graph_name: + await self.configure_age(connection, graph_name) + elif with_age and not graph_name: + raise ValueError("Graph name is required when with_age is True") + + if data is None: + await connection.execute(sql) + else: + await connection.execute(sql, *data.values()) + except ( + asyncpg.exceptions.UniqueViolationError, + asyncpg.exceptions.DuplicateTableError, + asyncpg.exceptions.DuplicateObjectError, # Catch "already exists" error + asyncpg.exceptions.InvalidSchemaNameError, # Also catch for AGE extension "already exists" + ) as e: + if ignore_if_exists: + # If the flag is set, just ignore these specific errors + pass + elif upsert: + print("Key value duplicate, but upsert succeeded.") + else: + logger.error(f"Upsert error: {e}") + except Exception as e: + logger.error(f"PostgreSQL database,\nsql:{sql},\ndata:{data},\nerror:{e}") + raise + + +class ClientManager: + _instances: dict[str, Any] = {"db": None, "ref_count": 0} + _lock = asyncio.Lock() + + @staticmethod + #def get_config() -> dict[str, Any]: + def get_config(global_config: dict[str, Any] | None = None) -> dict[str, Any]: + # First try to get workspace from global config + workspace = None + if global_config and "vector_db_storage_cls_kwargs" in global_config: + workspace = global_config["vector_db_storage_cls_kwargs"].get("workspace") + + # Read standard config + config = configparser.ConfigParser() + config.read("config.ini", "utf-8") + + return { + "host": os.environ.get( + "POSTGRES_HOST", + config.get("postgres", "host", fallback="localhost"), + ), + "port": os.environ.get( + "POSTGRES_PORT", config.get("postgres", "port", fallback=5432) + ), + "user": os.environ.get( + "POSTGRES_USER", config.get("postgres", "user", fallback="postgres") + ), + "password": os.environ.get( + "POSTGRES_PASSWORD", + config.get("postgres", "password", fallback=None), + ), + "database": os.environ.get( + "POSTGRES_DATABASE", + config.get("postgres", "database", fallback="postgres"), + ), + # Use workspace from global config if available, otherwise fall back to env/config.ini + "workspace": workspace or os.environ.get( + "POSTGRES_WORKSPACE", + config.get("postgres", "workspace", fallback="default"), + ), + "max_connections": os.environ.get( + "POSTGRES_MAX_CONNECTIONS", + config.get("postgres", "max_connections", fallback=20), + ), + } + + @classmethod + #async def get_client(cls) -> PostgreSQLDB: + async def get_client(cls, global_config: dict[str, Any] | None = None) -> PostgreSQLDB: + async with cls._lock: + if cls._instances["db"] is None: + #config = ClientManager.get_config() + config = cls.get_config(global_config) + db = PostgreSQLDB(config) + await db.initdb() + await db.check_tables() + cls._instances["db"] = db + cls._instances["ref_count"] = 0 + cls._instances["ref_count"] += 1 + return cls._instances["db"] + + @classmethod + async def release_client(cls, db: PostgreSQLDB): + async with cls._lock: + if db is not None: + if db is cls._instances["db"]: + cls._instances["ref_count"] -= 1 + if cls._instances["ref_count"] == 0: + await db.pool.close() + logger.info("Closed PostgreSQL database connection pool") + cls._instances["db"] = None + else: + await db.pool.close() + + +@final +@dataclass +class PGKVStorage(BaseKVStorage): + db: PostgreSQLDB = field(default=None) + + def __post_init__(self): + namespace_prefix = self.global_config.get("namespace_prefix") + self.base_namespace = self.namespace.replace(namespace_prefix, "") + self._max_batch_size = self.global_config["embedding_batch_num"] + + async def initialize(self): + if self.db is None: + #self.db = await ClientManager.get_client() + self.db = await ClientManager.get_client(self.global_config) + + async def finalize(self): + if self.db is not None: + await ClientManager.release_client(self.db) + self.db = None + + ################ QUERY METHODS ################ + async def get_all(self) -> dict[str, Any]: + """Get all data from storage + + Returns: + Dictionary containing all stored data + """ + table_name = namespace_to_table_name(self.namespace) + if not table_name: + logger.error(f"Unknown namespace for get_all: {self.namespace}") + return {} + + sql = f"SELECT * FROM {table_name} WHERE workspace=$1" + params = {"workspace": self.db.workspace} + + try: + results = await self.db.query(sql, params, multirows=True) + + if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): + result_dict = {} + for row in results: + mode = row["mode"] + if mode not in result_dict: + result_dict[mode] = {} + result_dict[mode][row["id"]] = row + return result_dict + else: + return {row["id"]: row for row in results} + except Exception as e: + logger.error(f"Error retrieving all data from {self.namespace}: {e}") + return {} + + async def get_by_id(self, id: str) -> dict[str, Any] | None: + """Get doc_full data by id.""" + sql = SQL_TEMPLATES["get_by_id_" + self.namespace] + params = {"workspace": self.db.workspace, "id": id} + if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): + array_res = await self.db.query(sql, params, multirows=True) + res = {} + for row in array_res: + res[row["id"]] = row + return res if res else None + else: + response = await self.db.query(sql, params) + return response if response else None + + async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]: + """Specifically for llm_response_cache.""" + sql = SQL_TEMPLATES["get_by_mode_id_" + self.namespace] + params = {"workspace": self.db.workspace, "mode": mode, "id": id} + if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): + array_res = await self.db.query(sql, params, multirows=True) + res = {} + for row in array_res: + res[row["id"]] = row + return res + else: + return None + + # Query by id + async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: + """Get doc_chunks data by id""" + sql = SQL_TEMPLATES["get_by_ids_" + self.namespace].format( + ids=",".join([f"'{id}'" for id in ids]) + ) + params = {"workspace": self.db.workspace} + if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): + array_res = await self.db.query(sql, params, multirows=True) + modes = set() + dict_res: dict[str, dict] = {} + for row in array_res: + modes.add(row["mode"]) + for mode in modes: + if mode not in dict_res: + dict_res[mode] = {} + for row in array_res: + dict_res[row["mode"]][row["id"]] = row + return [{k: v} for k, v in dict_res.items()] + else: + return await self.db.query(sql, params, multirows=True) + + async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]: + """Specifically for llm_response_cache.""" + SQL = SQL_TEMPLATES["get_by_status_" + self.namespace] + params = {"workspace": self.db.workspace, "status": status} + return await self.db.query(SQL, params, multirows=True) + + async def filter_keys(self, keys: set[str]) -> set[str]: + """Filter out duplicated content""" + sql = SQL_TEMPLATES["filter_keys"].format( + table_name=namespace_to_table_name(self.namespace), + ids=",".join([f"'{id}'" for id in keys]), + ) + params = {"workspace": self.db.workspace} + try: + res = await self.db.query(sql, params, multirows=True) + if res: + exist_keys = [key["id"] for key in res] + else: + exist_keys = [] + new_keys = set([s for s in keys if s not in exist_keys]) + return new_keys + except Exception as e: + logger.error( + f"PostgreSQL database,\nsql:{sql},\nparams:{params},\nerror:{e}" + ) + raise + + ################ INSERT METHODS ################ + async def upsert(self, data: dict[str, dict[str, Any]]) -> None: + logger.debug(f"Inserting {len(data)} to {self.namespace}") + if not data: + return + + if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS): + pass + elif is_namespace(self.namespace, NameSpace.KV_STORE_FULL_DOCS): + for k, v in data.items(): + upsert_sql = SQL_TEMPLATES["upsert_doc_full"] + _data = { + "id": k, + "content": v["content"], + "workspace": self.db.workspace, + } + await self.db.execute(upsert_sql, _data) + elif is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): + for mode, items in data.items(): + for k, v in items.items(): + upsert_sql = SQL_TEMPLATES["upsert_llm_response_cache"] + _data = { + "workspace": self.db.workspace, + "id": k, + "original_prompt": v["original_prompt"], + "return_value": v["return"], + "mode": mode, + "chunk_id": v.get("chunk_id"), + } + + await self.db.execute(upsert_sql, _data) + + async def index_done_callback(self) -> None: + # PG handles persistence automatically + pass + + async def delete(self, ids: list[str]) -> None: + """Delete specific records from storage by their IDs + + Args: + ids (list[str]): List of document IDs to be deleted from storage + + Returns: + None + """ + if not ids: + return + + table_name = namespace_to_table_name(self.namespace) + if not table_name: + logger.error(f"Unknown namespace for deletion: {self.namespace}") + return + + delete_sql = f"DELETE FROM {table_name} WHERE workspace=$1 AND id = ANY($2)" + + try: + await self.db.execute( + delete_sql, {"workspace": self.db.workspace, "ids": ids} + ) + logger.debug( + f"Successfully deleted {len(ids)} records from {self.namespace}" + ) + except Exception as e: + logger.error(f"Error while deleting records from {self.namespace}: {e}") + + async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool: + """Delete specific records from storage by cache mode + + Args: + modes (list[str]): List of cache modes to be dropped from storage + + Returns: + bool: True if successful, False otherwise + """ + if not modes: + return False + + try: + table_name = namespace_to_table_name(self.namespace) + if not table_name: + return False + + if table_name != "LIGHTRAG_LLM_CACHE": + return False + + sql = f""" + DELETE FROM {table_name} + WHERE workspace = $1 AND mode = ANY($2) + """ + params = {"workspace": self.db.workspace, "modes": modes} + + logger.info(f"Deleting cache by modes: {modes}") + await self.db.execute(sql, params) + return True + except Exception as e: + logger.error(f"Error deleting cache by modes {modes}: {e}") + return False + + async def drop(self) -> dict[str, str]: + """Drop the storage""" + try: + table_name = namespace_to_table_name(self.namespace) + if not table_name: + return { + "status": "error", + "message": f"Unknown namespace: {self.namespace}", + } + + drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format( + table_name=table_name + ) + await self.db.execute(drop_sql, {"workspace": self.db.workspace}) + return {"status": "success", "message": "data dropped"} + except Exception as e: + return {"status": "error", "message": str(e)} + + +@final +@dataclass +class PGVectorStorage(BaseVectorStorage): + db: PostgreSQLDB | None = field(default=None) + + def __post_init__(self): + self._max_batch_size = self.global_config["embedding_batch_num"] + namespace_prefix = self.global_config.get("namespace_prefix") + self.base_namespace = self.namespace.replace(namespace_prefix, "") + config = self.global_config.get("vector_db_storage_cls_kwargs", {}) + cosine_threshold = config.get("cosine_better_than_threshold") + if cosine_threshold is None: + raise ValueError( + "cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs" + ) + self.cosine_better_than_threshold = cosine_threshold + + async def initialize(self): + if self.db is None: + #self.db = await ClientManager.get_client() + self.db = await ClientManager.get_client(self.global_config) + + async def finalize(self): + if self.db is not None: + await ClientManager.release_client(self.db) + self.db = None + + def _upsert_chunks( + self, item: dict[str, Any], current_time: datetime.datetime + ) -> tuple[str, dict[str, Any]]: + try: + upsert_sql = SQL_TEMPLATES["upsert_chunk"] + data: dict[str, Any] = { + "workspace": self.db.workspace, + "id": item["__id__"], + "tokens": item["tokens"], + "chunk_order_index": item["chunk_order_index"], + "full_doc_id": item["full_doc_id"], + "content": item["content"], + "content_vector": json.dumps(item["__vector__"].tolist()), + "file_path": item["file_path"], + "create_time": current_time, + "update_time": current_time, + } + except Exception as e: + logger.error(f"Error to prepare upsert,\nsql: {e}\nitem: {item}") + raise + + return upsert_sql, data + + def _upsert_entities( + self, item: dict[str, Any], current_time: datetime.datetime + ) -> tuple[str, dict[str, Any]]: + upsert_sql = SQL_TEMPLATES["upsert_entity"] + source_id = item["source_id"] + if isinstance(source_id, str) and "" in source_id: + chunk_ids = source_id.split("") + else: + chunk_ids = [source_id] + + data: dict[str, Any] = { + "workspace": self.db.workspace, + "id": item["__id__"], + "entity_name": item["entity_name"], + "content": item["content"], + "content_vector": json.dumps(item["__vector__"].tolist()), + "chunk_ids": chunk_ids, + "file_path": item.get("file_path", None), + "create_time": current_time, + "update_time": current_time, + } + return upsert_sql, data + + def _upsert_relationships( + self, item: dict[str, Any], current_time: datetime.datetime + ) -> tuple[str, dict[str, Any]]: + upsert_sql = SQL_TEMPLATES["upsert_relationship"] + source_id = item["source_id"] + if isinstance(source_id, str) and "" in source_id: + chunk_ids = source_id.split("") + else: + chunk_ids = [source_id] + + data: dict[str, Any] = { + "workspace": self.db.workspace, + "id": item["__id__"], + "source_id": item["src_id"], + "target_id": item["tgt_id"], + "content": item["content"], + "content_vector": json.dumps(item["__vector__"].tolist()), + "chunk_ids": chunk_ids, + "file_path": item.get("file_path", None), + "create_time": current_time, + "update_time": current_time, + } + return upsert_sql, data + + async def upsert(self, data: dict[str, dict[str, Any]]) -> None: + logger.debug(f"Inserting {len(data)} to {self.namespace}") + if not data: + return + + # Get current time with UTC timezone + current_time = datetime.datetime.now(timezone.utc) + list_data = [ + { + "__id__": k, + **{k1: v1 for k1, v1 in v.items()}, + } + for k, v in data.items() + ] + contents = [v["content"] for v in data.values()] + batches = [ + contents[i : i + self._max_batch_size] + for i in range(0, len(contents), self._max_batch_size) + ] + + embedding_tasks = [self.embedding_func(batch) for batch in batches] + embeddings_list = await asyncio.gather(*embedding_tasks) + + embeddings = np.concatenate(embeddings_list) + for i, d in enumerate(list_data): + d["__vector__"] = embeddings[i] + for item in list_data: + if is_namespace(self.namespace, NameSpace.VECTOR_STORE_CHUNKS): + upsert_sql, data = self._upsert_chunks(item, current_time) + elif is_namespace(self.namespace, NameSpace.VECTOR_STORE_ENTITIES): + upsert_sql, data = self._upsert_entities(item, current_time) + elif is_namespace(self.namespace, NameSpace.VECTOR_STORE_RELATIONSHIPS): + upsert_sql, data = self._upsert_relationships(item, current_time) + else: + raise ValueError(f"{self.namespace} is not supported") + + await self.db.execute(upsert_sql, data) + + #################### query method ############### + async def query( + self, query: str, top_k: int, ids: list[str] | None = None + ) -> list[dict[str, Any]]: + embeddings = await self.embedding_func( + [query], _priority=5 + ) # higher priority for query + embedding = embeddings[0] + embedding_string = ",".join(map(str, embedding)) + # Use parameterized document IDs (None means search across all documents) + sql = SQL_TEMPLATES[self.namespace].format(embedding_string=embedding_string) + params = { + "workspace": self.db.workspace, + "doc_ids": ids, + "better_than_threshold": self.cosine_better_than_threshold, + "top_k": top_k, + } + results = await self.db.query(sql, params=params, multirows=True) + return results + + async def index_done_callback(self) -> None: + # PG handles persistence automatically + pass + + async def delete(self, ids: list[str]) -> None: + """Delete vectors with specified IDs from the storage. + + Args: + ids: List of vector IDs to be deleted + """ + if not ids: + return + + table_name = namespace_to_table_name(self.namespace) + if not table_name: + logger.error(f"Unknown namespace for vector deletion: {self.namespace}") + return + + delete_sql = f"DELETE FROM {table_name} WHERE workspace=$1 AND id = ANY($2)" + + try: + await self.db.execute( + delete_sql, {"workspace": self.db.workspace, "ids": ids} + ) + logger.debug( + f"Successfully deleted {len(ids)} vectors from {self.namespace}" + ) + except Exception as e: + logger.error(f"Error while deleting vectors from {self.namespace}: {e}") + + async def delete_entity(self, entity_name: str) -> None: + """Delete an entity by its name from the vector storage. + + Args: + entity_name: The name of the entity to delete + """ + try: + # Construct SQL to delete the entity + delete_sql = """DELETE FROM LIGHTRAG_VDB_ENTITY + WHERE workspace=$1 AND entity_name=$2""" + + await self.db.execute( + delete_sql, {"workspace": self.db.workspace, "entity_name": entity_name} + ) + logger.debug(f"Successfully deleted entity {entity_name}") + except Exception as e: + logger.error(f"Error deleting entity {entity_name}: {e}") + + async def delete_entity_relation(self, entity_name: str) -> None: + """Delete all relations associated with an entity. + + Args: + entity_name: The name of the entity whose relations should be deleted + """ + try: + # Delete relations where the entity is either the source or target + delete_sql = """DELETE FROM LIGHTRAG_VDB_RELATION + WHERE workspace=$1 AND (source_id=$2 OR target_id=$2)""" + + await self.db.execute( + delete_sql, {"workspace": self.db.workspace, "entity_name": entity_name} + ) + logger.debug(f"Successfully deleted relations for entity {entity_name}") + except Exception as e: + logger.error(f"Error deleting relations for entity {entity_name}: {e}") + + async def get_by_id(self, id: str) -> dict[str, Any] | None: + """Get vector data by its ID + + Args: + id: The unique identifier of the vector + + Returns: + The vector data if found, or None if not found + """ + table_name = namespace_to_table_name(self.namespace) + if not table_name: + logger.error(f"Unknown namespace for ID lookup: {self.namespace}") + return None + + query = f"SELECT *, EXTRACT(EPOCH FROM create_time)::BIGINT as created_at FROM {table_name} WHERE workspace=$1 AND id=$2" + params = {"workspace": self.db.workspace, "id": id} + + try: + result = await self.db.query(query, params) + if result: + return dict(result) + return None + except Exception as e: + logger.error(f"Error retrieving vector data for ID {id}: {e}") + return None + + async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: + """Get multiple vector data by their IDs + + Args: + ids: List of unique identifiers + + Returns: + List of vector data objects that were found + """ + if not ids: + return [] + + table_name = namespace_to_table_name(self.namespace) + if not table_name: + logger.error(f"Unknown namespace for IDs lookup: {self.namespace}") + return [] + + ids_str = ",".join([f"'{id}'" for id in ids]) + query = f"SELECT *, EXTRACT(EPOCH FROM create_time)::BIGINT as created_at FROM {table_name} WHERE workspace=$1 AND id IN ({ids_str})" + params = {"workspace": self.db.workspace} + + try: + results = await self.db.query(query, params, multirows=True) + return [dict(record) for record in results] + except Exception as e: + logger.error(f"Error retrieving vector data for IDs {ids}: {e}") + return [] + + async def drop(self) -> dict[str, str]: + """Drop the storage""" + try: + table_name = namespace_to_table_name(self.namespace) + if not table_name: + return { + "status": "error", + "message": f"Unknown namespace: {self.namespace}", + } + + drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format( + table_name=table_name + ) + await self.db.execute(drop_sql, {"workspace": self.db.workspace}) + return {"status": "success", "message": "data dropped"} + except Exception as e: + return {"status": "error", "message": str(e)} + + +@final +@dataclass +class PGDocStatusStorage(DocStatusStorage): + db: PostgreSQLDB = field(default=None) + + async def initialize(self): + if self.db is None: + #self.db = await ClientManager.get_client() + self.db = await ClientManager.get_client(self.global_config) + + async def finalize(self): + if self.db is not None: + await ClientManager.release_client(self.db) + self.db = None + + async def filter_keys(self, keys: set[str]) -> set[str]: + """Filter out duplicated content""" + sql = SQL_TEMPLATES["filter_keys"].format( + table_name=namespace_to_table_name(self.namespace), + ids=",".join([f"'{id}'" for id in keys]), + ) + params = {"workspace": self.db.workspace} + try: + res = await self.db.query(sql, params, multirows=True) + if res: + exist_keys = [key["id"] for key in res] + else: + exist_keys = [] + new_keys = set([s for s in keys if s not in exist_keys]) + print(f"keys: {keys}") + print(f"new_keys: {new_keys}") + return new_keys + except Exception as e: + logger.error( + f"PostgreSQL database,\nsql:{sql},\nparams:{params},\nerror:{e}" + ) + raise + + async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: + sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and id=$2" + params = {"workspace": self.db.workspace, "id": id} + result = await self.db.query(sql, params, True) + if result is None or result == []: + return None + else: + return dict( + content=result[0]["content"], + content_length=result[0]["content_length"], + content_summary=result[0]["content_summary"], + status=result[0]["status"], + chunks_count=result[0]["chunks_count"], + created_at=result[0]["created_at"], + updated_at=result[0]["updated_at"], + file_path=result[0]["file_path"], + ) + + async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: + """Get doc_chunks data by multiple IDs.""" + if not ids: + return [] + + sql = "SELECT * FROM LIGHTRAG_DOC_STATUS WHERE workspace=$1 AND id = ANY($2)" + params = {"workspace": self.db.workspace, "ids": ids} + + results = await self.db.query(sql, params, True) + + if not results: + return [] + return [ + { + "content": row["content"], + "content_length": row["content_length"], + "content_summary": row["content_summary"], + "status": row["status"], + "chunks_count": row["chunks_count"], + "created_at": row["created_at"], + "updated_at": row["updated_at"], + "file_path": row["file_path"], + } + for row in results + ] + + async def get_status_counts(self) -> dict[str, int]: + """Get counts of documents in each status""" + sql = """SELECT status as "status", COUNT(1) as "count" + FROM LIGHTRAG_DOC_STATUS + where workspace=$1 GROUP BY STATUS + """ + result = await self.db.query(sql, {"workspace": self.db.workspace}, True) + counts = {} + for doc in result: + counts[doc["status"]] = doc["count"] + return counts + + async def get_docs_by_status( + self, status: DocStatus + ) -> dict[str, DocProcessingStatus]: + """all documents with a specific status""" + sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$2" + params = {"workspace": self.db.workspace, "status": status.value} + result = await self.db.query(sql, params, True) + docs_by_status = { + element["id"]: DocProcessingStatus( + content=element["content"], + content_summary=element["content_summary"], + content_length=element["content_length"], + status=element["status"], + created_at=element["created_at"], + updated_at=element["updated_at"], + chunks_count=element["chunks_count"], + file_path=element["file_path"], + ) + for element in result + } + return docs_by_status + + async def index_done_callback(self) -> None: + # PG handles persistence automatically + pass + + async def delete(self, ids: list[str]) -> None: + """Delete specific records from storage by their IDs + + Args: + ids (list[str]): List of document IDs to be deleted from storage + + Returns: + None + """ + if not ids: + return + + table_name = namespace_to_table_name(self.namespace) + if not table_name: + logger.error(f"Unknown namespace for deletion: {self.namespace}") + return + + delete_sql = f"DELETE FROM {table_name} WHERE workspace=$1 AND id = ANY($2)" + + try: + await self.db.execute( + delete_sql, {"workspace": self.db.workspace, "ids": ids} + ) + logger.debug( + f"Successfully deleted {len(ids)} records from {self.namespace}" + ) + except Exception as e: + logger.error(f"Error while deleting records from {self.namespace}: {e}") + + async def upsert(self, data: dict[str, dict[str, Any]]) -> None: + """Update or insert document status + + Args: + data: dictionary of document IDs and their status data + """ + logger.debug(f"Inserting {len(data)} to {self.namespace}") + if not data: + return + + def parse_datetime(dt_str): + if dt_str is None: + return None + if isinstance(dt_str, (datetime.date, datetime.datetime)): + # If it's a datetime object without timezone info, remove timezone info + if isinstance(dt_str, datetime.datetime): + # Remove timezone info, return naive datetime object + return dt_str.replace(tzinfo=None) + return dt_str + try: + # Process ISO format string with timezone + dt = datetime.datetime.fromisoformat(dt_str) + # Remove timezone info, return naive datetime object + return dt.replace(tzinfo=None) + except (ValueError, TypeError): + logger.warning(f"Unable to parse datetime string: {dt_str}") + return None + + # Modified SQL to include created_at and updated_at in both INSERT and UPDATE operations + # Both fields are updated from the input data in both INSERT and UPDATE cases + sql = """insert into LIGHTRAG_DOC_STATUS(workspace,id,content,content_summary,content_length,chunks_count,status,file_path,created_at,updated_at) + values($1,$2,$3,$4,$5,$6,$7,$8,$9,$10) + on conflict(id,workspace) do update set + content = EXCLUDED.content, + content_summary = EXCLUDED.content_summary, + content_length = EXCLUDED.content_length, + chunks_count = EXCLUDED.chunks_count, + status = EXCLUDED.status, + file_path = EXCLUDED.file_path, + created_at = EXCLUDED.created_at, + updated_at = EXCLUDED.updated_at""" + for k, v in data.items(): + # Remove timezone information, store utc time in db + created_at = parse_datetime(v.get("created_at")) + updated_at = parse_datetime(v.get("updated_at")) + + # chunks_count is optional + await self.db.execute( + sql, + { + "workspace": self.db.workspace, + "id": k, + "content": v["content"], + "content_summary": v["content_summary"], + "content_length": v["content_length"], + "chunks_count": v["chunks_count"] if "chunks_count" in v else -1, + "status": v["status"], + "file_path": v["file_path"], + "created_at": created_at, # Use the converted datetime object + "updated_at": updated_at, # Use the converted datetime object + }, + ) + + async def drop(self) -> dict[str, str]: + """Drop the storage""" + try: + table_name = namespace_to_table_name(self.namespace) + if not table_name: + return { + "status": "error", + "message": f"Unknown namespace: {self.namespace}", + } + + drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format( + table_name=table_name + ) + await self.db.execute(drop_sql, {"workspace": self.db.workspace}) + return {"status": "success", "message": "data dropped"} + except Exception as e: + return {"status": "error", "message": str(e)} + + +class PGGraphQueryException(Exception): + """Exception for the AGE queries.""" + + def __init__(self, exception: Union[str, dict[str, Any]]) -> None: + if isinstance(exception, dict): + self.message = exception["message"] if "message" in exception else "unknown" + self.details = exception["details"] if "details" in exception else "unknown" + else: + self.message = exception + self.details = "unknown" + + def get_message(self) -> str: + return self.message + + def get_details(self) -> Any: + return self.details + + +@final +@dataclass +class PGGraphStorage(BaseGraphStorage): + def __post_init__(self): + self.graph_name = self.namespace or os.environ.get("AGE_GRAPH_NAME", "lightrag") + self.db: PostgreSQLDB | None = None + + @staticmethod + def _normalize_node_id(node_id: str) -> str: + """ + Normalize node ID to ensure special characters are properly handled in Cypher queries. + + Args: + node_id: The original node ID + + Returns: + Normalized node ID suitable for Cypher queries + """ + # Escape backslashes + normalized_id = node_id + normalized_id = normalized_id.replace("\\", "\\\\") + normalized_id = normalized_id.replace('"', '\\"') + return normalized_id + + async def initialize(self): + if self.db is None: + self.db = await ClientManager.get_client() + + # Execute each statement separately and ignore errors + queries = [ + f"SELECT create_graph('{self.graph_name}')", + f"SELECT create_vlabel('{self.graph_name}', 'base');", + f"SELECT create_elabel('{self.graph_name}', 'DIRECTED');", + # f'CREATE INDEX CONCURRENTLY vertex_p_idx ON {self.graph_name}."_ag_label_vertex" (id)', + f'CREATE INDEX CONCURRENTLY vertex_idx_node_id ON {self.graph_name}."_ag_label_vertex" (ag_catalog.agtype_access_operator(properties, \'"entity_id"\'::agtype))', + # f'CREATE INDEX CONCURRENTLY edge_p_idx ON {self.graph_name}."_ag_label_edge" (id)', + f'CREATE INDEX CONCURRENTLY edge_sid_idx ON {self.graph_name}."_ag_label_edge" (start_id)', + f'CREATE INDEX CONCURRENTLY edge_eid_idx ON {self.graph_name}."_ag_label_edge" (end_id)', + f'CREATE INDEX CONCURRENTLY edge_seid_idx ON {self.graph_name}."_ag_label_edge" (start_id,end_id)', + f'CREATE INDEX CONCURRENTLY directed_p_idx ON {self.graph_name}."DIRECTED" (id)', + f'CREATE INDEX CONCURRENTLY directed_eid_idx ON {self.graph_name}."DIRECTED" (end_id)', + f'CREATE INDEX CONCURRENTLY directed_sid_idx ON {self.graph_name}."DIRECTED" (start_id)', + f'CREATE INDEX CONCURRENTLY directed_seid_idx ON {self.graph_name}."DIRECTED" (start_id,end_id)', + f'CREATE INDEX CONCURRENTLY entity_p_idx ON {self.graph_name}."base" (id)', + f'CREATE INDEX CONCURRENTLY entity_idx_node_id ON {self.graph_name}."base" (ag_catalog.agtype_access_operator(properties, \'"entity_id"\'::agtype))', + f'CREATE INDEX CONCURRENTLY entity_node_id_gin_idx ON {self.graph_name}."base" using gin(properties)', + f'ALTER TABLE {self.graph_name}."DIRECTED" CLUSTER ON directed_sid_idx', + ] + + for query in queries: + # Use the new flag to silently ignore "already exists" errors + # at the source, preventing log spam. + await self.db.execute( + query, + upsert=True, + ignore_if_exists=True, # Pass the new flag + with_age=True, + graph_name=self.graph_name, + ) + + async def finalize(self): + if self.db is not None: + await ClientManager.release_client(self.db) + self.db = None + + async def index_done_callback(self) -> None: + # PG handles persistence automatically + pass + + @staticmethod + def _record_to_dict(record: asyncpg.Record) -> dict[str, Any]: + """ + Convert a record returned from an age query to a dictionary + + Args: + record (): a record from an age query result + + Returns: + dict[str, Any]: a dictionary representation of the record where + the dictionary key is the field name and the value is the + value converted to a python type + """ + # result holder + d = {} + + # prebuild a mapping of vertex_id to vertex mappings to be used + # later to build edges + vertices = {} + for k in record.keys(): + v = record[k] + # agtype comes back '{key: value}::type' which must be parsed + if isinstance(v, str) and "::" in v: + if v.startswith("[") and v.endswith("]"): + if "::vertex" not in v: + continue + v = v.replace("::vertex", "") + vertexes = json.loads(v) + for vertex in vertexes: + vertices[vertex["id"]] = vertex.get("properties") + else: + dtype = v.split("::")[-1] + v = v.split("::")[0] + if dtype == "vertex": + vertex = json.loads(v) + vertices[vertex["id"]] = vertex.get("properties") + + # iterate returned fields and parse appropriately + for k in record.keys(): + v = record[k] + if isinstance(v, str) and "::" in v: + if v.startswith("[") and v.endswith("]"): + if "::vertex" in v: + v = v.replace("::vertex", "") + d[k] = json.loads(v) + + elif "::edge" in v: + v = v.replace("::edge", "") + d[k] = json.loads(v) + else: + print("WARNING: unsupported type") + continue + + else: + dtype = v.split("::")[-1] + v = v.split("::")[0] + if dtype == "vertex": + d[k] = json.loads(v) + elif dtype == "edge": + d[k] = json.loads(v) + else: + d[k] = v # Keep as string + + return d + + @staticmethod + def _format_properties( + properties: dict[str, Any], _id: Union[str, None] = None + ) -> str: + """ + Convert a dictionary of properties to a string representation that + can be used in a cypher query insert/merge statement. + + Args: + properties (dict[str,str]): a dictionary containing node/edge properties + _id (Union[str, None]): the id of the node or None if none exists + + Returns: + str: the properties dictionary as a properly formatted string + """ + props = [] + # wrap property key in backticks to escape + for k, v in properties.items(): + prop = f"`{k}`: {json.dumps(v)}" + props.append(prop) + if _id is not None and "id" not in properties: + props.append( + f"id: {json.dumps(_id)}" if isinstance(_id, str) else f"id: {_id}" + ) + return "{" + ", ".join(props) + "}" + + async def _query( + self, + query: str, + readonly: bool = True, + upsert: bool = False, + ) -> list[dict[str, Any]]: + """ + Query the graph by taking a cypher query, converting it to an + age compatible query, executing it and converting the result + + Args: + query (str): a cypher query to be executed + + Returns: + list[dict[str, Any]]: a list of dictionaries containing the result set + """ + try: + if readonly: + data = await self.db.query( + query, + multirows=True, + with_age=True, + graph_name=self.graph_name, + ) + else: + data = await self.db.execute( + query, + upsert=upsert, + with_age=True, + graph_name=self.graph_name, + ) + + except Exception as e: + raise PGGraphQueryException( + { + "message": f"Error executing graph query: {query}", + "wrapped": query, + "detail": str(e), + } + ) from e + + if data is None: + result = [] + # decode records + else: + result = [self._record_to_dict(d) for d in data] + + return result + + async def has_node(self, node_id: str) -> bool: + entity_name_label = self._normalize_node_id(node_id) + + query = """SELECT * FROM cypher('%s', $$ + MATCH (n:base {entity_id: "%s"}) + RETURN count(n) > 0 AS node_exists + $$) AS (node_exists bool)""" % (self.graph_name, entity_name_label) + + single_result = (await self._query(query))[0] + + return single_result["node_exists"] + + async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: + src_label = self._normalize_node_id(source_node_id) + tgt_label = self._normalize_node_id(target_node_id) + + query = """SELECT * FROM cypher('%s', $$ + MATCH (a:base {entity_id: "%s"})-[r]-(b:base {entity_id: "%s"}) + RETURN COUNT(r) > 0 AS edge_exists + $$) AS (edge_exists bool)""" % ( + self.graph_name, + src_label, + tgt_label, + ) + + single_result = (await self._query(query))[0] + + return single_result["edge_exists"] + + async def get_node(self, node_id: str) -> dict[str, str] | None: + """Get node by its label identifier, return only node properties""" + + label = self._normalize_node_id(node_id) + query = """SELECT * FROM cypher('%s', $$ + MATCH (n:base {entity_id: "%s"}) + RETURN n + $$) AS (n agtype)""" % (self.graph_name, label) + record = await self._query(query) + if record: + node = record[0] + node_dict = node["n"]["properties"] + + # Process string result, parse it to JSON dictionary + if isinstance(node_dict, str): + try: + import json + + node_dict = json.loads(node_dict) + except json.JSONDecodeError: + logger.warning(f"Failed to parse node string: {node_dict}") + + return node_dict + return None + + async def node_degree(self, node_id: str) -> int: + label = self._normalize_node_id(node_id) + + query = """SELECT * FROM cypher('%s', $$ + MATCH (n:base {entity_id: "%s"})-[r]-() + RETURN count(r) AS total_edge_count + $$) AS (total_edge_count integer)""" % (self.graph_name, label) + record = (await self._query(query))[0] + if record: + edge_count = int(record["total_edge_count"]) + return edge_count + + async def edge_degree(self, src_id: str, tgt_id: str) -> int: + src_degree = await self.node_degree(src_id) + trg_degree = await self.node_degree(tgt_id) + + # Convert None to 0 for addition + src_degree = 0 if src_degree is None else src_degree + trg_degree = 0 if trg_degree is None else trg_degree + + degrees = int(src_degree) + int(trg_degree) + + return degrees + + async def get_edge( + self, source_node_id: str, target_node_id: str + ) -> dict[str, str] | None: + """Get edge properties between two nodes""" + + src_label = self._normalize_node_id(source_node_id) + tgt_label = self._normalize_node_id(target_node_id) + + query = """SELECT * FROM cypher('%s', $$ + MATCH (a:base {entity_id: "%s"})-[r]-(b:base {entity_id: "%s"}) + RETURN properties(r) as edge_properties + LIMIT 1 + $$) AS (edge_properties agtype)""" % ( + self.graph_name, + src_label, + tgt_label, + ) + record = await self._query(query) + if record and record[0] and record[0]["edge_properties"]: + result = record[0]["edge_properties"] + + # Process string result, parse it to JSON dictionary + if isinstance(result, str): + try: + import json + + result = json.loads(result) + except json.JSONDecodeError: + logger.warning(f"Failed to parse edge string: {result}") + + return result + + async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: + """ + Retrieves all edges (relationships) for a particular node identified by its label. + :return: list of dictionaries containing edge information + """ + label = self._normalize_node_id(source_node_id) + + query = """SELECT * FROM cypher('%s', $$ + MATCH (n:base {entity_id: "%s"}) + OPTIONAL MATCH (n)-[]-(connected:base) + RETURN n.entity_id AS source_id, connected.entity_id AS connected_id + $$) AS (source_id text, connected_id text)""" % ( + self.graph_name, + label, + ) + + results = await self._query(query) + edges = [] + for record in results: + source_id = record["source_id"] + connected_id = record["connected_id"] + + if source_id and connected_id: + edges.append((source_id, connected_id)) + + return edges + + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type((PGGraphQueryException,)), + ) + async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: + """ + Upsert a node in the Neo4j database. + + Args: + node_id: The unique identifier for the node (used as label) + node_data: Dictionary of node properties + """ + if "entity_id" not in node_data: + raise ValueError( + "PostgreSQL: node properties must contain an 'entity_id' field" + ) + + label = self._normalize_node_id(node_id) + properties = self._format_properties(node_data) + + query = """SELECT * FROM cypher('%s', $$ + MERGE (n:base {entity_id: "%s"}) + SET n += %s + RETURN n + $$) AS (n agtype)""" % ( + self.graph_name, + label, + properties, + ) + + try: + await self._query(query, readonly=False, upsert=True) + + except Exception: + logger.error(f"POSTGRES, upsert_node error on node_id: `{node_id}`") + raise + + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type((PGGraphQueryException,)), + ) + async def upsert_edge( + self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] + ) -> None: + """ + Upsert an edge and its properties between two nodes identified by their labels. + + Args: + source_node_id (str): Label of the source node (used as identifier) + target_node_id (str): Label of the target node (used as identifier) + edge_data (dict): dictionary of properties to set on the edge + """ + src_label = self._normalize_node_id(source_node_id) + tgt_label = self._normalize_node_id(target_node_id) + edge_properties = self._format_properties(edge_data) + + query = """SELECT * FROM cypher('%s', $$ + MATCH (source:base {entity_id: "%s"}) + WITH source + MATCH (target:base {entity_id: "%s"}) + MERGE (source)-[r:DIRECTED]-(target) + SET r += %s + SET r += %s + RETURN r + $$) AS (r agtype)""" % ( + self.graph_name, + src_label, + tgt_label, + edge_properties, + edge_properties, # https://github.com/HKUDS/LightRAG/issues/1438#issuecomment-2826000195 + ) + + try: + await self._query(query, readonly=False, upsert=True) + + except Exception: + logger.error( + f"POSTGRES, upsert_edge error on edge: `{source_node_id}`-`{target_node_id}`" + ) + raise + + async def delete_node(self, node_id: str) -> None: + """ + Delete a node from the graph. + + Args: + node_id (str): The ID of the node to delete. + """ + label = self._normalize_node_id(node_id) + + query = """SELECT * FROM cypher('%s', $$ + MATCH (n:base {entity_id: "%s"}) + DETACH DELETE n + $$) AS (n agtype)""" % (self.graph_name, label) + + try: + await self._query(query, readonly=False) + except Exception as e: + logger.error("Error during node deletion: {%s}", e) + raise + + async def remove_nodes(self, node_ids: list[str]) -> None: + """ + Remove multiple nodes from the graph. + + Args: + node_ids (list[str]): A list of node IDs to remove. + """ + node_ids = [self._normalize_node_id(node_id) for node_id in node_ids] + node_id_list = ", ".join([f'"{node_id}"' for node_id in node_ids]) + + query = """SELECT * FROM cypher('%s', $$ + MATCH (n:base) + WHERE n.entity_id IN [%s] + DETACH DELETE n + $$) AS (n agtype)""" % (self.graph_name, node_id_list) + + try: + await self._query(query, readonly=False) + except Exception as e: + logger.error("Error during node removal: {%s}", e) + raise + + async def remove_edges(self, edges: list[tuple[str, str]]) -> None: + """ + Remove multiple edges from the graph. + + Args: + edges (list[tuple[str, str]]): A list of edges to remove, where each edge is a tuple of (source_node_id, target_node_id). + """ + for source, target in edges: + src_label = self._normalize_node_id(source) + tgt_label = self._normalize_node_id(target) + + query = """SELECT * FROM cypher('%s', $$ + MATCH (a:base {entity_id: "%s"})-[r]-(b:base {entity_id: "%s"}) + DELETE r + $$) AS (r agtype)""" % (self.graph_name, src_label, tgt_label) + + try: + await self._query(query, readonly=False) + logger.debug(f"Deleted edge from '{source}' to '{target}'") + except Exception as e: + logger.error(f"Error during edge deletion: {str(e)}") + raise + + async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]: + """ + Retrieve multiple nodes in one query using UNWIND. + + Args: + node_ids: List of node entity IDs to fetch. + + Returns: + A dictionary mapping each node_id to its node data (or None if not found). + """ + if not node_ids: + return {} + + # Format node IDs for the query + formatted_ids = ", ".join( + ['"' + self._normalize_node_id(node_id) + '"' for node_id in node_ids] + ) + + query = """SELECT * FROM cypher('%s', $$ + UNWIND [%s] AS node_id + MATCH (n:base {entity_id: node_id}) + RETURN node_id, n + $$) AS (node_id text, n agtype)""" % (self.graph_name, formatted_ids) + + results = await self._query(query) + + # Build result dictionary + nodes_dict = {} + for result in results: + if result["node_id"] and result["n"]: + node_dict = result["n"]["properties"] + + # Process string result, parse it to JSON dictionary + if isinstance(node_dict, str): + try: + import json + + node_dict = json.loads(node_dict) + except json.JSONDecodeError: + logger.warning( + f"Failed to parse node string in batch: {node_dict}" + ) + + # Remove the 'base' label if present in a 'labels' property + if "labels" in node_dict: + node_dict["labels"] = [ + label for label in node_dict["labels"] if label != "base" + ] + nodes_dict[result["node_id"]] = node_dict + + return nodes_dict + + async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]: + """ + Retrieve the degree for multiple nodes in a single query using UNWIND. + Calculates the total degree by counting distinct relationships. + Uses separate queries for outgoing and incoming edges. + + Args: + node_ids: List of node labels (entity_id values) to look up. + + Returns: + A dictionary mapping each node_id to its degree (total number of relationships). + If a node is not found, its degree will be set to 0. + """ + if not node_ids: + return {} + + # Format node IDs for the query + formatted_ids = ", ".join( + ['"' + self._normalize_node_id(node_id) + '"' for node_id in node_ids] + ) + + outgoing_query = """SELECT * FROM cypher('%s', $$ + UNWIND [%s] AS node_id + MATCH (n:base {entity_id: node_id}) + OPTIONAL MATCH (n)-[r]->(a) + RETURN node_id, count(a) AS out_degree + $$) AS (node_id text, out_degree bigint)""" % ( + self.graph_name, + formatted_ids, + ) + + incoming_query = """SELECT * FROM cypher('%s', $$ + UNWIND [%s] AS node_id + MATCH (n:base {entity_id: node_id}) + OPTIONAL MATCH (n)<-[r]-(b) + RETURN node_id, count(b) AS in_degree + $$) AS (node_id text, in_degree bigint)""" % ( + self.graph_name, + formatted_ids, + ) + + outgoing_results = await self._query(outgoing_query) + incoming_results = await self._query(incoming_query) + + out_degrees = {} + in_degrees = {} + + for result in outgoing_results: + if result["node_id"] is not None: + out_degrees[result["node_id"]] = int(result["out_degree"]) + + for result in incoming_results: + if result["node_id"] is not None: + in_degrees[result["node_id"]] = int(result["in_degree"]) + + degrees_dict = {} + for node_id in node_ids: + out_degree = out_degrees.get(node_id, 0) + in_degree = in_degrees.get(node_id, 0) + degrees_dict[node_id] = out_degree + in_degree + + return degrees_dict + + async def edge_degrees_batch( + self, edges: list[tuple[str, str]] + ) -> dict[tuple[str, str], int]: + """ + Calculate the combined degree for each edge (sum of the source and target node degrees) + in batch using the already implemented node_degrees_batch. + + Args: + edges: List of (source_node_id, target_node_id) tuples + + Returns: + Dictionary mapping edge tuples to their combined degrees + """ + if not edges: + return {} + + # Use node_degrees_batch to get all node degrees efficiently + all_nodes = set() + for src, tgt in edges: + all_nodes.add(src) + all_nodes.add(tgt) + + node_degrees = await self.node_degrees_batch(list(all_nodes)) + + # Calculate edge degrees + edge_degrees_dict = {} + for src, tgt in edges: + src_degree = node_degrees.get(src, 0) + tgt_degree = node_degrees.get(tgt, 0) + edge_degrees_dict[(src, tgt)] = src_degree + tgt_degree + + return edge_degrees_dict + + async def get_edges_batch( + self, pairs: list[dict[str, str]] + ) -> dict[tuple[str, str], dict]: + """ + Retrieve edge properties for multiple (src, tgt) pairs in one query. + Get forward and backward edges seperately and merge them before return + + Args: + pairs: List of dictionaries, e.g. [{"src": "node1", "tgt": "node2"}, ...] + + Returns: + A dictionary mapping (src, tgt) tuples to their edge properties. + """ + if not pairs: + return {} + + src_nodes = [] + tgt_nodes = [] + for pair in pairs: + src_nodes.append(self._normalize_node_id(pair["src"])) + tgt_nodes.append(self._normalize_node_id(pair["tgt"])) + + src_array = ", ".join([f'"{src}"' for src in src_nodes]) + tgt_array = ", ".join([f'"{tgt}"' for tgt in tgt_nodes]) + + forward_query = f"""SELECT * FROM cypher('{self.graph_name}', $$ + WITH [{src_array}] AS sources, [{tgt_array}] AS targets + UNWIND range(0, size(sources)-1) AS i + MATCH (a:base {{entity_id: sources[i]}})-[r:DIRECTED]->(b:base {{entity_id: targets[i]}}) + RETURN sources[i] AS source, targets[i] AS target, properties(r) AS edge_properties + $$) AS (source text, target text, edge_properties agtype)""" + + backward_query = f"""SELECT * FROM cypher('{self.graph_name}', $$ + WITH [{src_array}] AS sources, [{tgt_array}] AS targets + UNWIND range(0, size(sources)-1) AS i + MATCH (a:base {{entity_id: sources[i]}})<-[r:DIRECTED]-(b:base {{entity_id: targets[i]}}) + RETURN sources[i] AS source, targets[i] AS target, properties(r) AS edge_properties + $$) AS (source text, target text, edge_properties agtype)""" + + forward_results = await self._query(forward_query) + backward_results = await self._query(backward_query) + + edges_dict = {} + + for result in forward_results: + if result["source"] and result["target"] and result["edge_properties"]: + edge_props = result["edge_properties"] + + # Process string result, parse it to JSON dictionary + if isinstance(edge_props, str): + try: + import json + + edge_props = json.loads(edge_props) + except json.JSONDecodeError: + logger.warning( + f"Failed to parse edge properties string: {edge_props}" + ) + continue + + edges_dict[(result["source"], result["target"])] = edge_props + + for result in backward_results: + if result["source"] and result["target"] and result["edge_properties"]: + edge_props = result["edge_properties"] + + # Process string result, parse it to JSON dictionary + if isinstance(edge_props, str): + try: + import json + + edge_props = json.loads(edge_props) + except json.JSONDecodeError: + logger.warning( + f"Failed to parse edge properties string: {edge_props}" + ) + continue + + edges_dict[(result["source"], result["target"])] = edge_props + + return edges_dict + + async def get_nodes_edges_batch( + self, node_ids: list[str] + ) -> dict[str, list[tuple[str, str]]]: + """ + Get all edges (both outgoing and incoming) for multiple nodes in a single batch operation. + + Args: + node_ids: List of node IDs to get edges for + + Returns: + Dictionary mapping node IDs to lists of (source, target) edge tuples + """ + if not node_ids: + return {} + + # Format node IDs for the query + formatted_ids = ", ".join( + ['"' + self._normalize_node_id(node_id) + '"' for node_id in node_ids] + ) + + outgoing_query = """SELECT * FROM cypher('%s', $$ + UNWIND [%s] AS node_id + MATCH (n:base {entity_id: node_id}) + OPTIONAL MATCH (n:base)-[]->(connected:base) + RETURN node_id, connected.entity_id AS connected_id + $$) AS (node_id text, connected_id text)""" % ( + self.graph_name, + formatted_ids, + ) + + incoming_query = """SELECT * FROM cypher('%s', $$ + UNWIND [%s] AS node_id + MATCH (n:base {entity_id: node_id}) + OPTIONAL MATCH (n:base)<-[]-(connected:base) + RETURN node_id, connected.entity_id AS connected_id + $$) AS (node_id text, connected_id text)""" % ( + self.graph_name, + formatted_ids, + ) + + outgoing_results = await self._query(outgoing_query) + incoming_results = await self._query(incoming_query) + + nodes_edges_dict = {node_id: [] for node_id in node_ids} + + for result in outgoing_results: + if result["node_id"] and result["connected_id"]: + nodes_edges_dict[result["node_id"]].append( + (result["node_id"], result["connected_id"]) + ) + + for result in incoming_results: + if result["node_id"] and result["connected_id"]: + nodes_edges_dict[result["node_id"]].append( + (result["connected_id"], result["node_id"]) + ) + + return nodes_edges_dict + + async def get_all_labels(self) -> list[str]: + """ + Get all labels (node IDs) in the graph. + + Returns: + list[str]: A list of all labels in the graph. + """ + query = ( + """SELECT * FROM cypher('%s', $$ + MATCH (n:base) + WHERE n.entity_id IS NOT NULL + RETURN DISTINCT n.entity_id AS label + ORDER BY n.entity_id + $$) AS (label text)""" + % self.graph_name + ) + + results = await self._query(query) + labels = [] + for result in results: + if result and isinstance(result, dict) and "label" in result: + labels.append(result["label"]) + return labels + + async def _bfs_subgraph( + self, node_label: str, max_depth: int, max_nodes: int + ) -> KnowledgeGraph: + """ + Implements a true breadth-first search algorithm for subgraph retrieval. + This method is used as a fallback when the standard Cypher query is too slow + or when we need to guarantee BFS ordering. + + Args: + node_label: Label of the starting node + max_depth: Maximum depth of the subgraph + max_nodes: Maximum number of nodes to return + + Returns: + KnowledgeGraph object containing nodes and edges + """ + from collections import deque + + result = KnowledgeGraph() + visited_nodes = set() + visited_node_ids = set() + visited_edges = set() + visited_edge_pairs = set() + + # Get starting node data + label = self._normalize_node_id(node_label) + query = """SELECT * FROM cypher('%s', $$ + MATCH (n:base {entity_id: "%s"}) + RETURN id(n) as node_id, n + $$) AS (node_id bigint, n agtype)""" % (self.graph_name, label) + + node_result = await self._query(query) + if not node_result or not node_result[0].get("n"): + return result + + # Create initial KnowledgeGraphNode + start_node_data = node_result[0]["n"] + entity_id = start_node_data["properties"]["entity_id"] + internal_id = str(start_node_data["id"]) + + start_node = KnowledgeGraphNode( + id=internal_id, + labels=[entity_id], + properties=start_node_data["properties"], + ) + + # Initialize BFS queue, each element is a tuple of (node, depth) + queue = deque([(start_node, 0)]) + + visited_nodes.add(entity_id) + visited_node_ids.add(internal_id) + result.nodes.append(start_node) + + result.is_truncated = False + + # BFS search main loop + while queue: + # Get all nodes at the current depth + current_level_nodes = [] + current_depth = None + + # Determine current depth + if queue: + current_depth = queue[0][1] + + # Extract all nodes at current depth from the queue + while queue and queue[0][1] == current_depth: + node, depth = queue.popleft() + if depth > max_depth: + continue + current_level_nodes.append(node) + + if not current_level_nodes: + continue + + # Check depth limit + if current_depth > max_depth: + continue + + # Prepare node IDs list + node_ids = [node.labels[0] for node in current_level_nodes] + formatted_ids = ", ".join( + [f'"{self._normalize_node_id(node_id)}"' for node_id in node_ids] + ) + + # Construct batch query for outgoing edges + outgoing_query = f"""SELECT * FROM cypher('{self.graph_name}', $$ + UNWIND [{formatted_ids}] AS node_id + MATCH (n:base {{entity_id: node_id}}) + OPTIONAL MATCH (n)-[r]->(neighbor:base) + RETURN node_id AS current_id, + id(n) AS current_internal_id, + id(neighbor) AS neighbor_internal_id, + neighbor.entity_id AS neighbor_id, + id(r) AS edge_id, + r, + neighbor, + true AS is_outgoing + $$) AS (current_id text, current_internal_id bigint, neighbor_internal_id bigint, + neighbor_id text, edge_id bigint, r agtype, neighbor agtype, is_outgoing bool)""" + + # Construct batch query for incoming edges + incoming_query = f"""SELECT * FROM cypher('{self.graph_name}', $$ + UNWIND [{formatted_ids}] AS node_id + MATCH (n:base {{entity_id: node_id}}) + OPTIONAL MATCH (n)<-[r]-(neighbor:base) + RETURN node_id AS current_id, + id(n) AS current_internal_id, + id(neighbor) AS neighbor_internal_id, + neighbor.entity_id AS neighbor_id, + id(r) AS edge_id, + r, + neighbor, + false AS is_outgoing + $$) AS (current_id text, current_internal_id bigint, neighbor_internal_id bigint, + neighbor_id text, edge_id bigint, r agtype, neighbor agtype, is_outgoing bool)""" + + # Execute queries + outgoing_results = await self._query(outgoing_query) + incoming_results = await self._query(incoming_query) + + # Combine results + neighbors = outgoing_results + incoming_results + + # Create mapping from node ID to node object + node_map = {node.labels[0]: node for node in current_level_nodes} + + # Process all results in a single loop + for record in neighbors: + if not record.get("neighbor") or not record.get("r"): + continue + + # Get current node information + current_entity_id = record["current_id"] + current_node = node_map[current_entity_id] + + # Get neighbor node information + neighbor_entity_id = record["neighbor_id"] + neighbor_internal_id = str(record["neighbor_internal_id"]) + is_outgoing = record["is_outgoing"] + + # Determine edge direction + if is_outgoing: + source_id = current_node.id + target_id = neighbor_internal_id + else: + source_id = neighbor_internal_id + target_id = current_node.id + + if not neighbor_entity_id: + continue + + # Get edge and node information + b_node = record["neighbor"] + rel = record["r"] + edge_id = str(record["edge_id"]) + + # Create neighbor node object + neighbor_node = KnowledgeGraphNode( + id=neighbor_internal_id, + labels=[neighbor_entity_id], + properties=b_node["properties"], + ) + + # Sort entity_ids to ensure (A,B) and (B,A) are treated as the same edge + sorted_pair = tuple(sorted([current_entity_id, neighbor_entity_id])) + + # Create edge object + edge = KnowledgeGraphEdge( + id=edge_id, + type=rel["label"], + source=source_id, + target=target_id, + properties=rel["properties"], + ) + + if neighbor_internal_id in visited_node_ids: + # Add backward edge if neighbor node is already visited + if ( + edge_id not in visited_edges + and sorted_pair not in visited_edge_pairs + ): + result.edges.append(edge) + visited_edges.add(edge_id) + visited_edge_pairs.add(sorted_pair) + else: + if len(visited_node_ids) < max_nodes and current_depth < max_depth: + # Add new node to result and queue + result.nodes.append(neighbor_node) + visited_nodes.add(neighbor_entity_id) + visited_node_ids.add(neighbor_internal_id) + + # Add node to queue with incremented depth + queue.append((neighbor_node, current_depth + 1)) + + # Add forward edge + if ( + edge_id not in visited_edges + and sorted_pair not in visited_edge_pairs + ): + result.edges.append(edge) + visited_edges.add(edge_id) + visited_edge_pairs.add(sorted_pair) + else: + if current_depth < max_depth: + result.is_truncated = True + + return result + + async def get_knowledge_graph( + self, + node_label: str, + max_depth: int = 3, + max_nodes: int = MAX_GRAPH_NODES, + ) -> KnowledgeGraph: + """ + Retrieve a connected subgraph of nodes where the label includes the specified `node_label`. + + Args: + node_label: Label of the starting node, * means all nodes + max_depth: Maximum depth of the subgraph, Defaults to 3 + max_nodes: Maxiumu nodes to return, Defaults to 1000 + + Returns: + KnowledgeGraph object containing nodes and edges, with an is_truncated flag + indicating whether the graph was truncated due to max_nodes limit + """ + kg = KnowledgeGraph() + + # Handle wildcard query - get all nodes + if node_label == "*": + # First check total node count to determine if graph should be truncated + count_query = f"""SELECT * FROM cypher('{self.graph_name}', $$ + MATCH (n:base) + RETURN count(distinct n) AS total_nodes + $$) AS (total_nodes bigint)""" + + count_result = await self._query(count_query) + total_nodes = count_result[0]["total_nodes"] if count_result else 0 + is_truncated = total_nodes > max_nodes + + # Get max_nodes with highest degrees + query_nodes = f"""SELECT * FROM cypher('{self.graph_name}', $$ + MATCH (n:base) + OPTIONAL MATCH (n)-[r]->() + RETURN id(n) as node_id, count(r) as degree + $$) AS (node_id BIGINT, degree BIGINT) + ORDER BY degree DESC + LIMIT {max_nodes}""" + node_results = await self._query(query_nodes) + + node_ids = [str(result["node_id"]) for result in node_results] + + logger.info(f"Total nodes: {total_nodes}, Selected nodes: {len(node_ids)}") + + if node_ids: + formatted_ids = ", ".join(node_ids) + # Construct batch query for subgraph within max_nodes + query = f"""SELECT * FROM cypher('{self.graph_name}', $$ + WITH [{formatted_ids}] AS node_ids + MATCH (a) + WHERE id(a) IN node_ids + OPTIONAL MATCH (a)-[r]->(b) + WHERE id(b) IN node_ids + RETURN a, r, b + $$) AS (a AGTYPE, r AGTYPE, b AGTYPE)""" + results = await self._query(query) + + # Process query results, deduplicate nodes and edges + nodes_dict = {} + edges_dict = {} + for result in results: + # Process node a + if result.get("a") and isinstance(result["a"], dict): + node_a = result["a"] + node_id = str(node_a["id"]) + if node_id not in nodes_dict and "properties" in node_a: + nodes_dict[node_id] = KnowledgeGraphNode( + id=node_id, + labels=[node_a["properties"]["entity_id"]], + properties=node_a["properties"], + ) + + # Process node b + if result.get("b") and isinstance(result["b"], dict): + node_b = result["b"] + node_id = str(node_b["id"]) + if node_id not in nodes_dict and "properties" in node_b: + nodes_dict[node_id] = KnowledgeGraphNode( + id=node_id, + labels=[node_b["properties"]["entity_id"]], + properties=node_b["properties"], + ) + + # Process edge r + if result.get("r") and isinstance(result["r"], dict): + edge = result["r"] + edge_id = str(edge["id"]) + if edge_id not in edges_dict: + edges_dict[edge_id] = KnowledgeGraphEdge( + id=edge_id, + type=edge["label"], + source=str(edge["start_id"]), + target=str(edge["end_id"]), + properties=edge["properties"], + ) + + kg = KnowledgeGraph( + nodes=list(nodes_dict.values()), + edges=list(edges_dict.values()), + is_truncated=is_truncated, + ) + else: + # For single node query, use BFS algorithm + kg = await self._bfs_subgraph(node_label, max_depth, max_nodes) + + logger.info( + f"Subgraph query successful | Node count: {len(kg.nodes)} | Edge count: {len(kg.edges)}" + ) + else: + # For non-wildcard queries, use the BFS algorithm + kg = await self._bfs_subgraph(node_label, max_depth, max_nodes) + logger.info( + f"Subgraph query for '{node_label}' successful | Node count: {len(kg.nodes)} | Edge count: {len(kg.edges)}" + ) + + return kg + + async def drop(self) -> dict[str, str]: + """Drop the storage""" + try: + drop_query = f"""SELECT * FROM cypher('{self.graph_name}', $$ + MATCH (n) + DETACH DELETE n + $$) AS (result agtype)""" + + await self._query(drop_query, readonly=False) + return {"status": "success", "message": "graph data dropped"} + except Exception as e: + logger.error(f"Error dropping graph: {e}") + return {"status": "error", "message": str(e)} + + +NAMESPACE_TABLE_MAP = { + NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL", + NameSpace.KV_STORE_TEXT_CHUNKS: "LIGHTRAG_DOC_CHUNKS", + NameSpace.VECTOR_STORE_CHUNKS: "LIGHTRAG_DOC_CHUNKS", + NameSpace.VECTOR_STORE_ENTITIES: "LIGHTRAG_VDB_ENTITY", + NameSpace.VECTOR_STORE_RELATIONSHIPS: "LIGHTRAG_VDB_RELATION", + NameSpace.DOC_STATUS: "LIGHTRAG_DOC_STATUS", + NameSpace.KV_STORE_LLM_RESPONSE_CACHE: "LIGHTRAG_LLM_CACHE", +} + + +def namespace_to_table_name(namespace: str) -> str: + for k, v in NAMESPACE_TABLE_MAP.items(): + if is_namespace(namespace, k): + return v + + +TABLES = { + "LIGHTRAG_DOC_FULL": { + "ddl": """CREATE TABLE LIGHTRAG_DOC_FULL ( + id VARCHAR(255), + workspace VARCHAR(255), + doc_name VARCHAR(1024), + content TEXT, + meta JSONB, + create_time TIMESTAMP(0), + update_time TIMESTAMP(0), + CONSTRAINT LIGHTRAG_DOC_FULL_PK PRIMARY KEY (workspace, id) + )""" + }, + "LIGHTRAG_DOC_CHUNKS": { + "ddl": """CREATE TABLE LIGHTRAG_DOC_CHUNKS ( + id VARCHAR(255), + workspace VARCHAR(255), + full_doc_id VARCHAR(256), + chunk_order_index INTEGER, + tokens INTEGER, + content TEXT, + content_vector VECTOR, + file_path VARCHAR(256), + create_time TIMESTAMP(0) WITH TIME ZONE, + update_time TIMESTAMP(0) WITH TIME ZONE, + CONSTRAINT LIGHTRAG_DOC_CHUNKS_PK PRIMARY KEY (workspace, id) + )""" + }, + "LIGHTRAG_VDB_ENTITY": { + "ddl": """CREATE TABLE LIGHTRAG_VDB_ENTITY ( + id VARCHAR(255), + workspace VARCHAR(255), + entity_name VARCHAR(255), + content TEXT, + content_vector VECTOR, + create_time TIMESTAMP(0) WITH TIME ZONE, + update_time TIMESTAMP(0) WITH TIME ZONE, + chunk_ids VARCHAR(255)[] NULL, + file_path TEXT NULL, + CONSTRAINT LIGHTRAG_VDB_ENTITY_PK PRIMARY KEY (workspace, id) + )""" + }, + "LIGHTRAG_VDB_RELATION": { + "ddl": """CREATE TABLE LIGHTRAG_VDB_RELATION ( + id VARCHAR(255), + workspace VARCHAR(255), + source_id VARCHAR(256), + target_id VARCHAR(256), + content TEXT, + content_vector VECTOR, + create_time TIMESTAMP(0) WITH TIME ZONE, + update_time TIMESTAMP(0) WITH TIME ZONE, + chunk_ids VARCHAR(255)[] NULL, + file_path TEXT NULL, + CONSTRAINT LIGHTRAG_VDB_RELATION_PK PRIMARY KEY (workspace, id) + )""" + }, + "LIGHTRAG_LLM_CACHE": { + "ddl": """CREATE TABLE LIGHTRAG_LLM_CACHE ( + workspace varchar(255) NOT NULL, + id varchar(255) NOT NULL, + mode varchar(32) NOT NULL, + original_prompt TEXT, + return_value TEXT, + chunk_id VARCHAR(255) NULL, + create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + update_time TIMESTAMP, + CONSTRAINT LIGHTRAG_LLM_CACHE_PK PRIMARY KEY (workspace, mode, id) + )""" + }, + "LIGHTRAG_DOC_STATUS": { + "ddl": """CREATE TABLE LIGHTRAG_DOC_STATUS ( + workspace varchar(255) NOT NULL, + id varchar(255) NOT NULL, + content TEXT NULL, + content_summary varchar(255) NULL, + content_length int4 NULL, + chunks_count int4 NULL, + status varchar(64) NULL, + file_path TEXT NULL, + created_at timestamp with time zone DEFAULT CURRENT_TIMESTAMP NULL, + updated_at timestamp with time zone DEFAULT CURRENT_TIMESTAMP NULL, + CONSTRAINT LIGHTRAG_DOC_STATUS_PK PRIMARY KEY (workspace, id) + )""" + }, +} + + +SQL_TEMPLATES = { + # SQL for KVStorage + "get_by_id_full_docs": """SELECT id, COALESCE(content, '') as content + FROM LIGHTRAG_DOC_FULL WHERE workspace=$1 AND id=$2 + """, + "get_by_id_text_chunks": """SELECT id, tokens, COALESCE(content, '') as content, + chunk_order_index, full_doc_id, file_path + FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id=$2 + """, + "get_by_id_llm_response_cache": """SELECT id, original_prompt, COALESCE(return_value, '') as "return", mode, chunk_id + FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND mode=$2 + """, + "get_by_mode_id_llm_response_cache": """SELECT id, original_prompt, COALESCE(return_value, '') as "return", mode, chunk_id + FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND mode=$2 AND id=$3 + """, + "get_by_ids_full_docs": """SELECT id, COALESCE(content, '') as content + FROM LIGHTRAG_DOC_FULL WHERE workspace=$1 AND id IN ({ids}) + """, + "get_by_ids_text_chunks": """SELECT id, tokens, COALESCE(content, '') as content, + chunk_order_index, full_doc_id, file_path + FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id IN ({ids}) + """, + "get_by_ids_llm_response_cache": """SELECT id, original_prompt, COALESCE(return_value, '') as "return", mode, chunk_id + FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND mode= IN ({ids}) + """, + "filter_keys": "SELECT id FROM {table_name} WHERE workspace=$1 AND id IN ({ids})", + "upsert_doc_full": """INSERT INTO LIGHTRAG_DOC_FULL (id, content, workspace) + VALUES ($1, $2, $3) + ON CONFLICT (workspace,id) DO UPDATE + SET content = $2, update_time = CURRENT_TIMESTAMP + """, + "upsert_llm_response_cache": """INSERT INTO LIGHTRAG_LLM_CACHE(workspace,id,original_prompt,return_value,mode,chunk_id) + VALUES ($1, $2, $3, $4, $5, $6) + ON CONFLICT (workspace,mode,id) DO UPDATE + SET original_prompt = EXCLUDED.original_prompt, + return_value=EXCLUDED.return_value, + mode=EXCLUDED.mode, + chunk_id=EXCLUDED.chunk_id, + update_time = CURRENT_TIMESTAMP + """, + "upsert_chunk": """INSERT INTO LIGHTRAG_DOC_CHUNKS (workspace, id, tokens, + chunk_order_index, full_doc_id, content, content_vector, file_path, + create_time, update_time) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) + ON CONFLICT (workspace,id) DO UPDATE + SET tokens=EXCLUDED.tokens, + chunk_order_index=EXCLUDED.chunk_order_index, + full_doc_id=EXCLUDED.full_doc_id, + content = EXCLUDED.content, + content_vector=EXCLUDED.content_vector, + file_path=EXCLUDED.file_path, + update_time = EXCLUDED.update_time + """, + # SQL for VectorStorage + "upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content, + content_vector, chunk_ids, file_path, create_time, update_time) + VALUES ($1, $2, $3, $4, $5, $6::varchar[], $7, $8, $9) + ON CONFLICT (workspace,id) DO UPDATE + SET entity_name=EXCLUDED.entity_name, + content=EXCLUDED.content, + content_vector=EXCLUDED.content_vector, + chunk_ids=EXCLUDED.chunk_ids, + file_path=EXCLUDED.file_path, + update_time=EXCLUDED.update_time + """, + "upsert_relationship": """INSERT INTO LIGHTRAG_VDB_RELATION (workspace, id, source_id, + target_id, content, content_vector, chunk_ids, file_path, create_time, update_time) + VALUES ($1, $2, $3, $4, $5, $6, $7::varchar[], $8, $9, $10) + ON CONFLICT (workspace,id) DO UPDATE + SET source_id=EXCLUDED.source_id, + target_id=EXCLUDED.target_id, + content=EXCLUDED.content, + content_vector=EXCLUDED.content_vector, + chunk_ids=EXCLUDED.chunk_ids, + file_path=EXCLUDED.file_path, + update_time = EXCLUDED.update_time + """, + "relationships": """ + WITH relevant_chunks AS ( + SELECT id as chunk_id + FROM LIGHTRAG_DOC_CHUNKS + WHERE $2::varchar[] IS NULL OR full_doc_id = ANY($2::varchar[]) + ) + SELECT source_id as src_id, target_id as tgt_id, EXTRACT(EPOCH FROM create_time)::BIGINT as created_at + FROM ( + SELECT r.id, r.source_id, r.target_id, r.create_time, 1 - (r.content_vector <=> '[{embedding_string}]'::vector) as distance + FROM LIGHTRAG_VDB_RELATION r + JOIN relevant_chunks c ON c.chunk_id = ANY(r.chunk_ids) + WHERE r.workspace=$1 + ) filtered + WHERE distance>$3 + ORDER BY distance DESC + LIMIT $4 + """, + "entities": """ + WITH relevant_chunks AS ( + SELECT id as chunk_id + FROM LIGHTRAG_DOC_CHUNKS + WHERE $2::varchar[] IS NULL OR full_doc_id = ANY($2::varchar[]) + ) + SELECT entity_name, EXTRACT(EPOCH FROM create_time)::BIGINT as created_at FROM + ( + SELECT e.id, e.entity_name, e.create_time, 1 - (e.content_vector <=> '[{embedding_string}]'::vector) as distance + FROM LIGHTRAG_VDB_ENTITY e + JOIN relevant_chunks c ON c.chunk_id = ANY(e.chunk_ids) + WHERE e.workspace=$1 + ) as chunk_distances + WHERE distance>$3 + ORDER BY distance DESC + LIMIT $4 + """, + "chunks": """ + WITH relevant_chunks AS ( + SELECT id as chunk_id + FROM LIGHTRAG_DOC_CHUNKS + WHERE $2::varchar[] IS NULL OR full_doc_id = ANY($2::varchar[]) + ) + SELECT id, content, file_path, EXTRACT(EPOCH FROM create_time)::BIGINT as created_at FROM + ( + SELECT id, content, file_path, create_time, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance + FROM LIGHTRAG_DOC_CHUNKS + WHERE workspace=$1 + AND id IN (SELECT chunk_id FROM relevant_chunks) + ) as chunk_distances + WHERE distance>$3 + ORDER BY distance DESC + LIMIT $4 + """, + # DROP tables + "drop_specifiy_table_workspace": """ + DELETE FROM {table_name} WHERE workspace=$1 + """, +} diff --git a/dsLightRag/Util/LightRagUtil.py b/dsLightRag/Util/LightRagUtil.py index 3f6c588d..5a5b012c 100644 --- a/dsLightRag/Util/LightRagUtil.py +++ b/dsLightRag/Util/LightRagUtil.py @@ -169,7 +169,12 @@ async def initialize_pg_rag(WORKING_DIR): doc_status_storage="PGDocStatusStorage", graph_storage="PGGraphStorage", vector_storage="PGVectorStorage", +<<<<<<< HEAD auto_manage_storages_states=False +======= + auto_manage_storages_states=False, + workspace='dsideal' +>>>>>>> ad0f8b29474086983be114604937eed80611fb97 ) await rag.initialize_storages()