Skip to content

Commit

Permalink
Fix bug that assumed frozen-dict keys were strings.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 606707497
  • Loading branch information
levskaya authored and Flax Authors committed Feb 13, 2024
1 parent a566196 commit 0431206
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 1 deletion.
2 changes: 1 addition & 1 deletion flax/core/frozen_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def _frozen_dict_state_dict(xs):


def _restore_frozen_dict(xs, states):
diff = set(map(str, xs.keys())).difference(states.keys())
diff = set(map(str, xs.keys())).difference(map(str, states.keys()))
if diff:
raise ValueError(
'The target dict keys and state dict keys do not match, target dict'
Expand Down
12 changes: 12 additions & 0 deletions tests/linen/linen_module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import contextlib
import copy
import dataclasses
import enum
import functools
import gc
import inspect
Expand Down Expand Up @@ -3037,6 +3038,17 @@ def __call__(self, x):
self.assertIn('next_layer_1', variables['params'])
self.assertNotIn('child_template', variables['params'])

def test_nonstring_keys_in_dict_on_module(self):
class MyEnum(str, enum.Enum):
a = 'a'
b = 'b'
class MyModule(nn.Module):
config: dict[MyEnum, int]
def __call__(self, inputs):
return inputs
module = MyModule(config={MyEnum.a: 1, MyEnum.b: 2})
variables = module.init(jax.random.key(0), jnp.zeros([0]))


class FrozenDictTests(absltest.TestCase):
def test_frozendict_flag(self):
Expand Down

0 comments on commit 0431206

Please sign in to comment.