Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add duplicate node counter functionality and tests #508

Merged
merged 34 commits into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
3e19358
Add node counter tests
kajalpatelinfo Jun 18, 2024
ea2402c
CI fixes
kajalpatelinfo Jun 18, 2024
b122aa9
Add comments
kajalpatelinfo Jun 18, 2024
4a52c8d
Remove unnecessary test
kajalpatelinfo Jun 18, 2024
570eda4
Add duplicate node functionality and tests
kajalpatelinfo Jun 23, 2024
d8dbe62
Remove incrementation for DictOfNamedArrays and update tests
kajalpatelinfo Jun 26, 2024
84262cc
Merge branch 'main' into duplicate_node_counts
kajalpatelinfo Jun 26, 2024
178127c
Edit tests to account for not counting DictOfNamedArrays
kajalpatelinfo Jun 26, 2024
326045e
Fix CI tests
kajalpatelinfo Jun 26, 2024
6a0a2a9
Fix comments
kajalpatelinfo Jun 26, 2024
e235f8f
Merge branch 'main' into duplicate_node_counts
kajalpatelinfo Jun 26, 2024
0dca4d7
Clarify wording and clean up
kajalpatelinfo Jun 27, 2024
d695c9f
Merge branch 'main' into duplicate_node_counts
kajalpatelinfo Jun 27, 2024
9489ecf
Move `get_node_multiplicities` to its own mapper
kajalpatelinfo Jun 27, 2024
27d6283
Add autofunction
kajalpatelinfo Jun 27, 2024
a89bf52
Merge branch 'main' into duplicate_node_counts
kajalpatelinfo Jul 4, 2024
e3a2986
Linting
kajalpatelinfo Jul 11, 2024
25c79a6
Add Dict typedef and format
kajalpatelinfo Jul 16, 2024
0b56ea4
Format further
kajalpatelinfo Jul 16, 2024
7f2e3ef
Merge branch 'main' into duplicate_node_counts
kajalpatelinfo Jul 16, 2024
6fdcfe5
Fix CI errors
kajalpatelinfo Jul 22, 2024
b4a8cb8
Merge branch 'main' into duplicate_node_counts
kajalpatelinfo Jul 22, 2024
275c609
Fix wording
kajalpatelinfo Jul 24, 2024
4ca47b2
Implement new DAG generator with guaranteed duplicates
kajalpatelinfo Jul 25, 2024
02917e8
Apply suggestions from code review
kajalpatelinfo Jul 25, 2024
2c39189
Merge branch 'main' into duplicate_node_counts
kajalpatelinfo Jul 25, 2024
7e24f46
Ruff fixes
kajalpatelinfo Jul 26, 2024
900937b
remove prints
majosm Jul 26, 2024
00436f1
Apply suggestions from code review
kajalpatelinfo Jul 30, 2024
8d8066f
Add explicit bool for count_duplicates
kajalpatelinfo Jul 31, 2024
f2b8e02
Merge branch 'main' into duplicate_node_counts
kajalpatelinfo Jul 31, 2024
59ec433
Update test/testlib.py
kajalpatelinfo Aug 1, 2024
d8469b3
Seed random
kajalpatelinfo Aug 1, 2024
572f382
Merge branch 'main' into duplicate_node_counts
majosm Aug 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 96 additions & 11 deletions pytato/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@

.. autofunction:: get_num_nodes

.. autofunction:: get_node_type_counts

.. autofunction:: get_node_multiplicities

.. autofunction:: get_num_call_sites

.. autoclass:: DirectPredecessorsGetter
Expand Down Expand Up @@ -398,34 +402,115 @@ def map_named_call_result(
@optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True)
class NodeCountMapper(CachedWalkMapper):
"""
Counts the number of nodes in a DAG.
Counts the number of nodes of a given type in a DAG.

.. attribute:: count
.. autoattribute:: expr_type_counts
.. autoattribute:: count_duplicates

The number of nodes.
Dictionary mapping node types to number of nodes of that type.
"""

def __init__(self, count_duplicates: bool = False) -> None:
inducer marked this conversation as resolved.
Show resolved Hide resolved
from collections import defaultdict
super().__init__()
self.expr_type_counts: dict[type[Any], int] = defaultdict(int)
self.count_duplicates = count_duplicates

def get_cache_key(self, expr: ArrayOrNames) -> int | ArrayOrNames:
# Returns unique nodes only if count_duplicates is False
return id(expr) if self.count_duplicates else expr

def post_visit(self, expr: Any) -> None:
if not isinstance(expr, DictOfNamedArrays):
self.expr_type_counts[type(expr)] += 1


def get_node_type_counts(
outputs: Array | DictOfNamedArrays,
count_duplicates: bool = False
) -> dict[type[Any], int]:
"""
Returns a dictionary mapping node types to node count for that type
in DAG *outputs*.

Instances of `DictOfNamedArrays` are excluded from counting.
"""

from pytato.codegen import normalize_outputs
outputs = normalize_outputs(outputs)

ncm = NodeCountMapper(count_duplicates)
ncm(outputs)

