Skip to content

Commit

Permalink
Fix issue with list/dict losses
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Nov 25, 2024
1 parent bef0a9e commit 1e8426b
Show file tree
Hide file tree
Showing 4 changed files with 200 additions and 35 deletions.
3 changes: 3 additions & 0 deletions keras/src/losses/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ def from_config(cls, config):
config = serialization_lib.deserialize_keras_object(config)
return cls(**config)

def __repr__(self):
return f"<LossFunctionWrapper({self.fn}, kwargs={self._fn_kwargs})>"


@keras_export("keras.losses.MeanSquaredError")
class MeanSquaredError(LossFunctionWrapper):
Expand Down
43 changes: 37 additions & 6 deletions keras/src/models/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def output_shape(self):
def _assert_input_compatibility(self, *args):
return super(Model, self)._assert_input_compatibility(*args)

def _maybe_warn_inputs_struct_mismatch(self, inputs):
def _maybe_warn_inputs_struct_mismatch(self, inputs, raise_exception=False):
try:
# We first normalize to tuples before performing the check to
# suppress warnings when encountering mismatched tuples and lists.
Expand All @@ -225,12 +225,17 @@ def _maybe_warn_inputs_struct_mismatch(self, inputs):
model_inputs_struct = tree.map_structure(
lambda x: x.name, self._inputs_struct
)
inputs_struct = tree.map_structure(lambda x: f"type({x})", inputs)
warnings.warn(
inputs_struct = tree.map_structure(
lambda x: f"Tensor(shape={x.shape})", inputs
)
msg = (
"The structure of `inputs` doesn't match the expected "
f"structure: {model_inputs_struct}. "
f"Received: the structure of inputs={inputs_struct}"
f"structure.\nExpected: {model_inputs_struct}\n"
f"Received: inputs={inputs_struct}"
)
if raise_exception:
raise ValueError(msg)
warnings.warn(msg)

def _convert_inputs_to_tensors(self, flat_inputs):
converted = []
Expand Down Expand Up @@ -279,7 +284,33 @@ def _adjust_input_rank(self, flat_inputs):
return adjusted

def _standardize_inputs(self, inputs):
self._maybe_warn_inputs_struct_mismatch(inputs)
raise_exception = False
if isinstance(inputs, dict) and not isinstance(
self._inputs_struct, dict
):
# This is to avoid warning
# when we have reconciable dict/list structs
if hasattr(self._inputs_struct, "__len__") and all(
isinstance(i, backend.KerasTensor) for i in self._inputs_struct
):
expected_keys = set(i.name for i in self._inputs_struct)
keys = set(inputs.keys())
if expected_keys.issubset(keys):
inputs = [inputs[i.name] for i in self._inputs_struct]
else:
raise_exception = True
elif isinstance(self._inputs_struct, backend.KerasTensor):
if self._inputs_struct.name in inputs:
inputs = [inputs[self._inputs_struct.name]]
else:
raise_exception = True
else:
raise_exception = True

self._maybe_warn_inputs_struct_mismatch(
inputs, raise_exception=raise_exception
)

flat_inputs = tree.flatten(inputs)
flat_inputs = self._convert_inputs_to_tensors(flat_inputs)
return self._adjust_input_rank(flat_inputs)
Expand Down
114 changes: 112 additions & 2 deletions keras/src/models/functional_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,10 @@ def test_basic_flow_as_a_submodel(self):

@pytest.mark.requires_trainable_backend
def test_named_input_dict_io(self):
# Single input
input_a = Input(shape=(3,), batch_size=2, name="a")
x = layers.Dense(5)(input_a)
outputs = layers.Dense(4)(x)

model = Functional(input_a, outputs)

# Eager call
Expand All @@ -154,6 +154,27 @@ def test_named_input_dict_io(self):
out_val = model(in_val)
self.assertEqual(out_val.shape, (2, 4))

# Two inputs
input_a = Input(shape=(3,), batch_size=2, name="a")
input_b = Input(shape=(4,), batch_size=2, name="b")
a = layers.Dense(5)(input_a)
b = layers.Dense(5)(input_b)
x = layers.Concatenate()([a, b])
outputs = layers.Dense(4)(x)
model = Functional([input_a, input_b], outputs)

# Eager call
in_val = {"a": np.random.random((2, 3)), "b": np.random.random((2, 4))}
out_val = model(in_val)
self.assertEqual(out_val.shape, (2, 4))

# Symbolic call
input_a_2 = Input(shape=(3,), batch_size=2)
input_b_2 = Input(shape=(4,), batch_size=2)
in_val = {"a": input_a_2, "b": input_b_2}
out_val = model(in_val)
self.assertEqual(out_val.shape, (2, 4))

@pytest.mark.requires_trainable_backend
def test_input_dict_with_extra_field(self):
input_a = Input(shape=(3,), batch_size=2, name="a")
Expand All @@ -180,7 +201,7 @@ def test_input_dict_with_extra_field(self):
self.assertLen(record, 1)
self.assertStartsWith(
str(record[0].message),
r"The structure of `inputs` doesn't match the expected structure:",
r"The structure of `inputs` doesn't match the expected structure",
)

@parameterized.named_parameters(
Expand Down Expand Up @@ -547,3 +568,92 @@ def test_layers_setter(self):
AttributeError, "`Model.layers` attribute is reserved"
):
model.layers = [layers.Dense(4)]

def test_dict_input_to_list_model(self):
vocabulary_size = 100
num_tags = 10
num_departments = 3
num_samples = 128

title = layers.Input(shape=(vocabulary_size,), name="title")
text_body = layers.Input(shape=(vocabulary_size,), name="text_body")
tags = layers.Input(shape=(num_tags,), name="tags")
features = layers.Concatenate()([title, text_body, tags])
features = layers.Dense(64, activation="relu")(features)
priority = layers.Dense(1, activation="sigmoid", name="priority")(
features
)
department = layers.Dense(
num_departments, activation="softmax", name="department"
)(features)
model = Functional(
inputs=[title, text_body, tags], outputs=[priority, department]
)

title_data = np.random.randint(
0, 2, size=(num_samples, vocabulary_size)
)
text_body_data = np.random.randint(
0, 2, size=(num_samples, vocabulary_size)
)
tags_data = np.random.randint(0, 2, size=(num_samples, num_tags))
priority_data = np.random.random(size=(num_samples, 1))
department_data = np.random.randint(
0, 2, size=(num_samples, num_departments)
)

# List style fit
model.compile(
optimizer="adam",
loss=["mean_squared_error", "categorical_crossentropy"],
metrics=[["mean_absolute_error"], ["accuracy"]],
)
model.fit(
[title_data, text_body_data, tags_data],
[priority_data, department_data],
epochs=1,
)
model.evaluate(
[title_data, text_body_data, tags_data],
[priority_data, department_data],
)
priority_preds, department_preds = model.predict(
[title_data, text_body_data, tags_data]
)

# Dict style fit
model.compile(
optimizer="adam",
loss={
"priority": "mean_squared_error",
"department": "categorical_crossentropy",
},
metrics={
"priority": ["mean_absolute_error"],
"department": ["accuracy"],
},
)
model.fit(
{
"title": title_data,
"text_body": text_body_data,
"tags": tags_data,
},
{"priority": priority_data, "department": department_data},
epochs=1,
)
model.evaluate(
{
"title": title_data,
"text_body": text_body_data,
"tags": tags_data,
},
{"priority": priority_data, "department": department_data},
)
priority_preds, department_preds = model.predict(
{
"title": title_data,
"text_body": text_body_data,
"tags": tags_data,
}
)
75 changes: 48 additions & 27 deletions keras/src/trainers/compile_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,6 @@ def __init__(
f"Received instead: loss_weights={loss_weights} "
f"of type {type(loss_weights)}"
)

self._user_loss = loss
self._user_loss_weights = loss_weights
self.built = False
Expand Down Expand Up @@ -566,8 +565,21 @@ def key_check_fn(key, objs):
)

