From 1ecbcb87adc20b9be9b566c772349a4fca9a40d3 Mon Sep 17 00:00:00 2001 From: Luke Wood Date: Fri, 19 Apr 2024 01:44:47 -0400 Subject: [PATCH 1/7] Implement unit tests for pickling --- keras/src/models/model.py | 27 +++++++++++++++++++++++++++ keras/src/models/model_test.py | 14 +++++++++++++- 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/keras/src/models/model.py b/keras/src/models/model.py index ca12459354f..f160bfa8017 100644 --- a/keras/src/models/model.py +++ b/keras/src/models/model.py @@ -2,6 +2,7 @@ import json import typing import warnings +import io from keras.src import backend from keras.src import utils @@ -12,6 +13,8 @@ from keras.src.trainers import trainer as base_trainer from keras.src.utils import summary_utils from keras.src.utils import traceback_utils +import keras.src.saving.saving_lib as saving_lib + if backend.backend() == "tensorflow": from keras.src.backend.tensorflow.trainer import ( @@ -348,6 +351,30 @@ def load_weights(self, filepath, skip_mismatch=False, **kwargs): self, filepath, skip_mismatch=skip_mismatch, **kwargs ) + # Note: renaming this function will cause old pickles to be broken. + # This is probably not a huge deal, as pickle should not be a recommended + # saving format -- it should only be supported for use with distributed + # computing frameworks. + @classmethod + def _depickle_model(cls, bytesio): + # pickle is not safe regardless of what you do. + return saving_lib._load_model_from_fileobj( + bytesio, custom_objects=None, compile=True, safe_mode=False + ) + + def __reduce__(self): + """__reduce__ is used to customize the behavior of `pickle.pickle()`. + + The method returns a tuple of two elements: a function, and a list of + arguments to pass to that function. In this case we just leverage the + keras saving library.""" + buf = io.BytesIO() + saving_lib._save_model_to_fileobj(self, buf, "h5") + return ( + self._depickle_model, + (buf,), + ) + def quantize(self, mode): """Quantize the weights of the model. diff --git a/keras/src/models/model_test.py b/keras/src/models/model_test.py index 871fc4bff19..38e427b687f 100644 --- a/keras/src/models/model_test.py +++ b/keras/src/models/model_test.py @@ -9,6 +9,7 @@ from keras.src.models.functional import Functional from keras.src.models.model import Model from keras.src.models.model import model_from_json +import pickle def _get_model(): @@ -138,7 +139,7 @@ def test_functional_single_output(self, model_fn, loss_type): loss = [loss] elif loss_type == "dict": loss = {"output_a": loss} - elif loss_type == "dict_lsit": + elif loss_type == "dict_list": loss = {"output_a": [loss]} model.compile( optimizer="sgd", @@ -171,6 +172,17 @@ def test_functional_single_output(self, model_fn, loss_type): ) self.assertListEqual(hist_keys, ref_keys) + reloaded_pickle = pickle.loads(pickle.dumps(model)) + # self.assertAllClose fails for some dtypes + + pred_reloaded = reloaded_pickle.predict(x) + pred = model.predict(x) + if isinstance(pred, dict): + for key in pred: + np.testing.assert_allclose(pred_reloaded[key], pred[key]) + else: + np.testing.assert_allclose(pred_reloaded, pred) + def test_functional_list_outputs_list_losses(self): model = _get_model_multi_outputs_list() self.assertIsInstance(model, Functional) From de7f76252c69c9f2d87b01dcf7f13b1b08febf1a Mon Sep 17 00:00:00 2001 From: Luke Wood Date: Fri, 19 Apr 2024 01:48:05 -0400 Subject: [PATCH 2/7] Reformat model_test --- keras/src/models/model_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/keras/src/models/model_test.py b/keras/src/models/model_test.py index 38e427b687f..81f2d898ffd 100644 --- a/keras/src/models/model_test.py +++ b/keras/src/models/model_test.py @@ -174,11 +174,11 @@ def test_functional_single_output(self, model_fn, loss_type): reloaded_pickle = pickle.loads(pickle.dumps(model)) # self.assertAllClose fails for some dtypes - + pred_reloaded = reloaded_pickle.predict(x) pred = model.predict(x) if isinstance(pred, dict): - for key in pred: + for key in pred: np.testing.assert_allclose(pred_reloaded[key], pred[key]) else: np.testing.assert_allclose(pred_reloaded, pred) From 9ae1a60b36bc8b61b799a749ac8bb52dbda01346 Mon Sep 17 00:00:00 2001 From: Luke Wood Date: Fri, 19 Apr 2024 01:51:52 -0400 Subject: [PATCH 3/7] Reformat model_test --- keras/src/models/model.py | 5 ++--- keras/src/models/model_test.py | 6 ++++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/keras/src/models/model.py b/keras/src/models/model.py index f160bfa8017..dae91275b16 100644 --- a/keras/src/models/model.py +++ b/keras/src/models/model.py @@ -1,9 +1,10 @@ import inspect +import io import json import typing import warnings -import io +import keras.src.saving.saving_lib as saving_lib from keras.src import backend from keras.src import utils from keras.src.api_export import keras_export @@ -13,8 +14,6 @@ from keras.src.trainers import trainer as base_trainer from keras.src.utils import summary_utils from keras.src.utils import traceback_utils -import keras.src.saving.saving_lib as saving_lib - if backend.backend() == "tensorflow": from keras.src.backend.tensorflow.trainer import ( diff --git a/keras/src/models/model_test.py b/keras/src/models/model_test.py index 81f2d898ffd..9d42effad2b 100644 --- a/keras/src/models/model_test.py +++ b/keras/src/models/model_test.py @@ -1,3 +1,5 @@ +import pickle + import numpy as np import pytest from absl.testing import parameterized @@ -9,7 +11,6 @@ from keras.src.models.functional import Functional from keras.src.models.model import Model from keras.src.models.model import model_from_json -import pickle def _get_model(): @@ -173,10 +174,11 @@ def test_functional_single_output(self, model_fn, loss_type): self.assertListEqual(hist_keys, ref_keys) reloaded_pickle = pickle.loads(pickle.dumps(model)) - # self.assertAllClose fails for some dtypes pred_reloaded = reloaded_pickle.predict(x) pred = model.predict(x) + + # self.assertAllClose fails for some dtypes, so we use np if isinstance(pred, dict): for key in pred: np.testing.assert_allclose(pred_reloaded[key], pred[key]) From 1ccee9d9826f2480c9676a47c6b502aa77d83b4e Mon Sep 17 00:00:00 2001 From: Luke Wood Date: Fri, 19 Apr 2024 10:50:26 -0400 Subject: [PATCH 4/7] Rename depickle to unpickle --- keras/src/models/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/keras/src/models/model.py b/keras/src/models/model.py index dae91275b16..3ee5abd4c72 100644 --- a/keras/src/models/model.py +++ b/keras/src/models/model.py @@ -355,7 +355,7 @@ def load_weights(self, filepath, skip_mismatch=False, **kwargs): # saving format -- it should only be supported for use with distributed # computing frameworks. @classmethod - def _depickle_model(cls, bytesio): + def _unpickle_model(cls, bytesio): # pickle is not safe regardless of what you do. return saving_lib._load_model_from_fileobj( bytesio, custom_objects=None, compile=True, safe_mode=False @@ -370,7 +370,7 @@ def __reduce__(self): buf = io.BytesIO() saving_lib._save_model_to_fileobj(self, buf, "h5") return ( - self._depickle_model, + self._unpickle_model, (buf,), ) From 3c8e6fd4f97622ff8f174e74eab7df735c5dbae9 Mon Sep 17 00:00:00 2001 From: Luke Wood Date: Fri, 19 Apr 2024 16:40:11 -0400 Subject: [PATCH 5/7] Rename depickle to unpickle --- keras/src/models/model_test.py | 37 +++++++++++++++++++++++----------- 1 file changed, 25 insertions(+), 12 deletions(-) diff --git a/keras/src/models/model_test.py b/keras/src/models/model_test.py index 9d42effad2b..7257d8b4fdb 100644 --- a/keras/src/models/model_test.py +++ b/keras/src/models/model_test.py @@ -118,6 +118,31 @@ def call(self, x): ) self.assertIsInstance(new_model, Functional) + @parameterized.named_parameters( + ("single_output_1", _get_model_single_output), + ("single_output_2", _get_model_single_output), + ("single_output_3", _get_model_single_output), + ("single_output_4", _get_model_single_output), + ("single_list_output_1", _get_model_single_output_list), + ("single_list_output_2", _get_model_single_output_list), + ("single_list_output_3", _get_model_single_output_list), + ("single_list_output_4", _get_model_single_output_list), + ) + def test_functional_pickling(self, model_fn): + model = model_fn() + self.assertIsInstance(model, Functional) + model.compile() + x = np.random.rand(8, 3) + + reloaded_pickle = pickle.loads(pickle.dumps(model)) + + pred_reloaded = reloaded_pickle.predict(x) + pred = model.predict(x) + + # self.assertAllClose fails for some dtypes, so we use np + self.assertAllClose(np.array(pred_reloaded), np.array(pred)) + + @parameterized.named_parameters( ("single_output_1", _get_model_single_output, None), ("single_output_2", _get_model_single_output, "list"), @@ -173,18 +198,6 @@ def test_functional_single_output(self, model_fn, loss_type): ) self.assertListEqual(hist_keys, ref_keys) - reloaded_pickle = pickle.loads(pickle.dumps(model)) - - pred_reloaded = reloaded_pickle.predict(x) - pred = model.predict(x) - - # self.assertAllClose fails for some dtypes, so we use np - if isinstance(pred, dict): - for key in pred: - np.testing.assert_allclose(pred_reloaded[key], pred[key]) - else: - np.testing.assert_allclose(pred_reloaded, pred) - def test_functional_list_outputs_list_losses(self): model = _get_model_multi_outputs_list() self.assertIsInstance(model, Functional) From 1c87cc318d5de0629d3aa51098ee7d4260546cb7 Mon Sep 17 00:00:00 2001 From: Luke Wood Date: Fri, 19 Apr 2024 16:41:04 -0400 Subject: [PATCH 6/7] Reformat --- keras/src/models/model_test.py | 1 - 1 file changed, 1 deletion(-) diff --git a/keras/src/models/model_test.py b/keras/src/models/model_test.py index 7257d8b4fdb..7f389197930 100644 --- a/keras/src/models/model_test.py +++ b/keras/src/models/model_test.py @@ -142,7 +142,6 @@ def test_functional_pickling(self, model_fn): # self.assertAllClose fails for some dtypes, so we use np self.assertAllClose(np.array(pred_reloaded), np.array(pred)) - @parameterized.named_parameters( ("single_output_1", _get_model_single_output, None), ("single_output_2", _get_model_single_output, "list"), From 0994d27a79103700031f1761b78a76926a0bd63c Mon Sep 17 00:00:00 2001 From: Luke Wood Date: Fri, 19 Apr 2024 16:45:04 -0400 Subject: [PATCH 7/7] remove a comment --- keras/src/models/model_test.py | 1 - 1 file changed, 1 deletion(-) diff --git a/keras/src/models/model_test.py b/keras/src/models/model_test.py index 7f389197930..7fa91c5b95d 100644 --- a/keras/src/models/model_test.py +++ b/keras/src/models/model_test.py @@ -139,7 +139,6 @@ def test_functional_pickling(self, model_fn): pred_reloaded = reloaded_pickle.predict(x) pred = model.predict(x) - # self.assertAllClose fails for some dtypes, so we use np self.assertAllClose(np.array(pred_reloaded), np.array(pred)) @parameterized.named_parameters(