HuangHai 1 week ago
commit 305b126d32

@ -78,3 +78,38 @@ async def get_question(request: Request):
print(select_question_sql) print(select_question_sql)
page = await get_page_data_by_sql(select_question_sql, page_number, page_size) page = await get_page_data_by_sql(select_question_sql, page_number, page_size)
return {"success": True, "message": "查询成功!", "data": page} return {"success": True, "message": "查询成功!", "data": page}
# 【TeachingModel-5】提问
@router.post("/sendQuestion")
async def send_question(request: Request):
# 获取参数
bureau_id = await get_request_str_param(request, "bureau_id", True, True)
person_id = await get_request_str_param(request, "person_id", True, True)
theme_id = await get_request_num_param(request, "theme_id", True, True, None)
question = await get_request_str_param(request, "question_type", True, True)
theme_object = await find_by_id("t_ai_teaching_model_theme", "id", theme_id)
if theme_object is None:
return {"success": False, "message": "主题不存在!"}
# 保存个人历史问题
param = {}
param["stage_id"] = int(theme_object["stage_id"])
param["subject_id"] = int(theme_object["subject_id"])
param["theme_id"] = theme_id
param["question"] = question
param["question_type"] = 2
param["question_person_id"] = person_id
param["person_id"] = person_id
param["bureau_id"] = bureau_id
question_id = await insert("t_ai_teaching_model_question", param)
# 处理theme的调用次数
# 向rag提问

@ -37,12 +37,17 @@ app = FastAPI(lifespan=lifespan)
app.mount("/static", StaticFiles(directory="Static"), name="static") app.mount("/static", StaticFiles(directory="Static"), name="static")
# 访问根的跳转
@app.get("/")
async def redirect_to_ai():
return fastapi.responses.RedirectResponse(url="/static/ai.html")
@app.post("/api/rag") @app.post("/api/rag")
async def rag(request: fastapi.Request): async def rag(request: fastapi.Request):
data = await request.json() data = await request.json()
logger.info(f"Received request: {data}")
workspace = data.get("topic", "ShiJi") # Chinese, Math ,ShiJi 默认是少年读史记 workspace = data.get("topic", "ShiJi") # Chinese, Math ,ShiJi 默认是少年读史记
mode = data.get("mode", "hybrid") # 默认为hybrid模式 logger.info("工作空间:" + workspace)
# 查询的问题 # 查询的问题
query = data.get("query") query = data.get("query")
@ -104,9 +109,7 @@ async def rag(request: fastapi.Request):
rag = await initialize_pg_rag(WORKING_DIR=WORKING_DIR, workspace=workspace) rag = await initialize_pg_rag(WORKING_DIR=WORKING_DIR, workspace=workspace)
resp = await rag.aquery( resp = await rag.aquery(
query=query, query=query,
param=QueryParam(mode=mode, stream=True, user_prompt=user_prompt)) param=QueryParam(mode="hybrid", stream=True, user_prompt=user_prompt))
# hybrid naive
async for chunk in resp: async for chunk in resp:
if not chunk: if not chunk:
continue continue
@ -306,10 +309,11 @@ async def get_sources(page: int = 1, limit: int = 10):
offset = (page - 1) * limit offset = (page - 1) * limit
rows = await conn.fetch( rows = await conn.fetch(
""" """
SELECT id, account_id,account_name, created_at SELECT id, account_id, account_name, created_at
FROM t_wechat_source FROM t_wechat_source
ORDER BY created_at DESC ORDER BY created_at DESC
LIMIT $1 OFFSET $2 LIMIT $1
OFFSET $2
""", """,
limit, offset limit, offset
) )
@ -348,11 +352,16 @@ async def get_articles(page: int = 1, limit: int = 10):
offset = (page - 1) * limit offset = (page - 1) * limit
rows = await conn.fetch( rows = await conn.fetch(
""" """
SELECT a.id, a.title, a.source as name, SELECT a.id,
a.publish_time, a.collection_time,a.url a.title,
a.source as name,
a.publish_time,
a.collection_time,
a.url
FROM t_wechat_articles a FROM t_wechat_articles a
ORDER BY a.collection_time DESC ORDER BY a.collection_time DESC
LIMIT $1 OFFSET $2 LIMIT $1
OFFSET $2
""", """,
limit, offset limit, offset
) )

@ -152,14 +152,14 @@ os.environ["POSTGRES_PASSWORD"] = POSTGRES_PASSWORD
os.environ["POSTGRES_DATABASE"] = POSTGRES_DATABASE os.environ["POSTGRES_DATABASE"] = POSTGRES_DATABASE
async def initialize_pg_rag(WORKING_DIR, workspace='default'): async def initialize_pg_rag(WORKING_DIR, workspace):
rag = LightRAG( rag = LightRAG(
working_dir=WORKING_DIR, working_dir=WORKING_DIR,
llm_model_func=llm_model_func, llm_model_func=llm_model_func,
llm_model_name=LLM_MODEL_NAME, llm_model_name=LLM_MODEL_NAME,
llm_model_max_async=4, llm_model_max_async=4,
llm_model_max_token_size=32768, llm_model_max_token_size=32768,
enable_llm_cache_for_entity_extract=True, enable_llm_cache_for_entity_extract=False, # 这里黄海修改了一下,不知道是不是有用
embedding_func=EmbeddingFunc( embedding_func=EmbeddingFunc(
embedding_dim=EMBED_DIM, embedding_dim=EMBED_DIM,
max_token_size=EMBED_MAX_TOKEN_SIZE, max_token_size=EMBED_MAX_TOKEN_SIZE,

Loading…
Cancel
Save