Skip to content

Commit

Permalink
Merge branch 'master' into feat/auto_wrap_pt_fsdp
Browse files Browse the repository at this point in the history
  • Loading branch information
justusschock authored Aug 26, 2022
2 parents 5726103 + e67842d commit 38f9bc8
Show file tree
Hide file tree
Showing 14 changed files with 246 additions and 114 deletions.
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ warn_no_return = "False"
module = [
"pytorch_lightning.callbacks.progress.rich_progress",
"pytorch_lightning.core.datamodule",
"pytorch_lightning.demos.boring_classes",
"pytorch_lightning.demos.mnist_datamodule",
"pytorch_lightning.profilers.base",
"pytorch_lightning.profilers.pytorch",
Expand Down
10 changes: 9 additions & 1 deletion src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for passing extra init-parameters to the `LightningDataModule.from_datasets` ([#14185](https://github.com/Lightning-AI/lightning/issues/14185))


- Added support for saving sharded optimizer state dict outside of `DDPShardedStrategy` ([#14208](https://github.com/PyTorchLightning/pytorch-lightning/pull/14208))



### Changed

Expand All @@ -30,7 +33,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Replaced the unwrapping logic in strategies with direct access to unwrapped `LightningModule` ([#13738](https://github.com/Lightning-AI/lightning/pull/13738))


- Enabled `on_before_batch_transfer` for `DPStrategy` and `IPUAccelerator` ([14023](https://github.com/Lightning-AI/lightning/pull/14023))
- Enabled `on_before_batch_transfer` for `DPStrategy` and `IPUAccelerator` ([#14023](https://github.com/Lightning-AI/lightning/pull/14023))

- Included `torch.cuda` rng state to the aggregate `_collect_rng_states()` and `_set_rng_states()` ([#14384](https://github.com/Lightning-AI/lightning/pull/14384))



Expand Down Expand Up @@ -85,6 +90,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed wrong num padding for `RichProgressBar` ([#14296](https://github.com/Lightning-AI/lightning/pull/14296))


- Fixed `LightningDataModule` hparams parsing ([#12806](https://github.com/PyTorchLightning/pytorch-lightning/pull/12806))


## [1.7.2] - 2022-08-17

### Added
Expand Down
77 changes: 43 additions & 34 deletions src/pytorch_lightning/demos/boring_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,27 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from typing import cast, Dict, Iterator, List, Optional, Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader, Dataset, IterableDataset, Subset

from pytorch_lightning import LightningDataModule, LightningModule
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT


class RandomDictDataset(Dataset):
def __init__(self, size: int, length: int):
self.len = length
self.data = torch.randn(length, size)

def __getitem__(self, index):
def __getitem__(self, index: int) -> Dict[str, Tensor]:
a = self.data[index]
b = a + 2
return {"a": a, "b": b}
Expand All @@ -40,7 +45,7 @@ def __init__(self, size: int, length: int):
self.len = length
self.data = torch.randn(length, size)

def __getitem__(self, index):
def __getitem__(self, index: int) -> Tensor:
return self.data[index]

def __len__(self) -> int:
Expand All @@ -52,7 +57,7 @@ def __init__(self, size: int, count: int):
self.count = count
self.size = size

def __iter__(self):
def __iter__(self) -> Iterator[Tensor]:
for _ in range(self.count):
yield torch.randn(self.size)

Expand All @@ -62,16 +67,16 @@ def __init__(self, size: int, count: int):
self.count = count
self.size = size

def __iter__(self):
def __iter__(self) -> Iterator[Tensor]:
for _ in range(len(self)):
yield torch.randn(self.size)

def __len__(self):
def __len__(self) -> int:
return self.count


class BoringModel(LightningModule):
def __init__(self):
def __init__(self) -> None:
"""Testing PL Module.
Use as follows:
Expand All @@ -90,60 +95,63 @@ def training_step(...):
super().__init__()
self.layer = torch.nn.Linear(32, 2)

def forward(self, x):
def forward(self, x: Tensor) -> Tensor: # type: ignore[override]
return self.layer(x)

def loss(self, batch, preds):
def loss(self, batch: Tensor, preds: Tensor) -> Tensor:
# An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls
return torch.nn.functional.mse_loss(preds, torch.ones_like(preds))

def step(self, x):
def step(self, x: Tensor) -> Tensor:
x = self(x)
out = torch.nn.functional.mse_loss(x, torch.ones_like(x))
return out

def training_step(self, batch, batch_idx):
def training_step(self, batch: Tensor, batch_idx: int) -> STEP_OUTPUT: # type: ignore[override]
output = self(batch)
loss = self.loss(batch, output)
return {"loss": loss}

def training_step_end(self, training_step_outputs):
def training_step_end(self, training_step_outputs: STEP_OUTPUT) -> STEP_OUTPUT:
return training_step_outputs

def training_epoch_end(self, outputs) -> None:
def training_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:
outputs = cast(List[Dict[str, Tensor]], outputs)
torch.stack([x["loss"] for x in outputs]).mean()

def validation_step(self, batch, batch_idx):
def validation_step(self, batch: Tensor, batch_idx: int) -> Optional[STEP_OUTPUT]: # type: ignore[override]
output = self(batch)
loss = self.loss(batch, output)
return {"x": loss}

def validation_epoch_end(self, outputs) -> None:
def validation_epoch_end(self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]) -> None:
outputs = cast(List[Dict[str, Tensor]], outputs)
torch.stack([x["x"] for x in outputs]).mean()

def test_step(self, batch, batch_idx):
def test_step(self, batch: Tensor, batch_idx: int) -> Optional[STEP_OUTPUT]: # type: ignore[override]
output = self(batch)
loss = self.loss(batch, output)
return {"y": loss}

def test_epoch_end(self, outputs) -> None:
def test_epoch_end(self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]) -> None:
outputs = cast(List[Dict[str, Tensor]], outputs)
torch.stack([x["y"] for x in outputs]).mean()

def configure_optimizers(self):
def configure_optimizers(self) -> Tuple[List[torch.optim.Optimizer], List[_LRScheduler]]:
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
return [optimizer], [lr_scheduler]

def train_dataloader(self):
def train_dataloader(self) -> DataLoader:
return DataLoader(RandomDataset(32, 64))

def val_dataloader(self):
def val_dataloader(self) -> DataLoader:
return DataLoader(RandomDataset(32, 64))

def test_dataloader(self):
def test_dataloader(self) -> DataLoader:
return DataLoader(RandomDataset(32, 64))

def predict_dataloader(self):
def predict_dataloader(self) -> DataLoader:
return DataLoader(RandomDataset(32, 64))


Expand All @@ -155,7 +163,7 @@ def __init__(self, data_dir: str = "./"):
self.checkpoint_state: Optional[str] = None
self.random_full = RandomDataset(32, 64 * 4)

def setup(self, stage: Optional[str] = None):
def setup(self, stage: Optional[str] = None) -> None:
if stage == "fit" or stage is None:
self.random_train = Subset(self.random_full, indices=range(64))

Expand All @@ -168,26 +176,27 @@ def setup(self, stage: Optional[str] = None):
if stage == "predict" or stage is None:
self.random_predict = Subset(self.random_full, indices=range(64 * 3, 64 * 4))

def train_dataloader(self):
def train_dataloader(self) -> DataLoader:
return DataLoader(self.random_train)

def val_dataloader(self):
def val_dataloader(self) -> DataLoader:
return DataLoader(self.random_val)

def test_dataloader(self):
def test_dataloader(self) -> DataLoader:
return DataLoader(self.random_test)

def predict_dataloader(self):
def predict_dataloader(self) -> DataLoader:
return DataLoader(self.random_predict)


class ManualOptimBoringModel(BoringModel):
def __init__(self):
def __init__(self) -> None:
super().__init__()
self.automatic_optimization = False

def training_step(self, batch, batch_idx):
def training_step(self, batch: Tensor, batch_idx: int) -> STEP_OUTPUT: # type: ignore[override]
opt = self.optimizers()
assert isinstance(opt, (Optimizer, LightningOptimizer))
output = self(batch)
loss = self.loss(batch, output)
opt.zero_grad()
Expand All @@ -202,21 +211,21 @@ def __init__(self, out_dim: int = 10, learning_rate: float = 0.02):
self.l1 = torch.nn.Linear(32, out_dim)
self.learning_rate = learning_rate

def forward(self, x):
def forward(self, x: Tensor) -> Tensor: # type: ignore[override]
return torch.relu(self.l1(x.view(x.size(0), -1)))

def training_step(self, batch, batch_nb):
def training_step(self, batch: Tensor, batch_nb: int) -> STEP_OUTPUT: # type: ignore[override]
x = batch
x = self(x)
loss = x.sum()
return loss

def configure_optimizers(self):
def configure_optimizers(self) -> torch.optim.Optimizer:
return torch.optim.Adam(self.parameters(), lr=self.learning_rate)


class Net(nn.Module):
def __init__(self):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
Expand All @@ -225,7 +234,7 @@ def __init__(self):
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)

def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
Expand Down
17 changes: 1 addition & 16 deletions src/pytorch_lightning/strategies/sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from typing import Dict, Generator, List, Optional, Tuple, Union
from typing import Dict, Generator, List, Tuple, Union

from torch import Tensor
from torch.nn import Module
Expand All @@ -27,7 +27,6 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _FAIRSCALE_AVAILABLE, _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE
from pytorch_lightning.utilities.optimizer import optimizers_to_device
from pytorch_lightning.utilities.rank_zero import rank_zero_only

if _FAIRSCALE_AVAILABLE:
from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel
Expand Down Expand Up @@ -120,20 +119,6 @@ def _reinit_optimizers_with_oss(self, optimizers: List[Union[Optimizer, Lightnin
del optimizer
return optimizers

def optimizer_state(self, optimizer: "OSS") -> Optional[dict]:
if isinstance(optimizer, LightningOptimizer):
optimizer = optimizer._optimizer
optimizer.consolidate_state_dict()
return self._optim_state_dict(optimizer)

@rank_zero_only
def _optim_state_dict(self, optimizer):
"""
Retrieves state dict only on rank 0, which contains the entire optimizer state after calling
:meth:`consolidate_state_dict`.
"""
return optimizer.state_dict()

def pre_backward(self, closure_loss: Tensor) -> None:
pass

Expand Down
16 changes: 1 addition & 15 deletions src/pytorch_lightning/strategies/sharded_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from typing import Any, Dict, Generator, List, Tuple
from typing import Dict, Generator, List, Tuple

from torch import Tensor
from torch.nn import Module
Expand All @@ -25,7 +25,6 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _FAIRSCALE_AVAILABLE
from pytorch_lightning.utilities.optimizer import optimizers_to_device
from pytorch_lightning.utilities.rank_zero import rank_zero_only

if _FAIRSCALE_AVAILABLE:
from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel
Expand Down Expand Up @@ -85,11 +84,6 @@ def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["OSS"]:

return self._reinit_optimizers_with_oss(optimizers)

def optimizer_state(self, optimizer: "OSS") -> Dict[str, Any]:
if isinstance(optimizer, OSS):
optimizer.consolidate_state_dict()
return self._optim_state_dict(optimizer)

@contextmanager
def block_backward_sync(self) -> Generator:
"""Blocks syncing gradients behaviour on backwards pass.
Expand All @@ -103,14 +97,6 @@ def block_backward_sync(self) -> Generator:
else:
yield None

@rank_zero_only
def _optim_state_dict(self, optimizer: Optimizer) -> Dict[str, Any]:
"""
Retrieves state dict only on rank 0, which contains the entire optimizer state after calling
:meth:`consolidate_state_dict`.
"""
return optimizer.state_dict()

def pre_backward(self, closure_loss: Tensor) -> None:
pass

Expand Down
10 changes: 10 additions & 0 deletions src/pytorch_lightning/strategies/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,16 @@ def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]:
Allows for syncing/collating optimizer state from processes in custom plugins.
"""
if isinstance(optimizer, LightningOptimizer):
optimizer = optimizer._optimizer

if hasattr(optimizer, "consolidate_state_dict"):
# there are optimizers like Fairscale's OSS or PyTorch's ZeroRedundancyOptimizer that shard their
# states, and to avoid OOM we consolidate the full state on rank 0 only
optimizer.consolidate_state_dict()
return optimizer.state_dict() if self.is_global_zero else {}

# for optimizers that are not sharded, we return the state dict on all ranks
return optimizer.state_dict()

def backward(
Expand Down
13 changes: 8 additions & 5 deletions src/pytorch_lightning/utilities/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,14 +321,17 @@ def _lightning_get_all_attr_holders(model: "pl.LightningModule", attribute: str)
holders.append(model)

# Check if attribute in model.hparams, either namespace or dict
if hasattr(model, "hparams"):
if attribute in model.hparams:
holders.append(model.hparams)
if hasattr(model, "hparams") and attribute in model.hparams:
holders.append(model.hparams)

trainer = model._trainer
# Check if the attribute in datamodule (datamodule gets registered in Trainer)
if trainer is not None and trainer.datamodule is not None and hasattr(trainer.datamodule, attribute):
holders.append(trainer.datamodule)
if trainer is not None and trainer.datamodule is not None:
if hasattr(trainer.datamodule, attribute):
holders.append(trainer.datamodule)

if hasattr(trainer.datamodule, "hparams") and attribute in trainer.datamodule.hparams:
holders.append(trainer.datamodule.hparams)

return holders

Expand Down
Loading

0 comments on commit 38f9bc8

Please sign in to comment.