Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resolve instantiation problem with init_meta_context #10493

Merged
merged 29 commits into from
Nov 15, 2021
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
a776406
update
tchaton Nov 11, 2021
299f049
update
tchaton Nov 11, 2021
2c08d3a
update
tchaton Nov 12, 2021
8077db8
update changelog
tchaton Nov 12, 2021
1244717
Merge branch 'meta_improvement' of https://github.com/PyTorchLightnin…
tchaton Nov 12, 2021
9f749a8
update on comments
tchaton Nov 12, 2021
0d2c68b
udpate
tchaton Nov 12, 2021
4e9d805
update
tchaton Nov 12, 2021
35ffbf8
update
tchaton Nov 12, 2021
cbf424a
update
tchaton Nov 12, 2021
84ab16f
update
tchaton Nov 12, 2021
3deda2f
add comment
tchaton Nov 12, 2021
46ffd62
update
tchaton Nov 12, 2021
6b03c34
Merge branch 'meta_improvement' of https://github.com/PyTorchLightnin…
tchaton Nov 12, 2021
e277a4d
update
tchaton Nov 12, 2021
e724483
Minor fixes
carmocca Nov 12, 2021
a2a1084
__instancecheck__ magic
carmocca Nov 12, 2021
ae3bec8
Merge branch 'master' into meta_improvement
tchaton Nov 12, 2021
dfb42e4
update
tchaton Nov 12, 2021
31babe7
update
tchaton Nov 12, 2021
e08ece3
Update CHANGELOG.md
rohitgr7 Nov 15, 2021
f4fefc0
Merge branch 'master' into meta_improvement
tchaton Nov 15, 2021
49a9eaf
Merge branch 'meta_improvement' of https://github.com/PyTorchLightnin…
tchaton Nov 15, 2021
c12cd17
update
tchaton Nov 15, 2021
7324f8b
Merge branch 'master' into meta_improvement
Nov 15, 2021
2ce60ca
update
tchaton Nov 15, 2021
139b14b
update
tchaton Nov 15, 2021
90ddb71
Merge branch 'meta_improvement' of https://github.com/PyTorchLightnin…
tchaton Nov 15, 2021
e5df66e
update on comments
tchaton Nov 15, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `CombinedLoader` and `max_size_cycle` didn't receive a `DistributedSampler` ([#10374](https://github.com/PyTorchLightning/pytorch-lightning/issues/10374))


