From 1e8426b2f543c4bb739994bb2fc9e0b5ad117e3f Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Mon, 25 Nov 2024 16:25:25 +0100 Subject: [PATCH] Fix issue with list/dict losses --- keras/src/losses/losses.py | 3 + keras/src/models/functional.py | 43 +++++++++-- keras/src/models/functional_test.py | 114 +++++++++++++++++++++++++++- keras/src/trainers/compile_utils.py | 75 +++++++++++------- 4 files changed, 200 insertions(+), 35 deletions(-) diff --git a/keras/src/losses/losses.py b/keras/src/losses/losses.py index 3d4d47a7373..cc18b37df65 100644 --- a/keras/src/losses/losses.py +++ b/keras/src/losses/losses.py @@ -44,6 +44,9 @@ def from_config(cls, config): config = serialization_lib.deserialize_keras_object(config) return cls(**config) + def __repr__(self): + return f"" + @keras_export("keras.losses.MeanSquaredError") class MeanSquaredError(LossFunctionWrapper): diff --git a/keras/src/models/functional.py b/keras/src/models/functional.py index 9c71308a651..6d0bd40e352 100644 --- a/keras/src/models/functional.py +++ b/keras/src/models/functional.py @@ -212,7 +212,7 @@ def output_shape(self): def _assert_input_compatibility(self, *args): return super(Model, self)._assert_input_compatibility(*args) - def _maybe_warn_inputs_struct_mismatch(self, inputs): + def _maybe_warn_inputs_struct_mismatch(self, inputs, raise_exception=False): try: # We first normalize to tuples before performing the check to # suppress warnings when encountering mismatched tuples and lists. @@ -225,12 +225,17 @@ def _maybe_warn_inputs_struct_mismatch(self, inputs): model_inputs_struct = tree.map_structure( lambda x: x.name, self._inputs_struct ) - inputs_struct = tree.map_structure(lambda x: f"type({x})", inputs) - warnings.warn( + inputs_struct = tree.map_structure( + lambda x: f"Tensor(shape={x.shape})", inputs + ) + msg = ( "The structure of `inputs` doesn't match the expected " - f"structure: {model_inputs_struct}. " - f"Received: the structure of inputs={inputs_struct}" + f"structure.\nExpected: {model_inputs_struct}\n" + f"Received: inputs={inputs_struct}" ) + if raise_exception: + raise ValueError(msg) + warnings.warn(msg) def _convert_inputs_to_tensors(self, flat_inputs): converted = [] @@ -279,7 +284,33 @@ def _adjust_input_rank(self, flat_inputs): return adjusted def _standardize_inputs(self, inputs): - self._maybe_warn_inputs_struct_mismatch(inputs) + raise_exception = False + if isinstance(inputs, dict) and not isinstance( + self._inputs_struct, dict + ): + # This is to avoid warning + # when we have reconciable dict/list structs + if hasattr(self._inputs_struct, "__len__") and all( + isinstance(i, backend.KerasTensor) for i in self._inputs_struct + ): + expected_keys = set(i.name for i in self._inputs_struct) + keys = set(inputs.keys()) + if expected_keys.issubset(keys): + inputs = [inputs[i.name] for i in self._inputs_struct] + else: + raise_exception = True + elif isinstance(self._inputs_struct, backend.KerasTensor): + if self._inputs_struct.name in inputs: + inputs = [inputs[self._inputs_struct.name]] + else: + raise_exception = True + else: + raise_exception = True + + self._maybe_warn_inputs_struct_mismatch( + inputs, raise_exception=raise_exception + ) + flat_inputs = tree.flatten(inputs) flat_inputs = self._convert_inputs_to_tensors(flat_inputs) return self._adjust_input_rank(flat_inputs) diff --git a/keras/src/models/functional_test.py b/keras/src/models/functional_test.py index 6314985f198..5fab4546cac 100644 --- a/keras/src/models/functional_test.py +++ b/keras/src/models/functional_test.py @@ -137,10 +137,10 @@ def test_basic_flow_as_a_submodel(self): @pytest.mark.requires_trainable_backend def test_named_input_dict_io(self): + # Single input input_a = Input(shape=(3,), batch_size=2, name="a") x = layers.Dense(5)(input_a) outputs = layers.Dense(4)(x) - model = Functional(input_a, outputs) # Eager call @@ -154,6 +154,27 @@ def test_named_input_dict_io(self): out_val = model(in_val) self.assertEqual(out_val.shape, (2, 4)) + # Two inputs + input_a = Input(shape=(3,), batch_size=2, name="a") + input_b = Input(shape=(4,), batch_size=2, name="b") + a = layers.Dense(5)(input_a) + b = layers.Dense(5)(input_b) + x = layers.Concatenate()([a, b]) + outputs = layers.Dense(4)(x) + model = Functional([input_a, input_b], outputs) + + # Eager call + in_val = {"a": np.random.random((2, 3)), "b": np.random.random((2, 4))} + out_val = model(in_val) + self.assertEqual(out_val.shape, (2, 4)) + + # Symbolic call + input_a_2 = Input(shape=(3,), batch_size=2) + input_b_2 = Input(shape=(4,), batch_size=2) + in_val = {"a": input_a_2, "b": input_b_2} + out_val = model(in_val) + self.assertEqual(out_val.shape, (2, 4)) + @pytest.mark.requires_trainable_backend def test_input_dict_with_extra_field(self): input_a = Input(shape=(3,), batch_size=2, name="a") @@ -180,7 +201,7 @@ def test_input_dict_with_extra_field(self): self.assertLen(record, 1) self.assertStartsWith( str(record[0].message), - r"The structure of `inputs` doesn't match the expected structure:", + r"The structure of `inputs` doesn't match the expected structure", ) @parameterized.named_parameters( @@ -547,3 +568,92 @@ def test_layers_setter(self): AttributeError, "`Model.layers` attribute is reserved" ): model.layers = [layers.Dense(4)] + + def test_dict_input_to_list_model(self): + vocabulary_size = 100 + num_tags = 10 + num_departments = 3 + num_samples = 128 + + title = layers.Input(shape=(vocabulary_size,), name="title") + text_body = layers.Input(shape=(vocabulary_size,), name="text_body") + tags = layers.Input(shape=(num_tags,), name="tags") + features = layers.Concatenate()([title, text_body, tags]) + features = layers.Dense(64, activation="relu")(features) + priority = layers.Dense(1, activation="sigmoid", name="priority")( + features + ) + department = layers.Dense( + num_departments, activation="softmax", name="department" + )(features) + model = Functional( + inputs=[title, text_body, tags], outputs=[priority, department] + ) + + title_data = np.random.randint( + 0, 2, size=(num_samples, vocabulary_size) + ) + text_body_data = np.random.randint( + 0, 2, size=(num_samples, vocabulary_size) + ) + tags_data = np.random.randint(0, 2, size=(num_samples, num_tags)) + priority_data = np.random.random(size=(num_samples, 1)) + department_data = np.random.randint( + 0, 2, size=(num_samples, num_departments) + ) + + # List style fit + model.compile( + optimizer="adam", + loss=["mean_squared_error", "categorical_crossentropy"], + metrics=[["mean_absolute_error"], ["accuracy"]], + ) + model.fit( + [title_data, text_body_data, tags_data], + [priority_data, department_data], + epochs=1, + ) + model.evaluate( + [title_data, text_body_data, tags_data], + [priority_data, department_data], + ) + priority_preds, department_preds = model.predict( + [title_data, text_body_data, tags_data] + ) + + # Dict style fit + model.compile( + optimizer="adam", + loss={ + "priority": "mean_squared_error", + "department": "categorical_crossentropy", + }, + metrics={ + "priority": ["mean_absolute_error"], + "department": ["accuracy"], + }, + ) + model.fit( + { + "title": title_data, + "text_body": text_body_data, + "tags": tags_data, + }, + {"priority": priority_data, "department": department_data}, + epochs=1, + ) + model.evaluate( + { + "title": title_data, + "text_body": text_body_data, + "tags": tags_data, + }, + {"priority": priority_data, "department": department_data}, + ) + priority_preds, department_preds = model.predict( + { + "title": title_data, + "text_body": text_body_data, + "tags": tags_data, + } + ) diff --git a/keras/src/trainers/compile_utils.py b/keras/src/trainers/compile_utils.py index 2e094e31e79..4e9aacb8119 100644 --- a/keras/src/trainers/compile_utils.py +++ b/keras/src/trainers/compile_utils.py @@ -428,7 +428,6 @@ def __init__( f"Received instead: loss_weights={loss_weights} " f"of type {type(loss_weights)}" ) - self._user_loss = loss self._user_loss_weights = loss_weights self.built = False @@ -566,8 +565,21 @@ def key_check_fn(key, objs): ) def build(self, y_true, y_pred): - loss = self._user_loss - loss_weights = self._user_loss_weights + if ( + self.output_names + and isinstance(self._user_loss, dict) + and not isinstance(y_pred, dict) + ): + loss = [self._user_loss[name] for name in self.output_names] + if isinstance(self._user_loss_weights, dict): + loss_weights = [ + self._user_loss_weights[name] for name in self.output_names + ] + else: + loss_weights = self._user_loss_weights + else: + loss = self._user_loss + loss_weights = self._user_loss_weights flat_output_names = self.output_names # Pytree leaf container @@ -691,26 +703,30 @@ def call(self, y_true, y_pred, sample_weight=None): try: tree.assert_same_structure(y_pred, y_true, check_types=False) except ValueError: + # Check case where y_true is either flat or leaf + if ( + not tree.is_nested(y_true) + and hasattr(y_pred, "__len__") + and len(y_pred) == 1 + ): + y_true = [y_true] + + # Check case where y_pred is list/tuple and y_true is dict + elif isinstance(y_pred, (list, tuple)) and isinstance(y_true, dict): + if set(self.output_names) == set(y_true.keys()): + y_true = [y_true[name] for name in self.output_names] + try: - # Check case where y_true is either flat or leaf - if ( - not tree.is_nested(y_true) - and hasattr(y_pred, "__len__") - and len(y_pred) == 1 - ): - y_true = [y_true] + y_true = tree.pack_sequence_as(y_pred, y_true) + except: + # Check case where y_true has the same structure but uses + # different (but reconcilable) container types, + # e.g `list` vs `tuple`. try: - y_true = tree.pack_sequence_as(y_pred, y_true) + tree.assert_same_paths(y_true, y_pred) + y_true = tree.pack_sequence_as(y_pred, tree.flatten(y_true)) except: - # Check case where y_true has the same structure but uses - # different (but reconcilable) container types, - # e.g `list` vs `tuple`. try: - tree.assert_same_paths(y_true, y_pred) - y_true = tree.pack_sequence_as( - y_pred, tree.flatten(y_true) - ) - except: # Check case where loss is partially defined over y_pred flat_y_true = tree.flatten(y_true) flat_loss = tree.flatten(self._user_loss) @@ -726,14 +742,18 @@ def call(self, y_true, y_pred, sample_weight=None): ): y_true[i] = y_t y_true = tree.pack_sequence_as(self._user_loss, y_true) - except: - y_true_struct = tree.map_structure(lambda _: "*", y_true) - y_pred_struct = tree.map_structure(lambda _: "*", y_pred) - raise ValueError( - "y_true and y_pred have different structures.\n" - f"y_true: {y_true_struct}\n" - f"y_pred: {y_pred_struct}\n" - ) + except: + y_true_struct = tree.map_structure( + lambda _: "*", y_true + ) + y_pred_struct = tree.map_structure( + lambda _: "*", y_pred + ) + raise ValueError( + "y_true and y_pred have different structures.\n" + f"y_true: {y_true_struct}\n" + f"y_pred: {y_pred_struct}\n" + ) if not self.built: self.build(y_true, y_pred) @@ -774,6 +794,7 @@ def resolve_path(path, object): _sample_weight = resolve_path(path, sample_weight) else: _sample_weight = sample_weight + value = ops.cast( loss_fn(y_t, y_p, _sample_weight), dtype=self.dtype )