Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improves fingerprint definition for Modules in nn.jit. #3736

Merged
1 commit merged into from
Mar 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 124 additions & 29 deletions flax/linen/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,20 @@
import functools
import inspect
from typing import (
Any,
Callable,
Dict,
Iterable,
Mapping,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Union,
Any,
Callable,
Dict,
Iterable,
Mapping,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Union,
)

from flax import core
from flax import errors, struct, traceback_util
from flax import serialization
from flax.core import Scope, lift, meta
Expand Down Expand Up @@ -460,15 +461,19 @@ class _HashableProxy:
"""

module: Module
hash_key: int

@classmethod
def from_module(cls, module: Module) -> '_HashableProxy':
fingerprint = _module_pytree_hash(module)
hash_key = hash(fingerprint)
return cls(module, hash_key)

def __hash__(self):
key = _module_pytree_hash(self.module)
return hash(key)
return self.hash_key

def __eq__(self, other):
return isinstance(other, _HashableProxy) and _module_pytree_hash(
self.module
) == _module_pytree_hash(other.module)
return isinstance(other, _HashableProxy) and self.hash_key == other.hash_key


def _recursive_freeze(name: str, x: Any, seen_modules: dict[FlaxId, int]):
Expand Down Expand Up @@ -500,8 +505,6 @@ def _module_pytree_hash_recursive(
# if its a new module we add it to the cache and give it
# a new index
seen_modules[module._id] = len(seen_modules)
# TODO(cgarciae): define a way for the user of nn.jit to define
# what fields it wants to ignore per Module instance.
values = []
for field in dataclasses.fields(module):
if field.init and field.name not in ('parent', 'name'):
Expand All @@ -514,32 +517,124 @@ def _module_pytree_hash_recursive(
else:
_check_field_is_hashable(field.name, value)
values.append((field.name, value))
# add state static values
static_state = (
('in_compact_method', module._state.in_compact_method),
('in_setup', module._state.in_setup),
('setup_called', module._state.setup_called),
('is_initialized', module._state.is_initialized),
('autoname_cursor', tuple(module._state.autoname_cursor.items())),
)
_check_field_is_hashable('_state', static_state)
values.append(('_state', static_state))
# add scope static values
scope = module.scope
if scope is not None:
static_scope = (
('mutable', _filter_static(scope.mutable)),
('flags', _freeze_collections(scope.flags)),
('rng_counts', _freeze_collections(scope.rng_counters.items())),
('reservations', _freeze_collections(scope.reservations)),
)
_check_field_is_hashable('scope', static_scope)
values.append(('scope', static_scope))
# build static tuple
static = tuple(values)

return type(module), static


def _check_field_is_hashable(name: str, x: Any):
"""Checks if a field is hashable."""
def _filter_static(_filter: CollectionFilter, /):
if isinstance(_filter, (bool, str)):
return _filter
elif isinstance(_filter, core.DenyList):
return ('_denylist', _filter_static(_filter.deny))
elif isinstance(_filter, Mapping):
return tuple((k, _filter_static(v)) for k, v in _filter.items())
elif isinstance(_filter, Iterable):
return tuple(map(_filter_static, _filter))
else:
raise ValueError(f'Unknown filter type: {type(_filter)}')


def _freeze_collections(x: Any) -> Any:
if isinstance(x, str):
return x
if isinstance(x, Mapping):
return tuple((k, _freeze_collections(v)) for k, v in x.items())
elif isinstance(x, Iterable):
return tuple(_freeze_collections(v) for v in x)
else:
return x


def _is_hashable(x):
try:
hash(x)
except Exception as e:
return True
except Exception:
return False


def _recursive_find_unhashable_iter(
path: tuple[str, ...], obj: Any, seen_ids: set[int]
) -> Iterable[tuple[tuple[str, ...], Any]]:
if id(obj) in seen_ids:
return
seen_ids.add(id(obj))
if serialization.is_serializable(obj):
for k, v in serialization.to_state_dict(obj).items():
next_path = (*path, k)
if not _is_hashable(v):
yield next_path, v
yield from _recursive_find_unhashable_iter(next_path, v, seen_ids)
elif dataclasses.is_dataclass(obj):
ignore = ('parent', 'name') if isinstance(obj, Module) else ()
for field in dataclasses.fields(obj):
if field.name in ignore:
continue
value = getattr(obj, field.name)
next_path = (*path, field.name)
if not _is_hashable(value):
yield next_path, value
yield from _recursive_find_unhashable_iter(next_path, value, seen_ids)
elif not _is_hashable(obj):
yield path, obj


def _recursive_find_unhashable(name: str, obj: Any):
return (
('/'.join(path), obj)
for path, obj in _recursive_find_unhashable_iter((name,), obj, set())
)


def _check_field_is_hashable(name: str, x: Any):
"""Checks if a field is hashable."""
if not _is_hashable(x):
unhashable_paths = ''.join(
f'\n- {path}: {type(value)}'
for path, value in _recursive_find_unhashable(name, x)
)
if dataclasses.is_dataclass(x):
if x.__hash__ is None:
raise ValueError(
f"field '{name}' of type '{type(x)}' is a dataclass but its not"
' hashable, using `dataclass(frozen=True, eq=True)` to make it'
' immutable and hashable, or `dataclass(unsafe_hash=True)` for'
' mutable types.'
) from e
f' mutable types. Unhashable fields: {unhashable_paths}'
)
else:
raise ValueError(
f"field '{name}' of type '{type(x)}' is a hashable dataclass but"
' hashing failed, this probably means that at least one of its'
' fields is not hashable.'
) from e
' fields is not hashable. Unhashable fields:'
f' {unhashable_paths}'
)
else:
raise ValueError(f"type '{type(x)}' is not hashable.") from e
raise ValueError(
f"type '{x}' is not hashable. Unhashable fields: {unhashable_paths}"
)


def decorator_lift_transform_jit(class_fn, **trafo_kwargs):
Expand Down Expand Up @@ -610,8 +705,8 @@ def core_fn(
)
module_scopes = module_scopes[0]

# get a hash for the Module by using its repr as a proxy
hash_key = _HashableProxy(self)
# get a hashable proxy object for the Module
hash_key = _HashableProxy.from_module(self)

return trafo_fn(module_scopes, hash_key, *args, **kwargs)

Expand Down Expand Up @@ -684,7 +779,7 @@ def core_fn(scopes, module_hash, *args, **kwargs):
module_scopes, args, kwargs = get_module_scopes(self, args, kwargs)

# get a hash for the Module by using its repr as a proxy
hash_key = _HashableProxy(self)
hash_key = _HashableProxy.from_module(self)

ret = trafo_fn(module_scopes, hash_key, *args, **kwargs)
return ret
Expand Down
13 changes: 8 additions & 5 deletions tests/linen/linen_transforms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1801,14 +1801,14 @@ def f(obj):
n += 1
return None

f(_HashableProxy(nn.Dense(10)))
f(_HashableProxy.from_module(nn.Dense(10)))
self.assertEqual(n, 1)
f(_HashableProxy(nn.Dense(10)))
f(_HashableProxy.from_module(nn.Dense(10)))
self.assertEqual(n, 1)

f(_HashableProxy(nn.Dense(20)))
f(_HashableProxy.from_module(nn.Dense(20)))
self.assertEqual(n, 2)
f(_HashableProxy(nn.Dense(20)))
f(_HashableProxy.from_module(nn.Dense(20)))
self.assertEqual(n, 2)

def test_jit_reuse(self):
Expand Down Expand Up @@ -1948,7 +1948,10 @@ def __init__(self, a: int):

def __hash__(self):
# test object is not being passed as static
raise Exception('immutable')
raise Exception('this should not be called')

def __eq__(self, __value, /):
raise Exception('this should not be called')

def to_dict(node: Node):
return {'a': node.a}
Expand Down
Loading