-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make keras.Model picklable #14748
Make keras.Model picklable #14748
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR. This is valuable functionality.
Returns: | ||
keras.Model: a Keras Model instance. | ||
""" | ||
temp_dir = f"ram://{uuid4()}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will this work across all systems?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As per tensorflow/tensorflow#48086, there is currently a bug in TF that makes this fail on Windows. It looks like the current momentum is to fix it via tensorflow/tensorflow#48125, which is getting close to being ready.
So for now this would only work on *nix, but once the bug in TF is fixed via that PR or otherwise, this should work on all platforms without any change to this PR / Keras.
keras/saving/pickle_utils.py
Outdated
from keras.saving.save import load_model | ||
|
||
|
||
def unpack_model(packed_keras_model): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's use the names serialize_as_bytecode
and deserialize_from_bytecode
to be more explicit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds great, thank you for the suggestion
keras/saving/pickle_utils_test.py
Outdated
"""Tests pickle protoocol support. | ||
""" | ||
|
||
@keras_parameterized.run_all_keras_modes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This parameterization isn't useful here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, removed
keras/saving/pickle_utils_test.py
Outdated
|
||
@keras_parameterized.run_all_keras_modes | ||
def test_pickle_model(self): | ||
"""Test copy.copy, copy.deepcopy and pickle on Functional Model.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's test all model types (Sequential, Functional, subclass). We have a parameterization for that. @keras_parameterized.run_with_all_model_types
. See examples.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I attempted to incorporate this parametrization based on other examples, but I'm not 100% sure I got it right.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems to be working, I'm seeing tests run for sequential
, subclass
, etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you!
I'm actually seeing an error in the subclass model case:
|
Correction: this object is actually replicated in |
Yeah, I see that as well. Would you propose implementing I think another alternative might be to require pickle protocol >=3 (default as of Python 3.4 I believe); I'll test this and report back. |
Yes, I think it would be straightforward. We'd have to do it for Changes would have to be replicated in the TF versions of these objects in a separate PR (for consistency).
That is fine too if that works. |
Both of these solutions led to the same problem: somewhere a But maybe let's think higher level for a second: this is only happening for subclassed models, and only for untrained models (as per this PR, if class MyModel(keras.Model):
.... To behave just like |
Did you apply the fix to all 3 objects I listed? I really do expect that 3 objects are the only problem.
I would suggest simply raising a clear error message when people try to pickle an unbuilt model. That would solve the issue. |
Thanks for the update. Some tests are failing: https://source.cloud.google.com/results/invocations/1f3f9ea2-14c6-4d42-907f-74d572c363ac/targets/keras%2Fgithub%2Fubuntu%2Fcpu%2Fpresubmit/log |
Apologies, I didn't run the entire test suite before pushing. It looks like the failure is coming from here: keras/keras/tests/model_subclassing_test.py Lines 711 to 724 in 70d7d07
Please correct me if I am misinterpreting the logs. The above test instantiates a subclassed model, then tries to In my opinion, the existing test is not great; it is testing an API that only partially works, and was never explicitly supported/implemented. I would remove that test and only explicitly support built models, as discussed in #14748 (comment). But perhaps I am missing some of the finer details. What do you think @fchollet ? |
Can't we just fall back to It would be bad practice to break an existing, tested behavior. Given how extensively Keras is used at Google this would be pretty much guaranteed to break some internal builds, which we'd have to resolve on our end before merging the PR, which could significantly delay merging the PR. |
That is what we were doing, but In other words, the existing test ( At least that is my understanding of the situation. There are 3 options I can think of:
|
This is the more robust solution of the lot, and I believe it is quite doable -- adding support for the list of objects I mentioned should be enough. If you try this and it fails, I would recommend falling back to:
This is an acceptable option because for the models that will fail the user will see an explicit error message -- "your model contains these weird objects that can't be pickled." The user will be left wondering, "wait, why?" but hopefully this will only happen for a small set of models. |
@fchollet I think the weak refs can be worked around by doing something like* class Model:
def __getstate__(self):
state = super().__getstate__()
state.pop("_compiled_trainable_state", None)
state.pop("_trackable_saver", None)
return state
def __setstate__(self, state):
super().__setstate__(state)
self._reset_compile_cache()
self._trackable_saver = saver_with_op_caching(self) I'm not sure if this is safe or not, I do not have the context for what the expected behavior/use of this data is. In any case, this just led me down a rabbit hole of more unpicklable things around the Keras/TensorFlow codebase:
And some more that I lost track of. Thus, I would like to propose that we scope/implement this PR as follows:
I think this should take care of the most common use cases without breaking any existing use cases or tests, while also leaving the door open for those loose ends around the codebase to be tied up so that Keras can promise picklability of unbuilt models in the future. dca259a implements this proposal, and passes * I did not push this because it is not needed unless we wanted to attempt to support pickling of unbuilt models int this PR, which I am proposing we do not |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the update!
I think the weak refs can be worked around by doing something like*
I believe this would invalidate model compilation (you'd have to recompile the model after copy). Which of course would not matter for an unbuilt model. Overall the workaround looks mysterious/complex so it's better not to do it, for the sake of future maintainability.
Thus, I would like to propose that we scope/implement this PR as follows:
That sounds good to me.
keras/engine/training.py
Outdated
# it _may_ be possible to serialize as a plain Python object, | ||
# as long as the constituent parts (layers, optimizers, losses, etc.) | ||
# can be serialized as plain Python objects. | ||
# Thus we call up the MRO to get an implementation of __reduce__ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MRO?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MRO == Method Resolution Order here, but I admit it's probably not the best term to use, much less abbreviated.
How about Thus we call up the superclass hierarchy to get an implementation of __reduce__
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
An attempt at porting over from tensorflow/tensorflow#39609.
Is this something that should now go here (in this repo) now?
Thanks