From b055f15c08589f5ffa1298e749bcb61fae8205e1 Mon Sep 17 00:00:00 2001 From: Julian Minder Date: Mon, 29 Jul 2024 09:51:59 +0000 Subject: [PATCH 1/2] added the possibility to return iterators from the match_* functions --- data2neo/neo4j/__init__.py | 78 +-------------------- data2neo/neo4j/matching.py | 120 +++++++++++++++++++++++++++++++++ tests/unit/neo4j/test_match.py | 76 +++++++++++++++++++++ 3 files changed, 197 insertions(+), 77 deletions(-) create mode 100644 data2neo/neo4j/matching.py diff --git a/data2neo/neo4j/__init__.py b/data2neo/neo4j/__init__.py index 1cabe1e..6380a5c 100644 --- a/data2neo/neo4j/__init__.py +++ b/data2neo/neo4j/__init__.py @@ -2,7 +2,7 @@ from typing import List, Union from .graph_elements import Node, Relationship, Subgraph, Attribute -from .cypher import cypher_join, _match_clause, encode_value, encode_key +from .matching import match_nodes, match_relationships def create(graph: Subgraph, session: Session): """ @@ -47,79 +47,3 @@ def pull(graph: Subgraph, session: Session): """ session.execute_read(graph.__db_pull__) - -def match_nodes(session: Session, *labels: List[str], **properties: dict): - """ - Matches nodes in the database. - - Args: - labels (List[str]): The labels to match. - session (Session): The `session `_ to use. - properties (dict): The properties to match. - """ - flat_params = [tuple(labels),] - data = [] - for k, v in properties.items(): - data.append(v) - flat_params.append(k) - - if len(data) > 1: - data = [data] - - unwind = "UNWIND $data as r" if len(data) > 0 else "" - clause = cypher_join(unwind, _match_clause('n', tuple(flat_params), "r"), "RETURN n, LABELS(n), ID(n)", data=data) - - records = session.run(*clause).data() - # Convert to Node - out = [] - for record in records: - node = Node.from_dict(record['LABELS(n)'], record['n'], identity=record['ID(n)']) - out.append(node) - return out - - -def match_relationships(session: Session, from_node: Node =None, to_node:Node =None, rel_type: str =None, **properties: dict): - """ - Matches relationships in the database. - - Args: - session (Session): The `session `_ to use. - from_node (Node): The node to match the relationship from (Default: None) - to_node (Node): The node to match the relationship to (Default: None) - rel_type (str): The type of the relationship to match (Default: None) - properties (dict): The properties to match. - """ - if from_node is not None: - assert from_node.identity is not None, "from_node must have an identity" - - if to_node is not None: - assert to_node.identity is not None, "to_node must have an identity" - - params = "" - for k, v in properties.items(): - if params != "": - params += ", " - params += f"{encode_key(k)}: {encode_value(v)}" - - clauses = [] - if from_node is not None: - clauses.append(f"ID(from_node) = {from_node.identity}") - if to_node is not None: - clauses.append(f"ID(to_node) = {to_node.identity}") - if rel_type is not None: - clauses.append(f"type(r) = {encode_value(rel_type)}") - - clause = cypher_join( - f"MATCH (from_node)-[r {{{params}}}]->(to_node)", - "WHERE" if len(clauses) > 0 else "", - " AND ".join(clauses), - "RETURN PROPERTIES(r), TYPE(r), ID(r), from_node, LABELS(from_node), ID(from_node), to_node, LABELS(to_node), ID(to_node)" - ) - records = session.run(*clause).data() - out = [] - for record in records: - fn = Node.from_dict(record['LABELS(from_node)'], record['from_node'], identity=record['ID(from_node)']) if from_node is None else from_node - tn = Node.from_dict(record['LABELS(to_node)'], record['to_node'], identity=record['ID(to_node)']) if to_node is None else to_node - rel = Relationship.from_dict(fn, tn, record['TYPE(r)'], record['PROPERTIES(r)'], identity=record['ID(r)']) - out.append(rel) - return out \ No newline at end of file diff --git a/data2neo/neo4j/matching.py b/data2neo/neo4j/matching.py new file mode 100644 index 0000000..15bca34 --- /dev/null +++ b/data2neo/neo4j/matching.py @@ -0,0 +1,120 @@ +from neo4j import Session +from typing import List, Union + +from .graph_elements import Node, Relationship, Subgraph, Attribute +from .cypher import cypher_join, _match_clause, encode_value, encode_key +from abc import ABC, abstractmethod + +class ResultIterator(ABC): + def __init__(self, count, match): + self._count = count + self._match = match + + def __len__(self): + return self._count + + @abstractmethod + def __iter__(self): + pass + +class NodeIterator(ResultIterator): + def __iter__(self): + for record in self._match: + node = Node.from_dict(record['LABELS(n)'], record['n'], identity=record['ID(n)']) + yield node + +class RelationshipIterator(ResultIterator): + def __iter__(self): + for record in self._match: + fn = Node.from_dict(record['LABELS(from_node)'], record['from_node'], identity=record['ID(from_node)']) + tn = Node.from_dict(record['LABELS(to_node)'], record['to_node'], identity=record['ID(to_node)']) + rel = Relationship.from_dict(fn, tn, record['TYPE(r)'], record['PROPERTIES(r)'], identity=record['ID(r)']) + yield rel + +def match_nodes(session: Session, *labels: List[str], return_iterator=False, **properties: dict): + """ + Matches nodes in the database. + + Args: + session (Session): The `session `_ to use. + labels (List[str]): The labels to match. + return_iterator (bool): Whether to return an iterator or a list (Default: False) + properties (dict): The properties to match. + """ + flat_params = [tuple(labels),] + data = [] + for k, v in properties.items(): + data.append(v) + flat_params.append(k) + + if len(data) > 1: + data = [data] + + unwind = "UNWIND $data as r" if len(data) > 0 else "" + + + clause = cypher_join(unwind, _match_clause('n', tuple(flat_params), "r"), "RETURN n, LABELS(n), ID(n)", data=data) + count_clause = cypher_join(unwind, _match_clause('n', tuple(flat_params), "r"), "RETURN count(n)", data=data) + + count = session.run(*count_clause).single().value() + + match = session.run(*clause) + iterator = NodeIterator(count, match) + + if return_iterator: + return iterator + else: + return list(iterator) + +def match_relationships(session: Session, from_node: Node =None, to_node:Node =None, rel_type: str =None, return_iterator=False, **properties: dict): + """ + Matches relationships in the database. + + Args: + session (Session): The `session `_ to use. + from_node (Node): The node to match the relationship from (Default: None) + to_node (Node): The node to match the relationship to (Default: None) + rel_type (str): The type of the relationship to match (Default: None) + return_iterator (bool): Whether to return an iterator or a list (Default: False) + properties (dict): The properties to match. + """ + if from_node is not None: + assert from_node.identity is not None, "from_node must have an identity" + + if to_node is not None: + assert to_node.identity is not None, "to_node must have an identity" + + params = "" + for k, v in properties.items(): + if params != "": + params += ", " + params += f"{encode_key(k)}: {encode_value(v)}" + + clauses = [] + if from_node is not None: + clauses.append(f"ID(from_node) = {from_node.identity}") + if to_node is not None: + clauses.append(f"ID(to_node) = {to_node.identity}") + if rel_type is not None: + clauses.append(f"type(r) = {encode_value(rel_type)}") + + clause = cypher_join( + f"MATCH (from_node)-[r {{{params}}}]->(to_node)", + "WHERE" if len(clauses) > 0 else "", + " AND ".join(clauses), + "RETURN PROPERTIES(r), TYPE(r), ID(r), from_node, LABELS(from_node), ID(from_node), to_node, LABELS(to_node), ID(to_node)" + ) + count_clause = cypher_join( + f"MATCH (from_node)-[r {{{params}}}]->(to_node)", + "WHERE" if len(clauses) > 0 else "", + " AND ".join(clauses), + "RETURN count(r)" + ) + count = session.run(*count_clause).single().value() + + match = session.run(*clause) + + if return_iterator: + return RelationshipIterator(count, match) + else: + return list(RelationshipIterator(count, match)) \ No newline at end of file diff --git a/tests/unit/neo4j/test_match.py b/tests/unit/neo4j/test_match.py index 1c79322..62eaf9c 100644 --- a/tests/unit/neo4j/test_match.py +++ b/tests/unit/neo4j/test_match.py @@ -78,6 +78,43 @@ def test_match_nodes(session): nodes = match_nodes(session, name="test1", anotherattr="test") assert(len(nodes) == 1) assert(check_node(nodes, 1)) + +def test_match_nodes_with_iterator(session): + # match by single label + nodes = match_nodes(session, "test", return_iterator=True) + assert(len(nodes) == 2) + nodes = list(nodes) + assert(len(nodes) == 2) + assert(check_node(nodes, 1)) + assert(check_node(nodes, 2)) + + # match by multiple labels + nodes = match_nodes(session, "test", "second", return_iterator=True) + assert(len(nodes) == 1) + nodes = list(nodes) + assert(len(nodes) == 1) + assert(check_node(nodes, 1)) + + # match by properties with no label + nodes = match_nodes(session, name="test3", return_iterator=True) + assert(len(nodes) == 1) + nodes = list(nodes) + assert(len(nodes) == 1) + assert(check_node(nodes, 3)) + + # match by properties with label + nodes = match_nodes(session, "test", name="test1", return_iterator=True) + assert(len(nodes) == 1) + nodes = list(nodes) + assert(len(nodes) == 1) + assert(check_node(nodes, 1)) + + # match by two properties + nodes = match_nodes(session, name="test1", anotherattr="test", return_iterator=True) + assert(len(nodes) == 1) + nodes = list(nodes) + assert(len(nodes) == 1) + assert(check_node(nodes, 1)) def test_match_relationships(session): # match by type @@ -109,3 +146,42 @@ def test_match_relationships(session): assert(len(rels) == 1) assert(check_rel(rels, 1)) +def test_match_relationships_with_iterator(session): + # match by type + rels = match_relationships(session, rel_type="to", return_iterator=True) + assert(len(rels) == 2) + assert(len(rels) == 2) + assert(check_rel(rels, 1)) + rels = list(rels) + assert(check_rel(rels, 2)) + + # match by properties + rels = match_relationships(session, rel_type="to", id=1, return_iterator=True) + assert(len(rels) == 1) + rels = list(rels) + assert(len(rels) == 1) + assert(check_rel(rels, 1)) + + # match by multiple properties + rels = match_relationships(session, rel_type="to", id=2, anotherattr="test", return_iterator=True) + assert(len(rels) == 1) + rels = list(rels) + assert(len(rels) == 1) + assert(check_rel(rels,2)) + + # match by from node + n1 = match_nodes(session, "test", id=1)[0] + rels = match_relationships(session, from_node=n1, return_iterator=True) + assert(len(rels) == 2) + rels = list(rels) + assert(len(rels) == 2) + assert(check_rel(rels, 1)) + assert(check_rel(rels, 2)) + + # match by to node + n2 = match_nodes(session, "test", id=2)[0] + rels = match_relationships(session, to_node=n2, return_iterator=True) + assert(len(rels) == 1) + rels = list(rels) + assert(len(rels) == 1) + assert(check_rel(rels, 1)) From 9e04e8f0bc5c13b9fbf292f9e27993ceb38f7419 Mon Sep 17 00:00:00 2001 From: Julian Minder Date: Mon, 29 Jul 2024 10:41:46 +0000 Subject: [PATCH 2/2] fixed bug in test --- tests/unit/neo4j/test_match.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/neo4j/test_match.py b/tests/unit/neo4j/test_match.py index 62eaf9c..7bad686 100644 --- a/tests/unit/neo4j/test_match.py +++ b/tests/unit/neo4j/test_match.py @@ -150,9 +150,9 @@ def test_match_relationships_with_iterator(session): # match by type rels = match_relationships(session, rel_type="to", return_iterator=True) assert(len(rels) == 2) + rels = list(rels) assert(len(rels) == 2) assert(check_rel(rels, 1)) - rels = list(rels) assert(check_rel(rels, 2)) # match by properties