You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

205 lines
7.0 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# 导入必要的库
import os # 用于文件路径操作
import json # 用于JSON数据处理
import xml.etree.ElementTree as ET # 用于解析XML文件
from neo4j import GraphDatabase # Neo4j数据库驱动
from Config.Config import * # 导入配置文件
from Util.Neo4jExecutor import Neo4jExecutor
# 动画学院
WORKING_DIR = "./Topic/DongHua"
def xml_to_json(xml_file):
"""
将GraphML格式的XML文件转换为JSON格式
:param xml_file: 输入的XML文件路径
:return: 包含节点和边的字典或None(解析失败时)
"""
try:
# 解析XML文件
tree = ET.parse(xml_file)
root = tree.getroot()
# 打印根元素信息
print(f"Root element: {root.tag}")
print(f"Root attributes: {root.attrib}")
# 初始化数据结构
data = {"nodes": [], "edges": []}
# 定义XML命名空间
namespace = {"": "http://graphml.graphdrawing.org/xmlns"}
# 提取所有节点数据
for node in root.findall(".//node", namespace):
node_data = {
"id": node.get("id").strip('"'), # 节点ID
"entity_type": node.find("./data[@key='d1']", namespace).text.strip('"')
if node.find("./data[@key='d1']", namespace) is not None
else "", # 实体类型
"description": node.find("./data[@key='d2']", namespace).text
if node.find("./data[@key='d2']", namespace) is not None
else "", # 节点描述
"source_id": node.find("./data[@key='d3']", namespace).text
if node.find("./data[@key='d3']", namespace) is not None
else "", # 源ID
}
data["nodes"].append(node_data)
# 提取所有边数据
for edge in root.findall(".//edge", namespace):
edge_data = {
"source": edge.get("source").strip('"'), # 源节点
"target": edge.get("target").strip('"'), # 目标节点
"weight": float(edge.find("./data[@key='d5']", namespace).text)
if edge.find("./data[@key='d5']", namespace) is not None
else 0.0, # 边权重
"description": edge.find("./data[@key='d6']", namespace).text
if edge.find("./data[@key='d6']", namespace) is not None
else "", # 边描述
"keywords": edge.find("./data[@key='d7']", namespace).text
if edge.find("./data[@key='d7']", namespace) is not None
else "", # 关键词
"source_id": edge.find("./data[@key='d8']", namespace).text
if edge.find("./data[@key='d8']", namespace) is not None
else "", # 源ID
}
data["edges"].append(edge_data)
# 打印找到的节点和边数量
print(f"Found {len(data['nodes'])} nodes and {len(data['edges'])} edges")
return data
except ET.ParseError as e:
print(f"Error parsing XML file: {e}")
return None
except Exception as e:
print(f"An error occurred: {e}")
return None
def convert_xml_to_json(xml_path, output_path):
"""
将XML文件转换为JSON并保存输出
:param xml_path: 输入XML文件路径
:param output_path: 输出JSON文件路径
:return: JSON数据或None(转换失败时)
"""
if not os.path.exists(xml_path):
print(f"Error: File not found - {xml_path}")
return None
json_data = xml_to_json(xml_path)
if json_data:
with open(output_path, "w", encoding="utf-8") as f:
json.dump(json_data, f, ensure_ascii=False, indent=2)
print(f"JSON file created: {output_path}")
return json_data
else:
print("Failed to create JSON data")
return None
def process_in_batches(tx, query, data, batch_size):
"""
批量处理数据并执行查询
:param tx: Neo4j事务对象
:param query: Cypher查询语句
:param data: 要处理的数据
:param batch_size: 每批处理的数据量
"""
for i in range(0, len(data), batch_size):
batch = data[i: i + batch_size]
tx.run(query, {"nodes": batch} if "nodes" in query else {"edges": batch})
if __name__ == "__main__":
# 创建Neo4jExecutor实例
executor = Neo4jExecutor.create_default()
executor.graph.run("MATCH (n) DETACH DELETE n")
print("清库成功")
# 文件路径设置
xml_file = os.path.join(WORKING_DIR, "graph_chunk_entity_relation.graphml")
json_file = os.path.join(WORKING_DIR, "graph_data.json")
# 将XML转换为JSON
json_data = convert_xml_to_json(xml_file, json_file)
if json_data is None:
exit(0)
# 加载节点和边数据
nodes = json_data.get("nodes", [])
edges = json_data.get("edges", [])
# Neo4j查询语句
create_nodes_query = """
UNWIND $nodes AS node
MERGE (e:Entity {id: node.id})
SET e.entity_type = node.entity_type,
e.description = node.description,
e.source_id = node.source_id,
e.displayName = node.id
REMOVE e:Entity
WITH e, node
CALL apoc.create.addLabels(e, [node.id]) YIELD node AS labeledNode
RETURN count(*)
"""
create_edges_query = """
UNWIND $edges AS edge
MATCH (source {id: edge.source})
MATCH (target {id: edge.target})
WITH source, target, edge,
CASE
WHEN edge.keywords CONTAINS 'lead' THEN 'lead'
WHEN edge.keywords CONTAINS 'participate' THEN 'participate'
WHEN edge.keywords CONTAINS 'uses' THEN 'uses'
WHEN edge.keywords CONTAINS 'located' THEN 'located'
WHEN edge.keywords CONTAINS 'occurs' THEN 'occurs'
ELSE REPLACE(SPLIT(edge.keywords, ',')[0], '"', '')
END AS relType
CALL apoc.create.relationship(source, relType, {
weight: edge.weight,
description: edge.description,
keywords: edge.keywords,
source_id: edge.source_id
}, target) YIELD rel
RETURN count(*)
"""
set_displayname_and_labels_query = """
MATCH (n)
SET n.displayName = n.id
WITH n
CALL apoc.create.setLabels(n, [n.entity_type]) YIELD node
RETURN count(*)
"""
# 创建Neo4j驱动连接
driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD))
try:
# 执行批量查询
with driver.session() as session:
# 批量插入节点
session.execute_write(
process_in_batches, create_nodes_query, nodes, BATCH_SIZE_NODES
)
# 批量插入边
session.execute_write(
process_in_batches, create_edges_query, edges, BATCH_SIZE_EDGES
)
# 设置显示名称和标签
session.run(set_displayname_and_labels_query)
except Exception as e:
print(f"Error occurred: {e}")
finally:
# 关闭数据库连接
driver.close()