From eaa6cb0f594488e95bb2ea2149dba44eada2f1f0 Mon Sep 17 00:00:00 2001 From: RobertTLange Date: Mon, 6 Mar 2023 17:26:52 +0100 Subject: [PATCH] Fix checkpoint loading for LES --- CHANGELOG.md | 4 ++++ evosax/_version.py | 2 +- evosax/strategies/les.py | 4 +++- evosax/utils/learned_eo.py | 10 +++++++--- 4 files changed, 15 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1f16218..06a0950 100755 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,10 @@ - [ ] Wavelet Based Encoding (van Steenkiste, 2016) - [ ] CNN Hypernetwork (Ha - start with simple MLP) +### [v0.1.2] - [03/2023] + +- Fix LES checkpoint loading from package data via `pkgutil`. + ### [v0.1.1] - [03/2023] ##### Added diff --git a/evosax/_version.py b/evosax/_version.py index 485f44a..b3f4756 100644 --- a/evosax/_version.py +++ b/evosax/_version.py @@ -1 +1 @@ -__version__ = "0.1.1" +__version__ = "0.1.2" diff --git a/evosax/strategies/les.py b/evosax/strategies/les.py index 2a0b7e7..f1f30c1 100644 --- a/evosax/strategies/les.py +++ b/evosax/strategies/les.py @@ -4,6 +4,7 @@ from flax import struct import jax import jax.numpy as jnp +import pkgutil from ..utils.learned_eo import ( AttentionWeights, EvoPathMLP, @@ -77,7 +78,8 @@ def __init__( net_ckpt_path = os.path.join( os.path.dirname(__file__), f"ckpt/{ckpt_fname}" ) - self.les_net_params = load_pkl_object(net_ckpt_path) + data = pkgutil.get_data(__name__, f"ckpt/{ckpt_fname}") + self.les_net_params = load_pkl_object(data, pkg_load=True) print(f"Loaded pretrained LES model from ckpt: {ckpt_fname}") @property diff --git a/evosax/utils/learned_eo.py b/evosax/utils/learned_eo.py index fad40b5..9cdc46d 100644 --- a/evosax/utils/learned_eo.py +++ b/evosax/utils/learned_eo.py @@ -1,3 +1,4 @@ +from typing import Any import functools import sys import chex @@ -18,10 +19,13 @@ import pickle -def load_pkl_object(filename: str) -> chex.ArrayTree: +def load_pkl_object(filename: Any, pkg_load: bool = False) -> chex.ArrayTree: """Reload pickle objects from path.""" - with open(filename, "rb") as input: - obj = pickle.load(input) + if not pkg_load: + with open(filename, "rb") as input: + obj = pickle.load(input) + else: + obj = pickle.loads(filename) return obj