main
HuangHai 2 weeks ago
parent 3c655d4ca1
commit 0f184f08a6

@ -43,14 +43,16 @@ async def print_stream(stream):
app = FastAPI(lifespan=lifespan)
# 挂载静态文件目录
app.mount("/static", StaticFiles(directory="Static"), name="static")
app.mount("../static", StaticFiles(directory="Static"), name="static")
@app.post("/api/rag")
async def rag(request: fastapi.Request):
data = await request.json()
workspace = data.get("topic") # Chinese, Math
topic = data.get("topic") # Chinese, Math
mode = data.get("mode", "hybrid") # 默认为hybrid模式
# 拼接路径
WORKING_PATH = "./Topic/" + topic
# 查询的问题
query = data.get("query")
# 关闭参考资料
@ -60,12 +62,17 @@ async def rag(request: fastapi.Request):
user_prompt = user_prompt + "\n 4、资料中提到的知识内容需要判断是否与本次问题相关不相关的绝对不要输出"
user_prompt = user_prompt + "\n 5、如果问题与提供的知识库内容不符则明确告诉未在知识库范围内提到"
user_prompt = user_prompt + "\n 6、发现输出内容中包含Latex公式的一定要检查是不是包含了$$或$的包含符号,不能让Latex无包含符号出现"
# 使用PG库后这个是没有用的,但目前的项目代码要求必传,就写一个吧。
WORKING_DIR = f"./output"
async def generate_response_stream(query: str):
try:
rag = await initialize_pg_rag(WORKING_DIR=WORKING_DIR, workspace=workspace)
rag = LightRAG(
working_dir=WORKING_PATH,
llm_model_func=create_llm_model_func(),
embedding_func=create_embedding_func()
)
await rag.initialize_storages()
await initialize_pipeline_status()
resp = await rag.aquery(
query=query,
param=QueryParam(mode=mode, stream=True, user_prompt=user_prompt))
@ -140,7 +147,7 @@ async def get_tree_data():
try:
pg_pool = await init_postgres_pool()
async with pg_pool.acquire() as conn:
# 执行查询
# 执行查询
rows = await conn.fetch("""
SELECT id,
title,
@ -211,22 +218,23 @@ async def update_knowledge(request: fastapi.Request):
SET prerequisite = $1
WHERE id = $2
""",
json.dumps(
[{"id": p["id"], "title": p["title"]} for p in knowledge],
ensure_ascii=False
),
node_id)
json.dumps(
[{"id": p["id"], "title": p["title"]} for p in knowledge],
ensure_ascii=False
),
node_id)
else: # related knowledge
await conn.execute("""
UPDATE knowledge_points
SET related = $1
WHERE id = $2
""",
json.dumps(
[{"id": p["id"], "title": p["title"]} for p in knowledge],
ensure_ascii=False
),
node_id)
json.dumps(
[{"id": p["id"], "title": p["title"]} for p in knowledge],
ensure_ascii=False
),
node_id)
return {"code": 0, "msg": "更新成功"}
except Exception as e:

@ -0,0 +1,46 @@
import asyncio
import inspect
from Util.LightRagUtil import configure_logging, initialize_rag, print_stream
from lightrag import QueryParam
# 化学
data = [
# {"NAME": "Chemistry", "Q": "硝酸光照分解的化学反应方程式是什么", "ChineseName": "化学"},
{"NAME": "Chemistry", "Q": "氢气与氧气燃烧的现象", "ChineseName": "化学"},
{"NAME": "Math", "Q": "氧化铁与硝酸的化学反应方程式是什么", "ChineseName": "数学"},
{"NAME": "Chinese", "Q": "氧化铁与硝酸的化学反应方程式是什么", "ChineseName": "语文"},
{"NAME": "JiHe", "Q": "三角形两边之和大于第三边的证明", "ChineseName": "几何"}
]
# 准备查询的科目
KEMU = "JiHe" # Chemistry JiHe
# 查找索引号
idx = [i for i, d in enumerate(data) if d["NAME"] == KEMU][0]
async def main():
try:
user_prompt = "\n 1、资料中提供化学反应方程式的一定要严格按提供的Latex公式输出绝对不允许对Latex公式进行修改 "
user_prompt = user_prompt + "\n 2、如果资料中提供了图片的一定要严格按照原文提供图片输出不允许省略或不输出"
user_prompt = user_prompt + "\n 3、资料中提到的知识内容需要判断是否与本次问题相关不相关的绝对不要输出"
rag = await initialize_rag('Topic/' + data[idx]["NAME"])
resp = await rag.aquery(
data[idx]["Q"],
param=QueryParam(mode="hybrid", stream=True, user_prompt=user_prompt),
# hybrid naive
)
if inspect.isasyncgen(resp):
await print_stream(resp)
else:
print(resp)
except Exception as e:
print(f"An error occurred: {e}")
finally:
if rag:
await rag.finalize_storages()
if __name__ == "__main__":
configure_logging()
asyncio.run(main())

