Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix filteration logic for eval results #10810

Merged
merged 12 commits into from
Dec 3, 2021
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed a consolidation error in Lite when attempting to save the state dict of a sharded optimizer ([#10746](https://github.com/PyTorchLightning/pytorch-lightning/pull/10746))


- Fixed an issue to return the results for each dataloader separately instead of duplicating them for each ([#10810](https://github.com/PyTorchLightning/pytorch-lightning/pull/10810))


-


Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,8 @@ def log(
on_epoch=on_epoch,
reduce_fx=reduce_fx,
enable_graph=enable_graph,
dataloader_idx=(self._current_dataloader_idx if add_dataloader_idx else None),
add_dataloader_idx=add_dataloader_idx,
dataloader_idx=self._current_dataloader_idx,
batch_size=batch_size,
sync_dist=sync_dist and distributed_available(),
sync_dist_fn=self.trainer.training_type_plugin.reduce or sync_ddp,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,21 +155,19 @@ def update_eval_step_metrics(self) -> None:
# increment the step even if nothing was logged
self._increment_eval_log_step()

@staticmethod
def _filter_metrics_for_dataloader(
dl_idx: int, metrics: _OUT_DICT, metric_prefix: str = "dataloader_idx"
) -> _OUT_DICT:
return {k: v for k, v in metrics.items() if metric_prefix not in k or k.endswith(f"{metric_prefix}_{dl_idx}")}

def _prepare_eval_loop_results(self, metrics: _OUT_DICT) -> None:
def _prepare_eval_loop_results(self) -> None:
if self.trainer.sanity_checking:
return

on_step = not self._epoch_end_reached
num_dataloaders = self.trainer._evaluation_loop.num_dataloaders
has_been_initialized = len(self.eval_loop_results) == num_dataloaders
for dl_idx in range(self.trainer._evaluation_loop.num_dataloaders):
# remove callback metrics that don't belong to this dataloader
callback_metrics = self._filter_metrics_for_dataloader(dl_idx, metrics)
assert self.trainer._results is not None
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
for dl_idx in range(num_dataloaders):
callback_metrics = self.trainer._results.metrics(
on_step, dataloader_idx=dl_idx if num_dataloaders > 1 else None
)["callback"]
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved

if has_been_initialized:
self.eval_loop_results[dl_idx].update(callback_metrics)
else:
Expand All @@ -183,7 +181,7 @@ def update_eval_epoch_metrics(self) -> List[_OUT_DICT]:
# log all the metrics as a single dict
self.log_metrics(metrics["log"])

self._prepare_eval_loop_results(metrics["callback"])
self._prepare_eval_loop_results()

# log results of evaluation
if (
Expand Down
14 changes: 10 additions & 4 deletions pytorch_lightning/trainer/connectors/logger_connector/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ class _Metadata:
on_epoch: bool = True
reduce_fx: Callable = torch.mean
enable_graph: bool = False
add_dataloader_idx: bool = True
dataloader_idx: Optional[int] = None
metric_attribute: Optional[str] = None
_sync: Optional[_Sync] = None
Expand Down Expand Up @@ -432,6 +433,7 @@ def log(
sync_dist: bool = False,
sync_dist_fn: Callable = _Sync.no_op,
sync_dist_group: Optional[Any] = None,
add_dataloader_idx: bool = True,
dataloader_idx: Optional[int] = None,
batch_size: Optional[int] = None,
metric_attribute: Optional[str] = None,
Expand All @@ -449,7 +451,7 @@ def log(
# storage key
key = f"{fx}.{name}"
# add dataloader_suffix to both key and fx
if dataloader_idx is not None:
if add_dataloader_idx and dataloader_idx is not None:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
key += f".{dataloader_idx}"
fx += f".{dataloader_idx}"

Expand All @@ -462,6 +464,7 @@ def log(
on_epoch=on_epoch,
reduce_fx=reduce_fx,
enable_graph=enable_graph,
add_dataloader_idx=add_dataloader_idx,
dataloader_idx=dataloader_idx,
metric_attribute=metric_attribute,
)
Expand Down Expand Up @@ -527,14 +530,15 @@ def valid_items(self) -> Generator:
def _forked_name(self, result_metric: ResultMetric, on_step: bool) -> Tuple[str, str]:
name = result_metric.meta.name
forked_name = result_metric.meta.forked_name(on_step)
add_datalaoder_idx = result_metric.meta.add_dataloader_idx
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
dl_idx = result_metric.meta.dataloader_idx
if dl_idx is not None:
if add_datalaoder_idx and dl_idx is not None:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
dataloader_suffix = self.DATALOADER_SUFFIX.format(dl_idx)
name += dataloader_suffix
forked_name += dataloader_suffix
return name, forked_name

def metrics(self, on_step: bool) -> _METRICS:
def metrics(self, on_step: bool, dataloader_idx: Optional[int] = None) -> _METRICS:
metrics = _METRICS(callback={}, log={}, pbar={})

for _, result_metric in self.valid_items():
Expand Down Expand Up @@ -564,7 +568,9 @@ def any_tensor(_: Any) -> None:
metrics["log"][forked_name] = value

# populate callback metrics. callback metrics don't take `_step` forked metrics
if self.training or result_metric.meta.on_epoch and not on_step:
if (self.training or result_metric.meta.on_epoch and not on_step) and (
dataloader_idx is None or result_metric.meta.dataloader_idx == dataloader_idx
):
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
metrics["callback"][name] = value
metrics["callback"][forked_name] = value

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1354,7 +1354,7 @@ def __set_ckpt_path(self, ckpt_path: Optional[str], model_provided: bool, model_
" The best model of the previous `fit` call will be used."
f" You can pass `{fn}(ckpt_path='best')` to use and best model"
" checkpoint and avoid this warning or"
" `ckpt_path=trainer.model_checkpoint.last_model_path` to use the last model."
" `ckpt_path=trainer.checkpoint_callback.last_model_path` to use the last model."
)
ckpt_path = "best"

Expand Down
65 changes: 38 additions & 27 deletions tests/trainer/logging_/test_eval_loop_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

from pytorch_lightning import callbacks, Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel, RandomDataset
from tests.helpers.runif import RunIf
Expand Down Expand Up @@ -676,32 +675,6 @@ def val_dataloader(self):
trainer.fit(model)


@pytest.mark.parametrize(
["kwargs", "expected"],
[
carmocca marked this conversation as resolved.
Show resolved Hide resolved
({"dl_idx": 0, "metrics": {"acc": 123}}, {"acc": 123}),
(
{"dl_idx": 0, "metrics": {"acc/dataloader_idx_0": 123, "acc/dataloader_idx_1": 321}},
{"acc/dataloader_idx_0": 123},
),
(
{"dl_idx": 10, "metrics": {"acc/dataloader_idx_1": 123, "acc/dataloader_idx_10": 321}},
{"acc/dataloader_idx_10": 321},
),
(
{"dl_idx": 3, "metrics": {"top_3_acc/dataloader_idx_0": 123, "top_3_acc/dataloader_idx_3": 321}},
{"top_3_acc/dataloader_idx_3": 321},
),
# theoretical case, as `/dataloader_idx_3` would have been added
({"dl_idx": 3, "metrics": {"top_3_acc": 123}}, {"top_3_acc": 123}),
],
)
def test_filter_metrics_for_dataloader(kwargs, expected):
carmocca marked this conversation as resolved.
Show resolved Hide resolved
"""Logged metrics should only include metrics from the concerned dataloader."""
actual = LoggerConnector._filter_metrics_for_dataloader(**kwargs)
assert actual == expected


@RunIf(min_gpus=1)
def test_evaluation_move_metrics_to_cpu_and_outputs(tmpdir):
class TestModel(BoringModel):
Expand All @@ -723,3 +696,41 @@ def validation_epoch_end(self, outputs):
model = TestModel()
trainer = Trainer(default_root_dir=tmpdir, limit_val_batches=2, move_metrics_to_cpu=True, gpus=1)
trainer.validate(model, verbose=False)


def test_logging_results_with_no_dataloader_idx(tmpdir):
num_dataloaders = 2
log_key_common = "test_log_common"
log_key_no_dl_idx = "test_log_no_dl_idx_{}"
log_key_dl0 = "test_log_a_class"
log_key_dl1 = "test_log_b_class"

class CustomBoringModel(BoringModel):
def test_step(self, batch, batch_idx, dataloader_idx):
logits = self.layer(batch)
loss = torch.nn.functional.mse_loss(logits, torch.randn(*logits.shape))
self.log(log_key_common, loss)
self.log(log_key_no_dl_idx.format(dataloader_idx), loss, add_dataloader_idx=False)

if dataloader_idx == 0:
self.log(log_key_dl0, loss, add_dataloader_idx=False)
else:
self.log(log_key_dl1, loss, add_dataloader_idx=False)

return {"y": loss}

def test_dataloader(self):
return [torch.utils.data.DataLoader(RandomDataset(32, 64)) for _ in range(num_dataloaders)]

model = CustomBoringModel()
model.test_epoch_end = None
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=2)
results = trainer.test(model)

assert len(results) == num_dataloaders
for dl_idx in range(num_dataloaders):
assert results[dl_idx].keys() == {
f"{log_key_common}/dataloader_idx_{dl_idx}",
log_key_no_dl_idx.format(dl_idx),
log_key_dl0 if dl_idx == 0 else log_key_dl1,
}