Skip to content

Commit

Permalink
Merge pull request #26 from jkminder/fix-match-iterator
Browse files Browse the repository at this point in the history
added the possibility to return iterators from the match_* functions
  • Loading branch information
jkminder authored Aug 25, 2024
2 parents de3d500 + 9e04e8f commit 70bec53
Show file tree
Hide file tree
Showing 3 changed files with 197 additions and 77 deletions.
78 changes: 1 addition & 77 deletions data2neo/neo4j/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import List, Union

from .graph_elements import Node, Relationship, Subgraph, Attribute
from .cypher import cypher_join, _match_clause, encode_value, encode_key
from .matching import match_nodes, match_relationships

def create(graph: Subgraph, session: Session):
"""
Expand Down Expand Up @@ -47,79 +47,3 @@ def pull(graph: Subgraph, session: Session):
"""
session.execute_read(graph.__db_pull__)


def match_nodes(session: Session, *labels: List[str], **properties: dict):
"""
Matches nodes in the database.
Args:
labels (List[str]): The labels to match.
session (Session): The `session <https://neo4j.com/docs/api/python-driver/current/api.html#session>`_ to use.
properties (dict): The properties to match.
"""
flat_params = [tuple(labels),]
data = []
for k, v in properties.items():
data.append(v)
flat_params.append(k)

if len(data) > 1:
data = [data]

unwind = "UNWIND $data as r" if len(data) > 0 else ""
clause = cypher_join(unwind, _match_clause('n', tuple(flat_params), "r"), "RETURN n, LABELS(n), ID(n)", data=data)

records = session.run(*clause).data()
# Convert to Node
out = []
for record in records:
node = Node.from_dict(record['LABELS(n)'], record['n'], identity=record['ID(n)'])
out.append(node)
return out


def match_relationships(session: Session, from_node: Node =None, to_node:Node =None, rel_type: str =None, **properties: dict):
"""
Matches relationships in the database.
Args:
session (Session): The `session <https://neo4j.com/docs/api/python-driver/current/api.html#session>`_ to use.
from_node (Node): The node to match the relationship from (Default: None)
to_node (Node): The node to match the relationship to (Default: None)
rel_type (str): The type of the relationship to match (Default: None)
properties (dict): The properties to match.
"""
if from_node is not None:
assert from_node.identity is not None, "from_node must have an identity"

if to_node is not None:
assert to_node.identity is not None, "to_node must have an identity"

params = ""
for k, v in properties.items():
if params != "":
params += ", "
params += f"{encode_key(k)}: {encode_value(v)}"

clauses = []
if from_node is not None:
clauses.append(f"ID(from_node) = {from_node.identity}")
if to_node is not None:
clauses.append(f"ID(to_node) = {to_node.identity}")
if rel_type is not None:
clauses.append(f"type(r) = {encode_value(rel_type)}")

clause = cypher_join(
f"MATCH (from_node)-[r {{{params}}}]->(to_node)",
"WHERE" if len(clauses) > 0 else "",
" AND ".join(clauses),
"RETURN PROPERTIES(r), TYPE(r), ID(r), from_node, LABELS(from_node), ID(from_node), to_node, LABELS(to_node), ID(to_node)"
)
records = session.run(*clause).data()
out = []
for record in records:
fn = Node.from_dict(record['LABELS(from_node)'], record['from_node'], identity=record['ID(from_node)']) if from_node is None else from_node
tn = Node.from_dict(record['LABELS(to_node)'], record['to_node'], identity=record['ID(to_node)']) if to_node is None else to_node
rel = Relationship.from_dict(fn, tn, record['TYPE(r)'], record['PROPERTIES(r)'], identity=record['ID(r)'])
out.append(rel)
return out
120 changes: 120 additions & 0 deletions data2neo/neo4j/matching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
from neo4j import Session
from typing import List, Union

from .graph_elements import Node, Relationship, Subgraph, Attribute
from .cypher import cypher_join, _match_clause, encode_value, encode_key
from abc import ABC, abstractmethod

class ResultIterator(ABC):
def __init__(self, count, match):
self._count = count
self._match = match

def __len__(self):
return self._count

@abstractmethod
def __iter__(self):
pass

class NodeIterator(ResultIterator):
def __iter__(self):
for record in self._match:
node = Node.from_dict(record['LABELS(n)'], record['n'], identity=record['ID(n)'])
yield node

class RelationshipIterator(ResultIterator):
def __iter__(self):
for record in self._match:
fn = Node.from_dict(record['LABELS(from_node)'], record['from_node'], identity=record['ID(from_node)'])
tn = Node.from_dict(record['LABELS(to_node)'], record['to_node'], identity=record['ID(to_node)'])
rel = Relationship.from_dict(fn, tn, record['TYPE(r)'], record['PROPERTIES(r)'], identity=record['ID(r)'])
yield rel

def match_nodes(session: Session, *labels: List[str], return_iterator=False, **properties: dict):
"""
Matches nodes in the database.
Args:
session (Session): The `session <https://neo4j.com/docs/api/python-driver/current/api.html#session>`_ to use.
labels (List[str]): The labels to match.
return_iterator (bool): Whether to return an iterator or a list (Default: False)
properties (dict): The properties to match.
"""
flat_params = [tuple(labels),]
data = []
for k, v in properties.items():
data.append(v)
flat_params.append(k)

if len(data) > 1:
data = [data]

unwind = "UNWIND $data as r" if len(data) > 0 else ""


