From 1d7aad846f0932a8e8020bae7c178e97198118ca Mon Sep 17 00:00:00 2001 From: Misha Lvovsky Date: Sun, 24 Dec 2023 00:54:23 -0500 Subject: [PATCH 1/3] added static_keynames support --- chex/_src/dataclass.py | 34 ++++++++++++++++++++++++++++------ 1 file changed, 28 insertions(+), 6 deletions(-) diff --git a/chex/_src/dataclass.py b/chex/_src/dataclass.py index c262d40..7fa85c6 100644 --- a/chex/_src/dataclass.py +++ b/chex/_src/dataclass.py @@ -96,6 +96,7 @@ def dataclass( frozen=False, kw_only: bool = False, mappable_dataclass=True, # pylint: disable=redefined-outer-name + static_keynames=[], ): """JAX-friendly wrapper for :py:func:`dataclasses.dataclass`. @@ -129,7 +130,7 @@ def dataclass( def dcls(cls): # Make sure to create a separate _Dataclass instance for each `cls`. return _Dataclass( - init, repr, eq, order, unsafe_hash, frozen, kw_only, mappable_dataclass + init, repr, eq, order, unsafe_hash, frozen, kw_only, mappable_dataclass, static_keynames )(cls) if cls is None: @@ -150,6 +151,7 @@ def __init__( frozen=False, kw_only=False, mappable_dataclass=True, # pylint: disable=redefined-outer-name + static_keynames=[], ): self.init = init self.repr = repr # pylint: disable=redefined-builtin @@ -159,6 +161,7 @@ def __init__( self.frozen = frozen self.kw_only = kw_only self.mappable_dataclass = mappable_dataclass + self.static_keynames = static_keynames def __call__(self, cls): """Forwards class to dataclasses's wrapper and registers it with JAX.""" @@ -233,14 +236,18 @@ def _init(self, *args, **kwargs): setattr(dcls, "__getstate__", _getstate) setattr(dcls, "__setstate__", _setstate) setattr(dcls, "__init__", _init) + setattr(dcls, "_static_keynames", self.static_keynames) return dcls -def _dataclass_unflatten(dcls, keys, values): +def _dataclass_unflatten(dcls, aux_pack, values): """Creates a chex dataclass from a flatten jax.tree_util representation.""" dcls_object = dcls.__new__(dcls) + keys, static_keynames, static_keyvals = aux_pack attribute_dict = dict(zip(keys, values)) + # Add static items to the attribute dict. + attribute_dict.update(zip(static_keynames, static_keyvals)) # Looping over fields instead of keys & values preserves the field order. # Using dataclasses.fields fails because dataclass uids change after # serialisation (eg, with cloudpickle). @@ -256,11 +263,27 @@ def _dataclass_unflatten(dcls, keys, values): def _flatten_with_path(dcls): path = [] keys = [] + + static_keynames = [] + static_keyvals = [] + for k, v in sorted(dcls.__dict__.items()): k = jax.tree_util.GetAttrKey(k) - path.append((k, v)) - keys.append(k) - return path, keys + # Store the static keys separately. + if k.name in dcls._static_keynames: + static_keynames.append(k) + static_keyvals.append(v) + else: + path.append((k, v)) + keys.append(k) + return path, (keys, static_keynames, static_keyvals) + + +def flatten(dcls): + paths, (keys, static_keynames, static_keyvals) = _flatten_with_path(dcls) + vals = [p[1] for p in paths] + keys = [k.name for k in keys] + return vals, (keys, static_keynames, static_keyvals) @functools.cache @@ -277,7 +300,6 @@ def register_dataclass_type_with_jax_tree_util(data_class): constructable from keyword arguments corresponding to the members exposed in instance.__dict__. """ - flatten = lambda d: jax.util.unzip2(sorted(d.__dict__.items()))[::-1] unflatten = functools.partial(_dataclass_unflatten, data_class) try: jax.tree_util.register_pytree_with_keys( From 31999bb5fcc959aa0841586495d109989216aa9f Mon Sep 17 00:00:00 2001 From: Misha Lvovsky Date: Sun, 24 Dec 2023 01:15:07 -0500 Subject: [PATCH 2/3] made tests happy and added new field to docstring --- chex/_src/dataclass.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/chex/_src/dataclass.py b/chex/_src/dataclass.py index 7fa85c6..69a7b97 100644 --- a/chex/_src/dataclass.py +++ b/chex/_src/dataclass.py @@ -96,7 +96,7 @@ def dataclass( frozen=False, kw_only: bool = False, mappable_dataclass=True, # pylint: disable=redefined-outer-name - static_keynames=[], + static_keynames=None, ): """JAX-friendly wrapper for :py:func:`dataclasses.dataclass`. @@ -123,6 +123,8 @@ def dataclass( As a side-effect, e.g. `np.testing.assert_array_equal` will only check the field names are equal and not the content. Use `chex.assert_tree_*` instead. + static_keynames: A list of field names that should be ignored by JAX + transformations. Returns: A JAX-friendly dataclass. @@ -130,7 +132,15 @@ def dataclass( def dcls(cls): # Make sure to create a separate _Dataclass instance for each `cls`. return _Dataclass( - init, repr, eq, order, unsafe_hash, frozen, kw_only, mappable_dataclass, static_keynames + init, + repr, + eq, + order, + unsafe_hash, + frozen, + kw_only, + mappable_dataclass, + static_keynames, )(cls) if cls is None: @@ -151,7 +161,7 @@ def __init__( frozen=False, kw_only=False, mappable_dataclass=True, # pylint: disable=redefined-outer-name - static_keynames=[], + static_keynames=None, ): self.init = init self.repr = repr # pylint: disable=redefined-builtin @@ -236,7 +246,7 @@ def _init(self, *args, **kwargs): setattr(dcls, "__getstate__", _getstate) setattr(dcls, "__setstate__", _setstate) setattr(dcls, "__init__", _init) - setattr(dcls, "_static_keynames", self.static_keynames) + setattr(dcls, "static_keynames", self.static_keynames) return dcls @@ -270,7 +280,8 @@ def _flatten_with_path(dcls): for k, v in sorted(dcls.__dict__.items()): k = jax.tree_util.GetAttrKey(k) # Store the static keys separately. - if k.name in dcls._static_keynames: + if (dcls.static_keynames is not None and + k.name in dcls.static_keynames): static_keynames.append(k) static_keyvals.append(v) else: From 7a3446e1da9dc160758d7c7a631484cf46c50e9c Mon Sep 17 00:00:00 2001 From: Misha Lvovsky Date: Sun, 24 Dec 2023 01:48:41 -0500 Subject: [PATCH 3/3] fixed a minor mistake leading to incorrect unflattening for static keys --- chex/_src/dataclass.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chex/_src/dataclass.py b/chex/_src/dataclass.py index 69a7b97..de12133 100644 --- a/chex/_src/dataclass.py +++ b/chex/_src/dataclass.py @@ -282,7 +282,7 @@ def _flatten_with_path(dcls): # Store the static keys separately. if (dcls.static_keynames is not None and k.name in dcls.static_keynames): - static_keynames.append(k) + static_keynames.append(k.name) static_keyvals.append(v) else: path.append((k, v))