From cac5e7856f39f8df597bc7f8711bf40e444c9430 Mon Sep 17 00:00:00 2001 From: HuangHai <10402852@qq.com> Date: Sat, 19 Jul 2025 11:26:15 +0800 Subject: [PATCH] 'commit' --- .../6、使用PG数据库.md | 12 - ...resql支持工作空间的代码修改.txt | 9 - .../postgres_impl.py | 3364 ----------------- .../第一、二部分数与代数.txt | 0 .../第三部分图形与几何.txt | 0 .../第四部分统计与概率.txt | 0 .../说明.txt | 0 .../下一步需要研究的技术内容.txt | 24 +- 8 files changed, 16 insertions(+), 3393 deletions(-) delete mode 100644 dsLightRag/Doc/T1、Postgresql支持工作空间的代码修改【有问题】/6、使用PG数据库.md delete mode 100644 dsLightRag/Doc/T1、Postgresql支持工作空间的代码修改【有问题】/Postgresql支持工作空间的代码修改.txt delete mode 100644 dsLightRag/Doc/T1、Postgresql支持工作空间的代码修改【有问题】/postgres_impl.py rename dsLightRag/Doc/{T2、史校长资料 => T1、史校长资料}/第一、二部分数与代数.txt (100%) rename dsLightRag/Doc/{T2、史校长资料 => T1、史校长资料}/第三部分图形与几何.txt (100%) rename dsLightRag/Doc/{T2、史校长资料 => T1、史校长资料}/第四部分统计与概率.txt (100%) rename dsLightRag/Doc/{T2、史校长资料 => T1、史校长资料}/说明.txt (100%) diff --git a/dsLightRag/Doc/T1、Postgresql支持工作空间的代码修改【有问题】/6、使用PG数据库.md b/dsLightRag/Doc/T1、Postgresql支持工作空间的代码修改【有问题】/6、使用PG数据库.md deleted file mode 100644 index 436e9292..00000000 --- a/dsLightRag/Doc/T1、Postgresql支持工作空间的代码修改【有问题】/6、使用PG数据库.md +++ /dev/null @@ -1,12 +0,0 @@ -- 渡渡鸟 -``` -docker pull swr.cn-north-4.myhuaweicloud.com/ddn-k8s/docker.io/shangor/postgres-for-rag:v1.0 -docker tag swr.cn-north-4.myhuaweicloud.com/ddn-k8s/docker.io/shangor/postgres-for-rag:v1.0 shangor/postgres-for-rag:latest -``` - -- 启动 -``` -docker ps -a -docker rm -f 067d6ec9324d -docker run -p 5432:5432 -d --name postgres-LightRag shangor/postgres-for-rag sh -c "service postgresql start && sleep infinity" -``` \ No newline at end of file diff --git a/dsLightRag/Doc/T1、Postgresql支持工作空间的代码修改【有问题】/Postgresql支持工作空间的代码修改.txt b/dsLightRag/Doc/T1、Postgresql支持工作空间的代码修改【有问题】/Postgresql支持工作空间的代码修改.txt deleted file mode 100644 index 9d0eb07e..00000000 --- a/dsLightRag/Doc/T1、Postgresql支持工作空间的代码修改【有问题】/Postgresql支持工作空间的代码修改.txt +++ /dev/null @@ -1,9 +0,0 @@ -参考文档 -https://github.com/HKUDS/LightRAG/issues/1244 - -源码路径 -D:\anaconda3\envs\py310\Lib\site-packages\lightrag - -修改的文件 -# 用VScode打开编辑,可以全局搜索 -D:\anaconda3\envs\py310\Lib\site-packages\lightrag\kg\postgres_impl.py \ No newline at end of file diff --git a/dsLightRag/Doc/T1、Postgresql支持工作空间的代码修改【有问题】/postgres_impl.py b/dsLightRag/Doc/T1、Postgresql支持工作空间的代码修改【有问题】/postgres_impl.py deleted file mode 100644 index e6bb7ca6..00000000 --- a/dsLightRag/Doc/T1、Postgresql支持工作空间的代码修改【有问题】/postgres_impl.py +++ /dev/null @@ -1,3364 +0,0 @@ -import asyncio -import json -import os -import re -import datetime -from datetime import timezone -from dataclasses import dataclass, field -from typing import Any, Union, final -import numpy as np -import configparser - -from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge - -from tenacity import ( - retry, - retry_if_exception_type, - stop_after_attempt, - wait_exponential, -) - -from ..base import ( - BaseGraphStorage, - BaseKVStorage, - BaseVectorStorage, - DocProcessingStatus, - DocStatus, - DocStatusStorage, -) -from ..namespace import NameSpace, is_namespace -from ..utils import logger -from ..constants import GRAPH_FIELD_SEP - -import pipmaster as pm - -if not pm.is_installed("asyncpg"): - pm.install("asyncpg") - -import asyncpg # type: ignore -from asyncpg import Pool # type: ignore - -from dotenv import load_dotenv - -# use the .env that is inside the current folder -# allows to use different .env file for each lightrag instance -# the OS environment variables take precedence over the .env file -load_dotenv(dotenv_path=".env", override=False) - - -class PostgreSQLDB: - def __init__(self, config: dict[str, Any], **kwargs: Any): - self.host = config["host"] - self.port = config["port"] - self.user = config["user"] - self.password = config["password"] - self.database = config["database"] - self.workspace = config["workspace"] - self.max = int(config["max_connections"]) - self.increment = 1 - self.pool: Pool | None = None - - if self.user is None or self.password is None or self.database is None: - raise ValueError("Missing database user, password, or database") - - async def initdb(self): - try: - self.pool = await asyncpg.create_pool( # type: ignore - user=self.user, - password=self.password, - database=self.database, - host=self.host, - port=self.port, - min_size=1, - max_size=self.max, - ) - - logger.info( - f"PostgreSQL, Connected to database at {self.host}:{self.port}/{self.database}" - ) - except Exception as e: - logger.error( - f"PostgreSQL, Failed to connect database at {self.host}:{self.port}/{self.database}, Got:{e}" - ) - raise - - @staticmethod - async def configure_age(connection: asyncpg.Connection, graph_name: str) -> None: - """Set the Apache AGE environment and creates a graph if it does not exist. - - This method: - - Sets the PostgreSQL `search_path` to include `ag_catalog`, ensuring that Apache AGE functions can be used without specifying the schema. - - Attempts to create a new graph with the provided `graph_name` if it does not already exist. - - Silently ignores errors related to the graph already existing. - - """ - try: - await connection.execute( # type: ignore - 'SET search_path = ag_catalog, "$user", public' - ) - await connection.execute( # type: ignore - f"select create_graph('{graph_name}')" - ) - except ( - asyncpg.exceptions.InvalidSchemaNameError, - asyncpg.exceptions.UniqueViolationError, - ): - pass - - async def _migrate_llm_cache_add_columns(self): - """Add chunk_id and cache_type columns to LIGHTRAG_LLM_CACHE table if they don't exist""" - try: - # Check if both columns exist - check_columns_sql = """ - SELECT column_name - FROM information_schema.columns - WHERE table_name = 'lightrag_llm_cache' - AND column_name IN ('chunk_id', 'cache_type') - """ - - existing_columns = await self.query(check_columns_sql, multirows=True) - existing_column_names = ( - {col["column_name"] for col in existing_columns} - if existing_columns - else set() - ) - - # Add missing chunk_id column - if "chunk_id" not in existing_column_names: - logger.info("Adding chunk_id column to LIGHTRAG_LLM_CACHE table") - add_chunk_id_sql = """ - ALTER TABLE LIGHTRAG_LLM_CACHE - ADD COLUMN chunk_id VARCHAR(255) NULL - """ - await self.execute(add_chunk_id_sql) - logger.info( - "Successfully added chunk_id column to LIGHTRAG_LLM_CACHE table" - ) - else: - logger.info( - "chunk_id column already exists in LIGHTRAG_LLM_CACHE table" - ) - - # Add missing cache_type column - if "cache_type" not in existing_column_names: - logger.info("Adding cache_type column to LIGHTRAG_LLM_CACHE table") - add_cache_type_sql = """ - ALTER TABLE LIGHTRAG_LLM_CACHE - ADD COLUMN cache_type VARCHAR(32) NULL - """ - await self.execute(add_cache_type_sql) - logger.info( - "Successfully added cache_type column to LIGHTRAG_LLM_CACHE table" - ) - - # Migrate existing data using optimized regex pattern - logger.info( - "Migrating existing LLM cache data to populate cache_type field (optimized)" - ) - optimized_update_sql = """ - UPDATE LIGHTRAG_LLM_CACHE - SET cache_type = CASE - WHEN id ~ '^[^:]+:[^:]+:' THEN split_part(id, ':', 2) - ELSE 'extract' - END - WHERE cache_type IS NULL - """ - await self.execute(optimized_update_sql) - logger.info("Successfully migrated existing LLM cache data") - else: - logger.info( - "cache_type column already exists in LIGHTRAG_LLM_CACHE table" - ) - - except Exception as e: - logger.warning(f"Failed to add columns to LIGHTRAG_LLM_CACHE: {e}") - - async def _migrate_timestamp_columns(self): - """Migrate timestamp columns in tables to witimezone-free types, assuming original data is in UTC time""" - # Tables and columns that need migration - tables_to_migrate = { - "LIGHTRAG_VDB_ENTITY": ["create_time", "update_time"], - "LIGHTRAG_VDB_RELATION": ["create_time", "update_time"], - "LIGHTRAG_DOC_CHUNKS": ["create_time", "update_time"], - "LIGHTRAG_DOC_STATUS": ["created_at", "updated_at"], - } - - for table_name, columns in tables_to_migrate.items(): - for column_name in columns: - try: - # Check if column exists - check_column_sql = f""" - SELECT column_name, data_type - FROM information_schema.columns - WHERE table_name = '{table_name.lower()}' - AND column_name = '{column_name}' - """ - - column_info = await self.query(check_column_sql) - if not column_info: - logger.warning( - f"Column {table_name}.{column_name} does not exist, skipping migration" - ) - continue - - # Check column type - data_type = column_info.get("data_type") - if data_type == "timestamp without time zone": - logger.debug( - f"Column {table_name}.{column_name} is already witimezone-free, no migration needed" - ) - continue - - # Execute migration, explicitly specifying UTC timezone for interpreting original data - logger.info( - f"Migrating {table_name}.{column_name} from {data_type} to TIMESTAMP(0) type" - ) - migration_sql = f""" - ALTER TABLE {table_name} - ALTER COLUMN {column_name} TYPE TIMESTAMP(0), - ALTER COLUMN {column_name} SET DEFAULT CURRENT_TIMESTAMP - """ - - await self.execute(migration_sql) - logger.info( - f"Successfully migrated {table_name}.{column_name} to timezone-free type" - ) - except Exception as e: - # Log error but don't interrupt the process - logger.warning(f"Failed to migrate {table_name}.{column_name}: {e}") - - async def _migrate_doc_chunks_to_vdb_chunks(self): - """ - Migrate data from LIGHTRAG_DOC_CHUNKS to LIGHTRAG_VDB_CHUNKS if specific conditions are met. - This migration is intended for users who are upgrading and have an older table structure - where LIGHTRAG_DOC_CHUNKS contained a `content_vector` column. - - """ - try: - # 1. Check if the new table LIGHTRAG_VDB_CHUNKS is empty - vdb_chunks_count_sql = "SELECT COUNT(1) as count FROM LIGHTRAG_VDB_CHUNKS" - vdb_chunks_count_result = await self.query(vdb_chunks_count_sql) - if vdb_chunks_count_result and vdb_chunks_count_result["count"] > 0: - logger.info( - "Skipping migration: LIGHTRAG_VDB_CHUNKS already contains data." - ) - return - - # 2. Check if `content_vector` column exists in the old table - check_column_sql = """ - SELECT 1 FROM information_schema.columns - WHERE table_name = 'lightrag_doc_chunks' AND column_name = 'content_vector' - """ - column_exists = await self.query(check_column_sql) - if not column_exists: - logger.info( - "Skipping migration: `content_vector` not found in LIGHTRAG_DOC_CHUNKS" - ) - return - - # 3. Check if the old table LIGHTRAG_DOC_CHUNKS has data - doc_chunks_count_sql = "SELECT COUNT(1) as count FROM LIGHTRAG_DOC_CHUNKS" - doc_chunks_count_result = await self.query(doc_chunks_count_sql) - if not doc_chunks_count_result or doc_chunks_count_result["count"] == 0: - logger.info("Skipping migration: LIGHTRAG_DOC_CHUNKS is empty.") - return - - # 4. Perform the migration - logger.info( - "Starting data migration from LIGHTRAG_DOC_CHUNKS to LIGHTRAG_VDB_CHUNKS..." - ) - migration_sql = """ - INSERT INTO LIGHTRAG_VDB_CHUNKS ( - id, workspace, full_doc_id, chunk_order_index, tokens, content, - content_vector, file_path, create_time, update_time - ) - SELECT - id, workspace, full_doc_id, chunk_order_index, tokens, content, - content_vector, file_path, create_time, update_time - FROM LIGHTRAG_DOC_CHUNKS - ON CONFLICT (workspace, id) DO NOTHING; - """ - await self.execute(migration_sql) - logger.info("Data migration to LIGHTRAG_VDB_CHUNKS completed successfully.") - - except Exception as e: - logger.error(f"Failed during data migration to LIGHTRAG_VDB_CHUNKS: {e}") - # Do not re-raise, to allow the application to start - - async def _check_llm_cache_needs_migration(self): - """Check if LLM cache data needs migration by examining any record with old format""" - try: - # Optimized query: directly check for old format records without sorting - check_sql = """ - SELECT 1 FROM LIGHTRAG_LLM_CACHE - WHERE id NOT LIKE '%:%' - LIMIT 1 - """ - result = await self.query(check_sql) - - # If any old format record exists, migration is needed - return result is not None - - except Exception as e: - logger.warning(f"Failed to check LLM cache migration status: {e}") - return False - - async def _migrate_llm_cache_to_flattened_keys(self): - """Optimized version: directly execute single UPDATE migration to migrate old format cache keys to flattened format""" - try: - # Check if migration is needed - check_sql = """ - SELECT COUNT(*) as count FROM LIGHTRAG_LLM_CACHE - WHERE id NOT LIKE '%:%' - """ - result = await self.query(check_sql) - - if not result or result["count"] == 0: - logger.info("No old format LLM cache data found, skipping migration") - return - - old_count = result["count"] - logger.info(f"Found {old_count} old format cache records") - - # Check potential primary key conflicts (optional but recommended) - conflict_check_sql = """ - WITH new_ids AS ( - SELECT - workspace, - mode, - id as old_id, - mode || ':' || - CASE WHEN mode = 'default' THEN 'extract' ELSE 'unknown' END || ':' || - md5(original_prompt) as new_id - FROM LIGHTRAG_LLM_CACHE - WHERE id NOT LIKE '%:%' - ) - SELECT COUNT(*) as conflicts - FROM new_ids n1 - JOIN LIGHTRAG_LLM_CACHE existing - ON existing.workspace = n1.workspace - AND existing.mode = n1.mode - AND existing.id = n1.new_id - WHERE existing.id LIKE '%:%' -- Only check conflicts with existing new format records - """ - - conflict_result = await self.query(conflict_check_sql) - if conflict_result and conflict_result["conflicts"] > 0: - logger.warning( - f"Found {conflict_result['conflicts']} potential ID conflicts with existing records" - ) - # Can choose to continue or abort, here we choose to continue and log warning - - # Execute single UPDATE migration - logger.info("Starting optimized LLM cache migration...") - migration_sql = """ - UPDATE LIGHTRAG_LLM_CACHE - SET - id = mode || ':' || - CASE WHEN mode = 'default' THEN 'extract' ELSE 'unknown' END || ':' || - md5(original_prompt), - cache_type = CASE WHEN mode = 'default' THEN 'extract' ELSE 'unknown' END, - update_time = CURRENT_TIMESTAMP - WHERE id NOT LIKE '%:%' - """ - - # Execute migration - await self.execute(migration_sql) - - # Verify migration results - verify_sql = """ - SELECT COUNT(*) as remaining_old FROM LIGHTRAG_LLM_CACHE - WHERE id NOT LIKE '%:%' - """ - verify_result = await self.query(verify_sql) - remaining = verify_result["remaining_old"] if verify_result else -1 - - if remaining == 0: - logger.info( - f"✅ Successfully migrated {old_count} LLM cache records to flattened format" - ) - else: - logger.warning( - f"⚠️ Migration completed but {remaining} old format records remain" - ) - - except Exception as e: - logger.error(f"Optimized LLM cache migration failed: {e}") - raise - - async def _migrate_doc_status_add_chunks_list(self): - """Add chunks_list column to LIGHTRAG_DOC_STATUS table if it doesn't exist""" - try: - # Check if chunks_list column exists - check_column_sql = """ - SELECT column_name - FROM information_schema.columns - WHERE table_name = 'lightrag_doc_status' - AND column_name = 'chunks_list' - """ - - column_info = await self.query(check_column_sql) - if not column_info: - logger.info("Adding chunks_list column to LIGHTRAG_DOC_STATUS table") - add_column_sql = """ - ALTER TABLE LIGHTRAG_DOC_STATUS - ADD COLUMN chunks_list JSONB NULL DEFAULT '[]'::jsonb - """ - await self.execute(add_column_sql) - logger.info( - "Successfully added chunks_list column to LIGHTRAG_DOC_STATUS table" - ) - else: - logger.info( - "chunks_list column already exists in LIGHTRAG_DOC_STATUS table" - ) - except Exception as e: - logger.warning( - f"Failed to add chunks_list column to LIGHTRAG_DOC_STATUS: {e}" - ) - - async def _migrate_text_chunks_add_llm_cache_list(self): - """Add llm_cache_list column to LIGHTRAG_DOC_CHUNKS table if it doesn't exist""" - try: - # Check if llm_cache_list column exists - check_column_sql = """ - SELECT column_name - FROM information_schema.columns - WHERE table_name = 'lightrag_doc_chunks' - AND column_name = 'llm_cache_list' - """ - - column_info = await self.query(check_column_sql) - if not column_info: - logger.info("Adding llm_cache_list column to LIGHTRAG_DOC_CHUNKS table") - add_column_sql = """ - ALTER TABLE LIGHTRAG_DOC_CHUNKS - ADD COLUMN llm_cache_list JSONB NULL DEFAULT '[]'::jsonb - """ - await self.execute(add_column_sql) - logger.info( - "Successfully added llm_cache_list column to LIGHTRAG_DOC_CHUNKS table" - ) - else: - logger.info( - "llm_cache_list column already exists in LIGHTRAG_DOC_CHUNKS table" - ) - except Exception as e: - logger.warning( - f"Failed to add llm_cache_list column to LIGHTRAG_DOC_CHUNKS: {e}" - ) - - async def _migrate_field_lengths(self): - """Migrate database field lengths: entity_name, source_id, target_id, and file_path""" - # Define the field changes needed - field_migrations = [ - { - "table": "LIGHTRAG_VDB_ENTITY", - "column": "entity_name", - "old_type": "character varying(255)", - "new_type": "VARCHAR(512)", - "description": "entity_name from 255 to 512", - }, - { - "table": "LIGHTRAG_VDB_RELATION", - "column": "source_id", - "old_type": "character varying(256)", - "new_type": "VARCHAR(512)", - "description": "source_id from 256 to 512", - }, - { - "table": "LIGHTRAG_VDB_RELATION", - "column": "target_id", - "old_type": "character varying(256)", - "new_type": "VARCHAR(512)", - "description": "target_id from 256 to 512", - }, - { - "table": "LIGHTRAG_DOC_CHUNKS", - "column": "file_path", - "old_type": "character varying(256)", - "new_type": "TEXT", - "description": "file_path to TEXT NULL", - }, - { - "table": "LIGHTRAG_VDB_CHUNKS", - "column": "file_path", - "old_type": "character varying(256)", - "new_type": "TEXT", - "description": "file_path to TEXT NULL", - }, - ] - - for migration in field_migrations: - try: - # Check current column definition - check_column_sql = """ - SELECT column_name, data_type, character_maximum_length, is_nullable - FROM information_schema.columns - WHERE table_name = $1 AND column_name = $2 - """ - - column_info = await self.query( - check_column_sql, - { - "table_name": migration["table"].lower(), - "column_name": migration["column"], - }, - ) - - if not column_info: - logger.warning( - f"Column {migration['table']}.{migration['column']} does not exist, skipping migration" - ) - continue - - current_type = column_info.get("data_type", "").lower() - current_length = column_info.get("character_maximum_length") - - # Check if migration is needed - needs_migration = False - - if migration["column"] == "entity_name" and current_length == 255: - needs_migration = True - elif ( - migration["column"] in ["source_id", "target_id"] - and current_length == 256 - ): - needs_migration = True - elif ( - migration["column"] == "file_path" - and current_type == "character varying" - ): - needs_migration = True - - if needs_migration: - logger.info( - f"Migrating {migration['table']}.{migration['column']}: {migration['description']}" - ) - - # Execute the migration - alter_sql = f""" - ALTER TABLE {migration['table']} - ALTER COLUMN {migration['column']} TYPE {migration['new_type']} - """ - - await self.execute(alter_sql) - logger.info( - f"Successfully migrated {migration['table']}.{migration['column']}" - ) - else: - logger.debug( - f"Column {migration['table']}.{migration['column']} already has correct type, no migration needed" - ) - - except Exception as e: - # Log error but don't interrupt the process - logger.warning( - f"Failed to migrate {migration['table']}.{migration['column']}: {e}" - ) - - async def check_tables(self): - # First create all tables - for k, v in TABLES.items(): - try: - await self.query(f"SELECT 1 FROM {k} LIMIT 1") - except Exception: - try: - logger.info(f"PostgreSQL, Try Creating table {k} in database") - await self.execute(v["ddl"]) - logger.info( - f"PostgreSQL, Creation success table {k} in PostgreSQL database" - ) - except Exception as e: - logger.error( - f"PostgreSQL, Failed to create table {k} in database, Please verify the connection with PostgreSQL database, Got: {e}" - ) - raise e - - # Create index for id column in each table - try: - index_name = f"idx_{k.lower()}_id" - check_index_sql = f""" - SELECT 1 FROM pg_indexes - WHERE indexname = '{index_name}' - AND tablename = '{k.lower()}' - """ - index_exists = await self.query(check_index_sql) - - if not index_exists: - create_index_sql = f"CREATE INDEX {index_name} ON {k}(id)" - logger.info(f"PostgreSQL, Creating index {index_name} on table {k}") - await self.execute(create_index_sql) - except Exception as e: - logger.error( - f"PostgreSQL, Failed to create index on table {k}, Got: {e}" - ) - - # Create composite index for (workspace, id) columns in each table - try: - composite_index_name = f"idx_{k.lower()}_workspace_id" - check_composite_index_sql = f""" - SELECT 1 FROM pg_indexes - WHERE indexname = '{composite_index_name}' - AND tablename = '{k.lower()}' - """ - composite_index_exists = await self.query(check_composite_index_sql) - - if not composite_index_exists: - create_composite_index_sql = ( - f"CREATE INDEX {composite_index_name} ON {k}(workspace, id)" - ) - logger.info( - f"PostgreSQL, Creating composite index {composite_index_name} on table {k}" - ) - await self.execute(create_composite_index_sql) - except Exception as e: - logger.error( - f"PostgreSQL, Failed to create composite index on table {k}, Got: {e}" - ) - - # After all tables are created, attempt to migrate timestamp fields - try: - await self._migrate_timestamp_columns() - except Exception as e: - logger.error(f"PostgreSQL, Failed to migrate timestamp columns: {e}") - # Don't throw an exception, allow the initialization process to continue - - # Migrate LLM cache table to add chunk_id and cache_type columns if needed - try: - await self._migrate_llm_cache_add_columns() - except Exception as e: - logger.error(f"PostgreSQL, Failed to migrate LLM cache columns: {e}") - # Don't throw an exception, allow the initialization process to continue - - # Finally, attempt to migrate old doc chunks data if needed - try: - await self._migrate_doc_chunks_to_vdb_chunks() - except Exception as e: - logger.error(f"PostgreSQL, Failed to migrate doc_chunks to vdb_chunks: {e}") - - # Check and migrate LLM cache to flattened keys if needed - try: - if await self._check_llm_cache_needs_migration(): - await self._migrate_llm_cache_to_flattened_keys() - except Exception as e: - logger.error(f"PostgreSQL, LLM cache migration failed: {e}") - - # Migrate doc status to add chunks_list field if needed - try: - await self._migrate_doc_status_add_chunks_list() - except Exception as e: - logger.error( - f"PostgreSQL, Failed to migrate doc status chunks_list field: {e}" - ) - - # Migrate text chunks to add llm_cache_list field if needed - try: - await self._migrate_text_chunks_add_llm_cache_list() - except Exception as e: - logger.error( - f"PostgreSQL, Failed to migrate text chunks llm_cache_list field: {e}" - ) - - # Migrate field lengths for entity_name, source_id, target_id, and file_path - try: - await self._migrate_field_lengths() - except Exception as e: - logger.error(f"PostgreSQL, Failed to migrate field lengths: {e}") - - async def query( - self, - sql: str, - params: dict[str, Any] | None = None, - multirows: bool = False, - with_age: bool = False, - graph_name: str | None = None, - ) -> dict[str, Any] | None | list[dict[str, Any]]: - # start_time = time.time() - # logger.info(f"PostgreSQL, Querying:\n{sql}") - - async with self.pool.acquire() as connection: # type: ignore - if with_age and graph_name: - await self.configure_age(connection, graph_name) # type: ignore - elif with_age and not graph_name: - raise ValueError("Graph name is required when with_age is True") - - try: - if params: - rows = await connection.fetch(sql, *params.values()) - else: - rows = await connection.fetch(sql) - - if multirows: - if rows: - columns = [col for col in rows[0].keys()] - data = [dict(zip(columns, row)) for row in rows] - else: - data = [] - else: - if rows: - columns = rows[0].keys() - data = dict(zip(columns, rows[0])) - else: - data = None - - # query_time = time.time() - start_time - # logger.info(f"PostgreSQL, Query result len: {len(data)}") - # logger.info(f"PostgreSQL, Query execution time: {query_time:.4f}s") - - return data - except Exception as e: - logger.error(f"PostgreSQL database, error:{e}") - raise - - async def execute( - self, - sql: str, - data: dict[str, Any] | None = None, - upsert: bool = False, - ignore_if_exists: bool = False, - with_age: bool = False, - graph_name: str | None = None, - ): - try: - async with self.pool.acquire() as connection: # type: ignore - if with_age and graph_name: - await self.configure_age(connection, graph_name) - elif with_age and not graph_name: - raise ValueError("Graph name is required when with_age is True") - - if data is None: - await connection.execute(sql) - else: - await connection.execute(sql, *data.values()) - except ( - asyncpg.exceptions.UniqueViolationError, - asyncpg.exceptions.DuplicateTableError, - asyncpg.exceptions.DuplicateObjectError, # Catch "already exists" error - asyncpg.exceptions.InvalidSchemaNameError, # Also catch for AGE extension "already exists" - ) as e: - if ignore_if_exists: - # If the flag is set, just ignore these specific errors - pass - elif upsert: - print("Key value duplicate, but upsert succeeded.") - else: - logger.error(f"Upsert error: {e}") - except Exception as e: - logger.error(f"PostgreSQL database,\nsql:{sql},\ndata:{data},\nerror:{e}") - raise - - -class ClientManager: - _instances: dict[str, Any] = {"db": None, "ref_count": 0} - _lock = asyncio.Lock() - - @staticmethod - #def get_config() -> dict[str, Any]: - def get_config(global_config: dict[str, Any] | None = None) -> dict[str, Any]: - # First try to get workspace from global config - workspace = None - if global_config and "vector_db_storage_cls_kwargs" in global_config: - workspace = global_config["vector_db_storage_cls_kwargs"].get("workspace") - # Read standard config - config = configparser.ConfigParser() - config.read("config.ini", "utf-8") - - return { - "host": os.environ.get( - "POSTGRES_HOST", - config.get("postgres", "host", fallback="localhost"), - ), - "port": os.environ.get( - "POSTGRES_PORT", config.get("postgres", "port", fallback=5432) - ), - "user": os.environ.get( - "POSTGRES_USER", config.get("postgres", "user", fallback="postgres") - ), - "password": os.environ.get( - "POSTGRES_PASSWORD", - config.get("postgres", "password", fallback=None), - ), - "database": os.environ.get( - "POSTGRES_DATABASE", - config.get("postgres", "database", fallback="postgres"), - ), - # Use workspace from global config if available, otherwise fall back to env/config.ini - #"workspace": os.environ.get( - "workspace": workspace or os.environ.get( - "POSTGRES_WORKSPACE", - config.get("postgres", "workspace", fallback="default"), - ), - "max_connections": os.environ.get( - "POSTGRES_MAX_CONNECTIONS", - config.get("postgres", "max_connections", fallback=20), - ), - } - - @classmethod - #async def get_client(cls) -> PostgreSQLDB: - async def get_client(cls, global_config: dict[str, Any] | None = None) -> PostgreSQLDB: - async with cls._lock: - if cls._instances["db"] is None: - #config = ClientManager.get_config() - config = cls.get_config(global_config) - db = PostgreSQLDB(config) - await db.initdb() - await db.check_tables() - cls._instances["db"] = db - cls._instances["ref_count"] = 0 - cls._instances["ref_count"] += 1 - return cls._instances["db"] - - @classmethod - async def release_client(cls, db: PostgreSQLDB): - async with cls._lock: - if db is not None: - if db is cls._instances["db"]: - cls._instances["ref_count"] -= 1 - if cls._instances["ref_count"] == 0: - await db.pool.close() - logger.info("Closed PostgreSQL database connection pool") - cls._instances["db"] = None - else: - await db.pool.close() - - -@final -@dataclass -class PGKVStorage(BaseKVStorage): - db: PostgreSQLDB = field(default=None) - - def __post_init__(self): - self.base_namespace = self.global_config["vector_db_storage_cls_kwargs"].get("workspace") - self._max_batch_size = self.global_config["embedding_batch_num"] - - async def initialize(self): - if self.db is None: - self.db = await ClientManager.get_client(self.global_config) - #self.db = await ClientManager.get_client() - # Implement workspace priority: PostgreSQLDB.workspace > self.workspace > "default" - if self.db.workspace: - # Use PostgreSQLDB's workspace (highest priority) - final_workspace = self.db.workspace - elif hasattr(self, "workspace") and self.workspace: - # Use storage class's workspace (medium priority) - final_workspace = self.workspace - self.db.workspace = final_workspace - else: - # Use "default" for compatibility (lowest priority) - final_workspace = "default" - self.db.workspace = final_workspace - - async def finalize(self): - if self.db is not None: - await ClientManager.release_client(self.db) - self.db = None - - ################ QUERY METHODS ################ - async def get_all(self) -> dict[str, Any]: - """Get all data from storage - - Returns: - Dictionary containing all stored data - """ - table_name = namespace_to_table_name(self.namespace) - if not table_name: - logger.error(f"Unknown namespace for get_all: {self.namespace}") - return {} - - sql = f"SELECT * FROM {table_name} WHERE workspace=$1" - params = {"workspace": self.db.workspace} - - try: - results = await self.db.query(sql, params, multirows=True) - - # Special handling for LLM cache to ensure compatibility with _get_cached_extraction_results - if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): - processed_results = {} - for row in results: - create_time = row.get("create_time", 0) - update_time = row.get("update_time", 0) - # Map field names and add cache_type for compatibility - processed_row = { - **row, - "return": row.get("return_value", ""), - "cache_type": row.get("original_prompt", "unknow"), - "original_prompt": row.get("original_prompt", ""), - "chunk_id": row.get("chunk_id"), - "mode": row.get("mode", "default"), - "create_time": create_time, - "update_time": create_time if update_time == 0 else update_time, - } - processed_results[row["id"]] = processed_row - return processed_results - - # For text_chunks namespace, parse llm_cache_list JSON string back to list - if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS): - processed_results = {} - for row in results: - llm_cache_list = row.get("llm_cache_list", []) - if isinstance(llm_cache_list, str): - try: - llm_cache_list = json.loads(llm_cache_list) - except json.JSONDecodeError: - llm_cache_list = [] - row["llm_cache_list"] = llm_cache_list - create_time = row.get("create_time", 0) - update_time = row.get("update_time", 0) - row["create_time"] = create_time - row["update_time"] = ( - create_time if update_time == 0 else update_time - ) - processed_results[row["id"]] = row - return processed_results - - # For other namespaces, return as-is - return {row["id"]: row for row in results} - except Exception as e: - logger.error(f"Error retrieving all data from {self.namespace}: {e}") - return {} - - async def get_by_id(self, id: str) -> dict[str, Any] | None: - """Get data by id.""" - sql = SQL_TEMPLATES["get_by_id_" + self.namespace] - params = {"workspace": self.db.workspace, "id": id} - response = await self.db.query(sql, params) - - if response and is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS): - # Parse llm_cache_list JSON string back to list - llm_cache_list = response.get("llm_cache_list", []) - if isinstance(llm_cache_list, str): - try: - llm_cache_list = json.loads(llm_cache_list) - except json.JSONDecodeError: - llm_cache_list = [] - response["llm_cache_list"] = llm_cache_list - create_time = response.get("create_time", 0) - update_time = response.get("update_time", 0) - response["create_time"] = create_time - response["update_time"] = create_time if update_time == 0 else update_time - - # Special handling for LLM cache to ensure compatibility with _get_cached_extraction_results - if response and is_namespace( - self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE - ): - create_time = response.get("create_time", 0) - update_time = response.get("update_time", 0) - # Map field names and add cache_type for compatibility - response = { - **response, - "return": response.get("return_value", ""), - "cache_type": response.get("cache_type"), - "original_prompt": response.get("original_prompt", ""), - "chunk_id": response.get("chunk_id"), - "mode": response.get("mode", "default"), - "create_time": create_time, - "update_time": create_time if update_time == 0 else update_time, - } - - return response if response else None - - # Query by id - async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: - """Get data by ids""" - sql = SQL_TEMPLATES["get_by_ids_" + self.namespace].format( - ids=",".join([f"'{id}'" for id in ids]) - ) - params = {"workspace": self.db.workspace} - results = await self.db.query(sql, params, multirows=True) - - if results and is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS): - # Parse llm_cache_list JSON string back to list for each result - for result in results: - llm_cache_list = result.get("llm_cache_list", []) - if isinstance(llm_cache_list, str): - try: - llm_cache_list = json.loads(llm_cache_list) - except json.JSONDecodeError: - llm_cache_list = [] - result["llm_cache_list"] = llm_cache_list - create_time = result.get("create_time", 0) - update_time = result.get("update_time", 0) - result["create_time"] = create_time - result["update_time"] = create_time if update_time == 0 else update_time - - # Special handling for LLM cache to ensure compatibility with _get_cached_extraction_results - if results and is_namespace( - self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE - ): - processed_results = [] - for row in results: - create_time = row.get("create_time", 0) - update_time = row.get("update_time", 0) - # Map field names and add cache_type for compatibility - processed_row = { - **row, - "return": row.get("return_value", ""), - "cache_type": row.get("cache_type"), - "original_prompt": row.get("original_prompt", ""), - "chunk_id": row.get("chunk_id"), - "mode": row.get("mode", "default"), - "create_time": create_time, - "update_time": create_time if update_time == 0 else update_time, - } - processed_results.append(processed_row) - return processed_results - - return results if results else [] - - async def filter_keys(self, keys: set[str]) -> set[str]: - """Filter out duplicated content""" - sql = SQL_TEMPLATES["filter_keys"].format( - table_name=namespace_to_table_name(self.namespace), - ids=",".join([f"'{id}'" for id in keys]), - ) - params = {"workspace": self.db.workspace} - try: - res = await self.db.query(sql, params, multirows=True) - if res: - exist_keys = [key["id"] for key in res] - else: - exist_keys = [] - new_keys = set([s for s in keys if s not in exist_keys]) - return new_keys - except Exception as e: - logger.error( - f"PostgreSQL database,\nsql:{sql},\nparams:{params},\nerror:{e}" - ) - raise - - ################ INSERT METHODS ################ - async def upsert(self, data: dict[str, dict[str, Any]]) -> None: - logger.debug(f"Inserting {len(data)} to {self.namespace}") - if not data: - return - - if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS): - # Get current UTC time and convert to naive datetime for database storage - current_time = datetime.datetime.now(timezone.utc).replace(tzinfo=None) - for k, v in data.items(): - upsert_sql = SQL_TEMPLATES["upsert_text_chunk"] - _data = { - "workspace": self.db.workspace, - "id": k, - "tokens": v["tokens"], - "chunk_order_index": v["chunk_order_index"], - "full_doc_id": v["full_doc_id"], - "content": v["content"], - "file_path": v["file_path"], - "llm_cache_list": json.dumps(v.get("llm_cache_list", [])), - "create_time": current_time, - "update_time": current_time, - } - await self.db.execute(upsert_sql, _data) - elif is_namespace(self.namespace, NameSpace.KV_STORE_FULL_DOCS): - for k, v in data.items(): - upsert_sql = SQL_TEMPLATES["upsert_doc_full"] - _data = { - "id": k, - "content": v["content"], - "workspace": self.db.workspace, - } - await self.db.execute(upsert_sql, _data) - elif is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE): - for k, v in data.items(): - upsert_sql = SQL_TEMPLATES["upsert_llm_response_cache"] - _data = { - "workspace": self.db.workspace, - "id": k, # Use flattened key as id - "original_prompt": v["original_prompt"], - "return_value": v["return"], - "mode": v.get("mode", "default"), # Get mode from data - "chunk_id": v.get("chunk_id"), - "cache_type": v.get( - "cache_type", "extract" - ), # Get cache_type from data - } - - await self.db.execute(upsert_sql, _data) - - async def index_done_callback(self) -> None: - # PG handles persistence automatically - pass - - async def delete(self, ids: list[str]) -> None: - """Delete specific records from storage by their IDs - - Args: - ids (list[str]): List of document IDs to be deleted from storage - - Returns: - None - """ - if not ids: - return - - table_name = namespace_to_table_name(self.namespace) - if not table_name: - logger.error(f"Unknown namespace for deletion: {self.namespace}") - return - - delete_sql = f"DELETE FROM {table_name} WHERE workspace=$1 AND id = ANY($2)" - - try: - await self.db.execute( - delete_sql, {"workspace": self.db.workspace, "ids": ids} - ) - logger.debug( - f"Successfully deleted {len(ids)} records from {self.namespace}" - ) - except Exception as e: - logger.error(f"Error while deleting records from {self.namespace}: {e}") - - async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool: - """Delete specific records from storage by cache mode - - Args: - modes (list[str]): List of cache modes to be dropped from storage - - Returns: - bool: True if successful, False otherwise - """ - if not modes: - return False - - try: - table_name = namespace_to_table_name(self.namespace) - if not table_name: - return False - - if table_name != "LIGHTRAG_LLM_CACHE": - return False - - sql = f""" - DELETE FROM {table_name} - WHERE workspace = $1 AND mode = ANY($2) - """ - params = {"workspace": self.db.workspace, "modes": modes} - - logger.info(f"Deleting cache by modes: {modes}") - await self.db.execute(sql, params) - return True - except Exception as e: - logger.error(f"Error deleting cache by modes {modes}: {e}") - return False - - async def drop(self) -> dict[str, str]: - """Drop the storage""" - try: - table_name = namespace_to_table_name(self.namespace) - if not table_name: - return { - "status": "error", - "message": f"Unknown namespace: {self.namespace}", - } - - drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format( - table_name=table_name - ) - await self.db.execute(drop_sql, {"workspace": self.db.workspace}) - return {"status": "success", "message": "data dropped"} - except Exception as e: - return {"status": "error", "message": str(e)} - - -@final -@dataclass -class PGVectorStorage(BaseVectorStorage): - db: PostgreSQLDB | None = field(default=None) - - def __post_init__(self): - self._max_batch_size = self.global_config["embedding_batch_num"] - self.base_namespace = self.global_config["vector_db_storage_cls_kwargs"].get("workspace") - config = self.global_config.get("vector_db_storage_cls_kwargs", {}) - cosine_threshold = config.get("cosine_better_than_threshold") - if cosine_threshold is None: - raise ValueError( - "cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs" - ) - self.cosine_better_than_threshold = cosine_threshold - - async def initialize(self): - if self.db is None: - #self.db = await ClientManager.get_client() - self.db = await ClientManager.get_client(self.global_config) - # Implement workspace priority: PostgreSQLDB.workspace > self.workspace > "default" - if self.db.workspace: - # Use PostgreSQLDB's workspace (highest priority) - final_workspace = self.db.workspace - elif hasattr(self, "workspace") and self.workspace: - # Use storage class's workspace (medium priority) - final_workspace = self.workspace - self.db.workspace = final_workspace - else: - # Use "default" for compatibility (lowest priority) - final_workspace = "default" - self.db.workspace = final_workspace - - async def finalize(self): - if self.db is not None: - await ClientManager.release_client(self.db) - self.db = None - - def _upsert_chunks( - self, item: dict[str, Any], current_time: datetime.datetime - ) -> tuple[str, dict[str, Any]]: - try: - upsert_sql = SQL_TEMPLATES["upsert_chunk"] - data: dict[str, Any] = { - "workspace": self.db.workspace, - "id": item["__id__"], - "tokens": item["tokens"], - "chunk_order_index": item["chunk_order_index"], - "full_doc_id": item["full_doc_id"], - "content": item["content"], - "content_vector": json.dumps(item["__vector__"].tolist()), - "file_path": item["file_path"], - "create_time": current_time, - "update_time": current_time, - } - except Exception as e: - logger.error(f"Error to prepare upsert,\nsql: {e}\nitem: {item}") - raise - - return upsert_sql, data - - def _upsert_entities( - self, item: dict[str, Any], current_time: datetime.datetime - ) -> tuple[str, dict[str, Any]]: - upsert_sql = SQL_TEMPLATES["upsert_entity"] - source_id = item["source_id"] - if isinstance(source_id, str) and "" in source_id: - chunk_ids = source_id.split("") - else: - chunk_ids = [source_id] - - data: dict[str, Any] = { - "workspace": self.db.workspace, - "id": item["__id__"], - "entity_name": item["entity_name"], - "content": item["content"], - "content_vector": json.dumps(item["__vector__"].tolist()), - "chunk_ids": chunk_ids, - "file_path": item.get("file_path", None), - "create_time": current_time, - "update_time": current_time, - } - return upsert_sql, data - - def _upsert_relationships( - self, item: dict[str, Any], current_time: datetime.datetime - ) -> tuple[str, dict[str, Any]]: - upsert_sql = SQL_TEMPLATES["upsert_relationship"] - source_id = item["source_id"] - if isinstance(source_id, str) and "" in source_id: - chunk_ids = source_id.split("") - else: - chunk_ids = [source_id] - - data: dict[str, Any] = { - "workspace": self.db.workspace, - "id": item["__id__"], - "source_id": item["src_id"], - "target_id": item["tgt_id"], - "content": item["content"], - "content_vector": json.dumps(item["__vector__"].tolist()), - "chunk_ids": chunk_ids, - "file_path": item.get("file_path", None), - "create_time": current_time, - "update_time": current_time, - } - return upsert_sql, data - - async def upsert(self, data: dict[str, dict[str, Any]]) -> None: - logger.debug(f"Inserting {len(data)} to {self.namespace}") - if not data: - return - - # Get current UTC time and convert to naive datetime for database storage - current_time = datetime.datetime.now(timezone.utc).replace(tzinfo=None) - list_data = [ - { - "__id__": k, - **{k1: v1 for k1, v1 in v.items()}, - } - for k, v in data.items() - ] - contents = [v["content"] for v in data.values()] - batches = [ - contents[i : i + self._max_batch_size] - for i in range(0, len(contents), self._max_batch_size) - ] - - embedding_tasks = [self.embedding_func(batch) for batch in batches] - embeddings_list = await asyncio.gather(*embedding_tasks) - - embeddings = np.concatenate(embeddings_list) - for i, d in enumerate(list_data): - d["__vector__"] = embeddings[i] - for item in list_data: - if is_namespace(self.namespace, NameSpace.VECTOR_STORE_CHUNKS): - upsert_sql, data = self._upsert_chunks(item, current_time) - elif is_namespace(self.namespace, NameSpace.VECTOR_STORE_ENTITIES): - upsert_sql, data = self._upsert_entities(item, current_time) - elif is_namespace(self.namespace, NameSpace.VECTOR_STORE_RELATIONSHIPS): - upsert_sql, data = self._upsert_relationships(item, current_time) - else: - raise ValueError(f"{self.namespace} is not supported") - - await self.db.execute(upsert_sql, data) - - #################### query method ############### - async def query( - self, query: str, top_k: int, ids: list[str] | None = None - ) -> list[dict[str, Any]]: - embeddings = await self.embedding_func( - [query], _priority=5 - ) # higher priority for query - embedding = embeddings[0] - embedding_string = ",".join(map(str, embedding)) - # Use parameterized document IDs (None means search across all documents) - sql = SQL_TEMPLATES[self.namespace].format(embedding_string=embedding_string) - params = { - "workspace": self.db.workspace, - "doc_ids": ids, - "better_than_threshold": self.cosine_better_than_threshold, - "top_k": top_k, - } - results = await self.db.query(sql, params=params, multirows=True) - return results - - async def index_done_callback(self) -> None: - # PG handles persistence automatically - pass - - async def delete(self, ids: list[str]) -> None: - """Delete vectors with specified IDs from the storage. - - Args: - ids: List of vector IDs to be deleted - """ - if not ids: - return - - table_name = namespace_to_table_name(self.namespace) - if not table_name: - logger.error(f"Unknown namespace for vector deletion: {self.namespace}") - return - - delete_sql = f"DELETE FROM {table_name} WHERE workspace=$1 AND id = ANY($2)" - - try: - await self.db.execute( - delete_sql, {"workspace": self.db.workspace, "ids": ids} - ) - logger.debug( - f"Successfully deleted {len(ids)} vectors from {self.namespace}" - ) - except Exception as e: - logger.error(f"Error while deleting vectors from {self.namespace}: {e}") - - async def delete_entity(self, entity_name: str) -> None: - """Delete an entity by its name from the vector storage. - - Args: - entity_name: The name of the entity to delete - """ - try: - # Construct SQL to delete the entity - delete_sql = """DELETE FROM LIGHTRAG_VDB_ENTITY - WHERE workspace=$1 AND entity_name=$2""" - - await self.db.execute( - delete_sql, {"workspace": self.db.workspace, "entity_name": entity_name} - ) - logger.debug(f"Successfully deleted entity {entity_name}") - except Exception as e: - logger.error(f"Error deleting entity {entity_name}: {e}") - - async def delete_entity_relation(self, entity_name: str) -> None: - """Delete all relations associated with an entity. - - Args: - entity_name: The name of the entity whose relations should be deleted - """ - try: - # Delete relations where the entity is either the source or target - delete_sql = """DELETE FROM LIGHTRAG_VDB_RELATION - WHERE workspace=$1 AND (source_id=$2 OR target_id=$2)""" - - await self.db.execute( - delete_sql, {"workspace": self.db.workspace, "entity_name": entity_name} - ) - logger.debug(f"Successfully deleted relations for entity {entity_name}") - except Exception as e: - logger.error(f"Error deleting relations for entity {entity_name}: {e}") - - async def get_by_id(self, id: str) -> dict[str, Any] | None: - """Get vector data by its ID - - Args: - id: The unique identifier of the vector - - Returns: - The vector data if found, or None if not found - """ - table_name = namespace_to_table_name(self.namespace) - if not table_name: - logger.error(f"Unknown namespace for ID lookup: {self.namespace}") - return None - - query = f"SELECT *, EXTRACT(EPOCH FROM create_time)::BIGINT as created_at FROM {table_name} WHERE workspace=$1 AND id=$2" - params = {"workspace": self.db.workspace, "id": id} - - try: - result = await self.db.query(query, params) - if result: - return dict(result) - return None - except Exception as e: - logger.error(f"Error retrieving vector data for ID {id}: {e}") - return None - - async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: - """Get multiple vector data by their IDs - - Args: - ids: List of unique identifiers - - Returns: - List of vector data objects that were found - """ - if not ids: - return [] - - table_name = namespace_to_table_name(self.namespace) - if not table_name: - logger.error(f"Unknown namespace for IDs lookup: {self.namespace}") - return [] - - ids_str = ",".join([f"'{id}'" for id in ids]) - query = f"SELECT *, EXTRACT(EPOCH FROM create_time)::BIGINT as created_at FROM {table_name} WHERE workspace=$1 AND id IN ({ids_str})" - params = {"workspace": self.db.workspace} - - try: - results = await self.db.query(query, params, multirows=True) - return [dict(record) for record in results] - except Exception as e: - logger.error(f"Error retrieving vector data for IDs {ids}: {e}") - return [] - - async def drop(self) -> dict[str, str]: - """Drop the storage""" - try: - table_name = namespace_to_table_name(self.namespace) - if not table_name: - return { - "status": "error", - "message": f"Unknown namespace: {self.namespace}", - } - - drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format( - table_name=table_name - ) - await self.db.execute(drop_sql, {"workspace": self.db.workspace}) - return {"status": "success", "message": "data dropped"} - except Exception as e: - return {"status": "error", "message": str(e)} - - -@final -@dataclass -class PGDocStatusStorage(DocStatusStorage): - db: PostgreSQLDB = field(default=None) - - def _format_datetime_with_timezone(self, dt): - """Convert datetime to ISO format string with timezone info""" - if dt is None: - return None - # If no timezone info, assume it's UTC time (as stored in database) - if dt.tzinfo is None: - dt = dt.replace(tzinfo=timezone.utc) - # If datetime already has timezone info, keep it as is - return dt.isoformat() - - async def initialize(self): - if self.db is None: - #self.db = await ClientManager.get_client() - self.db = await ClientManager.get_client(self.global_config) - # Implement workspace priority: PostgreSQLDB.workspace > self.workspace > "default" - if self.db.workspace: - # Use PostgreSQLDB's workspace (highest priority) - final_workspace = self.db.workspace - elif hasattr(self, "workspace") and self.workspace: - # Use storage class's workspace (medium priority) - final_workspace = self.workspace - self.db.workspace = final_workspace - else: - # Use "default" for compatibility (lowest priority) - final_workspace = global_config["vector_db_storage_cls_kwargs"].get("workspace") - self.db.workspace = final_workspace - - async def finalize(self): - if self.db is not None: - await ClientManager.release_client(self.db) - self.db = None - - async def filter_keys(self, keys: set[str]) -> set[str]: - """Filter out duplicated content""" - sql = SQL_TEMPLATES["filter_keys"].format( - table_name=namespace_to_table_name(self.namespace), - ids=",".join([f"'{id}'" for id in keys]), - ) - params = {"workspace": self.db.workspace} - try: - res = await self.db.query(sql, params, multirows=True) - if res: - exist_keys = [key["id"] for key in res] - else: - exist_keys = [] - new_keys = set([s for s in keys if s not in exist_keys]) - # print(f"keys: {keys}") - # print(f"new_keys: {new_keys}") - return new_keys - except Exception as e: - logger.error( - f"PostgreSQL database,\nsql:{sql},\nparams:{params},\nerror:{e}" - ) - raise - - async def get_by_id(self, id: str) -> Union[dict[str, Any], None]: - sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and id=$2" - params = {"workspace": self.db.workspace, "id": id} - result = await self.db.query(sql, params, True) - if result is None or result == []: - return None - else: - # Parse chunks_list JSON string back to list - chunks_list = result[0].get("chunks_list", []) - if isinstance(chunks_list, str): - try: - chunks_list = json.loads(chunks_list) - except json.JSONDecodeError: - chunks_list = [] - - # Convert datetime objects to ISO format strings with timezone info - created_at = self._format_datetime_with_timezone(result[0]["created_at"]) - updated_at = self._format_datetime_with_timezone(result[0]["updated_at"]) - - return dict( - content=result[0]["content"], - content_length=result[0]["content_length"], - content_summary=result[0]["content_summary"], - status=result[0]["status"], - chunks_count=result[0]["chunks_count"], - created_at=created_at, - updated_at=updated_at, - file_path=result[0]["file_path"], - chunks_list=chunks_list, - ) - - async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]: - """Get doc_chunks data by multiple IDs.""" - if not ids: - return [] - - sql = "SELECT * FROM LIGHTRAG_DOC_STATUS WHERE workspace=$1 AND id = ANY($2)" - params = {"workspace": self.db.workspace, "ids": ids} - - results = await self.db.query(sql, params, True) - - if not results: - return [] - - processed_results = [] - for row in results: - # Parse chunks_list JSON string back to list - chunks_list = row.get("chunks_list", []) - if isinstance(chunks_list, str): - try: - chunks_list = json.loads(chunks_list) - except json.JSONDecodeError: - chunks_list = [] - - # Convert datetime objects to ISO format strings with timezone info - created_at = self._format_datetime_with_timezone(row["created_at"]) - updated_at = self._format_datetime_with_timezone(row["updated_at"]) - - processed_results.append( - { - "content": row["content"], - "content_length": row["content_length"], - "content_summary": row["content_summary"], - "status": row["status"], - "chunks_count": row["chunks_count"], - "created_at": created_at, - "updated_at": updated_at, - "file_path": row["file_path"], - "chunks_list": chunks_list, - } - ) - - return processed_results - - async def get_status_counts(self) -> dict[str, int]: - """Get counts of documents in each status""" - sql = """SELECT status as "status", COUNT(1) as "count" - FROM LIGHTRAG_DOC_STATUS - where workspace=$1 GROUP BY STATUS - """ - result = await self.db.query(sql, {"workspace": self.db.workspace}, True) - counts = {} - for doc in result: - counts[doc["status"]] = doc["count"] - return counts - - async def get_docs_by_status( - self, status: DocStatus - ) -> dict[str, DocProcessingStatus]: - """all documents with a specific status""" - sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$2" - params = {"workspace": self.db.workspace, "status": status.value} - result = await self.db.query(sql, params, True) - - docs_by_status = {} - for element in result: - # Parse chunks_list JSON string back to list - chunks_list = element.get("chunks_list", []) - if isinstance(chunks_list, str): - try: - chunks_list = json.loads(chunks_list) - except json.JSONDecodeError: - chunks_list = [] - - # Convert datetime objects to ISO format strings with timezone info - created_at = self._format_datetime_with_timezone(element["created_at"]) - updated_at = self._format_datetime_with_timezone(element["updated_at"]) - - docs_by_status[element["id"]] = DocProcessingStatus( - content=element["content"], - content_summary=element["content_summary"], - content_length=element["content_length"], - status=element["status"], - created_at=created_at, - updated_at=updated_at, - chunks_count=element["chunks_count"], - file_path=element["file_path"], - chunks_list=chunks_list, - ) - - return docs_by_status - - async def index_done_callback(self) -> None: - # PG handles persistence automatically - pass - - async def delete(self, ids: list[str]) -> None: - """Delete specific records from storage by their IDs - - Args: - ids (list[str]): List of document IDs to be deleted from storage - - Returns: - None - """ - if not ids: - return - - table_name = namespace_to_table_name(self.namespace) - if not table_name: - logger.error(f"Unknown namespace for deletion: {self.namespace}") - return - - delete_sql = f"DELETE FROM {table_name} WHERE workspace=$1 AND id = ANY($2)" - - try: - await self.db.execute( - delete_sql, {"workspace": self.db.workspace, "ids": ids} - ) - logger.debug( - f"Successfully deleted {len(ids)} records from {self.namespace}" - ) - except Exception as e: - logger.error(f"Error while deleting records from {self.namespace}: {e}") - - async def upsert(self, data: dict[str, dict[str, Any]]) -> None: - """Update or insert document status - - Args: - data: dictionary of document IDs and their status data - """ - logger.debug(f"Inserting {len(data)} to {self.namespace}") - if not data: - return - - def parse_datetime(dt_str): - """Parse datetime and ensure it's stored as UTC time in database""" - if dt_str is None: - return None - if isinstance(dt_str, (datetime.date, datetime.datetime)): - # If it's a datetime object - if isinstance(dt_str, datetime.datetime): - # If no timezone info, assume it's UTC - if dt_str.tzinfo is None: - dt_str = dt_str.replace(tzinfo=timezone.utc) - # Convert to UTC and remove timezone info for storage - return dt_str.astimezone(timezone.utc).replace(tzinfo=None) - return dt_str - try: - # Process ISO format string with timezone - dt = datetime.datetime.fromisoformat(dt_str) - # If no timezone info, assume it's UTC - if dt.tzinfo is None: - dt = dt.replace(tzinfo=timezone.utc) - # Convert to UTC and remove timezone info for storage - return dt.astimezone(timezone.utc).replace(tzinfo=None) - except (ValueError, TypeError): - logger.warning(f"Unable to parse datetime string: {dt_str}") - return None - - # Modified SQL to include created_at, updated_at, and chunks_list in both INSERT and UPDATE operations - # All fields are updated from the input data in both INSERT and UPDATE cases - sql = """insert into LIGHTRAG_DOC_STATUS(workspace,id,content,content_summary,content_length,chunks_count,status,file_path,chunks_list,created_at,updated_at) - values($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11) - on conflict(id,workspace) do update set - content = EXCLUDED.content, - content_summary = EXCLUDED.content_summary, - content_length = EXCLUDED.content_length, - chunks_count = EXCLUDED.chunks_count, - status = EXCLUDED.status, - file_path = EXCLUDED.file_path, - chunks_list = EXCLUDED.chunks_list, - created_at = EXCLUDED.created_at, - updated_at = EXCLUDED.updated_at""" - for k, v in data.items(): - # Remove timezone information, store utc time in db - created_at = parse_datetime(v.get("created_at")) - updated_at = parse_datetime(v.get("updated_at")) - - # chunks_count and chunks_list are optional - await self.db.execute( - sql, - { - "workspace": self.db.workspace, - "id": k, - "content": v["content"], - "content_summary": v["content_summary"], - "content_length": v["content_length"], - "chunks_count": v["chunks_count"] if "chunks_count" in v else -1, - "status": v["status"], - "file_path": v["file_path"], - "chunks_list": json.dumps(v.get("chunks_list", [])), - "created_at": created_at, # Use the converted datetime object - "updated_at": updated_at, # Use the converted datetime object - }, - ) - - async def drop(self) -> dict[str, str]: - """Drop the storage""" - try: - table_name = namespace_to_table_name(self.namespace) - if not table_name: - return { - "status": "error", - "message": f"Unknown namespace: {self.namespace}", - } - - drop_sql = SQL_TEMPLATES["drop_specifiy_table_workspace"].format( - table_name=table_name - ) - await self.db.execute(drop_sql, {"workspace": self.db.workspace}) - return {"status": "success", "message": "data dropped"} - except Exception as e: - return {"status": "error", "message": str(e)} - - -class PGGraphQueryException(Exception): - """Exception for the AGE queries.""" - - def __init__(self, exception: Union[str, dict[str, Any]]) -> None: - if isinstance(exception, dict): - self.message = exception["message"] if "message" in exception else "unknown" - self.details = exception["details"] if "details" in exception else "unknown" - else: - self.message = exception - self.details = "unknown" - - def get_message(self) -> str: - return self.message - - def get_details(self) -> Any: - return self.details - - -@final -@dataclass -class PGGraphStorage(BaseGraphStorage): - def __post_init__(self): - # Graph name will be dynamically generated in initialize() based on workspace - self.db: PostgreSQLDB | None = None - - def _get_workspace_graph_name(self) -> str: - """ - Generate graph name based on workspace and namespace for data isolation. - Rules: - - If workspace is empty or "default": graph_name = namespace - - If workspace has other value: graph_name = workspace_namespace - - Args: - None - - Returns: - str: The graph name for the current workspace - """ - workspace = getattr(self, "workspace", None) - namespace = self.namespace - - if workspace and workspace.strip() and workspace.strip().lower() != "default": - # Ensure names comply with PostgreSQL identifier specifications - safe_workspace = re.sub(r"[^a-zA-Z0-9_]", "_", workspace.strip()) - safe_namespace = re.sub(r"[^a-zA-Z0-9_]", "_", namespace) - return f"{safe_workspace}_{safe_namespace}" - else: - # When workspace is empty or "default", use namespace directly - return re.sub(r"[^a-zA-Z0-9_]", "_", namespace) - - @staticmethod - def _normalize_node_id(node_id: str) -> str: - """ - Normalize node ID to ensure special characters are properly handled in Cypher queries. - - Args: - node_id: The original node ID - - Returns: - Normalized node ID suitable for Cypher queries - """ - # Escape backslashes - normalized_id = node_id - normalized_id = normalized_id.replace("\\", "\\\\") - normalized_id = normalized_id.replace('"', '\\"') - return normalized_id - - async def initialize(self): - if self.db is None: - self.db = await ClientManager.get_client() - # Implement workspace priority: PostgreSQLDB.workspace > self.workspace > None - if self.db.workspace: - # Use PostgreSQLDB's workspace (highest priority) - final_workspace = self.db.workspace - elif hasattr(self, "workspace") and self.workspace: - # Use storage class's workspace (medium priority) - final_workspace = self.workspace - self.db.workspace = final_workspace - else: - # Use None for compatibility (lowest priority) - final_workspace = None - self.db.workspace = final_workspace - - # Dynamically generate graph name based on workspace - self.workspace = self.db.workspace - self.graph_name = self._get_workspace_graph_name() - - # Log the graph initialization for debugging - logger.info( - f"PostgreSQL Graph initialized: workspace='{self.workspace}', graph_name='{self.graph_name}'" - ) - - # Execute each statement separately and ignore errors - queries = [ - f"SELECT create_graph('{self.graph_name}')", - f"SELECT create_vlabel('{self.graph_name}', 'base');", - f"SELECT create_elabel('{self.graph_name}', 'DIRECTED');", - # f'CREATE INDEX CONCURRENTLY vertex_p_idx ON {self.graph_name}."_ag_label_vertex" (id)', - f'CREATE INDEX CONCURRENTLY vertex_idx_node_id ON {self.graph_name}."_ag_label_vertex" (ag_catalog.agtype_access_operator(properties, \'"entity_id"\'::agtype))', - # f'CREATE INDEX CONCURRENTLY edge_p_idx ON {self.graph_name}."_ag_label_edge" (id)', - f'CREATE INDEX CONCURRENTLY edge_sid_idx ON {self.graph_name}."_ag_label_edge" (start_id)', - f'CREATE INDEX CONCURRENTLY edge_eid_idx ON {self.graph_name}."_ag_label_edge" (end_id)', - f'CREATE INDEX CONCURRENTLY edge_seid_idx ON {self.graph_name}."_ag_label_edge" (start_id,end_id)', - f'CREATE INDEX CONCURRENTLY directed_p_idx ON {self.graph_name}."DIRECTED" (id)', - f'CREATE INDEX CONCURRENTLY directed_eid_idx ON {self.graph_name}."DIRECTED" (end_id)', - f'CREATE INDEX CONCURRENTLY directed_sid_idx ON {self.graph_name}."DIRECTED" (start_id)', - f'CREATE INDEX CONCURRENTLY directed_seid_idx ON {self.graph_name}."DIRECTED" (start_id,end_id)', - f'CREATE INDEX CONCURRENTLY entity_p_idx ON {self.graph_name}."base" (id)', - f'CREATE INDEX CONCURRENTLY entity_idx_node_id ON {self.graph_name}."base" (ag_catalog.agtype_access_operator(properties, \'"entity_id"\'::agtype))', - f'CREATE INDEX CONCURRENTLY entity_node_id_gin_idx ON {self.graph_name}."base" using gin(properties)', - f'ALTER TABLE {self.graph_name}."DIRECTED" CLUSTER ON directed_sid_idx', - ] - - for query in queries: - # Use the new flag to silently ignore "already exists" errors - # at the source, preventing log spam. - await self.db.execute( - query, - upsert=True, - ignore_if_exists=True, # Pass the new flag - with_age=True, - graph_name=self.graph_name, - ) - - async def finalize(self): - if self.db is not None: - await ClientManager.release_client(self.db) - self.db = None - - async def index_done_callback(self) -> None: - # PG handles persistence automatically - pass - - @staticmethod - def _record_to_dict(record: asyncpg.Record) -> dict[str, Any]: - """ - Convert a record returned from an age query to a dictionary - - Args: - record (): a record from an age query result - - Returns: - dict[str, Any]: a dictionary representation of the record where - the dictionary key is the field name and the value is the - value converted to a python type - """ - # result holder - d = {} - - # prebuild a mapping of vertex_id to vertex mappings to be used - # later to build edges - vertices = {} - for k in record.keys(): - v = record[k] - # agtype comes back '{key: value}::type' which must be parsed - if isinstance(v, str) and "::" in v: - if v.startswith("[") and v.endswith("]"): - if "::vertex" not in v: - continue - v = v.replace("::vertex", "") - vertexes = json.loads(v) - for vertex in vertexes: - vertices[vertex["id"]] = vertex.get("properties") - else: - dtype = v.split("::")[-1] - v = v.split("::")[0] - if dtype == "vertex": - vertex = json.loads(v) - vertices[vertex["id"]] = vertex.get("properties") - - # iterate returned fields and parse appropriately - for k in record.keys(): - v = record[k] - if isinstance(v, str) and "::" in v: - if v.startswith("[") and v.endswith("]"): - if "::vertex" in v: - v = v.replace("::vertex", "") - d[k] = json.loads(v) - - elif "::edge" in v: - v = v.replace("::edge", "") - d[k] = json.loads(v) - else: - print("WARNING: unsupported type") - continue - - else: - dtype = v.split("::")[-1] - v = v.split("::")[0] - if dtype == "vertex": - d[k] = json.loads(v) - elif dtype == "edge": - d[k] = json.loads(v) - else: - d[k] = v # Keep as string - - return d - - @staticmethod - def _format_properties( - properties: dict[str, Any], _id: Union[str, None] = None - ) -> str: - """ - Convert a dictionary of properties to a string representation that - can be used in a cypher query insert/merge statement. - - Args: - properties (dict[str,str]): a dictionary containing node/edge properties - _id (Union[str, None]): the id of the node or None if none exists - - Returns: - str: the properties dictionary as a properly formatted string - """ - props = [] - # wrap property key in backticks to escape - for k, v in properties.items(): - prop = f"`{k}`: {json.dumps(v)}" - props.append(prop) - if _id is not None and "id" not in properties: - props.append( - f"id: {json.dumps(_id)}" if isinstance(_id, str) else f"id: {_id}" - ) - return "{" + ", ".join(props) + "}" - - async def _query( - self, - query: str, - readonly: bool = True, - upsert: bool = False, - ) -> list[dict[str, Any]]: - """ - Query the graph by taking a cypher query, converting it to an - age compatible query, executing it and converting the result - - Args: - query (str): a cypher query to be executed - - Returns: - list[dict[str, Any]]: a list of dictionaries containing the result set - """ - try: - if readonly: - data = await self.db.query( - query, - multirows=True, - with_age=True, - graph_name=self.graph_name, - ) - else: - data = await self.db.execute( - query, - upsert=upsert, - with_age=True, - graph_name=self.graph_name, - ) - - except Exception as e: - raise PGGraphQueryException( - { - "message": f"Error executing graph query: {query}", - "wrapped": query, - "detail": str(e), - } - ) from e - - if data is None: - result = [] - # decode records - else: - result = [self._record_to_dict(d) for d in data] - - return result - - async def has_node(self, node_id: str) -> bool: - entity_name_label = self._normalize_node_id(node_id) - - query = """SELECT * FROM cypher('%s', $$ - MATCH (n:base {entity_id: "%s"}) - RETURN count(n) > 0 AS node_exists - $$) AS (node_exists bool)""" % (self.graph_name, entity_name_label) - - single_result = (await self._query(query))[0] - - return single_result["node_exists"] - - async def has_edge(self, source_node_id: str, target_node_id: str) -> bool: - src_label = self._normalize_node_id(source_node_id) - tgt_label = self._normalize_node_id(target_node_id) - - query = """SELECT * FROM cypher('%s', $$ - MATCH (a:base {entity_id: "%s"})-[r]-(b:base {entity_id: "%s"}) - RETURN COUNT(r) > 0 AS edge_exists - $$) AS (edge_exists bool)""" % ( - self.graph_name, - src_label, - tgt_label, - ) - - single_result = (await self._query(query))[0] - - return single_result["edge_exists"] - - async def get_node(self, node_id: str) -> dict[str, str] | None: - """Get node by its label identifier, return only node properties""" - - label = self._normalize_node_id(node_id) - query = """SELECT * FROM cypher('%s', $$ - MATCH (n:base {entity_id: "%s"}) - RETURN n - $$) AS (n agtype)""" % (self.graph_name, label) - record = await self._query(query) - if record: - node = record[0] - node_dict = node["n"]["properties"] - - # Process string result, parse it to JSON dictionary - if isinstance(node_dict, str): - try: - node_dict = json.loads(node_dict) - except json.JSONDecodeError: - logger.warning(f"Failed to parse node string: {node_dict}") - - return node_dict - return None - - async def node_degree(self, node_id: str) -> int: - label = self._normalize_node_id(node_id) - - query = """SELECT * FROM cypher('%s', $$ - MATCH (n:base {entity_id: "%s"})-[r]-() - RETURN count(r) AS total_edge_count - $$) AS (total_edge_count integer)""" % (self.graph_name, label) - record = (await self._query(query))[0] - if record: - edge_count = int(record["total_edge_count"]) - return edge_count - - async def edge_degree(self, src_id: str, tgt_id: str) -> int: - src_degree = await self.node_degree(src_id) - trg_degree = await self.node_degree(tgt_id) - - # Convert None to 0 for addition - src_degree = 0 if src_degree is None else src_degree - trg_degree = 0 if trg_degree is None else trg_degree - - degrees = int(src_degree) + int(trg_degree) - - return degrees - - async def get_edge( - self, source_node_id: str, target_node_id: str - ) -> dict[str, str] | None: - """Get edge properties between two nodes""" - - src_label = self._normalize_node_id(source_node_id) - tgt_label = self._normalize_node_id(target_node_id) - - query = """SELECT * FROM cypher('%s', $$ - MATCH (a:base {entity_id: "%s"})-[r]-(b:base {entity_id: "%s"}) - RETURN properties(r) as edge_properties - LIMIT 1 - $$) AS (edge_properties agtype)""" % ( - self.graph_name, - src_label, - tgt_label, - ) - record = await self._query(query) - if record and record[0] and record[0]["edge_properties"]: - result = record[0]["edge_properties"] - - # Process string result, parse it to JSON dictionary - if isinstance(result, str): - try: - result = json.loads(result) - except json.JSONDecodeError: - logger.warning(f"Failed to parse edge string: {result}") - - return result - - async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None: - """ - Retrieves all edges (relationships) for a particular node identified by its label. - :return: list of dictionaries containing edge information - """ - label = self._normalize_node_id(source_node_id) - - query = """SELECT * FROM cypher('%s', $$ - MATCH (n:base {entity_id: "%s"}) - OPTIONAL MATCH (n)-[]-(connected:base) - RETURN n.entity_id AS source_id, connected.entity_id AS connected_id - $$) AS (source_id text, connected_id text)""" % ( - self.graph_name, - label, - ) - - results = await self._query(query) - edges = [] - for record in results: - source_id = record["source_id"] - connected_id = record["connected_id"] - - if source_id and connected_id: - edges.append((source_id, connected_id)) - - return edges - - @retry( - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=4, max=10), - retry=retry_if_exception_type((PGGraphQueryException,)), - ) - async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None: - """ - Upsert a node in the Neo4j database. - - Args: - node_id: The unique identifier for the node (used as label) - node_data: Dictionary of node properties - """ - if "entity_id" not in node_data: - raise ValueError( - "PostgreSQL: node properties must contain an 'entity_id' field" - ) - - label = self._normalize_node_id(node_id) - properties = self._format_properties(node_data) - - query = """SELECT * FROM cypher('%s', $$ - MERGE (n:base {entity_id: "%s"}) - SET n += %s - RETURN n - $$) AS (n agtype)""" % ( - self.graph_name, - label, - properties, - ) - - try: - await self._query(query, readonly=False, upsert=True) - - except Exception: - logger.error(f"POSTGRES, upsert_node error on node_id: `{node_id}`") - raise - - @retry( - stop=stop_after_attempt(3), - wait=wait_exponential(multiplier=1, min=4, max=10), - retry=retry_if_exception_type((PGGraphQueryException,)), - ) - async def upsert_edge( - self, source_node_id: str, target_node_id: str, edge_data: dict[str, str] - ) -> None: - """ - Upsert an edge and its properties between two nodes identified by their labels. - - Args: - source_node_id (str): Label of the source node (used as identifier) - target_node_id (str): Label of the target node (used as identifier) - edge_data (dict): dictionary of properties to set on the edge - """ - src_label = self._normalize_node_id(source_node_id) - tgt_label = self._normalize_node_id(target_node_id) - edge_properties = self._format_properties(edge_data) - - query = """SELECT * FROM cypher('%s', $$ - MATCH (source:base {entity_id: "%s"}) - WITH source - MATCH (target:base {entity_id: "%s"}) - MERGE (source)-[r:DIRECTED]-(target) - SET r += %s - SET r += %s - RETURN r - $$) AS (r agtype)""" % ( - self.graph_name, - src_label, - tgt_label, - edge_properties, - edge_properties, # https://github.com/HKUDS/LightRAG/issues/1438#issuecomment-2826000195 - ) - - try: - await self._query(query, readonly=False, upsert=True) - - except Exception: - logger.error( - f"POSTGRES, upsert_edge error on edge: `{source_node_id}`-`{target_node_id}`" - ) - raise - - async def delete_node(self, node_id: str) -> None: - """ - Delete a node from the graph. - - Args: - node_id (str): The ID of the node to delete. - """ - label = self._normalize_node_id(node_id) - - query = """SELECT * FROM cypher('%s', $$ - MATCH (n:base {entity_id: "%s"}) - DETACH DELETE n - $$) AS (n agtype)""" % (self.graph_name, label) - - try: - await self._query(query, readonly=False) - except Exception as e: - logger.error("Error during node deletion: {%s}", e) - raise - - async def remove_nodes(self, node_ids: list[str]) -> None: - """ - Remove multiple nodes from the graph. - - Args: - node_ids (list[str]): A list of node IDs to remove. - """ - node_ids = [self._normalize_node_id(node_id) for node_id in node_ids] - node_id_list = ", ".join([f'"{node_id}"' for node_id in node_ids]) - - query = """SELECT * FROM cypher('%s', $$ - MATCH (n:base) - WHERE n.entity_id IN [%s] - DETACH DELETE n - $$) AS (n agtype)""" % (self.graph_name, node_id_list) - - try: - await self._query(query, readonly=False) - except Exception as e: - logger.error("Error during node removal: {%s}", e) - raise - - async def remove_edges(self, edges: list[tuple[str, str]]) -> None: - """ - Remove multiple edges from the graph. - - Args: - edges (list[tuple[str, str]]): A list of edges to remove, where each edge is a tuple of (source_node_id, target_node_id). - """ - for source, target in edges: - src_label = self._normalize_node_id(source) - tgt_label = self._normalize_node_id(target) - - query = """SELECT * FROM cypher('%s', $$ - MATCH (a:base {entity_id: "%s"})-[r]-(b:base {entity_id: "%s"}) - DELETE r - $$) AS (r agtype)""" % (self.graph_name, src_label, tgt_label) - - try: - await self._query(query, readonly=False) - logger.debug(f"Deleted edge from '{source}' to '{target}'") - except Exception as e: - logger.error(f"Error during edge deletion: {str(e)}") - raise - - async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]: - """ - Retrieve multiple nodes in one query using UNWIND. - - Args: - node_ids: List of node entity IDs to fetch. - - Returns: - A dictionary mapping each node_id to its node data (or None if not found). - """ - if not node_ids: - return {} - - # Format node IDs for the query - formatted_ids = ", ".join( - ['"' + self._normalize_node_id(node_id) + '"' for node_id in node_ids] - ) - - query = """SELECT * FROM cypher('%s', $$ - UNWIND [%s] AS node_id - MATCH (n:base {entity_id: node_id}) - RETURN node_id, n - $$) AS (node_id text, n agtype)""" % (self.graph_name, formatted_ids) - - results = await self._query(query) - - # Build result dictionary - nodes_dict = {} - for result in results: - if result["node_id"] and result["n"]: - node_dict = result["n"]["properties"] - - # Process string result, parse it to JSON dictionary - if isinstance(node_dict, str): - try: - node_dict = json.loads(node_dict) - except json.JSONDecodeError: - logger.warning( - f"Failed to parse node string in batch: {node_dict}" - ) - - # Remove the 'base' label if present in a 'labels' property - # if "labels" in node_dict: - # node_dict["labels"] = [ - # label for label in node_dict["labels"] if label != "base" - # ] - - nodes_dict[result["node_id"]] = node_dict - - return nodes_dict - - async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]: - """ - Retrieve the degree for multiple nodes in a single query using UNWIND. - Calculates the total degree by counting distinct relationships. - Uses separate queries for outgoing and incoming edges. - - Args: - node_ids: List of node labels (entity_id values) to look up. - - Returns: - A dictionary mapping each node_id to its degree (total number of relationships). - If a node is not found, its degree will be set to 0. - """ - if not node_ids: - return {} - - # Format node IDs for the query - formatted_ids = ", ".join( - ['"' + self._normalize_node_id(node_id) + '"' for node_id in node_ids] - ) - - outgoing_query = """SELECT * FROM cypher('%s', $$ - UNWIND [%s] AS node_id - MATCH (n:base {entity_id: node_id}) - OPTIONAL MATCH (n)-[r]->(a) - RETURN node_id, count(a) AS out_degree - $$) AS (node_id text, out_degree bigint)""" % ( - self.graph_name, - formatted_ids, - ) - - incoming_query = """SELECT * FROM cypher('%s', $$ - UNWIND [%s] AS node_id - MATCH (n:base {entity_id: node_id}) - OPTIONAL MATCH (n)<-[r]-(b) - RETURN node_id, count(b) AS in_degree - $$) AS (node_id text, in_degree bigint)""" % ( - self.graph_name, - formatted_ids, - ) - - outgoing_results = await self._query(outgoing_query) - incoming_results = await self._query(incoming_query) - - out_degrees = {} - in_degrees = {} - - for result in outgoing_results: - if result["node_id"] is not None: - out_degrees[result["node_id"]] = int(result["out_degree"]) - - for result in incoming_results: - if result["node_id"] is not None: - in_degrees[result["node_id"]] = int(result["in_degree"]) - - degrees_dict = {} - for node_id in node_ids: - out_degree = out_degrees.get(node_id, 0) - in_degree = in_degrees.get(node_id, 0) - degrees_dict[node_id] = out_degree + in_degree - - return degrees_dict - - async def edge_degrees_batch( - self, edges: list[tuple[str, str]] - ) -> dict[tuple[str, str], int]: - """ - Calculate the combined degree for each edge (sum of the source and target node degrees) - in batch using the already implemented node_degrees_batch. - - Args: - edges: List of (source_node_id, target_node_id) tuples - - Returns: - Dictionary mapping edge tuples to their combined degrees - """ - if not edges: - return {} - - # Use node_degrees_batch to get all node degrees efficiently - all_nodes = set() - for src, tgt in edges: - all_nodes.add(src) - all_nodes.add(tgt) - - node_degrees = await self.node_degrees_batch(list(all_nodes)) - - # Calculate edge degrees - edge_degrees_dict = {} - for src, tgt in edges: - src_degree = node_degrees.get(src, 0) - tgt_degree = node_degrees.get(tgt, 0) - edge_degrees_dict[(src, tgt)] = src_degree + tgt_degree - - return edge_degrees_dict - - async def get_edges_batch( - self, pairs: list[dict[str, str]] - ) -> dict[tuple[str, str], dict]: - """ - Retrieve edge properties for multiple (src, tgt) pairs in one query. - Get forward and backward edges seperately and merge them before return - - Args: - pairs: List of dictionaries, e.g. [{"src": "node1", "tgt": "node2"}, ...] - - Returns: - A dictionary mapping (src, tgt) tuples to their edge properties. - """ - if not pairs: - return {} - - src_nodes = [] - tgt_nodes = [] - for pair in pairs: - src_nodes.append(self._normalize_node_id(pair["src"])) - tgt_nodes.append(self._normalize_node_id(pair["tgt"])) - - src_array = ", ".join([f'"{src}"' for src in src_nodes]) - tgt_array = ", ".join([f'"{tgt}"' for tgt in tgt_nodes]) - - forward_query = f"""SELECT * FROM cypher('{self.graph_name}', $$ - WITH [{src_array}] AS sources, [{tgt_array}] AS targets - UNWIND range(0, size(sources)-1) AS i - MATCH (a:base {{entity_id: sources[i]}})-[r]->(b:base {{entity_id: targets[i]}}) - RETURN sources[i] AS source, targets[i] AS target, properties(r) AS edge_properties - $$) AS (source text, target text, edge_properties agtype)""" - - backward_query = f"""SELECT * FROM cypher('{self.graph_name}', $$ - WITH [{src_array}] AS sources, [{tgt_array}] AS targets - UNWIND range(0, size(sources)-1) AS i - MATCH (a:base {{entity_id: sources[i]}})<-[r]-(b:base {{entity_id: targets[i]}}) - RETURN sources[i] AS source, targets[i] AS target, properties(r) AS edge_properties - $$) AS (source text, target text, edge_properties agtype)""" - - forward_results = await self._query(forward_query) - backward_results = await self._query(backward_query) - - edges_dict = {} - - for result in forward_results: - if result["source"] and result["target"] and result["edge_properties"]: - edge_props = result["edge_properties"] - - # Process string result, parse it to JSON dictionary - if isinstance(edge_props, str): - try: - edge_props = json.loads(edge_props) - except json.JSONDecodeError: - logger.warning( - f"Failed to parse edge properties string: {edge_props}" - ) - continue - - edges_dict[(result["source"], result["target"])] = edge_props - - for result in backward_results: - if result["source"] and result["target"] and result["edge_properties"]: - edge_props = result["edge_properties"] - - # Process string result, parse it to JSON dictionary - if isinstance(edge_props, str): - try: - edge_props = json.loads(edge_props) - except json.JSONDecodeError: - logger.warning( - f"Failed to parse edge properties string: {edge_props}" - ) - continue - - edges_dict[(result["source"], result["target"])] = edge_props - - return edges_dict - - async def get_nodes_edges_batch( - self, node_ids: list[str] - ) -> dict[str, list[tuple[str, str]]]: - """ - Get all edges (both outgoing and incoming) for multiple nodes in a single batch operation. - - Args: - node_ids: List of node IDs to get edges for - - Returns: - Dictionary mapping node IDs to lists of (source, target) edge tuples - """ - if not node_ids: - return {} - - # Format node IDs for the query - formatted_ids = ", ".join( - ['"' + self._normalize_node_id(node_id) + '"' for node_id in node_ids] - ) - - outgoing_query = """SELECT * FROM cypher('%s', $$ - UNWIND [%s] AS node_id - MATCH (n:base {entity_id: node_id}) - OPTIONAL MATCH (n:base)-[]->(connected:base) - RETURN node_id, connected.entity_id AS connected_id - $$) AS (node_id text, connected_id text)""" % ( - self.graph_name, - formatted_ids, - ) - - incoming_query = """SELECT * FROM cypher('%s', $$ - UNWIND [%s] AS node_id - MATCH (n:base {entity_id: node_id}) - OPTIONAL MATCH (n:base)<-[]-(connected:base) - RETURN node_id, connected.entity_id AS connected_id - $$) AS (node_id text, connected_id text)""" % ( - self.graph_name, - formatted_ids, - ) - - outgoing_results = await self._query(outgoing_query) - incoming_results = await self._query(incoming_query) - - nodes_edges_dict = {node_id: [] for node_id in node_ids} - - for result in outgoing_results: - if result["node_id"] and result["connected_id"]: - nodes_edges_dict[result["node_id"]].append( - (result["node_id"], result["connected_id"]) - ) - - for result in incoming_results: - if result["node_id"] and result["connected_id"]: - nodes_edges_dict[result["node_id"]].append( - (result["connected_id"], result["node_id"]) - ) - - return nodes_edges_dict - - async def get_all_labels(self) -> list[str]: - """ - Get all labels (node IDs) in the graph. - - Returns: - list[str]: A list of all labels in the graph. - """ - query = ( - """SELECT * FROM cypher('%s', $$ - MATCH (n:base) - WHERE n.entity_id IS NOT NULL - RETURN DISTINCT n.entity_id AS label - ORDER BY n.entity_id - $$) AS (label text)""" - % self.graph_name - ) - - results = await self._query(query) - labels = [] - for result in results: - if result and isinstance(result, dict) and "label" in result: - labels.append(result["label"]) - return labels - - async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: - """ - Retrieves nodes from the graph that are associated with a given list of chunk IDs. - This method uses a Cypher query with UNWIND to efficiently find all nodes - where the `source_id` property contains any of the specified chunk IDs. - """ - # The string representation of the list for the cypher query - chunk_ids_str = json.dumps(chunk_ids) - - query = f""" - SELECT * FROM cypher('{self.graph_name}', $$ - UNWIND {chunk_ids_str} AS chunk_id - MATCH (n:base) - WHERE n.source_id IS NOT NULL AND chunk_id IN split(n.source_id, '{GRAPH_FIELD_SEP}') - RETURN n - $$) AS (n agtype); - """ - results = await self._query(query) - - # Build result list - nodes = [] - for result in results: - if result["n"]: - node_dict = result["n"]["properties"] - - # Process string result, parse it to JSON dictionary - if isinstance(node_dict, str): - try: - node_dict = json.loads(node_dict) - except json.JSONDecodeError: - logger.warning( - f"Failed to parse node string in batch: {node_dict}" - ) - - node_dict["id"] = node_dict["entity_id"] - nodes.append(node_dict) - - return nodes - - async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]: - """ - Retrieves edges from the graph that are associated with a given list of chunk IDs. - This method uses a Cypher query with UNWIND to efficiently find all edges - where the `source_id` property contains any of the specified chunk IDs. - """ - chunk_ids_str = json.dumps(chunk_ids) - - query = f""" - SELECT * FROM cypher('{self.graph_name}', $$ - UNWIND {chunk_ids_str} AS chunk_id - MATCH (a:base)-[r]-(b:base) - WHERE r.source_id IS NOT NULL AND chunk_id IN split(r.source_id, '{GRAPH_FIELD_SEP}') - RETURN DISTINCT r, startNode(r) AS source, endNode(r) AS target - $$) AS (edge agtype, source agtype, target agtype); - """ - results = await self._query(query) - edges = [] - if results: - for item in results: - edge_agtype = item["edge"]["properties"] - # Process string result, parse it to JSON dictionary - if isinstance(edge_agtype, str): - try: - edge_agtype = json.loads(edge_agtype) - except json.JSONDecodeError: - logger.warning( - f"Failed to parse edge string in batch: {edge_agtype}" - ) - - source_agtype = item["source"]["properties"] - # Process string result, parse it to JSON dictionary - if isinstance(source_agtype, str): - try: - source_agtype = json.loads(source_agtype) - except json.JSONDecodeError: - logger.warning( - f"Failed to parse node string in batch: {source_agtype}" - ) - - target_agtype = item["target"]["properties"] - # Process string result, parse it to JSON dictionary - if isinstance(target_agtype, str): - try: - target_agtype = json.loads(target_agtype) - except json.JSONDecodeError: - logger.warning( - f"Failed to parse node string in batch: {target_agtype}" - ) - - if edge_agtype and source_agtype and target_agtype: - edge_properties = edge_agtype - edge_properties["source"] = source_agtype["entity_id"] - edge_properties["target"] = target_agtype["entity_id"] - edges.append(edge_properties) - return edges - - async def _bfs_subgraph( - self, node_label: str, max_depth: int, max_nodes: int - ) -> KnowledgeGraph: - """ - Implements a true breadth-first search algorithm for subgraph retrieval. - This method is used as a fallback when the standard Cypher query is too slow - or when we need to guarantee BFS ordering. - - Args: - node_label: Label of the starting node - max_depth: Maximum depth of the subgraph - max_nodes: Maximum number of nodes to return - - Returns: - KnowledgeGraph object containing nodes and edges - """ - from collections import deque - - result = KnowledgeGraph() - visited_nodes = set() - visited_node_ids = set() - visited_edges = set() - visited_edge_pairs = set() - - # Get starting node data - label = self._normalize_node_id(node_label) - query = """SELECT * FROM cypher('%s', $$ - MATCH (n:base {entity_id: "%s"}) - RETURN id(n) as node_id, n - $$) AS (node_id bigint, n agtype)""" % (self.graph_name, label) - - node_result = await self._query(query) - if not node_result or not node_result[0].get("n"): - return result - - # Create initial KnowledgeGraphNode - start_node_data = node_result[0]["n"] - entity_id = start_node_data["properties"]["entity_id"] - internal_id = str(start_node_data["id"]) - - start_node = KnowledgeGraphNode( - id=internal_id, - labels=[entity_id], - properties=start_node_data["properties"], - ) - - # Initialize BFS queue, each element is a tuple of (node, depth) - queue = deque([(start_node, 0)]) - - visited_nodes.add(entity_id) - visited_node_ids.add(internal_id) - result.nodes.append(start_node) - - result.is_truncated = False - - # BFS search main loop - while queue: - # Get all nodes at the current depth - current_level_nodes = [] - current_depth = None - - # Determine current depth - if queue: - current_depth = queue[0][1] - - # Extract all nodes at current depth from the queue - while queue and queue[0][1] == current_depth: - node, depth = queue.popleft() - if depth > max_depth: - continue - current_level_nodes.append(node) - - if not current_level_nodes: - continue - - # Check depth limit - if current_depth > max_depth: - continue - - # Prepare node IDs list - node_ids = [node.labels[0] for node in current_level_nodes] - formatted_ids = ", ".join( - [f'"{self._normalize_node_id(node_id)}"' for node_id in node_ids] - ) - - # Construct batch query for outgoing edges - outgoing_query = f"""SELECT * FROM cypher('{self.graph_name}', $$ - UNWIND [{formatted_ids}] AS node_id - MATCH (n:base {{entity_id: node_id}}) - OPTIONAL MATCH (n)-[r]->(neighbor:base) - RETURN node_id AS current_id, - id(n) AS current_internal_id, - id(neighbor) AS neighbor_internal_id, - neighbor.entity_id AS neighbor_id, - id(r) AS edge_id, - r, - neighbor, - true AS is_outgoing - $$) AS (current_id text, current_internal_id bigint, neighbor_internal_id bigint, - neighbor_id text, edge_id bigint, r agtype, neighbor agtype, is_outgoing bool)""" - - # Construct batch query for incoming edges - incoming_query = f"""SELECT * FROM cypher('{self.graph_name}', $$ - UNWIND [{formatted_ids}] AS node_id - MATCH (n:base {{entity_id: node_id}}) - OPTIONAL MATCH (n)<-[r]-(neighbor:base) - RETURN node_id AS current_id, - id(n) AS current_internal_id, - id(neighbor) AS neighbor_internal_id, - neighbor.entity_id AS neighbor_id, - id(r) AS edge_id, - r, - neighbor, - false AS is_outgoing - $$) AS (current_id text, current_internal_id bigint, neighbor_internal_id bigint, - neighbor_id text, edge_id bigint, r agtype, neighbor agtype, is_outgoing bool)""" - - # Execute queries - outgoing_results = await self._query(outgoing_query) - incoming_results = await self._query(incoming_query) - - # Combine results - neighbors = outgoing_results + incoming_results - - # Create mapping from node ID to node object - node_map = {node.labels[0]: node for node in current_level_nodes} - - # Process all results in a single loop - for record in neighbors: - if not record.get("neighbor") or not record.get("r"): - continue - - # Get current node information - current_entity_id = record["current_id"] - current_node = node_map[current_entity_id] - - # Get neighbor node information - neighbor_entity_id = record["neighbor_id"] - neighbor_internal_id = str(record["neighbor_internal_id"]) - is_outgoing = record["is_outgoing"] - - # Determine edge direction - if is_outgoing: - source_id = current_node.id - target_id = neighbor_internal_id - else: - source_id = neighbor_internal_id - target_id = current_node.id - - if not neighbor_entity_id: - continue - - # Get edge and node information - b_node = record["neighbor"] - rel = record["r"] - edge_id = str(record["edge_id"]) - - # Create neighbor node object - neighbor_node = KnowledgeGraphNode( - id=neighbor_internal_id, - labels=[neighbor_entity_id], - properties=b_node["properties"], - ) - - # Sort entity_ids to ensure (A,B) and (B,A) are treated as the same edge - sorted_pair = tuple(sorted([current_entity_id, neighbor_entity_id])) - - # Create edge object - edge = KnowledgeGraphEdge( - id=edge_id, - type=rel["label"], - source=source_id, - target=target_id, - properties=rel["properties"], - ) - - if neighbor_internal_id in visited_node_ids: - # Add backward edge if neighbor node is already visited - if ( - edge_id not in visited_edges - and sorted_pair not in visited_edge_pairs - ): - result.edges.append(edge) - visited_edges.add(edge_id) - visited_edge_pairs.add(sorted_pair) - else: - if len(visited_node_ids) < max_nodes and current_depth < max_depth: - # Add new node to result and queue - result.nodes.append(neighbor_node) - visited_nodes.add(neighbor_entity_id) - visited_node_ids.add(neighbor_internal_id) - - # Add node to queue with incremented depth - queue.append((neighbor_node, current_depth + 1)) - - # Add forward edge - if ( - edge_id not in visited_edges - and sorted_pair not in visited_edge_pairs - ): - result.edges.append(edge) - visited_edges.add(edge_id) - visited_edge_pairs.add(sorted_pair) - else: - if current_depth < max_depth: - result.is_truncated = True - - return result - - async def get_knowledge_graph( - self, - node_label: str, - max_depth: int = 3, - max_nodes: int = None, - ) -> KnowledgeGraph: - """ - Retrieve a connected subgraph of nodes where the label includes the specified `node_label`. - - Args: - node_label: Label of the starting node, * means all nodes - max_depth: Maximum depth of the subgraph, Defaults to 3 - max_nodes: Maximum nodes to return, Defaults to global_config max_graph_nodes - - Returns: - KnowledgeGraph object containing nodes and edges, with an is_truncated flag - indicating whether the graph was truncated due to max_nodes limit - """ - # Use global_config max_graph_nodes as default if max_nodes is None - if max_nodes is None: - max_nodes = self.global_config.get("max_graph_nodes", 1000) - else: - # Limit max_nodes to not exceed global_config max_graph_nodes - max_nodes = min(max_nodes, self.global_config.get("max_graph_nodes", 1000)) - kg = KnowledgeGraph() - - # Handle wildcard query - get all nodes - if node_label == "*": - # First check total node count to determine if graph should be truncated - count_query = f"""SELECT * FROM cypher('{self.graph_name}', $$ - MATCH (n:base) - RETURN count(distinct n) AS total_nodes - $$) AS (total_nodes bigint)""" - - count_result = await self._query(count_query) - total_nodes = count_result[0]["total_nodes"] if count_result else 0 - is_truncated = total_nodes > max_nodes - - # Get max_nodes with highest degrees - query_nodes = f"""SELECT * FROM cypher('{self.graph_name}', $$ - MATCH (n:base) - OPTIONAL MATCH (n)-[r]->() - RETURN id(n) as node_id, count(r) as degree - $$) AS (node_id BIGINT, degree BIGINT) - ORDER BY degree DESC - LIMIT {max_nodes}""" - node_results = await self._query(query_nodes) - - node_ids = [str(result["node_id"]) for result in node_results] - - logger.info(f"Total nodes: {total_nodes}, Selected nodes: {len(node_ids)}") - - if node_ids: - formatted_ids = ", ".join(node_ids) - # Construct batch query for subgraph within max_nodes - query = f"""SELECT * FROM cypher('{self.graph_name}', $$ - WITH [{formatted_ids}] AS node_ids - MATCH (a) - WHERE id(a) IN node_ids - OPTIONAL MATCH (a)-[r]->(b) - WHERE id(b) IN node_ids - RETURN a, r, b - $$) AS (a AGTYPE, r AGTYPE, b AGTYPE)""" - results = await self._query(query) - - # Process query results, deduplicate nodes and edges - nodes_dict = {} - edges_dict = {} - for result in results: - # Process node a - if result.get("a") and isinstance(result["a"], dict): - node_a = result["a"] - node_id = str(node_a["id"]) - if node_id not in nodes_dict and "properties" in node_a: - nodes_dict[node_id] = KnowledgeGraphNode( - id=node_id, - labels=[node_a["properties"]["entity_id"]], - properties=node_a["properties"], - ) - - # Process node b - if result.get("b") and isinstance(result["b"], dict): - node_b = result["b"] - node_id = str(node_b["id"]) - if node_id not in nodes_dict and "properties" in node_b: - nodes_dict[node_id] = KnowledgeGraphNode( - id=node_id, - labels=[node_b["properties"]["entity_id"]], - properties=node_b["properties"], - ) - - # Process edge r - if result.get("r") and isinstance(result["r"], dict): - edge = result["r"] - edge_id = str(edge["id"]) - if edge_id not in edges_dict: - edges_dict[edge_id] = KnowledgeGraphEdge( - id=edge_id, - type=edge["label"], - source=str(edge["start_id"]), - target=str(edge["end_id"]), - properties=edge["properties"], - ) - - kg = KnowledgeGraph( - nodes=list(nodes_dict.values()), - edges=list(edges_dict.values()), - is_truncated=is_truncated, - ) - else: - # For single node query, use BFS algorithm - kg = await self._bfs_subgraph(node_label, max_depth, max_nodes) - - logger.info( - f"Subgraph query successful | Node count: {len(kg.nodes)} | Edge count: {len(kg.edges)}" - ) - else: - # For non-wildcard queries, use the BFS algorithm - kg = await self._bfs_subgraph(node_label, max_depth, max_nodes) - logger.info( - f"Subgraph query for '{node_label}' successful | Node count: {len(kg.nodes)} | Edge count: {len(kg.edges)}" - ) - - return kg - - async def drop(self) -> dict[str, str]: - """Drop the storage""" - try: - drop_query = f"""SELECT * FROM cypher('{self.graph_name}', $$ - MATCH (n) - DETACH DELETE n - $$) AS (result agtype)""" - - await self._query(drop_query, readonly=False) - return { - "status": "success", - "message": f"workspace '{self.workspace}' graph data dropped", - } - except Exception as e: - logger.error(f"Error dropping graph: {e}") - return {"status": "error", "message": str(e)} - - -NAMESPACE_TABLE_MAP = { - NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL", - NameSpace.KV_STORE_TEXT_CHUNKS: "LIGHTRAG_DOC_CHUNKS", - NameSpace.VECTOR_STORE_CHUNKS: "LIGHTRAG_VDB_CHUNKS", - NameSpace.VECTOR_STORE_ENTITIES: "LIGHTRAG_VDB_ENTITY", - NameSpace.VECTOR_STORE_RELATIONSHIPS: "LIGHTRAG_VDB_RELATION", - NameSpace.DOC_STATUS: "LIGHTRAG_DOC_STATUS", - NameSpace.KV_STORE_LLM_RESPONSE_CACHE: "LIGHTRAG_LLM_CACHE", -} - - -def namespace_to_table_name(namespace: str) -> str: - for k, v in NAMESPACE_TABLE_MAP.items(): - if is_namespace(namespace, k): - return v - - -TABLES = { - "LIGHTRAG_DOC_FULL": { - "ddl": """CREATE TABLE LIGHTRAG_DOC_FULL ( - id VARCHAR(255), - workspace VARCHAR(255), - doc_name VARCHAR(1024), - content TEXT, - meta JSONB, - create_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP, - update_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP, - CONSTRAINT LIGHTRAG_DOC_FULL_PK PRIMARY KEY (workspace, id) - )""" - }, - "LIGHTRAG_DOC_CHUNKS": { - "ddl": """CREATE TABLE LIGHTRAG_DOC_CHUNKS ( - id VARCHAR(255), - workspace VARCHAR(255), - full_doc_id VARCHAR(256), - chunk_order_index INTEGER, - tokens INTEGER, - content TEXT, - file_path TEXT NULL, - llm_cache_list JSONB NULL DEFAULT '[]'::jsonb, - create_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP, - update_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP, - CONSTRAINT LIGHTRAG_DOC_CHUNKS_PK PRIMARY KEY (workspace, id) - )""" - }, - "LIGHTRAG_VDB_CHUNKS": { - "ddl": """CREATE TABLE LIGHTRAG_VDB_CHUNKS ( - id VARCHAR(255), - workspace VARCHAR(255), - full_doc_id VARCHAR(256), - chunk_order_index INTEGER, - tokens INTEGER, - content TEXT, - content_vector VECTOR, - file_path TEXT NULL, - create_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP, - update_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP, - CONSTRAINT LIGHTRAG_VDB_CHUNKS_PK PRIMARY KEY (workspace, id) - )""" - }, - "LIGHTRAG_VDB_ENTITY": { - "ddl": """CREATE TABLE LIGHTRAG_VDB_ENTITY ( - id VARCHAR(255), - workspace VARCHAR(255), - entity_name VARCHAR(512), - content TEXT, - content_vector VECTOR, - create_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP, - update_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP, - chunk_ids VARCHAR(255)[] NULL, - file_path TEXT NULL, - CONSTRAINT LIGHTRAG_VDB_ENTITY_PK PRIMARY KEY (workspace, id) - )""" - }, - "LIGHTRAG_VDB_RELATION": { - "ddl": """CREATE TABLE LIGHTRAG_VDB_RELATION ( - id VARCHAR(255), - workspace VARCHAR(255), - source_id VARCHAR(512), - target_id VARCHAR(512), - content TEXT, - content_vector VECTOR, - create_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP, - update_time TIMESTAMP(0) DEFAULT CURRENT_TIMESTAMP, - chunk_ids VARCHAR(255)[] NULL, - file_path TEXT NULL, - CONSTRAINT LIGHTRAG_VDB_RELATION_PK PRIMARY KEY (workspace, id) - )""" - }, - "LIGHTRAG_LLM_CACHE": { - "ddl": """CREATE TABLE LIGHTRAG_LLM_CACHE ( - workspace varchar(255) NOT NULL, - id varchar(255) NOT NULL, - mode varchar(32) NOT NULL, - original_prompt TEXT, - return_value TEXT, - chunk_id VARCHAR(255) NULL, - create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - update_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - CONSTRAINT LIGHTRAG_LLM_CACHE_PK PRIMARY KEY (workspace, mode, id) - )""" - }, - "LIGHTRAG_DOC_STATUS": { - "ddl": """CREATE TABLE LIGHTRAG_DOC_STATUS ( - workspace varchar(255) NOT NULL, - id varchar(255) NOT NULL, - content TEXT NULL, - content_summary varchar(255) NULL, - content_length int4 NULL, - chunks_count int4 NULL, - status varchar(64) NULL, - file_path TEXT NULL, - chunks_list JSONB NULL DEFAULT '[]'::jsonb, - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - CONSTRAINT LIGHTRAG_DOC_STATUS_PK PRIMARY KEY (workspace, id) - )""" - }, -} - - -SQL_TEMPLATES = { - # SQL for KVStorage - "get_by_id_full_docs": """SELECT id, COALESCE(content, '') as content - FROM LIGHTRAG_DOC_FULL WHERE workspace=$1 AND id=$2 - """, - "get_by_id_text_chunks": """SELECT id, tokens, COALESCE(content, '') as content, - chunk_order_index, full_doc_id, file_path, - COALESCE(llm_cache_list, '[]'::jsonb) as llm_cache_list, - EXTRACT(EPOCH FROM create_time)::BIGINT as create_time, - EXTRACT(EPOCH FROM update_time)::BIGINT as update_time - FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id=$2 - """, - "get_by_id_llm_response_cache": """SELECT id, original_prompt, return_value, mode, chunk_id, cache_type, - EXTRACT(EPOCH FROM create_time)::BIGINT as create_time, - EXTRACT(EPOCH FROM update_time)::BIGINT as update_time - FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND id=$2 - """, - "get_by_mode_id_llm_response_cache": """SELECT id, original_prompt, return_value, mode, chunk_id - FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND mode=$2 AND id=$3 - """, - "get_by_ids_full_docs": """SELECT id, COALESCE(content, '') as content - FROM LIGHTRAG_DOC_FULL WHERE workspace=$1 AND id IN ({ids}) - """, - "get_by_ids_text_chunks": """SELECT id, tokens, COALESCE(content, '') as content, - chunk_order_index, full_doc_id, file_path, - COALESCE(llm_cache_list, '[]'::jsonb) as llm_cache_list, - EXTRACT(EPOCH FROM create_time)::BIGINT as create_time, - EXTRACT(EPOCH FROM update_time)::BIGINT as update_time - FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id IN ({ids}) - """, - "get_by_ids_llm_response_cache": """SELECT id, original_prompt, return_value, mode, chunk_id, cache_type, - EXTRACT(EPOCH FROM create_time)::BIGINT as create_time, - EXTRACT(EPOCH FROM update_time)::BIGINT as update_time - FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND id IN ({ids}) - """, - "filter_keys": "SELECT id FROM {table_name} WHERE workspace=$1 AND id IN ({ids})", - "upsert_doc_full": """INSERT INTO LIGHTRAG_DOC_FULL (id, content, workspace) - VALUES ($1, $2, $3) - ON CONFLICT (workspace,id) DO UPDATE - SET content = $2, update_time = CURRENT_TIMESTAMP - """, - "upsert_llm_response_cache": """INSERT INTO LIGHTRAG_LLM_CACHE(workspace,id,original_prompt,return_value,mode,chunk_id,cache_type) - VALUES ($1, $2, $3, $4, $5, $6, $7) - ON CONFLICT (workspace,mode,id) DO UPDATE - SET original_prompt = EXCLUDED.original_prompt, - return_value=EXCLUDED.return_value, - mode=EXCLUDED.mode, - chunk_id=EXCLUDED.chunk_id, - cache_type=EXCLUDED.cache_type, - update_time = CURRENT_TIMESTAMP - """, - "upsert_text_chunk": """INSERT INTO LIGHTRAG_DOC_CHUNKS (workspace, id, tokens, - chunk_order_index, full_doc_id, content, file_path, llm_cache_list, - create_time, update_time) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) - ON CONFLICT (workspace,id) DO UPDATE - SET tokens=EXCLUDED.tokens, - chunk_order_index=EXCLUDED.chunk_order_index, - full_doc_id=EXCLUDED.full_doc_id, - content = EXCLUDED.content, - file_path=EXCLUDED.file_path, - llm_cache_list=EXCLUDED.llm_cache_list, - update_time = EXCLUDED.update_time - """, - # SQL for VectorStorage - "upsert_chunk": """INSERT INTO LIGHTRAG_VDB_CHUNKS (workspace, id, tokens, - chunk_order_index, full_doc_id, content, content_vector, file_path, - create_time, update_time) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) - ON CONFLICT (workspace,id) DO UPDATE - SET tokens=EXCLUDED.tokens, - chunk_order_index=EXCLUDED.chunk_order_index, - full_doc_id=EXCLUDED.full_doc_id, - content = EXCLUDED.content, - content_vector=EXCLUDED.content_vector, - file_path=EXCLUDED.file_path, - update_time = EXCLUDED.update_time - """, - "upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content, - content_vector, chunk_ids, file_path, create_time, update_time) - VALUES ($1, $2, $3, $4, $5, $6::varchar[], $7, $8, $9) - ON CONFLICT (workspace,id) DO UPDATE - SET entity_name=EXCLUDED.entity_name, - content=EXCLUDED.content, - content_vector=EXCLUDED.content_vector, - chunk_ids=EXCLUDED.chunk_ids, - file_path=EXCLUDED.file_path, - update_time=EXCLUDED.update_time - """, - "upsert_relationship": """INSERT INTO LIGHTRAG_VDB_RELATION (workspace, id, source_id, - target_id, content, content_vector, chunk_ids, file_path, create_time, update_time) - VALUES ($1, $2, $3, $4, $5, $6, $7::varchar[], $8, $9, $10) - ON CONFLICT (workspace,id) DO UPDATE - SET source_id=EXCLUDED.source_id, - target_id=EXCLUDED.target_id, - content=EXCLUDED.content, - content_vector=EXCLUDED.content_vector, - chunk_ids=EXCLUDED.chunk_ids, - file_path=EXCLUDED.file_path, - update_time = EXCLUDED.update_time - """, - "relationships": """ - WITH relevant_chunks AS ( - SELECT id as chunk_id - FROM LIGHTRAG_VDB_CHUNKS - WHERE $2::varchar[] IS NULL OR full_doc_id = ANY($2::varchar[]) - ) - SELECT source_id as src_id, target_id as tgt_id, EXTRACT(EPOCH FROM create_time)::BIGINT as created_at - FROM ( - SELECT r.id, r.source_id, r.target_id, r.create_time, 1 - (r.content_vector <=> '[{embedding_string}]'::vector) as distance - FROM LIGHTRAG_VDB_RELATION r - JOIN relevant_chunks c ON c.chunk_id = ANY(r.chunk_ids) - WHERE r.workspace=$1 - ) filtered - WHERE distance>$3 - ORDER BY distance DESC - LIMIT $4 - """, - "entities": """ - WITH relevant_chunks AS ( - SELECT id as chunk_id - FROM LIGHTRAG_VDB_CHUNKS - WHERE $2::varchar[] IS NULL OR full_doc_id = ANY($2::varchar[]) - ) - SELECT entity_name, EXTRACT(EPOCH FROM create_time)::BIGINT as created_at FROM - ( - SELECT e.id, e.entity_name, e.create_time, 1 - (e.content_vector <=> '[{embedding_string}]'::vector) as distance - FROM LIGHTRAG_VDB_ENTITY e - JOIN relevant_chunks c ON c.chunk_id = ANY(e.chunk_ids) - WHERE e.workspace=$1 - ) as chunk_distances - WHERE distance>$3 - ORDER BY distance DESC - LIMIT $4 - """, - "chunks": """ - WITH relevant_chunks AS ( - SELECT id as chunk_id - FROM LIGHTRAG_VDB_CHUNKS - WHERE $2::varchar[] IS NULL OR full_doc_id = ANY($2::varchar[]) - ) - SELECT id, content, file_path, EXTRACT(EPOCH FROM create_time)::BIGINT as created_at FROM - ( - SELECT id, content, file_path, create_time, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance - FROM LIGHTRAG_VDB_CHUNKS - WHERE workspace=$1 - AND id IN (SELECT chunk_id FROM relevant_chunks) - ) as chunk_distances - WHERE distance>$3 - ORDER BY distance DESC - LIMIT $4 - """, - # DROP tables - "drop_specifiy_table_workspace": """ - DELETE FROM {table_name} WHERE workspace=$1 - """, -} diff --git a/dsLightRag/Doc/T2、史校长资料/第一、二部分数与代数.txt b/dsLightRag/Doc/T1、史校长资料/第一、二部分数与代数.txt similarity index 100% rename from dsLightRag/Doc/T2、史校长资料/第一、二部分数与代数.txt rename to dsLightRag/Doc/T1、史校长资料/第一、二部分数与代数.txt diff --git a/dsLightRag/Doc/T2、史校长资料/第三部分图形与几何.txt b/dsLightRag/Doc/T1、史校长资料/第三部分图形与几何.txt similarity index 100% rename from dsLightRag/Doc/T2、史校长资料/第三部分图形与几何.txt rename to dsLightRag/Doc/T1、史校长资料/第三部分图形与几何.txt diff --git a/dsLightRag/Doc/T2、史校长资料/第四部分统计与概率.txt b/dsLightRag/Doc/T1、史校长资料/第四部分统计与概率.txt similarity index 100% rename from dsLightRag/Doc/T2、史校长资料/第四部分统计与概率.txt rename to dsLightRag/Doc/T1、史校长资料/第四部分统计与概率.txt diff --git a/dsLightRag/Doc/T2、史校长资料/说明.txt b/dsLightRag/Doc/T1、史校长资料/说明.txt similarity index 100% rename from dsLightRag/Doc/T2、史校长资料/说明.txt rename to dsLightRag/Doc/T1、史校长资料/说明.txt diff --git a/dsLightRag/Doc/下一步需要研究的技术内容.txt b/dsLightRag/Doc/下一步需要研究的技术内容.txt index cc7fcf20..cb75675a 100644 --- a/dsLightRag/Doc/下一步需要研究的技术内容.txt +++ b/dsLightRag/Doc/下一步需要研究的技术内容.txt @@ -1,9 +1,13 @@ 一、深入学习LightRAG的维护 https://github.com/HKUDS/LightRAG/blob/main/README-zh.md -(1)编辑实体和关系 -(2)删除功能 -(3)实体合并 -(4) 尝试用python+postgresql+age 绘制出指定主题的知识图谱 +(1)根据文档构建的实体、块、关联关系的获取,维护,可视化展现 +(2)增加实体和关系 +(2)编辑实体和关系 +(4)实体合并 (不同名称的实体合并,重新维护说明信息等) + +黄海: +清晰合理的关系维护,是目前我看到比华为优秀的地方,华为提供的技术方案,都是直接以主体和文本块直接关联,这很显然是不对的, +三者 Entity+Chunk+Relation 之间的关系也是合理的。 二、需要为文档入库提供两个工具: @@ -12,10 +16,14 @@ https://github.com/HKUDS/LightRAG/blob/main/README-zh.md https://www.53ai.com/news/OpenSourceLLM/2024071482650.html 2、PDF转office -Python简单使用MinerU -https://blog.csdn.net/make_progress/article/details/145802697 +吴缤测试的飞浆 + +吴缤:上面两项需要制作DEMO -https://www.cnblogs.com/Rainy7/p/12275952.html +三、为初中数学学科打造可以解题的大模型 +(1) QWen Math,QWen VL 与 QVQ的使用 +吴缤:制作DEMO演示我们在初中数学方面取得了哪些进展? -三、QWen Math 与QVQ的使用 +(2) 可否与知识库通过多次交互,逐步分解题目,再提取知识库中提到的关键点,引导大模型进行多次深度分析,以达到提高解题能力的目标。 +TODO