You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

180 lines
6.6 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# -*- coding: utf-8 -*-
import re
import hashlib
from py2neo import Graph
from openai import OpenAI
from Config import *
class KnowledgeGraph:
def __init__(self, content: str):
self.content = content
self.question_id = hashlib.md5(content.encode()).hexdigest()[:8]
self.graph = Graph(NEO4J_URI, auth=NEO4J_AUTH)
self.knowledge_points = self._get_knowledge_points()
self.client = OpenAI(api_key=MODEL_API_KEY, base_url=MODEL_API_URL)
#self.knowledge_points = self._get_knowledge_points()
print("加载知识点数量:", len(self.knowledge_points)) # 添加调试信息
def _get_knowledge_points(self) -> dict:
"""保持ID原始大小写"""
try:
# 移除lower()转换
return {row['n.id']: row['n.name'] # 直接使用原始ID
for row in self.graph.run("MATCH (n:KnowledgePoint) RETURN n.id, n.name")}
except Exception as e:
print(f"获取知识点失败:", str(e))
return {}
def _make_prompt(self) -> str:
"""生成知识点识别专用提示词"""
example_ids = list(self.knowledge_points.keys())[:5]
example_names = [self.knowledge_points[k] for k in example_ids]
return f"""你是一个数学专家,请分析题目考查的知识点,严格:
1. 只使用以下存在的知识点格式ID:名称):
{", ".join([f"{k}:{v}" for k, v in zip(example_ids, example_names)])}...
{len(self.knowledge_points)}个可用知识点
2. 按此格式生成Cypher
MERGE (q:Question {{id: "{self.question_id}"}})
SET q.content = "题目内容"
WITH q
MATCH (kp:KnowledgePoint {{id: "知识点ID"}})
MERGE (q)-[:TESTS_KNOWLEDGE]->(kp)"""
def _clean_cypher(self, code: str) -> str:
"""完整清洗逻辑"""
safe = []
cypher_block = re.findall(r"```(?:cypher)?\n(.*?)```", code, re.DOTALL)
if not cypher_block:
return ""
# 预处理:获取所有知识点的规范大写形式
valid_ids_upper = [k.upper() for k in self.knowledge_points.keys()]
has_question = False
for line in cypher_block[0].split('\n'):
# 清理注释和空白
line = line.split('//')[0].strip()
if not line:
continue
# 阻止CREATE操作
if 'CREATE' in line.upper():
continue
# 强制Question节点在最前面
if 'MERGE (q:Question' in line:
has_question = True
safe.insert(0, line)
continue
# 处理知识点匹配
if 'MATCH (kp:KnowledgePoint' in line:
# 提取并验证ID
kp_id_match = re.search(r"id: ['\"](.*?)['\"]", line)
if kp_id_match:
original_id = kp_id_match.group(1)
upper_id = original_id.upper()
# 验证存在性(不区分大小写)
if upper_id not in valid_ids_upper:
print(f"忽略无效知识点ID: {original_id}")
continue
# 替换为数据库实际存储的大写ID
line = line.replace(original_id, upper_id)
# 自动补全WITH语句
if has_question and 'MERGE (q)-[:TESTS_KNOWLEDGE]' in line:
if not any('WITH q' in l for l in safe):
safe.append("WITH q")
safe.append(line)
# 确保Question节点后紧跟WITH
if has_question:
# 在MERGE (q:Question)之后插入WITH
for i, line in enumerate(safe):
if 'MERGE (q:Question' in line:
if i + 1 >= len(safe) or not safe[i + 1].startswith('WITH'):
safe.insert(i + 1, "WITH q")
break
# 最终过滤空行
return '\n'.join([line for line in safe if line])
def run(self) -> str:
"""执行知识点关联流程"""
try:
response = self.client.chat.completions.create(
model=MODEL_NAME,
messages=[
{
"role": "system",
"content": self._make_prompt()
},
{
"role": "user",
"content": f"题目内容:{self.content}\n请分析考查的知识点只返回Cypher代码"
}
]
)
raw_cypher = response.choices[0].message.content
cleaned_cypher = self._clean_cypher(raw_cypher)
if cleaned_cypher:
print("验证通过的Cypher\n", cleaned_cypher)
return cleaned_cypher
return ""
except Exception as e:
print("知识点分析失败:", str(e))
return ""
def query_related_knowledge(self):
"""查询题目关联的知识点"""
cypher = f"""
MATCH (q:Question {{id: "{self.question_id}"}})-[:TESTS_KNOWLEDGE]->(kp)
RETURN kp.id AS knowledge_id, kp.name AS knowledge_name
"""
try:
result = self.graph.run(cypher).data()
if result:
print(f"题目关联的知识点({self.question_id}")
for row in result:
print(f"- {row['knowledge_name']} (ID: {row['knowledge_id']})")
else:
print("该题目尚未关联知识点")
return result
except Exception as e:
print("查询失败:", str(e))
return []
# 测试用例
if __name__ == '__main__':
test_case = """【时间问题】甲乙两车从相距240公里的两地同时出发相向而行甲车时速60公里乙车时速40公里几小时后相遇"""
kg = KnowledgeGraph(test_case)
cypher = kg.run()
if cypher:
# 插入数据
kg.graph.run(cypher)
print("执行成功!关联知识点:")
kg.query_related_knowledge() # 新增查询
else:
print("未生成有效Cypher")
# # 临时诊断
# print("当前知识库中是否存在该ID",
# 'f0333b305f7246b5a06d03d4e3ff55a9' in kg.knowledge_points)
#
# # 直接查询数据库
# test_cypher = '''
# MATCH (kp:KnowledgePoint)
# WHERE kp.id = 'f0333b305f7246b5a06d03d4e3ff55a9'
# RETURN kp.id, kp.name
# '''
# print("直接查询结果:", kg.graph.run(test_cypher).data())