From 2da180029f10246bcc82c2d6e000e4ff6e863e23 Mon Sep 17 00:00:00 2001 From: Fabien Hertschuh Date: Tue, 29 Oct 2024 13:50:19 -0700 Subject: [PATCH] Fix serialization / deserialization. - Serialization was not taking the registered name and package from the registry. - Deserialization was selecting symbols by postfix as a fallback. PiperOrigin-RevId: 691149084 --- tf_keras/saving/object_registration_test.py | 10 ++++++---- tf_keras/saving/serialization_lib.py | 11 +---------- 2 files changed, 7 insertions(+), 14 deletions(-) diff --git a/tf_keras/saving/object_registration_test.py b/tf_keras/saving/object_registration_test.py index 8cbf39a8b..57f4bf003 100644 --- a/tf_keras/saving/object_registration_test.py +++ b/tf_keras/saving/object_registration_test.py @@ -112,11 +112,13 @@ def from_config(cls, config): self.assertEqual(5, new_inst._val) def test_serialize_custom_function(self): - @object_registration.register_keras_serializable() + @object_registration.register_keras_serializable( + package="Test", name="func" + ) def my_fn(): return 42 - serialized_name = "Custom>my_fn" + serialized_name = "Test>func" class_name = object_registration._GLOBAL_CUSTOM_NAMES[my_fn] self.assertEqual(serialized_name, class_name) fn_class_name = object_registration.get_registered_name(my_fn) @@ -124,9 +126,9 @@ def my_fn(): config = serialization_lib.serialize_keras_object(my_fn) if tf.__internal__.tf2.enabled(): - self.assertEqual("my_fn", config["config"]) + self.assertEqual(serialized_name, config["config"]) else: - self.assertEqual(class_name, config) + self.assertEqual(serialized_name, config) fn = serialization_lib.deserialize_keras_object(config) self.assertEqual(42, fn()) diff --git a/tf_keras/saving/serialization_lib.py b/tf_keras/saving/serialization_lib.py index a0f4fb8b0..0bfb43e97 100644 --- a/tf_keras/saving/serialization_lib.py +++ b/tf_keras/saving/serialization_lib.py @@ -378,7 +378,7 @@ def _get_class_or_fn_config(obj): """Return the object's config depending on its type.""" # Functions / lambdas: if isinstance(obj, types.FunctionType): - return obj.__name__ + return object_registration.get_registered_name(obj) # All classes: if hasattr(obj, "get_config"): config = obj.get_config() @@ -789,15 +789,6 @@ def _retrieve_class_or_fn( if obj is not None: return obj - # Retrieval of registered custom function in a package - filtered_dict = { - k: v - for k, v in custom_objects.items() - if k.endswith(full_config["config"]) - } - if filtered_dict: - return next(iter(filtered_dict.values())) - # Otherwise, attempt to retrieve the class object given the `module` # and `class_name`. Import the module, find the class. try: