Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make keras.Model picklable #14748

Merged
merged 22 commits into from
Jul 16, 2021

Conversation

adriangb
Copy link
Contributor

An attempt at porting over from tensorflow/tensorflow#39609.

Is this something that should now go here (in this repo) now?

Thanks

@google-cla google-cla bot added the cla: yes label Jun 18, 2021
Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR. This is valuable functionality.

Returns:
keras.Model: a Keras Model instance.
"""
temp_dir = f"ram://{uuid4()}"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this work across all systems?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As per tensorflow/tensorflow#48086, there is currently a bug in TF that makes this fail on Windows. It looks like the current momentum is to fix it via tensorflow/tensorflow#48125, which is getting close to being ready.

So for now this would only work on *nix, but once the bug in TF is fixed via that PR or otherwise, this should work on all platforms without any change to this PR / Keras.

from keras.saving.save import load_model


def unpack_model(packed_keras_model):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use the names serialize_as_bytecode and deserialize_from_bytecode to be more explicit.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds great, thank you for the suggestion

"""Tests pickle protoocol support.
"""

@keras_parameterized.run_all_keras_modes
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This parameterization isn't useful here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, removed


@keras_parameterized.run_all_keras_modes
def test_pickle_model(self):
"""Test copy.copy, copy.deepcopy and pickle on Functional Model."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's test all model types (Sequential, Functional, subclass). We have a parameterization for that. @keras_parameterized.run_with_all_model_types. See examples.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I attempted to incorporate this parametrization based on other examples, but I'm not 100% sure I got it right.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to be working, I'm seeing tests run for sequential, subclass, etc.

Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you!

@fchollet fchollet added the ready to pull Ready to be merged into the codebase label Jun 20, 2021
@fchollet fchollet added ready to pull Ready to be merged into the codebase and removed ready to pull Ready to be merged into the codebase labels Jun 20, 2021
@fchollet
Copy link
Member

I'm actually seeing an error in the subclass model case:

Traceback (most recent call last):
  File "<embedded stdlib>/copyreg.py", line 69, in _reduce_ex
    getstate = self.__getstate__
AttributeError: 'ObjectIdentityDictionary' object has no attribute '__getstate__'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "absl/testing/parameterized.py", line 316, in bound_param_test
    return test_method(self, *testcase_params)
  File "keras/keras_parameterized.py", line 284, in decorated
    _test_subclass_model_type(f, self, *args, **kwargs)
  File "keras/keras_parameterized.py", line 301, in _test_subclass_model_type
    f(test_or_class, *args, **kwargs)
  File "keras/saving/pickle_utils_test.py", line 47, in test_pickle_model
    model = roundtrip(original_model)
  File "keras/saving/pickle_utils_test.py", line 38, in roundtrip
    model = pickle.loads(pickle.dumps(model, protocol=protocol))
  File "<embedded stdlib>/copyreg.py", line 72, in _reduce_ex
    raise TypeError("a class that defines __slots__ without "
TypeError: a class that defines __slots__ without defining __getstate__ cannot be pickled

ObjectIdentityDictionary appears to be a class defined in tensorflow/python/util/object_identity.py. It may be that we have to fix it there first before we can proceed with this PR.

@fchollet
Copy link
Member

Correction: this object is actually replicated in keras/utils/object_identity.py.

@adriangb
Copy link
Contributor Author

adriangb commented Jun 22, 2021

Yeah, I see that as well.

Would you propose implementing __{get,set}state__ for those objects? If it's just those, that seems reasonable.

I think another alternative might be to require pickle protocol >=3 (default as of Python 3.4 I believe); I'll test this and report back.

@fchollet
Copy link
Member

Would you propose implementing {get,set}state for those objects? If it's just those, that seems reasonable.

Yes, I think it would be straightforward. We'd have to do it for ObjectIdentityDictionary, _ObjectIdentityWrapper, ObjectIdentitySet. Some have weakrefs so it will require a little bit of thinking but still very straightforward.

Changes would have to be replicated in the TF versions of these objects in a separate PR (for consistency).

I think another alternative might be to require pickle protocol >=3 (default as of Python 3.4 I believe); I'll test this and report back.

That is fine too if that works.

@fchollet fchollet self-assigned this Jun 22, 2021
@adriangb
Copy link
Contributor Author

require pickle protocol >=3

implementing {get,set}state for those objects

Both of these solutions led to the same problem: somewhere a weakref object is (attempting) to be pickled. I can't tell where because this is happening within cPickle, so there isn't a traceback. I haven't had any luck setting up a test.py to test manually in pdb or otherwise, it seems that some protobuf compiling and such is needed which I guess Bazel does automatically.

But maybe let's think higher level for a second: this is only happening for subclassed models, and only for untrained models (as per this PR, if Model.built is False copying/pickling is delegated to object because SavedModel doesn't support unbuilt models; please correct me if this is wrong). Why would subclassed models behave any differently than non-subclassed models? That is, I'd expect:

