Skip to content

Commit

Permalink
remove NLL in metric (apache#18794)
Browse files Browse the repository at this point in the history
  • Loading branch information
acphile authored Jul 27, 2020
1 parent 9e77e81 commit 74430a9
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 85 deletions.
91 changes: 13 additions & 78 deletions python/mxnet/gluon/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -1352,9 +1352,12 @@ class :math:`k`.
Index of invalid label to ignore when
counting. By default, sets to -1.
If set to `None`, it will include all entries.
axis : int (default -1)
axis : int, default -1
The axis from prediction that was used to
compute softmax. By default use the last axis.
from_logits : boolean, default False
Whether `pred` is expected to be a logits tensor.
By default, we assume that `pred` encodes a probability distribution.
name : str
Name of this metric instance for display.
output_names : list of str, or None
Expand All @@ -1373,12 +1376,13 @@ class :math:`k`.
>>> print ce.get()
('cross-entropy', 0.57159948348999023)
"""
def __init__(self, eps=1e-12, ignore_label=None, axis=-1, name='cross-entropy',
output_names=None, label_names=None):
def __init__(self, eps=1e-12, ignore_label=None, axis=-1, from_logits=False,
name='cross-entropy', output_names=None, label_names=None):
super(CrossEntropy, self).__init__(
name, output_names=output_names, label_names=label_names)
self.ignore_label = ignore_label
self.axis = axis
self.from_logits = from_logits
self.eps = eps

def update(self, labels, preds):
Expand All @@ -1400,6 +1404,8 @@ def update(self, labels, preds):
assert label.size == pred.size/pred.shape[-1], \
"shape mismatch: %s vs. %s"%(label.shape, pred.shape)
label = label.reshape((label.size,))
if self.from_logits:
pred = ndarray.softmax(pred, axis=self.axis)
pred = ndarray.pick(pred.as_in_context(label.ctx), label.astype(dtype='int32'), axis=self.axis)
label = label.as_np_ndarray()
pred = pred.as_np_ndarray()
Expand Down Expand Up @@ -1469,11 +1475,11 @@ class Perplexity(CrossEntropy):
>>> print perp.get()
('Perplexity', 1.7710976285155853)
"""
def __init__(self, eps=1e-12, ignore_label=None, axis=-1, name='perplexity',
output_names=None, label_names=None):
def __init__(self, eps=1e-12, ignore_label=None, axis=-1, from_logits=False,
name='perplexity', output_names=None, label_names=None):
super(Perplexity, self).__init__(
name=name, eps=eps, ignore_label=ignore_label, axis=axis,
output_names=output_names, label_names=label_names)
eps=eps, ignore_label=ignore_label, axis=axis, from_logits=from_logits,
name=name, output_names=output_names, label_names=label_names)

def get(self):
if self.num_inst == 0:
Expand All @@ -1482,77 +1488,6 @@ def get(self):
return (self.name, math.exp(self.sum_metric/self.num_inst))


@register
@alias('nll_loss')
@use_np
class NegativeLogLikelihood(EvalMetric):
"""Computes the negative log-likelihood loss.
The negative log-likelihoodd loss over a batch of sample size :math:`N` is given by
.. math::
-\\sum_{n=1}^{N}\\sum_{k=1}^{K}t_{nk}\\log (y_{nk}),
where :math:`K` is the number of classes, :math:`y_{nk}` is the prediceted probability for
:math:`k`-th class for :math:`n`-th sample. :math:`t_{nk}=1` if and only if sample
:math:`n` belongs to class :math:`k`.
Parameters
----------
eps : float
Negative log-likelihood loss is undefined for predicted value is 0,
so predicted values are added with the small constant.
name : str
Name of this metric instance for display.
output_names : list of str, or None
Name of predictions that should be used when updating with update_dict.
By default include all predictions.
label_names : list of str, or None
Name of labels that should be used when updating with update_dict.
By default include all labels.
Examples
--------
>>> predicts = [mx.nd.array([[0.3, 0.7], [0, 1.], [0.4, 0.6]])]
>>> labels = [mx.nd.array([0, 1, 1])]
>>> nll_loss = mx.gluon.metric.NegativeLogLikelihood()
>>> nll_loss.update(labels, predicts)
>>> print nll_loss.get()
('nll-loss', 0.57159948348999023)
"""
def __init__(self, eps=1e-12, name='nll-loss',
output_names=None, label_names=None):
super(NegativeLogLikelihood, self).__init__(
name, eps=eps,
output_names=output_names, label_names=label_names)
self.eps = eps

