diff --git a/dsRagAnything/Config/Config.py b/dsRagAnything/Config/Config.py index 534265a9..5ac7e634 100644 --- a/dsRagAnything/Config/Config.py +++ b/dsRagAnything/Config/Config.py @@ -6,15 +6,15 @@ EMBED_DIM = 1024 EMBED_MAX_TOKEN_SIZE = 8192 # 大模型 【DeepSeek深度求索官方】 -LLM_API_KEY="sk-44ae895eeb614aa1a9c6460579e322f1" +LLM_API_KEY = "sk-44ae895eeb614aa1a9c6460579e322f1" LLM_BASE_URL = "https://api.deepseek.com" LLM_MODEL_NAME = "deepseek-chat" # 阿里云提供的大模型服务 -#LLM_API_KEY="sk-f6da0c787eff4b0389e4ad03a35a911f" -#LLM_BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1" -#LLM_MODEL_NAME = "qwen-plus" # 不要使用通义千问,会导致化学方程式不正确! -#LLM_MODEL_NAME = "deepseek-v3" +# LLM_API_KEY="sk-f6da0c787eff4b0389e4ad03a35a911f" +# LLM_BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1" +# LLM_MODEL_NAME = "qwen-plus" # 不要使用通义千问,会导致化学方程式不正确! +# LLM_MODEL_NAME = "deepseek-v3" # 视觉模型 VISION_API_KEY = "sk-pbqibyjwhrgmnlsmdygplahextfaclgnedetybccknxojlyl" @@ -23,10 +23,11 @@ VISION_MODEL_NAME = "GLM-4.1V-9B-Thinking" # Neo4j NEO4J_URI = "bolt://localhost:7687" -NEO4J_USERNAME="neo4j" -NEO4J_PASSWORD="DsideaL147258369" +NEO4J_USERNAME = "neo4j" +NEO4J_PASSWORD = "DsideaL147258369" +NEO4J_AUTH = (NEO4J_USERNAME, NEO4J_PASSWORD) # 写入Neo4j的批次大小 -BATCH_SIZE_NODES=100 +BATCH_SIZE_NODES = 100 # 写入Neo4j的批次大小 -BATCH_SIZE_EDGES=100 \ No newline at end of file +BATCH_SIZE_EDGES = 100 diff --git a/dsRagAnything/Config/__pycache__/Config.cpython-310.pyc b/dsRagAnything/Config/__pycache__/Config.cpython-310.pyc index 8a6d0825..7a234c82 100644 Binary files a/dsRagAnything/Config/__pycache__/Config.cpython-310.pyc and b/dsRagAnything/Config/__pycache__/Config.cpython-310.pyc differ diff --git a/dsRagAnything/T3_WriteToNeo4j.py b/dsRagAnything/T3_WriteToNeo4j.py index fd6e61db..c3e8ebe4 100644 --- a/dsRagAnything/T3_WriteToNeo4j.py +++ b/dsRagAnything/T3_WriteToNeo4j.py @@ -4,10 +4,10 @@ import json # 用于JSON数据处理 import xml.etree.ElementTree as ET # 用于解析XML文件 from neo4j import GraphDatabase # Neo4j数据库驱动 from Config.Config import * # 导入配置文件 +from Util.Neo4jExecutor import Neo4jExecutor # 动画学院 WORKING_DIR = "./Topic/DongHua" -#TXT_FILE = "sushi.txt" def xml_to_json(xml_file): """ @@ -113,7 +113,13 @@ def process_in_batches(tx, query, data, batch_size): tx.run(query, {"nodes": batch} if "nodes" in query else {"edges": batch}) -def main(): + +if __name__ == "__main__": + # 创建Neo4jExecutor实例 + executor = Neo4jExecutor.create_default() + executor.graph.run("MATCH (n) DETACH DELETE n") + print("清库成功") + # 文件路径设置 xml_file = os.path.join(WORKING_DIR, "graph_chunk_entity_relation.graphml") json_file = os.path.join(WORKING_DIR, "graph_data.json") @@ -121,7 +127,7 @@ def main(): # 将XML转换为JSON json_data = convert_xml_to_json(xml_file, json_file) if json_data is None: - return + exit(0) # 加载节点和边数据 nodes = json_data.get("nodes", []) @@ -129,47 +135,47 @@ def main(): # Neo4j查询语句 create_nodes_query = """ - UNWIND $nodes AS node - MERGE (e:Entity {id: node.id}) - SET e.entity_type = node.entity_type, - e.description = node.description, - e.source_id = node.source_id, - e.displayName = node.id - REMOVE e:Entity - WITH e, node - CALL apoc.create.addLabels(e, [node.id]) YIELD node AS labeledNode - RETURN count(*) - """ + UNWIND $nodes AS node + MERGE (e:Entity {id: node.id}) + SET e.entity_type = node.entity_type, + e.description = node.description, + e.source_id = node.source_id, + e.displayName = node.id + REMOVE e:Entity + WITH e, node + CALL apoc.create.addLabels(e, [node.id]) YIELD node AS labeledNode + RETURN count(*) + """ create_edges_query = """ - UNWIND $edges AS edge - MATCH (source {id: edge.source}) - MATCH (target {id: edge.target}) - WITH source, target, edge, - CASE - WHEN edge.keywords CONTAINS 'lead' THEN 'lead' - WHEN edge.keywords CONTAINS 'participate' THEN 'participate' - WHEN edge.keywords CONTAINS 'uses' THEN 'uses' - WHEN edge.keywords CONTAINS 'located' THEN 'located' - WHEN edge.keywords CONTAINS 'occurs' THEN 'occurs' - ELSE REPLACE(SPLIT(edge.keywords, ',')[0], '"', '') - END AS relType - CALL apoc.create.relationship(source, relType, { - weight: edge.weight, - description: edge.description, - keywords: edge.keywords, - source_id: edge.source_id - }, target) YIELD rel - RETURN count(*) - """ + UNWIND $edges AS edge + MATCH (source {id: edge.source}) + MATCH (target {id: edge.target}) + WITH source, target, edge, + CASE + WHEN edge.keywords CONTAINS 'lead' THEN 'lead' + WHEN edge.keywords CONTAINS 'participate' THEN 'participate' + WHEN edge.keywords CONTAINS 'uses' THEN 'uses' + WHEN edge.keywords CONTAINS 'located' THEN 'located' + WHEN edge.keywords CONTAINS 'occurs' THEN 'occurs' + ELSE REPLACE(SPLIT(edge.keywords, ',')[0], '"', '') + END AS relType + CALL apoc.create.relationship(source, relType, { + weight: edge.weight, + description: edge.description, + keywords: edge.keywords, + source_id: edge.source_id + }, target) YIELD rel + RETURN count(*) + """ set_displayname_and_labels_query = """ - MATCH (n) - SET n.displayName = n.id - WITH n - CALL apoc.create.setLabels(n, [n.entity_type]) YIELD node - RETURN count(*) - """ + MATCH (n) + SET n.displayName = n.id + WITH n + CALL apoc.create.setLabels(n, [n.entity_type]) YIELD node + RETURN count(*) + """ # 创建Neo4j驱动连接 driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD)) @@ -196,7 +202,3 @@ def main(): finally: # 关闭数据库连接 driver.close() - - -if __name__ == "__main__": - main() diff --git a/dsRagAnything/Util/Neo4jExecutor.py b/dsRagAnything/Util/Neo4jExecutor.py new file mode 100644 index 00000000..28f55cdb --- /dev/null +++ b/dsRagAnything/Util/Neo4jExecutor.py @@ -0,0 +1,109 @@ +import re +from typing import List, Tuple + +from py2neo import Graph, Node, Relationship, Subgraph +from Config import Config + +def clear(db): + # 清空数据 + db.run("MATCH (n) DETACH DELETE n") + + # 分步删除约束和索引 + try: + # 删除约束 + constraints = db.run("SHOW CONSTRAINTS YIELD name").data() + for constr in constraints: + db.run(f"DROP CONSTRAINT `{constr['name']}`") + + # 删除索引 + indexes = db.run("SHOW INDEXES YIELD name, type WHERE type <> 'LOOKUP'").data() + for idx in indexes: + db.run(f"DROP INDEX `{idx['name']}`") + except Exception as e: + print(f"删除操作失败: {e}") + +def create_subgraph(db: Graph, nodes: List[Node], relations: List[Tuple[Node, str, Node]]) -> None: + """统一创建子图""" + subgraph = Subgraph( + nodes=nodes, + relationships=[Relationship(start, rel_type, end) for start, rel_type, end in relations] + ) + db.create(subgraph) + +def tx_create(db: Graph, nodes: List[Node], relations: List[Tuple[Node, str, Node]]) -> None: + """事务方式创建数据""" + try: + tx = db.begin() + subgraph = Subgraph( + nodes=nodes, + relationships=[Relationship(start, rel_type, end) for start, rel_type, end in relations] + ) + tx.create(subgraph) + db.commit(tx) + except Exception as e: + db.rollback(tx) + print(f"事务操作失败: {str(e)}") + raise + +class Neo4jExecutor: + # 添加类变量存储连接配置 + NEO4J_URI = Config.NEO4J_URI + NEO4J_AUTH = Config.NEO4J_AUTH + + def __init__(self, uri=None, auth=None): + # 使用默认配置或传入参数 + self.graph = Graph(uri or self.NEO4J_URI, + auth=auth or self.NEO4J_AUTH) + + @classmethod + def create_default(cls): + """使用默认配置创建执行器""" + return cls(cls.NEO4J_URI, cls.NEO4J_AUTH) + + # 新增文本执行方法 + def execute_cypher_text(self, cypher_text: str) -> dict: + """直接执行Cypher文本""" + stats = {'total': 0, 'success': 0, 'failed': 0} + try: + statements = re.split(r';\s*\n', cypher_text) + statements = [s.strip() for s in statements if s.strip()] + + stats['total'] = len(statements) + + for stmt in statements: + try: + self.graph.run(stmt) + stats['success'] += 1 + except Exception as e: + stats['failed'] += 1 + print(f"执行失败: {stmt[:50]}... \n错误: {str(e)[:100]}") + + return stats + except Exception as e: + print(f"执行失败: {stmt[:100]}... \n完整错误: {str(e)}") # 原为str(e)[:100] + return stats + + def execute_cypher_file(self, file_path: str) -> dict: # 确保方法名称正确 + """执行Cypher文件""" + stats = {'total': 0, 'success': 0, 'failed': 0} + try: + with open(file_path, 'r', encoding='utf-8') as f: + cypher_script = f.read() + statements = re.split(r';\s*\n', cypher_script) + statements = [s.strip() for s in statements if s.strip()] + + stats['total'] = len(statements) + + for stmt in statements: + try: + self.graph.run(stmt) + stats['success'] += 1 + except Exception as e: + stats['failed'] += 1 + print(f"执行失败: {stmt[:50]}... \n错误: {str(e)[:100]}") + + return stats + + except Exception as e: + print(f"文件错误: {str(e)}") + return stats \ No newline at end of file diff --git a/dsRagAnything/Util/__pycache__/Neo4jExecutor.cpython-310.pyc b/dsRagAnything/Util/__pycache__/Neo4jExecutor.cpython-310.pyc new file mode 100644 index 00000000..ebc431ff Binary files /dev/null and b/dsRagAnything/Util/__pycache__/Neo4jExecutor.cpython-310.pyc differ