Skip to content

Commit

Permalink
Deprecate aggregate_logging_outputs API (use reduce_metrics instead) (#…
Browse files Browse the repository at this point in the history
…1611)

Summary:
Pull Request resolved: #1611

Pull Request resolved: fairinternal/fairseq-py#974

Differential Revision: D19292402

Pulled By: myleott

fbshipit-source-id: d51327584e048d3e39c133e9ef57a791e0329a66
  • Loading branch information
myleott authored and facebook-github-bot committed Jan 11, 2020
1 parent 0ce722d commit 8679339
Show file tree
Hide file tree
Showing 23 changed files with 444 additions and 429 deletions.
28 changes: 12 additions & 16 deletions fairseq/criterions/adaptive_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.


import math

import torch.nn.functional as F

from fairseq import utils
from . import FairseqCriterion, register_criterion
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion


@register_criterion('adaptive_loss')
Expand Down Expand Up @@ -74,28 +74,24 @@ def forward(self, model, sample, reduce=True):
return loss, sample_size, logging_output

@staticmethod
def aggregate_logging_outputs(logging_outputs):
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get('loss', 0) for log in logging_outputs)
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
agg_output = {
'loss': loss_sum / sample_size / math.log(2) if sample_size > 0 else 0.,
'nll_loss': loss_sum / sample_size / math.log(2) if sample_size > 0 else 0.,
'ntokens': ntokens,
'nsentences': nsentences,
'sample_size': sample_size,
}

metrics.log_scalar('loss', loss_sum / sample_size / math.log(2), sample_size, round=3)
if sample_size != ntokens:
agg_output['nll_loss'] = loss_sum / ntokens / math.log(2) if ntokens > 0 else 0.
return agg_output
metrics.log_scalar('nll_loss', loss_sum / ntokens / math.log(2), ntokens, round=3)
metrics.log_derived('ppl', lambda meters: round(2**meters['nll_loss'].avg, 3))
else:
metrics.log_derived('ppl', lambda meters: round(2**meters['loss'].avg, 3))

@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `aggregate_logging_outputs`.
Setting this to True will improves distributed training speed.
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return True
4 changes: 2 additions & 2 deletions fairseq/criterions/binary_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
# LICENSE file in the root directory of this source tree.

import math

import numpy as np
import torch
import torch.nn.functional as F

from fairseq import utils

from . import FairseqCriterion, register_criterion
from fairseq.criterions import FairseqCriterion, register_criterion


@register_criterion('binary_cross_entropy')
Expand Down
6 changes: 5 additions & 1 deletion fairseq/criterions/composite_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch import nn

from fairseq import utils
from . import FairseqCriterion, register_criterion
from fairseq.criterions import FairseqCriterion, register_criterion


@register_criterion('composite_loss')
Expand Down Expand Up @@ -88,4 +88,8 @@ def forward(self, model, sample, reduce=True):
def aggregate_logging_outputs(logging_outputs):
return underlying_criterion.__class__.aggregate_logging_outputs(logging_outputs)

@staticmethod
def reduce_metrics(logging_outputs) -> None:
underlying_criterion.__class__.reduce_metrics(logging_outputs)

return _CompositeLoss(args, task, underlying_criterion)
28 changes: 12 additions & 16 deletions fairseq/criterions/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
# LICENSE file in the root directory of this source tree.

import math
import torch.nn.functional as F

from fairseq import utils
import torch.nn.functional as F

from . import FairseqCriterion, register_criterion
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion


@register_criterion('cross_entropy')
Expand All @@ -30,7 +30,6 @@ def forward(self, model, sample, reduce=True):
sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
logging_output = {
'loss': utils.item(loss.data) if reduce else loss.data,
'nll_loss': utils.item(loss.data) if reduce else loss.data,
'ntokens': sample['ntokens'],
'nsentences': sample['target'].size(0),
'sample_size': sample_size,
Expand All @@ -50,27 +49,24 @@ def compute_loss(self, model, net_output, sample, reduce=True):
return loss, loss

@staticmethod
def aggregate_logging_outputs(logging_outputs):
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get('loss', 0) for log in logging_outputs)
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
agg_output = {
'loss': loss_sum / sample_size / math.log(2) if sample_size > 0 else 0.,
'ntokens': ntokens,
'nsentences': nsentences,
'sample_size': sample_size,
}

