You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

109 lines
3.8 KiB

import re
from typing import List, Tuple
from py2neo import Graph, Node, Relationship, Subgraph
from Config import Config
def clear(db):
# 清空数据
db.run("MATCH (n) DETACH DELETE n")
# 分步删除约束和索引
try:
# 删除约束
constraints = db.run("SHOW CONSTRAINTS YIELD name").data()
for constr in constraints:
db.run(f"DROP CONSTRAINT `{constr['name']}`")
# 删除索引
indexes = db.run("SHOW INDEXES YIELD name, type WHERE type <> 'LOOKUP'").data()
for idx in indexes:
db.run(f"DROP INDEX `{idx['name']}`")
except Exception as e:
print(f"删除操作失败: {e}")
def create_subgraph(db: Graph, nodes: List[Node], relations: List[Tuple[Node, str, Node]]) -> None:
"""统一创建子图"""
subgraph = Subgraph(
nodes=nodes,
relationships=[Relationship(start, rel_type, end) for start, rel_type, end in relations]
)
db.create(subgraph)
def tx_create(db: Graph, nodes: List[Node], relations: List[Tuple[Node, str, Node]]) -> None:
"""事务方式创建数据"""
try:
tx = db.begin()
subgraph = Subgraph(
nodes=nodes,
relationships=[Relationship(start, rel_type, end) for start, rel_type, end in relations]
)
tx.create(subgraph)
db.commit(tx)
except Exception as e:
db.rollback(tx)
print(f"事务操作失败: {str(e)}")
raise
class Neo4jExecutor:
# 添加类变量存储连接配置
NEO4J_URI = Config.NEO4J_URI
NEO4J_AUTH = Config.NEO4J_AUTH
def __init__(self, uri=None, auth=None):
# 使用默认配置或传入参数
self.graph = Graph(uri or self.NEO4J_URI,
auth=auth or self.NEO4J_AUTH)
@classmethod
def create_default(cls):
"""使用默认配置创建执行器"""
return cls(cls.NEO4J_URI, cls.NEO4J_AUTH)
# 新增文本执行方法
def execute_cypher_text(self, cypher_text: str) -> dict:
"""直接执行Cypher文本"""
stats = {'total': 0, 'success': 0, 'failed': 0}
try:
statements = re.split(r';\s*\n', cypher_text)
statements = [s.strip() for s in statements if s.strip()]
stats['total'] = len(statements)
for stmt in statements:
try:
self.graph.run(stmt)
stats['success'] += 1
except Exception as e:
stats['failed'] += 1
print(f"执行失败: {stmt[:50]}... \n错误: {str(e)[:100]}")
return stats
except Exception as e:
print(f"执行失败: {stmt[:100]}... \n完整错误: {str(e)}") # 原为str(e)[:100]
return stats
def execute_cypher_file(self, file_path: str) -> dict: # 确保方法名称正确
"""执行Cypher文件"""
stats = {'total': 0, 'success': 0, 'failed': 0}
try:
with open(file_path, 'r', encoding='utf-8') as f:
cypher_script = f.read()
statements = re.split(r';\s*\n', cypher_script)
statements = [s.strip() for s in statements if s.strip()]
stats['total'] = len(statements)
for stmt in statements:
try:
self.graph.run(stmt)
stats['success'] += 1
except Exception as e:
stats['failed'] += 1
print(f"执行失败: {stmt[:50]}... \n错误: {str(e)[:100]}")
return stats
except Exception as e:
print(f"文件错误: {str(e)}")
return stats