You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

134 lines
5.3 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# -*- coding: utf-8 -*-
import re
import time
import hashlib
from typing import Iterator, Tuple
from openai import OpenAI
from openai.types.chat import ChatCompletionChunk
from Config import *
class KnowledgeGraph:
def __init__(self, shiti_content: str):
self.shiti_content = shiti_content
self.client = OpenAI(api_key=MODEL_API_KEY, base_url=MODEL_API_URL)
self.question_id = self._generate_question_id()
def _generate_question_id(self) -> str:
"""生成题目唯一标识符"""
return hashlib.md5(self.shiti_content.encode()).hexdigest()[:8]
def _generate_stream(self) -> Iterator[ChatCompletionChunk]:
"""流式生成内容优化约束和MERGE"""
system_prompt = f'''
一、总结本题有哪些知识点和能力点。
二、将总结出的知识点能力点等信息按以下要求生成Neo4j 5.26+的Cypher语句
1. 必须包含约束创建:
CREATE CONSTRAINT IF NOT EXISTS FOR (kp:KnowledgePoint) REQUIRE kp.id IS UNIQUE;
CREATE CONSTRAINT IF NOT EXISTS FOR (ab:AbilityPoint) REQUIRE ab.id IS UNIQUE;
2. 节点使用MERGE并包含唯一ID
// 知识点节点
MERGE (kp:KnowledgePoint {{id: "KP_101"}})
SET kp.name = "长方形周长计算",
kp.level = "小学"
// 题目节点使用生成的ID{self.question_id}
MERGE (q:Question {{id: "{self.question_id}"}})
SET q.content = "巧求周长7个相同小长方形拼图求周长",
q.difficulty = 3
3. 关系基于已存在节点:
MATCH (q:Question {{id: "{self.question_id}"}}), (kp:KnowledgePoint {{id: "KP_101"}})
MERGE (q)-[r:TESTS_KNOWLEDGE]->(kp)
SET r.weight = 0.8'''
return self.client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": self.shiti_content}
],
stream=True,
timeout=300
)
def _extract_cypher(self, content: str) -> str:
"""增强的Cypher提取处理多代码块"""
cypher_blocks = []
# 匹配所有cypher代码块包含语言声明
pattern = r"```(?:cypher)?\n(.*?)```"
for block in re.findall(pattern, content, re.DOTALL):
# 清理注释和空行
cleaned = [
line.split('//')[0].strip()
for line in block.split('\n')
if line.strip() and not line.strip().startswith('//')
]
if cleaned:
cypher_blocks.append('\n'.join(cleaned))
return ';\n\n'.join(cypher_blocks)
def run(self) -> Tuple[bool, str, str]:
"""执行生成流程(确保所有路径都有返回值)"""
start_time = time.time()
spinner = ['', '', '', '', '', '', '', '', '', '']
content_buffer = []
cypher_script = ""
try:
print(f"🚀 开始生成知识点和能力点的总结和插入语句")
stream = self._generate_stream()
# 添加流数据检查
if not stream:
print("\n❌ 生成失败:无法获取生成流")
return False, "生成流获取失败", ""
for idx, chunk in enumerate(stream):
print(f"\r{spinner[idx % 10]} 生成中({int(time.time() - start_time)}秒)", end="")
if chunk.choices and chunk.choices[0].delta.content:
content_chunk = chunk.choices[0].delta.content
content_buffer.append(content_chunk)
if len(content_buffer) == 1:
print("\n\n📝 内容生成开始:")
print(content_chunk, end="", flush=True)
# 确保最终返回
if content_buffer:
full_content = ''.join(content_buffer)
cypher_script = self._extract_cypher(full_content)
print(f"\n\n✅ 生成成功!耗时 {int(time.time() - start_time)}")
print("\n================ 完整结果 ================")
print(full_content)
print("\n================ Cypher语句 ===============")
print(cypher_script if cypher_script else "未检测到Cypher语句")
print("==========================================")
return True, full_content, cypher_script
# 添加空内容处理
print("\n⚠️ 生成完成但未获取到有效内容")
return False, "空内容", ""
except Exception as e:
print(f"\n\n❌ 生成失败:{str(e)}")
return False, str(e), ""
if __name__ == '__main__':
shiti_content = '''
下面是一道小学三年级的数学题目,巧求周长:
把7个完全相同的小长方形拼成如图的样子已知每个小长方形的长是10厘米则拼成的大长方形的周长是多少厘米
'''
kg = KnowledgeGraph(shiti_content)
success, result, cypher = kg.run()
if success and cypher:
with open("knowledge_graph.cypher", "w", encoding="utf-8") as f:
f.write(cypher)
print(f"\nCypher语句已保存至 knowledge_graph.cypher (题目ID: {kg.question_id})")