class MyModel(keras.Model):
    ....

To behave just like keras.Model. So there must be some other stuff going on with the subclassing?

@fchollet
Copy link
Member

Both of these solutions led to the same problem

Did you apply the fix to all 3 objects I listed? I really do expect that 3 objects are the only problem.

But maybe let's think higher level for a second: this is only happening for subclassed models, and only for untrained models (as per this PR, if Model.built is False copying/pickling is delegated to object because SavedModel doesn't support unbuilt models

I would suggest simply raising a clear error message when people try to pickle an unbuilt model. That would solve the issue.

@fchollet
Copy link
Member

fchollet commented Jul 4, 2021

@adriangb
Copy link
Contributor Author

adriangb commented Jul 4, 2021

Apologies, I didn't run the entire test suite before pushing.

It looks like the failure is coming from here:

class MyModel(keras.Model):
def __init__(self):
super(MyModel, self).__init__()
self.my_variable = tf.Variable(0.0, trainable=False)
self.layer = keras.layers.Dense(4)
def call(self, obs):
return self.layer(obs)
model = MyModel()
model.my_variable.assign_add(1.0)
new_model = copy.deepcopy(model)

Please correct me if I am misinterpreting the logs.

The above test instantiates a subclassed model, then tries to copy.deepcopy without fitting or calling .build.
Before this PR, that would implicitly fall back to object.__reduce__, but with this PR it is now routed to SaveModel.
But as discussed above, SaveModel does not support unbuilt models.
And, also as discussed in #14748 (comment), object.__reduce__ does not support all types of models.

In my opinion, the existing test is not great; it is testing an API that only partially works, and was never explicitly supported/implemented. I would remove that test and only explicitly support built models, as discussed in #14748 (comment). But perhaps I am missing some of the finer details.

What do you think @fchollet ?

@fchollet
Copy link
Member

fchollet commented Jul 5, 2021

Before this PR, that would implicitly fall back to object.reduce, but with this PR it is now routed to SaveModel.

Can't we just fall back to object.__reduce__ when the model is unbuilt? That way we preserve the existing behavior, while adding robust support for pickling / copying built models.

It would be bad practice to break an existing, tested behavior. Given how extensively Keras is used at Google this would be pretty much guaranteed to break some internal builds, which we'd have to resolve on our end before merging the PR, which could significantly delay merging the PR.

@adriangb
Copy link
Contributor Author

adriangb commented Jul 5, 2021

Can't we just fall back to object.reduce when the model is unbuilt?

That is what we were doing, but object.__reduce__ doesn't support models that have those weakref-based wrappers (which I guess are part of the parametrized tests but no the existing copy.copy test that is now failing), which is why back at #14748 (comment) we had decided to only support unbuilt models.

In other words, the existing test (keras/keras/tests/model_subclassing_test.py -> test_deepcopy) is testing with a class of model that seems to work fine with object.__reduce__, but the naïve behavior of object.__reduce__ can't support all Keras models, since we already know that some of the ones generated by parametrization can't be copied via object.__reduce__.

At least that is my understanding of the situation.

There are 3 options I can think of:

  1. Don't support unbuilt models and remove the existing test. As you say above, this is probably a bad idea.
  2. Revert this PR back to 28c1187 (when we were supporting unbuilt models) and exclude the subclassed parametrization in the new test, maybe manually testing a simple subclassed model like the one in keras/keras/tests/model_subclassing_test.py -> test_deepcopy.
  3. Revert this PR back to 28c1187 and fix the pickling of those wrapper objects and any other issues that get uncovered once that's fixed (which I attempted to do and failed, but can try again).

@fchollet
Copy link
Member

fchollet commented Jul 7, 2021

