Skip to content

Commit

Permalink
Merge pull request #44 from RobertTLange/develop
Browse files Browse the repository at this point in the history
Fix checkpoint loading for LES
  • Loading branch information
RobertTLange authored Mar 6, 2023
2 parents adece30 + eaa6cb0 commit 76700ee
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 5 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
- [ ] Wavelet Based Encoding (van Steenkiste, 2016)
- [ ] CNN Hypernetwork (Ha - start with simple MLP)

### [v0.1.2] - [03/2023]

- Fix LES checkpoint loading from package data via `pkgutil`.

### [v0.1.1] - [03/2023]

##### Added
Expand Down
2 changes: 1 addition & 1 deletion evosax/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.1.1"
__version__ = "0.1.2"
4 changes: 3 additions & 1 deletion evosax/strategies/les.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from flax import struct
import jax
import jax.numpy as jnp
import pkgutil
from ..utils.learned_eo import (
AttentionWeights,
EvoPathMLP,
Expand Down Expand Up @@ -77,7 +78,8 @@ def __init__(
net_ckpt_path = os.path.join(
os.path.dirname(__file__), f"ckpt/{ckpt_fname}"
)
self.les_net_params = load_pkl_object(net_ckpt_path)
data = pkgutil.get_data(__name__, f"ckpt/{ckpt_fname}")
self.les_net_params = load_pkl_object(data, pkg_load=True)
print(f"Loaded pretrained LES model from ckpt: {ckpt_fname}")

@property
Expand Down
10 changes: 7 additions & 3 deletions evosax/utils/learned_eo.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Any
import functools
import sys
import chex
Expand All @@ -18,10 +19,13 @@
import pickle


def load_pkl_object(filename: str) -> chex.ArrayTree:
def load_pkl_object(filename: Any, pkg_load: bool = False) -> chex.ArrayTree:
"""Reload pickle objects from path."""
with open(filename, "rb") as input:
obj = pickle.load(input)
if not pkg_load:
with open(filename, "rb") as input:
obj = pickle.load(input)
else:
obj = pickle.loads(filename)
return obj


Expand Down

0 comments on commit 76700ee

Please sign in to comment.