Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pickle support for Keras model #19555

Merged
merged 7 commits into from
Apr 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions keras/src/models/model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import inspect
import io
import json
import typing
import warnings

import keras.src.saving.saving_lib as saving_lib
from keras.src import backend
from keras.src import utils
from keras.src.api_export import keras_export
Expand Down Expand Up @@ -348,6 +350,30 @@ def load_weights(self, filepath, skip_mismatch=False, **kwargs):
self, filepath, skip_mismatch=skip_mismatch, **kwargs
)

# Note: renaming this function will cause old pickles to be broken.
# This is probably not a huge deal, as pickle should not be a recommended
# saving format -- it should only be supported for use with distributed
# computing frameworks.
@classmethod
def _unpickle_model(cls, bytesio):
# pickle is not safe regardless of what you do.
return saving_lib._load_model_from_fileobj(
bytesio, custom_objects=None, compile=True, safe_mode=False
)

def __reduce__(self):
"""__reduce__ is used to customize the behavior of `pickle.pickle()`.

The method returns a tuple of two elements: a function, and a list of
arguments to pass to that function. In this case we just leverage the
keras saving library."""
buf = io.BytesIO()
saving_lib._save_model_to_fileobj(self, buf, "h5")
return (
self._unpickle_model,
(buf,),
)

def quantize(self, mode):
"""Quantize the weights of the model.

Expand Down
27 changes: 26 additions & 1 deletion keras/src/models/model_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import pickle

import numpy as np
import pytest
from absl.testing import parameterized
Expand Down Expand Up @@ -116,6 +118,29 @@ def call(self, x):
)
self.assertIsInstance(new_model, Functional)

@parameterized.named_parameters(
("single_output_1", _get_model_single_output),
("single_output_2", _get_model_single_output),
("single_output_3", _get_model_single_output),
("single_output_4", _get_model_single_output),
("single_list_output_1", _get_model_single_output_list),
("single_list_output_2", _get_model_single_output_list),
("single_list_output_3", _get_model_single_output_list),
("single_list_output_4", _get_model_single_output_list),
)
def test_functional_pickling(self, model_fn):
model = model_fn()
self.assertIsInstance(model, Functional)
model.compile()
x = np.random.rand(8, 3)

reloaded_pickle = pickle.loads(pickle.dumps(model))

pred_reloaded = reloaded_pickle.predict(x)
pred = model.predict(x)

self.assertAllClose(np.array(pred_reloaded), np.array(pred))

@parameterized.named_parameters(
("single_output_1", _get_model_single_output, None),
("single_output_2", _get_model_single_output, "list"),
Expand All @@ -138,7 +163,7 @@ def test_functional_single_output(self, model_fn, loss_type):
loss = [loss]
elif loss_type == "dict":
loss = {"output_a": loss}
elif loss_type == "dict_lsit":
elif loss_type == "dict_list":
loss = {"output_a": [loss]}
model.compile(
optimizer="sgd",
Expand Down