# -*- coding: utf-8 -*- import hashlib import json import time from typing import Iterator, Dict from openai import OpenAI from openai.types.chat import ChatCompletionChunk from K2_Neo4jExecutor import * class KnowledgeGraph: def __init__(self, content: str): self.content = content self.question_id = self._generate_question_id() self.graph = self._init_graph_connection() self.existing_knowledge = self._fetch_existing_nodes("KnowledgePoint") self.existing_ability = self._fetch_existing_nodes("AbilityPoint") self.client = OpenAI(api_key=MODEL_API_KEY, base_url=MODEL_API_URL) def _validate_ids(self, line: str) -> bool: """修正后的ID验证(改为大小写不敏感)""" # 使用更灵活的正则表达式 found_ids = { 'kp': set(re.findall(r'(?i)(kp_[\da-f]{6})', line)), 'ab': set(re.findall(r'(?i)(ab_[\da-f]{6})', line)) } # 转换为小写统一比较 valid_kp = all(kp.lower() in self.existing_knowledge for kp in found_ids['kp']) valid_ab = all(ab.lower() in self.existing_ability for ab in found_ids['ab']) return valid_kp and valid_ab def _init_graph_connection(self) -> Graph: """初始化并测试数据库连接""" try: graph = Graph(NEO4J_URI, auth=NEO4J_AUTH) graph.run("RETURN 1").data() print("✅ Neo4j连接成功") return graph except Exception as e: raise ConnectionError(f"❌ 数据库连接失败: {str(e)}") def _validate_weight(self, line: str) -> bool: """验证关系权重是否合法""" weight_match = re.search(r"weight\s*:\s*([0-9.]+)", line) if weight_match: try: weight = float(weight_match.group(1)) return 0.0 <= weight <= 1.0 except ValueError: return False return True # 没有weight属性时视为合法 def _validate_cypher_structure(self, cypher: str) -> bool: """验证WITH子句存在性""" has_merge = re.search(r'\bMERGE\s*\(q:Question\b', cypher, re.IGNORECASE) has_with = re.search(r'\bWITH\s+q\b', cypher, re.IGNORECASE) return not has_merge or (has_merge and has_with) def _generate_question_id(self) -> str: """生成题目唯一标识符""" return hashlib.md5(self.content.encode()).hexdigest()[:8] def _fetch_existing_nodes(self, label: str) -> Dict[str, str]: """从Neo4j获取已有节点""" try: cypher = f"MATCH (n:{label}) RETURN n.id as id, n.name as name" result = self.graph.run(cypher).data() return {item['id']: item['name'] for item in result} except Exception as e: print(f"❌ 节点查询失败: {str(e)}") return {} def _generate_stream(self) -> Iterator[ChatCompletionChunk]: """生成限制性提示词""" # 修改提示词中的Cypher示例部分 system_prompt = f''' 将题目中涉及到的小学数学知识点、能力点进行总结,并且按照以下格式生成在neo4j-community-5.26.2上的语句: // 在示例中强调WITH的必要性 MERGE (q:Question {{id: "{self.question_id}"}}) SET q.content = "..." // 必须使用WITH传递上下文 WITH q MATCH (kp:KnowledgePoint {{id: "KP_xxxxxx"}}) MERGE (q)-[:TESTS_KNOWLEDGE]->(kp) // 多个关系需要继续使用WITH WITH q MATCH (ab:AbilityPoint {{id: "AB_xxxxxx"}}) MERGE (q)-[:REQUIRES_ABILITY]->(ab) ''' return self.client.chat.completions.create( model=MODEL_NAME, messages=[ {"role": "system", "content": system_prompt}, {"role": "user", "content": self.content} ], stream=True, timeout=300 ) def _format_node_list(self, nodes: Dict[str, str]) -> str: """格式化节点列表""" if not nodes: return " (无相关节点)" sample = [] for i, (k, v) in enumerate(nodes.items()): if i >= 5: sample.append(f" ...(共{len(nodes)}个,仅显示前5个)") break sample.append(f" - {k}: {v}") return '\n'.join(sample) def _extract_cypher(self, content: str) -> str: """安全提取Cypher""" safe_blocks = [] for block in re.findall(r"```(?:cypher)?\n(.*?)```", content, re.DOTALL): cleaned = self._sanitize_cypher(block) if cleaned: safe_blocks.append(cleaned) return ';\n\n'.join(safe_blocks) if safe_blocks else "" def _sanitize_cypher(self, cypher: str) -> str: valid_lines = [] for line_num, line in enumerate(cypher.split('\n'), 1): # 添加调试输出 print(f"正在处理第{line_num}行: {line[:50]}...") # 允许SET子句中的内容 if line.strip().startswith("SET"): valid_lines.append(line) continue # 记录过滤原因 filter_reason = [] if re.search(r'\bCREATE\b', line, re.IGNORECASE): filter_reason.append("包含CREATE语句") if not self._validate_ids(line): filter_reason.append("存在无效ID") if not self._validate_weight(line): filter_reason.append("权重值非法") if filter_reason: print(f"第{line_num}行被过滤: {raw_line[:50]}... | 原因: {', '.join(filter_reason)}") continue valid_lines.append(line) return '\n'.join(valid_lines) def run(self) -> Tuple[bool, str, str]: """执行安全生成流程""" if not self.existing_knowledge or not self.existing_ability: print("❌ 知识库或能力点为空,请检查数据库") return False, "节点数据为空", "" start_time = time.time() spinner = ['⠋', '⠙', '⠹', '⠸', '⠼', '⠴', '⠦', '⠧', '⠇', '⠏'] content_buffer = [] try: print(f"🚀 开始生成(知识点:{len(self.existing_knowledge)}个,能力点:{len(self.existing_ability)}个)") stream = self._generate_stream() for idx, chunk in enumerate(stream): print(f"\r{spinner[idx % 10]} 生成中({int(time.time() - start_time)}秒)", end="") if chunk.choices and chunk.choices[0].delta.content: content_chunk = chunk.choices[0].delta.content content_buffer.append(content_chunk) if len(content_buffer) == 1: print("\n\n📝 内容生成开始:") print(content_chunk, end="", flush=True) if content_buffer: full_content = ''.join(content_buffer) cypher_script = self._extract_cypher(full_content) print(f"\n\n✅ 生成完成!耗时 {int(time.time() - start_time)}秒") return True, full_content, cypher_script print("\n⚠️ 生成完成但未获取到有效内容") return False, "空内容", "" # 修改run方法中的异常处理 except Exception as e: # 修正后的代码 error_msg = str(e) if not isinstance(e, dict) else json.dumps(e) print(f"\n\n❌ 生成失败:{error_msg}") return False, error_msg, "" if __name__ == '__main__': # 准备执行 executor = K2_Neo4jExecutor( uri=NEO4J_URI, auth=NEO4J_AUTH ) # 测试用例 test_content = ''' 题目:一个长方体的长是8厘米,宽是5厘米,高是3厘米,求它的表面积是多少平方厘米? ''' try: kg = KnowledgeGraph(test_content) success, result, cypher = kg.run() res = executor.execute_cypher_text(cypher) print("恭喜,执行数据插入完成!") except Exception as e: print(f"程序初始化失败: {str(e)}")