From 3e19358d7cf2f57a45bccd6759a61ca72890eab7 Mon Sep 17 00:00:00 2001 From: kajalpatelinfo Date: Mon, 17 Jun 2024 19:08:57 -0600 Subject: [PATCH 1/8] Add node counter tests --- pytato/analysis/__init__.py | 35 +++++++++++---- test/test_pytato.py | 90 ++++++++++++++++++++++++++++++++++++- test/testlib.py | 26 ++++++++++- 3 files changed, 139 insertions(+), 12 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 5bf374746..41c21aa9d 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -26,7 +26,7 @@ """ from typing import (Mapping, Dict, Union, Set, Tuple, Any, FrozenSet, - TYPE_CHECKING) + Type, TYPE_CHECKING) from pytato.array import (Array, IndexLambda, Stack, Concatenate, Einsum, DictOfNamedArrays, NamedArray, IndexBase, IndexRemappingBase, InputArgumentBase, @@ -49,6 +49,8 @@ .. autofunction:: get_num_nodes +.. autofunction:: get_node_type_counts + .. autofunction:: get_num_call_sites .. autoclass:: DirectPredecessorsGetter @@ -381,23 +383,38 @@ def map_named_call_result(self, expr: NamedCallResult) -> FrozenSet[Array]: @optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True) class NodeCountMapper(CachedWalkMapper): """ - Counts the number of nodes in a DAG. + Counts the number of nodes of a given type in a DAG. - .. attribute:: count + .. attribute:: counts - The number of nodes. + Dictionary mapping node types to number of nodes of that type. """ def __init__(self) -> None: + from collections import defaultdict super().__init__() - self.count = 0 + self.counts = defaultdict(int) def get_cache_key(self, expr: ArrayOrNames) -> int: - return id(expr) + # does NOT account for duplicate nodes + return expr def post_visit(self, expr: Any) -> None: - self.count += 1 + self.counts[type(expr)] += 1 + +def get_node_type_counts(outputs: Union[Array, DictOfNamedArrays]) -> Dict[Type, int]: + """ + Returns a dictionary mapping node types to node count for that type + in DAG *outputs*. + """ + + from pytato.codegen import normalize_outputs + outputs = normalize_outputs(outputs) + + ncm = NodeCountMapper() + ncm(outputs) + return ncm.counts def get_num_nodes(outputs: Union[Array, DictOfNamedArrays]) -> int: """Returns the number of nodes in DAG *outputs*.""" @@ -408,7 +425,7 @@ def get_num_nodes(outputs: Union[Array, DictOfNamedArrays]) -> int: ncm = NodeCountMapper() ncm(outputs) - return ncm.count + return sum(ncm.counts.values()) # }}} @@ -463,4 +480,4 @@ def get_num_call_sites(outputs: Union[Array, DictOfNamedArrays]) -> int: # }}} -# vim: fdm=marker +# vim: fdm=marker \ No newline at end of file diff --git a/test/test_pytato.py b/test/test_pytato.py index 8939073cb..cea10480c 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -26,7 +26,6 @@ """ import sys - import numpy as np import pytest import attrs @@ -585,7 +584,7 @@ def test_repr_array_is_deterministic(): assert repr(dag) == repr(dag) -def test_nodecountmapper(): +def test_node_count_mapper(): from testlib import RandomDAGContext, make_random_dag from pytato.analysis import get_num_nodes @@ -600,6 +599,93 @@ def test_nodecountmapper(): assert get_num_nodes(dag)-1 == len(pt.transform.DependencyMapper()(dag)) +def test_empty_dag_count(): + from pytato.analysis import get_num_nodes, get_node_type_counts + + empty_dag = pt.make_dict_of_named_arrays({}) + + # Verify that get_num_nodes returns 0 for an empty DAG + assert get_num_nodes(empty_dag) - 1 == 0 + + counts = get_node_type_counts(empty_dag) + assert len(counts) == 1 + +def test_single_node_dag_count(): + from pytato.analysis import get_num_nodes, get_node_type_counts + + data = np.random.rand(4, 4) + single_node_dag = pt.make_dict_of_named_arrays({"result": pt.make_data_wrapper(data)}) + + # Get counts per node type + node_counts = get_node_type_counts(single_node_dag) + + # Assert that there is only one node of type DataWrapper and one node of DictOfNamedArrays + # DictOfNamedArrays is automatically added + assert node_counts == {pt.DataWrapper: 1, pt.DictOfNamedArrays: 1} + assert sum(node_counts.values()) - 1 == 1 # Total node count is 1 + + # Get total number of nodes + total_nodes = get_num_nodes(single_node_dag) + + assert total_nodes - 1 == 1 + + +def test_small_dag_count(): + from pytato.analysis import get_num_nodes, get_node_type_counts + + # Make a DAG using two nodes and one operation + a = pt.make_placeholder(name="a", shape=(2, 2), dtype=np.float64) + b = a + 1 + dag = pt.make_dict_of_named_arrays({"result": b}) # b = a + 1 + + # Verify that get_num_nodes returns 2 for a DAG with two nodes + assert get_num_nodes(dag) - 1 == 2 + + counts = get_node_type_counts(dag) + assert len(counts) - 1 == 2 + assert counts[pt.array.Placeholder] == 1 # "a" + assert counts[pt.array.IndexLambda] == 1 # single operation + + +def test_large_dag_count(): + from pytato.analysis import get_num_nodes, get_node_type_counts + from testlib import make_large_dag + + iterations = 100 + dag = make_large_dag(iterations, seed=42) + + # Verify that the number of nodes is equal to iterations + 1 (placeholder) + assert get_num_nodes(dag) - 1 == iterations + 1 + + # Verify that the counts dictionary has correct counts for the complicated DAG + counts = get_node_type_counts(dag) + assert len(counts) >= 1 + assert counts[pt.array.Placeholder] == 1 + assert counts[pt.array.IndexLambda] == 100 # 100 operations + assert sum(counts.values()) - 1 == iterations + 1 + + +def test_random_dag_count(): + from testlib import get_random_pt_dag + from pytato.analysis import get_num_nodes + for i in range(80): + dag = get_random_pt_dag(seed=i, axis_len=5) + + # Subtract 1 since NodeCountMapper counts an extra one for DictOfNamedArrays. + assert get_num_nodes(dag) - 1 == len(pt.transform.DependencyMapper()(dag)) + +def test_random_dag_with_comm_count(): + from testlib import get_random_pt_dag_with_send_recv_nodes + from pytato.analysis import get_num_nodes + rank = 0 + size = 2 + for i in range(10): + dag = get_random_pt_dag_with_send_recv_nodes(seed=i, rank=rank, size=size) + + # Subtract 1 since NodeCountMapper counts an extra one for DictOfNamedArrays. + assert get_num_nodes(dag) - 1 == len(pt.transform.DependencyMapper()(dag)) + + def test_rec_get_user_nodes(): x1 = pt.make_placeholder("x1", shape=(10, 4)) x2 = pt.make_placeholder("x2", shape=(10, 4)) diff --git a/test/testlib.py b/test/testlib.py index 5cd1342d3..cdf827e96 100644 --- a/test/testlib.py +++ b/test/testlib.py @@ -311,6 +311,30 @@ def gen_comm(rdagc: RandomDAGContext) -> pt.Array: convert_dws_to_placeholders=convert_dws_to_placeholders, additional_generators=[(comm_fake_probability, gen_comm)]) +def make_large_dag(iterations: int, seed: int = 0) -> pt.DictOfNamedArrays: + """ + Builds a DAG with emphasis on number of operations. + """ + import random + import operator + + rng = np.random.default_rng(seed) + random.seed(seed) + + # Begin with a placeholder + a = pt.make_placeholder(name="a", shape=(2, 2), dtype=np.float64) + current = a + + # Will randomly choose from the operators + operations = [operator.add, operator.sub, operator.mul, operator.truediv] + + for _ in range(iterations): + operation = random.choice(operations) + value = rng.uniform(1, 10) + current = operation(current, value) + + return pt.make_dict_of_named_arrays({"result": current}) + # }}} @@ -369,4 +393,4 @@ class QuuxTag(TestlibTag): # }}} -# vim: foldmethod=marker +# vim: foldmethod=marker \ No newline at end of file From ea2402c9fe933dbf30cee8a71cf5247387461e65 Mon Sep 17 00:00:00 2001 From: kajalpatelinfo Date: Mon, 17 Jun 2024 19:41:47 -0600 Subject: [PATCH 2/8] CI fixes --- doc/conf.py | 1 + pytato/analysis/__init__.py | 10 ++++++---- test/test_pytato.py | 14 ++++++++------ test/testlib.py | 37 +++++++++++++++++++------------------ 4 files changed, 34 insertions(+), 28 deletions(-) diff --git a/doc/conf.py b/doc/conf.py index 081642f1d..e6f7ac0c0 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -46,6 +46,7 @@ nitpick_ignore_regex = [ ["py:class", r"numpy.(u?)int[\d]+"], ["py:class", r"typing_extensions(.+)"], + ["py:class", r"numpy.bool_"], # As of 2023-10-05, it doesn't look like there's sphinx documentation # available. ["py:class", r"immutabledict(.*)"], diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 41c21aa9d..5da4ea70e 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -393,16 +393,17 @@ class NodeCountMapper(CachedWalkMapper): def __init__(self) -> None: from collections import defaultdict super().__init__() - self.counts = defaultdict(int) + self.counts = defaultdict(int) # type: Dict[Type[Any], int] - def get_cache_key(self, expr: ArrayOrNames) -> int: + def get_cache_key(self, expr: ArrayOrNames) -> ArrayOrNames: # does NOT account for duplicate nodes return expr def post_visit(self, expr: Any) -> None: self.counts[type(expr)] += 1 -def get_node_type_counts(outputs: Union[Array, DictOfNamedArrays]) -> Dict[Type, int]: + +def get_node_type_counts(outputs: Union[Array, DictOfNamedArrays]) -> Dict[Type[Any], int]: """ Returns a dictionary mapping node types to node count for that type in DAG *outputs*. @@ -416,6 +417,7 @@ def get_node_type_counts(outputs: Union[Array, DictOfNamedArrays]) -> Dict[Type, return ncm.counts + def get_num_nodes(outputs: Union[Array, DictOfNamedArrays]) -> int: """Returns the number of nodes in DAG *outputs*.""" @@ -480,4 +482,4 @@ def get_num_call_sites(outputs: Union[Array, DictOfNamedArrays]) -> int: # }}} -# vim: fdm=marker \ No newline at end of file +# vim: fdm=marker diff --git a/test/test_pytato.py b/test/test_pytato.py index cea10480c..51e5381d5 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -610,6 +610,7 @@ def test_empty_dag_count(): counts = get_node_type_counts(empty_dag) assert len(counts) == 1 + def test_single_node_dag_count(): from pytato.analysis import get_num_nodes, get_node_type_counts @@ -626,7 +627,7 @@ def test_single_node_dag_count(): # Get total number of nodes total_nodes = get_num_nodes(single_node_dag) - + assert total_nodes - 1 == 1 @@ -636,15 +637,15 @@ def test_small_dag_count(): # Make a DAG using two nodes and one operation a = pt.make_placeholder(name="a", shape=(2, 2), dtype=np.float64) b = a + 1 - dag = pt.make_dict_of_named_arrays({"result": b}) # b = a + 1 + dag = pt.make_dict_of_named_arrays({"result": b}) # b = a + 1 # Verify that get_num_nodes returns 2 for a DAG with two nodes assert get_num_nodes(dag) - 1 == 2 counts = get_node_type_counts(dag) assert len(counts) - 1 == 2 - assert counts[pt.array.Placeholder] == 1 # "a" - assert counts[pt.array.IndexLambda] == 1 # single operation + assert counts[pt.array.Placeholder] == 1 # "a" + assert counts[pt.array.IndexLambda] == 1 # single operation def test_large_dag_count(): @@ -661,7 +662,7 @@ def test_large_dag_count(): counts = get_node_type_counts(dag) assert len(counts) >= 1 assert counts[pt.array.Placeholder] == 1 - assert counts[pt.array.IndexLambda] == 100 # 100 operations + assert counts[pt.array.IndexLambda] == 100 # 100 operations assert sum(counts.values()) - 1 == iterations + 1 @@ -670,10 +671,11 @@ def test_random_dag_count(): from pytato.analysis import get_num_nodes for i in range(80): dag = get_random_pt_dag(seed=i, axis_len=5) - + # Subtract 1 since NodeCountMapper counts an extra one for DictOfNamedArrays. assert get_num_nodes(dag) - 1 == len(pt.transform.DependencyMapper()(dag)) + def test_random_dag_with_comm_count(): from testlib import get_random_pt_dag_with_send_recv_nodes from pytato.analysis import get_num_nodes diff --git a/test/testlib.py b/test/testlib.py index cdf827e96..e15489c4b 100644 --- a/test/testlib.py +++ b/test/testlib.py @@ -311,29 +311,30 @@ def gen_comm(rdagc: RandomDAGContext) -> pt.Array: convert_dws_to_placeholders=convert_dws_to_placeholders, additional_generators=[(comm_fake_probability, gen_comm)]) + def make_large_dag(iterations: int, seed: int = 0) -> pt.DictOfNamedArrays: - """ - Builds a DAG with emphasis on number of operations. - """ - import random - import operator + """ + Builds a DAG with emphasis on number of operations. + """ + import random + import operator - rng = np.random.default_rng(seed) - random.seed(seed) + rng = np.random.default_rng(seed) + random.seed(seed) - # Begin with a placeholder - a = pt.make_placeholder(name="a", shape=(2, 2), dtype=np.float64) - current = a + # Begin with a placeholder + a = pt.make_placeholder(name="a", shape=(2, 2), dtype=np.float64) + current = a - # Will randomly choose from the operators - operations = [operator.add, operator.sub, operator.mul, operator.truediv] + # Will randomly choose from the operators + operations = [operator.add, operator.sub, operator.mul, operator.truediv] - for _ in range(iterations): - operation = random.choice(operations) - value = rng.uniform(1, 10) - current = operation(current, value) + for _ in range(iterations): + operation = random.choice(operations) + value = rng.uniform(1, 10) + current = operation(current, value) - return pt.make_dict_of_named_arrays({"result": current}) + return pt.make_dict_of_named_arrays({"result": current}) # }}} @@ -393,4 +394,4 @@ class QuuxTag(TestlibTag): # }}} -# vim: foldmethod=marker \ No newline at end of file +# vim: foldmethod=marker From b122aa9a6e87419bd5ddc77ac9ca098f6091d7c0 Mon Sep 17 00:00:00 2001 From: kajalpatelinfo Date: Mon, 17 Jun 2024 19:49:18 -0600 Subject: [PATCH 3/8] Add comments --- doc/conf.py | 1 - test/testlib.py | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/conf.py b/doc/conf.py index e6f7ac0c0..081642f1d 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -46,7 +46,6 @@ nitpick_ignore_regex = [ ["py:class", r"numpy.(u?)int[\d]+"], ["py:class", r"typing_extensions(.+)"], - ["py:class", r"numpy.bool_"], # As of 2023-10-05, it doesn't look like there's sphinx documentation # available. ["py:class", r"immutabledict(.*)"], diff --git a/test/testlib.py b/test/testlib.py index e15489c4b..a208f0816 100644 --- a/test/testlib.py +++ b/test/testlib.py @@ -334,6 +334,7 @@ def make_large_dag(iterations: int, seed: int = 0) -> pt.DictOfNamedArrays: value = rng.uniform(1, 10) current = operation(current, value) + # DAG should have `iterations` number of operations return pt.make_dict_of_named_arrays({"result": current}) # }}} From 4a52c8d8ce400b71461e97424bb62fa43e5ce28e Mon Sep 17 00:00:00 2001 From: kajalpatelinfo Date: Tue, 18 Jun 2024 10:08:34 -0600 Subject: [PATCH 4/8] Remove unnecessary test --- test/test_pytato.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/test/test_pytato.py b/test/test_pytato.py index 51e5381d5..962fe337e 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -584,21 +584,6 @@ def test_repr_array_is_deterministic(): assert repr(dag) == repr(dag) -def test_node_count_mapper(): - from testlib import RandomDAGContext, make_random_dag - from pytato.analysis import get_num_nodes - - axis_len = 5 - - for i in range(10): - rdagc = RandomDAGContext(np.random.default_rng(seed=i), - axis_len=axis_len, use_numpy=False) - dag = make_random_dag(rdagc) - - # Subtract 1 since NodeCountMapper counts an extra one for DictOfNamedArrays. - assert get_num_nodes(dag)-1 == len(pt.transform.DependencyMapper()(dag)) - - def test_empty_dag_count(): from pytato.analysis import get_num_nodes, get_node_type_counts From d8dbe62f5a17a477b84592d8817d11b547bf59ee Mon Sep 17 00:00:00 2001 From: kajalpatelinfo Date: Tue, 25 Jun 2024 18:16:20 -0600 Subject: [PATCH 5/8] Remove incrementation for DictOfNamedArrays and update tests --- pytato/analysis/__init__.py | 11 +++++++++-- test/test_pytato.py | 23 +++++++++++------------ 2 files changed, 20 insertions(+), 14 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 5da4ea70e..4938f6794 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -400,13 +400,16 @@ def get_cache_key(self, expr: ArrayOrNames) -> ArrayOrNames: return expr def post_visit(self, expr: Any) -> None: - self.counts[type(expr)] += 1 + if type(expr) is not DictOfNamedArrays: + self.counts[type(expr)] += 1 def get_node_type_counts(outputs: Union[Array, DictOfNamedArrays]) -> Dict[Type[Any], int]: """ Returns a dictionary mapping node types to node count for that type in DAG *outputs*. + + `DictOfNamedArrays` are added when *outputs* is normalized and ignored. """ from pytato.codegen import normalize_outputs @@ -419,7 +422,11 @@ def get_node_type_counts(outputs: Union[Array, DictOfNamedArrays]) -> Dict[Type[ def get_num_nodes(outputs: Union[Array, DictOfNamedArrays]) -> int: - """Returns the number of nodes in DAG *outputs*.""" + """ + Returns the number of nodes in DAG *outputs*. + + `DictOfNamedArrays` are added when *outputs* is normalized and ignored. + """ from pytato.codegen import normalize_outputs outputs = normalize_outputs(outputs) diff --git a/test/test_pytato.py b/test/test_pytato.py index 962fe337e..3d9e2a684 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -590,10 +590,10 @@ def test_empty_dag_count(): empty_dag = pt.make_dict_of_named_arrays({}) # Verify that get_num_nodes returns 0 for an empty DAG - assert get_num_nodes(empty_dag) - 1 == 0 + assert get_num_nodes(empty_dag) == 0 counts = get_node_type_counts(empty_dag) - assert len(counts) == 1 + assert len(counts) == 0 def test_single_node_dag_count(): @@ -606,14 +606,13 @@ def test_single_node_dag_count(): node_counts = get_node_type_counts(single_node_dag) # Assert that there is only one node of type DataWrapper and one node of DictOfNamedArrays - # DictOfNamedArrays is automatically added - assert node_counts == {pt.DataWrapper: 1, pt.DictOfNamedArrays: 1} - assert sum(node_counts.values()) - 1 == 1 # Total node count is 1 + assert node_counts == {pt.DataWrapper: 1} + assert sum(node_counts.values()) == 1 # Total node count is 1 # Get total number of nodes total_nodes = get_num_nodes(single_node_dag) - assert total_nodes - 1 == 1 + assert total_nodes == 1 def test_small_dag_count(): @@ -625,10 +624,10 @@ def test_small_dag_count(): dag = pt.make_dict_of_named_arrays({"result": b}) # b = a + 1 # Verify that get_num_nodes returns 2 for a DAG with two nodes - assert get_num_nodes(dag) - 1 == 2 + assert get_num_nodes(dag) == 2 counts = get_node_type_counts(dag) - assert len(counts) - 1 == 2 + assert len(counts) == 2 assert counts[pt.array.Placeholder] == 1 # "a" assert counts[pt.array.IndexLambda] == 1 # single operation @@ -641,14 +640,14 @@ def test_large_dag_count(): dag = make_large_dag(iterations, seed=42) # Verify that the number of nodes is equal to iterations + 1 (placeholder) - assert get_num_nodes(dag) - 1 == iterations + 1 + assert get_num_nodes(dag) == iterations + 1 # Verify that the counts dictionary has correct counts for the complicated DAG counts = get_node_type_counts(dag) assert len(counts) >= 1 assert counts[pt.array.Placeholder] == 1 assert counts[pt.array.IndexLambda] == 100 # 100 operations - assert sum(counts.values()) - 1 == iterations + 1 + assert sum(counts.values()) == iterations + 1 def test_random_dag_count(): @@ -658,7 +657,7 @@ def test_random_dag_count(): dag = get_random_pt_dag(seed=i, axis_len=5) # Subtract 1 since NodeCountMapper counts an extra one for DictOfNamedArrays. - assert get_num_nodes(dag) - 1 == len(pt.transform.DependencyMapper()(dag)) + assert get_num_nodes(dag) == len(pt.transform.DependencyMapper()(dag)) def test_random_dag_with_comm_count(): @@ -670,7 +669,7 @@ def test_random_dag_with_comm_count(): dag = get_random_pt_dag_with_send_recv_nodes(seed=i, rank=rank, size=size) # Subtract 1 since NodeCountMapper counts an extra one for DictOfNamedArrays. - assert get_num_nodes(dag) - 1 == len(pt.transform.DependencyMapper()(dag)) + assert get_num_nodes(dag) == len(pt.transform.DependencyMapper()(dag)) def test_rec_get_user_nodes(): From 6a0a2a9f44f58d1d4ba7e39122a4f2577ad0b157 Mon Sep 17 00:00:00 2001 From: kajalpatelinfo Date: Tue, 25 Jun 2024 18:39:57 -0600 Subject: [PATCH 6/8] Fix comments --- test/test_pytato.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/test/test_pytato.py b/test/test_pytato.py index 3d9e2a684..e202fd2c3 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -600,12 +600,13 @@ def test_single_node_dag_count(): from pytato.analysis import get_num_nodes, get_node_type_counts data = np.random.rand(4, 4) - single_node_dag = pt.make_dict_of_named_arrays({"result": pt.make_data_wrapper(data)}) + single_node_dag = pt.make_dict_of_named_arrays( + {"result": pt.make_data_wrapper(data)}) # Get counts per node type node_counts = get_node_type_counts(single_node_dag) - # Assert that there is only one node of type DataWrapper and one node of DictOfNamedArrays + # Assert that there is only one node of type DataWrapper assert node_counts == {pt.DataWrapper: 1} assert sum(node_counts.values()) == 1 # Total node count is 1 @@ -642,7 +643,6 @@ def test_large_dag_count(): # Verify that the number of nodes is equal to iterations + 1 (placeholder) assert get_num_nodes(dag) == iterations + 1 - # Verify that the counts dictionary has correct counts for the complicated DAG counts = get_node_type_counts(dag) assert len(counts) >= 1 assert counts[pt.array.Placeholder] == 1 @@ -656,7 +656,6 @@ def test_random_dag_count(): for i in range(80): dag = get_random_pt_dag(seed=i, axis_len=5) - # Subtract 1 since NodeCountMapper counts an extra one for DictOfNamedArrays. assert get_num_nodes(dag) == len(pt.transform.DependencyMapper()(dag)) @@ -666,9 +665,9 @@ def test_random_dag_with_comm_count(): rank = 0 size = 2 for i in range(10): - dag = get_random_pt_dag_with_send_recv_nodes(seed=i, rank=rank, size=size) + dag = get_random_pt_dag_with_send_recv_nodes( + seed=i, rank=rank, size=size) - # Subtract 1 since NodeCountMapper counts an extra one for DictOfNamedArrays. assert get_num_nodes(dag) == len(pt.transform.DependencyMapper()(dag)) From 0dca4d7c4295179f8e946f9411349d2dfa43512e Mon Sep 17 00:00:00 2001 From: kajalpatelinfo Date: Wed, 26 Jun 2024 19:50:38 -0600 Subject: [PATCH 7/8] Clarify wording and clean up --- pytato/analysis/__init__.py | 6 +++--- test/test_pytato.py | 1 - 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 4938f6794..3a112501b 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -400,7 +400,7 @@ def get_cache_key(self, expr: ArrayOrNames) -> ArrayOrNames: return expr def post_visit(self, expr: Any) -> None: - if type(expr) is not DictOfNamedArrays: + if not isinstance(expr, DictOfNamedArrays): self.counts[type(expr)] += 1 @@ -409,7 +409,7 @@ def get_node_type_counts(outputs: Union[Array, DictOfNamedArrays]) -> Dict[Type[ Returns a dictionary mapping node types to node count for that type in DAG *outputs*. - `DictOfNamedArrays` are added when *outputs* is normalized and ignored. + Instances of `DictOfNamedArrays` are excluded from counting. """ from pytato.codegen import normalize_outputs @@ -425,7 +425,7 @@ def get_num_nodes(outputs: Union[Array, DictOfNamedArrays]) -> int: """ Returns the number of nodes in DAG *outputs*. - `DictOfNamedArrays` are added when *outputs* is normalized and ignored. + Instances of `DictOfNamedArrays` are excluded from counting. """ from pytato.codegen import normalize_outputs diff --git a/test/test_pytato.py b/test/test_pytato.py index e202fd2c3..23aebd9e4 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -608,7 +608,6 @@ def test_single_node_dag_count(): # Assert that there is only one node of type DataWrapper assert node_counts == {pt.DataWrapper: 1} - assert sum(node_counts.values()) == 1 # Total node count is 1 # Get total number of nodes total_nodes = get_num_nodes(single_node_dag) From 1444c50d77badc5d947cb2aa536eaaca1861c5df Mon Sep 17 00:00:00 2001 From: kajalpatelinfo Date: Tue, 2 Jul 2024 18:05:48 -0600 Subject: [PATCH 8/8] Formatting --- pytato/analysis/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 3a112501b..b822d6622 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -404,7 +404,8 @@ def post_visit(self, expr: Any) -> None: self.counts[type(expr)] += 1 -def get_node_type_counts(outputs: Union[Array, DictOfNamedArrays]) -> Dict[Type[Any], int]: +def get_node_type_counts(outputs: Union[Array, DictOfNamedArrays] + ) -> Dict[Type[Any], int]: """ Returns a dictionary mapping node types to node count for that type in DAG *outputs*.