Skip to content

Commit

Permalink
Copy loss and metric to prevent side effect
Browse files Browse the repository at this point in the history
  • Loading branch information
Malo committed Apr 4, 2022
1 parent 2db5acf commit 1c57df1
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 2 deletions.
4 changes: 2 additions & 2 deletions keras/engine/compile_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def __init__(self, losses, loss_weights=None, output_names=None):
self._user_losses = losses
self._user_loss_weights = loss_weights

self._losses = losses
self._losses = copy.copy(losses)
self._loss_weights = loss_weights
self._per_output_metrics = None # Per-output losses become metrics.
self._loss_metric = metrics_mod.Mean(name='loss') # Total loss.
Expand Down Expand Up @@ -309,7 +309,7 @@ def __init__(self, metrics=None, weighted_metrics=None, output_names=None,
self._user_metrics = metrics
self._user_weighted_metrics = weighted_metrics

self._metrics = metrics
self._metrics = copy.copy(metrics)
self._weighted_metrics = weighted_metrics
self._built = False

Expand Down
25 changes: 25 additions & 0 deletions keras/engine/compile_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,19 @@ def test_nested_structure(self):
self.assertEqual(b_1_metric.name, 'b_1_loss')
self.assertEqual(b_1_metric.result().numpy(), 0.5)

def test_no_input_mutation(self):
loss = {'a': 'mae'}
loss_container = compile_utils.LossesContainer(loss)

y_t = {'a': tf.zeros((10, 1))}
y_p = {'a': tf.ones((10, 1)), 'b': tf.zeros((10, 1))}
sw = tf.convert_to_tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])

total_loss = loss_container(y_t, y_p, sample_weight=sw)
self.assertIsInstance(total_loss, tf.Tensor)
self.assertEqual(total_loss.numpy(), 0.5)
self.assertLen(loss, 1)

def test_broadcast_single_loss(self):
loss_container = compile_utils.LossesContainer('mse')

Expand Down Expand Up @@ -585,6 +598,18 @@ def test_nested_structure(self):
self.assertEqual(b_1_mse_metric.name, 'b_1_mse')
self.assertEqual(b_1_mse_metric.result().numpy(), 4.)

def test_no_input_mutation(self):
metric = {'a': 'mae'}
metric_container = compile_utils.MetricsContainer(metric)

y_t = {'a': tf.zeros((10, 1))}
y_p = {'a': tf.ones((10, 1)), 'b': tf.zeros((10, 1))}

metric_container.update_state(y_t, y_p)
self.assertLen(metric, 1)
mae_metric = metric_container.metrics[0]
self.assertEqual(mae_metric.result().numpy(), 1.)

def test_crossentropy(self):
metric_container = compile_utils.MetricsContainer('crossentropy')
y_t, y_p = tf.ones((10, 1)), tf.ones((10, 1))
Expand Down

0 comments on commit 1c57df1

Please sign in to comment.