main
黄海 5 months ago
parent 79bad13113
commit f9bdf90ff6

@ -20,4 +20,8 @@ ALY_SK='oizcTOZ8izbGUouboC00RcmGE8vBQ1'
# 正确路径拼接方式
mdWorkingPath = Path(__file__).parent / 'md-file' / 'readme'
DEFAULT_TEMPLATE = mdWorkingPath / 'default.md' # 使用 / 运算符
DEFAULT_OUTPUT_DIR = mdWorkingPath / 'output' # 使用 / 运算符
DEFAULT_OUTPUT_DIR = mdWorkingPath / 'output' # 使用 / 运算符
# 请在Config.py中配置以下参数
NEO4J_URI = "neo4j://10.10.21.20:7687"
NEO4J_AUTH = ("neo4j", "DsideaL4r5t6y7u")

@ -2,68 +2,69 @@
import re
import time
import hashlib
from typing import Iterator, Tuple
from typing import Iterator, Tuple, Dict
from py2neo import Graph
from openai import OpenAI
from openai.types.chat import ChatCompletionChunk
from Config import *
class KnowledgeGraph:
def __init__(self, content: str):
self.content = content
self.client = OpenAI(api_key=MODEL_API_KEY, base_url=MODEL_API_URL)
self.question_id = self._generate_question_id()
self.graph = self._init_graph_connection()
self.existing_knowledge = self._fetch_existing_nodes("KnowledgePoint")
self.existing_ability = self._fetch_existing_nodes("AbilityPoint")
self.client = OpenAI(api_key=MODEL_API_KEY, base_url=MODEL_API_URL)
def _init_graph_connection(self) -> Graph:
"""初始化并测试数据库连接"""
try:
graph = Graph(NEO4J_URI, auth=NEO4J_AUTH)
graph.run("RETURN 1").data()
print("✅ Neo4j连接成功")
return graph
except Exception as e:
raise ConnectionError(f"❌ 数据库连接失败: {str(e)}")
def _generate_question_id(self) -> str:
"""生成题目唯一标识符"""
return hashlib.md5(self.content.encode()).hexdigest()[:8]
def _fetch_existing_nodes(self, label: str) -> Dict[str, str]:
"""从Neo4j获取已有节点"""
try:
cypher = f"MATCH (n:{label}) RETURN n.id as id, n.name as name"
result = self.graph.run(cypher).data()
return {item['id']: item['name'] for item in result}
except Exception as e:
print(f"❌ 节点查询失败: {str(e)}")
return {}
def _generate_stream(self) -> Iterator[ChatCompletionChunk]:
"""动态化提示词版本"""
system_prompt = f'''请根据题目内容生成Neo4j Cypher语句严格遵循以下规则
# 节点创建规范
1. 知识点节点
- 标签: KnowledgePoint
- 必须属性:
* id: "KP_" + 知识点名称的MD5前6位示例name="分数运算" id="KP_ae3b8c"
* name: 知识点名称从题目内容中提取
* level: 学段小学/初中/高中
2. 能力点节点
- 标签: AbilityPoint
- 必须属性:
* id: "AB_" + 能力名称的MD5前6位
* name: 能力点名称
* category: 能力类型计算/推理/空间想象等
3. 题目节点
- 标签: Question
- 必须属性:
* id: "{self.question_id}"已根据题目内容生成
* content: 题目文本摘要50字内
* difficulty: 难度系数1-5整数
# 关系规则
1. 题目与知识点关系
(q:Question)-[:TESTS_KNOWLEDGE]->(kp:KnowledgePoint)
需设置权重属性 weight0.1-1.0
2. 题目与能力点关系
(q:Question)-[:REQUIRES_ABILITY]->(ab:AbilityPoint)
需设置权重属性 weight
# 生成步骤
1. 先创建约束必须
CREATE CONSTRAINT IF NOT EXISTS FOR (kp:KnowledgePoint) REQUIRE kp.id IS UNIQUE;
CREATE CONSTRAINT IF NOT EXISTS FOR (ab:AbilityPoint) REQUIRE ab.id IS UNIQUE;
2. 使用MERGE创建节点禁止使用CREATE
3. 最后创建关系需先MATCH已存在节点
# 当前题目信息
- 生成的问题ID: {self.question_id}
- 题目内容: "{self.content[:50]}..."已截断'''
"""生成限制性提示词"""
system_prompt = f'''# 严格生成规则
1. 仅允许使用以下预注册节点
- 知识点列表{len(self.existing_knowledge)}
{self._format_node_list(self.existing_knowledge)}
- 能力点列表{len(self.existing_ability)}
{self._format_node_list(self.existing_ability)}
2. 必须遵守的Cypher模式
MERGE (q:Question {{id: "{self.question_id}"}})
SET q.content = "题目内容摘要"
WITH q
MATCH (kp:KnowledgePoint {{id: "KP_xxxxxx"}})
MERGE (q)-[:TESTS_KNOWLEDGE {{weight: 0.8}}]->(kp)
MATCH (ab:AbilityPoint {{id: "AB_xxxxxx"}})
MERGE (q)-[:REQUIRES_ABILITY {{weight: 0.7}}]->(ab)
3. 绝对禁止
- 使用CREATE创建新节点
- 修改已有节点属性
- 使用未注册的ID'''
return self.client.chat.completions.create(
model=MODEL_NAME,
@ -75,39 +76,87 @@ class KnowledgeGraph:
timeout=300
)
def _format_node_list(self, nodes: Dict[str, str]) -> str:
"""格式化节点列表"""
if not nodes:
return " (无相关节点)"
sample = []
for i, (k, v) in enumerate(nodes.items()):
if i >= 5:
sample.append(f" ...(共{len(nodes)}仅显示前5个")
break
sample.append(f" - {k}: {v}")
return '\n'.join(sample)
def _extract_cypher(self, content: str) -> str:
"""增强的Cypher提取处理多代码块"""
cypher_blocks = []
# 匹配所有cypher代码块包含语言声明
pattern = r"```(?:cypher)?\n(.*?)```"
for block in re.findall(pattern, content, re.DOTALL):
# 清理注释和空行
cleaned = [
line.split('//')[0].strip()
for line in block.split('\n')
if line.strip() and not line.strip().startswith('//')
]
"""安全提取Cypher"""
safe_blocks = []
for block in re.findall(r"```(?:cypher)?\n(.*?)```", content, re.DOTALL):
cleaned = self._sanitize_cypher(block)
if cleaned:
cypher_blocks.append('\n'.join(cleaned))
return ';\n\n'.join(cypher_blocks)
safe_blocks.append(cleaned)
return ';\n\n'.join(safe_blocks) if safe_blocks else ""
def _sanitize_cypher(self, cypher: str) -> str:
"""消毒Cypher语句"""
valid_lines = []
for line in cypher.split('\n'):
line = line.split('//')[0].strip()
if not line:
continue
# 检查非法操作
if re.search(r'\bCREATE\b', line, re.IGNORECASE):
continue
# 验证节点ID
if not self._validate_ids(line):
continue
# 验证权重范围
if not self._validate_weight(line):
continue
valid_lines.append(line)
return '\n'.join(valid_lines) if valid_lines else ''
def _validate_ids(self, line: str) -> bool:
"""验证行内的所有ID"""
kp_ids = {id_.upper() for id_ in re.findall(r'kp_[\da-f]{6}', line, re.IGNORECASE)}
ab_ids = {id_.upper() for id_ in re.findall(r'ab_[\da-f]{6}', line, re.IGNORECASE)}
valid_kp = all(kp in self.existing_knowledge for kp in kp_ids)
valid_ab = all(ab in self.existing_ability for ab in ab_ids)
return valid_kp and valid_ab
def _validate_weight(self, line: str) -> bool:
"""验证关系权重是否合法"""
weight_match = re.search(r"weight\s*:\s*([0-9.]+)", line)
if weight_match:
try:
weight = float(weight_match.group(1))
return 0.0 <= weight <= 1.0
except ValueError:
return False
return True # 没有weight属性时视为合法
def run(self) -> Tuple[bool, str, str]:
"""执行生成流程(确保所有路径都有返回值)"""
"""执行安全生成流程"""
if not self.existing_knowledge or not self.existing_ability:
print("❌ 知识库或能力点为空,请检查数据库")
return False, "节点数据为空", ""
start_time = time.time()
spinner = ['', '', '', '', '', '', '', '', '', '']
content_buffer = []
try:
print(f"🚀 开始生成知识点和能力点的总结和插入语句")
print(f"🚀 开始生成(知识点:{len(self.existing_knowledge)}个,能力点:{len(self.existing_ability)}个)")
stream = self._generate_stream()
# 添加流数据检查
if not stream:
print("\n❌ 生成失败:无法获取生成流")
return False, "生成流获取失败", ""
for idx, chunk in enumerate(stream):
print(f"\r{spinner[idx % 10]} 生成中({int(time.time() - start_time)}秒)", end="")
@ -119,20 +168,16 @@ class KnowledgeGraph:
print("\n\n📝 内容生成开始:")
print(content_chunk, end="", flush=True)
# 确保最终返回
if content_buffer:
full_content = ''.join(content_buffer)
cypher_script = self._extract_cypher(full_content)
print(f"\n\n✅ 生成成功!耗时 {int(time.time() - start_time)}")
print("\n================ 完整结果 ================")
print(full_content)
print("\n================ Cypher语句 ===============")
print(cypher_script if cypher_script else "未检测到Cypher语句")
print(f"\n\n✅ 生成完成!耗时 {int(time.time() - start_time)}")
print("\n================ 安全Cypher ===============")
print(cypher_script if cypher_script else "未通过安全检查")
print("==========================================")
return True, full_content, cypher_script
# 添加空内容处理
print("\n⚠️ 生成完成但未获取到有效内容")
return False, "空内容", ""
@ -142,14 +187,18 @@ class KnowledgeGraph:
if __name__ == '__main__':
shiti_content = '''
下面是一道小学三年级的数学题目巧求周长
把7个完全相同的小长方形拼成如图的样子已知每个小长方形的长是10厘米则拼成的大长方形的周长是多少厘米
# 测试用例
test_content = '''
题目一个长方体的长是8厘米宽是5厘米高是3厘米求它的表面积是多少平方厘米
'''
kg = KnowledgeGraph(shiti_content)
success, result, cypher = kg.run()
if success and cypher:
with open("knowledge_graph.cypher", "w", encoding="utf-8") as f:
f.write(cypher)
print(f"\nCypher语句已保存至 knowledge_graph.cypher (题目ID: {kg.question_id})")
try:
kg = KnowledgeGraph(test_content)
success, result, cypher = kg.run()
if success and cypher:
with open("output.cypher", "w", encoding="utf-8") as f:
f.write(cypher)
print(f"\nCypher已保存至output.cypherID: {kg.question_id}")
except Exception as e:
print(f"程序初始化失败: {str(e)}")

@ -2,7 +2,7 @@
from py2neo import Graph
import re
from Util import *
from Config import *
class K2_Neo4jExecutor:
def __init__(self, uri, auth):
@ -36,8 +36,8 @@ class K2_Neo4jExecutor:
if __name__ == '__main__':
executor = K2_Neo4jExecutor(
uri="neo4j://10.10.21.20:7687",
auth=("neo4j", "DsideaL4r5t6y7u")
uri=NEO4J_URI,
auth=NEO4J_AUTH
)
# 清库
clear(executor.graph)

Loading…
Cancel
Save