Skip to content
This repository has been archived by the owner on May 11, 2023. It is now read-only.

Commit

Permalink
Merge pull request #2 from Daniel-Dodd/Python-3.8-and-fixes
Browse files Browse the repository at this point in the history
Python 3.8 and fixes
  • Loading branch information
daniel-dodd authored Mar 22, 2023
2 parents 2a51170 + ee3074e commit 5e2e581
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 22 deletions.
16 changes: 16 additions & 0 deletions codecov.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
comment: false

coverage:
status:
project:
default:
target: "97"
patch:
default:
target: "100"

ignore:
- "setup.py"
- "*/tests/.*"
- "__init__.py"
- "tests/*.py"
52 changes: 31 additions & 21 deletions mytree/mytree.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,13 @@ def __init_subclass__(cls, mutable: bool = False):

def replace(self, **kwargs: Any) -> Mytree:
"""
Replace the values of the fields of the object with the values of the
keyword arguments. A new object will be created with the same
type as the original object.
Replace the values of the fields of the object.
Args:
**kwargs: keyword arguments to replace the fields of the object.
Returns:
Mytree: with the fields replaced.
"""
fields = vars(self)
for key in kwargs:
Expand All @@ -46,31 +50,37 @@ def replace(self, **kwargs: Any) -> Mytree:

def replace_meta(self, **kwargs: Any) -> Mytree:
"""
Replace the values of the fields of the object with the values of the
keyword arguments. If the object is a dataclass, `dataclasses.replace`
will be used. Otherwise, a new object will be created with the same
type as the original object.
Replace the metadata of the fields.
Args:
**kwargs: keyword arguments to replace the metadata of the fields of the object.
Returns:
Mytree: with the metadata of the fields replaced.
"""
fields = vars(self)
for key in kwargs:
if key not in self._pytree__meta.keys():
raise ValueError(f"'{key}' is not a leaf of {type(self).__name__}")
if key not in fields:
raise ValueError(f"'{key}' is not a field of {type(self).__name__}")

pytree = copy(self)
pytree.__dict__.update(_pytree__meta={**pytree._pytree__meta, **kwargs})
return pytree

def update_meta(self, **kwargs: Any) -> Mytree:
"""
Replace the values of the fields of the object with the values of the
keyword arguments. If the object is a dataclass, `dataclasses.replace`
will be used. Otherwise, a new object will be created with the same
type as the original object.
Update the metadata of the fields. The metadata must already exist.
Args:
**kwargs: keyword arguments to replace the fields of the object.
Returns:
Mytree: with the fields replaced.
"""
fields = vars(self)
for key in kwargs:
if key not in self._pytree__meta.keys():
raise ValueError(
f"'{key}' is not an attribute of {type(self).__name__}"
)
if key not in fields:
raise ValueError(f"'{key}' is not a field of {type(self).__name__}")

pytree = copy(self)
new = deepcopy(pytree._pytree__meta)
Expand All @@ -82,13 +92,13 @@ def update_meta(self, **kwargs: Any) -> Mytree:
pytree.__dict__.update(_pytree__meta=new)
return pytree

def replace_trainable(Mytree: Mytree, **kwargs: Dict[str, bool]) -> Mytree:
def replace_trainable(self: Mytree, **kwargs: Dict[str, bool]) -> Mytree:
"""Replace the trainability status of local nodes of the Mytree."""
return Mytree.update_meta(**{k: {"trainable": v} for k, v in kwargs.items()})
return self.update_meta(**{k: {"trainable": v} for k, v in kwargs.items()})

def replace_bijector(Mytree: Mytree, **kwargs: Dict[str, Bijector]) -> Mytree:
def replace_bijector(self: Mytree, **kwargs: Dict[str, Bijector]) -> Mytree:
"""Replace the bijectors of local nodes of the Mytree."""
return Mytree.update_meta(**{k: {"bijector": v} for k, v in kwargs.items()})
return self.update_meta(**{k: {"bijector": v} for k, v in kwargs.items()})

def constrain(self) -> Mytree:
"""Transform model parameters to the constrained space according to their defined bijectors.
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def read(*local_path: str) -> str:
long_description=read("README.md"),
long_description_content_type="text/markdown",
packages=find_packages(".", exclude=["tests"]),
python_requires=">=3.9",
python_requires=">=3.8",
install_requires=INSTALL_REQUIRES,
extras_require=EXTRA_REQUIRE,
zip_safe=True,
Expand Down
22 changes: 22 additions & 0 deletions tests/test_mytree.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,28 @@ def loss_fn(model):
assert grad.weight == 0.0
assert grad.bias == 2.0

new = model.replace_meta(bias={"amazing": True})
assert new.weight == 1.0
assert new.bias == 2.0
assert model.weight == 1.0
assert model.bias == 2.0
assert meta(new).bias == {"amazing": True}
assert meta(model).bias == {}

with pytest.raises(ValueError, match=f"'cool' is not a field of SimpleModel"):
model.replace_meta(cool={"don't": "think so"})

with pytest.raises(ValueError, match=f"'cool' is not a field of SimpleModel"):
model.update_meta(cool={"don't": "think so"})

new = model.update_meta(bias={"amazing": True})
assert new.weight == 1.0
assert new.bias == 2.0
assert model.weight == 1.0
assert model.bias == 2.0
assert meta(new).bias == {"amazing": True}
assert meta(model).bias == {}


@pytest.mark.parametrize("is_dataclass", [True, False])
def test_nested_mytree_structure(is_dataclass):
Expand Down

0 comments on commit 5e2e581

Please sign in to comment.