You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

93 lines
2.8 KiB

import asyncio
import logging
import os
import time
from lightrag import LightRAG, QueryParam
from lightrag.utils import EmbeddingFunc
from lightrag.kg.shared_storage import initialize_pipeline_status
from Config.Config import EMBED_DIM, EMBED_MAX_TOKEN_SIZE, LLM_MODEL_NAME
from Util.LightRagUtil import embedding_func, llm_model_func
# 在程序开始时添加以下配置
logging.basicConfig(
level=logging.INFO, # 设置日志级别为INFO
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
# 或者如果你想更详细地控制日志输出
logger = logging.getLogger('lightrag')
logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
logger.addHandler(handler)
ROOT_DIR = '.'
WORKING_DIR = f"{ROOT_DIR}/dickens-pg"
logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO)
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
# AGE
os.environ["AGE_GRAPH_NAME"] = "dickens"
os.environ["POSTGRES_HOST"] = "10.10.14.208"
os.environ["POSTGRES_PORT"] = "5432"
os.environ["POSTGRES_USER"] = "postgres"
os.environ["POSTGRES_PASSWORD"] = "postgres"
os.environ["POSTGRES_DATABASE"] = "rag"
async def initialize_rag():
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=llm_model_func,
llm_model_name=LLM_MODEL_NAME,
llm_model_max_async=4,
llm_model_max_token_size=32768,
enable_llm_cache_for_entity_extract=True,
embedding_func=EmbeddingFunc(
embedding_dim=EMBED_DIM,
max_token_size=EMBED_MAX_TOKEN_SIZE,
func=embedding_func
),
kv_storage="PGKVStorage",
doc_status_storage="PGDocStatusStorage",
graph_storage="PGGraphStorage",
vector_storage="PGVectorStorage",
auto_manage_storages_states=False,
)
await rag.initialize_storages()
await initialize_pipeline_status()
return rag
async def main():
try:
rag = await initialize_rag()
#with open(f"{ROOT_DIR}/book.txt", "r", encoding="utf-8") as f:
# await rag.ainsert(f.read())
print("==== Trying to test the rag queries ====")
print("**** Start Naive Query ****")
start_time = time.time()
# Perform naive search
await rag.aquery(
"What are the top themes in this story?", param=QueryParam(mode="naive")
)
print(f"Naive Query Time: {time.time() - start_time} seconds")
except Exception as e:
#logger.error(f"Main execution error: {e}")
pass
finally:
# 确保所有资源正确释放
await asyncio.sleep(0.1) # 给pending任务完成的时间
if __name__ == "__main__":
asyncio.run(main())