main
HuangHai 3 weeks ago
parent eefda6c0e2
commit a26781be57

@ -128,7 +128,7 @@ class StreamLLMClient:
3. 返回严格JSON格式
{{
"problem_types": ["题型"],
"knowledge_points": ["匹配的知识点"],
"knowledge_points.sql": ["匹配的知识点"],
"literacy_points": ["匹配的素养点"]
}}
@ -205,13 +205,13 @@ class ProblemAnalyzer:
print("\n📊 分析结果:")
print(f" 题型: {analysis.get('problem_types', [])}")
print(f" 知识点: {analysis.get('knowledge_points', [])}")
print(f" 知识点: {analysis.get('knowledge_points.sql', [])}")
print(f" 素养点: {analysis.get('literacy_points', [])}")
self.kg.store_analysis(
question_id=self.question_id,
content=self.content,
knowledge=analysis.get('knowledge_points', []),
knowledge=analysis.get('knowledge_points.sql', []),
literacy=analysis.get('literacy_points', [])
)

@ -31,7 +31,7 @@ def query_all_questions():
OPTIONAL MATCH (q)-[:DEVELOPS_LITERACY]->(lp:LiteracyNode)
RETURN
q.content AS content,
collect(DISTINCT {id: kp.id, name: kp.name}) AS knowledge_points,
collect(DISTINCT {id: kp.id, name: kp.name}) AS knowledge_points.sql,
collect(DISTINCT {id: lp.value, title: lp.title}) AS literacy_points
""", qid=qid).data()
@ -39,7 +39,7 @@ def query_all_questions():
result = {
"question_id": qid,
"content": data[0]['content'],
"knowledge_points": data[0]['knowledge_points'],
"knowledge_points.sql": data[0]['knowledge_points.sql'],
"literacy_points": data[0]['literacy_points']
}
results.append(result)
@ -49,8 +49,8 @@ def query_all_questions():
print(f"📚 试题ID: {qid}")
print(f"📝 内容全文: {result['content']}") # 新增完整内容输出
print(f"🔍 内容摘要: {result['content'][:50]}...") # 保留摘要显示
print(f"🧠 知识点: {[kp['name'] for kp in result['knowledge_points']]}")
print(f"🆔 知识点ID: {[kp['id'] for kp in result['knowledge_points']]}")
print(f"🧠 知识点: {[kp['name'] for kp in result['knowledge_points.sql']]}")
print(f"🆔 知识点ID: {[kp['id'] for kp in result['knowledge_points.sql']]}")
print(f"🌟 素养点: {[lp['title'] for lp in result['literacy_points']]}")
print(f"🔢 素养点ID: {[lp['id'] for lp in result['literacy_points']]}")
print('=' * 90)

@ -0,0 +1,63 @@
import json
import logging
from typing import Dict, List
from Util.MySQLUtil import init_mysql_pool
# 配置日志
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
async def process_node(node: Dict, parent_id: str = None) -> List[Dict]:
"""处理单个知识点节点"""
knowledge_point = {
"id": node["value"],
"title": node["title"],
"parent_id": parent_id,
"is_leaf": node["isLeaf"],
"prerequisite": json.dumps(node.get("PREREQUISITE", [])),
"related": json.dumps(node.get("RELATED_TO", []))
}
children = []
if not node["isLeaf"] and "children" in node:
for child in node["children"]:
children.extend(await process_node(child, node["value"]))
return [knowledge_point] + children
async def insert_knowledge_points(mysql_pool, knowledge_points: List[Dict]):
"""批量插入知识点数据"""
async with mysql_pool.acquire() as conn:
await conn.ping()
async with conn.cursor() as cur:
for point in knowledge_points:
await cur.execute(
"""INSERT INTO knowledge_points
(id, title, parent_id, is_leaf, prerequisite, related)
VALUES (%s, %s, %s, %s, %s, %s)""",
(point["id"], point["title"], point["parent_id"],
point["is_leaf"], point["prerequisite"], point["related"])
)
await conn.commit()
async def main():
"""主函数"""
# 初始化MySQL连接池
mysql_pool = await init_mysql_pool()
# 读取JSON文件
with open("d:\\dsWork\\dsProject\\dsRag\\Neo4j\\小学数学知识点体系.json", "r", encoding="utf-8") as f:
data = json.load(f)
# 处理知识点数据
knowledge_points = []
for node in data["data"]["tree"]:
knowledge_points.extend(await process_node(node))
# 插入数据库
await insert_knowledge_points(mysql_pool, knowledge_points)
logger.info(f"成功插入 {len(knowledge_points)} 条知识点数据")
if __name__ == "__main__":
import asyncio
asyncio.run(main())

@ -0,0 +1,10 @@
CREATE TABLE knowledge_points (
id VARCHAR(32) PRIMARY KEY COMMENT '知识点唯一标识',
title VARCHAR(100) NOT NULL COMMENT '知识点标题',
parent_id VARCHAR(32) COMMENT '父节点ID',
is_leaf BOOLEAN NOT NULL COMMENT '是否为叶子节点',
sort INT COMMENT '排序字段',
prerequisite JSON COMMENT '先修知识点(仅一级节点)',
related JSON COMMENT '相关知识点(仅一级节点)',
KEY idx_parent_id (parent_id)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='小学数学知识点体系';

@ -2,10 +2,11 @@
pip install aiomysql
"""
import logging
from typing import Optional, Dict, List
from aiomysql import create_pool
from Config.Config import *
# 配置日志
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
@ -41,241 +42,5 @@ async def save_chat_to_mysql(mysql_pool, person_id, prompt, result, audio_url, d
await conn.commit()
# 清空表
async def truncate_chat_log(mysql_pool):
async with mysql_pool.acquire() as conn:
await conn.ping() # 重置连接
async with conn.cursor() as cur:
await cur.execute("TRUNCATE TABLE t_chat_log")
await conn.commit()
logger.info("表 t_chat_log 已清空。")
from aiomysql import DictCursor
# 分页查询聊天记录
async def get_chat_log_by_session(mysql_pool, person_id, page=1, page_size=10):
"""
根据 person_id 查询聊天记录并按 id 降序分页
:param mysql_pool: MySQL 连接池
:param person_id: 用户会话 ID
:param page: 当前页码默认值为 1但会动态计算为最后一页
:param page_size: 每页记录数
:return: 分页数据
"""
if not mysql_pool:
raise ValueError("MySQL 连接池未初始化")
async with mysql_pool.acquire() as conn:
await conn.ping() # 重置连接
async with conn.cursor(DictCursor) as cur: # 使用 DictCursor
# 查询总记录数
await cur.execute(
"SELECT COUNT(*) FROM t_chat_log WHERE person_id = %s",
(person_id,)
)
total = (await cur.fetchone())['COUNT(*)']
# 计算总页数
total_pages = (total + page_size - 1) // page_size
# 计算偏移量
offset = (page - 1) * page_size
# 查询分页数据,按 id 降序排列
await cur.execute(
"SELECT id, person_id, user_input, model_response, audio_url, duration,input_type,output_type,input_image_type,image_width,image_height, create_time "
"FROM t_chat_log WHERE person_id = %s ORDER BY id DESC LIMIT %s OFFSET %s",
(person_id, page_size, offset)
)
records = await cur.fetchall()
# 将查询结果反转,确保最新消息显示在最后
if page==1 and records:
records.reverse()
# 将查询结果转换为字典列表
result = [
{
"id": record['id'],
"person_id": record['person_id'],
"user_input": record['user_input'],
"model_response": record['model_response'],
"audio_url": record['audio_url'],
"duration": record['duration'],
"input_type": record['input_type'],
"output_type": record['output_type'],
"image_width": record['image_width'],
"image_height": record['image_height'],
"input_image_type": record['input_image_type'],
"create_time": record['create_time'].strftime("%Y-%m-%d %H:%M:%S")
}
for record in records
]
return {
"data": result, # 按 id 升序排列的数据
"total": total,
"page": page,
"page_size": page_size,
"total_pages": total_pages
}
# 获取指定会话的最后一条记录的 id
async def get_last_chat_log_id(mysql_pool, person_id):
"""
获取指定会话的最后一条记录的 id
:param mysql_pool: MySQL 连接池
:param session_id: 用户会话 ID
:return: 最后一条记录的 id如果未找到则返回 None
"""
async with mysql_pool.acquire() as conn:
await conn.ping() # 重置连接
async with conn.cursor() as cur:
await cur.execute(
"SELECT id FROM t_chat_log WHERE person_id = %s ORDER BY id DESC LIMIT 1",
(person_id,)
)
result = await cur.fetchone()
return result[0] if result else None
# 更新为危险的记录
async def update_risk(mysql_pool, person_id, risk_memo):
async with mysql_pool.acquire() as conn:
await conn.ping() # 重置连接
async with conn.cursor() as cur:
# 1. 获取此人员的最后一条记录 id
last_id = await get_last_chat_log_id(mysql_pool, person_id)
if last_id:
# 2. 更新 risk_flag 和 risk_memo
await cur.execute(
"UPDATE t_chat_log SET risk_flag = 1, risk_memo = %s WHERE id = %s",
(risk_memo.replace('\n', '').replace("NO", ""), last_id)
)
await conn.commit()
logger.info(f"已更新 person_id={person_id} 的最后一条记录 (id={last_id}) 的 risk_flag 和 risk_memo。")
else:
logger.warning(f"未找到 person_id={person_id} 的记录。")
# 查询用户信息
async def get_user_by_login_name(mysql_pool, login_name: str) -> Optional[Dict]:
"""
根据用户名查询用户信息
:param pool: MySQL 连接池
:param login_name: 用户名
:return: 用户信息字典形式
"""
async with mysql_pool.acquire() as conn:
await conn.ping() # 重置连接
async with conn.cursor() as cursor:
sql = "SELECT * FROM t_base_person WHERE login_name = %s"
await cursor.execute(sql, (login_name,))
row = await cursor.fetchone()
if not row:
return None
# 将元组转换为字典
columns = [column[0] for column in cursor.description]
return dict(zip(columns, row))
# 显示统计分析页面
async def get_chat_logs_summary(mysql_pool, risk_flag: int, offset: int, page_size: int) -> (List[Dict], int):
"""
获取聊天记录的统计分析结果
:param mysql_pool: MySQL 连接池
:param risk_flag: 风险标志
:param offset: 偏移量
:param page_size: 每页记录数
:return: 日志列表和总记录数
"""
async with mysql_pool.acquire() as conn:
await conn.ping() # 重置连接
async with conn.cursor() as cursor:
# 查询符合条件的记录
sql = """
SELECT tbp.*, COUNT(*) AS cnt
FROM t_chat_log AS tcl
INNER JOIN t_base_person AS tbp ON tcl.person_id = tbp.person_id
WHERE tcl.risk_flag = %s
GROUP BY tcl.person_id
ORDER BY COUNT(*) DESC
LIMIT %s OFFSET %s
"""
await cursor.execute(sql, (risk_flag, page_size, offset))
rows = await cursor.fetchall()
# 获取列名
columns = [column[0] for column in cursor.description]
# 查询总记录数
count_sql = """
SELECT COUNT(DISTINCT tcl.person_id)
FROM t_chat_log AS tcl
INNER JOIN t_base_person AS tbp ON tcl.person_id = tbp.person_id
WHERE tcl.risk_flag = %s
"""
await cursor.execute(count_sql, (risk_flag,))
total = (await cursor.fetchone())[0]
# 将元组转换为字典
logs = [dict(zip(columns, row)) for row in rows] if rows else []
return logs, total
async def get_chat_logs_by_risk_flag(mysql_pool, risk_flag: int, person_id: str, offset: int, page_size: int) -> (
List[Dict], int):
"""
根据风险标志查询聊天记录
:param mysql_pool: MySQL 连接池
:param risk_flag: 风险标志
:param offset: 分页偏移量
:param page_size: 每页记录数
:return: 聊天记录列表和总记录数
"""
async with mysql_pool.acquire() as conn:
await conn.ping() # 重置连接
async with conn.cursor() as cursor:
# 查询符合条件的记录
sql = """
SELECT tcl.id, tcl.user_input, tcl.model_response, tcl.audio_url, tcl.duration,
tcl.create_time, tcl.risk_flag, tcl.risk_memo, tcl.risk_result,
tbp.person_id, tbp.login_name, tbp.person_name
FROM t_chat_log AS tcl
INNER JOIN t_base_person AS tbp ON tcl.person_id = tbp.person_id
WHERE tcl.risk_flag = %s and tcl.person_id=%s ORDER BY TCL.ID DESC
LIMIT %s OFFSET %s
"""
await cursor.execute(sql, (risk_flag, person_id, page_size, offset))
rows = await cursor.fetchall()
# 在 count_sql 执行前获取列名
columns = [column[0] for column in cursor.description]
# 查询总记录数
count_sql = """
SELECT COUNT(*)
FROM t_chat_log AS tcl
INNER JOIN t_base_person AS tbp ON tcl.person_id = tbp.person_id
WHERE tcl.risk_flag = %s and tcl.person_id=%s
"""
await cursor.execute(count_sql, (risk_flag, person_id))
total = (await cursor.fetchone())[0]
# 将元组转换为字典,并格式化 create_time
if rows:
logs = []
for row in rows:
log = dict(zip(columns, row))
# 格式化 create_time
if log["create_time"]:
log["create_time"] = log["create_time"].strftime("%Y-%m-%d %H:%M:%S")
logs.append(log)
return logs, total
return [], 0

Loading…
Cancel
Save