Revert this PR back to 28c1187 and fix the pickling of those wrapper objects and any other issues that get uncovered once that's fixed (which I attempted to do and failed, but can try again).

This is the more robust solution of the lot, and I believe it is quite doable -- adding support for the list of objects I mentioned should be enough.

If you try this and it fails, I would recommend falling back to:

exclude the subclassed parametrization in the new test, maybe manually testing a simple subclassed model like the one in keras/keras/tests/model_subclassing_test.py -> test_deepcopy.

This is an acceptable option because for the models that will fail the user will see an explicit error message -- "your model contains these weird objects that can't be pickled." The user will be left wondering, "wait, why?" but hopefully this will only happen for a small set of models.

@adriangb
Copy link
Contributor Author

@fchollet I think the weak refs can be worked around by doing something like*

class Model:

  def __getstate__(self):
      state = super().__getstate__()
      state.pop("_compiled_trainable_state", None)
      state.pop("_trackable_saver", None)
      return state

  def __setstate__(self, state):
      super().__setstate__(state)
      self._reset_compile_cache()
      self._trackable_saver = saver_with_op_caching(self)

I'm not sure if this is safe or not, I do not have the context for what the expected behavior/use of this data is.
For what it's worth, bazel test -c opt -- //keras/saving/... //keras/engine/... //keras/tests/... still passes all tests.

In any case, this just led me down a rabbit hole of more unpicklable things around the Keras/TensorFlow codebase:

And some more that I lost track of.
For those two, the fix consists of moving the locals to the module level.
For the Metric issue, maybe using keras.metrics.serialize would work.
Unfortunately, I do not have time to keep digging deeper and make multiple PRs across both repos.
And I'd also fear breaking something that might not even be tested, these fixes start to get deep into implementation details which I'm afraid may only be obvious to those with extensive experience in the codebase.

Thus, I would like to propose that we scope/implement this PR as follows:

  • Built models can be copied, deepcopied and pickled (backed by SavedModel)
  • Unbuilt models can be copied and deepcopied, but we make no promises about pickling (this will depend on what losses / optimizers / etc. the model is constructed with).

I think this should take care of the most common use cases without breaking any existing use cases or tests, while also leaving the door open for those loose ends around the codebase to be tied up so that Keras can promise picklability of unbuilt models in the future.

dca259a implements this proposal, and passes bazel test -c opt -- //keras/saving/... //keras/engine/... //keras/tests/... locally for me.

* I did not push this because it is not needed unless we wanted to attempt to support pickling of unbuilt models int this PR, which I am proposing we do not

Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update!

I think the weak refs can be worked around by doing something like*

I believe this would invalidate model compilation (you'd have to recompile the model after copy). Which of course would not matter for an unbuilt model. Overall the workaround looks mysterious/complex so it's better not to do it, for the sake of future maintainability.

Thus, I would like to propose that we scope/implement this PR as follows:

That sounds good to me.

# it _may_ be possible to serialize as a plain Python object,
# as long as the constituent parts (layers, optimizers, losses, etc.)
# can be serialized as plain Python objects.
# Thus we call up the MRO to get an implementation of __reduce__
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MRO?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MRO == Method Resolution Order here, but I admit it's probably not the best term to use, much less abbreviated.

How about Thus we call up the superclass hierarchy to get an implementation of __reduce__?

keras/engine/training.py Outdated Show resolved Hide resolved
keras/saving/pickle_utils.py Outdated Show resolved Hide resolved
keras/saving/pickle_utils.py Outdated Show resolved Hide resolved
keras/saving/pickle_utils.py Outdated Show resolved Hide resolved
keras/saving/pickle_utils.py Outdated Show resolved Hide resolved
@adriangb
Copy link
Contributor Author

adriangb commented Jul 12, 2021

@fchollet thank you for that last round of review. I pushed the requested changes in 8a96eed

Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@fchollet fchollet added kokoro:force-run and removed ready to pull Ready to be merged into the codebase labels Jul 13, 2021
@qlzh727 qlzh727 added the ready to pull Ready to be merged into the codebase label Jul 15, 2021
@copybara-service copybara-service bot merged commit a86bc99 into keras-team:master Jul 16, 2021
@adriangb adriangb deleted the make-model-picklable branch July 16, 2021 18:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla: yes ready to pull Ready to be merged into the codebase
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants