Skip to content

Commit

Permalink
add weighted average to results obj (Lightning-AI#2930)
Browse files Browse the repository at this point in the history
* track batch size in result obj
  • Loading branch information
williamFalcon authored and atee committed Aug 17, 2020
1 parent d2a746a commit 4a2d62f
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 9 deletions.
29 changes: 27 additions & 2 deletions pytorch_lightning/core/step_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ def __init__(

self['meta'] = {
'_internal': {
'_reduce_on_epoch': False
'_reduce_on_epoch': False,
'batch_sizes': []
}
}

Expand Down Expand Up @@ -166,6 +167,14 @@ def __set_meta(
_internal = self['meta']['_internal']
_internal['_reduce_on_epoch'] = max(_internal['_reduce_on_epoch'], on_epoch)

def track_batch_size(self, batch_size):
meta = self['meta']
meta['_internal']['batch_sizes'].append(batch_size)

def get_batch_sizes(self):
meta = self['meta']
return torch.tensor(meta['_internal']['batch_sizes'])

def get_callback_metrics(self) -> dict:
result = {
'early_stop_on': self.early_stop_on,
Expand Down Expand Up @@ -301,18 +310,27 @@ def padded_gather(cls, outputs):

@classmethod
def reduce_on_epoch_end(cls, outputs):
# get the batch sizes for all outputs
batch_sizes = torch.stack([x.get_batch_sizes() for x in outputs]).view(-1)

meta = outputs[0]['meta']
result = cls()
result = recursive_gather(outputs, result)
recursive_stack(result)


for k, option in meta.items():
if k == '_internal':
continue

if option['on_epoch']:
fx = option['reduce_fx']
result[k] = fx(result[k])
if fx == torch.mean:
reduced_val = weighted_mean(result[k], batch_sizes)
else:
reduced_val = fx(result[k])

result[k] = reduced_val

result['meta'] = meta
return result
Expand Down Expand Up @@ -713,3 +731,10 @@ def get_callback_metrics(self) -> dict:
}

return result


def weighted_mean(result, weights):
weights = weights.to(result.device)
numerator = torch.dot(result.float(), weights.t().float())
result = numerator / weights.sum().float()
return result
8 changes: 7 additions & 1 deletion pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,8 +331,14 @@ def _evaluate(
else:
output = self.evaluation_forward(model, batch, batch_idx, dataloader_idx, test_mode)

is_result_obj = isinstance(output, Result)

# track batch size for weighted average
if is_result_obj:
output.track_batch_size(len(batch))

# allow only EvalResult when using structured results (from val_step)
if isinstance(output, Result) and not isinstance(output, EvalResult):
if is_result_obj and not isinstance(output, EvalResult):
m = 'only EvalResults or dicts are allowed from validation_step'
raise MisconfigurationException(m)

Expand Down
13 changes: 7 additions & 6 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,16 +848,13 @@ def run_training_batch(self, batch, batch_idx):
# add metrics to loggers
if using_results_obj:
metrics_to_log = opt_closure_result.training_step_output.batch_log_metrics
else:
metrics_to_log = opt_closure_result.training_step_output.log_metrics
batch_log_metrics.append(metrics_to_log)

# add metrics to progress bar
if using_results_obj:
step_pbar_metrics = opt_closure_result.training_step_output.batch_pbar_metrics
else:
metrics_to_log = opt_closure_result.training_step_output.log_metrics
step_pbar_metrics = opt_closure_result.training_step_output.pbar_on_batch_end

# track metrics
batch_log_metrics.append(metrics_to_log)
if len(step_pbar_metrics) > 0:
self.add_progress_bar_metrics(step_pbar_metrics)

Expand Down Expand Up @@ -1018,6 +1015,10 @@ def optimizer_closure(self, split_batch, batch_idx, opt_idx, optimizer, hiddens)
training_step_output_for_epoch_end = training_step_output
is_result_obj = isinstance(training_step_output, Result)

# track batch size for weighted average
if is_result_obj:
training_step_output.track_batch_size(len(split_batch))

# don't allow EvalResult in the training_step
if isinstance(training_step_output, EvalResult):
raise MisconfigurationException('training_step cannot return EvalResult, '
Expand Down

0 comments on commit 4a2d62f

Please sign in to comment.