metrics.log_scalar('loss', loss_sum / sample_size / math.log(2), sample_size, round=3)
if sample_size != ntokens:
agg_output['nll_loss'] = loss_sum / ntokens / math.log(2)
return agg_output
metrics.log_scalar('nll_loss', loss_sum / ntokens / math.log(2), ntokens, round=3)
metrics.log_derived('ppl', lambda meters: round(2**meters['nll_loss'].avg, 3))
else:
metrics.log_derived('ppl', lambda meters: round(2**meters['loss'].avg, 3))

@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `aggregate_logging_outputs`.
Setting this to True will improves distributed training speed.
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return True
29 changes: 26 additions & 3 deletions fairseq/criterions/fairseq_criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Dict, List

from torch.nn.modules.loss import _Loss

from fairseq import metrics, utils


class FairseqCriterion(_Loss):

Expand Down Expand Up @@ -34,15 +38,34 @@ def forward(self, model, sample, reduce=True):
raise NotImplementedError

@staticmethod
def aggregate_logging_outputs(logging_outputs):
def aggregate_logging_outputs(
logging_outputs: List[Dict[str, Any]],
) -> Dict[str, Any]:
"""Aggregate logging outputs from data parallel training."""
utils.deprecation_warning(
'The aggregate_logging_outputs API is deprecated. '
'Please use the reduce_metrics API instead.'
)
raise NotImplementedError

@classmethod
def reduce_metrics(cls, logging_outputs: List[Dict[str, Any]]) -> None:
"""Aggregate logging outputs from data parallel training."""
utils.deprecation_warning(
'Criterions should implement the reduce_metrics API. '
'Falling back to deprecated aggregate_logging_outputs API.'
)
agg_logging_outputs = cls.aggregate_logging_outputs(logging_outputs)
for k, v in agg_logging_outputs.items():
if k in {'nsentences', 'ntokens', 'sample_size'}:
continue
metrics.log_scalar(k, v)

@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `aggregate_logging_outputs`.
Setting this to True will improves distributed training speed.
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return False
25 changes: 11 additions & 14 deletions fairseq/criterions/label_smoothed_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@

import math

from fairseq import utils

from . import FairseqCriterion, register_criterion
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion


def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=None, reduce=True):
Expand Down Expand Up @@ -76,24 +75,22 @@ def compute_loss(self, model, net_output, sample, reduce=True):
return loss, nll_loss

@staticmethod
def aggregate_logging_outputs(logging_outputs):
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get('loss', 0) for log in logging_outputs)
nll_loss_sum = sum(log.get('nll_loss', 0) for log in logging_outputs)
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
return {
'loss': sum(log.get('loss', 0) for log in logging_outputs) / sample_size / math.log(2) if sample_size > 0 else 0.,
'nll_loss': sum(log.get('nll_loss', 0) for log in logging_outputs) / ntokens / math.log(2) if ntokens > 0 else 0.,
'ntokens': ntokens,
'nsentences': nsentences,
'sample_size': sample_size,
}

metrics.log_scalar('loss', loss_sum / sample_size / math.log(2), sample_size, round=3)
metrics.log_scalar('nll_loss', nll_loss_sum / ntokens / math.log(2), ntokens, round=3)
metrics.log_derived('ppl', lambda meters: round(2**meters['nll_loss'].avg, 3))

@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `aggregate_logging_outputs`.
Setting this to True will improves distributed training speed.
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return True
27 changes: 13 additions & 14 deletions fairseq/criterions/label_smoothed_cross_entropy_with_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

import math

from fairseq import utils
from fairseq import metrics, utils
from fairseq.criterions import register_criterion

from .label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterion
from . import register_criterion


