Skip to content

Commit

Permalink
Improves fingerprint definition for Modules in nn.jit.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 607465305
  • Loading branch information
Cristian Garcia authored and Flax Authors committed Mar 4, 2024
1 parent 7d6de00 commit c338072
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 34 deletions.
153 changes: 124 additions & 29 deletions flax/linen/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,20 @@
import functools
import inspect
from typing import (
Any,
Callable,
Dict,
Iterable,
Mapping,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Union,
Any,
Callable,
Dict,
Iterable,
Mapping,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Union,
)

from flax import core
from flax import errors, struct, traceback_util
from flax import serialization
from flax.core import Scope, lift, meta
Expand Down Expand Up @@ -460,15 +461,19 @@ class _HashableProxy:
"""

module: Module
hash_key: int

@classmethod
def from_module(cls, module: Module) -> '_HashableProxy':
fingerprint = _module_pytree_hash(module)
hash_key = hash(fingerprint)
return cls(module, hash_key)

def __hash__(self):
key = _module_pytree_hash(self.module)
return hash(key)
return self.hash_key

def __eq__(self, other):
return isinstance(other, _HashableProxy) and _module_pytree_hash(
self.module
) == _module_pytree_hash(other.module)
return isinstance(other, _HashableProxy) and self.hash_key == other.hash_key


def _recursive_freeze(name: str, x: Any, seen_modules: dict[FlaxId, int]):
Expand Down Expand Up @@ -500,8 +505,6 @@ def _module_pytree_hash_recursive(
# if its a new module we add it to the cache and give it
# a new index
seen_modules[module._id] = len(seen_modules)
# TODO(cgarciae): define a way for the user of nn.jit to define
# what fields it wants to ignore per Module instance.
values = []
for field in dataclasses.fields(module):
if field.init and field.name not in ('parent', 'name'):
Expand All @@ -514,32 +517,124 @@ def _module_pytree_hash_recursive(
else:
_check_field_is_hashable(field.name, value)
values.append((field.name, value))
# add state static values
static_state = (
('in_compact_method', module._state.in_compact_method),
('in_setup', module._state.in_setup),
('setup_called', module._state.setup_called),
('is_initialized', module._state.is_initialized),
('autoname_cursor', tuple(module._state.autoname_cursor.items())),
)
_check_field_is_hashable('_state', static_state)
values.append(('_state', static_state))
# add scope static values
scope = module.scope
if scope is not None:
static_scope = (
('mutable', _filter_static(scope.mutable)),
('flags', _freeze_collections(scope.flags)),
('rng_counts', _freeze_collections(scope.rng_counters.items())),
('reservations', _freeze_collections(scope.reservations)),
)
_check_field_is_hashable('scope', static_scope)
values.append(('scope', static_scope))
# build static tuple
static = tuple(values)

return type(module), static


def _check_field_is_hashable(name: str, x: Any):
"""Checks if a field is hashable."""
def _filter_static(_filter: CollectionFilter, /):
if isinstance(_filter, (bool, str)):
return _filter
elif isinstance(_filter, core.DenyList):
return ('_denylist', _filter_static(_filter.deny))
elif isinstance(_filter, Mapping):
return tuple((k, _filter_static(v)) for k, v in _filter.items())
elif isinstance(_filter, Iterable):
return tuple(map(_filter_static, _filter))
else:
raise ValueError(f'Unknown filter type: {type(_filter)}')


def _freeze_collections(x: Any) -> Any:
if isinstance(x, str):
return x
if isinstance(x, Mapping):
return tuple((k, _freeze_collections(v)) for k, v in x.items())
elif isinstance(x, Iterable):
return tuple(_freeze_collections(v) for v in x)
else:
return x


def _is_hashable(x):
try:
hash(x)
except Exception as e:
return True
except Exception:
return False


def _recursive_find_unhashable_iter(
path: tuple[str, ...], obj: Any, seen_ids: set[int]
) -> Iterable[tuple[tuple[str, ...], Any]]:
if id(obj) in seen_ids:
return
seen_ids.add(id(obj))
if serialization.is_serializable(obj):
for k, v in serialization.to_state_dict(obj).items():
next_path = (*path, k)
if not _is_hashable(v):
yield next_path, v
yield from _recursive_find_unhashable_iter(next_path, v, seen_ids)
elif dataclasses.is_dataclass(obj):
ignore = ('parent', 'name') if isinstance(obj, Module) else ()
for field in dataclasses.fields(obj):
if field.name in ignore:
continue
value = getattr(obj, field.name)
next_path = (*path, field.name)
if not _is_hashable(value):
yield next_path, value
yield from _recursive_find_unhashable_iter(next_path, value, seen_ids)
elif not _is_hashable(obj):
yield path, obj


def _recursive_find_unhashable(name: str, obj: Any):
return (
('/'.join(path), obj)
for path, obj in _recursive_find_unhashable_iter((name,), obj, set())
)


def _check_field_is_hashable(name: str, x: Any):
"""Checks if a field is hashable."""
if not _is_hashable(x):
unhashable_paths = ''.join(
f'\n- {path}: {type(value)}'
for path, value in _recursive_find_unhashable(name, x)
)
if dataclasses.is_dataclass(x):
if x.__hash__ is None:
raise ValueError(
f"field '{name}' of type '{type(x)}' is a dataclass but its not"
' hashable, using `dataclass(frozen=True, eq=True)` to make it'
' immutable and hashable, or `dataclass(unsafe_hash=True)` for'
' mutable types.'
) from e
f' mutable types. Unhashable fields: {unhashable_paths}'
)
else:
raise ValueError(
f"field '{name}' of type '{type(x)}' is a hashable dataclass but"
' hashing failed, this probably means that at least one of its'
' fields is not hashable.'
) from e
' fields is not hashable. Unhashable fields:'
f' {unhashable_paths}'
)
else:
raise ValueError(f"type '{type(x)}' is not hashable.") from e
raise ValueError(
f"type '{x}' is not hashable. Unhashable fields: {unhashable_paths}"
)


def decorator_lift_transform_jit(class_fn, **trafo_kwargs):
Expand Down Expand Up @@ -610,8 +705,8 @@ def core_fn(
)
module_scopes = module_scopes[0]

# get a hash for the Module by using its repr as a proxy
hash_key = _HashableProxy(self)
# get a hashable proxy object for the Module
hash_key = _HashableProxy.from_module(self)

return trafo_fn(module_scopes, hash_key, *args, **kwargs)

Expand Down Expand Up @@ -684,7 +779,7 @@ def core_fn(scopes, module_hash, *args, **kwargs):
module_scopes, args, kwargs = get_module_scopes(self, args, kwargs)

# get a hash for the Module by using its repr as a proxy
hash_key = _HashableProxy(self)
hash_key = _HashableProxy.from_module(self)

ret = trafo_fn(module_scopes, hash_key, *args, **kwargs)
return ret
Expand Down
13 changes: 8 additions & 5 deletions tests/linen/linen_transforms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1801,14 +1801,14 @@ def f(obj):
n += 1
return None

f(_HashableProxy(nn.Dense(10)))
f(_HashableProxy.from_module(nn.Dense(10)))
self.assertEqual(n, 1)
f(_HashableProxy(nn.Dense(10)))
f(_HashableProxy.from_module(nn.Dense(10)))
self.assertEqual(n, 1)

f(_HashableProxy(nn.Dense(20)))
f(_HashableProxy.from_module(nn.Dense(20)))
self.assertEqual(n, 2)
f(_HashableProxy(nn.Dense(20)))
f(_HashableProxy.from_module(nn.Dense(20)))
self.assertEqual(n, 2)

def test_jit_reuse(self):
Expand Down Expand Up @@ -1948,7 +1948,10 @@ def __init__(self, a: int):

def __hash__(self):
# test object is not being passed as static
raise Exception('immutable')
raise Exception('this should not be called')

def __eq__(self, __value, /):
raise Exception('this should not be called')

def to_dict(node: Node):
return {'a': node.a}
Expand Down

0 comments on commit c338072

Please sign in to comment.