diff --git a/dsRagAnything/Test.py b/dsRagAnything/T1_Train.py similarity index 96% rename from dsRagAnything/Test.py rename to dsRagAnything/T1_Train.py index 9382f74e..af0504ab 100644 --- a/dsRagAnything/Test.py +++ b/dsRagAnything/T1_Train.py @@ -28,7 +28,7 @@ if not logger.handlers: logger.addHandler(console_handler) # 循环滚动文件处理器(控制在200K左右) file_handler = RotatingFileHandler( - 'raganything.log', + 'lightrag.log', maxBytes=200*1024, # 200KB backupCount=5, # 最多保留5个备份文件 encoding='utf-8', diff --git a/dsRagAnything/T2_Query.py b/dsRagAnything/T2_Query.py new file mode 100644 index 00000000..bf91ebf0 --- /dev/null +++ b/dsRagAnything/T2_Query.py @@ -0,0 +1,46 @@ +import asyncio +import inspect +from Util.LightRagUtil import configure_logging, initialize_rag, print_stream +from lightrag import QueryParam + +# 化学 +data = [ + # {"NAME": "Chemistry", "Q": "硝酸光照分解的化学反应方程式是什么", "ChineseName": "化学"}, + {"NAME": "Chemistry", "Q": "氢气与氧气燃烧的现象", "ChineseName": "化学"}, + {"NAME": "Math", "Q": "氧化铁与硝酸的化学反应方程式是什么", "ChineseName": "数学"}, + {"NAME": "Chinese", "Q": "氧化铁与硝酸的化学反应方程式是什么", "ChineseName": "语文"}, + {"NAME": "JiHe", "Q": "三角形两边之和大于第三边的证明", "ChineseName": "几何"} +] + +# 准备查询的科目 +KEMU = "JiHe" # Chemistry JiHe + +# 查找索引号 +idx = [i for i, d in enumerate(data) if d["NAME"] == KEMU][0] + + +async def main(): + try: + user_prompt = "\n 1、资料中提供化学反应方程式的,一定要严格按提供的Latex公式输出,绝对不允许对Latex公式进行修改 !" + user_prompt = user_prompt + "\n 2、如果资料中提供了图片的,一定要严格按照原文提供图片输出,不允许省略或不输出!" + user_prompt = user_prompt + "\n 3、资料中提到的知识内容,需要判断是否与本次问题相关,不相关的绝对不要输出!" + rag = await initialize_rag('Topic/' + data[idx]["NAME"]) + resp = await rag.aquery( + data[idx]["Q"], + param=QueryParam(mode="hybrid", stream=True, user_prompt=user_prompt), + # hybrid naive + ) + if inspect.isasyncgen(resp): + await print_stream(resp) + else: + print(resp) + except Exception as e: + print(f"An error occurred: {e}") + finally: + if rag: + await rag.finalize_storages() + + +if __name__ == "__main__": + configure_logging() + asyncio.run(main()) diff --git a/dsRagAnything/Util/LightRagUtil.py b/dsRagAnything/Util/LightRagUtil.py new file mode 100644 index 00000000..838c678e --- /dev/null +++ b/dsRagAnything/Util/LightRagUtil.py @@ -0,0 +1,185 @@ +import logging +import logging.config +import os + +import numpy as np +from lightrag import LightRAG +from lightrag.kg.shared_storage import initialize_pipeline_status +from lightrag.llm.openai import openai_complete_if_cache, openai_embed +from lightrag.rerank import custom_rerank +from lightrag.utils import EmbeddingFunc, logger, set_verbose_debug + +import Config.Config +from Config.Config import * + + +async def print_stream(stream): + async for chunk in stream: + if chunk: + print(chunk, end="", flush=True) + + +def configure_logging(): + for logger_name in ["uvicorn", "uvicorn.access", "uvicorn.error", "lightrag"]: + logger_instance = logging.getLogger(logger_name) + logger_instance.handlers = [] + logger_instance.filters = [] + + log_dir = os.getenv("LOG_DIR", os.getcwd()) + log_file_path = os.path.abspath( + os.path.join(log_dir, "./Logs/lightrag.log") + ) + + print(f"\nLightRAG log file: {log_file_path}\n") + os.makedirs(os.path.dirname(log_dir), exist_ok=True) + + log_max_bytes = int(os.getenv("LOG_MAX_BYTES", 10485760)) + log_backup_count = int(os.getenv("LOG_BACKUP_COUNT", 5)) + + logging.config.dictConfig( + { + "version": 1, + "disable_existing_loggers": False, + "formatters": { + "default": { + "format": "%(levelname)s: %(message)s", + }, + "detailed": { + "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s", + }, + }, + "handlers": { + "console": { + "formatter": "default", + "class": "logging.StreamHandler", + "stream": "ext://sys.stderr", + }, + "file": { + "formatter": "detailed", + "class": "logging.handlers.RotatingFileHandler", + "filename": log_file_path, + "maxBytes": log_max_bytes, + "backupCount": log_backup_count, + "encoding": "utf-8", + }, + }, + "loggers": { + "lightrag": { + "handlers": ["console", "file"], + "level": "INFO", + "propagate": False, + }, + }, + } + ) + logger.setLevel(logging.INFO) + set_verbose_debug(os.getenv("VERBOSE_DEBUG", "false").lower() == "true") + + +async def llm_model_func( + prompt, system_prompt=None, history_messages=None, **kwargs +) -> str: + return await openai_complete_if_cache( + os.getenv("LLM_MODEL", Config.Config.ALY_LLM_MODEL_NAME), + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + api_key=Config.Config.ALY_LLM_API_KEY, + base_url=Config.Config.ALY_LLM_BASE_URL, + **kwargs, + ) + + +async def embedding_func(texts: list[str]) -> np.ndarray: + return await openai_embed( + texts, + model=EMBED_MODEL_NAME, + api_key=EMBED_API_KEY, + base_url=EMBED_BASE_URL + ) + +async def rerank_func(query: str, documents: list, top_k: int = None, **kwargs): + return await custom_rerank( + query=query, + documents=documents, + model=Config.Config.RERANK_MODEL, + base_url=Config.Config.RERANK_BASE_URL, + api_key=Config.Config.RERANK_BINDING_API_KEY, + top_k=top_k or 10, + **kwargs, + ) + +async def initialize_rag(working_dir, graph_storage=None): + if graph_storage is None: + graph_storage = 'NetworkXStorage' + rag = LightRAG( + working_dir=working_dir, + llm_model_func=llm_model_func, + graph_storage=graph_storage, + embedding_func=EmbeddingFunc( + embedding_dim=EMBED_DIM, + max_token_size=EMBED_MAX_TOKEN_SIZE, + func=embedding_func + ), + rerank_model_func=rerank_func, + ) + await rag.initialize_storages() + await initialize_pipeline_status() + return rag + + +def create_llm_model_func(): + def llm_model_func(prompt, system_prompt=None, history_messages=[], **kwargs): + return openai_complete_if_cache( + Config.Config.ALY_LLM_MODEL_NAME, + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + api_key=Config.Config.ALY_LLM_API_KEY, + base_url=Config.Config.ALY_LLM_BASE_URL, + **kwargs, + ) + + return llm_model_func + + +def create_embedding_func(): + return EmbeddingFunc( + embedding_dim=Config.Config.EMBED_DIM, + max_token_size=Config.Config.EMBED_MAX_TOKEN_SIZE, + func=lambda texts: openai_embed( + texts, + model=Config.Config.EMBED_MODEL_NAME, + api_key=Config.Config.EMBED_API_KEY, + base_url=Config.Config.EMBED_BASE_URL, + ), + ) + + + +async def initialize_pg_rag(WORKING_DIR, workspace='default'): + rag = LightRAG( + working_dir=WORKING_DIR, + llm_model_func=llm_model_func, + llm_model_name=Config.Config.ALY_LLM_MODEL_NAME, + llm_model_max_async=4, + #llm_model_max_token_size=32768, + enable_llm_cache_for_entity_extract=True, + embedding_func=EmbeddingFunc( + embedding_dim=Config.Config.EMBED_DIM, + max_token_size=Config.Config.EMBED_MAX_TOKEN_SIZE, + func=embedding_func + ), + rerank_model_func=rerank_func, + kv_storage="PGKVStorage", + doc_status_storage="PGDocStatusStorage", + graph_storage="PGGraphStorage", + vector_storage="PGVectorStorage", + auto_manage_storages_states=False, + vector_db_storage_cls_kwargs={"workspace": workspace} + ) + + await rag.initialize_storages() + await initialize_pipeline_status() + + return rag \ No newline at end of file diff --git a/dsRagAnything/Util/__init__.py b/dsRagAnything/Util/__init__.py new file mode 100644 index 00000000..e69de29b