|
|
# -*- coding: utf-8 -*-
|
|
|
import re
|
|
|
import time
|
|
|
import hashlib
|
|
|
from typing import Iterator, Tuple
|
|
|
from openai import OpenAI
|
|
|
from openai.types.chat import ChatCompletionChunk
|
|
|
from Config import *
|
|
|
|
|
|
|
|
|
class KnowledgeGraph:
|
|
|
def __init__(self, shiti_content: str):
|
|
|
self.shiti_content = shiti_content
|
|
|
self.client = OpenAI(api_key=MODEL_API_KEY, base_url=MODEL_API_URL)
|
|
|
self.question_id = self._generate_question_id()
|
|
|
|
|
|
def _generate_question_id(self) -> str:
|
|
|
"""生成题目唯一标识符"""
|
|
|
return hashlib.md5(self.shiti_content.encode()).hexdigest()[:8]
|
|
|
|
|
|
def _generate_stream(self) -> Iterator[ChatCompletionChunk]:
|
|
|
"""流式生成内容(优化约束和MERGE)"""
|
|
|
system_prompt = f'''
|
|
|
一、总结本题有哪些知识点和能力点。
|
|
|
二、将总结出的知识点,能力点等信息,按以下要求生成Neo4j 5.26+的Cypher语句:
|
|
|
1. 必须包含约束创建:
|
|
|
CREATE CONSTRAINT IF NOT EXISTS FOR (kp:KnowledgePoint) REQUIRE kp.id IS UNIQUE;
|
|
|
CREATE CONSTRAINT IF NOT EXISTS FOR (ab:AbilityPoint) REQUIRE ab.id IS UNIQUE;
|
|
|
|
|
|
2. 节点使用MERGE并包含唯一ID:
|
|
|
// 知识点节点
|
|
|
MERGE (kp:KnowledgePoint {{id: "KP_101"}})
|
|
|
SET kp.name = "长方形周长计算",
|
|
|
kp.level = "小学"
|
|
|
|
|
|
// 题目节点(使用生成的ID:{self.question_id})
|
|
|
MERGE (q:Question {{id: "{self.question_id}"}})
|
|
|
SET q.content = "巧求周长:7个相同小长方形拼图求周长",
|
|
|
q.difficulty = 3
|
|
|
|
|
|
3. 关系基于已存在节点:
|
|
|
MATCH (q:Question {{id: "{self.question_id}"}}), (kp:KnowledgePoint {{id: "KP_101"}})
|
|
|
MERGE (q)-[r:TESTS_KNOWLEDGE]->(kp)
|
|
|
SET r.weight = 0.8'''
|
|
|
|
|
|
return self.client.chat.completions.create(
|
|
|
model=MODEL_NAME,
|
|
|
messages=[
|
|
|
{"role": "system", "content": system_prompt},
|
|
|
{"role": "user", "content": self.shiti_content}
|
|
|
],
|
|
|
stream=True,
|
|
|
timeout=300
|
|
|
)
|
|
|
|
|
|
def _extract_cypher(self, content: str) -> str:
|
|
|
"""增强的Cypher提取(处理多代码块)"""
|
|
|
cypher_blocks = []
|
|
|
# 匹配所有cypher代码块(包含语言声明)
|
|
|
pattern = r"```(?:cypher)?\n(.*?)```"
|
|
|
|
|
|
for block in re.findall(pattern, content, re.DOTALL):
|
|
|
# 清理注释和空行
|
|
|
cleaned = [
|
|
|
line.split('//')[0].strip()
|
|
|
for line in block.split('\n')
|
|
|
if line.strip() and not line.strip().startswith('//')
|
|
|
]
|
|
|
if cleaned:
|
|
|
cypher_blocks.append('\n'.join(cleaned))
|
|
|
|
|
|
return ';\n\n'.join(cypher_blocks)
|
|
|
|
|
|
def run(self) -> Tuple[bool, str, str]:
|
|
|
"""执行生成流程(确保所有路径都有返回值)"""
|
|
|
start_time = time.time()
|
|
|
spinner = ['⠋', '⠙', '⠹', '⠸', '⠼', '⠴', '⠦', '⠧', '⠇', '⠏']
|
|
|
content_buffer = []
|
|
|
cypher_script = ""
|
|
|
|
|
|
try:
|
|
|
print(f"🚀 开始生成知识点和能力点的总结和插入语句")
|
|
|
stream = self._generate_stream()
|
|
|
|
|
|
# 添加流数据检查
|
|
|
if not stream:
|
|
|
print("\n❌ 生成失败:无法获取生成流")
|
|
|
return False, "生成流获取失败", ""
|
|
|
|
|
|
for idx, chunk in enumerate(stream):
|
|
|
print(f"\r{spinner[idx % 10]} 生成中({int(time.time() - start_time)}秒)", end="")
|
|
|
|
|
|
if chunk.choices and chunk.choices[0].delta.content:
|
|
|
content_chunk = chunk.choices[0].delta.content
|
|
|
content_buffer.append(content_chunk)
|
|
|
|
|
|
if len(content_buffer) == 1:
|
|
|
print("\n\n📝 内容生成开始:")
|
|
|
print(content_chunk, end="", flush=True)
|
|
|
|
|
|
# 确保最终返回
|
|
|
if content_buffer:
|
|
|
full_content = ''.join(content_buffer)
|
|
|
cypher_script = self._extract_cypher(full_content)
|
|
|
|
|
|
print(f"\n\n✅ 生成成功!耗时 {int(time.time() - start_time)}秒")
|
|
|
print("\n================ 完整结果 ================")
|
|
|
print(full_content)
|
|
|
print("\n================ Cypher语句 ===============")
|
|
|
print(cypher_script if cypher_script else "未检测到Cypher语句")
|
|
|
print("==========================================")
|
|
|
return True, full_content, cypher_script
|
|
|
|
|
|
# 添加空内容处理
|
|
|
print("\n⚠️ 生成完成但未获取到有效内容")
|
|
|
return False, "空内容", ""
|
|
|
|
|
|
except Exception as e:
|
|
|
print(f"\n\n❌ 生成失败:{str(e)}")
|
|
|
return False, str(e), ""
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
shiti_content = '''
|
|
|
下面是一道小学三年级的数学题目,巧求周长:
|
|
|
把7个完全相同的小长方形拼成如图的样子,已知每个小长方形的长是10厘米,则拼成的大长方形的周长是多少厘米?
|
|
|
'''
|
|
|
kg = KnowledgeGraph(shiti_content)
|
|
|
success, result, cypher = kg.run()
|
|
|
|
|
|
if success and cypher:
|
|
|
with open("knowledge_graph.cypher", "w", encoding="utf-8") as f:
|
|
|
f.write(cypher)
|
|
|
print(f"\nCypher语句已保存至 knowledge_graph.cypher (题目ID: {kg.question_id})") |