main
黄海 5 months ago
parent 910c86bb33
commit 67b88130cf

@ -19,19 +19,20 @@ class KnowledgeGraph:
self.client = OpenAI(api_key=MODEL_API_KEY, base_url=MODEL_API_URL) self.client = OpenAI(api_key=MODEL_API_KEY, base_url=MODEL_API_URL)
def _validate_ids(self, line: str) -> bool: def _validate_ids(self, line: str) -> bool:
"""修正后的ID验证严格匹配数据库格式""" """强化ID验证严格过滤非法节点"""
# 调整正则表达式匹配6位小写hex格式 # 调整正则表达式严格匹配格式
found_ids = { found_ids = {
'kp': set(re.findall(r'(kp_[a-f0-9]{6})', line.lower())), 'kp': set(re.findall(r'\b(kp_[a-f0-9]{6})\b', line.lower())),
'ab': set(re.findall(r'(ab_[a-f0-9]{6})', line.lower())) 'ab': set(re.findall(r'\b(ab_[a-f0-9]{6})\b', line.lower()))
} }
# 直接检查小写形式 # 严格检查存在性(空集合视为有效)
valid_kp = all(kp in self.existing_knowledge for kp in found_ids['kp']) valid_kp = not found_ids['kp'] or all(kp in self.existing_knowledge for kp in found_ids['kp'])
valid_ab = all(ab in self.existing_ability for ab in found_ids['ab']) valid_ab = not found_ids['ab'] or all(ab in self.existing_ability for ab in found_ids['ab'])
return valid_kp and valid_ab return valid_kp and valid_ab
def _init_graph_connection(self) -> Graph: def _init_graph_connection(self) -> Graph:
"""初始化并测试数据库连接""" """初始化并测试数据库连接"""
try: try:
@ -74,22 +75,32 @@ class KnowledgeGraph:
return {} return {}
def _generate_stream(self) -> Iterator[ChatCompletionChunk]: def _generate_stream(self) -> Iterator[ChatCompletionChunk]:
"""生成限制性提示词(添加现有节点示例)""" """强化提示词限制"""
# 在提示词中添加现有节点示例 # 生成现有ID列表的提示
knowledge_samples = '\n'.join([f"KP_{k[3:]}: {v}" for k, v in list(self.existing_knowledge.items())[:5]]) existing_kp_ids = '\n'.join([f"- {k}" for k in list(self.existing_knowledge.keys())[:5]])
ability_samples = '\n'.join([f"AB_{a[3:]}: {v}" for a, v in list(self.existing_ability.items())[:5]]) existing_ab_ids = '\n'.join([f"- {k}" for k in list(self.existing_ability.keys())[:5]])
system_prompt = f''' 将题目中涉及到的小学数学知识点、能力点进行总结并且按照以下格式生成在neo4j-community-5.26.2上的语句: system_prompt = f'''
请严格使用以下已有节点ID一定不要创建新ID 将题目中涉及到的小学数学知识点能力点进行总结并且按照以下格式生成在neo4j-community-5.26.2上的语句
现有知识点示例KP_后接6位小写字母/数字 重要限制条件违反将导致执行失败
{knowledge_samples} 1. 禁止创建新节点只能使用以下现有ID
现有能力点示例AB_后接6位小写字母/数字 2. 现有知识点ID列表
{ability_samples} {existing_kp_ids}
生成格式要求 ...
3. 现有能力点ID列表
{existing_ab_ids}
...
4. 必须使用MATCH定位已有节点后才能建立关系
生成格式示例注意WITH子句
MERGE (q:Question {{id: "{self.question_id}"}}) MERGE (q:Question {{id: "{self.question_id}"}})
SET q.content = "...", SET q.content = "题目内容",
q.name = "{self.content[:10]}" // 直接截取前10字符 q.name = "前10字符"
// ... rest of the template ...
WITH q
MATCH (kp1:KnowledgePoint {{id: "kp_3f5g6h"}})
WHERE kp1 IS NOT NULL
MERGE (q)-[:TESTS_KNOWLEDGE]->(kp1)
''' '''
return self.client.chat.completions.create( return self.client.chat.completions.create(
@ -126,41 +137,22 @@ class KnowledgeGraph:
return ';\n\n'.join(safe_blocks) if safe_blocks else "" return ';\n\n'.join(safe_blocks) if safe_blocks else ""
def _sanitize_cypher(self, cypher: str) -> str: def _sanitize_cypher(self, cypher: str) -> str:
valid_lines = [] # 新增过滤条件禁止MERGE非Question节点
for line_num, line in enumerate(cypher.split('\n'), 1): if re.search(r'\bMERGE\s*\((?!q:Question)', cypher, re.IGNORECASE):
# 添加调试输出 print("⚠️ 检测到非法MERGE语句已过滤")
print(f"正在处理第{line_num}行: {line[:50]}...") return ""
# 允许SET子句中的内容
if line.strip().startswith("SET"):
valid_lines.append(line)
continue
# 记录过滤原因
filter_reason = []
if re.search(r'\bCREATE\b', line, re.IGNORECASE):
filter_reason.append("包含CREATE语句")
if not self._validate_ids(line):
filter_reason.append("存在无效ID")
if not self._validate_weight(line):
filter_reason.append("权重值非法")
if filter_reason:
print(f"{line_num}行被过滤: {line[:50]}... | 原因: {', '.join(filter_reason)}")
continue
valid_lines.append(line)
return '\n'.join(valid_lines) # 新增过滤条件验证所有MATCH的节点是否带ID
for line in cypher.split('\n'):
if 'MATCH' in line and not re.search(r'\{id:\s*".+?"\}', line):
print(f"⚠️ 检测到无ID的MATCH语句: {line[:50]}")
return ""
def run(self) -> Tuple[bool, str, str]: def run(self) -> Tuple[bool, str, str]:
"""执行安全生成流程""" """执行安全生成流程(修正返回三元组)"""
if not self.existing_knowledge or not self.existing_ability: if not self.existing_knowledge or not self.existing_ability:
print("❌ 知识库或能力点为空,请检查数据库") print("❌ 知识库或能力点为空,请检查数据库")
return False, "节点数据为空", "" return False, "节点数据为空", "" # 保持三元组格式
start_time = time.time() start_time = time.time()
spinner = ['', '', '', '', '', '', '', '', '', ''] spinner = ['', '', '', '', '', '', '', '', '', '']
@ -183,16 +175,15 @@ class KnowledgeGraph:
if content_buffer: if content_buffer:
full_content = ''.join(content_buffer) full_content = ''.join(content_buffer)
cypher_script = self._extract_cypher(full_content) # 新增提取步骤
print(f"\n\n✅ 生成完成!耗时 {int(time.time() - start_time)}") print(f"\n\n✅ 生成完成!耗时 {int(time.time() - start_time)}")
return True, full_content return True, full_content, cypher_script # 返回三元组
print("\n⚠️ 生成完成但未获取到有效内容") print("\n⚠️ 生成完成但未获取到有效内容")
return False, "空内容" return False, "空内容", "" # 修正为三元组
# 修改run方法中的异常处理
except Exception as e: except Exception as e:
# 修正后的代码
error_msg = str(e) if not isinstance(e, dict) else json.dumps(e) error_msg = str(e) if not isinstance(e, dict) else json.dumps(e)
print(f"\n\n❌ 生成失败:{error_msg}") print(f"\n\n❌ 生成失败:{error_msg}")
return False, error_msg, "" return False, error_msg, "" # 保持三元组格式

@ -26,6 +26,14 @@ if __name__ == '__main__':
uri=NEO4J_URI, uri=NEO4J_URI,
auth=NEO4J_AUTH auth=NEO4J_AUTH
) )
# 新增数据库约束确保节点必须带ID
init_script = """
CREATE CONSTRAINT IF NOT EXISTS FOR (kp:KnowledgePoint) REQUIRE kp.id IS UNIQUE;
CREATE CONSTRAINT IF NOT EXISTS FOR (ab:AbilityPoint) REQUIRE ab.id IS UNIQUE;
CREATE CONSTRAINT IF NOT EXISTS FOR (q:Question) REQUIRE q.id IS UNIQUE;
"""
executor.execute_cypher_text(init_script)
# 使用示例 # 使用示例
question_blocks = split_questions('ShiTi.md') question_blocks = split_questions('ShiTi.md')
@ -35,7 +43,7 @@ if __name__ == '__main__':
print("-" * 50) print("-" * 50)
try: try:
kg = KnowledgeGraph(block) kg = KnowledgeGraph(block)
success, cypher = kg.run() success, cypher, result = kg.run()
# 替换一些特殊符号 # 替换一些特殊符号
cypher = cypher.replace('```neo4j', '').replace('```', '').replace('```cypher', '') cypher = cypher.replace('```neo4j', '').replace('```', '').replace('```cypher', '')
print(cypher) print(cypher)

Loading…
Cancel
Save