Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding static_keynames support #327

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 39 additions & 6 deletions chex/_src/dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def dataclass(
frozen=False,
kw_only: bool = False,
mappable_dataclass=True, # pylint: disable=redefined-outer-name
static_keynames=None,
):
"""JAX-friendly wrapper for :py:func:`dataclasses.dataclass`.

Expand All @@ -122,14 +123,24 @@ def dataclass(
As a side-effect, e.g. `np.testing.assert_array_equal` will only check
the field names are equal and not the content. Use `chex.assert_tree_*`
instead.
static_keynames: A list of field names that should be ignored by JAX
transformations.

Returns:
A JAX-friendly dataclass.
"""
def dcls(cls):
# Make sure to create a separate _Dataclass instance for each `cls`.
return _Dataclass(
init, repr, eq, order, unsafe_hash, frozen, kw_only, mappable_dataclass
init,
repr,
eq,
order,
unsafe_hash,
frozen,
kw_only,
mappable_dataclass,
static_keynames,
)(cls)

if cls is None:
Expand All @@ -150,6 +161,7 @@ def __init__(
frozen=False,
kw_only=False,
mappable_dataclass=True, # pylint: disable=redefined-outer-name
static_keynames=None,
):
self.init = init
self.repr = repr # pylint: disable=redefined-builtin
Expand All @@ -159,6 +171,7 @@ def __init__(
self.frozen = frozen
self.kw_only = kw_only
self.mappable_dataclass = mappable_dataclass
self.static_keynames = static_keynames

def __call__(self, cls):
"""Forwards class to dataclasses's wrapper and registers it with JAX."""
Expand Down Expand Up @@ -233,14 +246,18 @@ def _init(self, *args, **kwargs):
setattr(dcls, "__getstate__", _getstate)
setattr(dcls, "__setstate__", _setstate)
setattr(dcls, "__init__", _init)
setattr(dcls, "static_keynames", self.static_keynames)

return dcls


def _dataclass_unflatten(dcls, keys, values):
def _dataclass_unflatten(dcls, aux_pack, values):
"""Creates a chex dataclass from a flatten jax.tree_util representation."""
dcls_object = dcls.__new__(dcls)
keys, static_keynames, static_keyvals = aux_pack
attribute_dict = dict(zip(keys, values))
# Add static items to the attribute dict.
attribute_dict.update(zip(static_keynames, static_keyvals))
# Looping over fields instead of keys & values preserves the field order.
# Using dataclasses.fields fails because dataclass uids change after
# serialisation (eg, with cloudpickle).
Expand All @@ -256,11 +273,28 @@ def _dataclass_unflatten(dcls, keys, values):
def _flatten_with_path(dcls):
path = []
keys = []

static_keynames = []
static_keyvals = []

for k, v in sorted(dcls.__dict__.items()):
k = jax.tree_util.GetAttrKey(k)
path.append((k, v))
keys.append(k)
return path, keys
# Store the static keys separately.
if (dcls.static_keynames is not None and
k.name in dcls.static_keynames):
static_keynames.append(k.name)
static_keyvals.append(v)
else:
path.append((k, v))
keys.append(k)
return path, (keys, static_keynames, static_keyvals)


def flatten(dcls):
paths, (keys, static_keynames, static_keyvals) = _flatten_with_path(dcls)
vals = [p[1] for p in paths]
keys = [k.name for k in keys]
return vals, (keys, static_keynames, static_keyvals)


@functools.cache
Expand All @@ -277,7 +311,6 @@ def register_dataclass_type_with_jax_tree_util(data_class):
constructable from keyword arguments corresponding to the members exposed
in instance.__dict__.
"""
flatten = lambda d: jax.util.unzip2(sorted(d.__dict__.items()))[::-1]
unflatten = functools.partial(_dataclass_unflatten, data_class)
try:
jax.tree_util.register_pytree_with_keys(
Expand Down
Loading