Skip to content

Commit

Permalink
added enum support for tabulate
Browse files Browse the repository at this point in the history
  • Loading branch information
chiamp committed Nov 15, 2023
1 parent 2f6e5ff commit 47019f6
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 2 deletions.
10 changes: 8 additions & 2 deletions flax/linen/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Flax Module summary library."""

import dataclasses
import enum
import io
from abc import ABC, abstractmethod
from types import MappingProxyType
Expand Down Expand Up @@ -42,7 +43,7 @@
import yaml

import flax.linen.module as module_lib
from flax.core import meta, unfreeze
from flax.core import FrozenDict, meta, unfreeze
from flax.core.scope import (
CollectionFilter,
DenyList,
Expand Down Expand Up @@ -709,12 +710,17 @@ def _normalize_structure(obj):
if isinstance(obj, (tuple, list)):
return tuple(map(_normalize_structure, obj))
elif isinstance(obj, Mapping):
return {k: _normalize_structure(v) for k, v in obj.items()}
return {
_normalize_structure(k): _normalize_structure(v) for k, v in obj.items()
}
elif dataclasses.is_dataclass(obj):
return {
f.name: _normalize_structure(getattr(obj, f.name))
for f in dataclasses.fields(obj)
}
elif isinstance(obj, enum.Enum):
# `yaml.safe_dump` does not support Enum key types so extract the underlying value
return obj.value
else:
return obj

Expand Down
22 changes: 22 additions & 0 deletions tests/linen/summary_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import enum
from typing import List

import jax
Expand Down Expand Up @@ -735,6 +737,26 @@ def __call__(self, x):
lines = rep.splitlines()
self.assertIn('Total Parameters: 50', lines[-2])

def test_tabulate_enum(self):
class Net(nn.Module):
@nn.compact
def __call__(self, inputs):
x = inputs['x']
x = nn.Dense(features=2)(x)
return jnp.sum(x)

class InputEnum(str, enum.Enum):
x = 'x'

inputs = {InputEnum.x: jnp.ones((1, 1))}
# test args
lines = Net().tabulate(jax.random.key(0), inputs).split('\n')
self.assertIn('x: \x1b[2mfloat32\x1b[0m[1,1]', lines[5])
# test kwargs
lines = Net().tabulate(jax.random.key(0), inputs=inputs).split('\n')
self.assertIn('inputs:', lines[5])
self.assertIn('x: \x1b[2mfloat32\x1b[0m[1,1]', lines[6])


if __name__ == '__main__':
absltest.main()

0 comments on commit 47019f6

Please sign in to comment.