From 4a2d62f2d3f8081ebc297b312fbfdf4cc74c1906 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Wed, 12 Aug 2020 08:02:00 -0400 Subject: [PATCH] add weighted average to results obj (#2930) * track batch size in result obj --- pytorch_lightning/core/step_result.py | 29 ++++++++++++++++++-- pytorch_lightning/trainer/evaluation_loop.py | 8 +++++- pytorch_lightning/trainer/training_loop.py | 13 +++++---- 3 files changed, 41 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index ea62fdab2e9960..eea6e07822e049 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -36,7 +36,8 @@ def __init__( self['meta'] = { '_internal': { - '_reduce_on_epoch': False + '_reduce_on_epoch': False, + 'batch_sizes': [] } } @@ -166,6 +167,14 @@ def __set_meta( _internal = self['meta']['_internal'] _internal['_reduce_on_epoch'] = max(_internal['_reduce_on_epoch'], on_epoch) + def track_batch_size(self, batch_size): + meta = self['meta'] + meta['_internal']['batch_sizes'].append(batch_size) + + def get_batch_sizes(self): + meta = self['meta'] + return torch.tensor(meta['_internal']['batch_sizes']) + def get_callback_metrics(self) -> dict: result = { 'early_stop_on': self.early_stop_on, @@ -301,18 +310,27 @@ def padded_gather(cls, outputs): @classmethod def reduce_on_epoch_end(cls, outputs): + # get the batch sizes for all outputs + batch_sizes = torch.stack([x.get_batch_sizes() for x in outputs]).view(-1) + meta = outputs[0]['meta'] result = cls() result = recursive_gather(outputs, result) recursive_stack(result) + for k, option in meta.items(): if k == '_internal': continue if option['on_epoch']: fx = option['reduce_fx'] - result[k] = fx(result[k]) + if fx == torch.mean: + reduced_val = weighted_mean(result[k], batch_sizes) + else: + reduced_val = fx(result[k]) + + result[k] = reduced_val result['meta'] = meta return result @@ -713,3 +731,10 @@ def get_callback_metrics(self) -> dict: } return result + + +def weighted_mean(result, weights): + weights = weights.to(result.device) + numerator = torch.dot(result.float(), weights.t().float()) + result = numerator / weights.sum().float() + return result diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 433ea970877db6..7e90eb6dc6ec5d 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -331,8 +331,14 @@ def _evaluate( else: output = self.evaluation_forward(model, batch, batch_idx, dataloader_idx, test_mode) + is_result_obj = isinstance(output, Result) + + # track batch size for weighted average + if is_result_obj: + output.track_batch_size(len(batch)) + # allow only EvalResult when using structured results (from val_step) - if isinstance(output, Result) and not isinstance(output, EvalResult): + if is_result_obj and not isinstance(output, EvalResult): m = 'only EvalResults or dicts are allowed from validation_step' raise MisconfigurationException(m) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index ec5bd0938d15c7..8bd0dea62341c8 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -848,16 +848,13 @@ def run_training_batch(self, batch, batch_idx): # add metrics to loggers if using_results_obj: metrics_to_log = opt_closure_result.training_step_output.batch_log_metrics - else: - metrics_to_log = opt_closure_result.training_step_output.log_metrics - batch_log_metrics.append(metrics_to_log) - - # add metrics to progress bar - if using_results_obj: step_pbar_metrics = opt_closure_result.training_step_output.batch_pbar_metrics else: + metrics_to_log = opt_closure_result.training_step_output.log_metrics step_pbar_metrics = opt_closure_result.training_step_output.pbar_on_batch_end + # track metrics + batch_log_metrics.append(metrics_to_log) if len(step_pbar_metrics) > 0: self.add_progress_bar_metrics(step_pbar_metrics) @@ -1018,6 +1015,10 @@ def optimizer_closure(self, split_batch, batch_idx, opt_idx, optimizer, hiddens) training_step_output_for_epoch_end = training_step_output is_result_obj = isinstance(training_step_output, Result) + # track batch size for weighted average + if is_result_obj: + training_step_output.track_batch_size(len(split_batch)) + # don't allow EvalResult in the training_step if isinstance(training_step_output, EvalResult): raise MisconfigurationException('training_step cannot return EvalResult, '