clause = cypher_join(unwind, _match_clause('n', tuple(flat_params), "r"), "RETURN n, LABELS(n), ID(n)", data=data)
count_clause = cypher_join(unwind, _match_clause('n', tuple(flat_params), "r"), "RETURN count(n)", data=data)

count = session.run(*count_clause).single().value()

match = session.run(*clause)
iterator = NodeIterator(count, match)

if return_iterator:
return iterator
else:
return list(iterator)

def match_relationships(session: Session, from_node: Node =None, to_node:Node =None, rel_type: str =None, return_iterator=False, **properties: dict):
"""
Matches relationships in the database.
Args:
session (Session): The `session <https://neo4j.com/docs/api/python-driver/current/api.html#session>`_ to use.
from_node (Node): The node to match the relationship from (Default: None)
to_node (Node): The node to match the relationship to (Default: None)
rel_type (str): The type of the relationship to match (Default: None)
return_iterator (bool): Whether to return an iterator or a list (Default: False)
properties (dict): The properties to match.
"""
if from_node is not None:
assert from_node.identity is not None, "from_node must have an identity"

if to_node is not None:
assert to_node.identity is not None, "to_node must have an identity"

params = ""
for k, v in properties.items():
if params != "":
params += ", "
params += f"{encode_key(k)}: {encode_value(v)}"

clauses = []
if from_node is not None:
clauses.append(f"ID(from_node) = {from_node.identity}")
if to_node is not None:
clauses.append(f"ID(to_node) = {to_node.identity}")
if rel_type is not None:
clauses.append(f"type(r) = {encode_value(rel_type)}")

clause = cypher_join(
f"MATCH (from_node)-[r {{{params}}}]->(to_node)",
"WHERE" if len(clauses) > 0 else "",
" AND ".join(clauses),
"RETURN PROPERTIES(r), TYPE(r), ID(r), from_node, LABELS(from_node), ID(from_node), to_node, LABELS(to_node), ID(to_node)"
)
count_clause = cypher_join(
f"MATCH (from_node)-[r {{{params}}}]->(to_node)",
"WHERE" if len(clauses) > 0 else "",
" AND ".join(clauses),
"RETURN count(r)"
)
count = session.run(*count_clause).single().value()

match = session.run(*clause)

if return_iterator:
return RelationshipIterator(count, match)
else:
return list(RelationshipIterator(count, match))
76 changes: 76 additions & 0 deletions tests/unit/neo4j/test_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,43 @@ def test_match_nodes(session):
nodes = match_nodes(session, name="test1", anotherattr="test")
assert(len(nodes) == 1)
assert(check_node(nodes, 1))

def test_match_nodes_with_iterator(session):
# match by single label
nodes = match_nodes(session, "test", return_iterator=True)
assert(len(nodes) == 2)
nodes = list(nodes)
assert(len(nodes) == 2)
assert(check_node(nodes, 1))
assert(check_node(nodes, 2))

# match by multiple labels
nodes = match_nodes(session, "test", "second", return_iterator=True)
assert(len(nodes) == 1)
nodes = list(nodes)
assert(len(nodes) == 1)
assert(check_node(nodes, 1))

# match by properties with no label
nodes = match_nodes(session, name="test3", return_iterator=True)
assert(len(nodes) == 1)
nodes = list(nodes)
assert(len(nodes) == 1)
assert(check_node(nodes, 3))

# match by properties with label
nodes = match_nodes(session, "test", name="test1", return_iterator=True)
assert(len(nodes) == 1)
nodes = list(nodes)
assert(len(nodes) == 1)
assert(check_node(nodes, 1))

# match by two properties
nodes = match_nodes(session, name="test1", anotherattr="test", return_iterator=True)
assert(len(nodes) == 1)
nodes = list(nodes)
assert(len(nodes) == 1)
assert(check_node(nodes, 1))

def test_match_relationships(session):
# match by type
Expand Down Expand Up @@ -109,3 +146,42 @@ def test_match_relationships(session):
assert(len(rels) == 1)
assert(check_rel(rels, 1))

def test_match_relationships_with_iterator(session):
# match by type
rels = match_relationships(session, rel_type="to", return_iterator=True)
assert(len(rels) == 2)
rels = list(rels)
assert(len(rels) == 2)
assert(check_rel(rels, 1))
assert(check_rel(rels, 2))

# match by properties
rels = match_relationships(session, rel_type="to", id=1, return_iterator=True)
assert(len(rels) == 1)
rels = list(rels)
assert(len(rels) == 1)
assert(check_rel(rels, 1))

# match by multiple properties
rels = match_relationships(session, rel_type="to", id=2, anotherattr="test", return_iterator=True)
assert(len(rels) == 1)
rels = list(rels)
assert(len(rels) == 1)
assert(check_rel(rels,2))

# match by from node
n1 = match_nodes(session, "test", id=1)[0]
rels = match_relationships(session, from_node=n1, return_iterator=True)
assert(len(rels) == 2)
rels = list(rels)
assert(len(rels) == 2)
assert(check_rel(rels, 1))
assert(check_rel(rels, 2))

# match by to node
n2 = match_nodes(session, "test", id=2)[0]
rels = match_relationships(session, to_node=n2, return_iterator=True)
assert(len(rels) == 1)
rels = list(rels)
assert(len(rels) == 1)
assert(check_rel(rels, 1))

0 comments on commit 70bec53

Please sign in to comment.