205 lines
7.0 KiB
Python
205 lines
7.0 KiB
Python
# 导入必要的库
|
||
import os # 用于文件路径操作
|
||
import json # 用于JSON数据处理
|
||
import xml.etree.ElementTree as ET # 用于解析XML文件
|
||
from neo4j import GraphDatabase # Neo4j数据库驱动
|
||
from Config.Config import * # 导入配置文件
|
||
from Util.Neo4jExecutor import Neo4jExecutor
|
||
|
||
# 动画学院
|
||
WORKING_DIR = "./Topic/DongHua"
|
||
|
||
def xml_to_json(xml_file):
|
||
"""
|
||
将GraphML格式的XML文件转换为JSON格式
|
||
:param xml_file: 输入的XML文件路径
|
||
:return: 包含节点和边的字典,或None(解析失败时)
|
||
"""
|
||
try:
|
||
# 解析XML文件
|
||
tree = ET.parse(xml_file)
|
||
root = tree.getroot()
|
||
|
||
# 打印根元素信息
|
||
print(f"Root element: {root.tag}")
|
||
print(f"Root attributes: {root.attrib}")
|
||
|
||
# 初始化数据结构
|
||
data = {"nodes": [], "edges": []}
|
||
|
||
# 定义XML命名空间
|
||
namespace = {"": "http://graphml.graphdrawing.org/xmlns"}
|
||
|
||
# 提取所有节点数据
|
||
for node in root.findall(".//node", namespace):
|
||
node_data = {
|
||
"id": node.get("id").strip('"'), # 节点ID
|
||
"entity_type": node.find("./data[@key='d1']", namespace).text.strip('"')
|
||
if node.find("./data[@key='d1']", namespace) is not None
|
||
else "", # 实体类型
|
||
"description": node.find("./data[@key='d2']", namespace).text
|
||
if node.find("./data[@key='d2']", namespace) is not None
|
||
else "", # 节点描述
|
||
"source_id": node.find("./data[@key='d3']", namespace).text
|
||
if node.find("./data[@key='d3']", namespace) is not None
|
||
else "", # 源ID
|
||
}
|
||
data["nodes"].append(node_data)
|
||
|
||
# 提取所有边数据
|
||
for edge in root.findall(".//edge", namespace):
|
||
edge_data = {
|
||
"source": edge.get("source").strip('"'), # 源节点
|
||
"target": edge.get("target").strip('"'), # 目标节点
|
||
"weight": float(edge.find("./data[@key='d5']", namespace).text)
|
||
if edge.find("./data[@key='d5']", namespace) is not None
|
||
else 0.0, # 边权重
|
||
"description": edge.find("./data[@key='d6']", namespace).text
|
||
if edge.find("./data[@key='d6']", namespace) is not None
|
||
else "", # 边描述
|
||
"keywords": edge.find("./data[@key='d7']", namespace).text
|
||
if edge.find("./data[@key='d7']", namespace) is not None
|
||
else "", # 关键词
|
||
"source_id": edge.find("./data[@key='d8']", namespace).text
|
||
if edge.find("./data[@key='d8']", namespace) is not None
|
||
else "", # 源ID
|
||
}
|
||
data["edges"].append(edge_data)
|
||
|
||
# 打印找到的节点和边数量
|
||
print(f"Found {len(data['nodes'])} nodes and {len(data['edges'])} edges")
|
||
|
||
return data
|
||
except ET.ParseError as e:
|
||
print(f"Error parsing XML file: {e}")
|
||
return None
|
||
except Exception as e:
|
||
print(f"An error occurred: {e}")
|
||
return None
|
||
|
||
|
||
def convert_xml_to_json(xml_path, output_path):
|
||
"""
|
||
将XML文件转换为JSON并保存输出
|
||
:param xml_path: 输入XML文件路径
|
||
:param output_path: 输出JSON文件路径
|
||
:return: JSON数据或None(转换失败时)
|
||
"""
|
||
if not os.path.exists(xml_path):
|
||
print(f"Error: File not found - {xml_path}")
|
||
return None
|
||
|
||
json_data = xml_to_json(xml_path)
|
||
if json_data:
|
||
with open(output_path, "w", encoding="utf-8") as f:
|
||
json.dump(json_data, f, ensure_ascii=False, indent=2)
|
||
print(f"JSON file created: {output_path}")
|
||
return json_data
|
||
else:
|
||
print("Failed to create JSON data")
|
||
return None
|
||
|
||
|
||
def process_in_batches(tx, query, data, batch_size):
|
||
"""
|
||
批量处理数据并执行查询
|
||
:param tx: Neo4j事务对象
|
||
:param query: Cypher查询语句
|
||
:param data: 要处理的数据
|
||
:param batch_size: 每批处理的数据量
|
||
"""
|
||
for i in range(0, len(data), batch_size):
|
||
batch = data[i: i + batch_size]
|
||
tx.run(query, {"nodes": batch} if "nodes" in query else {"edges": batch})
|
||
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# 创建Neo4jExecutor实例
|
||
executor = Neo4jExecutor.create_default()
|
||
executor.graph.run("MATCH (n) DETACH DELETE n")
|
||
print("清库成功")
|
||
|
||
# 文件路径设置
|
||
xml_file = os.path.join(WORKING_DIR, "graph_chunk_entity_relation.graphml")
|
||
json_file = os.path.join(WORKING_DIR, "graph_data.json")
|
||
|
||
# 将XML转换为JSON
|
||
json_data = convert_xml_to_json(xml_file, json_file)
|
||
if json_data is None:
|
||
exit(0)
|
||
|
||
# 加载节点和边数据
|
||
nodes = json_data.get("nodes", [])
|
||
edges = json_data.get("edges", [])
|
||
|
||
# Neo4j查询语句
|
||
create_nodes_query = """
|
||
UNWIND $nodes AS node
|
||
MERGE (e:Entity {id: node.id})
|
||
SET e.entity_type = node.entity_type,
|
||
e.description = node.description,
|
||
e.source_id = node.source_id,
|
||
e.displayName = node.id
|
||
REMOVE e:Entity
|
||
WITH e, node
|
||
CALL apoc.create.addLabels(e, [node.id]) YIELD node AS labeledNode
|
||
RETURN count(*)
|
||
"""
|
||
|
||
create_edges_query = """
|
||
UNWIND $edges AS edge
|
||
MATCH (source {id: edge.source})
|
||
MATCH (target {id: edge.target})
|
||
WITH source, target, edge,
|
||
CASE
|
||
WHEN edge.keywords CONTAINS 'lead' THEN 'lead'
|
||
WHEN edge.keywords CONTAINS 'participate' THEN 'participate'
|
||
WHEN edge.keywords CONTAINS 'uses' THEN 'uses'
|
||
WHEN edge.keywords CONTAINS 'located' THEN 'located'
|
||
WHEN edge.keywords CONTAINS 'occurs' THEN 'occurs'
|
||
ELSE REPLACE(SPLIT(edge.keywords, ',')[0], '"', '')
|
||
END AS relType
|
||
CALL apoc.create.relationship(source, relType, {
|
||
weight: edge.weight,
|
||
description: edge.description,
|
||
keywords: edge.keywords,
|
||
source_id: edge.source_id
|
||
}, target) YIELD rel
|
||
RETURN count(*)
|
||
"""
|
||
|
||
set_displayname_and_labels_query = """
|
||
MATCH (n)
|
||
SET n.displayName = n.id
|
||
WITH n
|
||
CALL apoc.create.setLabels(n, [n.entity_type]) YIELD node
|
||
RETURN count(*)
|
||
"""
|
||
|
||
# 创建Neo4j驱动连接
|
||
driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD))
|
||
|
||
try:
|
||
# 执行批量查询
|
||
with driver.session() as session:
|
||
# 批量插入节点
|
||
session.execute_write(
|
||
process_in_batches, create_nodes_query, nodes, BATCH_SIZE_NODES
|
||
)
|
||
|
||
# 批量插入边
|
||
session.execute_write(
|
||
process_in_batches, create_edges_query, edges, BATCH_SIZE_EDGES
|
||
)
|
||
|
||
# 设置显示名称和标签
|
||
session.run(set_displayname_and_labels_query)
|
||
|
||
except Exception as e:
|
||
print(f"Error occurred: {e}")
|
||
|
||
finally:
|
||
# 关闭数据库连接
|
||
driver.close()
|