Skip to content

Commit

Permalink
Fix log_every_n_steps check in ThroughputMonitor (#19470)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Mar 1, 2024
1 parent dde9818 commit 32cd1f9
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 37 deletions.
1 change: 1 addition & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Fixed

- Fixed an issue with CSVLogger trying to append to file from a previous run when the version is set manually ([#19446](https://github.com/Lightning-AI/lightning/pull/19446))
- Fixed the divisibility check for `Trainer.accumulate_grad_batches` and `Trainer.log_every_n_steps` in ThroughputMonitor ([#19470](https://github.com/Lightning-AI/lightning/pull/19470))


## [2.2.0] - 2024-02-08
Expand Down
19 changes: 4 additions & 15 deletions src/lightning/pytorch/callbacks/throughput_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,21 +93,10 @@ def setup(self, trainer: "Trainer", pl_module: "LightningModule", stage: str) ->
dtype = _plugin_to_compute_dtype(trainer.precision_plugin)
self.available_flops = get_available_flops(trainer.strategy.root_device, dtype)

if stage == TrainerFn.FITTING:
if trainer.accumulate_grad_batches % trainer.log_every_n_steps != 0:
raise ValueError(
"The `ThroughputMonitor` only logs when gradient accumulation is finished. You set"
f" `Trainer(accumulate_grad_batches={trainer.accumulate_grad_batches},"
f" log_every_n_steps={trainer.log_every_n_steps})` but these are not divisible and thus will not"
" log anything."
)

if trainer.enable_validation:
# `fit` includes validation inside
throughput = Throughput(
available_flops=self.available_flops, world_size=trainer.world_size, **self.kwargs
)
self._throughputs[RunningStage.VALIDATING] = throughput
if stage == TrainerFn.FITTING and trainer.enable_validation:
# `fit` includes validation inside
throughput = Throughput(available_flops=self.available_flops, world_size=trainer.world_size, **self.kwargs)
self._throughputs[RunningStage.VALIDATING] = throughput

throughput = Throughput(available_flops=self.available_flops, world_size=trainer.world_size, **self.kwargs)
stage = trainer.state.stage
Expand Down
39 changes: 17 additions & 22 deletions tests/tests_pytorch/callbacks/test_throughput_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,8 @@ def test_throughput_monitor_fit_no_length_fn(tmp_path):
]


def test_throughput_monitor_fit_gradient_accumulation(tmp_path):
@pytest.mark.parametrize("log_every_n_steps", [1, 3])
def test_throughput_monitor_fit_gradient_accumulation(log_every_n_steps, tmp_path):
logger_mock = Mock()
logger_mock.save_dir = tmp_path
monitor = ThroughputMonitor(length_fn=lambda x: 3 * 2, batch_size_fn=lambda x: 3, window_size=4, separator="|")
Expand All @@ -174,26 +175,8 @@ def test_throughput_monitor_fit_gradient_accumulation(tmp_path):
limit_train_batches=5,
limit_val_batches=0,
max_epochs=2,
log_every_n_steps=3,
log_every_n_steps=log_every_n_steps,
accumulate_grad_batches=2,
num_sanity_val_steps=2,
enable_checkpointing=False,
enable_model_summary=False,
enable_progress_bar=False,
)
with pytest.raises(ValueError, match="not divisible"):
trainer.fit(model)

trainer = Trainer(
devices=1,
logger=logger_mock,
callbacks=monitor,
limit_train_batches=5,
limit_val_batches=0,
max_epochs=2,
log_every_n_steps=1,
accumulate_grad_batches=2,
num_sanity_val_steps=2,
enable_checkpointing=False,
enable_model_summary=False,
enable_progress_bar=False,
Expand All @@ -211,9 +194,19 @@ def test_throughput_monitor_fit_gradient_accumulation(tmp_path):
"train|device|flops_per_sec": 10.0,
"train|device|mfu": 0.1,
}
assert logger_mock.log_metrics.mock_calls == [

all_log_calls = [
call(
metrics={"train|time": 2.5, "train|batches": 2, "train|samples": 6, "train|lengths": 12, "epoch": 0}, step=0
metrics={
# The very first batch doesn't have the *_per_sec metrics yet
**(expected if log_every_n_steps > 1 else {}),
"train|time": 2.5,
"train|batches": 2,
"train|samples": 6,
"train|lengths": 12,
"epoch": 0,
},
step=0,
),
call(
metrics={
Expand Down Expand Up @@ -271,6 +264,8 @@ def test_throughput_monitor_fit_gradient_accumulation(tmp_path):
step=5,
),
]
expected_log_calls = all_log_calls[(log_every_n_steps - 1) :: log_every_n_steps]
assert logger_mock.log_metrics.mock_calls == expected_log_calls


@pytest.mark.parametrize("fn", ["validate", "test", "predict"])
Expand Down

0 comments on commit 32cd1f9

Please sign in to comment.