You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
121 lines
3.5 KiB
121 lines
3.5 KiB
import numpy as np
|
|
from lightrag import LightRAG
|
|
from lightrag.kg.shared_storage import initialize_pipeline_status
|
|
from lightrag.llm.openai import openai_complete_if_cache, openai_embed
|
|
from lightrag.utils import EmbeddingFunc
|
|
|
|
from Config.Config import *
|
|
|
|
|
|
def create_llm_model_func():
|
|
def llm_model_func(prompt, system_prompt=None, history_messages=[], **kwargs):
|
|
return openai_complete_if_cache(
|
|
LLM_MODEL_NAME,
|
|
prompt,
|
|
system_prompt=system_prompt,
|
|
history_messages=history_messages,
|
|
api_key=LLM_API_KEY,
|
|
base_url=LLM_BASE_URL,
|
|
**kwargs,
|
|
)
|
|
|
|
return llm_model_func
|
|
|
|
|
|
def create_vision_model_func(llm_model_func):
|
|
def vision_model_func(
|
|
prompt, system_prompt=None, history_messages=[], image_data=None, **kwargs
|
|
):
|
|
if image_data:
|
|
return openai_complete_if_cache(
|
|
VISION_MODEL_NAME,
|
|
"",
|
|
system_prompt=None,
|
|
history_messages=[],
|
|
messages=[
|
|
{"role": "system", "content": system_prompt}
|
|
if system_prompt
|
|
else None,
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "text", "text": prompt},
|
|
{
|
|
"type": "image_url",
|
|
"image_url": {
|
|
"url": f"data:image/jpeg;base64,{image_data}"
|
|
},
|
|
},
|
|
],
|
|
}
|
|
if image_data
|
|
else {"role": "user", "content": prompt},
|
|
],
|
|
api_key=VISION_API_KEY,
|
|
base_url=VISION_BASE_URL,
|
|
**kwargs,
|
|
)
|
|
else:
|
|
return llm_model_func(prompt, system_prompt, history_messages, **kwargs)
|
|
|
|
return vision_model_func
|
|
|
|
|
|
def create_embedding_func():
|
|
return EmbeddingFunc(
|
|
embedding_dim=1024,
|
|
max_token_size=8192,
|
|
func=lambda texts: openai_embed(
|
|
texts,
|
|
model=EMBED_MODEL_NAME,
|
|
api_key=EMBED_API_KEY,
|
|
base_url=EMBED_BASE_URL,
|
|
),
|
|
)
|
|
|
|
|
|
async def llm_model_func(
|
|
prompt, system_prompt=None, history_messages=None, **kwargs
|
|
) -> str:
|
|
return await openai_complete_if_cache(
|
|
LLM_MODEL_NAME,
|
|
prompt,
|
|
system_prompt=system_prompt,
|
|
history_messages=history_messages,
|
|
api_key=LLM_API_KEY,
|
|
base_url=LLM_MODEL_NAME,
|
|
**kwargs,
|
|
)
|
|
|
|
|
|
async def embedding_func(texts: list[str]) -> np.ndarray:
|
|
return await openai_embed(
|
|
texts,
|
|
model=EMBED_MODEL_NAME,
|
|
api_key=EMBED_API_KEY,
|
|
base_url=EMBED_BASE_URL
|
|
)
|
|
|
|
|
|
async def initialize_rag(working_dir):
|
|
rag = LightRAG(
|
|
working_dir=working_dir,
|
|
llm_model_func=llm_model_func,
|
|
embedding_func=EmbeddingFunc(
|
|
embedding_dim=EMBED_DIM,
|
|
max_token_size=EMBED_MAX_TOKEN_SIZE,
|
|
func=embedding_func
|
|
),
|
|
)
|
|
|
|
await rag.initialize_storages()
|
|
await initialize_pipeline_status()
|
|
|
|
return rag
|
|
|
|
|
|
async def print_stream(stream):
|
|
async for chunk in stream:
|
|
if chunk:
|
|
print(chunk, end="", flush=True)
|