main
HuangHai 5 days ago
parent bbb33fc77f
commit 43dbe63336

@ -19,7 +19,7 @@ from starlette.staticfiles import StaticFiles
from Util.LightRagUtil import *
from Util.PostgreSQLUtil import init_postgres_pool
# 想更详细地控制日志输出
# 日志输出
logger = logging.getLogger('lightrag')
logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
@ -34,7 +34,7 @@ async def lifespan(app: FastAPI):
app = FastAPI(lifespan=lifespan)
# 挂载静态文件目录
app.mount("/static", StaticFiles(directory="Static"), name="static")
app.mount("../static", StaticFiles(directory="Static"), name="static")
# 访问根的跳转
@ -47,7 +47,6 @@ async def redirect_to_ai():
async def rag(request: fastapi.Request):
data = await request.json()
workspace = data.get("topic", "ShiJi") # Chinese, Math ,ShiJi 默认是少年读史记
logger.info("工作空间:" + workspace)
# 查询的问题
query = data.get("query")
@ -100,13 +99,12 @@ async def rag(request: fastapi.Request):
"""
# 使用PG库后这个是没有用的,但目前的项目代码要求必传,就写一个吧。
WORKING_DIR = 'WorkingPath/' + workspace
WORKING_DIR = 'output/' + workspace
if not os.path.exists(WORKING_DIR):
os.makedirs(WORKING_DIR)
async def generate_response_stream(query: str):
try:
logger.info("workspace=" + workspace)
rag = await initialize_pg_rag(WORKING_DIR=WORKING_DIR, workspace=workspace)
resp = await rag.aquery(
@ -285,10 +283,10 @@ async def render_html(request: fastapi.Request):
html_content = html_content.replace("```", "")
# 创建临时文件
filename = f"relation_{uuid.uuid4().hex}.html"
filepath = os.path.join('static/temp', filename)
filepath = os.path.join('../static/temp', filename)
# 确保temp目录存在
os.makedirs('static/temp', exist_ok=True)
os.makedirs('../static/temp', exist_ok=True)
# 写入文件
with open(filepath, 'w', encoding='utf-8') as f:

@ -25,12 +25,12 @@ tasks = [
# { # 化学
# "workspace": "Chemistry", "docx_name": "Chemistry.docx",
# },
#{ # 数学
# "workspace": "Math", "docx_name": "Math.docx",
#},
{ # 几何
"workspace": "JiHe", "docx_name": "JiHe.docx",
{ # 数学
"workspace": "Math", "docx_name": "Math.docx",
},
#{ # 几何
# "workspace": "JiHe", "docx_name": "JiHe.docx",
#},
# { # 史记
# "workspace": "ShiJi", "docx_name": "ShiJi.docx",
# },
@ -45,7 +45,7 @@ tasks = [
# }
]
for task in tasks:
task["docx_path"] = "./static/Txt/" + task["docx_name"] # 3、文档路径 python是按引用传递的&
task["docx_path"] = "../static/Txt/" + task["docx_name"] # 3、文档路径 python是按引用传递的&
async def main():

@ -14,21 +14,20 @@ LLM_MODEL_NAME = "deepseek-chat"
# LLM_MODEL_NAME = "deepseek-v3"
#LLM_MODEL_NAME = "deepseek-r1" # 使用更牛B的r1模型
# 免费的嵌入模型
EMBED_MODEL_NAME = "BAAI/bge-m3"
EMBED_API_KEY = "sk-pbqibyjwhrgmnlsmdygplahextfaclgnedetybccknxojlyl"
EMBED_BASE_URL = "https://api.siliconflow.cn/v1"
EMBED_DIM = 1024
EMBED_MAX_TOKEN_SIZE = 8192
# 图数据库
NEO4J_URI = "bolt://localhost:7687"
NEO4J_USERNAME = "neo4j"
NEO4J_PASSWORD = "DsideaL147258369"
NEO4J_AUTH = (NEO4J_USERNAME, NEO4J_PASSWORD)
# POSTGRESQL配置信息
AGE_GRAPH_NAME = "dickens"
POSTGRES_HOST = "10.10.14.208"
POSTGRES_PORT = 5432
POSTGRES_USER = "postgres"
POSTGRES_PASSWORD = "postgres"
POSTGRES_DATABASE = "rag"
# 免费的重排模型
RERANK_MODEL='BAAI/bge-reranker-v2-m3'
RERANK_BINDING_HOST='https://api.siliconflow.cn/v1/rerank'
RERANK_BINDING_API_KEY='sk-pbqibyjwhrgmnlsmdygplahextfaclgnedetybccknxojlyl'

@ -13,4 +13,7 @@ global.index-url='https://mirrors.aliyun.com/pypi/simple/'
pip freeze > requirements.txt
# 新机器安装包
pip install -r D:\dsWork\dsProject\dsRag\requirements.txt
pip install -r D:\dsWork\dsProject\dsRag\requirements.txt
# 更新指定的包
pip install --upgrade lightrag-hku

@ -14,7 +14,6 @@ from starlette.responses import StreamingResponse
from starlette.staticfiles import StaticFiles
from Util.LightRagUtil import *
from Util.PostgreSQLUtil import init_postgres_pool
# 更详细地控制日志输出
logger = logging.getLogger('lightrag')
@ -68,7 +67,7 @@ async def rag(request: fastapi.Request):
await initialize_pipeline_status()
resp = await rag.aquery(
query=query,
param=QueryParam(mode=mode, stream=True, user_prompt=user_prompt))
param=QueryParam(mode=mode, stream=True, user_prompt=user_prompt,enable_rerank=False))
# hybrid naive
async for chunk in resp:

File diff suppressed because one or more lines are too long

@ -5,7 +5,10 @@ import numpy as np
from lightrag import LightRAG
from lightrag.kg.shared_storage import initialize_pipeline_status
from lightrag.llm.openai import openai_complete_if_cache, openai_embed
from lightrag.rerank import jina_rerank
from lightrag.utils import EmbeddingFunc, logger, set_verbose_debug
import Config.Config
from Config.Config import *
@ -95,6 +98,18 @@ async def embedding_func(texts: list[str]) -> np.ndarray:
)
async def my_rerank_func(query: str, documents: list, top_k: int = None, **kwargs):
return await jina_rerank(
query=query,
documents=documents,
model=Config.Config.RERANK_MODEL,
base_url=RERANK_BINDING_HOST,
api_key=Config.Config.RERANK_BINDING_API_KEY,
top_k=top_k or 10,
**kwargs
)
async def initialize_rag(working_dir, graph_storage=None):
if graph_storage is None:
graph_storage = 'NetworkXStorage'
@ -107,6 +122,8 @@ async def initialize_rag(working_dir, graph_storage=None):
max_token_size=EMBED_MAX_TOKEN_SIZE,
func=embedding_func
),
# 重排模型
rerank_model_func=my_rerank_func,
)
await rag.initialize_storages()
@ -140,39 +157,4 @@ def create_embedding_func():
api_key=EMBED_API_KEY,
base_url=EMBED_BASE_URL,
),
)
# AGE
os.environ["AGE_GRAPH_NAME"] = AGE_GRAPH_NAME
os.environ["POSTGRES_HOST"] = POSTGRES_HOST
os.environ["POSTGRES_PORT"] = str(POSTGRES_PORT)
os.environ["POSTGRES_USER"] = POSTGRES_USER
os.environ["POSTGRES_PASSWORD"] = POSTGRES_PASSWORD
os.environ["POSTGRES_DATABASE"] = POSTGRES_DATABASE
async def initialize_pg_rag(WORKING_DIR, workspace):
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=llm_model_func,
llm_model_name=LLM_MODEL_NAME,
llm_model_max_async=4,
llm_model_max_token_size=32768,
enable_llm_cache_for_entity_extract=True,
embedding_func=EmbeddingFunc(
embedding_dim=EMBED_DIM,
max_token_size=EMBED_MAX_TOKEN_SIZE,
func=embedding_func
),
kv_storage="PGKVStorage",
doc_status_storage="PGDocStatusStorage",
graph_storage="PGGraphStorage",
vector_storage="PGVectorStorage",
auto_manage_storages_states=False,
vector_db_storage_cls_kwargs={"workspace": workspace})
await rag.initialize_storages()
await initialize_pipeline_status()
return rag
)

@ -1,67 +0,0 @@
"""
pip install asyncpg
"""
import asyncpg
from Config.Config import *
# PostgreSQL 配置
POSTGRES_CONFIG = {
"host": POSTGRES_HOST,
"port": POSTGRES_PORT,
"user": POSTGRES_USER,
"password": POSTGRES_PASSWORD,
"database": POSTGRES_DATABASE,
"min_size": 1, # 设置为0表示不保留空闲连接
"max_size": 20,
"command_timeout": 60
}
# 初始化 PostgreSQL 连接池
async def init_postgres_pool():
return await asyncpg.create_pool(**POSTGRES_CONFIG)
# 查询示例
async def query_example(pool):
async with pool.acquire() as conn:
# 执行查询
rows = await conn.fetch('SELECT * FROM your_table LIMIT 10')
# 处理结果
for row in rows:
print(dict(row))
return rows
# 插入示例
async def insert_example(pool, data):
async with pool.acquire() as conn:
# 执行插入
stmt = """
INSERT INTO your_table (column1, column2)
VALUES ($1, $2)
"""
# 使用参数化查询防止SQL注入
inserted_id = await conn.fetchval(stmt, data['value1'], data['value2'])
print(f"插入成功!")
return inserted_id
# 使用示例
async def main():
pool = await init_postgres_pool()
# 查询示例
await query_example(pool)
# 插入示例
#data_to_insert = {'value1': 'test1', 'value2': 'test2'}
#await insert_example(pool, data_to_insert)
await pool.close()
if __name__ == '__main__':
import asyncio
asyncio.run(main())

@ -1,37 +0,0 @@
from pyvis.network import Network
import json
from selenium import webdriver
import time
# 读取您的JSON知识图谱数据
with open(r'D:\dsWork\dsProject\dsLightRag\Doc\史校长资料\技术一-知识图谱源文件\middle_school_math_graph.json',encoding='utf-8') as f:
data = json.load(f)
# 创建网络图
net = Network(height='800px', width='100%', notebook=False)
# 添加节点
for node in data['nodes']:
net.add_node(node['id'], label=node['id'][:20]+'...' if len(node['id'])>20 else node['id'])
# 添加边(如果有的话)
if 'edges' in data:
for edge in data['edges']:
net.add_edge(edge['source'], edge['target'])
# 保存为HTML文件
html_path = 'knowledge_graph_json.html'
net.show(html_path)
# 使用selenium截图保存为PNG
options = webdriver.ChromeOptions()
options.add_argument('--headless')
options.add_argument('--disable-gpu')
options.add_argument('--window-size=1200,900')
driver = webdriver.Chrome(options=options)
driver.get(f'file:///{os.path.abspath(html_path)}')
time.sleep(3) # 等待页面加载
driver.save_screenshot('knowledge_graph_json.png')
driver.quit()
Loading…
Cancel
Save