Skip to content

Commit

Permalink
Add pickle support for Keras model (#19555)
Browse files Browse the repository at this point in the history
* Implement unit tests for pickling

* Reformat model_test

* Reformat model_test

* Rename depickle to unpickle

* Rename depickle to unpickle

* Reformat

* remove a comment
  • Loading branch information
LukeWood authored Apr 20, 2024
1 parent 29d10d1 commit 6e42834
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 1 deletion.
26 changes: 26 additions & 0 deletions keras/src/models/model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import inspect
import io
import json
import typing
import warnings

import keras.src.saving.saving_lib as saving_lib
from keras.src import backend
from keras.src import utils
from keras.src.api_export import keras_export
Expand Down Expand Up @@ -348,6 +350,30 @@ def load_weights(self, filepath, skip_mismatch=False, **kwargs):
self, filepath, skip_mismatch=skip_mismatch, **kwargs
)

# Note: renaming this function will cause old pickles to be broken.
# This is probably not a huge deal, as pickle should not be a recommended
# saving format -- it should only be supported for use with distributed
# computing frameworks.
@classmethod
def _unpickle_model(cls, bytesio):
# pickle is not safe regardless of what you do.
return saving_lib._load_model_from_fileobj(
bytesio, custom_objects=None, compile=True, safe_mode=False
)

def __reduce__(self):
"""__reduce__ is used to customize the behavior of `pickle.pickle()`.
The method returns a tuple of two elements: a function, and a list of
arguments to pass to that function. In this case we just leverage the
keras saving library."""
buf = io.BytesIO()
saving_lib._save_model_to_fileobj(self, buf, "h5")
return (
self._unpickle_model,
(buf,),
)

def quantize(self, mode):
"""Quantize the weights of the model.
Expand Down
27 changes: 26 additions & 1 deletion keras/src/models/model_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import pickle

import numpy as np
import pytest
from absl.testing import parameterized
Expand Down Expand Up @@ -116,6 +118,29 @@ def call(self, x):
)
self.assertIsInstance(new_model, Functional)

@parameterized.named_parameters(
("single_output_1", _get_model_single_output),
("single_output_2", _get_model_single_output),
("single_output_3", _get_model_single_output),
("single_output_4", _get_model_single_output),
("single_list_output_1", _get_model_single_output_list),
("single_list_output_2", _get_model_single_output_list),
("single_list_output_3", _get_model_single_output_list),
("single_list_output_4", _get_model_single_output_list),
)
def test_functional_pickling(self, model_fn):
model = model_fn()
self.assertIsInstance(model, Functional)
model.compile()
x = np.random.rand(8, 3)

reloaded_pickle = pickle.loads(pickle.dumps(model))

pred_reloaded = reloaded_pickle.predict(x)
pred = model.predict(x)

self.assertAllClose(np.array(pred_reloaded), np.array(pred))

@parameterized.named_parameters(
("single_output_1", _get_model_single_output, None),
("single_output_2", _get_model_single_output, "list"),
Expand All @@ -138,7 +163,7 @@ def test_functional_single_output(self, model_fn, loss_type):
loss = [loss]
elif loss_type == "dict":
loss = {"output_a": loss}
elif loss_type == "dict_lsit":
elif loss_type == "dict_list":
loss = {"output_a": [loss]}
model.compile(
optimizer="sgd",
Expand Down

0 comments on commit 6e42834

Please sign in to comment.