main
HuangHai 5 days ago
parent 43dbe63336
commit 1e365914bd

@ -29,5 +29,5 @@ NEO4J_AUTH = (NEO4J_USERNAME, NEO4J_PASSWORD)
# 免费的重排模型
RERANK_MODEL='BAAI/bge-reranker-v2-m3'
RERANK_BINDING_HOST='https://api.siliconflow.cn/v1/rerank'
RERANK_BASE_URL='https://api.siliconflow.cn/v1/rerank'
RERANK_BINDING_API_KEY='sk-pbqibyjwhrgmnlsmdygplahextfaclgnedetybccknxojlyl'

@ -57,18 +57,13 @@ async def rag(request: fastapi.Request):
async def generate_response_stream(query: str):
try:
rag = LightRAG(
working_dir=WORKING_PATH,
llm_model_func=create_llm_model_func(),
embedding_func=create_embedding_func()
)
rag = await initialize_rag(WORKING_PATH)
await rag.initialize_storages()
await initialize_pipeline_status()
resp = await rag.aquery(
query=query,
param=QueryParam(mode=mode, stream=True, user_prompt=user_prompt,enable_rerank=False))
# hybrid naive
param=QueryParam(mode=mode, stream=True, user_prompt=user_prompt, enable_rerank=True))
async for chunk in resp:
if not chunk:
@ -227,14 +222,12 @@ async def update_knowledge(request: fastapi.Request):
),
node_id)
return {"code": 0, "msg": "更新成功"}
except Exception as e:
logger.error(f"更新知识失败: {str(e)}")
return {"code": 1, "msg": str(e)}
@app.post("/api/render_html")
async def render_html(request: fastapi.Request):
data = await request.json()
@ -351,6 +344,5 @@ async def get_articles(page: int = 1, limit: int = 10):
return {"code": 1, "msg": str(e)}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)

@ -0,0 +1,230 @@
"""
LightRAG Rerank Integration Example
This example demonstrates how to use rerank functionality with LightRAG
to improve retrieval quality across different query modes.
Configuration Required:
1. Set your LLM API key and base URL in llm_model_func()
2. Set your embedding API key and base URL in embedding_func()
3. Set your rerank API key and base URL in the rerank configuration
4. Or use environment variables (.env file):
- RERANK_MODEL=your_rerank_model
- RERANK_BINDING_HOST=your_rerank_endpoint
- RERANK_BINDING_API_KEY=your_rerank_api_key
Note: Rerank is now controlled per query via the 'enable_rerank' parameter (default: True)
"""
import asyncio
import os
import numpy as np
from lightrag import LightRAG, QueryParam
from lightrag.rerank import custom_rerank, RerankModel
from lightrag.llm.openai import openai_complete_if_cache, openai_embed
from lightrag.utils import EmbeddingFunc, setup_logger
from lightrag.kg.shared_storage import initialize_pipeline_status
# Set up your working directory
WORKING_DIR = "./test_rerank"
setup_logger("test_rerank")
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
async def llm_model_func(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
return await openai_complete_if_cache(
"gpt-4o-mini",
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
api_key="your_llm_api_key_here",
base_url="https://api.your-llm-provider.com/v1",
**kwargs,
)
async def embedding_func(texts: list[str]) -> np.ndarray:
return await openai_embed(
texts,
model="text-embedding-3-large",
api_key="your_embedding_api_key_here",
base_url="https://api.your-embedding-provider.com/v1",
)
async def my_rerank_func(query: str, documents: list, top_k: int = None, **kwargs):
"""Custom rerank function with all settings included"""
return await custom_rerank(
query=query,
documents=documents,
model="BAAI/bge-reranker-v2-m3",
base_url="https://api.your-rerank-provider.com/v1/rerank",
api_key="your_rerank_api_key_here",
top_k=top_k or 10, # Default top_k if not provided
**kwargs,
)
async def create_rag_with_rerank():
"""Create LightRAG instance with rerank configuration"""
# Get embedding dimension
test_embedding = await embedding_func(["test"])
embedding_dim = test_embedding.shape[1]
print(f"Detected embedding dimension: {embedding_dim}")
# Method 1: Using custom rerank function
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=embedding_dim,
max_token_size=8192,
func=embedding_func,
),
# Rerank Configuration - provide the rerank function
rerank_model_func=my_rerank_func,
)
await rag.initialize_storages()
await initialize_pipeline_status()
return rag
async def test_rerank_with_different_settings():
"""
Test rerank functionality with different enable_rerank settings
"""
print("🚀 Setting up LightRAG with Rerank functionality...")
rag = await create_rag_with_rerank()
# Insert sample documents
sample_docs = [
"Reranking improves retrieval quality by re-ordering documents based on relevance.",
"LightRAG is a powerful retrieval-augmented generation system with multiple query modes.",
"Vector databases enable efficient similarity search in high-dimensional embedding spaces.",
"Natural language processing has evolved with large language models and transformers.",
"Machine learning algorithms can learn patterns from data without explicit programming.",
]
print("📄 Inserting sample documents...")
await rag.ainsert(sample_docs)
query = "How does reranking improve retrieval quality?"
print(f"\n🔍 Testing query: '{query}'")
print("=" * 80)
# Test with rerank enabled (default)
print("\n📊 Testing with enable_rerank=True (default):")
result_with_rerank = await rag.aquery(
query,
param=QueryParam(
mode="naive",
top_k=10,
chunk_top_k=5,
enable_rerank=True, # Explicitly enable rerank
),
)
print(f" Result length: {len(result_with_rerank)} characters")
print(f" Preview: {result_with_rerank[:100]}...")
# Test with rerank disabled
print("\n📊 Testing with enable_rerank=False:")
result_without_rerank = await rag.aquery(
query,
param=QueryParam(
mode="naive",
top_k=10,
chunk_top_k=5,
enable_rerank=False, # Disable rerank
),
)
print(f" Result length: {len(result_without_rerank)} characters")
print(f" Preview: {result_without_rerank[:100]}...")
# Test with default settings (enable_rerank defaults to True)
print("\n📊 Testing with default settings (enable_rerank defaults to True):")
result_default = await rag.aquery(
query, param=QueryParam(mode="naive", top_k=10, chunk_top_k=5)
)
print(f" Result length: {len(result_default)} characters")
print(f" Preview: {result_default[:100]}...")
async def test_direct_rerank():
"""Test rerank function directly"""
print("\n🔧 Direct Rerank API Test")
print("=" * 40)
documents = [
{"content": "Reranking significantly improves retrieval quality"},
{"content": "LightRAG supports advanced reranking capabilities"},
{"content": "Vector search finds semantically similar documents"},
{"content": "Natural language processing with modern transformers"},
{"content": "The quick brown fox jumps over the lazy dog"},
]
query = "rerank improve quality"
print(f"Query: '{query}'")
print(f"Documents: {len(documents)}")
try:
reranked_docs = await custom_rerank(
query=query,
documents=documents,
model="BAAI/bge-reranker-v2-m3",
base_url="https://api.your-rerank-provider.com/v1/rerank",
api_key="your_rerank_api_key_here",
top_k=3,
)
print("\n✅ Rerank Results:")
for i, doc in enumerate(reranked_docs):
score = doc.get("rerank_score", "N/A")
content = doc.get("content", "")[:60]
print(f" {i+1}. Score: {score:.4f} | {content}...")
except Exception as e:
print(f"❌ Rerank failed: {e}")
async def main():
"""Main example function"""
print("🎯 LightRAG Rerank Integration Example")
print("=" * 60)
try:
# Test rerank with different enable_rerank settings
await test_rerank_with_different_settings()
# Test direct rerank
await test_direct_rerank()
print("\n✅ Example completed successfully!")
print("\n💡 Key Points:")
print(" ✓ Rerank is now controlled per query via 'enable_rerank' parameter")
print(" ✓ Default value for enable_rerank is True")
print(" ✓ Rerank function is configured at LightRAG initialization")
print(" ✓ Per-query enable_rerank setting overrides default behavior")
print(
" ✓ If enable_rerank=True but no rerank model is configured, a warning is issued"
)
print(" ✓ Monitor API usage and costs when using rerank services")
except Exception as e:
print(f"\n❌ Example failed: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
asyncio.run(main())

File diff suppressed because one or more lines are too long

@ -524,5 +524,18 @@
"create_time": 1752817655,
"update_time": 1752817655,
"_id": "hybrid:keywords:4ac555f6e82f71e609e48a23a7d40ffa"
},
"hybrid:keywords:cd10e7df77fd85a76d0824323f22e4b0": {
"return": "{\"high_level_keywords\": [\"\\u6559\\u5b66\\u8bbe\\u8ba1\", \"\\u51e0\\u4f55\\u6982\\u5ff5\", \"\\u6570\\u5b66\\u6559\\u80b2\"], \"low_level_keywords\": [\"\\u70b9\", \"\\u7ebf\", \"\\u9762\", \"\\u4f53\", \"\\u89d2\"]}",
"cache_type": "keywords",
"chunk_id": null,
"embedding": null,
"embedding_shape": null,
"embedding_min": null,
"embedding_max": null,
"original_prompt": "帮我写一下 如何理解点、线、面、体、角 的教学设计",
"create_time": 1752819291,
"update_time": 1752819291,
"_id": "hybrid:keywords:cd10e7df77fd85a76d0824323f22e4b0"
}
}

