You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

232 lines
6.9 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import logging
import logging.config
import os
import numpy as np
from lightrag import LightRAG
from lightrag.kg.shared_storage import initialize_pipeline_status
from lightrag.llm.openai import openai_complete_if_cache, openai_embed
from lightrag.utils import EmbeddingFunc, logger, set_verbose_debug
from openai import OpenAI
from Config.Config import *
async def print_stream(stream):
async for chunk in stream:
if chunk:
print(chunk, end="", flush=True)
def configure_logging():
for logger_name in ["uvicorn", "uvicorn.access", "uvicorn.error", "lightrag"]:
logger_instance = logging.getLogger(logger_name)
logger_instance.handlers = []
logger_instance.filters = []
log_dir = os.getenv("LOG_DIR", os.getcwd())
log_file_path = os.path.abspath(
os.path.join(log_dir, "./Logs/lightrag.log")
)
print(f"\nLightRAG log file: {log_file_path}\n")
os.makedirs(os.path.dirname(log_dir), exist_ok=True)
log_max_bytes = int(os.getenv("LOG_MAX_BYTES", 10485760))
log_backup_count = int(os.getenv("LOG_BACKUP_COUNT", 5))
logging.config.dictConfig(
{
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"default": {
"format": "%(levelname)s: %(message)s",
},
"detailed": {
"format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
},
},
"handlers": {
"console": {
"formatter": "default",
"class": "logging.StreamHandler",
"stream": "ext://sys.stderr",
},
"file": {
"formatter": "detailed",
"class": "logging.handlers.RotatingFileHandler",
"filename": log_file_path,
"maxBytes": log_max_bytes,
"backupCount": log_backup_count,
"encoding": "utf-8",
},
},
"loggers": {
"lightrag": {
"handlers": ["console", "file"],
"level": "INFO",
"propagate": False,
},
},
}
)
logger.setLevel(logging.INFO)
set_verbose_debug(os.getenv("VERBOSE_DEBUG", "false").lower() == "true")
async def llm_model_func(
prompt, system_prompt=None, history_messages=None, **kwargs
) -> str:
return await openai_complete_if_cache(
os.getenv("LLM_MODEL", LLM_MODEL_NAME),
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
api_key=LLM_API_KEY,
base_url=LLM_BASE_URL,
**kwargs,
)
async def embedding_func(texts: list[str]) -> np.ndarray:
return await openai_embed(
texts,
model=EMBED_MODEL_NAME,
api_key=EMBED_API_KEY,
base_url=EMBED_BASE_URL
)
async def initialize_rag(working_dir):
rag = LightRAG(
working_dir=working_dir,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=EMBED_DIM,
max_token_size=EMBED_MAX_TOKEN_SIZE,
func=embedding_func
),
)
await rag.initialize_storages()
await initialize_pipeline_status()
return rag
def create_llm_model_func():
def llm_model_func(prompt, system_prompt=None, history_messages=[], **kwargs):
return openai_complete_if_cache(
LLM_MODEL_NAME,
prompt,
system_prompt=system_prompt,
history_messages=history_messages,
api_key=LLM_API_KEY,
base_url=LLM_BASE_URL,
**kwargs,
)
return llm_model_func
def create_embedding_func():
return EmbeddingFunc(
embedding_dim=1024,
max_token_size=8192,
func=lambda texts: openai_embed(
texts,
model=EMBED_MODEL_NAME,
api_key=EMBED_API_KEY,
base_url=EMBED_BASE_URL,
),
)
def create_vision_model_func(llm_model_func):
def vision_model_func(
prompt, system_prompt=None, history_messages=[], image_data=None, **kwargs
):
if image_data:
return openai_complete_if_cache(
VISION_MODEL_NAME,
"",
system_prompt=None,
history_messages=[],
messages=[
{"role": "system", "content": system_prompt}
if system_prompt
else None,
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{image_data}"
},
},
],
}
if image_data
else {"role": "user", "content": prompt},
],
api_key=VISION_API_KEY,
base_url=VISION_BASE_URL,
**kwargs,
)
else:
return llm_model_func(prompt, system_prompt, history_messages, **kwargs)
return vision_model_func
def format_exam_content(raw_text, output_path):
client = OpenAI(
api_key=LLM_API_KEY,
base_url=LLM_BASE_URL
)
"""
将OCR识别的原始试卷内容格式化为标准试题格式
参数:
client: OpenAI客户端实例
raw_text: OCR识别的原始文本
output_path: 输出文件路径
返回:
格式化后的试题内容
"""
prompt = """
我将提供一份markdown格式的试卷请帮我整理出每道题的以下内容
1. 题目序号
2. 题目内容(自动识别并添加$或$$包裹数学公式)
3. 选项(如果有)
4. 答案
5. 解析
要求:
- 一道题一道题输出,不要使用表格
- 自动检测数学表达式并用$或$$正确包裹
- 确保公式中的特殊字符正确转义
- 除题目内容外,不要输出其它无关信息
内容如下:
"""
prompt += raw_text
completion = client.chat.completions.create(
model="deepseek-v3",
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt},
],
)
formatted_content = completion.choices[0].message.content
with open(output_path, "w", encoding="utf-8") as f:
f.write(formatted_content)