|
|
|
@ -19,19 +19,20 @@ class KnowledgeGraph:
|
|
|
|
|
self.client = OpenAI(api_key=MODEL_API_KEY, base_url=MODEL_API_URL)
|
|
|
|
|
|
|
|
|
|
def _validate_ids(self, line: str) -> bool:
|
|
|
|
|
"""修正后的ID验证(严格匹配数据库格式)"""
|
|
|
|
|
# 调整正则表达式匹配6位小写hex格式
|
|
|
|
|
"""强化ID验证(严格过滤非法节点)"""
|
|
|
|
|
# 调整正则表达式严格匹配格式
|
|
|
|
|
found_ids = {
|
|
|
|
|
'kp': set(re.findall(r'(kp_[a-f0-9]{6})', line.lower())),
|
|
|
|
|
'ab': set(re.findall(r'(ab_[a-f0-9]{6})', line.lower()))
|
|
|
|
|
'kp': set(re.findall(r'\b(kp_[a-f0-9]{6})\b', line.lower())),
|
|
|
|
|
'ab': set(re.findall(r'\b(ab_[a-f0-9]{6})\b', line.lower()))
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# 直接检查小写形式
|
|
|
|
|
valid_kp = all(kp in self.existing_knowledge for kp in found_ids['kp'])
|
|
|
|
|
valid_ab = all(ab in self.existing_ability for ab in found_ids['ab'])
|
|
|
|
|
# 严格检查存在性(空集合视为有效)
|
|
|
|
|
valid_kp = not found_ids['kp'] or all(kp in self.existing_knowledge for kp in found_ids['kp'])
|
|
|
|
|
valid_ab = not found_ids['ab'] or all(ab in self.existing_ability for ab in found_ids['ab'])
|
|
|
|
|
|
|
|
|
|
return valid_kp and valid_ab
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _init_graph_connection(self) -> Graph:
|
|
|
|
|
"""初始化并测试数据库连接"""
|
|
|
|
|
try:
|
|
|
|
@ -74,22 +75,32 @@ class KnowledgeGraph:
|
|
|
|
|
return {}
|
|
|
|
|
|
|
|
|
|
def _generate_stream(self) -> Iterator[ChatCompletionChunk]:
|
|
|
|
|
"""生成限制性提示词(添加现有节点示例)"""
|
|
|
|
|
# 在提示词中添加现有节点示例
|
|
|
|
|
knowledge_samples = '\n'.join([f"KP_{k[3:]}: {v}" for k, v in list(self.existing_knowledge.items())[:5]])
|
|
|
|
|
ability_samples = '\n'.join([f"AB_{a[3:]}: {v}" for a, v in list(self.existing_ability.items())[:5]])
|
|
|
|
|
|
|
|
|
|
system_prompt = f''' 将题目中涉及到的小学数学知识点、能力点进行总结,并且按照以下格式生成在neo4j-community-5.26.2上的语句:
|
|
|
|
|
请严格使用以下已有节点ID(一定不要创建新ID):
|
|
|
|
|
现有知识点示例(KP_后接6位小写字母/数字):
|
|
|
|
|
{knowledge_samples}
|
|
|
|
|
现有能力点示例(AB_后接6位小写字母/数字):
|
|
|
|
|
{ability_samples}
|
|
|
|
|
生成格式要求:
|
|
|
|
|
"""强化提示词限制"""
|
|
|
|
|
# 生成现有ID列表的提示
|
|
|
|
|
existing_kp_ids = '\n'.join([f"- {k}" for k in list(self.existing_knowledge.keys())[:5]])
|
|
|
|
|
existing_ab_ids = '\n'.join([f"- {k}" for k in list(self.existing_ability.keys())[:5]])
|
|
|
|
|
|
|
|
|
|
system_prompt = f'''
|
|
|
|
|
将题目中涉及到的小学数学知识点、能力点进行总结,并且按照以下格式生成在neo4j-community-5.26.2上的语句:
|
|
|
|
|
重要限制条件(违反将导致执行失败):
|
|
|
|
|
1. 禁止创建新节点(只能使用以下现有ID)
|
|
|
|
|
2. 现有知识点ID列表:
|
|
|
|
|
{existing_kp_ids}
|
|
|
|
|
...
|
|
|
|
|
3. 现有能力点ID列表:
|
|
|
|
|
{existing_ab_ids}
|
|
|
|
|
...
|
|
|
|
|
4. 必须使用MATCH定位已有节点后才能建立关系
|
|
|
|
|
|
|
|
|
|
生成格式示例(注意WITH子句):
|
|
|
|
|
MERGE (q:Question {{id: "{self.question_id}"}})
|
|
|
|
|
SET q.content = "...",
|
|
|
|
|
q.name = "{self.content[:10]}" // 直接截取前10字符
|
|
|
|
|
// ... rest of the template ...
|
|
|
|
|
SET q.content = "题目内容",
|
|
|
|
|
q.name = "前10字符"
|
|
|
|
|
|
|
|
|
|
WITH q
|
|
|
|
|
MATCH (kp1:KnowledgePoint {{id: "kp_3f5g6h"}})
|
|
|
|
|
WHERE kp1 IS NOT NULL
|
|
|
|
|
MERGE (q)-[:TESTS_KNOWLEDGE]->(kp1)
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
return self.client.chat.completions.create(
|
|
|
|
@ -126,41 +137,22 @@ class KnowledgeGraph:
|
|
|
|
|
return ';\n\n'.join(safe_blocks) if safe_blocks else ""
|
|
|
|
|
|
|
|
|
|
def _sanitize_cypher(self, cypher: str) -> str:
|
|
|
|
|
valid_lines = []
|
|
|
|
|
for line_num, line in enumerate(cypher.split('\n'), 1):
|
|
|
|
|
# 添加调试输出
|
|
|
|
|
print(f"正在处理第{line_num}行: {line[:50]}...")
|
|
|
|
|
|
|
|
|
|
# 允许SET子句中的内容
|
|
|
|
|
if line.strip().startswith("SET"):
|
|
|
|
|
valid_lines.append(line)
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
# 记录过滤原因
|
|
|
|
|
filter_reason = []
|
|
|
|
|
|
|
|
|
|
if re.search(r'\bCREATE\b', line, re.IGNORECASE):
|
|
|
|
|
filter_reason.append("包含CREATE语句")
|
|
|
|
|
|
|
|
|
|
if not self._validate_ids(line):
|
|
|
|
|
filter_reason.append("存在无效ID")
|
|
|
|
|
|
|
|
|
|
if not self._validate_weight(line):
|
|
|
|
|
filter_reason.append("权重值非法")
|
|
|
|
|
|
|
|
|
|
if filter_reason:
|
|
|
|
|
print(f"第{line_num}行被过滤: {line[:50]}... | 原因: {', '.join(filter_reason)}")
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
valid_lines.append(line)
|
|
|
|
|
# 新增过滤条件:禁止MERGE非Question节点
|
|
|
|
|
if re.search(r'\bMERGE\s*\((?!q:Question)', cypher, re.IGNORECASE):
|
|
|
|
|
print("⚠️ 检测到非法MERGE语句,已过滤")
|
|
|
|
|
return ""
|
|
|
|
|
|
|
|
|
|
return '\n'.join(valid_lines)
|
|
|
|
|
# 新增过滤条件:验证所有MATCH的节点是否带ID
|
|
|
|
|
for line in cypher.split('\n'):
|
|
|
|
|
if 'MATCH' in line and not re.search(r'\{id:\s*".+?"\}', line):
|
|
|
|
|
print(f"⚠️ 检测到无ID的MATCH语句: {line[:50]}")
|
|
|
|
|
return ""
|
|
|
|
|
|
|
|
|
|
def run(self) -> Tuple[bool, str, str]:
|
|
|
|
|
"""执行安全生成流程"""
|
|
|
|
|
"""执行安全生成流程(修正返回三元组)"""
|
|
|
|
|
if not self.existing_knowledge or not self.existing_ability:
|
|
|
|
|
print("❌ 知识库或能力点为空,请检查数据库")
|
|
|
|
|
return False, "节点数据为空", ""
|
|
|
|
|
return False, "节点数据为空", "" # 保持三元组格式
|
|
|
|
|
|
|
|
|
|
start_time = time.time()
|
|
|
|
|
spinner = ['⠋', '⠙', '⠹', '⠸', '⠼', '⠴', '⠦', '⠧', '⠇', '⠏']
|
|
|
|
@ -183,16 +175,15 @@ class KnowledgeGraph:
|
|
|
|
|
|
|
|
|
|
if content_buffer:
|
|
|
|
|
full_content = ''.join(content_buffer)
|
|
|
|
|
cypher_script = self._extract_cypher(full_content) # 新增提取步骤
|
|
|
|
|
print(f"\n\n✅ 生成完成!耗时 {int(time.time() - start_time)}秒")
|
|
|
|
|
return True, full_content
|
|
|
|
|
return True, full_content, cypher_script # 返回三元组
|
|
|
|
|
|
|
|
|
|
print("\n⚠️ 生成完成但未获取到有效内容")
|
|
|
|
|
return False, "空内容"
|
|
|
|
|
return False, "空内容", "" # 修正为三元组
|
|
|
|
|
|
|
|
|
|
# 修改run方法中的异常处理
|
|
|
|
|
except Exception as e:
|
|
|
|
|
# 修正后的代码
|
|
|
|
|
error_msg = str(e) if not isinstance(e, dict) else json.dumps(e)
|
|
|
|
|
print(f"\n\n❌ 生成失败:{error_msg}")
|
|
|
|
|
return False, error_msg, ""
|
|
|
|
|
return False, error_msg, "" # 保持三元组格式
|
|
|
|
|
|
|
|
|
|