Skip to content

Commit

Permalink
Custom dataclass decorator (#11)
Browse files Browse the repository at this point in the history
* create custom dataclass decorator

* fix tests

* add typing extension

* ignore overloads from covarage
  • Loading branch information
cgarciae authored Apr 10, 2023
1 parent 41c01f2 commit fb81642
Show file tree
Hide file tree
Showing 7 changed files with 326 additions and 235 deletions.
12 changes: 7 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,24 +72,26 @@ Static fields are not included in the pytree leaves, they
are passed as pytree metadata instead.

### Dataclasses
You can seamlessly use the `dataclasses.dataclass` decorator with `Pytree` classes.
Since `static_field` returns instances of `dataclasses.Field` these it will work as expected:
`simple_pytree` provides a `dataclass` decorator you can use with classes
that contain `static_field`s:

```python
import jax
from dataclasses import dataclass
from simple_pytree import Pytree, static_field
from simple_pytree import Pytree, dataclass, static_field

@dataclass
class Foo(Pytree):
x: int
y: int = static_field(2) # with default value
y: int = static_field(default=2)

foo = Foo(1)
foo = jax.tree_map(lambda x: -x, foo) # y is not modified

assert foo.x == -1 and foo.y == 2
```
`simple_pytree.dataclass` is just a wrapper around `dataclasses.dataclass` but
when used static analysis tools and IDEs will understand that `static_field` is a
field specifier just like `dataclasses.field`.

### Mutability
`Pytree` objects are immutable by default after `__init__`:
Expand Down
352 changes: 194 additions & 158 deletions poetry.lock

Large diffs are not rendered by default.

6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@ description = ""
authors = ["Cristian Garcia <[email protected]>"]
license = "MIT"
readme = "README.md"
packages = [{include = "simple_pytree"}]
packages = [{ include = "simple_pytree" }]

[tool.poetry.dependencies]
python = ">=3.8,<3.12"
jax = "*"
jaxlib = "*"
typing-extensions = "*"


[tool.poetry.group.dev.dependencies]
Expand All @@ -24,3 +25,6 @@ flax = "*"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

[tool.coverage.report]
exclude_lines = ["@tp.overload"]
5 changes: 3 additions & 2 deletions simple_pytree/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
__version__ = "0.1.7"

from .pytree import Pytree, field, static_field
from .dataclass import dataclass, field, static_field
from .pytree import Pytree, PytreeMeta

__all__ = ["Pytree", "field", "static_field"]
__all__ = ["Pytree", "PytreeMeta", "dataclass", "field", "static_field"]
103 changes: 103 additions & 0 deletions simple_pytree/dataclass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import dataclasses
import typing as tp

import typing_extensions as tpe

A = tp.TypeVar("A")


def field(
*,
default: tp.Any = dataclasses.MISSING,
pytree_node: bool = True,
default_factory: tp.Any = dataclasses.MISSING,
init: bool = True,
repr: bool = True,
hash: tp.Optional[bool] = None,
compare: bool = True,
metadata: tp.Optional[tp.Mapping[str, tp.Any]] = None,
):
if metadata is None:
metadata = {}
else:
metadata = dict(metadata)

if "pytree_node" in metadata:
raise ValueError("'pytree_node' found in metadata")

metadata["pytree_node"] = pytree_node

return dataclasses.field( # type: ignore
default=default,
default_factory=default_factory,
init=init,
repr=repr,
hash=hash,
compare=compare,
metadata=metadata,
)


def static_field(
*,
default: tp.Any = dataclasses.MISSING,
default_factory: tp.Any = dataclasses.MISSING,
init: bool = True,
repr: bool = True,
hash: tp.Optional[bool] = None,
compare: bool = True,
metadata: tp.Optional[tp.Mapping[str, tp.Any]] = None,
):
return field(
default=default,
pytree_node=False,
default_factory=default_factory,
init=init,
repr=repr,
hash=hash,
compare=compare,
metadata=metadata,
)


@tp.overload
def dataclass(cls: tp.Type[A]) -> tp.Type[A]:
...


@tp.overload
def dataclass(
*,
init: bool = True,
repr: bool = True,
eq: bool = True,
order: bool = False,
unsafe_hash: bool = False,
frozen: bool = False,
) -> tp.Callable[[tp.Type[A]], tp.Type[A]]:
...


@tpe.dataclass_transform(field_specifiers=(field, static_field, dataclasses.field))
def dataclass(
cls: tp.Optional[tp.Type[A]] = None,
init: bool = True,
repr: bool = True,
eq: bool = True,
order: bool = False,
unsafe_hash: bool = False,
frozen: bool = False,
) -> tp.Union[tp.Type[A], tp.Callable[[tp.Type[A]], tp.Type[A]]]:
decorator = dataclasses.dataclass(
init=init,
repr=repr,
eq=eq,
order=order,
unsafe_hash=unsafe_hash,
frozen=frozen,
)

if cls is None:
return decorator

return decorator(cls)
54 changes: 0 additions & 54 deletions simple_pytree/pytree.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,60 +11,6 @@
P = tp.TypeVar("P", bound="Pytree")


def field(
default: tp.Any = dataclasses.MISSING,
*,
pytree_node: bool = True,
default_factory: tp.Any = dataclasses.MISSING,
init: bool = True,
repr: bool = True,
hash: tp.Optional[bool] = None,
compare: bool = True,
metadata: tp.Optional[tp.Mapping[str, tp.Any]] = None,
):
if metadata is None:
metadata = {}
else:
metadata = dict(metadata)

if "pytree_node" in metadata:
raise ValueError("'pytree_node' found in metadata")

metadata["pytree_node"] = pytree_node

return dataclasses.field(
default=default,
default_factory=default_factory,
init=init,
repr=repr,
hash=hash,
compare=compare,
metadata=metadata,
)


def static_field(
default: tp.Any = dataclasses.MISSING,
*,
default_factory: tp.Any = dataclasses.MISSING,
init: bool = True,
repr: bool = True,
hash: tp.Optional[bool] = None,
compare: bool = True,
metadata: tp.Optional[tp.Mapping[str, tp.Any]] = None,
):
return field(
default=default,
pytree_node=False,
default_factory=default_factory,
init=init,
repr=repr,
hash=hash,
compare=compare,
metadata=metadata,
)


class PytreeMeta(ABCMeta):
def __call__(self: tp.Type[P], *args: tp.Any, **kwargs: tp.Any) -> P:
obj: P = self.__new__(self, *args, **kwargs)
Expand Down
29 changes: 14 additions & 15 deletions tests/test_pytree.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import dataclasses
from typing import Generic, TypeVar

import jax
import pytest
from flax import serialization

from simple_pytree import Pytree, field, static_field
from simple_pytree import Pytree, dataclass, field, static_field


class TestPytree:
Expand Down Expand Up @@ -36,10 +35,10 @@ def __init__(self, y) -> None:
pytree.x = 4

def test_immutable_pytree_dataclass(self):
@dataclasses.dataclass(frozen=True)
@dataclass(frozen=True)
class Foo(Pytree):
y: int = field()
x: int = static_field(2)
x: int = static_field(default=2)

pytree = Foo(y=3)

Expand All @@ -58,7 +57,7 @@ class Foo(Pytree):
pytree.x = 4

def test_jit(self):
@dataclasses.dataclass
@dataclass
class Foo(Pytree):
a: int
b: int = static_field()
Expand All @@ -79,7 +78,7 @@ def __init__(self, a, b):
self.a = a
self.b = b

@dataclasses.dataclass
@dataclass
class Foo(Pytree):
bar: Bar
c: int
Expand Down Expand Up @@ -125,15 +124,15 @@ def __init__(self, x: T):
MyClass[int]

def test_key_paths(self):
@dataclasses.dataclass
@dataclass
class Bar(Pytree):
a: int = 1
b: int = static_field(2)
b: int = static_field(default=2)

@dataclasses.dataclass
@dataclass
class Foo(Pytree):
x: int = 3
y: int = static_field(4)
y: int = static_field(default=4)
z: Bar = field(default_factory=Bar)

foo = Foo()
Expand Down Expand Up @@ -171,12 +170,12 @@ class Foo(Pytree):
Foo().replace(y=1)

def test_dataclass_inheritance(self):
@dataclasses.dataclass
@dataclass
class A(Pytree):
a: int = 1
b: int = static_field(2)
b: int = static_field(default=2)

@dataclasses.dataclass
@dataclass
class B(A):
c: int = 3

Expand Down Expand Up @@ -224,10 +223,10 @@ def __init__(self, y) -> None:
assert pytree.x == 4

def test_pytree_dataclass(self):
@dataclasses.dataclass
@dataclass
class Foo(Pytree, mutable=True):
y: int = field()
x: int = static_field(2)
x: int = static_field(default=2)

pytree: Foo = Foo(y=3)

Expand Down

0 comments on commit fb81642

Please sign in to comment.