Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Profile batch transfer and gradient clipping hooks #14069

Merged
merged 6 commits into from
Aug 11, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions src/pytorch_lightning/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.core.saving import ModelIO
from pytorch_lightning.loggers import Logger, LoggerCollection
from pytorch_lightning.trainer.connectors.data_connector import _DataHookSelector
from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import _FxValidator
from pytorch_lightning.utilities import _IS_WINDOWS, _TORCH_GREATER_EQUAL_1_10, GradClipAlgorithmType, warnings
from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors
Expand Down Expand Up @@ -291,16 +290,24 @@ def _apply_batch_transfer_handler(
self, batch: Any, device: Optional[torch.device] = None, dataloader_idx: int = 0
) -> Any:
device = device or self.device
datahook_selector = (
_DataHookSelector(self, None) if self._trainer is None else self.trainer._data_connector._datahook_selector
)

hook = datahook_selector.get_hook("on_before_batch_transfer")
batch = hook(batch, dataloader_idx)
hook = datahook_selector.get_hook("transfer_batch_to_device")
batch = hook(batch, device, dataloader_idx)
hook = datahook_selector.get_hook("on_after_batch_transfer")
batch = hook(batch, dataloader_idx)
def call_hook(hook_name, *args):
if self._trainer:
datahook_selector = self._trainer._data_connector._datahook_selector
obj = datahook_selector.get_instance(hook_name)
trainer_method = (
self._trainer._call_lightning_module_hook
if isinstance(obj, self.__class__)
else self._trainer._call_lightning_datamodule_hook
)
return trainer_method(hook_name, *args)
else:
hook = getattr(self, hook_name)
return hook(*args)

batch = call_hook("on_before_batch_transfer", batch, dataloader_idx)
batch = call_hook("transfer_batch_to_device", batch, device, dataloader_idx)
batch = call_hook("on_after_batch_transfer", batch, dataloader_idx)
return batch

