diff --git a/flax/linen/transforms.py b/flax/linen/transforms.py index 9e44aba5ec..2589a588c6 100644 --- a/flax/linen/transforms.py +++ b/flax/linen/transforms.py @@ -40,19 +40,20 @@ import functools import inspect from typing import ( - Any, - Callable, - Dict, - Iterable, - Mapping, - Optional, - Sequence, - Tuple, - Type, - TypeVar, - Union, + Any, + Callable, + Dict, + Iterable, + Mapping, + Optional, + Sequence, + Tuple, + Type, + TypeVar, + Union, ) +from flax import core from flax import errors, struct, traceback_util from flax import serialization from flax.core import Scope, lift, meta @@ -460,15 +461,19 @@ class _HashableProxy: """ module: Module + hash_key: int + + @classmethod + def from_module(cls, module: Module) -> '_HashableProxy': + fingerprint = _module_pytree_hash(module) + hash_key = hash(fingerprint) + return cls(module, hash_key) def __hash__(self): - key = _module_pytree_hash(self.module) - return hash(key) + return self.hash_key def __eq__(self, other): - return isinstance(other, _HashableProxy) and _module_pytree_hash( - self.module - ) == _module_pytree_hash(other.module) + return isinstance(other, _HashableProxy) and self.hash_key == other.hash_key def _recursive_freeze(name: str, x: Any, seen_modules: dict[FlaxId, int]): @@ -500,8 +505,6 @@ def _module_pytree_hash_recursive( # if its a new module we add it to the cache and give it # a new index seen_modules[module._id] = len(seen_modules) - # TODO(cgarciae): define a way for the user of nn.jit to define - # what fields it wants to ignore per Module instance. values = [] for field in dataclasses.fields(module): if field.init and field.name not in ('parent', 'name'): @@ -514,32 +517,124 @@ def _module_pytree_hash_recursive( else: _check_field_is_hashable(field.name, value) values.append((field.name, value)) + # add state static values + static_state = ( + ('in_compact_method', module._state.in_compact_method), + ('in_setup', module._state.in_setup), + ('setup_called', module._state.setup_called), + ('is_initialized', module._state.is_initialized), + ('autoname_cursor', tuple(module._state.autoname_cursor.items())), + ) + _check_field_is_hashable('_state', static_state) + values.append(('_state', static_state)) + # add scope static values + scope = module.scope + if scope is not None: + static_scope = ( + ('mutable', _filter_static(scope.mutable)), + ('flags', _freeze_collections(scope.flags)), + ('rng_counts', _freeze_collections(scope.rng_counters.items())), + ('reservations', _freeze_collections(scope.reservations)), + ) + _check_field_is_hashable('scope', static_scope) + values.append(('scope', static_scope)) + # build static tuple static = tuple(values) return type(module), static -def _check_field_is_hashable(name: str, x: Any): - """Checks if a field is hashable.""" +def _filter_static(_filter: CollectionFilter, /): + if isinstance(_filter, (bool, str)): + return _filter + elif isinstance(_filter, core.DenyList): + return ('_denylist', _filter_static(_filter.deny)) + elif isinstance(_filter, Mapping): + return tuple((k, _filter_static(v)) for k, v in _filter.items()) + elif isinstance(_filter, Iterable): + return tuple(map(_filter_static, _filter)) + else: + raise ValueError(f'Unknown filter type: {type(_filter)}') + + +def _freeze_collections(x: Any) -> Any: + if isinstance(x, str): + return x + if isinstance(x, Mapping): + return tuple((k, _freeze_collections(v)) for k, v in x.items()) + elif isinstance(x, Iterable): + return tuple(_freeze_collections(v) for v in x) + else: + return x + + +def _is_hashable(x): try: hash(x) - except Exception as e: + return True + except Exception: + return False + + +def _recursive_find_unhashable_iter( + path: tuple[str, ...], obj: Any, seen_ids: set[int] +) -> Iterable[tuple[tuple[str, ...], Any]]: + if id(obj) in seen_ids: + return + seen_ids.add(id(obj)) + if serialization.is_serializable(obj): + for k, v in serialization.to_state_dict(obj).items(): + next_path = (*path, k) + if not _is_hashable(v): + yield next_path, v + yield from _recursive_find_unhashable_iter(next_path, v, seen_ids) + elif dataclasses.is_dataclass(obj): + ignore = ('parent', 'name') if isinstance(obj, Module) else () + for field in dataclasses.fields(obj): + if field.name in ignore: + continue + value = getattr(obj, field.name) + next_path = (*path, field.name) + if not _is_hashable(value): + yield next_path, value + yield from _recursive_find_unhashable_iter(next_path, value, seen_ids) + elif not _is_hashable(obj): + yield path, obj + + +def _recursive_find_unhashable(name: str, obj: Any): + return ( + ('/'.join(path), obj) + for path, obj in _recursive_find_unhashable_iter((name,), obj, set()) + ) + + +def _check_field_is_hashable(name: str, x: Any): + """Checks if a field is hashable.""" + if not _is_hashable(x): + unhashable_paths = ''.join( + f'\n- {path}: {type(value)}' + for path, value in _recursive_find_unhashable(name, x) + ) if dataclasses.is_dataclass(x): if x.__hash__ is None: raise ValueError( f"field '{name}' of type '{type(x)}' is a dataclass but its not" ' hashable, using `dataclass(frozen=True, eq=True)` to make it' ' immutable and hashable, or `dataclass(unsafe_hash=True)` for' - ' mutable types.' - ) from e + f' mutable types. Unhashable fields: {unhashable_paths}' + ) else: raise ValueError( f"field '{name}' of type '{type(x)}' is a hashable dataclass but" ' hashing failed, this probably means that at least one of its' - ' fields is not hashable.' - ) from e + ' fields is not hashable. Unhashable fields:' + f' {unhashable_paths}' + ) else: - raise ValueError(f"type '{type(x)}' is not hashable.") from e + raise ValueError( + f"type '{x}' is not hashable. Unhashable fields: {unhashable_paths}" + ) def decorator_lift_transform_jit(class_fn, **trafo_kwargs): @@ -610,8 +705,8 @@ def core_fn( ) module_scopes = module_scopes[0] - # get a hash for the Module by using its repr as a proxy - hash_key = _HashableProxy(self) + # get a hashable proxy object for the Module + hash_key = _HashableProxy.from_module(self) return trafo_fn(module_scopes, hash_key, *args, **kwargs) @@ -684,7 +779,7 @@ def core_fn(scopes, module_hash, *args, **kwargs): module_scopes, args, kwargs = get_module_scopes(self, args, kwargs) # get a hash for the Module by using its repr as a proxy - hash_key = _HashableProxy(self) + hash_key = _HashableProxy.from_module(self) ret = trafo_fn(module_scopes, hash_key, *args, **kwargs) return ret diff --git a/tests/linen/linen_transforms_test.py b/tests/linen/linen_transforms_test.py index 53453d6fd4..12f8bb123a 100644 --- a/tests/linen/linen_transforms_test.py +++ b/tests/linen/linen_transforms_test.py @@ -1801,14 +1801,14 @@ def f(obj): n += 1 return None - f(_HashableProxy(nn.Dense(10))) + f(_HashableProxy.from_module(nn.Dense(10))) self.assertEqual(n, 1) - f(_HashableProxy(nn.Dense(10))) + f(_HashableProxy.from_module(nn.Dense(10))) self.assertEqual(n, 1) - f(_HashableProxy(nn.Dense(20))) + f(_HashableProxy.from_module(nn.Dense(20))) self.assertEqual(n, 2) - f(_HashableProxy(nn.Dense(20))) + f(_HashableProxy.from_module(nn.Dense(20))) self.assertEqual(n, 2) def test_jit_reuse(self): @@ -1948,7 +1948,10 @@ def __init__(self, a: int): def __hash__(self): # test object is not being passed as static - raise Exception('immutable') + raise Exception('this should not be called') + + def __eq__(self, __value, /): + raise Exception('this should not be called') def to_dict(node: Node): return {'a': node.a}