Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add node type counter and tests #507

Closed
wants to merge 10 commits into from
47 changes: 37 additions & 10 deletions pytato/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
"""

from typing import (Mapping, Dict, Union, Set, Tuple, Any, FrozenSet,
TYPE_CHECKING)
Type, TYPE_CHECKING)
from pytato.array import (Array, IndexLambda, Stack, Concatenate, Einsum,
DictOfNamedArrays, NamedArray,
IndexBase, IndexRemappingBase, InputArgumentBase,
Expand All @@ -49,6 +49,8 @@

.. autofunction:: get_num_nodes

.. autofunction:: get_node_type_counts

.. autofunction:: get_num_call_sites

.. autoclass:: DirectPredecessorsGetter
Expand Down Expand Up @@ -381,34 +383,59 @@ def map_named_call_result(self, expr: NamedCallResult) -> FrozenSet[Array]:
@optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True)
class NodeCountMapper(CachedWalkMapper):
"""
Counts the number of nodes in a DAG.
Counts the number of nodes of a given type in a DAG.

.. attribute:: count
.. attribute:: counts

The number of nodes.
Dictionary mapping node types to number of nodes of that type.
"""

def __init__(self) -> None:
from collections import defaultdict
super().__init__()
self.count = 0
self.counts = defaultdict(int) # type: Dict[Type[Any], int]

def get_cache_key(self, expr: ArrayOrNames) -> int:
return id(expr)
def get_cache_key(self, expr: ArrayOrNames) -> ArrayOrNames:
# does NOT account for duplicate nodes
return expr

def post_visit(self, expr: Any) -> None:
self.count += 1
if not isinstance(expr, DictOfNamedArrays):
self.counts[type(expr)] += 1


def get_node_type_counts(outputs: Union[Array, DictOfNamedArrays]
) -> Dict[Type[Any], int]:
"""
Returns a dictionary mapping node types to node count for that type
in DAG *outputs*.

Instances of `DictOfNamedArrays` are excluded from counting.
"""

from pytato.codegen import normalize_outputs
outputs = normalize_outputs(outputs)
kajalpatelinfo marked this conversation as resolved.
Show resolved Hide resolved

ncm = NodeCountMapper()
ncm(outputs)

return ncm.counts


def get_num_nodes(outputs: Union[Array, DictOfNamedArrays]) -> int:
"""Returns the number of nodes in DAG *outputs*."""
"""
Returns the number of nodes in DAG *outputs*.

Instances of `DictOfNamedArrays` are excluded from counting.
"""

from pytato.codegen import normalize_outputs
outputs = normalize_outputs(outputs)

ncm = NodeCountMapper()
ncm(outputs)

return ncm.count
return sum(ncm.counts.values())

# }}}

Expand Down
88 changes: 79 additions & 9 deletions test/test_pytato.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
"""

import sys

import numpy as np
import pytest
import attrs
Expand Down Expand Up @@ -585,19 +584,90 @@ def test_repr_array_is_deterministic():
assert repr(dag) == repr(dag)


def test_nodecountmapper():
from testlib import RandomDAGContext, make_random_dag
def test_empty_dag_count():
from pytato.analysis import get_num_nodes, get_node_type_counts

empty_dag = pt.make_dict_of_named_arrays({})

# Verify that get_num_nodes returns 0 for an empty DAG
assert get_num_nodes(empty_dag) == 0

counts = get_node_type_counts(empty_dag)
assert len(counts) == 0


def test_single_node_dag_count():
from pytato.analysis import get_num_nodes, get_node_type_counts

data = np.random.rand(4, 4)
single_node_dag = pt.make_dict_of_named_arrays(
{"result": pt.make_data_wrapper(data)})

# Get counts per node type
node_counts = get_node_type_counts(single_node_dag)

# Assert that there is only one node of type DataWrapper
assert node_counts == {pt.DataWrapper: 1}

# Get total number of nodes
total_nodes = get_num_nodes(single_node_dag)

assert total_nodes == 1


def test_small_dag_count():
from pytato.analysis import get_num_nodes, get_node_type_counts

# Make a DAG using two nodes and one operation
a = pt.make_placeholder(name="a", shape=(2, 2), dtype=np.float64)
b = a + 1
dag = pt.make_dict_of_named_arrays({"result": b}) # b = a + 1

# Verify that get_num_nodes returns 2 for a DAG with two nodes
assert get_num_nodes(dag) == 2

counts = get_node_type_counts(dag)
assert len(counts) == 2
assert counts[pt.array.Placeholder] == 1 # "a"
assert counts[pt.array.IndexLambda] == 1 # single operation


def test_large_dag_count():
from pytato.analysis import get_num_nodes, get_node_type_counts
from testlib import make_large_dag

iterations = 100
dag = make_large_dag(iterations, seed=42)

# Verify that the number of nodes is equal to iterations + 1 (placeholder)
assert get_num_nodes(dag) == iterations + 1

counts = get_node_type_counts(dag)
assert len(counts) >= 1
assert counts[pt.array.Placeholder] == 1
assert counts[pt.array.IndexLambda] == 100 # 100 operations
assert sum(counts.values()) == iterations + 1


def test_random_dag_count():
from testlib import get_random_pt_dag
from pytato.analysis import get_num_nodes
for i in range(80):
dag = get_random_pt_dag(seed=i, axis_len=5)

axis_len = 5
assert get_num_nodes(dag) == len(pt.transform.DependencyMapper()(dag))


def test_random_dag_with_comm_count():
from testlib import get_random_pt_dag_with_send_recv_nodes
from pytato.analysis import get_num_nodes
rank = 0
size = 2
for i in range(10):
rdagc = RandomDAGContext(np.random.default_rng(seed=i),
axis_len=axis_len, use_numpy=False)
dag = make_random_dag(rdagc)
dag = get_random_pt_dag_with_send_recv_nodes(
seed=i, rank=rank, size=size)

# Subtract 1 since NodeCountMapper counts an extra one for DictOfNamedArrays.
assert get_num_nodes(dag)-1 == len(pt.transform.DependencyMapper()(dag))
assert get_num_nodes(dag) == len(pt.transform.DependencyMapper()(dag))


def test_rec_get_user_nodes():
Expand Down
26 changes: 26 additions & 0 deletions test/testlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,32 @@ def gen_comm(rdagc: RandomDAGContext) -> pt.Array:
convert_dws_to_placeholders=convert_dws_to_placeholders,
additional_generators=[(comm_fake_probability, gen_comm)])


def make_large_dag(iterations: int, seed: int = 0) -> pt.DictOfNamedArrays:
"""
Builds a DAG with emphasis on number of operations.
"""
import random
import operator

rng = np.random.default_rng(seed)
random.seed(seed)

# Begin with a placeholder
a = pt.make_placeholder(name="a", shape=(2, 2), dtype=np.float64)
current = a

# Will randomly choose from the operators
operations = [operator.add, operator.sub, operator.mul, operator.truediv]

for _ in range(iterations):
operation = random.choice(operations)
value = rng.uniform(1, 10)
current = operation(current, value)

# DAG should have `iterations` number of operations
return pt.make_dict_of_named_arrays({"result": current})

# }}}


Expand Down
Loading