From 50fad8599f5ae7bbed3823f8a3f755516a64554b Mon Sep 17 00:00:00 2001 From: HuangHai <10402852@qq.com> Date: Sun, 6 Jul 2025 22:26:36 +0800 Subject: [PATCH] 'commit' --- .../Test/T3_Graph_visual_with_html.py | 45 ++++ .../Test/T4_graph_visual_with_neo4j.py | 208 ++++++++++++++++++ 2 files changed, 253 insertions(+) create mode 100644 dsRagAnything/Test/T3_Graph_visual_with_html.py create mode 100644 dsRagAnything/Test/T4_graph_visual_with_neo4j.py diff --git a/dsRagAnything/Test/T3_Graph_visual_with_html.py b/dsRagAnything/Test/T3_Graph_visual_with_html.py new file mode 100644 index 00000000..e6bbd7a1 --- /dev/null +++ b/dsRagAnything/Test/T3_Graph_visual_with_html.py @@ -0,0 +1,45 @@ +# 导入pipmaster模块用于检查并安装依赖包 +import pipmaster as pm + +# 检查并安装pyvis库(用于可视化网络图) +if not pm.is_installed("pyvis"): + pm.install("pyvis") +# 检查并安装networkx库(用于处理图结构数据) +if not pm.is_installed("networkx"): + pm.install("networkx") + +# 导入必要的库 +import networkx as nx # 用于创建和操作复杂的网络结构 +from pyvis.network import Network # 用于交互式网络可视化 +import random # 用于生成随机颜色 + +# 从GraphML文件读取知识图谱数据 +# 文件路径: ./dickens/graph_chunk_entity_relation.graphml +G = nx.read_graphml("Topic/Chinese/graph_chunk_entity_relation.graphml") + +# 创建pyvis网络可视化对象 +# 参数说明: +# height="100vh" - 设置可视化高度为100%视口高度 +# notebook=True - 设置为在notebook环境中使用 +net = Network(height="100vh", notebook=True) + +# 将networkx图转换为pyvis网络图 +net.from_nx(G) + +# 为每个节点设置随机颜色和提示信息 +for node in net.nodes: + # 生成随机十六进制颜色代码 + node["color"] = "#{:06x}".format(random.randint(0, 0xFFFFFF)) + # 如果节点有description属性,设置为鼠标悬停提示 + if "description" in node: + node["title"] = node["description"] + +# 为每条边设置提示信息 +for edge in net.edges: + # 如果边有description属性,设置为鼠标悬停提示 + if "description" in edge: + edge["title"] = edge["description"] + +# 生成并显示HTML格式的可视化结果 +# 输出文件: knowledge_graph.html +net.show("./Html/knowledge_graph.html") diff --git a/dsRagAnything/Test/T4_graph_visual_with_neo4j.py b/dsRagAnything/Test/T4_graph_visual_with_neo4j.py new file mode 100644 index 00000000..c47bc29e --- /dev/null +++ b/dsRagAnything/Test/T4_graph_visual_with_neo4j.py @@ -0,0 +1,208 @@ +# 导入必要的库 +import os # 用于文件路径操作 +import json # 用于JSON数据处理 +import xml.etree.ElementTree as ET # 用于解析XML文件 +from neo4j import GraphDatabase # Neo4j数据库驱动 +from Config.Config import * # 导入配置文件 +from Util.LightRagUtil import * + +# 数学 +# WORKING_DIR = "./Topic/Math" +# TXT_FILE = "小学数学教学中的若干问题.txt" + +# 苏轼 +WORKING_DIR = "./Topic/Chinese" +TXT_FILE = "sushi.txt" + +def xml_to_json(xml_file): + """ + 将GraphML格式的XML文件转换为JSON格式 + :param xml_file: 输入的XML文件路径 + :return: 包含节点和边的字典,或None(解析失败时) + """ + try: + # 解析XML文件 + tree = ET.parse(xml_file) + root = tree.getroot() + + # 打印根元素信息 + print(f"Root element: {root.tag}") + print(f"Root attributes: {root.attrib}") + + # 初始化数据结构 + data = {"nodes": [], "edges": []} + + # 定义XML命名空间 + namespace = {"": "http://graphml.graphdrawing.org/xmlns"} + + # 提取所有节点数据 + for node in root.findall(".//node", namespace): + node_data = { + "id": node.get("id").strip('"'), # 节点ID + "entity_type": node.find("./data[@key='d1']", namespace).text.strip('"') + if node.find("./data[@key='d1']", namespace) is not None + else "", # 实体类型 + "description": node.find("./data[@key='d2']", namespace).text + if node.find("./data[@key='d2']", namespace) is not None + else "", # 节点描述 + "source_id": node.find("./data[@key='d3']", namespace).text + if node.find("./data[@key='d3']", namespace) is not None + else "", # 源ID + } + data["nodes"].append(node_data) + + # 提取所有边数据 + for edge in root.findall(".//edge", namespace): + edge_data = { + "source": edge.get("source").strip('"'), # 源节点 + "target": edge.get("target").strip('"'), # 目标节点 + "weight": float(edge.find("./data[@key='d5']", namespace).text) + if edge.find("./data[@key='d5']", namespace) is not None + else 0.0, # 边权重 + "description": edge.find("./data[@key='d6']", namespace).text + if edge.find("./data[@key='d6']", namespace) is not None + else "", # 边描述 + "keywords": edge.find("./data[@key='d7']", namespace).text + if edge.find("./data[@key='d7']", namespace) is not None + else "", # 关键词 + "source_id": edge.find("./data[@key='d8']", namespace).text + if edge.find("./data[@key='d8']", namespace) is not None + else "", # 源ID + } + data["edges"].append(edge_data) + + # 打印找到的节点和边数量 + print(f"Found {len(data['nodes'])} nodes and {len(data['edges'])} edges") + + return data + except ET.ParseError as e: + print(f"Error parsing XML file: {e}") + return None + except Exception as e: + print(f"An error occurred: {e}") + return None + + +def convert_xml_to_json(xml_path, output_path): + """ + 将XML文件转换为JSON并保存输出 + :param xml_path: 输入XML文件路径 + :param output_path: 输出JSON文件路径 + :return: JSON数据或None(转换失败时) + """ + if not os.path.exists(xml_path): + print(f"Error: File not found - {xml_path}") + return None + + json_data = xml_to_json(xml_path) + if json_data: + with open(output_path, "w", encoding="utf-8") as f: + json.dump(json_data, f, ensure_ascii=False, indent=2) + print(f"JSON file created: {output_path}") + return json_data + else: + print("Failed to create JSON data") + return None + + +def process_in_batches(tx, query, data, batch_size): + """ + 批量处理数据并执行查询 + :param tx: Neo4j事务对象 + :param query: Cypher查询语句 + :param data: 要处理的数据 + :param batch_size: 每批处理的数据量 + """ + for i in range(0, len(data), batch_size): + batch = data[i: i + batch_size] + tx.run(query, {"nodes": batch} if "nodes" in query else {"edges": batch}) + + +def main(): + """主函数,执行知识图谱可视化流程""" + # 文件路径设置 + xml_file = os.path.join(WORKING_DIR, "graph_chunk_entity_relation.graphml") + json_file = os.path.join(WORKING_DIR, "graph_data.json") + + # 将XML转换为JSON + json_data = convert_xml_to_json(xml_file, json_file) + if json_data is None: + return + + # 加载节点和边数据 + nodes = json_data.get("nodes", []) + edges = json_data.get("edges", []) + + # Neo4j查询语句 + create_nodes_query = """ + UNWIND $nodes AS node + MERGE (e:Entity {id: node.id}) + SET e.entity_type = node.entity_type, + e.description = node.description, + e.source_id = node.source_id, + e.displayName = node.id + REMOVE e:Entity + WITH e, node + CALL apoc.create.addLabels(e, [node.id]) YIELD node AS labeledNode + RETURN count(*) + """ + + create_edges_query = """ + UNWIND $edges AS edge + MATCH (source {id: edge.source}) + MATCH (target {id: edge.target}) + WITH source, target, edge, + CASE + WHEN edge.keywords CONTAINS 'lead' THEN 'lead' + WHEN edge.keywords CONTAINS 'participate' THEN 'participate' + WHEN edge.keywords CONTAINS 'uses' THEN 'uses' + WHEN edge.keywords CONTAINS 'located' THEN 'located' + WHEN edge.keywords CONTAINS 'occurs' THEN 'occurs' + ELSE REPLACE(SPLIT(edge.keywords, ',')[0], '"', '') + END AS relType + CALL apoc.create.relationship(source, relType, { + weight: edge.weight, + description: edge.description, + keywords: edge.keywords, + source_id: edge.source_id + }, target) YIELD rel + RETURN count(*) + """ + + set_displayname_and_labels_query = """ + MATCH (n) + SET n.displayName = n.id + WITH n + CALL apoc.create.setLabels(n, [n.entity_type]) YIELD node + RETURN count(*) + """ + + # 创建Neo4j驱动连接 + driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) + + try: + # 执行批量查询 + with driver.session() as session: + # 批量插入节点 + session.execute_write( + process_in_batches, create_nodes_query, nodes, BATCH_SIZE_NODES + ) + + # 批量插入边 + session.execute_write( + process_in_batches, create_edges_query, edges, BATCH_SIZE_EDGES + ) + + # 设置显示名称和标签 + session.run(set_displayname_and_labels_query) + + except Exception as e: + print(f"Error occurred: {e}") + + finally: + # 关闭数据库连接 + driver.close() + + +if __name__ == "__main__": + main()