You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

105 lines
3.1 KiB

import json
import pymysql
from py2neo import Graph
from Config import Config
# 连接MySQL数据库
def get_mysql_connection():
return pymysql.connect(
host=Config.MYSQL_HOST,
port=Config.MYSQL_PORT,
user=Config.MYSQL_USER,
password=Config.MYSQL_PASSWORD,
database=Config.MYSQL_DB_NAME,
charset='utf8mb4'
)
# 从MySQL获取知识点数据
def fetch_knowledge_points():
connection = get_mysql_connection()
try:
with connection.cursor() as cursor:
sql = "SELECT id, title, parent_id, is_leaf, prerequisite, related FROM knowledge_points"
cursor.execute(sql)
return cursor.fetchall()
finally:
connection.close()
# 生成Cypher语句
def generate_cypher(knowledge_points):
cypher = []
# 创建所有知识点节点
for point in knowledge_points:
node_id = point[0]
title = point[1].replace("'", "''")
is_leaf = point[3]
cypher.append(
f"MERGE (n:KnowledgePoint {{id: '{node_id}'}}) "
f"SET n.name = '{title}', n.is_leaf = {is_leaf};"
)
# 创建父子关系
for point in knowledge_points:
node_id = point[0]
parent_id = point[2]
if parent_id:
cypher.append(
f"MATCH (parent:KnowledgePoint {{id: '{parent_id}'}}), "
f"(child:KnowledgePoint {{id: '{node_id}'}}) "
f"MERGE (parent)-[:HAS_SUB_POINT]->(child);"
)
# 处理先修知识关系
for point in knowledge_points:
node_id = point[0]
prerequisite = json.loads(point[4] if point[4] else '[]')
for req in prerequisite:
req_id = req['id']
cypher.append(
f"MATCH (a:KnowledgePoint {{id: '{node_id}'}}), "
f"(b:KnowledgePoint {{id: '{req_id}'}}) "
f"MERGE (a)-[:PREREQUISITE]->(b);"
)
# 处理相关知识关系
for point in knowledge_points:
node_id = point[0]
related = json.loads(point[5] if point[5] else '[]')
for rel in related:
rel_id = rel['id']
cypher.append(
f"MATCH (a:KnowledgePoint {{id: '{node_id}'}}), "
f"(b:KnowledgePoint {{id: '{rel_id}'}}) "
f"MERGE (a)-[:RELATED]->(b);"
)
return '\n'.join(cypher)
# 主函数
def main():
# 连接Neo4j
graph = Graph(Config.NEO4J_URI, auth=Config.NEO4J_AUTH)
# 清空现有数据
graph.run("MATCH (n) DETACH DELETE n")
# 从MySQL获取数据
knowledge_points = fetch_knowledge_points()
# 生成Cypher语句列表
cypher_statements = generate_cypher(knowledge_points)
# 逐条执行Cypher语句
for statement in cypher_statements.split('\n'):
if statement.strip(): # 跳过空行
graph.run(statement)
print("知识图谱数据导入成功!")
if __name__ == '__main__':
main()