diff --git a/AI/Neo4j/K1_KnowledgeGraph.py b/AI/Neo4j/K1_KnowledgeGraph.py index 6a72d85d..d59d13bc 100644 --- a/AI/Neo4j/K1_KnowledgeGraph.py +++ b/AI/Neo4j/K1_KnowledgeGraph.py @@ -16,10 +16,20 @@ class KnowledgeGraph: self.existing_knowledge = self._fetch_existing_nodes("KnowledgePoint") self.existing_ability = self._fetch_existing_nodes("AbilityPoint") self.client = OpenAI(api_key=MODEL_API_KEY, base_url=MODEL_API_URL) - # 在__init__方法中添加检查 - print(f"现有知识点示例:{list(self.existing_knowledge.items())[:3]}") - print(f"现有能力点示例:{list(self.existing_ability.items())[:3]}") + def _validate_ids(self, line: str) -> bool: + """修正后的ID验证(改为大小写不敏感)""" + # 使用更灵活的正则表达式 + found_ids = { + 'kp': set(re.findall(r'(?i)(kp_[\da-f]{6})', line)), + 'ab': set(re.findall(r'(?i)(ab_[\da-f]{6})', line)) + } + + # 转换为小写统一比较 + valid_kp = all(kp.lower() in self.existing_knowledge for kp in found_ids['kp']) + valid_ab = all(ab.lower() in self.existing_ability for ab in found_ids['ab']) + + return valid_kp and valid_ab def _init_graph_connection(self) -> Graph: """初始化并测试数据库连接""" try: @@ -30,6 +40,22 @@ class KnowledgeGraph: except Exception as e: raise ConnectionError(f"❌ 数据库连接失败: {str(e)}") + def _validate_weight(self, line: str) -> bool: + """验证关系权重是否合法""" + weight_match = re.search(r"weight\s*:\s*([0-9.]+)", line) + if weight_match: + try: + weight = float(weight_match.group(1)) + return 0.0 <= weight <= 1.0 + except ValueError: + return False + return True # 没有weight属性时视为合法 + + def _validate_cypher_structure(self, cypher: str) -> bool: + """验证WITH子句存在性""" + has_merge = re.search(r'\bMERGE\s*\(q:Question\b', cypher, re.IGNORECASE) + has_with = re.search(r'\bWITH\s+q\b', cypher, re.IGNORECASE) + return not has_merge or (has_merge and has_with) def _generate_question_id(self) -> str: """生成题目唯一标识符""" return hashlib.md5(self.content.encode()).hexdigest()[:8] @@ -96,6 +122,36 @@ class KnowledgeGraph: safe_blocks.append(cleaned) return ';\n\n'.join(safe_blocks) if safe_blocks else "" + def _sanitize_cypher(self, cypher: str) -> str: + valid_lines = [] + for line_num, line in enumerate(cypher.split('\n'), 1): + # 添加调试输出 + print(f"正在处理第{line_num}行: {line[:50]}...") + + # 允许SET子句中的内容 + if line.strip().startswith("SET"): + valid_lines.append(line) + continue + + # 记录过滤原因 + filter_reason = [] + + if re.search(r'\bCREATE\b', line, re.IGNORECASE): + filter_reason.append("包含CREATE语句") + + if not self._validate_ids(line): + filter_reason.append("存在无效ID") + + if not self._validate_weight(line): + filter_reason.append("权重值非法") + + if filter_reason: + print(f"第{line_num}行被过滤: {raw_line[:50]}... | 原因: {', '.join(filter_reason)}") + continue + + valid_lines.append(line) + + return '\n'.join(valid_lines) def run(self) -> Tuple[bool, str, str]: """执行安全生成流程""" if not self.existing_knowledge or not self.existing_ability: @@ -153,7 +209,7 @@ if __name__ == '__main__': try: kg = KnowledgeGraph(test_content) success, result, cypher = kg.run() - #res = executor.execute_cypher_text(cypher) - #print("\n\n执行结果:" + res) + res = executor.execute_cypher_text(cypher) + print("恭喜,执行数据插入完成!") except Exception as e: print(f"程序初始化失败: {str(e)}")