diff --git a/AI/Config.py b/AI/Config.py index 077f66c4..08f7bfdf 100644 --- a/AI/Config.py +++ b/AI/Config.py @@ -20,4 +20,8 @@ ALY_SK='oizcTOZ8izbGUouboC00RcmGE8vBQ1' # 正确路径拼接方式 mdWorkingPath = Path(__file__).parent / 'md-file' / 'readme' DEFAULT_TEMPLATE = mdWorkingPath / 'default.md' # 使用 / 运算符 -DEFAULT_OUTPUT_DIR = mdWorkingPath / 'output' # 使用 / 运算符 \ No newline at end of file +DEFAULT_OUTPUT_DIR = mdWorkingPath / 'output' # 使用 / 运算符 + +# 请在Config.py中配置以下参数 +NEO4J_URI = "neo4j://10.10.21.20:7687" +NEO4J_AUTH = ("neo4j", "DsideaL4r5t6y7u") \ No newline at end of file diff --git a/AI/Neo4j/K1_KnowledgeGraph.py b/AI/Neo4j/K1_KnowledgeGraph.py index 27b823ca..b137b70e 100644 --- a/AI/Neo4j/K1_KnowledgeGraph.py +++ b/AI/Neo4j/K1_KnowledgeGraph.py @@ -2,68 +2,69 @@ import re import time import hashlib -from typing import Iterator, Tuple +from typing import Iterator, Tuple, Dict +from py2neo import Graph from openai import OpenAI from openai.types.chat import ChatCompletionChunk from Config import * - class KnowledgeGraph: def __init__(self, content: str): self.content = content - self.client = OpenAI(api_key=MODEL_API_KEY, base_url=MODEL_API_URL) self.question_id = self._generate_question_id() + self.graph = self._init_graph_connection() + self.existing_knowledge = self._fetch_existing_nodes("KnowledgePoint") + self.existing_ability = self._fetch_existing_nodes("AbilityPoint") + self.client = OpenAI(api_key=MODEL_API_KEY, base_url=MODEL_API_URL) + + def _init_graph_connection(self) -> Graph: + """初始化并测试数据库连接""" + try: + graph = Graph(NEO4J_URI, auth=NEO4J_AUTH) + graph.run("RETURN 1").data() + print("✅ Neo4j连接成功") + return graph + except Exception as e: + raise ConnectionError(f"❌ 数据库连接失败: {str(e)}") def _generate_question_id(self) -> str: """生成题目唯一标识符""" return hashlib.md5(self.content.encode()).hexdigest()[:8] + def _fetch_existing_nodes(self, label: str) -> Dict[str, str]: + """从Neo4j获取已有节点""" + try: + cypher = f"MATCH (n:{label}) RETURN n.id as id, n.name as name" + result = self.graph.run(cypher).data() + return {item['id']: item['name'] for item in result} + except Exception as e: + print(f"❌ 节点查询失败: {str(e)}") + return {} + def _generate_stream(self) -> Iterator[ChatCompletionChunk]: - """动态化提示词版本""" - system_prompt = f'''请根据题目内容生成Neo4j Cypher语句,严格遵循以下规则: - # 节点创建规范 - 1. 知识点节点: - - 标签: KnowledgePoint - - 必须属性: - * id: "KP_" + 知识点名称的MD5前6位(示例:name="分数运算" → id="KP_ae3b8c") - * name: 知识点名称(从题目内容中提取) - * level: 学段(小学/初中/高中) - - 2. 能力点节点: - - 标签: AbilityPoint - - 必须属性: - * id: "AB_" + 能力名称的MD5前6位 - * name: 能力点名称 - * category: 能力类型(计算/推理/空间想象等) - - 3. 题目节点: - - 标签: Question - - 必须属性: - * id: "{self.question_id}"(已根据题目内容生成) - * content: 题目文本摘要(50字内) - * difficulty: 难度系数(1-5整数) - - # 关系规则 - 1. 题目与知识点关系: - (q:Question)-[:TESTS_KNOWLEDGE]->(kp:KnowledgePoint) - 需设置权重属性 weight(0.1-1.0) - - 2. 题目与能力点关系: - (q:Question)-[:REQUIRES_ABILITY]->(ab:AbilityPoint) - 需设置权重属性 weight - - # 生成步骤 - 1. 先创建约束(必须): - CREATE CONSTRAINT IF NOT EXISTS FOR (kp:KnowledgePoint) REQUIRE kp.id IS UNIQUE; - CREATE CONSTRAINT IF NOT EXISTS FOR (ab:AbilityPoint) REQUIRE ab.id IS UNIQUE; - - 2. 使用MERGE创建节点(禁止使用CREATE) - - 3. 最后创建关系(需先MATCH已存在节点) - - # 当前题目信息 - - 生成的问题ID: {self.question_id} - - 题目内容: "{self.content[:50]}..."(已截断)''' + """生成限制性提示词""" + system_prompt = f'''# 严格生成规则 +1. 仅允许使用以下预注册节点: + - 知识点列表(共{len(self.existing_knowledge)}个): +{self._format_node_list(self.existing_knowledge)} + - 能力点列表(共{len(self.existing_ability)}个): +{self._format_node_list(self.existing_ability)} + +2. 必须遵守的Cypher模式: +MERGE (q:Question {{id: "{self.question_id}"}}) +SET q.content = "题目内容摘要" + +WITH q +MATCH (kp:KnowledgePoint {{id: "KP_xxxxxx"}}) +MERGE (q)-[:TESTS_KNOWLEDGE {{weight: 0.8}}]->(kp) + +MATCH (ab:AbilityPoint {{id: "AB_xxxxxx"}}) +MERGE (q)-[:REQUIRES_ABILITY {{weight: 0.7}}]->(ab) + +3. 绝对禁止: +- 使用CREATE创建新节点 +- 修改已有节点属性 +- 使用未注册的ID''' return self.client.chat.completions.create( model=MODEL_NAME, @@ -75,39 +76,87 @@ class KnowledgeGraph: timeout=300 ) + def _format_node_list(self, nodes: Dict[str, str]) -> str: + """格式化节点列表""" + if not nodes: + return " (无相关节点)" + + sample = [] + for i, (k, v) in enumerate(nodes.items()): + if i >= 5: + sample.append(f" ...(共{len(nodes)}个,仅显示前5个)") + break + sample.append(f" - {k}: {v}") + return '\n'.join(sample) + def _extract_cypher(self, content: str) -> str: - """增强的Cypher提取(处理多代码块)""" - cypher_blocks = [] - # 匹配所有cypher代码块(包含语言声明) - pattern = r"```(?:cypher)?\n(.*?)```" - - for block in re.findall(pattern, content, re.DOTALL): - # 清理注释和空行 - cleaned = [ - line.split('//')[0].strip() - for line in block.split('\n') - if line.strip() and not line.strip().startswith('//') - ] + """安全提取Cypher""" + safe_blocks = [] + for block in re.findall(r"```(?:cypher)?\n(.*?)```", content, re.DOTALL): + cleaned = self._sanitize_cypher(block) if cleaned: - cypher_blocks.append('\n'.join(cleaned)) - - return ';\n\n'.join(cypher_blocks) + safe_blocks.append(cleaned) + return ';\n\n'.join(safe_blocks) if safe_blocks else "" + + def _sanitize_cypher(self, cypher: str) -> str: + """消毒Cypher语句""" + valid_lines = [] + for line in cypher.split('\n'): + line = line.split('//')[0].strip() + if not line: + continue + + # 检查非法操作 + if re.search(r'\bCREATE\b', line, re.IGNORECASE): + continue + + # 验证节点ID + if not self._validate_ids(line): + continue + + # 验证权重范围 + if not self._validate_weight(line): + continue + + valid_lines.append(line) + + return '\n'.join(valid_lines) if valid_lines else '' + + def _validate_ids(self, line: str) -> bool: + """验证行内的所有ID""" + kp_ids = {id_.upper() for id_ in re.findall(r'kp_[\da-f]{6}', line, re.IGNORECASE)} + ab_ids = {id_.upper() for id_ in re.findall(r'ab_[\da-f]{6}', line, re.IGNORECASE)} + + valid_kp = all(kp in self.existing_knowledge for kp in kp_ids) + valid_ab = all(ab in self.existing_ability for ab in ab_ids) + + return valid_kp and valid_ab + + def _validate_weight(self, line: str) -> bool: + """验证关系权重是否合法""" + weight_match = re.search(r"weight\s*:\s*([0-9.]+)", line) + if weight_match: + try: + weight = float(weight_match.group(1)) + return 0.0 <= weight <= 1.0 + except ValueError: + return False + return True # 没有weight属性时视为合法 def run(self) -> Tuple[bool, str, str]: - """执行生成流程(确保所有路径都有返回值)""" + """执行安全生成流程""" + if not self.existing_knowledge or not self.existing_ability: + print("❌ 知识库或能力点为空,请检查数据库") + return False, "节点数据为空", "" + start_time = time.time() spinner = ['⠋', '⠙', '⠹', '⠸', '⠼', '⠴', '⠦', '⠧', '⠇', '⠏'] content_buffer = [] try: - print(f"🚀 开始生成知识点和能力点的总结和插入语句") + print(f"🚀 开始生成(知识点:{len(self.existing_knowledge)}个,能力点:{len(self.existing_ability)}个)") stream = self._generate_stream() - # 添加流数据检查 - if not stream: - print("\n❌ 生成失败:无法获取生成流") - return False, "生成流获取失败", "" - for idx, chunk in enumerate(stream): print(f"\r{spinner[idx % 10]} 生成中({int(time.time() - start_time)}秒)", end="") @@ -119,20 +168,16 @@ class KnowledgeGraph: print("\n\n📝 内容生成开始:") print(content_chunk, end="", flush=True) - # 确保最终返回 if content_buffer: full_content = ''.join(content_buffer) cypher_script = self._extract_cypher(full_content) - print(f"\n\n✅ 生成成功!耗时 {int(time.time() - start_time)}秒") - print("\n================ 完整结果 ================") - print(full_content) - print("\n================ Cypher语句 ===============") - print(cypher_script if cypher_script else "未检测到Cypher语句") + print(f"\n\n✅ 生成完成!耗时 {int(time.time() - start_time)}秒") + print("\n================ 安全Cypher ===============") + print(cypher_script if cypher_script else "未通过安全检查") print("==========================================") return True, full_content, cypher_script - # 添加空内容处理 print("\n⚠️ 生成完成但未获取到有效内容") return False, "空内容", "" @@ -142,14 +187,18 @@ class KnowledgeGraph: if __name__ == '__main__': - shiti_content = ''' - 下面是一道小学三年级的数学题目,巧求周长: - 把7个完全相同的小长方形拼成如图的样子,已知每个小长方形的长是10厘米,则拼成的大长方形的周长是多少厘米? + # 测试用例 + test_content = ''' + 题目:一个长方体的长是8厘米,宽是5厘米,高是3厘米,求它的表面积是多少平方厘米? ''' - kg = KnowledgeGraph(shiti_content) - success, result, cypher = kg.run() - if success and cypher: - with open("knowledge_graph.cypher", "w", encoding="utf-8") as f: - f.write(cypher) - print(f"\nCypher语句已保存至 knowledge_graph.cypher (题目ID: {kg.question_id})") \ No newline at end of file + try: + kg = KnowledgeGraph(test_content) + success, result, cypher = kg.run() + + if success and cypher: + with open("output.cypher", "w", encoding="utf-8") as f: + f.write(cypher) + print(f"\nCypher已保存至output.cypher(ID: {kg.question_id})") + except Exception as e: + print(f"程序初始化失败: {str(e)}") \ No newline at end of file diff --git a/AI/Neo4j/K2_Neo4jExecutor.py b/AI/Neo4j/K2_Neo4jExecutor.py index c789b5b3..1c207d4c 100644 --- a/AI/Neo4j/K2_Neo4jExecutor.py +++ b/AI/Neo4j/K2_Neo4jExecutor.py @@ -2,7 +2,7 @@ from py2neo import Graph import re from Util import * - +from Config import * class K2_Neo4jExecutor: def __init__(self, uri, auth): @@ -36,8 +36,8 @@ class K2_Neo4jExecutor: if __name__ == '__main__': executor = K2_Neo4jExecutor( - uri="neo4j://10.10.21.20:7687", - auth=("neo4j", "DsideaL4r5t6y7u") + uri=NEO4J_URI, + auth=NEO4J_AUTH ) # 清库 clear(executor.graph) diff --git a/AI/__pycache__/Config.cpython-310.pyc b/AI/__pycache__/Config.cpython-310.pyc index ebf35c3d..f97f8284 100644 Binary files a/AI/__pycache__/Config.cpython-310.pyc and b/AI/__pycache__/Config.cpython-310.pyc differ