def build(self, y_true, y_pred):
loss = self._user_loss
loss_weights = self._user_loss_weights
if (
self.output_names
and isinstance(self._user_loss, dict)
and not isinstance(y_pred, dict)
):
loss = [self._user_loss[name] for name in self.output_names]
if isinstance(self._user_loss_weights, dict):
loss_weights = [
self._user_loss_weights[name] for name in self.output_names
]
else:
loss_weights = self._user_loss_weights
else:
loss = self._user_loss
loss_weights = self._user_loss_weights
flat_output_names = self.output_names

# Pytree leaf container
Expand Down Expand Up @@ -691,26 +703,30 @@ def call(self, y_true, y_pred, sample_weight=None):
try:
tree.assert_same_structure(y_pred, y_true, check_types=False)
except ValueError:
# Check case where y_true is either flat or leaf
if (
not tree.is_nested(y_true)
and hasattr(y_pred, "__len__")
and len(y_pred) == 1
):
y_true = [y_true]

# Check case where y_pred is list/tuple and y_true is dict
elif isinstance(y_pred, (list, tuple)) and isinstance(y_true, dict):
if set(self.output_names) == set(y_true.keys()):
y_true = [y_true[name] for name in self.output_names]

try:
# Check case where y_true is either flat or leaf
if (
not tree.is_nested(y_true)
and hasattr(y_pred, "__len__")
and len(y_pred) == 1
):
y_true = [y_true]
y_true = tree.pack_sequence_as(y_pred, y_true)
except:
# Check case where y_true has the same structure but uses
# different (but reconcilable) container types,
# e.g `list` vs `tuple`.
try:
y_true = tree.pack_sequence_as(y_pred, y_true)
tree.assert_same_paths(y_true, y_pred)
y_true = tree.pack_sequence_as(y_pred, tree.flatten(y_true))
except:
# Check case where y_true has the same structure but uses
# different (but reconcilable) container types,
# e.g `list` vs `tuple`.
try:
tree.assert_same_paths(y_true, y_pred)
y_true = tree.pack_sequence_as(
y_pred, tree.flatten(y_true)
)
except:
# Check case where loss is partially defined over y_pred
flat_y_true = tree.flatten(y_true)
flat_loss = tree.flatten(self._user_loss)
Expand All @@ -726,14 +742,18 @@ def call(self, y_true, y_pred, sample_weight=None):
):
y_true[i] = y_t
y_true = tree.pack_sequence_as(self._user_loss, y_true)
except:
y_true_struct = tree.map_structure(lambda _: "*", y_true)
y_pred_struct = tree.map_structure(lambda _: "*", y_pred)
raise ValueError(
"y_true and y_pred have different structures.\n"
f"y_true: {y_true_struct}\n"
f"y_pred: {y_pred_struct}\n"
)
except:
y_true_struct = tree.map_structure(
lambda _: "*", y_true
)
y_pred_struct = tree.map_structure(
lambda _: "*", y_pred
)
raise ValueError(
"y_true and y_pred have different structures.\n"
f"y_true: {y_true_struct}\n"
f"y_pred: {y_pred_struct}\n"
)

if not self.built:
self.build(y_true, y_pred)
Expand Down Expand Up @@ -774,6 +794,7 @@ def resolve_path(path, object):
_sample_weight = resolve_path(path, sample_weight)
else:
_sample_weight = sample_weight

value = ops.cast(
loss_fn(y_t, y_p, _sample_weight), dtype=self.dtype
)
Expand Down

0 comments on commit 1e8426b

Please sign in to comment.