main
黄海 5 months ago
parent 4580388489
commit b05f7f18d7

@ -14,11 +14,14 @@ class KnowledgeGraph:
self.knowledge_points = self._get_knowledge_points() self.knowledge_points = self._get_knowledge_points()
self.client = OpenAI(api_key=MODEL_API_KEY, base_url=MODEL_API_URL) self.client = OpenAI(api_key=MODEL_API_KEY, base_url=MODEL_API_URL)
#self.knowledge_points = self._get_knowledge_points()
print("加载知识点数量:", len(self.knowledge_points)) # 添加调试信息
def _get_knowledge_points(self) -> dict: def _get_knowledge_points(self) -> dict:
"""修复字段名称错误""" """保持ID原始大小写"""
try: try:
# 确保返回字段与节点属性名称匹配 # 移除lower()转换
return {row['n.id'].lower(): row['n.name'] # 修改为n.id return {row['n.id']: row['n.name'] # 直接使用原始ID
for row in self.graph.run("MATCH (n:KnowledgePoint) RETURN n.id, n.name")} for row in self.graph.run("MATCH (n:KnowledgePoint) RETURN n.id, n.name")}
except Exception as e: except Exception as e:
print(f"获取知识点失败:", str(e)) print(f"获取知识点失败:", str(e))
@ -42,45 +45,65 @@ MATCH (kp:KnowledgePoint {{id: "知识点ID"}})
MERGE (q)-[:TESTS_KNOWLEDGE]->(kp)""" MERGE (q)-[:TESTS_KNOWLEDGE]->(kp)"""
def _clean_cypher(self, code: str) -> str: def _clean_cypher(self, code: str) -> str:
"""修复WITH语句顺序问题""" """完整清洗逻辑"""
safe = [] safe = []
cypher_block = re.findall(r"```(?:cypher)?\n(.*?)```", code, re.DOTALL) cypher_block = re.findall(r"```(?:cypher)?\n(.*?)```", code, re.DOTALL)
if not cypher_block: if not cypher_block:
return "" return ""
# 强制先创建Question节点 # 预处理:获取所有知识点的规范大写形式
valid_ids_upper = [k.upper() for k in self.knowledge_points.keys()]
has_question = False has_question = False
for line in cypher_block[0].split('\n'): for line in cypher_block[0].split('\n'):
# 清理注释和空白
line = line.split('//')[0].strip() line = line.split('//')[0].strip()
if not line: if not line:
continue continue
# 确保Question节点最先创建 # 阻止CREATE操作
if 'CREATE' in line.upper():
continue
# 强制Question节点在最前面
if 'MERGE (q:Question' in line: if 'MERGE (q:Question' in line:
has_question = True has_question = True
safe.insert(0, line) # 确保这行在最前面 safe.insert(0, line)
continue continue
# 安全过滤 # 处理知识点匹配
if 'CREATE' in line.upper(): if 'MATCH (kp:KnowledgePoint' in line:
# 提取并验证ID
kp_id_match = re.search(r"id: ['\"](.*?)['\"]", line)
if kp_id_match:
original_id = kp_id_match.group(1)
upper_id = original_id.upper()
# 验证存在性(不区分大小写)
if upper_id not in valid_ids_upper:
print(f"忽略无效知识点ID: {original_id}")
continue continue
# 自动补全WITH语句仅在Question创建之后 # 替换为数据库实际存储的大写ID
if has_question and 'MERGE (q)-[:TESTS_KNOWLEDGE]' in line and not any('WITH q' in l for l in safe): line = line.replace(original_id, upper_id)
safe.append("WITH q")
# ID存在性验证 # 自动补全WITH语句
if 'MATCH (kp:KnowledgePoint' in line: if has_question and 'MERGE (q)-[:TESTS_KNOWLEDGE]' in line:
kp_id = re.findall(r"id: ['\"](.*?)['\"]", line) if not any('WITH q' in l for l in safe):
if kp_id and kp_id[0] not in self.knowledge_points: safe.append("WITH q")
continue
safe.append(line) safe.append(line)
# 补充必要WITH语句 # 确保Question节点后紧跟WITH
if has_question and not any(line.startswith('WITH q') for line in safe): if has_question:
safe.insert(1, "WITH q") # 在创建Question之后立即添加 # 在MERGE (q:Question)之后插入WITH
for i, line in enumerate(safe):
if 'MERGE (q:Question' in line:
if i + 1 >= len(safe) or not safe[i + 1].startswith('WITH'):
safe.insert(i + 1, "WITH q")
break
# 最终过滤空行
return '\n'.join([line for line in safe if line]) return '\n'.join([line for line in safe if line])
def run(self) -> str: def run(self) -> str:
@ -134,15 +157,24 @@ def query_related_knowledge(self):
# 测试用例 # 测试用例
if __name__ == '__main__': if __name__ == '__main__':
test_case = """【时间问题】甲乙两车从相距240公里的两地同时出发相向而行甲车时速60公里乙车时速40公里几小时后相遇""" test_case = """【时间问题】甲乙两车从相距240公里的两地同时出发相向而行甲车时速60公里乙车时速40公里几小时后相遇"""
test_case = """【时间问题】甲乙两车从相距240公里的两地同时出发相向而行..."""
kg = KnowledgeGraph(test_case) kg = KnowledgeGraph(test_case)
cypher = kg.run() cypher = kg.run()
if cypher: if cypher:
# 插入数据
kg.graph.run(cypher) kg.graph.run(cypher)
print("执行成功!关联知识点:") print("执行成功!关联知识点:")
kg.query_related_knowledge() # 新增查询 kg.query_related_knowledge() # 新增查询
else: else:
print("未生成有效Cypher") print("未生成有效Cypher")
# # 临时诊断
# print("当前知识库中是否存在该ID",
# 'f0333b305f7246b5a06d03d4e3ff55a9' in kg.knowledge_points)
#
# # 直接查询数据库
# test_cypher = '''
# MATCH (kp:KnowledgePoint)
# WHERE kp.id = 'f0333b305f7246b5a06d03d4e3ff55a9'
# RETURN kp.id, kp.name
# '''
# print("直接查询结果:", kg.graph.run(test_cypher).data())
Loading…
Cancel
Save