diff --git a/flax/linen/transforms.py b/flax/linen/transforms.py index 2589a588c6..0659756590 100644 --- a/flax/linen/transforms.py +++ b/flax/linen/transforms.py @@ -36,21 +36,22 @@ A lifted transformation can be applied to a ``Module`` class or a function that takes a ``Module`` instance as its first argument. """ + import dataclasses import functools import inspect from typing import ( - Any, - Callable, - Dict, - Iterable, - Mapping, - Optional, - Sequence, - Tuple, - Type, - TypeVar, - Union, + Any, + Callable, + Dict, + Iterable, + Mapping, + Optional, + Sequence, + Tuple, + Type, + TypeVar, + Union, ) from flax import core @@ -465,7 +466,7 @@ class _HashableProxy: @classmethod def from_module(cls, module: Module) -> '_HashableProxy': - fingerprint = _module_pytree_hash(module) + fingerprint = _module_fingerprint(module) hash_key = hash(fingerprint) return cls(module, hash_key) @@ -476,165 +477,95 @@ def __eq__(self, other): return isinstance(other, _HashableProxy) and self.hash_key == other.hash_key -def _recursive_freeze(name: str, x: Any, seen_modules: dict[FlaxId, int]): - if isinstance(x, dict): - return tuple( - (k, _recursive_freeze(k, v, seen_modules)) for k, v in x.items() - ) - elif isinstance(x, Module): - return _module_pytree_hash_recursive(x, seen_modules) - else: - _check_field_is_hashable(name, x) - return x +def _module_fingerprint(module: Module) -> tuple[type[Any], Any]: + return _fingerprint_recursive(module, (), {}) -def _module_pytree_hash(module: Module) -> tuple[type[Any], Any]: - return _module_pytree_hash_recursive(module, {}) +def _fingerprint_recursive( + obj: Any, path: tuple[str, ...], seen_modules: dict[FlaxId, int] +) -> Any: + """Creates a hashable representation for a Module by traversing its structure recursively.""" + def _get_fingerprint(name: str, value: Any) -> tuple[str, Any]: + return name, _fingerprint_recursive(value, (*path, name), seen_modules) -def _module_pytree_hash_recursive( - module: Module, seen_modules: dict[FlaxId, int] -) -> tuple[type[Any], Any]: - """Creates a hashable representation for a Module by travering its structure recursively.""" - static: Any - if module._id in seen_modules: - # if we have already seen the module we just use the index - # as its static component - static = seen_modules[module._id] - else: - # if its a new module we add it to the cache and give it - # a new index - seen_modules[module._id] = len(seen_modules) - values = [] - for field in dataclasses.fields(module): - if field.init and field.name not in ('parent', 'name'): - value = getattr(module, field.name) - if isinstance(value, Module): - value = _module_pytree_hash_recursive(value, seen_modules) - elif serialization.is_serializable(value): - value = serialization.to_state_dict(value) - value = _recursive_freeze(field.name, value, seen_modules) - else: - _check_field_is_hashable(field.name, value) - values.append((field.name, value)) - # add state static values - static_state = ( - ('in_compact_method', module._state.in_compact_method), - ('in_setup', module._state.in_setup), - ('setup_called', module._state.setup_called), - ('is_initialized', module._state.is_initialized), - ('autoname_cursor', tuple(module._state.autoname_cursor.items())), - ) - _check_field_is_hashable('_state', static_state) - values.append(('_state', static_state)) - # add scope static values - scope = module.scope - if scope is not None: - static_scope = ( - ('mutable', _filter_static(scope.mutable)), - ('flags', _freeze_collections(scope.flags)), - ('rng_counts', _freeze_collections(scope.rng_counters.items())), - ('reservations', _freeze_collections(scope.reservations)), + if isinstance(obj, str): + return obj + elif isinstance(obj, Module): + fingerprint: Any + if obj._id in seen_modules: + # if we have already seen the module we just use the index + # as its static component + fingerprint = seen_modules[obj._id] + else: + # if its a new module we add it to the cache and give it + # a new index + seen_modules[obj._id] = len(seen_modules) + # TODO(cgarciae): define a way for the user of nn.jit to define + # what fields it wants to ignore per Module instance. + fingerprints = [] + for field in dataclasses.fields(obj): + if not hasattr(obj, field.name): + continue + if field.name not in ('parent', 'name'): + value = getattr(obj, field.name) + fingerprints.append(_get_fingerprint(field.name, value)) + # add state fingerprint + state_fingerprint = ( + _get_fingerprint('in_compact_method', obj._state.in_compact_method), + _get_fingerprint('in_setup', obj._state.in_setup), + _get_fingerprint('setup_called', obj._state.setup_called), + _get_fingerprint('is_initialized', obj._state.is_initialized), + _get_fingerprint('autoname_cursor', obj._state.autoname_cursor), ) - _check_field_is_hashable('scope', static_scope) - values.append(('scope', static_scope)) - # build static tuple - static = tuple(values) - - return type(module), static - - -def _filter_static(_filter: CollectionFilter, /): - if isinstance(_filter, (bool, str)): - return _filter - elif isinstance(_filter, core.DenyList): - return ('_denylist', _filter_static(_filter.deny)) - elif isinstance(_filter, Mapping): - return tuple((k, _filter_static(v)) for k, v in _filter.items()) - elif isinstance(_filter, Iterable): - return tuple(map(_filter_static, _filter)) - else: - raise ValueError(f'Unknown filter type: {type(_filter)}') - - -def _freeze_collections(x: Any) -> Any: - if isinstance(x, str): - return x - if isinstance(x, Mapping): - return tuple((k, _freeze_collections(v)) for k, v in x.items()) - elif isinstance(x, Iterable): - return tuple(_freeze_collections(v) for v in x) - else: - return x - - -def _is_hashable(x): - try: - hash(x) - return True - except Exception: - return False - - -def _recursive_find_unhashable_iter( - path: tuple[str, ...], obj: Any, seen_ids: set[int] -) -> Iterable[tuple[tuple[str, ...], Any]]: - if id(obj) in seen_ids: - return - seen_ids.add(id(obj)) - if serialization.is_serializable(obj): - for k, v in serialization.to_state_dict(obj).items(): - next_path = (*path, k) - if not _is_hashable(v): - yield next_path, v - yield from _recursive_find_unhashable_iter(next_path, v, seen_ids) + fingerprints.append(('_state', state_fingerprint)) + # add scope fingerprint + scope = obj.scope + if scope is not None: + static_scope = ( + _get_fingerprint('mutable', scope.mutable), + _get_fingerprint('flags', scope.flags), + _get_fingerprint('rng_counts', scope.rng_counters), + _get_fingerprint('reservations', scope.reservations), + ) + _check_field_is_hashable((*path, 'scope'), static_scope) + fingerprints.append(('scope', static_scope)) + fingerprint = tuple(fingerprints) + return type(obj), fingerprint elif dataclasses.is_dataclass(obj): - ignore = ('parent', 'name') if isinstance(obj, Module) else () + fingerprints = [] for field in dataclasses.fields(obj): - if field.name in ignore: + if not hasattr(obj, field.name): continue value = getattr(obj, field.name) - next_path = (*path, field.name) - if not _is_hashable(value): - yield next_path, value - yield from _recursive_find_unhashable_iter(next_path, value, seen_ids) - elif not _is_hashable(obj): - yield path, obj - - -def _recursive_find_unhashable(name: str, obj: Any): - return ( - ('/'.join(path), obj) - for path, obj in _recursive_find_unhashable_iter((name,), obj, set()) - ) + value_fingerprint = _get_fingerprint(field.name, value) + fingerprints.append((field.name, value_fingerprint)) + return type(obj), tuple(fingerprints) + elif isinstance(obj, core.DenyList): + return type(obj), _get_fingerprint('deny', obj.deny) + elif isinstance(obj, dict): + fingerprint = tuple((k, _get_fingerprint(k, v)) for k, v in obj.items()) + return fingerprint + elif serialization.is_serializable(obj): + state = serialization.to_state_dict(obj) + fingerprint = _fingerprint_recursive(state, path, seen_modules) + return type(obj), fingerprint + elif isinstance(obj, Mapping): + return tuple((k, _get_fingerprint(k, v)) for k, v in obj.items()) + elif isinstance(obj, Iterable): + return tuple(_get_fingerprint(str(i), v) for i, v in enumerate(obj)) + else: + _check_field_is_hashable(path, obj) + return obj -def _check_field_is_hashable(name: str, x: Any): +def _check_field_is_hashable(path: tuple[str, ...], x: Any): """Checks if a field is hashable.""" - if not _is_hashable(x): - unhashable_paths = ''.join( - f'\n- {path}: {type(value)}' - for path, value in _recursive_find_unhashable(name, x) - ) - if dataclasses.is_dataclass(x): - if x.__hash__ is None: - raise ValueError( - f"field '{name}' of type '{type(x)}' is a dataclass but its not" - ' hashable, using `dataclass(frozen=True, eq=True)` to make it' - ' immutable and hashable, or `dataclass(unsafe_hash=True)` for' - f' mutable types. Unhashable fields: {unhashable_paths}' - ) - else: - raise ValueError( - f"field '{name}' of type '{type(x)}' is a hashable dataclass but" - ' hashing failed, this probably means that at least one of its' - ' fields is not hashable. Unhashable fields:' - f' {unhashable_paths}' - ) - else: - raise ValueError( - f"type '{x}' is not hashable. Unhashable fields: {unhashable_paths}" - ) + try: + hash(x) + except Exception as e: + path_name = '/'.join(path) + raise ValueError(f"Value at '{path_name}' is not hashable: {e}") from e def decorator_lift_transform_jit(class_fn, **trafo_kwargs): @@ -664,12 +595,12 @@ def wrapped_fn(self: Module, *args, **kwargs): # make a scope-function to transform def core_fn( - prewrapped_fn, - class_fn, - scopes, - module_hash, - *args, - **kwargs, + prewrapped_fn, + class_fn, + scopes, + module_hash, + *args, + **kwargs, ): # self = hash_key.obj self: Module = module_hash.module @@ -683,8 +614,8 @@ def core_fn( return res core_fns = [ - functools.partial(core_fn, prewrapped_fn, class_fn) - for prewrapped_fn, class_fn in zip(prewrapped_fns, class_fns) + functools.partial(core_fn, prewrapped_fn, class_fn) + for prewrapped_fn, class_fn in zip(prewrapped_fns, class_fns) ] # here we apply the given lifting transform to the scope-ingesting fn @@ -700,8 +631,8 @@ def core_fn( # constructors. Therefore, there is no obvious API for specifying # arguments per lifted Module. raise NotImplementedError( - 'This transform does not yet support' - ' Modules that include other Modules passed as arguments.' + 'This transform does not yet support' + ' Modules that include other Modules passed as arguments.' ) module_scopes = module_scopes[0] @@ -732,12 +663,12 @@ def module_class_lift_transform_jit(module_class, methods=None, **trafo_kwargs): class_trafo_args = {k: ((), v) for k, v in methods.items()} else: raise ValueError( - 'transform methods argument must be None, tuple, list, or dict.' + 'transform methods argument must be None, tuple, list, or dict.' ) # Handle partially initialized module class constructors. if isinstance(module_class, functools.partial) and issubclass( - module_class.func, Module + module_class.func, Module ): partial_object = module_class module_class = module_class.func @@ -761,9 +692,9 @@ def core_fn(scopes, module_hash, *args, **kwargs): self: Module = module_hash.module # make a clone of self using its arguments attrs = { - f.name: getattr(self, f.name) - for f in dataclasses.fields(self) - if f.name != 'parent' and f.init + f.name: getattr(self, f.name) + for f in dataclasses.fields(self) + if f.name != 'parent' and f.init } # we reference module_class, not self.__class__ to avoid infinite loop cloned = module_class(parent=None, **attrs) @@ -787,19 +718,19 @@ def core_fn(scopes, module_hash, *args, **kwargs): return wrapped_fn transformed_fns = { - fn_name: create_trans_fn(fn_name, fn_trafo_args) - for fn_name, fn_trafo_args in class_trafo_args.items() + fn_name: create_trans_fn(fn_name, fn_trafo_args) + for fn_name, fn_trafo_args in class_trafo_args.items() } # construct new dynamic class w. transformed methods transformed_cls = type( - transform.__name__.capitalize() + module_class.__name__, - (module_class,), - transformed_fns, + transform.__name__.capitalize() + module_class.__name__, + (module_class,), + transformed_fns, ) # Handle partially initialized module class constructors. if partial_object is not None: transformed_cls = functools.partial( - transformed_cls, *partial_object.args, **partial_object.keywords + transformed_cls, *partial_object.args, **partial_object.keywords ) return transformed_cls @@ -959,15 +890,15 @@ def vmap( def jit( - target: Target, - variables: CollectionFilter = True, - rngs: PRNGSequenceFilter = True, - static_argnums: Union[int, Iterable[int]] = (), - static_argnames: Union[str, Iterable[str]] = (), - donate_argnums: Union[int, Iterable[int]] = (), - device=None, - backend: Union[str, None] = None, - methods=None, + target: Target, + variables: CollectionFilter = True, + rngs: PRNGSequenceFilter = True, + static_argnums: Union[int, Iterable[int]] = (), + static_argnames: Union[str, Iterable[str]] = (), + donate_argnums: Union[int, Iterable[int]] = (), + device=None, + backend: Union[str, None] = None, + methods=None, ) -> Target: """Lifted version of ``jax.jit``. @@ -1016,27 +947,27 @@ def jit( # TODO(marcvanzee): Improve docstrings (#1977). if _is_module_class(target): return module_class_lift_transform_jit( - target, - variables=variables, - rngs=rngs, - static_argnums=static_argnums, - static_argnames=static_argnames, - donate_argnums=donate_argnums, - device=device, - backend=backend, - methods=methods, + target, + variables=variables, + rngs=rngs, + static_argnums=static_argnums, + static_argnames=static_argnames, + donate_argnums=donate_argnums, + device=device, + backend=backend, + methods=methods, ) # we presume this is being used as a function decorator in class definition elif callable(target) and not isinstance(target, Module): return decorator_lift_transform_jit( - target, - variables=variables, - rngs=rngs, - static_argnums=static_argnums, - static_argnames=static_argnames, - donate_argnums=donate_argnums, - device=device, - backend=backend, + target, + variables=variables, + rngs=rngs, + static_argnums=static_argnums, + static_argnames=static_argnames, + donate_argnums=donate_argnums, + device=device, + backend=backend, ) else: raise errors.TransformTargetError(target) @@ -1138,9 +1069,9 @@ def remat_scan( policy: Optional[Callable[..., bool]] = None, variable_broadcast: CollectionFilter = False, variable_carry: CollectionFilter = False, - variable_axes: Mapping[ - CollectionFilter, InOutScanAxis - ] = FrozenDict({True: 0}), + variable_axes: Mapping[CollectionFilter, InOutScanAxis] = FrozenDict( + {True: 0} + ), split_rngs: Mapping[PRNGSequenceFilter, bool] = FrozenDict({True: True}), ) -> Target: """Combines remat and scan for memory efficiency and constant time compilation. @@ -1200,9 +1131,7 @@ def remat_scan( def scan( target: Target, - variable_axes: Mapping[ - CollectionFilter, InOutScanAxis - ] = FrozenDict(), + variable_axes: Mapping[CollectionFilter, InOutScanAxis] = FrozenDict(), variable_broadcast: CollectionFilter = False, variable_carry: CollectionFilter = False, split_rngs: Mapping[PRNGSequenceFilter, bool] = FrozenDict(),