diff --git a/keras/src/models/model.py b/keras/src/models/model.py index ca12459354f..3ee5abd4c72 100644 --- a/keras/src/models/model.py +++ b/keras/src/models/model.py @@ -1,8 +1,10 @@ import inspect +import io import json import typing import warnings +import keras.src.saving.saving_lib as saving_lib from keras.src import backend from keras.src import utils from keras.src.api_export import keras_export @@ -348,6 +350,30 @@ def load_weights(self, filepath, skip_mismatch=False, **kwargs): self, filepath, skip_mismatch=skip_mismatch, **kwargs ) + # Note: renaming this function will cause old pickles to be broken. + # This is probably not a huge deal, as pickle should not be a recommended + # saving format -- it should only be supported for use with distributed + # computing frameworks. + @classmethod + def _unpickle_model(cls, bytesio): + # pickle is not safe regardless of what you do. + return saving_lib._load_model_from_fileobj( + bytesio, custom_objects=None, compile=True, safe_mode=False + ) + + def __reduce__(self): + """__reduce__ is used to customize the behavior of `pickle.pickle()`. + + The method returns a tuple of two elements: a function, and a list of + arguments to pass to that function. In this case we just leverage the + keras saving library.""" + buf = io.BytesIO() + saving_lib._save_model_to_fileobj(self, buf, "h5") + return ( + self._unpickle_model, + (buf,), + ) + def quantize(self, mode): """Quantize the weights of the model. diff --git a/keras/src/models/model_test.py b/keras/src/models/model_test.py index 871fc4bff19..7fa91c5b95d 100644 --- a/keras/src/models/model_test.py +++ b/keras/src/models/model_test.py @@ -1,3 +1,5 @@ +import pickle + import numpy as np import pytest from absl.testing import parameterized @@ -116,6 +118,29 @@ def call(self, x): ) self.assertIsInstance(new_model, Functional) + @parameterized.named_parameters( + ("single_output_1", _get_model_single_output), + ("single_output_2", _get_model_single_output), + ("single_output_3", _get_model_single_output), + ("single_output_4", _get_model_single_output), + ("single_list_output_1", _get_model_single_output_list), + ("single_list_output_2", _get_model_single_output_list), + ("single_list_output_3", _get_model_single_output_list), + ("single_list_output_4", _get_model_single_output_list), + ) + def test_functional_pickling(self, model_fn): + model = model_fn() + self.assertIsInstance(model, Functional) + model.compile() + x = np.random.rand(8, 3) + + reloaded_pickle = pickle.loads(pickle.dumps(model)) + + pred_reloaded = reloaded_pickle.predict(x) + pred = model.predict(x) + + self.assertAllClose(np.array(pred_reloaded), np.array(pred)) + @parameterized.named_parameters( ("single_output_1", _get_model_single_output, None), ("single_output_2", _get_model_single_output, "list"), @@ -138,7 +163,7 @@ def test_functional_single_output(self, model_fn, loss_type): loss = [loss] elif loss_type == "dict": loss = {"output_a": loss} - elif loss_type == "dict_lsit": + elif loss_type == "dict_list": loss = {"output_a": [loss]} model.compile( optimizer="sgd",