diff --git a/tf_keras/engine/functional_test.py b/tf_keras/engine/functional_test.py index 786cc4297..76c65fc98 100644 --- a/tf_keras/engine/functional_test.py +++ b/tf_keras/engine/functional_test.py @@ -691,6 +691,35 @@ def test_multi_input_multi_output_recursion(self): json_str = model.to_json() models.model_from_json(json_str) + @test_combinations.generate( + test_combinations.combine(mode=["graph", "eager"]) + ) + def test_multi_input_layer_call(self): + @object_registration.register_keras_serializable() + class MyLayer(layers.Layer): + def call(self, embedding, query_indices, slot_id, position): + return [embedding, query_indices, slot_id, position] + + with self.cached_session(): + a = layers.Input(shape=(32,), name="input_a") + b = layers.Input(shape=(32,), name="input_b") + c = layers.Input(shape=(32,), name="input_c") + d = layers.Input(shape=(32,), name="input_d") + + output = MyLayer()(a, b, c, d) + model = training_lib.Model( + inputs=[a, b, c, d], outputs=output, name="model" + ) + + config = model.get_config() + model2 = models.Model.from_config(config) + self.assertEqual(model2.get_config(), config) + + model.summary() + json_str = model.to_json() + model2 = models.model_from_json(json_str) + self.assertEqual(model2.to_json(), json_str) + @test_combinations.generate( test_combinations.combine(mode=["graph", "eager"]) ) diff --git a/tf_keras/engine/node.py b/tf_keras/engine/node.py index 2071bdb01..182ee73e2 100644 --- a/tf_keras/engine/node.py +++ b/tf_keras/engine/node.py @@ -84,9 +84,10 @@ def __init__(self, layer, call_args=None, call_kwargs=None, outputs=None): self.call_args = call_args self.call_kwargs = call_kwargs - # Cached for performance. + # Cached for performance. Put kwargs in order of the call method instead + # of using the sorted key order from `tf.nest.flatten`. self._flat_arguments = tf.nest.flatten( - (self.call_args, self.call_kwargs) + (self.call_args, self.call_kwargs.values()) ) # Used to avoid expensive `nest` operations in the most common case. self._single_positional_tensor_passed = ( @@ -176,9 +177,13 @@ def map_arguments(self, tensor_dict): for kt_id, kt_index in self._keras_inputs_ids_and_indices: flat_arguments[kt_index] = tensor_dict[kt_id].pop() + # Pack the same way as `self._flat_arguments`, i.e. `kwargs` as a + # list in the original order. args, kwargs = tf.nest.pack_sequence_as( - (self.call_args, self.call_kwargs), flat_arguments + (self.call_args, self.call_kwargs.values()), flat_arguments ) + # Add the keys to `kwargs` to go from a list to a dict. + kwargs = {k: v for k, v in zip(self.call_kwargs.keys(), kwargs)} return args, kwargs def serialize(self, make_node_key, node_conversion_map):