@register_criterion('label_smoothed_cross_entropy_with_alignment')
Expand Down Expand Up @@ -75,25 +75,24 @@ def compute_alignment_loss(self, sample, net_output):
return loss

@staticmethod
def aggregate_logging_outputs(logging_outputs):
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
loss_sum = sum(log.get('loss', 0) for log in logging_outputs)
nll_loss_sum = sum(log.get('nll_loss', 0) for log in logging_outputs)
alignment_loss_sum = sum(log.get('alignment_loss', 0) for log in logging_outputs)
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
return {
'loss': sum(log.get('loss', 0) for log in logging_outputs) / sample_size / math.log(2) if sample_size > 0 else 0.,
'nll_loss': sum(log.get('nll_loss', 0) for log in logging_outputs) / ntokens / math.log(2) if ntokens > 0 else 0.,
'alignment_loss': sum(log.get('alignment_loss', 0) for log in logging_outputs) / sample_size / math.log(2) if sample_size > 0 else 0.,
'ntokens': ntokens,
'nsentences': nsentences,
'sample_size': sample_size,
}

metrics.log_scalar('loss', loss_sum / sample_size / math.log(2), sample_size, round=3)
metrics.log_scalar('nll_loss', nll_loss_sum / ntokens / math.log(2), ntokens, round=3)
metrics.log_scalar('alignment_loss', alignment_loss_sum / sample_size / math.log(2), sample_size, round=3)
metrics.log_derived('ppl', lambda meters: round(2**meters['nll_loss'].avg, 3))

@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `aggregate_logging_outputs`.
Setting this to True will improves distributed training speed.
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return True
8 changes: 4 additions & 4 deletions fairseq/criterions/legacy_masked_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.


import math

import torch
import torch.nn.functional as F

from fairseq import utils
from . import FairseqCriterion, register_criterion
from fairseq.criterions import FairseqCriterion, register_criterion


def compute_cross_entropy_loss(logits, targets, ignore_index=-100):
Expand Down Expand Up @@ -150,7 +150,7 @@ def aggregate_logging_outputs(logging_outputs):
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `aggregate_logging_outputs`.
Setting this to True will improves distributed training speed.
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return True
30 changes: 9 additions & 21 deletions fairseq/criterions/masked_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@
import torch
import torch.nn.functional as F

from fairseq import utils

from . import FairseqCriterion, register_criterion
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion


@register_criterion('masked_lm')
Expand All @@ -19,11 +18,9 @@ class MaskedLmLoss(FairseqCriterion):
Implementation for the loss used in masked language model (MLM) training.
"""

def __init__(self, args, task):
super().__init__(args, task)

def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
Expand Down Expand Up @@ -56,35 +53,26 @@ def forward(self, model, sample, reduce=True):
)
logging_output = {
'loss': utils.item(loss.data) if reduce else loss.data,
'nll_loss': utils.item(loss.data) if reduce else loss.data,
'ntokens': sample['ntokens'],
'nsentences': sample['nsentences'],
'sample_size': sample_size,
}
return loss, sample_size, logging_output

@staticmethod
def aggregate_logging_outputs(logging_outputs):
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
loss = sum(log.get('loss', 0) for log in logging_outputs)
ntokens = sum(log.get('ntokens', 0) for log in logging_outputs)
nsentences = sum(log.get('nsentences', 0) for log in logging_outputs)
loss_sum = sum(log.get('loss', 0) for log in logging_outputs)
sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)

agg_output = {
'loss': loss / sample_size / math.log(2),
'nll_loss': sum(log.get('nll_loss', 0) for log in logging_outputs) / sample_size / math.log(2) if ntokens > 0 else 0.,
'ntokens': ntokens,
'nsentences': nsentences,
'sample_size': sample_size,
}
return agg_output
metrics.log_scalar('loss', loss_sum / sample_size / math.log(2), sample_size, round=3)
metrics.log_derived('ppl', lambda meters: round(2**meters['loss'].avg, 3))

@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `aggregate_logging_outputs`.
Setting this to True will improves distributed training speed.
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return True
Loading

0 comments on commit 8679339

Please sign in to comment.