From 4da7c53917ffbf3afd37e769c396a7af89f329e8 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Sun, 7 Aug 2022 01:50:56 +0530 Subject: [PATCH 1/5] profile batch transfer and gradient clipping hooks --- src/pytorch_lightning/core/module.py | 27 ++++++++++++------- .../plugins/precision/precision_plugin.py | 8 +++--- .../trainer/connectors/data_connector.py | 14 +++++----- 3 files changed, 29 insertions(+), 20 deletions(-) diff --git a/src/pytorch_lightning/core/module.py b/src/pytorch_lightning/core/module.py index f58503edd88cb..a8caa95dfa5f1 100644 --- a/src/pytorch_lightning/core/module.py +++ b/src/pytorch_lightning/core/module.py @@ -37,7 +37,6 @@ from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.core.saving import ModelIO from pytorch_lightning.loggers import Logger, LoggerCollection -from pytorch_lightning.trainer.connectors.data_connector import _DataHookSelector from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import _FxValidator from pytorch_lightning.utilities import _IS_WINDOWS, _TORCH_GREATER_EQUAL_1_10, GradClipAlgorithmType, warnings from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors @@ -291,16 +290,24 @@ def _apply_batch_transfer_handler( self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: int = 0 ) -> Any: device = device or self.device - datahook_selector = ( - _DataHookSelector(self, None) if self._trainer is None else self.trainer._data_connector._datahook_selector - ) - hook = datahook_selector.get_hook("on_before_batch_transfer") - batch = hook(batch, dataloader_idx) - hook = datahook_selector.get_hook("transfer_batch_to_device") - batch = hook(batch, device, dataloader_idx) - hook = datahook_selector.get_hook("on_after_batch_transfer") - batch = hook(batch, dataloader_idx) + def call_hook(hook_name, **kwargs): + if self._trainer: + datahook_selector = self._trainer._data_connector._datahook_selector + obj = datahook_selector.get_instance(hook_name) + trainer_method = ( + self._trainer._call_lightning_module_hook + if isinstance(obj, self.__class__) + else self._trainer._call_lightning_datamodule_hook + ) + return trainer_method(hook_name, **kwargs) + else: + hook = getattr(self, hook_name) + return hook(**kwargs) + + batch = call_hook("on_before_batch_transfer", batch=batch, dataloader_idx=dataloader_idx) + batch = call_hook("transfer_batch_to_device", batch=batch, device=device, dataloader_idx=dataloader_idx) + batch = call_hook("on_after_batch_transfer", batch=batch, dataloader_idx=dataloader_idx) return batch def print(self, *args, **kwargs) -> None: diff --git a/src/pytorch_lightning/plugins/precision/precision_plugin.py b/src/pytorch_lightning/plugins/precision/precision_plugin.py index 60dfb1ab6c92f..c8b20218f9abd 100644 --- a/src/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/src/pytorch_lightning/plugins/precision/precision_plugin.py @@ -182,9 +182,11 @@ def _clip_gradients( if not isinstance(model, pl.LightningModule) or not model.automatic_optimization: # the configuration validator disallows clipping on manual return - model.configure_gradient_clipping( - optimizer, - optimizer_idx, + + model.trainer._call_lightning_module_hook( + "configure_gradient_clipping", + optimizer=optimizer, + optimizer_idx=optimizer_idx, gradient_clip_val=clip_val, gradient_clip_algorithm=gradient_clip_algorithm, ) diff --git a/src/pytorch_lightning/trainer/connectors/data_connector.py b/src/pytorch_lightning/trainer/connectors/data_connector.py index e1aca404722db..9cea805c1da12 100644 --- a/src/pytorch_lightning/trainer/connectors/data_connector.py +++ b/src/pytorch_lightning/trainer/connectors/data_connector.py @@ -527,7 +527,7 @@ def is_module(self) -> bool: @dataclass class _DataHookSelector: - """Stores the info about the shared DataHooks within LightningModule and LightningDataModule. + """Stores the info about the shared DataHooks within ``LightningModule`` and ``LightningDataModule``. The hook source can be @@ -535,8 +535,8 @@ class _DataHookSelector: 2. a method from the :class:`~pytorch_lightning.core.datamodule.LightningDataModule`, Arguments: - model: A LightningModule - datamodule: A LightningDataModule + model: A ``LightningModule`` + datamodule: A ``LightningDataModule`` """ model: "pl.LightningModule" @@ -545,7 +545,7 @@ class _DataHookSelector: default=("on_before_batch_transfer", "transfer_batch_to_device", "on_after_batch_transfer") ) - def get_hook(self, hook_name: str) -> Callable: + def get_instance(self, hook_name: str) -> Callable: if hook_name not in self._valid_hooks: raise ValueError( f"`{hook_name}` is not a shared hook within `LightningModule` and `LightningDataModule`." @@ -553,7 +553,7 @@ def get_hook(self, hook_name: str) -> Callable: ) if self.datamodule is None: - return getattr(self.model, hook_name) + return self.model if is_overridden(hook_name, self.datamodule): if is_overridden(hook_name, self.model): @@ -561,11 +561,11 @@ def get_hook(self, hook_name: str) -> Callable: f"You have overridden `{hook_name}` in both `LightningModule` and `LightningDataModule`." " It will use the implementation from `LightningDataModule` instance." ) - return getattr(self.datamodule, hook_name) + return self.datamodule if is_overridden(hook_name, self.model): warning_cache.warn( f"You have overridden `{hook_name}` in `LightningModule` but have passed in a" " `LightningDataModule`. It will use the implementation from `LightningModule` instance." ) - return getattr(self.model, hook_name) + return self.model From 43ec3f851aebecc5dc0b52c7e0e7f44d6cace2de Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Sun, 7 Aug 2022 15:47:50 +0530 Subject: [PATCH 2/5] fix tests --- .../logger_connector/fx_validator.py | 5 +++ tests/tests_pytorch/models/test_hooks.py | 32 +++++++++++-------- .../trainer/connectors/test_data_connector.py | 26 +++++++-------- .../trainer/logging_/test_logger_connector.py | 10 +++--- 4 files changed, 41 insertions(+), 32 deletions(-) diff --git a/src/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py b/src/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py index 6f60ba6f1aa2f..56ad53ef4ba04 100644 --- a/src/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py +++ b/src/pytorch_lightning/trainer/connectors/logger_connector/fx_validator.py @@ -44,6 +44,8 @@ class _LogOptions(TypedDict): allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False ), "lr_scheduler_step": None, + "configure_gradient_clipping": None, + "clip_gradients": None, "on_before_zero_grad": _LogOptions( allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False ), @@ -98,6 +100,9 @@ class _LogOptions(TypedDict): "on_epoch_end": _LogOptions( allowed_on_step=(False,), allowed_on_epoch=(True,), default_on_step=False, default_on_epoch=True ), + "on_before_batch_transfer": None, + "transfer_batch_to_device": None, + "on_after_batch_transfer": None, "on_batch_start": _LogOptions( allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False ), diff --git a/tests/tests_pytorch/models/test_hooks.py b/tests/tests_pytorch/models/test_hooks.py index a2235c592d5fb..7ae3d4b978c34 100644 --- a/tests/tests_pytorch/models/test_hooks.py +++ b/tests/tests_pytorch/models/test_hooks.py @@ -296,9 +296,9 @@ def _auto_train_batch( for i in range(current_batch, batches): out.extend( [ - dict(name="on_before_batch_transfer", args=(ANY, 0)), - dict(name="transfer_batch_to_device", args=(ANY, device, 0)), - dict(name="on_after_batch_transfer", args=(ANY, 0)), + dict(name="on_before_batch_transfer", kwargs=dict(batch=ANY, dataloader_idx=0)), + dict(name="transfer_batch_to_device", kwargs=dict(batch=ANY, device=device, dataloader_idx=0)), + dict(name="on_after_batch_transfer", kwargs=dict(batch=ANY, dataloader_idx=0)), dict(name="Callback.on_batch_start", args=(trainer, model)), dict(name="Callback.on_train_batch_start", args=(trainer, model, ANY, i)), dict(name="on_train_batch_start", args=(ANY, i)), @@ -325,8 +325,9 @@ def _auto_train_batch( ), dict( name="configure_gradient_clipping", - args=(ANY, 0), - kwargs=dict(gradient_clip_val=None, gradient_clip_algorithm=None), + kwargs=dict( + optimizer=ANY, optimizer_idx=0, gradient_clip_val=None, gradient_clip_algorithm=None + ), ), # this is after because it refers to the `LightningModule.optimizer_step` hook which encapsulates # the actual call to `PrecisionPlugin.optimizer_step` @@ -354,9 +355,9 @@ def _manual_train_batch(trainer, model, batches, device=torch.device("cpu"), **k for i in range(batches): out.extend( [ - dict(name="on_before_batch_transfer", args=(ANY, 0)), - dict(name="transfer_batch_to_device", args=(ANY, device, 0)), - dict(name="on_after_batch_transfer", args=(ANY, 0)), + dict(name="on_before_batch_transfer", kwargs=dict(batch=ANY, dataloader_idx=0)), + dict(name="transfer_batch_to_device", kwargs=dict(batch=ANY, device=device, dataloader_idx=0)), + dict(name="on_after_batch_transfer", kwargs=dict(batch=ANY, dataloader_idx=0)), dict(name="Callback.on_batch_start", args=(trainer, model)), dict(name="Callback.on_train_batch_start", args=(trainer, model, ANY, i)), dict(name="on_train_batch_start", args=(ANY, i)), @@ -405,9 +406,9 @@ def _eval_batch(fn, trainer, model, batches, key, device=torch.device("cpu")): for i in range(batches): out.extend( [ - dict(name="on_before_batch_transfer", args=(ANY, 0)), - dict(name="transfer_batch_to_device", args=(ANY, device, 0)), - dict(name="on_after_batch_transfer", args=(ANY, 0)), + dict(name="on_before_batch_transfer", kwargs=dict(batch=ANY, dataloader_idx=0)), + dict(name="transfer_batch_to_device", kwargs=dict(batch=ANY, device=device, dataloader_idx=0)), + dict(name="on_after_batch_transfer", kwargs=dict(batch=ANY, dataloader_idx=0)), dict(name=f"Callback.on_{fn}_batch_start", args=(trainer, model, ANY, i, 0)), dict(name=f"on_{fn}_batch_start", args=(ANY, i, 0)), dict(name="forward", args=(ANY,)), @@ -425,9 +426,12 @@ def _predict_batch(trainer, model, batches): for i in range(batches): out.extend( [ - dict(name="on_before_batch_transfer", args=(ANY, 0)), - dict(name="transfer_batch_to_device", args=(ANY, torch.device("cpu"), 0)), - dict(name="on_after_batch_transfer", args=(ANY, 0)), + dict(name="on_before_batch_transfer", kwargs=dict(batch=ANY, dataloader_idx=0)), + dict( + name="transfer_batch_to_device", + kwargs=dict(batch=ANY, device=torch.device("cpu"), dataloader_idx=0), + ), + dict(name="on_after_batch_transfer", kwargs=dict(batch=ANY, dataloader_idx=0)), dict(name="Callback.on_predict_batch_start", args=(trainer, model, ANY, i, 0)), dict(name="on_predict_batch_start", args=(ANY, i, 0)), dict(name="forward", args=(ANY,)), diff --git a/tests/tests_pytorch/trainer/connectors/test_data_connector.py b/tests/tests_pytorch/trainer/connectors/test_data_connector.py index 52ef4c4db6d8d..f2a98daa9c5ad 100644 --- a/tests/tests_pytorch/trainer/connectors/test_data_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_data_connector.py @@ -470,34 +470,34 @@ def test_no_datamodule_no_overridden(self, hook_name): model, _, trainer = self.reset_instances() trainer._data_connector.attach_datamodule(model, datamodule=None) with no_warning_call(match=f"have overridden `{hook_name}` in"): - hook = trainer._data_connector._datahook_selector.get_hook(hook_name) + instance = trainer._data_connector._datahook_selector.get_instance(hook_name) - assert hook == getattr(model, hook_name) + assert instance is model def test_with_datamodule_no_overridden(self, hook_name): model, dm, trainer = self.reset_instances() trainer._data_connector.attach_datamodule(model, datamodule=dm) with no_warning_call(match=f"have overridden `{hook_name}` in"): - hook = trainer._data_connector._datahook_selector.get_hook(hook_name) + instance = trainer._data_connector._datahook_selector.get_instance(hook_name) - assert hook == getattr(model, hook_name) + assert instance is model def test_override_model_hook(self, hook_name): model, dm, trainer = self.reset_instances() trainer._data_connector.attach_datamodule(model, datamodule=dm) with no_warning_call(match=f"have overridden `{hook_name}` in"): - hook = trainer._data_connector._datahook_selector.get_hook(hook_name) + instance = trainer._data_connector._datahook_selector.get_instance(hook_name) - assert hook == getattr(model, hook_name) + assert instance is model def test_override_datamodule_hook(self, hook_name): model, dm, trainer = self.reset_instances() trainer._data_connector.attach_datamodule(model, datamodule=dm) setattr(dm, hook_name, self.overridden_func) with no_warning_call(match=f"have overridden `{hook_name}` in"): - hook = trainer._data_connector._datahook_selector.get_hook(hook_name) + instance = trainer._data_connector._datahook_selector.get_instance(hook_name) - assert hook == getattr(dm, hook_name) + assert instance is dm def test_override_both_model_and_datamodule(self, hook_name): model, dm, trainer = self.reset_instances() @@ -505,24 +505,24 @@ def test_override_both_model_and_datamodule(self, hook_name): setattr(model, hook_name, self.overridden_func) setattr(dm, hook_name, self.overridden_func) with pytest.warns(UserWarning, match=f"have overridden `{hook_name}` in both"): - hook = trainer._data_connector._datahook_selector.get_hook(hook_name) + instance = trainer._data_connector._datahook_selector.get_instance(hook_name) - assert hook == getattr(dm, hook_name) + assert instance is dm def test_with_datamodule_override_model(self, hook_name): model, dm, trainer = self.reset_instances() trainer._data_connector.attach_datamodule(model, datamodule=dm) setattr(model, hook_name, self.overridden_func) with pytest.warns(UserWarning, match=f"have overridden `{hook_name}` in `LightningModule`"): - hook = trainer._data_connector._datahook_selector.get_hook(hook_name) + instance = trainer._data_connector._datahook_selector.get_instance(hook_name) - assert hook == getattr(model, hook_name) + assert instance is model def test_invalid_hook_passed_in_datahook_selector(): dh_selector = _DataHookSelector(BoringModel(), None) with pytest.raises(ValueError, match="is not a shared hook"): - dh_selector.get_hook("setup") + dh_selector.get_instance("setup") def test_eval_distributed_sampler_warning(tmpdir): diff --git a/tests/tests_pytorch/trainer/logging_/test_logger_connector.py b/tests/tests_pytorch/trainer/logging_/test_logger_connector.py index 760e8eea2a85c..c2be22c61244b 100644 --- a/tests/tests_pytorch/trainer/logging_/test_logger_connector.py +++ b/tests/tests_pytorch/trainer/logging_/test_logger_connector.py @@ -187,11 +187,6 @@ def __init__(self, not_supported): { "log", "log_dict", - # the following are problematic as they do have `self._current_fx_name` defined some times but - # not others depending on where they were called. So we cannot reliably `self.log` in them - "on_before_batch_transfer", - "transfer_batch_to_device", - "on_after_batch_transfer", } ) # remove `nn.Module` hooks @@ -227,6 +222,9 @@ def test_fx_validator_integration(tmpdir): "on_pretrain_routine_end": "You can't", "train_dataloader": "You can't", "val_dataloader": "You can't", + "on_before_batch_transfer": "You can't", + "transfer_batch_to_device": "You can't", + "on_after_batch_transfer": "You can't", "on_validation_end": "You can't", "on_train_end": "You can't", "on_fit_end": "You can't", @@ -238,6 +236,8 @@ def test_fx_validator_integration(tmpdir): "on_validation_model_eval": "You can't", "on_validation_model_train": "You can't", "lr_scheduler_step": "You can't", + "configure_gradient_clipping": "You can't", + "clip_gradients": "You can't", "on_save_checkpoint": "You can't", "on_load_checkpoint": "You can't", "on_exception": "You can't", From 58beda0dc488bad987eaf29300a2f3f83f4bf43b Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Wed, 10 Aug 2022 14:09:45 +0530 Subject: [PATCH 3/5] rev positional args --- src/pytorch_lightning/core/module.py | 12 ++++---- .../plugins/precision/precision_plugin.py | 4 +-- tests/tests_pytorch/models/test_hooks.py | 29 +++++++++---------- 3 files changed, 22 insertions(+), 23 deletions(-) diff --git a/src/pytorch_lightning/core/module.py b/src/pytorch_lightning/core/module.py index a8caa95dfa5f1..612bcc72d2806 100644 --- a/src/pytorch_lightning/core/module.py +++ b/src/pytorch_lightning/core/module.py @@ -291,7 +291,7 @@ def _apply_batch_transfer_handler( ) -> Any: device = device or self.device - def call_hook(hook_name, **kwargs): + def call_hook(hook_name, *args): if self._trainer: datahook_selector = self._trainer._data_connector._datahook_selector obj = datahook_selector.get_instance(hook_name) @@ -300,14 +300,14 @@ def call_hook(hook_name, **kwargs): if isinstance(obj, self.__class__) else self._trainer._call_lightning_datamodule_hook ) - return trainer_method(hook_name, **kwargs) + return trainer_method(hook_name, *args) else: hook = getattr(self, hook_name) - return hook(**kwargs) + return hook(*args) - batch = call_hook("on_before_batch_transfer", batch=batch, dataloader_idx=dataloader_idx) - batch = call_hook("transfer_batch_to_device", batch=batch, device=device, dataloader_idx=dataloader_idx) - batch = call_hook("on_after_batch_transfer", batch=batch, dataloader_idx=dataloader_idx) + batch = call_hook("on_before_batch_transfer", batch, dataloader_idx) + batch = call_hook("transfer_batch_to_device", batch, device, dataloader_idx) + batch = call_hook("on_after_batch_transfer", batch, dataloader_idx) return batch def print(self, *args, **kwargs) -> None: diff --git a/src/pytorch_lightning/plugins/precision/precision_plugin.py b/src/pytorch_lightning/plugins/precision/precision_plugin.py index c8b20218f9abd..285a0f31e3955 100644 --- a/src/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/src/pytorch_lightning/plugins/precision/precision_plugin.py @@ -185,8 +185,8 @@ def _clip_gradients( model.trainer._call_lightning_module_hook( "configure_gradient_clipping", - optimizer=optimizer, - optimizer_idx=optimizer_idx, + optimizer, + optimizer_idx, gradient_clip_val=clip_val, gradient_clip_algorithm=gradient_clip_algorithm, ) diff --git a/tests/tests_pytorch/models/test_hooks.py b/tests/tests_pytorch/models/test_hooks.py index 7ae3d4b978c34..5f57ae147a05d 100644 --- a/tests/tests_pytorch/models/test_hooks.py +++ b/tests/tests_pytorch/models/test_hooks.py @@ -296,9 +296,9 @@ def _auto_train_batch( for i in range(current_batch, batches): out.extend( [ - dict(name="on_before_batch_transfer", kwargs=dict(batch=ANY, dataloader_idx=0)), - dict(name="transfer_batch_to_device", kwargs=dict(batch=ANY, device=device, dataloader_idx=0)), - dict(name="on_after_batch_transfer", kwargs=dict(batch=ANY, dataloader_idx=0)), + dict(name="on_before_batch_transfer", args=(ANY, 0)), + dict(name="transfer_batch_to_device", args=(ANY, device, 0)), + dict(name="on_after_batch_transfer", args=(ANY, 0)), dict(name="Callback.on_batch_start", args=(trainer, model)), dict(name="Callback.on_train_batch_start", args=(trainer, model, ANY, i)), dict(name="on_train_batch_start", args=(ANY, i)), @@ -325,9 +325,8 @@ def _auto_train_batch( ), dict( name="configure_gradient_clipping", - kwargs=dict( - optimizer=ANY, optimizer_idx=0, gradient_clip_val=None, gradient_clip_algorithm=None - ), + args=(ANY, 0), + kwargs=dict(gradient_clip_val=None, gradient_clip_algorithm=None), ), # this is after because it refers to the `LightningModule.optimizer_step` hook which encapsulates # the actual call to `PrecisionPlugin.optimizer_step` @@ -355,9 +354,9 @@ def _manual_train_batch(trainer, model, batches, device=torch.device("cpu"), **k for i in range(batches): out.extend( [ - dict(name="on_before_batch_transfer", kwargs=dict(batch=ANY, dataloader_idx=0)), - dict(name="transfer_batch_to_device", kwargs=dict(batch=ANY, device=device, dataloader_idx=0)), - dict(name="on_after_batch_transfer", kwargs=dict(batch=ANY, dataloader_idx=0)), + dict(name="on_before_batch_transfer", args=(ANY, 0)), + dict(name="transfer_batch_to_device", args=(ANY, device, 0)), + dict(name="on_after_batch_transfer", args=(ANY, 0)), dict(name="Callback.on_batch_start", args=(trainer, model)), dict(name="Callback.on_train_batch_start", args=(trainer, model, ANY, i)), dict(name="on_train_batch_start", args=(ANY, i)), @@ -406,9 +405,9 @@ def _eval_batch(fn, trainer, model, batches, key, device=torch.device("cpu")): for i in range(batches): out.extend( [ - dict(name="on_before_batch_transfer", kwargs=dict(batch=ANY, dataloader_idx=0)), - dict(name="transfer_batch_to_device", kwargs=dict(batch=ANY, device=device, dataloader_idx=0)), - dict(name="on_after_batch_transfer", kwargs=dict(batch=ANY, dataloader_idx=0)), + dict(name="on_before_batch_transfer", args=(ANY, 0)), + dict(name="transfer_batch_to_device", args=(ANY, device, 0)), + dict(name="on_after_batch_transfer", args=(ANY, 0)), dict(name=f"Callback.on_{fn}_batch_start", args=(trainer, model, ANY, i, 0)), dict(name=f"on_{fn}_batch_start", args=(ANY, i, 0)), dict(name="forward", args=(ANY,)), @@ -426,12 +425,12 @@ def _predict_batch(trainer, model, batches): for i in range(batches): out.extend( [ - dict(name="on_before_batch_transfer", kwargs=dict(batch=ANY, dataloader_idx=0)), + dict(name="on_before_batch_transfer", args=(ANY, 0)), dict( name="transfer_batch_to_device", - kwargs=dict(batch=ANY, device=torch.device("cpu"), dataloader_idx=0), + args=(ANY, torch.device("cpu"), 0), ), - dict(name="on_after_batch_transfer", kwargs=dict(batch=ANY, dataloader_idx=0)), + dict(name="on_after_batch_transfer", args=(ANY, 0)), dict(name="Callback.on_predict_batch_start", args=(trainer, model, ANY, i, 0)), dict(name="on_predict_batch_start", args=(ANY, i, 0)), dict(name="forward", args=(ANY,)), From 2f110723c6192a9c64b4cdb1352ba2fe26d874d3 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Wed, 10 Aug 2022 14:16:43 +0530 Subject: [PATCH 4/5] update --- .../trainer/connectors/data_connector.py | 10 +++++----- tests/tests_pytorch/models/test_hooks.py | 5 +---- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/src/pytorch_lightning/trainer/connectors/data_connector.py b/src/pytorch_lightning/trainer/connectors/data_connector.py index 9cea805c1da12..1de8bee90d18f 100644 --- a/src/pytorch_lightning/trainer/connectors/data_connector.py +++ b/src/pytorch_lightning/trainer/connectors/data_connector.py @@ -14,7 +14,7 @@ import multiprocessing import os from dataclasses import dataclass, field -from typing import Any, Callable, Collection, List, Optional, Tuple, Union +from typing import Any, Collection, List, Optional, Tuple, Union from weakref import proxy from torch.utils.data import BatchSampler, DataLoader, Sampler, SequentialSampler @@ -529,10 +529,10 @@ def is_module(self) -> bool: class _DataHookSelector: """Stores the info about the shared DataHooks within ``LightningModule`` and ``LightningDataModule``. - The hook source can be + The hook source can be: - 1. a method from the :class:`~pytorch_lightning.core.module.LightningModule`, - 2. a method from the :class:`~pytorch_lightning.core.datamodule.LightningDataModule`, + 1. the :class:`~pytorch_lightning.core.module.LightningModule`, + 2. the :class:`~pytorch_lightning.core.datamodule.LightningDataModule`, Arguments: model: A ``LightningModule`` @@ -545,7 +545,7 @@ class _DataHookSelector: default=("on_before_batch_transfer", "transfer_batch_to_device", "on_after_batch_transfer") ) - def get_instance(self, hook_name: str) -> Callable: + def get_instance(self, hook_name: str) -> Union["pl.LightningModule", "pl.LightningDataModule"]: if hook_name not in self._valid_hooks: raise ValueError( f"`{hook_name}` is not a shared hook within `LightningModule` and `LightningDataModule`." diff --git a/tests/tests_pytorch/models/test_hooks.py b/tests/tests_pytorch/models/test_hooks.py index 5f57ae147a05d..a2235c592d5fb 100644 --- a/tests/tests_pytorch/models/test_hooks.py +++ b/tests/tests_pytorch/models/test_hooks.py @@ -426,10 +426,7 @@ def _predict_batch(trainer, model, batches): out.extend( [ dict(name="on_before_batch_transfer", args=(ANY, 0)), - dict( - name="transfer_batch_to_device", - args=(ANY, torch.device("cpu"), 0), - ), + dict(name="transfer_batch_to_device", args=(ANY, torch.device("cpu"), 0)), dict(name="on_after_batch_transfer", args=(ANY, 0)), dict(name="Callback.on_predict_batch_start", args=(trainer, model, ANY, i, 0)), dict(name="on_predict_batch_start", args=(ANY, i, 0)), From abd46f5cd26e97dd2d6fd6e34bfbdfe00f174169 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Fri, 12 Aug 2022 00:30:43 +0200 Subject: [PATCH 5/5] add changelog --- src/pytorch_lightning/CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 5d77a3ad293b9..a3755d7733dba 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -8,7 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added -- +- Added profiling to these hooks: `on_before_batch_transfer`, `transfer_batch_to_device`, `on_after_batch_transfer`, `configure_gradient_clipping`, `clip_gradients` ([#14069](https://github.com/Lightning-AI/lightning/pull/14069)) -