return ncm.expr_type_counts


def get_num_nodes(
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This breaks compatibility (= True would retain compatibility).

Maybe deprecate not specifying the argument. (@majosm can explain)

outputs: Array | DictOfNamedArrays,
count_duplicates: bool | None = None
) -> int:
"""
Returns the number of nodes in DAG *outputs*.
Instances of `DictOfNamedArrays` are excluded from counting.
"""
if count_duplicates is None:
from warnings import warn
warn(
"The default value of 'count_duplicates' will change "
"from True to False in 2025. "
"For now, pass the desired value explicitly.",
DeprecationWarning, stacklevel=2)
count_duplicates = True

from pytato.codegen import normalize_outputs
outputs = normalize_outputs(outputs)

ncm = NodeCountMapper(count_duplicates)
ncm(outputs)

return sum(ncm.expr_type_counts.values())

# }}}


# {{{ NodeMultiplicityMapper


class NodeMultiplicityMapper(CachedWalkMapper):
"""
Computes the multiplicity of each unique node in a DAG.

The multiplicity of a node `x` is the number of nodes with distinct `id()`\\ s
that equal `x`.

.. autoattribute:: expr_multiplicity_counts
"""
def __init__(self) -> None:
from collections import defaultdict
super().__init__()
self.count = 0
self.expr_multiplicity_counts: dict[Array, int] = defaultdict(int)

def get_cache_key(self, expr: ArrayOrNames) -> int:
# Returns each node, including nodes that are duplicates
return id(expr)

def post_visit(self, expr: Any) -> None:
self.count += 1
if not isinstance(expr, DictOfNamedArrays):
self.expr_multiplicity_counts[expr] += 1


def get_num_nodes(outputs: Array | DictOfNamedArrays) -> int:
"""Returns the number of nodes in DAG *outputs*."""

def get_node_multiplicities(
outputs: Array | DictOfNamedArrays) -> dict[Array, int]:
"""
Returns the multiplicity per `expr`.
"""
from pytato.codegen import normalize_outputs
outputs = normalize_outputs(outputs)

ncm = NodeCountMapper()
ncm(outputs)
nmm = NodeMultiplicityMapper()
nmm(outputs)

return ncm.count
return nmm.expr_multiplicity_counts

# }}}

Expand Down
10 changes: 6 additions & 4 deletions pytato/distributed/verify.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,12 +194,14 @@ def _run_partition_diagnostics(

from pytato.analysis import get_num_nodes
num_nodes_per_part = [get_num_nodes(make_dict_of_named_arrays(
{x: gp.name_to_output[x] for x in part.output_names}))
{x: gp.name_to_output[x] for x in part.output_names}),
count_duplicates=False)
for part in gp.parts.values()]

logger.info(f"find_distributed_partition: Split {get_num_nodes(outputs)} nodes "
f"into {len(gp.parts)} parts, with {num_nodes_per_part} nodes in each "
"partition.")
logger.info("find_distributed_partition: "
f"Split {get_num_nodes(outputs, count_duplicates=False)} nodes "
f"into {len(gp.parts)} parts, with {num_nodes_per_part} nodes in each "
"partition.")

# }}}

Expand Down
7 changes: 4 additions & 3 deletions test/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1611,7 +1611,7 @@ def get_np_input_args():
_, (pt_result,) = knl(cq)

from pytato.analysis import get_num_nodes
print(get_num_nodes(pt_dag))
print(get_num_nodes(pt_dag, count_duplicates=False))

np.testing.assert_allclose(pt_result, np_result)

Expand All @@ -1637,8 +1637,9 @@ def test_zero_size_cl_array_dedup(ctx_factory):

dedup_dw_out = pt.transform.deduplicate_data_wrappers(out)

num_nodes_old = pt.analysis.get_num_nodes(out)
num_nodes_new = pt.analysis.get_num_nodes(dedup_dw_out)
num_nodes_old = pt.analysis.get_num_nodes(out, count_duplicates=True)
num_nodes_new = pt.analysis.get_num_nodes(
dedup_dw_out, count_duplicates=True)
# 'x2' would be merged with 'x1' as both of them point to the same data
# 'x3' would be merged with 'x4' as both of them point to the same data
assert num_nodes_new == (num_nodes_old - 2)
Expand Down
167 changes: 159 additions & 8 deletions test/test_pytato.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,20 +598,171 @@ def test_repr_array_is_deterministic():
assert repr(dag) == repr(dag)


def test_nodecountmapper():
from testlib import RandomDAGContext, make_random_dag
def test_empty_dag_count():
from pytato.analysis import get_node_type_counts, get_num_nodes

empty_dag = pt.make_dict_of_named_arrays({})

# Verify that get_num_nodes returns 0 for an empty DAG
assert get_num_nodes(empty_dag, count_duplicates=False) == 0

counts = get_node_type_counts(empty_dag)
assert len(counts) == 0


def test_single_node_dag_count():
from pytato.analysis import get_node_type_counts, get_num_nodes

