Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tensors not on the same device when using FSDP auto-wrapping #14900

Closed
awaelchli opened this issue Sep 27, 2022 · 4 comments · Fixed by #15301
Closed

Tensors not on the same device when using FSDP auto-wrapping #14900

awaelchli opened this issue Sep 27, 2022 · 4 comments · Fixed by #15301
Assignees
Labels
bug Something isn't working strategy: fairscale fsdp (removed) Fully Sharded Data Parallel

Comments

@awaelchli
Copy link
Contributor

Bug

When using the FSDP with defaults and auto-wrapping, there is an error in forward saying tensors and weights are not on the same device.

To Reproduce

import os

import torch
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningModule, Trainer


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)


def run():
    train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    val_data = DataLoader(RandomDataset(32, 64), batch_size=2)
    test_data = DataLoader(RandomDataset(32, 64), batch_size=2)

    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_train_batches=1,
        limit_val_batches=1,
        limit_test_batches=1,
        num_sanity_val_steps=0,
        max_epochs=1,
        accelerator='gpu',
        enable_model_summary=False,
        precision=16,
        strategy='fsdp_native'
    )
    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
    trainer.test(model, dataloaders=test_data)


if __name__ == "__main__":
    run()

Error:

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument mat1 in method wrapper_addmm)

Additional context

Reported by user kavya on Slack.
Comment:

however fairscale fsdp strategy seems to work fine with mixed precision. any additional steps necessary for fsdp_native with torch==1.12.1+cu113 and pytorch-lightning==1.7.7


If you enjoy Lightning, check out our other projects! ⚡

  • Metrics: Machine learning metrics for distributed, scalable PyTorch applications.

  • Lite: enables pure PyTorch users to scale their existing code on any kind of device while retaining full control over their own loops and optimization logic.

  • Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, fine-tuning, and solving problems with deep learning.

  • Bolts: Pretrained SOTA Deep Learning models, callbacks, and more for research and production with PyTorch Lightning and PyTorch.

  • Lightning Transformers: Flexible interface for high-performance research using SOTA Transformers leveraging PyTorch Lightning, Transformers, and Hydra.

@awaelchli awaelchli added bug Something isn't working strategy: fairscale fsdp (removed) Fully Sharded Data Parallel labels Sep 27, 2022
@awaelchli
Copy link
Contributor Author

awaelchli commented Sep 27, 2022

Auto-wrapping support was added in #14383. I can't see this reproduce on master the same way, but I get:

  File "/home/adrian/repositories/lightning/repro.py", line 70, in <module>
    run()
  File "/home/adrian/repositories/lightning/repro.py", line 65, in run
    trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data)
  File "/home/adrian/repositories/lightning/src/pytorch_lightning/trainer/trainer.py", line 570, in fit
    teardown.call_and_handle_interrupt(
  File "/home/adrian/repositories/lightning/src/pytorch_lightning/trainer/teardown.py", line 34, in call_and_handle_interrupt
    return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
  File "/home/adrian/repositories/lightning/src/pytorch_lightning/strategies/launchers/subprocess_script.py", line 91, in launch
    return function(*args, **kwargs)
  File "/home/adrian/repositories/lightning/src/pytorch_lightning/trainer/trainer.py", line 609, in _fit_impl
    results = self._run(model, ckpt_path=self.ckpt_path)
  File "/home/adrian/repositories/lightning/src/pytorch_lightning/trainer/trainer.py", line 1034, in _run
    self.strategy.setup(self)
  File "/home/adrian/repositories/lightning/src/pytorch_lightning/strategies/fully_sharded_native.py", line 249, in setup
    self.setup_optimizers(trainer)
  File "/home/adrian/repositories/lightning/src/pytorch_lightning/strategies/strategy.py", line 142, in setup_optimizers
    self.optimizers, self.lr_scheduler_configs, self.optimizer_frequencies = _init_optimizers_and_lr_schedulers(
  File "/home/adrian/repositories/lightning/src/pytorch_lightning/core/optimizer.py", line 180, in _init_optimizers_and_lr_schedulers
    optim_conf = model.trainer._call_lightning_module_hook("configure_optimizers", pl_module=model)
  File "/home/adrian/repositories/lightning/src/pytorch_lightning/trainer/trainer.py", line 1302, in _call_lightning_module_hook
    output = fn(*args, **kwargs)
  File "/home/adrian/repositories/lightning/repro.py", line 43, in configure_optimizers
    return torch.optim.SGD(self.layer.parameters(), lr=0.1)
  File "/home/adrian/anaconda3/envs/lightning/lib/python3.9/site-packages/torch/optim/sgd.py", line 109, in __init__
    super(SGD, self).__init__(params, defaults)
  File "/home/adrian/anaconda3/envs/lightning/lib/python3.9/site-packages/torch/optim/optimizer.py", line 61, in __init__
    raise ValueError("optimizer got an empty parameter list")
ValueError: optimizer got an empty parameter list

setup_optimizers() is probably getting called too early?

cc @rohitgr7

@rohitgr7
Copy link
Contributor

ah! weird. there are tests that check this. Let me explore it more.
can you try with?

def configure_optimizers(self):
    return torch.optim.SGD(self.trainer.model.parameters(), lr=0.1)

@rohitgr7 rohitgr7 self-assigned this Sep 28, 2022
@w3nhao
Copy link

w3nhao commented Oct 11, 2022

ah! weird. there are tests that check this. Let me explore it more. can you try with?

def configure_optimizers(self):
    return torch.optim.SGD(self.trainer.model.parameters(), lr=0.1)

hi @rohitgr7 , I've tried your suggestion and the code works without throwing any errors, but then I want to set different param groups with different LR, and it seems these params become one single group called _fsdp_wrapped_module.flat_param, what am I suppose to do?

@rohitgr7
Copy link
Contributor

yes @TOPFARMER that's the case actually: pytorch/pytorch#76382

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working strategy: fairscale fsdp (removed) Fully Sharded Data Parallel
Projects
None yet
4 participants