Skip to content

Commit

Permalink
Fix a bug impacting serialization order for call methods with mutiple…
Browse files Browse the repository at this point in the history
… inputs.

This bug made `get_config/from_config` and `to_json/from_json` not idempotent for layers that take multiple inputs in their `call` method.

When the functional model is constructed and the multiple inputs are passed as positional arguments, the `Node` object has multiple `call_args` in a list.

However, by design, serialization only treats the first argument and positional and serializes all the other arguments as keyword arguments. Upon deserialization, the extra arguments are created as keyword arguments. Their order was modified by `tf.nest.flatten`, which sorts dicts by key.

This change preserves the order of keyword arguments, regardless of keys.

Fixes #795

PiperOrigin-RevId: 686641020
  • Loading branch information
hertschuh authored and tensorflower-gardener committed Oct 16, 2024
1 parent 2aa84ae commit ea2cc87
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 3 deletions.
29 changes: 29 additions & 0 deletions tf_keras/engine/functional_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,6 +691,35 @@ def test_multi_input_multi_output_recursion(self):
json_str = model.to_json()
models.model_from_json(json_str)

@test_combinations.generate(
test_combinations.combine(mode=["graph", "eager"])
)
def test_multi_input_layer_call(self):
@object_registration.register_keras_serializable()
class MyLayer(layers.Layer):
def call(self, embedding, query_indices, slot_id, position):
return [embedding, query_indices, slot_id, position]

with self.cached_session():
a = layers.Input(shape=(32,), name="input_a")
b = layers.Input(shape=(32,), name="input_b")
c = layers.Input(shape=(32,), name="input_c")
d = layers.Input(shape=(32,), name="input_d")

output = MyLayer()(a, b, c, d)
model = training_lib.Model(
inputs=[a, b, c, d], outputs=output, name="model"
)

config = model.get_config()
model2 = models.Model.from_config(config)
self.assertEqual(model2.get_config(), config)

model.summary()
json_str = model.to_json()
model2 = models.model_from_json(json_str)
self.assertEqual(model2.to_json(), json_str)

@test_combinations.generate(
test_combinations.combine(mode=["graph", "eager"])
)
Expand Down
11 changes: 8 additions & 3 deletions tf_keras/engine/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,10 @@ def __init__(self, layer, call_args=None, call_kwargs=None, outputs=None):
self.call_args = call_args
self.call_kwargs = call_kwargs

# Cached for performance.
# Cached for performance. Put kwargs in order of the call method instead
# of using the sorted key order from `tf.nest.flatten`.
self._flat_arguments = tf.nest.flatten(
(self.call_args, self.call_kwargs)
(self.call_args, self.call_kwargs.values())
)
# Used to avoid expensive `nest` operations in the most common case.
self._single_positional_tensor_passed = (
Expand Down Expand Up @@ -176,9 +177,13 @@ def map_arguments(self, tensor_dict):
for kt_id, kt_index in self._keras_inputs_ids_and_indices:
flat_arguments[kt_index] = tensor_dict[kt_id].pop()

# Pack the same way as `self._flat_arguments`, i.e. `kwargs` as a
# list in the original order.
args, kwargs = tf.nest.pack_sequence_as(
(self.call_args, self.call_kwargs), flat_arguments
(self.call_args, self.call_kwargs.values()), flat_arguments
)
# Add the keys to `kwargs` to go from a list to a dict.
kwargs = {k: v for k, v in zip(self.call_kwargs.keys(), kwargs)}
return args, kwargs

def serialize(self, make_node_key, node_conversion_map):
Expand Down

0 comments on commit ea2cc87

Please sign in to comment.