Skip to content

Commit

Permalink
add function to compute the max node depth of a DAG
Browse files Browse the repository at this point in the history
  • Loading branch information
majosm committed Jul 19, 2023
1 parent 3805500 commit e832946
Showing 1 changed file with 48 additions and 0 deletions.
48 changes: 48 additions & 0 deletions pytato/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
.. autofunction:: is_einsum_similar_to_subscript
.. autofunction:: get_num_nodes
.. autofunction:: get_max_node_depth
.. autofunction:: get_num_call_sites
Expand Down Expand Up @@ -400,6 +401,53 @@ def get_num_nodes(outputs: Union[Array, DictOfNamedArrays]) -> int:
# }}}


# {{{ NodeMaxDepthMapper

@optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True)
class NodeMaxDepthMapper(CachedWalkMapper):
"""
Finds the maximum depth of a node in a DAG.
.. attribute:: max_depth
The depth of the deepest node.
"""

def __init__(self) -> None:
super().__init__()
self.depth = 0
self.max_depth = 0

# FIXME: Do I need this?
# type-ignore-reason: dropped the extra `*args, **kwargs`.
def get_cache_key(self, expr: ArrayOrNames) -> int: # type: ignore[override]
return id(expr)

def rec(self, expr: ArrayOrNames, *args: Any, **kwargs: Any) -> None:
"""Call the mapper method of *expr* and return the result."""
self.depth += 1
self.max_depth = max(self.max_depth, self.depth)

try:
super().rec(expr, *args, **kwargs)
finally:
self.depth -= 1


def get_max_node_depth(outputs: Union[Array, DictOfNamedArrays]) -> int:
"""Finds the maximum depth of a node in *outputs*."""

from pytato.codegen import normalize_outputs
outputs = normalize_outputs(outputs)

nmdm = NodeMaxDepthMapper()
nmdm(outputs)

return nmdm.max_depth

# }}}


# {{{ CallSiteCountMapper

@optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True)
Expand Down

0 comments on commit e832946

Please sign in to comment.