diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index d0fd6ef1e..baf43a9f1 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -48,6 +48,7 @@ .. autofunction:: is_einsum_similar_to_subscript .. autofunction:: get_num_nodes +.. autofunction:: get_max_node_depth .. autofunction:: get_num_call_sites @@ -400,6 +401,53 @@ def get_num_nodes(outputs: Union[Array, DictOfNamedArrays]) -> int: # }}} +# {{{ NodeMaxDepthMapper + +@optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True) +class NodeMaxDepthMapper(CachedWalkMapper): + """ + Finds the maximum depth of a node in a DAG. + + .. attribute:: max_depth + + The depth of the deepest node. + """ + + def __init__(self) -> None: + super().__init__() + self.depth = 0 + self.max_depth = 0 + + # FIXME: Do I need this? + # type-ignore-reason: dropped the extra `*args, **kwargs`. + def get_cache_key(self, expr: ArrayOrNames) -> int: # type: ignore[override] + return id(expr) + + def rec(self, expr: ArrayOrNames, *args: Any, **kwargs: Any) -> None: + """Call the mapper method of *expr* and return the result.""" + self.depth += 1 + self.max_depth = max(self.max_depth, self.depth) + + try: + super().rec(expr, *args, **kwargs) + finally: + self.depth -= 1 + + +def get_max_node_depth(outputs: Union[Array, DictOfNamedArrays]) -> int: + """Finds the maximum depth of a node in *outputs*.""" + + from pytato.codegen import normalize_outputs + outputs = normalize_outputs(outputs) + + nmdm = NodeMaxDepthMapper() + nmdm(outputs) + + return nmdm.max_depth + +# }}} + + # {{{ CallSiteCountMapper @optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True)