You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

191 lines
7.9 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# -*- coding: utf-8 -*-
import hashlib
import json
import time
from typing import Iterator, Dict
from openai import OpenAI
from openai.types.chat import ChatCompletionChunk
from Neo4j.K2_Neo4jExecutor import *
class KnowledgeGraph:
def __init__(self, content: str):
self.content = content
self.question_id = self._generate_question_id()
self.graph = self._init_graph_connection()
self.existing_knowledge = self._fetch_existing_nodes("KnowledgePoint")
self.existing_ability = self._fetch_existing_nodes("AbilityPoint")
self.client = OpenAI(api_key=MODEL_API_KEY, base_url=MODEL_API_URL)
def _validate_ids(self, line: str) -> bool:
"""强化ID验证严格过滤非法节点"""
# 调整正则表达式严格匹配格式
found_ids = {
'kp': set(re.findall(r'\b(kp_[a-f0-9]{6})\b', line.lower())),
'ab': set(re.findall(r'\b(ab_[a-f0-9]{6})\b', line.lower()))
}
# 严格检查存在性(空集合视为有效)
valid_kp = not found_ids['kp'] or all(kp in self.existing_knowledge for kp in found_ids['kp'])
valid_ab = not found_ids['ab'] or all(ab in self.existing_ability for ab in found_ids['ab'])
return valid_kp and valid_ab
def _init_graph_connection(self) -> Graph:
"""初始化并测试数据库连接"""
try:
graph = Graph(NEO4J_URI, auth=NEO4J_AUTH)
graph.run("RETURN 1").data()
print("✅ Neo4j连接成功")
return graph
except Exception as e:
raise ConnectionError(f"❌ 数据库连接失败: {str(e)}")
def _validate_weight(self, line: str) -> bool:
"""验证关系权重是否合法"""
weight_match = re.search(r"weight\s*:\s*([0-9.]+)", line)
if weight_match:
try:
weight = float(weight_match.group(1))
return 0.0 <= weight <= 1.0
except ValueError:
return False
return True # 没有weight属性时视为合法
def _validate_cypher_structure(self, cypher: str) -> bool:
"""验证WITH子句存在性"""
has_merge = re.search(r'\bMERGE\s*\(q:Question\b', cypher, re.IGNORECASE)
has_with = re.search(r'\bWITH\s+q\b', cypher, re.IGNORECASE)
return not has_merge or (has_merge and has_with)
def _generate_question_id(self) -> str:
"""生成题目唯一标识符"""
return hashlib.md5(self.content.encode()).hexdigest()[:8]
def _fetch_existing_nodes(self, label: str) -> Dict[str, str]:
"""从Neo4j获取已有节点"""
try:
cypher = f"MATCH (n:{label}) RETURN n.id as id, n.name as name"
result = self.graph.run(cypher).data()
return {item['id']: item['name'] for item in result}
except Exception as e:
print(f"❌ 节点查询失败: {str(e)}")
return {}
def _generate_stream(self) -> Iterator[ChatCompletionChunk]:
"""强化提示词限制"""
# 生成现有ID列表的提示
existing_kp_ids = '\n'.join([f"- {k}" for k in list(self.existing_knowledge.keys())[:5]])
existing_ab_ids = '\n'.join([f"- {k}" for k in list(self.existing_ability.keys())[:5]])
system_prompt = f'''
将题目中涉及到的小学数学知识点、能力点进行总结并且按照以下格式生成在neo4j-community-5.26.2上的语句:
重要限制条件(违反将导致执行失败):
1. 只输出cypher脚本不要输出其它内容,也不要加代码块的起始终止符
2. 禁止创建新节点只能使用以下现有ID
3. 现有知识点ID列表
{existing_kp_ids}
...
4. 现有能力点ID列表
{existing_ab_ids}
...
5. 必须使用MATCH定位已有节点后才能建立关系
生成格式示例注意WITH子句
MERGE (q:Question {{id: "{self.question_id}"}})
SET q.content = "题目内容",
q.name = "前10字符"
WITH q
MATCH (kp1:KnowledgePoint {{id: "kp_3f5g6h"}})
WHERE kp1 IS NOT NULL
MERGE (q)-[:TESTS_KNOWLEDGE]->(kp1)
'''
return self.client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": self.content}
],
stream=True,
timeout=300
)
def _format_node_list(self, nodes: Dict[str, str]) -> str:
"""格式化节点列表"""
if not nodes:
return " (无相关节点)"
sample = []
for i, (k, v) in enumerate(nodes.items()):
if i >= 5:
sample.append(f" ...(共{len(nodes)}仅显示前5个")
break
sample.append(f" - {k}: {v}")
return '\n'.join(sample)
def _extract_cypher(self, content: str) -> str:
"""安全提取Cypher"""
safe_blocks = []
for block in re.findall(r"```(?:cypher)?\n(.*?)```", content, re.DOTALL):
cleaned = self._sanitize_cypher(block)
if cleaned:
safe_blocks.append(cleaned)
return ';\n\n'.join(safe_blocks) if safe_blocks else ""
def _sanitize_cypher(self, cypher: str) -> str:
# 新增过滤条件禁止MERGE非Question节点
if re.search(r'\bMERGE\s*\((?!q:Question)', cypher, re.IGNORECASE):
print("⚠️ 检测到非法MERGE语句已过滤")
return ""
# 新增过滤条件验证所有MATCH的节点是否带ID
for line in cypher.split('\n'):
if 'MATCH' in line and not re.search(r'\{id:\s*".+?"\}', line):
print(f"⚠️ 检测到无ID的MATCH语句: {line[:50]}")
return ""
def run(self) -> Tuple[bool, str, str]:
"""执行安全生成流程(修正返回三元组)"""
if not self.existing_knowledge or not self.existing_ability:
print("❌ 知识库或能力点为空,请检查数据库")
return False, "节点数据为空", "" # 保持三元组格式
start_time = time.time()
spinner = ['', '', '', '', '', '', '', '', '', '']
content_buffer = []
try:
print(f"🚀 开始生成(知识点:{len(self.existing_knowledge)}个,能力点:{len(self.existing_ability)}个)")
stream = self._generate_stream()
for idx, chunk in enumerate(stream):
print(f"\r{spinner[idx % 10]} 生成中({int(time.time() - start_time)}秒)", end="")
if chunk.choices and chunk.choices[0].delta.content:
content_chunk = chunk.choices[0].delta.content
content_buffer.append(content_chunk)
if len(content_buffer) == 1:
print("\n\n📝 内容生成开始:")
print(content_chunk, end="", flush=True)
if content_buffer:
full_content = ''.join(content_buffer)
cypher_script = self._extract_cypher(full_content) # 新增提取步骤
print(f"\n\n✅ 生成完成!耗时 {int(time.time() - start_time)}")
return True, full_content, cypher_script # 返回三元组
print("\n⚠️ 生成完成但未获取到有效内容")
return False, "空内容", "" # 修正为三元组
except Exception as e:
error_msg = str(e) if not isinstance(e, dict) else json.dumps(e)
print(f"\n\n❌ 生成失败:{error_msg}")
return False, error_msg, "" # 保持三元组格式