Skip to content

Commit

Permalink
Keep "mse" as the metric name in the log (#812)
Browse files Browse the repository at this point in the history
* add the test

* add the fix

* fix other broken tests

---------

Co-authored-by: Haifeng Jin <[email protected]>
  • Loading branch information
haifeng-jin and haifeng-jin authored Aug 29, 2023
1 parent 4e35327 commit ca1a177
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 17 deletions.
9 changes: 4 additions & 5 deletions keras_core/models/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,17 +220,16 @@ def test_functional_list_outputs_list_losses_abbr(self):
hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)
hist_keys = sorted(hist.history.keys())
# TODO `tf.keras` also outputs individual losses for outputs
# TODO Align output names with 'bce', `mse`, `mae` of `tf.keras`
ref_keys = sorted(
[
"loss",
# "output_a_loss",
"output_a_binary_crossentropy",
"output_a_mean_absolute_error",
"output_a_mean_squared_error",
"output_a_bce",
"output_a_mae",
"output_a_mse",
"output_b_acc",
# "output_b_loss",
"output_b_mean_squared_error",
"output_b_mse",
]
)
self.assertListEqual(hist_keys, ref_keys)
Expand Down
15 changes: 8 additions & 7 deletions keras_core/trainers/compile_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,14 @@ def get_metric(identifier, y_true, y_pred):
)

if not isinstance(metric_obj, metrics_module.Metric):
if isinstance(identifier, str):
metric_name = identifier
else:
metric_name = get_object_name(metric_obj)
metric_obj = metrics_module.MeanMetricWrapper(
metric_obj, name=metric_name
)
metric_obj = metrics_module.MeanMetricWrapper(metric_obj)

if isinstance(identifier, str):
metric_name = identifier
else:
metric_name = get_object_name(metric_obj)
metric_obj.name = metric_name

return metric_obj


Expand Down
5 changes: 3 additions & 2 deletions keras_core/trainers/compile_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def test_dict_output_case(self):

def test_name_conversions(self):
compile_metrics = CompileMetrics(
metrics=["acc", "accuracy"],
metrics=["acc", "accuracy", "mse"],
weighted_metrics=[],
)
y_true = np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]])
Expand All @@ -207,9 +207,10 @@ def test_name_conversions(self):
compile_metrics.update_state(y_true, y_pred, sample_weight=None)
result = compile_metrics.result()
self.assertTrue(isinstance(result, dict))
self.assertEqual(len(result), 2)
self.assertEqual(len(result), 3)
self.assertAllClose(result["acc"], 0.333333)
self.assertAllClose(result["accuracy"], 0.333333)
self.assertTrue("mse" in result)


class TestCompileLoss(testing.TestCase, parameterized.TestCase):
Expand Down
12 changes: 9 additions & 3 deletions keras_core/trainers/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,9 @@ def on_predict_batch_end(self, batch, logs=None):
assert keys == ["outputs"]

model = ExampleModel(units=3)
model.compile(optimizer="adam", loss="mse", metrics=["mae"])
model.compile(
optimizer="adam", loss="mse", metrics=["mean_absolute_error"]
)
x = np.ones((16, 4))
y = np.zeros((16, 3))
x_test = np.ones((16, 4))
Expand Down Expand Up @@ -651,12 +653,16 @@ def test_recompile(self):
inputs = layers.Input((2,))
outputs = layers.Dense(3)(inputs)
model = keras_core.Model(inputs, outputs)
model.compile(optimizer="sgd", loss="mse", metrics=["mse"])
model.compile(
optimizer="sgd", loss="mse", metrics=["mean_squared_error"]
)
history_1 = model.fit(np.ones((3, 2)), np.ones((3, 3))).history
eval_out_1 = model.evaluate(
np.ones((3, 2)), np.ones((3, 3)), return_dict=True
)
model.compile(optimizer="sgd", loss="mse", metrics=["mae"])
model.compile(
optimizer="sgd", loss="mse", metrics=["mean_absolute_error"]
)
history_2 = model.fit(np.ones((3, 2)), np.ones((3, 3))).history
eval_out_2 = model.evaluate(
np.ones((3, 2)), np.ones((3, 3)), return_dict=True
Expand Down

0 comments on commit ca1a177

Please sign in to comment.