Skip to content

Commit

Permalink
added enum support for tabulate
Browse files Browse the repository at this point in the history
  • Loading branch information
chiamp committed Nov 14, 2023
1 parent 2f6e5ff commit b23a39f
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 1 deletion.
23 changes: 22 additions & 1 deletion flax/linen/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Flax Module summary library."""

import dataclasses
import enum
import io
from abc import ABC, abstractmethod
from types import MappingProxyType
Expand Down Expand Up @@ -42,7 +43,7 @@
import yaml

import flax.linen.module as module_lib
from flax.core import meta, unfreeze
from flax.core import FrozenDict, meta, unfreeze
from flax.core.scope import (
CollectionFilter,
DenyList,
Expand Down Expand Up @@ -310,6 +311,26 @@ def _tabulate_fn(*fn_args, **fn_kwargs):
compute_vjp_flops=compute_vjp_flops,
)

def process_enum(dict_obj):
"""If a key in dict_obj is an Enum, convert the Enum key
to its underlying value since the function `_as_yaml_str`
(called in `_render_table`) does not support Enum key types"""
return {
key.value if isinstance(key, enum.Enum) else key: value
for key, value in dict_obj.items()
}

fn_args = tuple(
process_enum(value) if isinstance(value, (dict, FrozenDict)) else value
for value in fn_args
)
fn_kwargs = {
key: process_enum(value)
if isinstance(value, (dict, FrozenDict))
else value
for key, value in fn_kwargs.items()
}

table = table_fn(rngs, *fn_args, **fn_kwargs, **kwargs)

non_param_cols = [
Expand Down
22 changes: 22 additions & 0 deletions tests/linen/summary_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import enum
from typing import List

import jax
Expand Down Expand Up @@ -735,6 +737,26 @@ def __call__(self, x):
lines = rep.splitlines()
self.assertIn('Total Parameters: 50', lines[-2])

def test_tabulate_enum(self):
class Net(nn.Module):
@nn.compact
def __call__(self, inputs):
x = inputs['x']
x = nn.Dense(features=2)(x)
return jnp.sum(x)

class InputEnum(str, enum.Enum):
x = 'x'

inputs = {InputEnum.x: jnp.ones((1, 1))}
# test args
lines = Net().tabulate(jax.random.key(0), inputs).split('\n')
self.assertIn('x: \x1b[2mfloat32\x1b[0m[1,1]', lines[5])
# test kwargs
lines = Net().tabulate(jax.random.key(0), inputs=inputs).split('\n')
self.assertIn('inputs:', lines[5])
self.assertIn('x: \x1b[2mfloat32\x1b[0m[1,1]', lines[6])


if __name__ == '__main__':
absltest.main()

0 comments on commit b23a39f

Please sign in to comment.