main
HuangHai 3 weeks ago
parent 31b80f0c19
commit 39afdbb181

@ -8,7 +8,6 @@ EMBED_MAX_TOKEN_SIZE = 8192
LLM_API_KEY="sk-44ae895eeb614aa1a9c6460579e322f1" LLM_API_KEY="sk-44ae895eeb614aa1a9c6460579e322f1"
LLM_BASE_URL = "https://api.deepseek.com" LLM_BASE_URL = "https://api.deepseek.com"
LLM_MODEL_NAME = "deepseek-chat" LLM_MODEL_NAME = "deepseek-chat"
# 视觉模型 # 视觉模型
VISION_API_KEY = "sk-pbqibyjwhrgmnlsmdygplahextfaclgnedetybccknxojlyl" VISION_API_KEY = "sk-pbqibyjwhrgmnlsmdygplahextfaclgnedetybccknxojlyl"
VISION_BASE_URL = "https://api.siliconflow.cn/v1/chat/completions" VISION_BASE_URL = "https://api.siliconflow.cn/v1/chat/completions"

@ -7,11 +7,13 @@ from logging.handlers import RotatingFileHandler
import fastapi import fastapi
import uvicorn import uvicorn
from fastapi import FastAPI from fastapi import FastAPI
from lightrag.kg.shared_storage import initialize_pipeline_status
from raganything import RAGAnything
from sse_starlette import EventSourceResponse from sse_starlette import EventSourceResponse
from starlette.staticfiles import StaticFiles from starlette.staticfiles import StaticFiles
from Util.RagUtil import initialize_rag from Util.RagUtil import initialize_rag, create_llm_model_func, create_vision_model_func, create_embedding_func
from lightrag import QueryParam from lightrag import QueryParam, LightRAG
# 初始化日志 # 初始化日志
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -65,12 +67,32 @@ app.mount("/static", StaticFiles(directory="Static"), name="static")
async def rag(request: fastapi.Request): async def rag(request: fastapi.Request):
data = await request.json() data = await request.json()
query = data.get("query") query = data.get("query")
lightrag_working_dir = "./rag_storage"
async def generate_response_stream(query: str): async def generate_response_stream(query: str):
try: try:
resp = await request.app.state.rag.aquery( llm_model_func = create_llm_model_func()
query=query, vision_model_func = create_vision_model_func(llm_model_func)
param=QueryParam(mode="hybrid", stream=True)) embedding_func = create_embedding_func()
lightrag_instance = LightRAG(
working_dir=lightrag_working_dir,
llm_model_func=llm_model_func,
embedding_func=embedding_func
)
await lightrag_instance.initialize_storages()
await initialize_pipeline_status()
rag = RAGAnything(
lightrag=lightrag_instance,
vision_model_func=vision_model_func,
)
resp = await rag.aquery(
"平台安全的保证方法有哪些?",
mode="hybrid"
)
print("查询结果:", resp)
async for chunk in resp: async for chunk in resp:
if not chunk: if not chunk:

@ -1,4 +1,3 @@
import httpx
import numpy as np import numpy as np
from lightrag import LightRAG from lightrag import LightRAG
from lightrag.kg.shared_storage import initialize_pipeline_status from lightrag.kg.shared_storage import initialize_pipeline_status

File diff suppressed because one or more lines are too long
Loading…
Cancel
Save