Skip to content

Commit

Permalink
MAINT/ENH: SaveModel based serialization (#128)
Browse files Browse the repository at this point in the history
Co-authored-by: Scott <[email protected]>
  • Loading branch information
adriangb and stsievert committed Jan 16, 2021
1 parent 739fd1d commit 08f8239
Show file tree
Hide file tree
Showing 6 changed files with 276 additions and 87 deletions.
12 changes: 6 additions & 6 deletions scikeras/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
__author__ = """Adrian Garcia Badaracco"""
__version__ = "0.2.1"

from tensorflow import keras as _keras

# Monkey patch log_cosh reference
# See https://github.com/tensorflow/tensorflow/pull/42097
# Will be removed whenever the
# min supported version of tf incorporates the fix
from tensorflow.python import keras # noqa
from scikeras import _saving_utils


keras.metrics.log_cosh = keras.metrics.logcosh
_keras.Model.__reduce__ = _saving_utils.pack_keras_model
_keras.losses.Loss.__reduce__ = _saving_utils.pack_keras_loss
_keras.metrics.Metric.__reduce__ = _saving_utils.pack_keras_metric
_keras.optimizers.Optimizer.__reduce__ = _saving_utils.pack_keras_optimizer
140 changes: 140 additions & 0 deletions scikeras/_saving_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import os
import tempfile
import zipfile

from io import BytesIO
from types import MethodType
from uuid import uuid4 as uuid

import numpy as np

from tensorflow import io as tf_io
from tensorflow import keras
from tensorflow.keras.models import load_model


def _get_temp_folder():
if os.name == "nt":
# the RAM-based filesystem is not fully supported on
# Windows yet, we save to a temp folder on disk instead
return tempfile.mkdtemp()
else:
return f"ram://{tempfile.mkdtemp()}"


def _temp_create_all_weights(self, var_list):
"""A hack to restore weights in optimizers that use slots.
See https://tensorflow.org/api_docs/python/tf/keras/optimizers/Optimizer#slots_2
"""
self._create_all_weights_orig(var_list)
try:
self.set_weights(self._restored_weights)
except ValueError:
# Weights don't match, eg. when optimizer was pickled before any training
# or a completely new dataset is being used right after pickling
pass
delattr(self, "_restored_weights")
self._create_all_weights = self._create_all_weights_orig


def _restore_optimizer_weights(optimizer, weights) -> None:
optimizer._restored_weights = weights
optimizer._create_all_weights_orig = optimizer._create_all_weights
# MethodType is used to "bind" the _temp_create_all_weights method
# to the "live" optimizer object
optimizer._create_all_weights = MethodType(_temp_create_all_weights, optimizer)


def unpack_keras_model(packed_keras_model, optimizer_weights):
"""Reconstruct a model from the result of __reduce__
"""
temp_dir = _get_temp_folder()
b = BytesIO(packed_keras_model)
with zipfile.ZipFile(b, "r", zipfile.ZIP_DEFLATED) as zf:
for path in zf.namelist():
dest = os.path.join(temp_dir, path)
tf_io.gfile.makedirs(os.path.dirname(dest))
with tf_io.gfile.GFile(dest, "wb") as f:
f.write(zf.read(path))
model: keras.Model = load_model(temp_dir)
for root, _, filenames in tf_io.gfile.walk(temp_dir):
for filename in filenames:
if filename.startswith("ram://"):
# Currently, tf.io.gfile.walk returns
# the entire path for the ram:// filesystem
dest = filename
else:
dest = os.path.join(root, filename)
tf_io.gfile.remove(dest)
_restore_optimizer_weights(model.optimizer, optimizer_weights)
return model


def pack_keras_model(model):
"""Support for Pythons's Pickle protocol.
"""
temp_dir = _get_temp_folder()
model.save(temp_dir)
b = BytesIO()
with zipfile.ZipFile(b, "w", zipfile.ZIP_DEFLATED) as zf:
for root, _, filenames in tf_io.gfile.walk(temp_dir):
for filename in filenames:
if filename.startswith("ram://"):
# Currently, tf.io.gfile.walk returns
# the entire path for the ram:// filesystem
dest = filename
else:
dest = os.path.join(root, filename)
with tf_io.gfile.GFile(dest, "rb") as f:
zf.writestr(os.path.relpath(dest, temp_dir), f.read())
tf_io.gfile.remove(dest)
b.seek(0)
return (
unpack_keras_model,
(np.asarray(memoryview(b.read())), model.optimizer.get_weights()),
)


def unpack_keras_optimizer(opt_serialized, weights):
"""Reconstruct optimizer.
"""
optimizer: keras.optimizers.Optimizer = keras.optimizers.deserialize(opt_serialized)
_restore_optimizer_weights(optimizer, weights)
return optimizer


def pack_keras_optimizer(optimizer: keras.optimizers.Optimizer):
"""Support for Pythons's Pickle protocol in Keras Optimizers.
"""
opt_serialized = keras.optimizers.serialize(optimizer)
weights = optimizer.get_weights()
return unpack_keras_optimizer, (opt_serialized, weights)


def unpack_keras_metric(metric_serialized):
"""Reconstruct metric.
"""
metric: keras.metrics.Metric = keras.metrics.deserialize(metric_serialized)
return metric


def pack_keras_metric(metric: keras.metrics.Metric):
"""Support for Pythons's Pickle protocol in Keras Metrics.
"""
metric_serialized = keras.metrics.serialize(metric)
return unpack_keras_metric, (metric_serialized,)


def unpack_keras_loss(loss_serialized):
"""Reconstruct loss.
"""
loss: keras.losses.Loss = keras.losses.deserialize(loss_serialized)
return loss


def pack_keras_loss(loss: keras.losses.Loss):
"""Support for Pythons's Pickle protocol in Keras Losses.
"""
loss_serialized = keras.losses.serialize(loss)
return unpack_keras_loss, (loss_serialized,)
59 changes: 1 addition & 58 deletions scikeras/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,11 @@
import warnings

from inspect import isclass
from typing import Any, Callable, Dict, Iterable, List, Union
from typing import Any, Callable, Dict, Iterable, Union

import numpy as np
import tensorflow as tf

from tensorflow.keras.layers import deserialize as deserialize_layer
from tensorflow.keras.layers import serialize as serialize_layer
from tensorflow.python.keras.saving import saving_utils


class TFRandomState:
def __init__(self, seed):
Expand Down Expand Up @@ -53,59 +49,6 @@ def __exit__(self, type, value, traceback):
tf.random.set_seed(None) # TODO: can we revert instead of unset?


def unpack_keras_model(model, training_config, weights):
"""Creates a new Keras model object using the input
parameters.
Returns
-------
Model
A copy of the input Keras Model,
compiled if the original was compiled.
"""
restored_model = deserialize_layer(model)
if training_config is not None:
restored_model.compile(
**saving_utils.compile_args_from_training_config(training_config)
)
restored_model.set_weights(weights)
restored_model.__reduce_ex__ = pack_keras_model.__get__(restored_model)
return restored_model


def pack_keras_model(model_obj, protocol):
"""Pickle a Keras Model.
Arguments:
model_obj: an instance of a Keras Model.
protocol: pickle protocol version, ignored.
Returns
-------
Pickled model
A tuple following the pickle protocol.
"""
model_metadata = saving_utils.model_metadata(model_obj)
training_config = model_metadata.get("training_config", None)
model = serialize_layer(model_obj)
weights = model_obj.get_weights()
return (unpack_keras_model, (model, training_config, weights))


def make_model_picklable(model_obj):
"""Makes a Keras Model object picklable without cloning.
Arguments:
model_obj: an instance of a Keras Model.
Returns
-------
Model
The input model, but directly picklable.
"""
model_obj.__reduce_ex__ = pack_keras_model.__get__(model_obj)


def route_params(
params: Dict[str, Any],
destination: str,
Expand Down
12 changes: 1 addition & 11 deletions scikeras/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,13 @@
from tensorflow.keras import metrics as metrics_module
from tensorflow.keras import optimizers as optimizers_module
from tensorflow.keras.models import Model
from tensorflow.python.keras.utils.generic_utils import register_keras_serializable
from tensorflow.keras.utils import register_keras_serializable

from scikeras._utils import (
TFRandomState,
_class_from_strings,
accepts_kwargs,
has_param,
make_model_picklable,
route_params,
unflatten_params,
)
Expand Down Expand Up @@ -217,12 +216,6 @@ def __init__(
**kwargs,
):

# ensure prebuilt model can be serialized
if isinstance(model, Model):
make_model_picklable(model)
if isinstance(build_fn, Model):
make_model_picklable(build_fn)

# Parse hardcoded params
self.model = model
self.build_fn = build_fn
Expand Down Expand Up @@ -425,9 +418,6 @@ def _build_keras_model(self):
else:
model = final_build_fn(**build_params)

# make serializable
make_model_picklable(model)

# compile model if user gave us an un-compiled model
if not (hasattr(model, "loss") and hasattr(model, "optimizer")):
if compile_kwargs is None:
Expand Down
3 changes: 1 addition & 2 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,8 +413,7 @@ def test_warm_start():

@register_keras_serializable(name="CustomMetric")
class CustomMetric(metrics_module.MeanAbsoluteError):
def __reduce__(self):
return metrics_module.deserialize, (metrics_module.serialize(self),)
pass


class TestPartialFit:
Expand Down
Loading

0 comments on commit 08f8239

Please sign in to comment.