Skip to content

Commit

Permalink
Avoid false-positive warnings about method calls on the Fabric-wrappe…
Browse files Browse the repository at this point in the history
…d module (#18819)
  • Loading branch information
awaelchli authored Oct 23, 2023
1 parent e7afe04 commit 97303b0
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 36 deletions.
5 changes: 3 additions & 2 deletions src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

-
- Calling a method other than `forward` that invokes submodules is now an error when the model is wrapped (e.g., with DDP) ([#18819](https://github.com/Lightning-AI/lightning/pull/18819))



### Deprecated
Expand All @@ -29,7 +30,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

-
- Fixed false-positive warnings about method calls on the Fabric-wrapped module ([#18819](https://github.com/Lightning-AI/lightning/pull/18819))


## [2.1.0] - 2023-10-11
Expand Down
51 changes: 33 additions & 18 deletions src/lightning/fabric/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from functools import wraps
from typing import Any, Callable, Dict, Generator, Iterator, List, Mapping, Optional, TypeVar, Union, overload

import torch
from lightning_utilities import WarningCache
from lightning_utilities.core.apply_func import apply_to_collection
from torch import Tensor
from torch import nn as nn
Expand All @@ -30,9 +30,7 @@
from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
from lightning.fabric.utilities.types import Optimizable
from lightning.fabric.utilities.warnings import PossibleUserWarning

warning_cache = WarningCache()
T_destination = TypeVar("T_destination", bound=Dict[str, Any])
_LIGHTNING_MODULE_STEP_METHODS = ("training_step", "validation_step", "test_step", "predict_step")

Expand Down Expand Up @@ -161,25 +159,40 @@ def wrapped_forward(*args: Any, **kwargs: Any) -> Any:
# We expect that the `forward_module` will eventually call `original_module.forward`, which we
# have patched to redirect back to `original_module.method_name()`.
def call_forward_module(*args: Any, **kwargs: Any) -> Any:
# Patch the original_module's forward so we can redirect the arguments back to the real method
# Patch the original_module's forward, so we can redirect the arguments back to the real method
self._original_module.forward = wrapped_forward
return self.forward(*args, **kwargs)

return call_forward_module

def _validate_method_access(self, name: str, attribute: Any) -> None:
if (
inspect.ismethod(attribute)
and inspect.signature(attribute).parameters
and self._forward_module != self._original_module
):
warning_cache.warn(
f"You are calling the method `{type(self._original_module).__name__}.{name}()` from outside the"
" model. This will bypass the wrapper from the strategy and result in incorrect behavior in"
" `.backward()`. You should pass your inputs through"
f" `{type(self._original_module).__name__}.forward()`.",
category=PossibleUserWarning,
)
def _wrap_method_with_module_call_tracker(self, method: Callable, name: str) -> Callable:
"""Tracks whether any submodule in ``self._original_module`` was called during the execution of ``method`` by
registering forward hooks on all submodules."""
module_called = False

def hook(*_: Any, **__: Any) -> None:
nonlocal module_called
module_called = True

@wraps(method)
def _wrapped_method(*args: Any, **kwargs: Any) -> Any:
handles = []
for module in self._original_module.modules():
handles.append(module.register_forward_hook(hook))

output = method(*args, **kwargs)

if module_called:
raise RuntimeError(
f"You are calling the method `{type(self._original_module).__name__}.{name}()` from outside the"
" model. This will bypass the wrapper from the strategy and result in incorrect behavior in"
" `.backward()`. You should pass your inputs through `forward()`.",
)
for handle in handles:
handle.remove()
return output

return _wrapped_method

def __getattr__(self, item: Any) -> Any:
if item in _LIGHTNING_MODULE_STEP_METHODS and self._forward_module != self._original_module:
Expand All @@ -194,7 +207,9 @@ def __getattr__(self, item: Any) -> Any:
# If the attribute is not available on the _FabricModule wrapper, redirect to the wrapped nn.Module
original_module = super().__getattr__("_original_module")
attr = getattr(original_module, item)
self._validate_method_access(item, attr)

if inspect.ismethod(attr) and self._forward_module != self._original_module:
attr = self._wrap_method_with_module_call_tracker(attr, item)
return attr

def __setattr__(self, name: str, value: Any) -> None:
Expand Down
40 changes: 25 additions & 15 deletions tests/tests_fabric/test_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@
_FabricOptimizer,
_unwrap_objects,
is_wrapped,
warning_cache,
)
from lightning_utilities.test.warning import no_warning_call
from torch.utils.data import BatchSampler, DistributedSampler
from torch.utils.data.dataloader import DataLoader

Expand Down Expand Up @@ -79,12 +77,24 @@ def test_fabric_module_method_lookup():
"""Test that access to methods warns about improper use when a wrapper from a strategy is involved."""

class OriginalModule(torch.nn.Module):
def method_no_args(self):
def __init__(self):
super().__init__()
self.submodule = torch.nn.Linear(2, 3)

def forward(self, x):
return x

def method_without_module_invocation(self):
return 100

def method_with_args(self, arg, kwarg=1):
def method_with_submodule_invocation(self):
self.submodule(torch.rand(2, 2))
return 101

def method_with_self_invocation(self):
self(None)
return 102

class ModuleWrapper(torch.nn.Module):
def __init__(self, module):
super().__init__()
Expand All @@ -93,21 +103,21 @@ def __init__(self, module):
# Regular case: forward_module == original_module -> no warnings
original_module = OriginalModule()
fabric_module = _FabricModule(forward_module=original_module, precision=Mock(), original_module=original_module)
warning_cache.clear()
with no_warning_call(UserWarning):
assert fabric_module.method_with_args(0) == 101
assert not warning_cache
assert fabric_module.method_without_module_invocation() == 100

# Special case: original module wrapped by forward module: -> warn if method accepts args
original_module = OriginalModule()
wrapped_module = ModuleWrapper(original_module)
fabric_module = _FabricModule(forward_module=wrapped_module, precision=Mock(), original_module=original_module)
warning_cache.clear()
with no_warning_call(UserWarning):
assert fabric_module.method_no_args() == 100
with pytest.warns(UserWarning, match=r"You are calling the method `OriginalModule.method_with_args\(\)` from"):
assert fabric_module.method_with_args(0) == 101
warning_cache.clear()
assert fabric_module.method_without_module_invocation() == 100
with pytest.raises(
RuntimeError, match=r"You are calling the method `OriginalModule.method_with_submodule_invocation\(\)` from"
):
assert fabric_module.method_with_submodule_invocation() == 101
with pytest.raises(
RuntimeError, match=r"You are calling the method `OriginalModule.method_with_self_invocation\(\)` from"
):
assert fabric_module.method_with_self_invocation() == 102


def test_fabric_module_setattr():
Expand Down Expand Up @@ -555,7 +565,7 @@ def normal_method(self):
fabric_module = _FabricModule(forward_module=forward_module, precision=precision, original_module=original_module)

# Regular methods on the original_module are visible and identical on the fabric_module ...
assert fabric_module.normal_method == original_module.normal_method
assert fabric_module.normal_method.__wrapped__ == original_module.normal_method

# ... but special methods like training_step get redirected to the forward_module
assert fabric_module.training_step.__name__ == "call_forward_module"
Expand Down
2 changes: 1 addition & 1 deletion tests/tests_pytorch/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -1488,7 +1488,7 @@ def test_resume_and_old_checkpoint_files_remain(same_resume_folder, tmp_path):
callback = ModelCheckpoint(dirpath=first, monitor="step", mode="max", save_top_k=2, every_n_train_steps=2)
trainer = Trainer(callbacks=callback, max_steps=5, **trainer_kwargs)
trainer.fit(model)
assert os.listdir(first) == ["epoch=0-step=2.ckpt", "epoch=0-step=4.ckpt"]
assert set(os.listdir(first)) == {"epoch=0-step=2.ckpt", "epoch=0-step=4.ckpt"}

# Continue training from checkpoint
callback = ModelCheckpoint(dirpath=new_dirpath, monitor="step", mode="max", save_top_k=2, every_n_train_steps=2)
Expand Down

0 comments on commit 97303b0

Please sign in to comment.