You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

201 lines
7.0 KiB

import os
from typing import Any, final
from dataclasses import dataclass
import pipmaster as pm
import configparser
from contextlib import asynccontextmanager
if not pm.is_installed("redis"):
pm.install("redis")
# aioredis is a depricated library, replaced with redis
from redis.asyncio import Redis, ConnectionPool # type: ignore
from redis.exceptions import RedisError, ConnectionError # type: ignore
from lightrag.utils import logger
from lightrag.base import BaseKVStorage
import json
config = configparser.ConfigParser()
config.read("config.ini", "utf-8")
# Constants for Redis connection pool
MAX_CONNECTIONS = 50
SOCKET_TIMEOUT = 5.0
SOCKET_CONNECT_TIMEOUT = 3.0
@final
@dataclass
class RedisKVStorage(BaseKVStorage):
def __post_init__(self):
redis_url = os.environ.get(
"REDIS_URI", config.get("redis", "uri", fallback="redis://localhost:6379")
)
# Create a connection pool with limits
self._pool = ConnectionPool.from_url(
redis_url,
max_connections=MAX_CONNECTIONS,
decode_responses=True,
socket_timeout=SOCKET_TIMEOUT,
socket_connect_timeout=SOCKET_CONNECT_TIMEOUT,
)
self._redis = Redis(connection_pool=self._pool)
logger.info(
f"Initialized Redis connection pool for {self.namespace} with max {MAX_CONNECTIONS} connections"
)
@asynccontextmanager
async def _get_redis_connection(self):
"""Safe context manager for Redis operations."""
try:
yield self._redis
except ConnectionError as e:
logger.error(f"Redis connection error in {self.namespace}: {e}")
raise
except RedisError as e:
logger.error(f"Redis operation error in {self.namespace}: {e}")
raise
except Exception as e:
logger.error(
f"Unexpected error in Redis operation for {self.namespace}: {e}"
)
raise
async def close(self):
"""Close the Redis connection pool to prevent resource leaks."""
if hasattr(self, "_redis") and self._redis:
await self._redis.close()
await self._pool.disconnect()
logger.debug(f"Closed Redis connection pool for {self.namespace}")
async def __aenter__(self):
"""Support for async context manager."""
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Ensure Redis resources are cleaned up when exiting context."""
await self.close()
async def get_by_id(self, id: str) -> dict[str, Any] | None:
async with self._get_redis_connection() as redis:
try:
data = await redis.get(f"{self.namespace}:{id}")
return json.loads(data) if data else None
except json.JSONDecodeError as e:
logger.error(f"JSON decode error for id {id}: {e}")
return None
async def get_by_ids(self, ids: list[str]) -> list[dict[str, Any]]:
async with self._get_redis_connection() as redis:
try:
pipe = redis.pipeline()
for id in ids:
pipe.get(f"{self.namespace}:{id}")
results = await pipe.execute()
return [json.loads(result) if result else None for result in results]
except json.JSONDecodeError as e:
logger.error(f"JSON decode error in batch get: {e}")
return [None] * len(ids)
async def filter_keys(self, keys: set[str]) -> set[str]:
async with self._get_redis_connection() as redis:
pipe = redis.pipeline()
for key in keys:
pipe.exists(f"{self.namespace}:{key}")
results = await pipe.execute()
existing_ids = {keys[i] for i, exists in enumerate(results) if exists}
return set(keys) - existing_ids
async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
if not data:
return
logger.info(f"Inserting {len(data)} items to {self.namespace}")
async with self._get_redis_connection() as redis:
try:
pipe = redis.pipeline()
for k, v in data.items():
pipe.set(f"{self.namespace}:{k}", json.dumps(v))
await pipe.execute()
for k in data:
data[k]["_id"] = k
except json.JSONEncodeError as e:
logger.error(f"JSON encode error during upsert: {e}")
raise
async def index_done_callback(self) -> None:
# Redis handles persistence automatically
pass
async def delete(self, ids: list[str]) -> None:
"""Delete entries with specified IDs"""
if not ids:
return
async with self._get_redis_connection() as redis:
pipe = redis.pipeline()
for id in ids:
pipe.delete(f"{self.namespace}:{id}")
results = await pipe.execute()
deleted_count = sum(results)
logger.info(
f"Deleted {deleted_count} of {len(ids)} entries from {self.namespace}"
)
async def drop_cache_by_modes(self, modes: list[str] | None = None) -> bool:
"""Delete specific records from storage by by cache mode
Importance notes for Redis storage:
1. This will immediately delete the specified cache modes from Redis
Args:
modes (list[str]): List of cache mode to be drop from storage
Returns:
True: if the cache drop successfully
False: if the cache drop failed
"""
if not modes:
return False
try:
await self.delete(modes)
return True
except Exception:
return False
async def drop(self) -> dict[str, str]:
"""Drop the storage by removing all keys under the current namespace.
Returns:
dict[str, str]: Status of the operation with keys 'status' and 'message'
"""
async with self._get_redis_connection() as redis:
try:
keys = await redis.keys(f"{self.namespace}:*")
if keys:
pipe = redis.pipeline()
for key in keys:
pipe.delete(key)
results = await pipe.execute()
deleted_count = sum(results)
logger.info(f"Dropped {deleted_count} keys from {self.namespace}")
return {
"status": "success",
"message": f"{deleted_count} keys dropped",
}
else:
logger.info(f"No keys found to drop in {self.namespace}")
return {"status": "success", "message": "no keys to drop"}
except Exception as e:
logger.error(f"Error dropping keys from {self.namespace}: {e}")
return {"status": "error", "message": str(e)}