|
|
|
@ -2,68 +2,69 @@
|
|
|
|
|
import re
|
|
|
|
|
import time
|
|
|
|
|
import hashlib
|
|
|
|
|
from typing import Iterator, Tuple
|
|
|
|
|
from typing import Iterator, Tuple, Dict
|
|
|
|
|
from py2neo import Graph
|
|
|
|
|
from openai import OpenAI
|
|
|
|
|
from openai.types.chat import ChatCompletionChunk
|
|
|
|
|
from Config import *
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class KnowledgeGraph:
|
|
|
|
|
def __init__(self, content: str):
|
|
|
|
|
self.content = content
|
|
|
|
|
self.client = OpenAI(api_key=MODEL_API_KEY, base_url=MODEL_API_URL)
|
|
|
|
|
self.question_id = self._generate_question_id()
|
|
|
|
|
self.graph = self._init_graph_connection()
|
|
|
|
|
self.existing_knowledge = self._fetch_existing_nodes("KnowledgePoint")
|
|
|
|
|
self.existing_ability = self._fetch_existing_nodes("AbilityPoint")
|
|
|
|
|
self.client = OpenAI(api_key=MODEL_API_KEY, base_url=MODEL_API_URL)
|
|
|
|
|
|
|
|
|
|
def _init_graph_connection(self) -> Graph:
|
|
|
|
|
"""初始化并测试数据库连接"""
|
|
|
|
|
try:
|
|
|
|
|
graph = Graph(NEO4J_URI, auth=NEO4J_AUTH)
|
|
|
|
|
graph.run("RETURN 1").data()
|
|
|
|
|
print("✅ Neo4j连接成功")
|
|
|
|
|
return graph
|
|
|
|
|
except Exception as e:
|
|
|
|
|
raise ConnectionError(f"❌ 数据库连接失败: {str(e)}")
|
|
|
|
|
|
|
|
|
|
def _generate_question_id(self) -> str:
|
|
|
|
|
"""生成题目唯一标识符"""
|
|
|
|
|
return hashlib.md5(self.content.encode()).hexdigest()[:8]
|
|
|
|
|
|
|
|
|
|
def _fetch_existing_nodes(self, label: str) -> Dict[str, str]:
|
|
|
|
|
"""从Neo4j获取已有节点"""
|
|
|
|
|
try:
|
|
|
|
|
cypher = f"MATCH (n:{label}) RETURN n.id as id, n.name as name"
|
|
|
|
|
result = self.graph.run(cypher).data()
|
|
|
|
|
return {item['id']: item['name'] for item in result}
|
|
|
|
|
except Exception as e:
|
|
|
|
|
print(f"❌ 节点查询失败: {str(e)}")
|
|
|
|
|
return {}
|
|
|
|
|
|
|
|
|
|
def _generate_stream(self) -> Iterator[ChatCompletionChunk]:
|
|
|
|
|
"""动态化提示词版本"""
|
|
|
|
|
system_prompt = f'''请根据题目内容生成Neo4j Cypher语句,严格遵循以下规则:
|
|
|
|
|
# 节点创建规范
|
|
|
|
|
1. 知识点节点:
|
|
|
|
|
- 标签: KnowledgePoint
|
|
|
|
|
- 必须属性:
|
|
|
|
|
* id: "KP_" + 知识点名称的MD5前6位(示例:name="分数运算" → id="KP_ae3b8c")
|
|
|
|
|
* name: 知识点名称(从题目内容中提取)
|
|
|
|
|
* level: 学段(小学/初中/高中)
|
|
|
|
|
|
|
|
|
|
2. 能力点节点:
|
|
|
|
|
- 标签: AbilityPoint
|
|
|
|
|
- 必须属性:
|
|
|
|
|
* id: "AB_" + 能力名称的MD5前6位
|
|
|
|
|
* name: 能力点名称
|
|
|
|
|
* category: 能力类型(计算/推理/空间想象等)
|
|
|
|
|
|
|
|
|
|
3. 题目节点:
|
|
|
|
|
- 标签: Question
|
|
|
|
|
- 必须属性:
|
|
|
|
|
* id: "{self.question_id}"(已根据题目内容生成)
|
|
|
|
|
* content: 题目文本摘要(50字内)
|
|
|
|
|
* difficulty: 难度系数(1-5整数)
|
|
|
|
|
|
|
|
|
|
# 关系规则
|
|
|
|
|
1. 题目与知识点关系:
|
|
|
|
|
(q:Question)-[:TESTS_KNOWLEDGE]->(kp:KnowledgePoint)
|
|
|
|
|
需设置权重属性 weight(0.1-1.0)
|
|
|
|
|
|
|
|
|
|
2. 题目与能力点关系:
|
|
|
|
|
(q:Question)-[:REQUIRES_ABILITY]->(ab:AbilityPoint)
|
|
|
|
|
需设置权重属性 weight
|
|
|
|
|
|
|
|
|
|
# 生成步骤
|
|
|
|
|
1. 先创建约束(必须):
|
|
|
|
|
CREATE CONSTRAINT IF NOT EXISTS FOR (kp:KnowledgePoint) REQUIRE kp.id IS UNIQUE;
|
|
|
|
|
CREATE CONSTRAINT IF NOT EXISTS FOR (ab:AbilityPoint) REQUIRE ab.id IS UNIQUE;
|
|
|
|
|
|
|
|
|
|
2. 使用MERGE创建节点(禁止使用CREATE)
|
|
|
|
|
|
|
|
|
|
3. 最后创建关系(需先MATCH已存在节点)
|
|
|
|
|
|
|
|
|
|
# 当前题目信息
|
|
|
|
|
- 生成的问题ID: {self.question_id}
|
|
|
|
|
- 题目内容: "{self.content[:50]}..."(已截断)'''
|
|
|
|
|
"""生成限制性提示词"""
|
|
|
|
|
system_prompt = f'''# 严格生成规则
|
|
|
|
|
1. 仅允许使用以下预注册节点:
|
|
|
|
|
- 知识点列表(共{len(self.existing_knowledge)}个):
|
|
|
|
|
{self._format_node_list(self.existing_knowledge)}
|
|
|
|
|
- 能力点列表(共{len(self.existing_ability)}个):
|
|
|
|
|
{self._format_node_list(self.existing_ability)}
|
|
|
|
|
|
|
|
|
|
2. 必须遵守的Cypher模式:
|
|
|
|
|
MERGE (q:Question {{id: "{self.question_id}"}})
|
|
|
|
|
SET q.content = "题目内容摘要"
|
|
|
|
|
|
|
|
|
|
WITH q
|
|
|
|
|
MATCH (kp:KnowledgePoint {{id: "KP_xxxxxx"}})
|
|
|
|
|
MERGE (q)-[:TESTS_KNOWLEDGE {{weight: 0.8}}]->(kp)
|
|
|
|
|
|
|
|
|
|
MATCH (ab:AbilityPoint {{id: "AB_xxxxxx"}})
|
|
|
|
|
MERGE (q)-[:REQUIRES_ABILITY {{weight: 0.7}}]->(ab)
|
|
|
|
|
|
|
|
|
|
3. 绝对禁止:
|
|
|
|
|
- 使用CREATE创建新节点
|
|
|
|
|
- 修改已有节点属性
|
|
|
|
|
- 使用未注册的ID'''
|
|
|
|
|
|
|
|
|
|
return self.client.chat.completions.create(
|
|
|
|
|
model=MODEL_NAME,
|
|
|
|
@ -75,39 +76,87 @@ class KnowledgeGraph:
|
|
|
|
|
timeout=300
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def _format_node_list(self, nodes: Dict[str, str]) -> str:
|
|
|
|
|
"""格式化节点列表"""
|
|
|
|
|
if not nodes:
|
|
|
|
|
return " (无相关节点)"
|
|
|
|
|
|
|
|
|
|
sample = []
|
|
|
|
|
for i, (k, v) in enumerate(nodes.items()):
|
|
|
|
|
if i >= 5:
|
|
|
|
|
sample.append(f" ...(共{len(nodes)}个,仅显示前5个)")
|
|
|
|
|
break
|
|
|
|
|
sample.append(f" - {k}: {v}")
|
|
|
|
|
return '\n'.join(sample)
|
|
|
|
|
|
|
|
|
|
def _extract_cypher(self, content: str) -> str:
|
|
|
|
|
"""增强的Cypher提取(处理多代码块)"""
|
|
|
|
|
cypher_blocks = []
|
|
|
|
|
# 匹配所有cypher代码块(包含语言声明)
|
|
|
|
|
pattern = r"```(?:cypher)?\n(.*?)```"
|
|
|
|
|
|
|
|
|
|
for block in re.findall(pattern, content, re.DOTALL):
|
|
|
|
|
# 清理注释和空行
|
|
|
|
|
cleaned = [
|
|
|
|
|
line.split('//')[0].strip()
|
|
|
|
|
for line in block.split('\n')
|
|
|
|
|
if line.strip() and not line.strip().startswith('//')
|
|
|
|
|
]
|
|
|
|
|
"""安全提取Cypher"""
|
|
|
|
|
safe_blocks = []
|
|
|
|
|
for block in re.findall(r"```(?:cypher)?\n(.*?)```", content, re.DOTALL):
|
|
|
|
|
cleaned = self._sanitize_cypher(block)
|
|
|
|
|
if cleaned:
|
|
|
|
|
cypher_blocks.append('\n'.join(cleaned))
|
|
|
|
|
|
|
|
|
|
return ';\n\n'.join(cypher_blocks)
|
|
|
|
|
safe_blocks.append(cleaned)
|
|
|
|
|
return ';\n\n'.join(safe_blocks) if safe_blocks else ""
|
|
|
|
|
|
|
|
|
|
def _sanitize_cypher(self, cypher: str) -> str:
|
|
|
|
|
"""消毒Cypher语句"""
|
|
|
|
|
valid_lines = []
|
|
|
|
|
for line in cypher.split('\n'):
|
|
|
|
|
line = line.split('//')[0].strip()
|
|
|
|
|
if not line:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
# 检查非法操作
|
|
|
|
|
if re.search(r'\bCREATE\b', line, re.IGNORECASE):
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
# 验证节点ID
|
|
|
|
|
if not self._validate_ids(line):
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
# 验证权重范围
|
|
|
|
|
if not self._validate_weight(line):
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
valid_lines.append(line)
|
|
|
|
|
|
|
|
|
|
return '\n'.join(valid_lines) if valid_lines else ''
|
|
|
|
|
|
|
|
|
|
def _validate_ids(self, line: str) -> bool:
|
|
|
|
|
"""验证行内的所有ID"""
|
|
|
|
|
kp_ids = {id_.upper() for id_ in re.findall(r'kp_[\da-f]{6}', line, re.IGNORECASE)}
|
|
|
|
|
ab_ids = {id_.upper() for id_ in re.findall(r'ab_[\da-f]{6}', line, re.IGNORECASE)}
|
|
|
|
|
|
|
|
|
|
valid_kp = all(kp in self.existing_knowledge for kp in kp_ids)
|
|
|
|
|
valid_ab = all(ab in self.existing_ability for ab in ab_ids)
|
|
|
|
|
|
|
|
|
|
return valid_kp and valid_ab
|
|
|
|
|
|
|
|
|
|
def _validate_weight(self, line: str) -> bool:
|
|
|
|
|
"""验证关系权重是否合法"""
|
|
|
|
|
weight_match = re.search(r"weight\s*:\s*([0-9.]+)", line)
|
|
|
|
|
if weight_match:
|
|
|
|
|
try:
|
|
|
|
|
weight = float(weight_match.group(1))
|
|
|
|
|
return 0.0 <= weight <= 1.0
|
|
|
|
|
except ValueError:
|
|
|
|
|
return False
|
|
|
|
|
return True # 没有weight属性时视为合法
|
|
|
|
|
|
|
|
|
|
def run(self) -> Tuple[bool, str, str]:
|
|
|
|
|
"""执行生成流程(确保所有路径都有返回值)"""
|
|
|
|
|
"""执行安全生成流程"""
|
|
|
|
|
if not self.existing_knowledge or not self.existing_ability:
|
|
|
|
|
print("❌ 知识库或能力点为空,请检查数据库")
|
|
|
|
|
return False, "节点数据为空", ""
|
|
|
|
|
|
|
|
|
|
start_time = time.time()
|
|
|
|
|
spinner = ['⠋', '⠙', '⠹', '⠸', '⠼', '⠴', '⠦', '⠧', '⠇', '⠏']
|
|
|
|
|
content_buffer = []
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
print(f"🚀 开始生成知识点和能力点的总结和插入语句")
|
|
|
|
|
print(f"🚀 开始生成(知识点:{len(self.existing_knowledge)}个,能力点:{len(self.existing_ability)}个)")
|
|
|
|
|
stream = self._generate_stream()
|
|
|
|
|
|
|
|
|
|
# 添加流数据检查
|
|
|
|
|
if not stream:
|
|
|
|
|
print("\n❌ 生成失败:无法获取生成流")
|
|
|
|
|
return False, "生成流获取失败", ""
|
|
|
|
|
|
|
|
|
|
for idx, chunk in enumerate(stream):
|
|
|
|
|
print(f"\r{spinner[idx % 10]} 生成中({int(time.time() - start_time)}秒)", end="")
|
|
|
|
|
|
|
|
|
@ -119,20 +168,16 @@ class KnowledgeGraph:
|
|
|
|
|
print("\n\n📝 内容生成开始:")
|
|
|
|
|
print(content_chunk, end="", flush=True)
|
|
|
|
|
|
|
|
|
|
# 确保最终返回
|
|
|
|
|
if content_buffer:
|
|
|
|
|
full_content = ''.join(content_buffer)
|
|
|
|
|
cypher_script = self._extract_cypher(full_content)
|
|
|
|
|
|
|
|
|
|
print(f"\n\n✅ 生成成功!耗时 {int(time.time() - start_time)}秒")
|
|
|
|
|
print("\n================ 完整结果 ================")
|
|
|
|
|
print(full_content)
|
|
|
|
|
print("\n================ Cypher语句 ===============")
|
|
|
|
|
print(cypher_script if cypher_script else "未检测到Cypher语句")
|
|
|
|
|
print(f"\n\n✅ 生成完成!耗时 {int(time.time() - start_time)}秒")
|
|
|
|
|
print("\n================ 安全Cypher ===============")
|
|
|
|
|
print(cypher_script if cypher_script else "未通过安全检查")
|
|
|
|
|
print("==========================================")
|
|
|
|
|
return True, full_content, cypher_script
|
|
|
|
|
|
|
|
|
|
# 添加空内容处理
|
|
|
|
|
print("\n⚠️ 生成完成但未获取到有效内容")
|
|
|
|
|
return False, "空内容", ""
|
|
|
|
|
|
|
|
|
@ -142,14 +187,18 @@ class KnowledgeGraph:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
shiti_content = '''
|
|
|
|
|
下面是一道小学三年级的数学题目,巧求周长:
|
|
|
|
|
把7个完全相同的小长方形拼成如图的样子,已知每个小长方形的长是10厘米,则拼成的大长方形的周长是多少厘米?
|
|
|
|
|
# 测试用例
|
|
|
|
|
test_content = '''
|
|
|
|
|
题目:一个长方体的长是8厘米,宽是5厘米,高是3厘米,求它的表面积是多少平方厘米?
|
|
|
|
|
'''
|
|
|
|
|
kg = KnowledgeGraph(shiti_content)
|
|
|
|
|
success, result, cypher = kg.run()
|
|
|
|
|
|
|
|
|
|
if success and cypher:
|
|
|
|
|
with open("knowledge_graph.cypher", "w", encoding="utf-8") as f:
|
|
|
|
|
f.write(cypher)
|
|
|
|
|
print(f"\nCypher语句已保存至 knowledge_graph.cypher (题目ID: {kg.question_id})")
|
|
|
|
|
try:
|
|
|
|
|
kg = KnowledgeGraph(test_content)
|
|
|
|
|
success, result, cypher = kg.run()
|
|
|
|
|
|
|
|
|
|
if success and cypher:
|
|
|
|
|
with open("output.cypher", "w", encoding="utf-8") as f:
|
|
|
|
|
f.write(cypher)
|
|
|
|
|
print(f"\nCypher已保存至output.cypher(ID: {kg.question_id})")
|
|
|
|
|
except Exception as e:
|
|
|
|
|
print(f"程序初始化失败: {str(e)}")
|