191 lines
7.9 KiB
Python
191 lines
7.9 KiB
Python
# -*- coding: utf-8 -*-
|
||
import hashlib
|
||
import json
|
||
import time
|
||
from typing import Iterator, Dict
|
||
|
||
from openai import OpenAI
|
||
from openai.types.chat import ChatCompletionChunk
|
||
from Neo4j.K2_Neo4jExecutor import *
|
||
|
||
|
||
class KnowledgeGraph:
|
||
def __init__(self, content: str):
|
||
self.content = content
|
||
self.question_id = self._generate_question_id()
|
||
self.graph = self._init_graph_connection()
|
||
self.existing_knowledge = self._fetch_existing_nodes("KnowledgePoint")
|
||
self.existing_ability = self._fetch_existing_nodes("AbilityPoint")
|
||
self.client = OpenAI(api_key=MODEL_API_KEY, base_url=MODEL_API_URL)
|
||
|
||
def _validate_ids(self, line: str) -> bool:
|
||
"""强化ID验证(严格过滤非法节点)"""
|
||
# 调整正则表达式严格匹配格式
|
||
found_ids = {
|
||
'kp': set(re.findall(r'\b(kp_[a-f0-9]{6})\b', line.lower())),
|
||
'ab': set(re.findall(r'\b(ab_[a-f0-9]{6})\b', line.lower()))
|
||
}
|
||
|
||
# 严格检查存在性(空集合视为有效)
|
||
valid_kp = not found_ids['kp'] or all(kp in self.existing_knowledge for kp in found_ids['kp'])
|
||
valid_ab = not found_ids['ab'] or all(ab in self.existing_ability for ab in found_ids['ab'])
|
||
|
||
return valid_kp and valid_ab
|
||
|
||
|
||
def _init_graph_connection(self) -> Graph:
|
||
"""初始化并测试数据库连接"""
|
||
try:
|
||
graph = Graph(NEO4J_URI, auth=NEO4J_AUTH)
|
||
graph.run("RETURN 1").data()
|
||
print("✅ Neo4j连接成功")
|
||
return graph
|
||
except Exception as e:
|
||
raise ConnectionError(f"❌ 数据库连接失败: {str(e)}")
|
||
|
||
def _validate_weight(self, line: str) -> bool:
|
||
"""验证关系权重是否合法"""
|
||
weight_match = re.search(r"weight\s*:\s*([0-9.]+)", line)
|
||
if weight_match:
|
||
try:
|
||
weight = float(weight_match.group(1))
|
||
return 0.0 <= weight <= 1.0
|
||
except ValueError:
|
||
return False
|
||
return True # 没有weight属性时视为合法
|
||
|
||
def _validate_cypher_structure(self, cypher: str) -> bool:
|
||
"""验证WITH子句存在性"""
|
||
has_merge = re.search(r'\bMERGE\s*\(q:Question\b', cypher, re.IGNORECASE)
|
||
has_with = re.search(r'\bWITH\s+q\b', cypher, re.IGNORECASE)
|
||
return not has_merge or (has_merge and has_with)
|
||
|
||
def _generate_question_id(self) -> str:
|
||
"""生成题目唯一标识符"""
|
||
return hashlib.md5(self.content.encode()).hexdigest()[:8]
|
||
|
||
def _fetch_existing_nodes(self, label: str) -> Dict[str, str]:
|
||
"""从Neo4j获取已有节点"""
|
||
try:
|
||
cypher = f"MATCH (n:{label}) RETURN n.id as id, n.name as name"
|
||
result = self.graph.run(cypher).data()
|
||
return {item['id']: item['name'] for item in result}
|
||
except Exception as e:
|
||
print(f"❌ 节点查询失败: {str(e)}")
|
||
return {}
|
||
|
||
def _generate_stream(self) -> Iterator[ChatCompletionChunk]:
|
||
"""强化提示词限制"""
|
||
# 生成现有ID列表的提示
|
||
existing_kp_ids = '\n'.join([f"- {k}" for k in list(self.existing_knowledge.keys())[:5]])
|
||
existing_ab_ids = '\n'.join([f"- {k}" for k in list(self.existing_ability.keys())[:5]])
|
||
|
||
system_prompt = f'''
|
||
将题目中涉及到的小学数学知识点、能力点进行总结,并且按照以下格式生成在neo4j-community-5.26.2上的语句:
|
||
重要限制条件(违反将导致执行失败):
|
||
1. 只输出cypher脚本,不要输出其它内容,也不要加代码块的起始终止符
|
||
2. 禁止创建新节点(只能使用以下现有ID)
|
||
3. 现有知识点ID列表:
|
||
{existing_kp_ids}
|
||
...
|
||
4. 现有能力点ID列表:
|
||
{existing_ab_ids}
|
||
...
|
||
5. 必须使用MATCH定位已有节点后才能建立关系
|
||
|
||
生成格式示例(注意WITH子句):
|
||
MERGE (q:Question {{id: "{self.question_id}"}})
|
||
SET q.content = "题目内容",
|
||
q.name = "前10字符"
|
||
|
||
WITH q
|
||
MATCH (kp1:KnowledgePoint {{id: "kp_3f5g6h"}})
|
||
WHERE kp1 IS NOT NULL
|
||
MERGE (q)-[:TESTS_KNOWLEDGE]->(kp1)
|
||
'''
|
||
|
||
return self.client.chat.completions.create(
|
||
model=MODEL_NAME,
|
||
messages=[
|
||
{"role": "system", "content": system_prompt},
|
||
{"role": "user", "content": self.content}
|
||
],
|
||
stream=True,
|
||
timeout=300
|
||
)
|
||
|
||
|
||
def _format_node_list(self, nodes: Dict[str, str]) -> str:
|
||
"""格式化节点列表"""
|
||
if not nodes:
|
||
return " (无相关节点)"
|
||
|
||
sample = []
|
||
for i, (k, v) in enumerate(nodes.items()):
|
||
if i >= 5:
|
||
sample.append(f" ...(共{len(nodes)}个,仅显示前5个)")
|
||
break
|
||
sample.append(f" - {k}: {v}")
|
||
return '\n'.join(sample)
|
||
|
||
def _extract_cypher(self, content: str) -> str:
|
||
"""安全提取Cypher"""
|
||
safe_blocks = []
|
||
for block in re.findall(r"```(?:cypher)?\n(.*?)```", content, re.DOTALL):
|
||
cleaned = self._sanitize_cypher(block)
|
||
if cleaned:
|
||
safe_blocks.append(cleaned)
|
||
return ';\n\n'.join(safe_blocks) if safe_blocks else ""
|
||
|
||
def _sanitize_cypher(self, cypher: str) -> str:
|
||
# 新增过滤条件:禁止MERGE非Question节点
|
||
if re.search(r'\bMERGE\s*\((?!q:Question)', cypher, re.IGNORECASE):
|
||
print("⚠️ 检测到非法MERGE语句,已过滤")
|
||
return ""
|
||
|
||
# 新增过滤条件:验证所有MATCH的节点是否带ID
|
||
for line in cypher.split('\n'):
|
||
if 'MATCH' in line and not re.search(r'\{id:\s*".+?"\}', line):
|
||
print(f"⚠️ 检测到无ID的MATCH语句: {line[:50]}")
|
||
return ""
|
||
|
||
def run(self) -> Tuple[bool, str, str]:
|
||
"""执行安全生成流程(修正返回三元组)"""
|
||
if not self.existing_knowledge or not self.existing_ability:
|
||
print("❌ 知识库或能力点为空,请检查数据库")
|
||
return False, "节点数据为空", "" # 保持三元组格式
|
||
|
||
start_time = time.time()
|
||
spinner = ['⠋', '⠙', '⠹', '⠸', '⠼', '⠴', '⠦', '⠧', '⠇', '⠏']
|
||
content_buffer = []
|
||
|
||
try:
|
||
print(f"🚀 开始生成(知识点:{len(self.existing_knowledge)}个,能力点:{len(self.existing_ability)}个)")
|
||
stream = self._generate_stream()
|
||
|
||
for idx, chunk in enumerate(stream):
|
||
print(f"\r{spinner[idx % 10]} 生成中({int(time.time() - start_time)}秒)", end="")
|
||
|
||
if chunk.choices and chunk.choices[0].delta.content:
|
||
content_chunk = chunk.choices[0].delta.content
|
||
content_buffer.append(content_chunk)
|
||
|
||
if len(content_buffer) == 1:
|
||
print("\n\n📝 内容生成开始:")
|
||
print(content_chunk, end="", flush=True)
|
||
|
||
if content_buffer:
|
||
full_content = ''.join(content_buffer)
|
||
cypher_script = self._extract_cypher(full_content) # 新增提取步骤
|
||
print(f"\n\n✅ 生成完成!耗时 {int(time.time() - start_time)}秒")
|
||
return True, full_content, cypher_script # 返回三元组
|
||
|
||
print("\n⚠️ 生成完成但未获取到有效内容")
|
||
return False, "空内容", "" # 修正为三元组
|
||
|
||
except Exception as e:
|
||
error_msg = str(e) if not isinstance(e, dict) else json.dumps(e)
|
||
print(f"\n\n❌ 生成失败:{error_msg}")
|
||
return False, error_msg, "" # 保持三元组格式
|
||
|