data = np.random.rand(4, 4)
single_node_dag = pt.make_dict_of_named_arrays(
{"result": pt.make_data_wrapper(data)})

# Get counts per node type
node_counts = get_node_type_counts(single_node_dag)

# Assert that there is only one node of type DataWrapper
assert node_counts == {pt.DataWrapper: 1}

# Get total number of nodes
total_nodes = get_num_nodes(single_node_dag, count_duplicates=False)

assert total_nodes == 1


def test_small_dag_count():
from pytato.analysis import get_node_type_counts, get_num_nodes

# Make a DAG using two nodes and one operation
a = pt.make_placeholder(name="a", shape=(2, 2), dtype=np.float64)
b = a + 1
dag = pt.make_dict_of_named_arrays({"result": b}) # b = a + 1

# Verify that get_num_nodes returns 2 for a DAG with two nodes
assert get_num_nodes(dag, count_duplicates=False) == 2

counts = get_node_type_counts(dag)
assert len(counts) == 2
assert counts[pt.array.Placeholder] == 1 # "a"
assert counts[pt.array.IndexLambda] == 1 # single operation


def test_large_dag_count():
from testlib import make_large_dag

from pytato.analysis import get_node_type_counts, get_num_nodes

iterations = 100
dag = make_large_dag(iterations, seed=42)

# Verify that the number of nodes is equal to iterations + 1 (placeholder)
assert get_num_nodes(dag, count_duplicates=False) == iterations + 1

counts = get_node_type_counts(dag)
assert len(counts) >= 1
assert counts[pt.array.Placeholder] == 1
assert counts[pt.array.IndexLambda] == 100 # 100 operations
assert sum(counts.values()) == iterations + 1


def test_random_dag_count():
from testlib import get_random_pt_dag

from pytato.analysis import get_num_nodes
for i in range(80):
dag = get_random_pt_dag(seed=i, axis_len=5)

axis_len = 5
assert get_num_nodes(dag, count_duplicates=False) == len(
pt.transform.DependencyMapper()(dag))


def test_random_dag_with_comm_count():
from testlib import get_random_pt_dag_with_send_recv_nodes

from pytato.analysis import get_num_nodes
rank = 0
size = 2
for i in range(10):
rdagc = RandomDAGContext(np.random.default_rng(seed=i),
axis_len=axis_len, use_numpy=False)
dag = make_random_dag(rdagc)
dag = get_random_pt_dag_with_send_recv_nodes(
seed=i, rank=rank, size=size)

assert get_num_nodes(dag, count_duplicates=False) == len(
pt.transform.DependencyMapper()(dag))


def test_small_dag_with_duplicates_count():
from testlib import make_small_dag_with_duplicates

from pytato.analysis import (
get_node_multiplicities,
get_node_type_counts,
get_num_nodes,
)

dag = make_small_dag_with_duplicates()

# Get the number of expressions, including duplicates
node_count = get_num_nodes(dag, count_duplicates=True)
expected_node_count = 4
assert node_count == expected_node_count

# Get the number of occurrences of each unique expression
node_multiplicity = get_node_multiplicities(dag)
assert any(count > 1 for count in node_multiplicity.values())

# Get difference in duplicates
num_duplicates = sum(count - 1 for count in node_multiplicity.values())

counts = get_node_type_counts(dag, count_duplicates=True)
expected_counts = {
pt.array.Placeholder: 1,
pt.array.IndexLambda: 3
}

for node_type, expected_count in expected_counts.items():
assert counts[node_type] == expected_count

# Check that duplicates are correctly calculated
assert node_count - num_duplicates == len(
pt.transform.DependencyMapper()(dag))
assert node_count - num_duplicates == get_num_nodes(
dag, count_duplicates=False)


def test_large_dag_with_duplicates_count():
from testlib import make_large_dag_with_duplicates

from pytato.analysis import (
get_node_multiplicities,
get_node_type_counts,
get_num_nodes,
)

iterations = 100
dag = make_large_dag_with_duplicates(iterations, seed=42)

# Get the number of expressions, including duplicates
node_count = get_num_nodes(dag, count_duplicates=True)

# Get the number of occurrences of each unique expression
node_multiplicity = get_node_multiplicities(dag)
assert any(count > 1 for count in node_multiplicity.values())

expected_node_count = sum(count for count in node_multiplicity.values())
assert node_count == expected_node_count

# Get difference in duplicates
num_duplicates = sum(count - 1 for count in node_multiplicity.values())

counts = get_node_type_counts(dag, count_duplicates=True)

assert counts[pt.array.Placeholder] == 1
assert sum(counts.values()) == expected_node_count

# Subtract 1 since NodeCountMapper counts an extra one for DictOfNamedArrays.
assert get_num_nodes(dag)-1 == len(pt.transform.DependencyMapper()(dag))
# Check that duplicates are correctly calculated
assert node_count - num_duplicates == len(
pt.transform.DependencyMapper()(dag))
assert node_count - num_duplicates == get_num_nodes(
dag, count_duplicates=False)


def test_rec_get_user_nodes():
Expand Down
Loading
Loading