def print(self, *args, **kwargs) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,9 @@ def _clip_gradients(
if not isinstance(model, pl.LightningModule) or not model.automatic_optimization:
# the configuration validator disallows clipping on manual
return
model.configure_gradient_clipping(

model.trainer._call_lightning_module_hook(
"configure_gradient_clipping",
optimizer,
optimizer_idx,
gradient_clip_val=clip_val,
Expand Down
22 changes: 11 additions & 11 deletions src/pytorch_lightning/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import multiprocessing
import os
from dataclasses import dataclass, field
from typing import Any, Callable, Collection, List, Optional, Tuple, Union
from typing import Any, Collection, List, Optional, Tuple, Union
from weakref import proxy

from torch.utils.data import BatchSampler, DataLoader, Sampler, SequentialSampler
Expand Down Expand Up @@ -527,16 +527,16 @@ def is_module(self) -> bool:

@dataclass
class _DataHookSelector:
"""Stores the info about the shared DataHooks within LightningModule and LightningDataModule.
"""Stores the info about the shared DataHooks within ``LightningModule`` and ``LightningDataModule``.

The hook source can be
The hook source can be:

1. a method from the :class:`~pytorch_lightning.core.module.LightningModule`,
2. a method from the :class:`~pytorch_lightning.core.datamodule.LightningDataModule`,
1. the :class:`~pytorch_lightning.core.module.LightningModule`,
2. the :class:`~pytorch_lightning.core.datamodule.LightningDataModule`,

Arguments:
model: A LightningModule
datamodule: A LightningDataModule
model: A ``LightningModule``
datamodule: A ``LightningDataModule``
"""

model: "pl.LightningModule"
Expand All @@ -545,27 +545,27 @@ class _DataHookSelector:
default=("on_before_batch_transfer", "transfer_batch_to_device", "on_after_batch_transfer")
)

def get_hook(self, hook_name: str) -> Callable:
def get_instance(self, hook_name: str) -> Union["pl.LightningModule", "pl.LightningDataModule"]:
if hook_name not in self._valid_hooks:
raise ValueError(
f"`{hook_name}` is not a shared hook within `LightningModule` and `LightningDataModule`."
f" Valid hooks are {self._valid_hooks}."
)

if self.datamodule is None:
return getattr(self.model, hook_name)
return self.model

if is_overridden(hook_name, self.datamodule):
if is_overridden(hook_name, self.model):
warning_cache.warn(
f"You have overridden `{hook_name}` in both `LightningModule` and `LightningDataModule`."
" It will use the implementation from `LightningDataModule` instance."
)
return getattr(self.datamodule, hook_name)
return self.datamodule

if is_overridden(hook_name, self.model):
warning_cache.warn(
f"You have overridden `{hook_name}` in `LightningModule` but have passed in a"
" `LightningDataModule`. It will use the implementation from `LightningModule` instance."
)
return getattr(self.model, hook_name)
return self.model
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ class _LogOptions(TypedDict):
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
),
"lr_scheduler_step": None,
"configure_gradient_clipping": None,
"clip_gradients": None,
"on_before_zero_grad": _LogOptions(
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
),
Expand Down Expand Up @@ -98,6 +100,9 @@ class _LogOptions(TypedDict):
"on_epoch_end": _LogOptions(
allowed_on_step=(False,), allowed_on_epoch=(True,), default_on_step=False, default_on_epoch=True
),
"on_before_batch_transfer": None,
"transfer_batch_to_device": None,
"on_after_batch_transfer": None,
"on_batch_start": _LogOptions(
allowed_on_step=(False, True), allowed_on_epoch=(False, True), default_on_step=True, default_on_epoch=False
),
Expand Down
26 changes: 13 additions & 13 deletions tests/tests_pytorch/trainer/connectors/test_data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,59 +470,59 @@ def test_no_datamodule_no_overridden(self, hook_name):
model, _, trainer = self.reset_instances()
trainer._data_connector.attach_datamodule(model, datamodule=None)
with no_warning_call(match=f"have overridden `{hook_name}` in"):
hook = trainer._data_connector._datahook_selector.get_hook(hook_name)
instance = trainer._data_connector._datahook_selector.get_instance(hook_name)

assert hook == getattr(model, hook_name)
assert instance is model

def test_with_datamodule_no_overridden(self, hook_name):
model, dm, trainer = self.reset_instances()
trainer._data_connector.attach_datamodule(model, datamodule=dm)
with no_warning_call(match=f"have overridden `{hook_name}` in"):
hook = trainer._data_connector._datahook_selector.get_hook(hook_name)
instance = trainer._data_connector._datahook_selector.get_instance(hook_name)

assert hook == getattr(model, hook_name)
assert instance is model

def test_override_model_hook(self, hook_name):
model, dm, trainer = self.reset_instances()
trainer._data_connector.attach_datamodule(model, datamodule=dm)
with no_warning_call(match=f"have overridden `{hook_name}` in"):
hook = trainer._data_connector._datahook_selector.get_hook(hook_name)
instance = trainer._data_connector._datahook_selector.get_instance(hook_name)

assert hook == getattr(model, hook_name)
assert instance is model

def test_override_datamodule_hook(self, hook_name):
model, dm, trainer = self.reset_instances()
trainer._data_connector.attach_datamodule(model, datamodule=dm)
setattr(dm, hook_name, self.overridden_func)
with no_warning_call(match=f"have overridden `{hook_name}` in"):
hook = trainer._data_connector._datahook_selector.get_hook(hook_name)
instance = trainer._data_connector._datahook_selector.get_instance(hook_name)

assert hook == getattr(dm, hook_name)
assert instance is dm

def test_override_both_model_and_datamodule(self, hook_name):
model, dm, trainer = self.reset_instances()
trainer._data_connector.attach_datamodule(model, datamodule=dm)
setattr(model, hook_name, self.overridden_func)
setattr(dm, hook_name, self.overridden_func)
with pytest.warns(UserWarning, match=f"have overridden `{hook_name}` in both"):
hook = trainer._data_connector._datahook_selector.get_hook(hook_name)
instance = trainer._data_connector._datahook_selector.get_instance(hook_name)

assert hook == getattr(dm, hook_name)
assert instance is dm

def test_with_datamodule_override_model(self, hook_name):
model, dm, trainer = self.reset_instances()
trainer._data_connector.attach_datamodule(model, datamodule=dm)
setattr(model, hook_name, self.overridden_func)
with pytest.warns(UserWarning, match=f"have overridden `{hook_name}` in `LightningModule`"):
hook = trainer._data_connector._datahook_selector.get_hook(hook_name)
instance = trainer._data_connector._datahook_selector.get_instance(hook_name)

assert hook == getattr(model, hook_name)
assert instance is model


def test_invalid_hook_passed_in_datahook_selector():
dh_selector = _DataHookSelector(BoringModel(), None)
with pytest.raises(ValueError, match="is not a shared hook"):
dh_selector.get_hook("setup")
dh_selector.get_instance("setup")


def test_eval_distributed_sampler_warning(tmpdir):
Expand Down
10 changes: 5 additions & 5 deletions tests/tests_pytorch/trainer/logging_/test_logger_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,11 +187,6 @@ def __init__(self, not_supported):
{
"log",
"log_dict",
# the following are problematic as they do have `self._current_fx_name` defined some times but
# not others depending on where they were called. So we cannot reliably `self.log` in them
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
"on_before_batch_transfer",
"transfer_batch_to_device",
"on_after_batch_transfer",
}
)
# remove `nn.Module` hooks
Expand Down Expand Up @@ -227,6 +222,9 @@ def test_fx_validator_integration(tmpdir):
"on_pretrain_routine_end": "You can't",
"train_dataloader": "You can't",
"val_dataloader": "You can't",
"on_before_batch_transfer": "You can't",
"transfer_batch_to_device": "You can't",
"on_after_batch_transfer": "You can't",
"on_validation_end": "You can't",
"on_train_end": "You can't",
"on_fit_end": "You can't",
Expand All @@ -238,6 +236,8 @@ def test_fx_validator_integration(tmpdir):
"on_validation_model_eval": "You can't",
"on_validation_model_train": "You can't",
"lr_scheduler_step": "You can't",
"configure_gradient_clipping": "You can't",
"clip_gradients": "You can't",
"on_save_checkpoint": "You can't",
"on_load_checkpoint": "You can't",
"on_exception": "You can't",
Expand Down