@ -5,7 +5,7 @@ import numpy as np
from lightrag import LightRAG
from lightrag.kg.shared_storage import initialize_pipeline_status
from lightrag.llm.openai import openai_complete_if_cache, openai_embed
from lightrag.rerank import jina_rerank
from lightrag.rerank import jina_rerank, custom_rerank
from lightrag.utils import EmbeddingFunc, logger, set_verbose_debug
import Config.Config
@ -98,18 +98,18 @@ async def embedding_func(texts: list[str]) -> np.ndarray:
)
async def my_rerank_func(query: str, documents: list, top_k: int = None, **kwargs):
return await jina_rerank(
async def rerank_func(query: str, documents: list, top_k: int = None, **kwargs):
return await custom_rerank(
query=query,
documents=documents,
model=Config.Config.RERANK_MODEL,
base_url=RERANK_BINDING_HOST,
base_url=Config.Config.RERANK_BASE_URL,
api_key=Config.Config.RERANK_BINDING_API_KEY,
top_k=top_k or 10,
**kwargs
**kwargs,
)
async def initialize_rag(working_dir, graph_storage=None):
if graph_storage is None:
graph_storage = 'NetworkXStorage'
@ -122,8 +122,7 @@ async def initialize_rag(working_dir, graph_storage=None):
max_token_size=EMBED_MAX_TOKEN_SIZE,
func=embedding_func
),
# 重排模型
rerank_model_func=my_rerank_func,
rerank_model_func=rerank_func,
)
await rag.initialize_storages()
@ -158,3 +157,15 @@ def create_embedding_func():
base_url=EMBED_BASE_URL,
),
)
async def rerank_func(query: str, documents: list, top_k: int = None, **kwargs):
return await jina_rerank(
query=query,
documents=documents,
model=Config.Config.RERANK_MODEL,
base_url=Config.Config.RERANK_BASE_URL,
api_key=Config.Config.RERANK_BINDING_API_KEY,
top_k=top_k or 10,
**kwargs
)
Loading…
Cancel
Save