main
HuangHai 7 days ago
parent 3406032f25
commit a45677716f

@ -36,17 +36,18 @@ app = FastAPI(lifespan=lifespan)
# 挂载静态文件目录
app.mount("/static", StaticFiles(directory="Static"), name="static")
# 访问根的跳转
@app.get("/")
async def redirect_to_ai():
return fastapi.responses.RedirectResponse(url="/static/ai.html")
@app.post("/api/rag")
async def rag(request: fastapi.Request):
data = await request.json()
logger.info(f"Received request: {data}")
workspace = data.get("topic", "ShiJi") # Chinese, Math ,ShiJi 默认是少年读史记
mode = data.get("mode", "hybrid") # 默认为hybrid模式
logger.info("工作空间:" + workspace)
# 查询的问题
query = data.get("query")
@ -108,9 +109,7 @@ async def rag(request: fastapi.Request):
rag = await initialize_pg_rag(WORKING_DIR=WORKING_DIR, workspace=workspace)
resp = await rag.aquery(
query=query,
param=QueryParam(mode=mode, stream=True, user_prompt=user_prompt))
# hybrid naive
param=QueryParam(mode="hybrid", stream=True, user_prompt=user_prompt))
async for chunk in resp:
if not chunk:
continue
@ -184,15 +183,15 @@ async def get_tree_data():
async with pg_pool.acquire() as conn:
# 执行查询
rows = await conn.fetch("""
SELECT id,
title,
parent_id,
is_leaf,
prerequisite,
related
FROM knowledge_points
ORDER BY parent_id, id
""")
SELECT id,
title,
parent_id,
is_leaf,
prerequisite,
related
FROM knowledge_points
ORDER BY parent_id, id
""")
# 构建节点映射
nodes = {}
for row in rows:
@ -249,10 +248,10 @@ async def update_knowledge(request: fastapi.Request):
async with pg_pool.acquire() as conn:
if update_type == 'prerequisite':
await conn.execute("""
UPDATE knowledge_points
SET prerequisite = $1
WHERE id = $2
""",
UPDATE knowledge_points
SET prerequisite = $1
WHERE id = $2
""",
json.dumps(
[{"id": p["id"], "title": p["title"]} for p in knowledge],
ensure_ascii=False
@ -260,10 +259,10 @@ async def update_knowledge(request: fastapi.Request):
node_id)
else: # related knowledge
await conn.execute("""
UPDATE knowledge_points
SET related = $1
WHERE id = $2
""",
UPDATE knowledge_points
SET related = $1
WHERE id = $2
""",
json.dumps(
[{"id": p["id"], "title": p["title"]} for p in knowledge],
ensure_ascii=False
@ -310,10 +309,11 @@ async def get_sources(page: int = 1, limit: int = 10):
offset = (page - 1) * limit
rows = await conn.fetch(
"""
SELECT id, account_id,account_name, created_at
FROM t_wechat_source
ORDER BY created_at DESC
LIMIT $1 OFFSET $2
SELECT id, account_id, account_name, created_at
FROM t_wechat_source
ORDER BY created_at DESC
LIMIT $1
OFFSET $2
""",
limit, offset
)
@ -352,11 +352,16 @@ async def get_articles(page: int = 1, limit: int = 10):
offset = (page - 1) * limit
rows = await conn.fetch(
"""
SELECT a.id, a.title, a.source as name,
a.publish_time, a.collection_time,a.url
SELECT a.id,
a.title,
a.source as name,
a.publish_time,
a.collection_time,
a.url
FROM t_wechat_articles a
ORDER BY a.collection_time DESC
LIMIT $1 OFFSET $2
ORDER BY a.collection_time DESC
LIMIT $1
OFFSET $2
""",
limit, offset
)

@ -152,14 +152,14 @@ os.environ["POSTGRES_PASSWORD"] = POSTGRES_PASSWORD
os.environ["POSTGRES_DATABASE"] = POSTGRES_DATABASE
async def initialize_pg_rag(WORKING_DIR, workspace='default'):
async def initialize_pg_rag(WORKING_DIR, workspace):
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=llm_model_func,
llm_model_name=LLM_MODEL_NAME,
llm_model_max_async=4,
llm_model_max_token_size=32768,
enable_llm_cache_for_entity_extract=True,
enable_llm_cache_for_entity_extract=False, # 这里黄海修改了一下,不知道是不是有用
embedding_func=EmbeddingFunc(
embedding_dim=EMBED_DIM,
max_token_size=EMBED_MAX_TOKEN_SIZE,

Loading…
Cancel
Save