- Fixed `isinstance` not working with `init_meta_context`, materialize model not being moved to the device ([#10493](https://github.com/PyTorchLightning/metrics/pull/10493))
tchaton marked this conversation as resolved.
Show resolved Hide resolved


- Fixed an issue that prevented the Trainer to shutdown workers when execution is interrupted due to failure([#10463](https://github.com/PyTorchLightning/pytorch-lightning/issues/10463))


Expand Down
6 changes: 5 additions & 1 deletion pytorch_lightning/core/mixins/device_dtype_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import torch
from torch.nn import Module

import pytorch_lightning as pl


class DeviceDtypeModuleMixin(Module):
__jit_unused_properties__ = ["device", "dtype"]
Expand Down Expand Up @@ -177,7 +179,9 @@ def __update_properties(
self, device: Optional[torch.device] = None, dtype: Optional[Union[str, torch.dtype]] = None
) -> None:
def apply_fn(module: Union["DeviceDtypeModuleMixin", Module]) -> None:
if not isinstance(module, DeviceDtypeModuleMixin):
# TODO: Find why `isinstance(module, DeviceDtypeModuleMixin)` doesn't
tchaton marked this conversation as resolved.
Show resolved Hide resolved
# work when using `init_meta_context`.
if not isinstance(module, (DeviceDtypeModuleMixin, pl.LightningModule)):
return
if device is not None:
module._device = device
Expand Down
15 changes: 13 additions & 2 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@
from pytorch_lightning.utilities.distributed import distributed_available
from pytorch_lightning.utilities.exceptions import ExitGracefullyException, MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_training
from pytorch_lightning.utilities.meta import materialize_module
from pytorch_lightning.utilities.meta import is_meta_device, materialize_module
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.types import (
Expand Down Expand Up @@ -1406,10 +1406,21 @@ def _call_setup_hook(self) -> None:

def _call_configure_sharded_model(self) -> None:
with self.accelerator.model_sharded_context():
materialize_module(self.lightning_module)
self._handle_meta_model()
self.call_hook("configure_sharded_model")
self.call_hook("on_configure_sharded_model")

def _handle_meta_model(self) -> None:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
if not is_meta_device(self.lightning_module):
return

if isinstance(self.training_type_plugin, DDPSpawnPlugin):
raise MisconfigurationException("LightningModule on meta device isn't supported with spawn.")
tchaton marked this conversation as resolved.
Show resolved Hide resolved

materialize_module(self.lightning_module)
# the trainer reference is lost during materialization
self.lightning_module.trainer = proxy(self)
tchaton marked this conversation as resolved.
Show resolved Hide resolved

def _call_teardown_hook(self) -> None:
fn = self.state.fn._setup_fn

Expand Down
46 changes: 30 additions & 16 deletions pytorch_lightning/utilities/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@
from functools import partial
from itertools import chain
from types import ModuleType
from typing import Callable, Dict, Generator, Iterator, List, Optional, Set, Type
from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Set, Type

import torch
from torch import nn, Tensor
from torch.nn import Module
from torch.nn.modules.container import ModuleDict, ModuleList, Sequential

import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_10
Expand Down Expand Up @@ -191,7 +192,6 @@ def materialize_module(root_module: nn.Module) -> nn.Module:

# cache subclasses to optimize the search when resetting the meta device later on.
__STORAGE_META__ = {}

__CREATED_MODULES__ = set()


Expand Down Expand Up @@ -237,45 +237,52 @@ def _set_meta_device() -> None:

for subclass in get_all_subclasses(torch.nn.modules.module.Module):

if isinstance(subclass, (Sequential, ModuleList, ModuleDict)):
if subclass in (Sequential, ModuleList, ModuleDict, pl.LightningModule):
continue

# if a subclass has already been stored, we should use the cache
if str(subclass) in __STORAGE_META__:
# reset the class import package to its rightfull state.
# reset the class import package to its rightful state.
mods, subclass, meta_class = __STORAGE_META__[subclass]
for mod in mods:
setattr(mod, subclass.__name__, meta_class)
continue

class _IsinstanceMetaclass(type(subclass)):
def __instancecheck__(self, instance: Any) -> bool:
"""Overrides the ``isinstance`` check on ``_MaterializerModule`` objects."""
return isinstance(instance, self.__bases__[0])

# Create a class subclassing current `subclass` overriding its new method.
# this will enable use to use `torch.distributed.nn.utils.init_meta` to create a `meta`
# version of the current subclass module
class _MetaClass(subclass):
class _MaterializerModule(subclass, metaclass=_IsinstanceMetaclass):
@classmethod
@contextmanager
def instantiation_context(cls, materialize: bool):
def instantiation_context(cls):
_unset_meta_device(from_created=True)
yield
_set_meta_device_populated(from_created=True)

@classmethod
def materialize(cls, materialize_fn: Callable):
with cls.instantiation_context(materialize=True):
with cls.instantiation_context():
obj = materialize_fn()
return obj

@staticmethod
def add_subclasses(subclass):
"""This is used to unrol the instantion tree while creating the modules."""
__CREATED_MODULES__.add(subclass)
"""This is used to unroll the instantiation tree while creating the modules."""
# Don't store the LightningModule as skipped from the Meta process.
if subclass != pl.LightningModule:
__CREATED_MODULES__.add(subclass)
if subclass.__bases__[0] != torch.nn.modules.module.Module:
_MetaClass.add_subclasses(subclass.__bases__[0])
_MaterializerModule.add_subclasses(subclass.__bases__[0])

def __new__(cls, *args, **kwargs):
subclass = cls.__bases__[0]
cls.add_subclasses(subclass)
with cls.instantiation_context(materialize=False):
with cls.instantiation_context():
obj = init_meta(subclass, *args, **kwargs)

obj.materialize = partial(cls.materialize, materialize_fn=obj.materialize)
Expand All @@ -294,9 +301,8 @@ def search(mod: ModuleType) -> List[ModuleType]:
# nn.Module class can be imported at different level and they all need to be mocked.
# Example: torch.nn.Linear is actually torch.nn.modules.linear.Linear
# Therefore, torch.nn.Linear, torch.nn.modules.Linear, torch.nn.modules.linear.Linear
# needs to be replaced by the torch.nn.linear.modules.Linear _MetaClass
out = []
out.append(search(mod))
# needs to be replaced by the torch.nn.linear.modules.Linear _MaterializerModule
out = [search(mod)]
for name in submodules[1:]:
mod = getattr(mod, name)
out.append(search(mod))
Expand All @@ -305,11 +311,11 @@ def search(mod: ModuleType) -> List[ModuleType]:
mods = [mod for mod in chain(*out) if mod]

# store the modules search so it doesn't have to be performed again for this class
__STORAGE_META__[subclass] = (mods, subclass, _MetaClass)
__STORAGE_META__[subclass] = (mods, subclass, _MaterializerModule)

# replace all subclass by its meta form
for mod in mods:
setattr(mod, subclass.__name__, _MetaClass)
setattr(mod, subclass.__name__, _MaterializerModule)


@contextmanager
Expand All @@ -321,3 +327,11 @@ def init_meta_context() -> Generator:
_set_meta_device()
yield
_unset_meta_device()


def is_meta_device(module: nn.Module) -> bool:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
try:
param = next(module.parameters())
return param.device.type == "meta"
except StopIteration:
return False
9 changes: 7 additions & 2 deletions tests/utilities/test_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from torch import nn

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities.meta import init_meta_context, materialize_module
from pytorch_lightning.utilities.meta import init_meta_context, is_meta_device, materialize_module
from tests.helpers.runif import RunIf


Expand All @@ -31,18 +31,23 @@ def __init__(self, num_layers: int):
self.layer = nn.Sequential(*[nn.Linear(1, 1) for _ in range(self.hparams.num_layers)])


@RunIf(min_torch="1.10.0")
@RunIf(special=True, min_torch="1.10.0")
def test_init_meta_context():

with init_meta_context():
m = nn.Linear(in_features=1, out_features=1)
assert isinstance(m, nn.Linear)
assert m.weight.device.type == "meta"
assert is_meta_device(m)
mlp = MLP(4)
assert mlp.layer[0].weight.device.type == "meta"

mlp = materialize_module(mlp)
assert mlp.layer[0].weight.device.type == "cpu"

assert not is_meta_device(mlp)
assert not is_meta_device(nn.Module())

model = BoringModel(4)
assert model.layer[0].weight.device.type == "meta"
materialize_module(model)
Expand Down