Skip to content

Commit

Permalink
Add evaluation_loss to the estimator base class. (apache#16888)
Browse files Browse the repository at this point in the history
* Add evaluation_loss to the estimator base class.

* Update the base estimator class to support the separate evaluation loss.

* Add evaluation loss to the base estimator class.

* Add unittest for evaluation loss in the test_evaluation function

* Update estimator.py

* Update estimator.py
  • Loading branch information
liuzh47 authored and ptrendx committed Nov 25, 2019
1 parent 39821a4 commit 1232c75
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 3 deletions.
11 changes: 9 additions & 2 deletions python/mxnet/gluon/contrib/estimator/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ class Estimator(object):
Trainer to apply optimizer on network parameters.
context : Context or list of Context
Device(s) to run the training on.
evaluation_loss: gluon.loss.loss
Loss (objective) function to calculate during evaluation. If set evaluation_loss
None, it will use the same loss function as self.loss
"""

Expand All @@ -85,12 +88,16 @@ def __init__(self, net,
metrics=None,
initializer=None,
trainer=None,
context=None):
context=None,
evaluation_loss=None):
self.net = net
self.loss = self._check_loss(loss)
self._train_metrics = _check_metrics(metrics)
self._add_default_training_metrics()
self._add_validation_metrics()
self.evaluation_loss = self.loss
if evaluation_loss is not None:
self.evaluation_loss = self._check_loss(evaluation_loss)

self.logger = logging.Logger(name='Estimator', level=logging.INFO)
self.logger.addHandler(logging.StreamHandler(sys.stdout))
Expand Down Expand Up @@ -228,7 +235,7 @@ def evaluate_batch(self,
"""
data, label = self._get_data_and_label(val_batch, self.context, batch_axis)
pred = [self.net(x) for x in data]
loss = [self.loss(y_hat, y) for y_hat, y in zip(pred, label)]
loss = [self.evaluation_loss(y_hat, y) for y_hat, y in zip(pred, label)]
# update metrics
for metric in val_metrics:
if isinstance(metric, metric_loss):
Expand Down
4 changes: 3 additions & 1 deletion tests/python/unittest/test_gluon_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,15 @@ def test_validation():
ctx = mx.cpu()
loss = gluon.loss.L2Loss()
acc = mx.metric.Accuracy()
evaluation_loss = gluon.loss.L1Loss()
net.initialize(ctx=ctx)
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001})
est = Estimator(net=net,
loss=loss,
metrics=acc,
trainer=trainer,
context=ctx)
context=ctx,
evaluation_loss=evaluation_loss)
# Input dataloader
est.fit(train_data=dataloader,
val_data=dataloader,
Expand Down

0 comments on commit 1232c75

Please sign in to comment.