diff --git a/flax/linen/summary.py b/flax/linen/summary.py index 56fcf2051a..fccf7ea794 100644 --- a/flax/linen/summary.py +++ b/flax/linen/summary.py @@ -15,6 +15,7 @@ """Flax Module summary library.""" import dataclasses +import enum import io from abc import ABC, abstractmethod from types import MappingProxyType @@ -42,7 +43,7 @@ import yaml import flax.linen.module as module_lib -from flax.core import meta, unfreeze +from flax.core import FrozenDict, meta, unfreeze from flax.core.scope import ( CollectionFilter, DenyList, @@ -709,12 +710,17 @@ def _normalize_structure(obj): if isinstance(obj, (tuple, list)): return tuple(map(_normalize_structure, obj)) elif isinstance(obj, Mapping): - return {k: _normalize_structure(v) for k, v in obj.items()} + return { + _normalize_structure(k): _normalize_structure(v) for k, v in obj.items() + } elif dataclasses.is_dataclass(obj): return { f.name: _normalize_structure(getattr(obj, f.name)) for f in dataclasses.fields(obj) } + elif isinstance(obj, enum.Enum): + # `yaml.safe_dump` does not support Enum key types so extract the underlying value + return obj.value else: return obj diff --git a/tests/linen/summary_test.py b/tests/linen/summary_test.py index 6d4be643df..3bcaf71e1d 100644 --- a/tests/linen/summary_test.py +++ b/tests/linen/summary_test.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. + +import enum from typing import List import jax @@ -735,6 +737,26 @@ def __call__(self, x): lines = rep.splitlines() self.assertIn('Total Parameters: 50', lines[-2]) + def test_tabulate_enum(self): + class Net(nn.Module): + @nn.compact + def __call__(self, inputs): + x = inputs['x'] + x = nn.Dense(features=2)(x) + return jnp.sum(x) + + class InputEnum(str, enum.Enum): + x = 'x' + + inputs = {InputEnum.x: jnp.ones((1, 1))} + # test args + lines = Net().tabulate(jax.random.key(0), inputs).split('\n') + self.assertIn('x: \x1b[2mfloat32\x1b[0m[1,1]', lines[5]) + # test kwargs + lines = Net().tabulate(jax.random.key(0), inputs=inputs).split('\n') + self.assertIn('inputs:', lines[5]) + self.assertIn('x: \x1b[2mfloat32\x1b[0m[1,1]', lines[6]) + if __name__ == '__main__': absltest.main()