Skip to content
This repository has been archived by the owner on May 11, 2023. It is now read-only.

Python 3.8 and fixes #2

Merged
merged 2 commits into from
Mar 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions codecov.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
comment: false

coverage:
status:
project:
default:
target: "97"
patch:
default:
target: "100"

ignore:
- "setup.py"
- "*/tests/.*"
- "__init__.py"
- "tests/*.py"
52 changes: 31 additions & 21 deletions mytree/mytree.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,13 @@ def __init_subclass__(cls, mutable: bool = False):

def replace(self, **kwargs: Any) -> Mytree:
"""
Replace the values of the fields of the object with the values of the
keyword arguments. A new object will be created with the same
type as the original object.
Replace the values of the fields of the object.

Args:
**kwargs: keyword arguments to replace the fields of the object.

Returns:
Mytree: with the fields replaced.
"""
fields = vars(self)
for key in kwargs:
Expand All @@ -46,31 +50,37 @@ def replace(self, **kwargs: Any) -> Mytree:

def replace_meta(self, **kwargs: Any) -> Mytree:
"""
Replace the values of the fields of the object with the values of the
keyword arguments. If the object is a dataclass, `dataclasses.replace`
will be used. Otherwise, a new object will be created with the same
type as the original object.
Replace the metadata of the fields.

Args:
**kwargs: keyword arguments to replace the metadata of the fields of the object.

Returns:
Mytree: with the metadata of the fields replaced.
"""
fields = vars(self)
for key in kwargs:
if key not in self._pytree__meta.keys():
raise ValueError(f"'{key}' is not a leaf of {type(self).__name__}")
if key not in fields:
raise ValueError(f"'{key}' is not a field of {type(self).__name__}")

pytree = copy(self)
pytree.__dict__.update(_pytree__meta={**pytree._pytree__meta, **kwargs})
return pytree

def update_meta(self, **kwargs: Any) -> Mytree:
"""
Replace the values of the fields of the object with the values of the
keyword arguments. If the object is a dataclass, `dataclasses.replace`
will be used. Otherwise, a new object will be created with the same
type as the original object.
Update the metadata of the fields. The metadata must already exist.

Args:
**kwargs: keyword arguments to replace the fields of the object.

Returns:
Mytree: with the fields replaced.
"""
fields = vars(self)
for key in kwargs:
if key not in self._pytree__meta.keys():
raise ValueError(
f"'{key}' is not an attribute of {type(self).__name__}"
)
if key not in fields:
raise ValueError(f"'{key}' is not a field of {type(self).__name__}")

pytree = copy(self)
new = deepcopy(pytree._pytree__meta)
Expand All @@ -82,13 +92,13 @@ def update_meta(self, **kwargs: Any) -> Mytree:
pytree.__dict__.update(_pytree__meta=new)
return pytree

def replace_trainable(Mytree: Mytree, **kwargs: Dict[str, bool]) -> Mytree:
def replace_trainable(self: Mytree, **kwargs: Dict[str, bool]) -> Mytree:
"""Replace the trainability status of local nodes of the Mytree."""
return Mytree.update_meta(**{k: {"trainable": v} for k, v in kwargs.items()})
return self.update_meta(**{k: {"trainable": v} for k, v in kwargs.items()})

def replace_bijector(Mytree: Mytree, **kwargs: Dict[str, Bijector]) -> Mytree:
def replace_bijector(self: Mytree, **kwargs: Dict[str, Bijector]) -> Mytree:
"""Replace the bijectors of local nodes of the Mytree."""
return Mytree.update_meta(**{k: {"bijector": v} for k, v in kwargs.items()})
return self.update_meta(**{k: {"bijector": v} for k, v in kwargs.items()})

def constrain(self) -> Mytree:
"""Transform model parameters to the constrained space according to their defined bijectors.
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def read(*local_path: str) -> str:
long_description=read("README.md"),
long_description_content_type="text/markdown",
packages=find_packages(".", exclude=["tests"]),
python_requires=">=3.9",
python_requires=">=3.8",
install_requires=INSTALL_REQUIRES,
extras_require=EXTRA_REQUIRE,
zip_safe=True,
Expand Down
22 changes: 22 additions & 0 deletions tests/test_mytree.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,28 @@ def loss_fn(model):
assert grad.weight == 0.0
assert grad.bias == 2.0

new = model.replace_meta(bias={"amazing": True})
assert new.weight == 1.0
assert new.bias == 2.0
assert model.weight == 1.0
assert model.bias == 2.0
assert meta(new).bias == {"amazing": True}
assert meta(model).bias == {}

with pytest.raises(ValueError, match=f"'cool' is not a field of SimpleModel"):
model.replace_meta(cool={"don't": "think so"})

with pytest.raises(ValueError, match=f"'cool' is not a field of SimpleModel"):
model.update_meta(cool={"don't": "think so"})

new = model.update_meta(bias={"amazing": True})
assert new.weight == 1.0
assert new.bias == 2.0
assert model.weight == 1.0
assert model.bias == 2.0
assert meta(new).bias == {"amazing": True}
assert meta(model).bias == {}


@pytest.mark.parametrize("is_dataclass", [True, False])
def test_nested_mytree_structure(is_dataclass):
Expand Down