def update(self, labels, preds):
"""Updates the internal evaluation result.
Parameters
----------
labels : list of `NDArray`
The labels of the data.
preds : list of `NDArray`
Predicted values.
"""
labels, preds = check_label_shapes(labels, preds, True)

for label, pred in zip(labels, preds):
label = label.as_np_ndarray()
pred = pred.as_np_ndarray().as_in_ctx(label.ctx)

label = label.reshape(-1)
num_examples = pred.shape[0]
assert label.shape[0] == num_examples, (label.shape[0], num_examples)
prob = pred[numpy.arange(num_examples, dtype=numpy.int64), numpy.int64(label)]
nll = (-numpy.log(prob + self.eps)).sum()
self.sum_metric += nll
self.num_inst += num_examples


@register
@alias('pearsonr')
@use_np
Expand Down
20 changes: 13 additions & 7 deletions tests/python/unittest/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,25 @@ def test_metrics():
check_metric('perplexity', axis=-1)
check_metric('pearsonr')
check_metric('pcc')
check_metric('nll_loss')
check_metric('ce')
check_metric('loss')
composite = mx.gluon.metric.create(['acc', 'f1'])
check_metric(composite)

def test_nll_loss():
metric = mx.gluon.metric.create('nll_loss')
def test_ce():
metric = mx.gluon.metric.create('ce')
pred = mx.nd.array([[0.2, 0.3, 0.5], [0.6, 0.1, 0.3]])
label = mx.nd.array([2, 1])
metric.update([label], [pred])
_, loss = metric.get()
expected_loss = -(np.log(pred[0][2].asscalar()) + np.log(pred[1][1].asscalar())) / 2
assert loss == expected_loss
metric = mx.gluon.metric.create('ce', from_logits=True)
pred = mx.nd.log(pred)
metric.update([label], [pred])
_, loss = metric.get()
np.testing.assert_almost_equal(loss, expected_loss)


def test_acc():
pred = mx.nd.array([[0.3, 0.7], [0, 1.], [0.4, 0.6]])
Expand Down Expand Up @@ -159,15 +165,15 @@ def test_multiclass_f1():
macroF1.update([label11, label12], [pred11, pred12])
assert microF1.num_inst == 6
assert macroF1.num_inst == 6

# from sklearn.metrics import f1_score
# overall_pred = [0, 1, 2, 0, 1, 2]
# overall_label = [0, 2, 1, 0, 0, 1]
fmacro = 0.26666666666666666 #f1_score(overall_label, overall_pred, average="macro")
fmicro = 0.3333333333333333 #f1_score(overall_label, overall_pred, average="micro")
np.testing.assert_almost_equal(microF1.get()[1], fmicro)
np.testing.assert_almost_equal(macroF1.get()[1], fmacro)

@xfail_when_nonstandard_decimal_separator
def test_multilabel_f1():
microF1 = mx.gluon.metric.create("f1", class_type="multilabel", average="micro")
Expand All @@ -183,7 +189,7 @@ def test_multilabel_f1():
macroF1.update([label], [pred])
microF1.update([label], [pred])
assert macroF1.get()[1] == 0.5 # one class is 1.0, the other is 0. (divided by 0)
np.testing.assert_almost_equal(microF1.get()[1], 2.0 / 3)
np.testing.assert_almost_equal(microF1.get()[1], 2.0 / 3)
macroF1.reset()
microF1.reset()

Expand All @@ -209,7 +215,7 @@ def test_mcc():
microMCC = mx.gluon.metric.create("mcc")

assert np.isnan(microMCC.get()[1])

# check divide by zero
pred = mx.nd.array([[0.9, 0.1],
[0.8, 0.2]])
Expand Down

0 comments on commit 74430a9

Please sign in to comment.