# -*- coding: utf-8 -*- import re import hashlib from py2neo import Graph from openai import OpenAI from Config import * class KnowledgeGraph: def __init__(self, content: str): self.content = content self.question_id = hashlib.md5(content.encode()).hexdigest()[:8] self.graph = Graph(NEO4J_URI, auth=NEO4J_AUTH) self.knowledge_points = self._get_knowledge_points() self.client = OpenAI(api_key=MODEL_API_KEY, base_url=MODEL_API_URL) #self.knowledge_points = self._get_knowledge_points() print("加载知识点数量:", len(self.knowledge_points)) # 添加调试信息 def _get_knowledge_points(self) -> dict: """保持ID原始大小写""" try: # 移除lower()转换 return {row['n.id']: row['n.name'] # 直接使用原始ID for row in self.graph.run("MATCH (n:KnowledgePoint) RETURN n.id, n.name")} except Exception as e: print(f"获取知识点失败:", str(e)) return {} def _make_prompt(self) -> str: """生成知识点识别专用提示词""" example_ids = list(self.knowledge_points.keys())[:5] example_names = [self.knowledge_points[k] for k in example_ids] return f"""你是一个数学专家,请分析题目考查的知识点,严格: 1. 只使用以下存在的知识点(格式:ID:名称): {", ".join([f"{k}:{v}" for k, v in zip(example_ids, example_names)])}... 共{len(self.knowledge_points)}个可用知识点 2. 按此格式生成Cypher: MERGE (q:Question {{id: "{self.question_id}"}}) SET q.content = "题目内容" WITH q MATCH (kp:KnowledgePoint {{id: "知识点ID"}}) MERGE (q)-[:TESTS_KNOWLEDGE]->(kp)""" def _clean_cypher(self, code: str) -> str: """完整清洗逻辑""" safe = [] cypher_block = re.findall(r"```(?:cypher)?\n(.*?)```", code, re.DOTALL) if not cypher_block: return "" # 预处理:获取所有知识点的规范大写形式 valid_ids_upper = [k.upper() for k in self.knowledge_points.keys()] has_question = False for line in cypher_block[0].split('\n'): # 清理注释和空白 line = line.split('//')[0].strip() if not line: continue # 阻止CREATE操作 if 'CREATE' in line.upper(): continue # 强制Question节点在最前面 if 'MERGE (q:Question' in line: has_question = True safe.insert(0, line) continue # 处理知识点匹配 if 'MATCH (kp:KnowledgePoint' in line: # 提取并验证ID kp_id_match = re.search(r"id: ['\"](.*?)['\"]", line) if kp_id_match: original_id = kp_id_match.group(1) upper_id = original_id.upper() # 验证存在性(不区分大小写) if upper_id not in valid_ids_upper: print(f"忽略无效知识点ID: {original_id}") continue # 替换为数据库实际存储的大写ID line = line.replace(original_id, upper_id) # 自动补全WITH语句 if has_question and 'MERGE (q)-[:TESTS_KNOWLEDGE]' in line: if not any('WITH q' in l for l in safe): safe.append("WITH q") safe.append(line) # 确保Question节点后紧跟WITH if has_question: # 在MERGE (q:Question)之后插入WITH for i, line in enumerate(safe): if 'MERGE (q:Question' in line: if i + 1 >= len(safe) or not safe[i + 1].startswith('WITH'): safe.insert(i + 1, "WITH q") break # 最终过滤空行 return '\n'.join([line for line in safe if line]) def run(self) -> str: """执行知识点关联流程""" try: response = self.client.chat.completions.create( model=MODEL_NAME, messages=[ { "role": "system", "content": self._make_prompt() }, { "role": "user", "content": f"题目内容:{self.content}\n请分析考查的知识点,只返回Cypher代码" } ] ) raw_cypher = response.choices[0].message.content cleaned_cypher = self._clean_cypher(raw_cypher) if cleaned_cypher: print("验证通过的Cypher:\n", cleaned_cypher) return cleaned_cypher return "" except Exception as e: print("知识点分析失败:", str(e)) return "" def query_related_knowledge(self): """查询题目关联的知识点""" cypher = f""" MATCH (q:Question {{id: "{self.question_id}"}})-[:TESTS_KNOWLEDGE]->(kp) RETURN kp.id AS knowledge_id, kp.name AS knowledge_name """ try: result = self.graph.run(cypher).data() if result: print(f"题目关联的知识点({self.question_id}):") for row in result: print(f"- {row['knowledge_name']} (ID: {row['knowledge_id']})") else: print("该题目尚未关联知识点") return result except Exception as e: print("查询失败:", str(e)) return [] # 测试用例 if __name__ == '__main__': test_case = """【时间问题】甲乙两车从相距240公里的两地同时出发相向而行,甲车时速60公里,乙车时速40公里,几小时后相遇?""" kg = KnowledgeGraph(test_case) cypher = kg.run() if cypher: # 插入数据 kg.graph.run(cypher) print("执行成功!关联知识点:") kg.query_related_knowledge() # 新增查询 else: print("未生成有效Cypher") # # 临时诊断 # print("当前知识库中是否存在该ID:", # 'f0333b305f7246b5a06d03d4e3ff55a9' in kg.knowledge_points) # # # 直接查询数据库 # test_cypher = ''' # MATCH (kp:KnowledgePoint) # WHERE kp.id = 'f0333b305f7246b5a06d03d4e3ff55a9' # RETURN kp.id, kp.name # ''' # print("直接查询结果:", kg.graph.run(test_cypher).data())