@ -49,10 +49,8 @@ app.mount("/static", StaticFiles(directory="Static"), name="static")
@app.post("/api/rag")
async def rag(request: fastapi.Request):
data = await request.json()
topic = data.get("topic") # Chinese, Math
workspace = data.get("topic") # Chinese, Math
mode = data.get("mode", "hybrid") # 默认为hybrid模式
# 拼接路径
WORKING_PATH = "./Topic/" + topic
# 查询的问题
query = data.get("query")
# 关闭参考资料
@ -62,17 +60,12 @@ async def rag(request: fastapi.Request):
user_prompt = user_prompt + "\n 4、资料中提到的知识内容需要判断是否与本次问题相关不相关的绝对不要输出"
user_prompt = user_prompt + "\n 5、如果问题与提供的知识库内容不符则明确告诉未在知识库范围内提到"
user_prompt = user_prompt + "\n 6、发现输出内容中包含Latex公式的一定要检查是不是包含了$$或$的包含符号,不能让Latex无包含符号出现"
# 使用PG库后这个是没有用的,但目前的项目代码要求必传,就写一个吧。
WORKING_DIR = f"./output"
async def generate_response_stream(query: str):
try:
rag = LightRAG(
working_dir=WORKING_PATH,
llm_model_func=create_llm_model_func(),
embedding_func=create_embedding_func()
)
await rag.initialize_storages()
await initialize_pipeline_status()
rag = await initialize_pg_rag(WORKING_DIR=WORKING_DIR, workspace=workspace)
resp = await rag.aquery(
query=query,
param=QueryParam(mode=mode, stream=True, user_prompt=user_prompt))
@ -147,7 +140,7 @@ async def get_tree_data():
try:
pg_pool = await init_postgres_pool()
async with pg_pool.acquire() as conn:
# 执行查询
# 执行查询
rows = await conn.fetch("""
SELECT id,
title,
@ -218,23 +211,22 @@ async def update_knowledge(request: fastapi.Request):
SET prerequisite = $1
WHERE id = $2
""",
json.dumps(
[{"id": p["id"], "title": p["title"]} for p in knowledge],
ensure_ascii=False
),
node_id)
json.dumps(
[{"id": p["id"], "title": p["title"]} for p in knowledge],
ensure_ascii=False
),
node_id)
else: # related knowledge
await conn.execute("""
UPDATE knowledge_points
SET related = $1
WHERE id = $2
""",
json.dumps(
[{"id": p["id"], "title": p["title"]} for p in knowledge],
ensure_ascii=False
),
node_id)
json.dumps(
[{"id": p["id"], "title": p["title"]} for p in knowledge],
ensure_ascii=False
),
node_id)
return {"code": 0, "msg": "更新成功"}
except Exception as e:

@ -1,34 +1,36 @@
import asyncio
import inspect
from Util.LightRagUtil import configure_logging, initialize_rag, print_stream
import logging
from lightrag import QueryParam
# 化学
data = [
# {"NAME": "Chemistry", "Q": "硝酸光照分解的化学反应方程式是什么", "ChineseName": "化学"},
{"NAME": "Chemistry", "Q": "氢气与氧气燃烧的现象", "ChineseName": "化学"},
{"NAME": "Math", "Q": "氧化铁与硝酸的化学反应方程式是什么", "ChineseName": "数学"},
{"NAME": "Chinese", "Q": "氧化铁与硝酸的化学反应方程式是什么", "ChineseName": "语文"},
{"NAME": "JiHe", "Q": "三角形两边之和大于第三边的证明", "ChineseName": "几何"}
]
from Util.LightRagUtil import configure_logging, print_stream, initialize_pg_rag
# 在程序开始时添加以下配置
logging.basicConfig(
level=logging.INFO, # 设置日志级别为INFO
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
# 或者如果你想更详细地控制日志输出
logger = logging.getLogger('lightrag')
logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
logger.addHandler(handler)
# 准备查询的科目
KEMU = "JiHe" # Chemistry JiHe
WORKING_DIR = f"./dsWorking"
# 查找索引号
idx = [i for i, d in enumerate(data) if d["NAME"] == KEMU][0]
logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO)
async def main():
try:
user_prompt = "\n 1、资料中提供化学反应方程式的一定要严格按提供的Latex公式输出绝对不允许对Latex公式进行修改 "
user_prompt = user_prompt + "\n 2、如果资料中提供了图片的一定要严格按照原文提供图片输出不允许省略或不输出"
user_prompt = user_prompt + "\n 3、资料中提到的知识内容需要判断是否与本次问题相关不相关的绝对不要输出"
rag = await initialize_rag('Topic/' + data[idx]["NAME"])
rag = await initialize_pg_rag(WORKING_DIR=WORKING_DIR, workspace='SuShi')
resp = await rag.aquery(
data[idx]["Q"],
param=QueryParam(mode="hybrid", stream=True, user_prompt=user_prompt),
# hybrid naive
# "苏轼的家人都有谁?",
"苏轼与美食",
param=QueryParam(mode="hybrid", stream=True),
)
if inspect.isasyncgen(resp):
await print_stream(resp)

@ -1,48 +0,0 @@
import asyncio
import inspect
import logging
from lightrag import QueryParam
from Util.LightRagUtil import configure_logging, print_stream, initialize_pg_rag
# 在程序开始时添加以下配置
logging.basicConfig(
level=logging.INFO, # 设置日志级别为INFO
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
# 或者如果你想更详细地控制日志输出
logger = logging.getLogger('lightrag')
logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
logger.addHandler(handler)
WORKING_DIR = f"./dsWorking"
logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO)
async def main():
try:
rag = await initialize_pg_rag(WORKING_DIR=WORKING_DIR, workspace='SuShi')
resp = await rag.aquery(
# "苏轼的家人都有谁?",
"苏轼与美食",
param=QueryParam(mode="hybrid", stream=True),
)
if inspect.isasyncgen(resp):
await print_stream(resp)
else:
print(resp)
except Exception as e:
print(f"An error occurred: {e}")
finally:
if rag:
await rag.finalize_storages()
if __name__ == "__main__":
configure_logging()
asyncio.run(main())
Loading…
Cancel
Save