You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

245 lines
7.7 KiB

import os
import asyncio
import inspect
import logging
import logging.config
import numpy as np
from lightrag import LightRAG, QueryParam
from lightrag.llm.openai import openai_complete_if_cache, openai_embed
from lightrag.llm.ollama import ollama_embed
from lightrag.utils import EmbeddingFunc, logger, set_verbose_debug
from lightrag.kg.shared_storage import initialize_pipeline_status
from dotenv import load_dotenv
load_dotenv(dotenv_path=".env", override=False)
WORKING_DIR = "./dickens"
def configure_logging():
"""Configure logging for the application"""
# Reset any existing handlers to ensure clean configuration
for logger_name in ["uvicorn", "uvicorn.access", "uvicorn.error", "lightrag"]:
logger_instance = logging.getLogger(logger_name)
logger_instance.handlers = []
logger_instance.filters = []
# Get log directory path from environment variable or use current directory
log_dir = os.getenv("LOG_DIR", os.getcwd())
log_file_path = os.path.abspath(
os.path.join(log_dir, "lightrag_compatible_demo.log")
)
print(f"\nLightRAG compatible demo log file: {log_file_path}\n")
os.makedirs(os.path.dirname(log_dir), exist_ok=True)
# Get log file max size and backup count from environment variables
log_max_bytes = int(os.getenv("LOG_MAX_BYTES", 10485760)) # Default 10MB
log_backup_count = int(os.getenv("LOG_BACKUP_COUNT", 5)) # Default 5 backups
logging.config.dictConfig(
{
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"default": {
"format": "%(levelname)s: %(message)s",
},
"detailed": {
"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
},
},
"handlers": {
"console": {
"formatter": "default",
"class": "logging.StreamHandler",
"stream": "ext://sys.stderr",
},
"file": {
"formatter": "detailed",
"class": "logging.handlers.RotatingFileHandler",
"filename": log_file_path,
"maxBytes": log_max_bytes,
"backupCount": log_backup_count,
"encoding": "utf-8",
},
},
"loggers": {
"lightrag": {
"handlers": ["console", "file"],
"level": "INFO",
"propagate": False,
},
},
}
)
# Set the logger level to INFO
logger.setLevel(logging.INFO)
# Enable verbose debug if needed
set_verbose_debug(os.getenv("VERBOSE_DEBUG", "false").lower() == "true")
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
async def llm_model_func(
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
return await openai_complete_if_cache(
os.getenv("LLM_MODEL", "deepseek-chat"),
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
# api_key=os.getenv("LLM_BINDING_API_KEY") or os.getenv("OPENAI_API_KEY"),
# 这里黄海写死了购买的DeepSeek的API_KEY
api_key='sk-44ae895eeb614aa1a9c6460579e322f1',
base_url="https://api.deepseek.com",
**kwargs,
)
async def print_stream(stream):
async for chunk in stream:
if chunk:
print(chunk, end="", flush=True)
async def embedding_func(texts: list[str]) -> np.ndarray:
# 硅基流动【这里使用了黄海在硅基流动申请的API KEY】
return await openai_embed(
texts,
model="BAAI/bge-m3",
api_key="sk-pbqibyjwhrgmnlsmdygplahextfaclgnedetybccknxojlyl",
base_url="https://api.siliconflow.cn/v1"
)
async def initialize_rag():
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=1024,
max_token_size=8192,
# 下面的代码黄海修改过
# func=lambda texts: ollama_embed(
# texts,
# embed_model=os.getenv("EMBEDDING_MODEL", "bge-m3:latest"),
# host=os.getenv("EMBEDDING_BINDING_HOST", "http://localhost:11434"),
# ),
func=embedding_func
),
)
await rag.initialize_storages()
await initialize_pipeline_status()
return rag
async def main():
try:
# 注释掉或删除以下清理代码
# files_to_delete = [
# "graph_chunk_entity_relation.graphml",
# "kv_store_doc_status.json",
# "kv_store_full_docs.json",
# "kv_store_text_chunks.json",
# "vdb_chunks.json",
# "vdb_entities.json",
# "vdb_relationships.json",
# ]
# for file in files_to_delete:
# file_path = os.path.join(WORKING_DIR, file)
# if os.path.exists(file_path):
# os.remove(file_path)
# print(f"Deleting old file:: {file_path}")
# Initialize RAG instance
rag = await initialize_rag()
# Test embedding function
test_text = ["This is a test string for embedding."]
embedding = await rag.embedding_func(test_text)
embedding_dim = embedding.shape[1]
print("\n=======================")
print("Test embedding function")
print("========================")
print(f"Test dict: {test_text}")
print(f"Detected embedding dimension: {embedding_dim}\n\n")
#with open("./sushi.txt", "r", encoding="utf-8") as f:
# await rag.ainsert(f.read())
# # Perform naive search
# print("\n=====================")
# print("Query mode: naive")
# print("=====================")
# resp = await rag.aquery(
# "What are the top themes in this story?",
# param=QueryParam(mode="naive", stream=True),
# )
# if inspect.isasyncgen(resp):
# await print_stream(resp)
# else:
# print(resp)
#
# # Perform local search
# print("\n=====================")
# print("Query mode: local")
# print("=====================")
# resp = await rag.aquery(
# "What are the top themes in this story?",
# param=QueryParam(mode="local", stream=True),
# )
# if inspect.isasyncgen(resp):
# await print_stream(resp)
# else:
# print(resp)
#
# # Perform global search
# print("\n=====================")
# print("Query mode: global")
# print("=====================")
# resp = await rag.aquery(
# "What are the top themes in this story?",
# param=QueryParam(mode="global", stream=True),
# )
# if inspect.isasyncgen(resp):
# await print_stream(resp)
# else:
# print(resp)
# Perform hybrid search
print("\n=====================")
print("Query mode: hybrid")
print("=====================")
resp = await rag.aquery(
"苏轼与王安石是什么关系?",
param=QueryParam(mode="hybrid", stream=True),
)
if inspect.isasyncgen(resp):
await print_stream(resp)
else:
print(resp)
except Exception as e:
print(f"An error occurred: {e}")
finally:
if rag:
await rag.finalize_storages()
if __name__ == "__main__":
# Configure logging before running the main function
configure_logging()
asyncio.run(main())
print("\nDone!")