Skip to content

Commit

Permalink
Fix some linting issues.
Browse files Browse the repository at this point in the history
  • Loading branch information
MTCam committed Mar 16, 2024
1 parent dfaaeed commit b9edcae
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions pytato/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
"""

from typing import (Mapping, Dict, Union, Set, Tuple, Any, FrozenSet,
TYPE_CHECKING)
Type, TYPE_CHECKING)
from pytato.array import (Array, IndexLambda, Stack, Concatenate, Einsum,
DictOfNamedArrays, NamedArray,
IndexBase, IndexRemappingBase, InputArgumentBase,
Expand Down Expand Up @@ -426,18 +426,24 @@ class NodeTypeCountMapper(CachedWalkMapper):
"""

def __init__(self) -> None:
from collections import defaultdict
super().__init__()
self.counts = {}
self.counts = defaultdict(int)

def get_cache_key(self, expr: ArrayOrNames) -> int:
return id(expr)

def post_visit(self, expr: Any) -> None:
if type(expr) not in self.counts:
self.counts[type(expr)] = 0
self.counts[type(expr)] += 1


def get_num_node_types(outputs: Union[Array, DictOfNamedArrays]) -> int:
"""Returns the number of nodes of each given type in DAG *outputs*."""
def get_num_node_types(outputs: Union[Array, DictOfNamedArrays]) -> Dict[Type, int]:
"""
Returns a dictionary mapping node types to node count for that type
in DAG *outputs*.
"""

from pytato.codegen import normalize_outputs
outputs = normalize_outputs(outputs)
Expand Down

0 comments on commit b9edcae

Please sign in to comment.