You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

180 lines
6.6 KiB

5 months ago
# -*- coding: utf-8 -*-
import re
import hashlib
from py2neo import Graph
from openai import OpenAI
from Config import *
class KnowledgeGraph:
def __init__(self, content: str):
self.content = content
self.question_id = hashlib.md5(content.encode()).hexdigest()[:8]
self.graph = Graph(NEO4J_URI, auth=NEO4J_AUTH)
self.knowledge_points = self._get_knowledge_points()
self.client = OpenAI(api_key=MODEL_API_KEY, base_url=MODEL_API_URL)
5 months ago
#self.knowledge_points = self._get_knowledge_points()
print("加载知识点数量:", len(self.knowledge_points)) # 添加调试信息
5 months ago
def _get_knowledge_points(self) -> dict:
5 months ago
"""保持ID原始大小写"""
5 months ago
try:
5 months ago
# 移除lower()转换
return {row['n.id']: row['n.name'] # 直接使用原始ID
5 months ago
for row in self.graph.run("MATCH (n:KnowledgePoint) RETURN n.id, n.name")}
except Exception as e:
print(f"获取知识点失败:", str(e))
return {}
def _make_prompt(self) -> str:
"""生成知识点识别专用提示词"""
example_ids = list(self.knowledge_points.keys())[:5]
example_names = [self.knowledge_points[k] for k in example_ids]
return f"""你是一个数学专家,请分析题目考查的知识点,严格:
1. 只使用以下存在的知识点格式ID:名称
{", ".join([f"{k}:{v}" for k, v in zip(example_ids, example_names)])}...
{len(self.knowledge_points)}个可用知识点
2. 按此格式生成Cypher
MERGE (q:Question {{id: "{self.question_id}"}})
SET q.content = "题目内容"
WITH q
MATCH (kp:KnowledgePoint {{id: "知识点ID"}})
MERGE (q)-[:TESTS_KNOWLEDGE]->(kp)"""
def _clean_cypher(self, code: str) -> str:
5 months ago
"""完整清洗逻辑"""
5 months ago
safe = []
cypher_block = re.findall(r"```(?:cypher)?\n(.*?)```", code, re.DOTALL)
if not cypher_block:
return ""
5 months ago
# 预处理:获取所有知识点的规范大写形式
valid_ids_upper = [k.upper() for k in self.knowledge_points.keys()]
5 months ago
has_question = False
for line in cypher_block[0].split('\n'):
5 months ago
# 清理注释和空白
5 months ago
line = line.split('//')[0].strip()
if not line:
continue
5 months ago
# 阻止CREATE操作
if 'CREATE' in line.upper():
continue
# 强制Question节点在最前面
5 months ago
if 'MERGE (q:Question' in line:
has_question = True
5 months ago
safe.insert(0, line)
5 months ago
continue
5 months ago
# 处理知识点匹配
if 'MATCH (kp:KnowledgePoint' in line:
# 提取并验证ID
kp_id_match = re.search(r"id: ['\"](.*?)['\"]", line)
if kp_id_match:
original_id = kp_id_match.group(1)
upper_id = original_id.upper()
5 months ago
5 months ago
# 验证存在性(不区分大小写)
if upper_id not in valid_ids_upper:
print(f"忽略无效知识点ID: {original_id}")
continue
5 months ago
5 months ago
# 替换为数据库实际存储的大写ID
line = line.replace(original_id, upper_id)
# 自动补全WITH语句
if has_question and 'MERGE (q)-[:TESTS_KNOWLEDGE]' in line:
if not any('WITH q' in l for l in safe):
safe.append("WITH q")
5 months ago
safe.append(line)
5 months ago
# 确保Question节点后紧跟WITH
if has_question:
# 在MERGE (q:Question)之后插入WITH
for i, line in enumerate(safe):
if 'MERGE (q:Question' in line:
if i + 1 >= len(safe) or not safe[i + 1].startswith('WITH'):
safe.insert(i + 1, "WITH q")
break
5 months ago
5 months ago
# 最终过滤空行
5 months ago
return '\n'.join([line for line in safe if line])
def run(self) -> str:
"""执行知识点关联流程"""
try:
response = self.client.chat.completions.create(
model=MODEL_NAME,
messages=[
{
"role": "system",
"content": self._make_prompt()
},
{
"role": "user",
"content": f"题目内容:{self.content}\n请分析考查的知识点只返回Cypher代码"
}
]
)
raw_cypher = response.choices[0].message.content
cleaned_cypher = self._clean_cypher(raw_cypher)
if cleaned_cypher:
print("验证通过的Cypher\n", cleaned_cypher)
return cleaned_cypher
return ""
except Exception as e:
print("知识点分析失败:", str(e))
return ""
5 months ago
def query_related_knowledge(self):
"""查询题目关联的知识点"""
cypher = f"""
MATCH (q:Question {{id: "{self.question_id}"}})-[:TESTS_KNOWLEDGE]->(kp)
RETURN kp.id AS knowledge_id, kp.name AS knowledge_name
"""
try:
result = self.graph.run(cypher).data()
if result:
print(f"题目关联的知识点({self.question_id}")
for row in result:
print(f"- {row['knowledge_name']} (ID: {row['knowledge_id']})")
else:
print("该题目尚未关联知识点")
return result
except Exception as e:
print("查询失败:", str(e))
return []
5 months ago
# 测试用例
if __name__ == '__main__':
test_case = """【时间问题】甲乙两车从相距240公里的两地同时出发相向而行甲车时速60公里乙车时速40公里几小时后相遇"""
kg = KnowledgeGraph(test_case)
cypher = kg.run()
if cypher:
5 months ago
# 插入数据
5 months ago
kg.graph.run(cypher)
print("执行成功!关联知识点:")
kg.query_related_knowledge() # 新增查询
else:
5 months ago
print("未生成有效Cypher")
# # 临时诊断
# print("当前知识库中是否存在该ID",
# 'f0333b305f7246b5a06d03d4e3ff55a9' in kg.knowledge_points)
#
# # 直接查询数据库
# test_cypher = '''
# MATCH (kp:KnowledgePoint)
# WHERE kp.id = 'f0333b305f7246b5a06d03d4e3ff55a9'
# RETURN kp.id, kp.name
# '''
# print("直接查询结果:", kg.graph.run(test_cypher).data())