Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inheritance of Data Loaders #12564

Closed
rusty1s opened this issue Apr 1, 2022 · 0 comments · Fixed by #12716
Closed

Inheritance of Data Loaders #12564

rusty1s opened this issue Apr 1, 2022 · 0 comments · Fixed by #12716
Assignees
Labels
bug Something isn't working data handling Generic data-related topic
Milestone

Comments

@rusty1s
Copy link
Contributor

rusty1s commented Apr 1, 2022

🐛 Bug

This is the official bug report of my comment earlier (see #10680 (comment)).

import pytorch_lightning as pl
import torch


class Model(pl.LightningModule):
    def __init__(self, in_channels):
        super().__init__()
        self.lin = torch.nn.Linear(in_channels, 1)

    def forward(self, x):
        return self.lin(x)

    def training_step(self, x, batch_idx):
        loss = self(x).mean()
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.01)


class MyBaseDataLoader(torch.utils.data.DataLoader):
    pass


class MyDataLoader(MyBaseDataLoader):
    def __init__(self, data: torch.Tensor, *args, **kwargs):
        self.data = data
        super().__init__(range(data.size(0)), *args, **kwargs,
                         collate_fn=self.sample_fn)

    def sample_fn(self, indices):
        assert isinstance(self.data, torch.Tensor)
        index = torch.tensor(indices)
        return self.data[index]


class DataModule(pl.LightningDataModule):
    def __init__(self, data: torch.Tensor):
        super().__init__()
        self.data = data

    def train_dataloader(self):
        return MyDataLoader(self.data, batch_size=32)


model = Model(in_channels=32)
datamodule = DataModule(torch.randn(100, 32))

trainer = pl.Trainer(accelerator='gpu', devices=1, max_epochs=10)
trainer.fit(model, datamodule)

which crashes with

  File "/home/matthias/miniconda3/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 530, in __next__
    data = self._next_data()
  File "/home/matthias/miniconda3/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 570, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
  File "/home/matthias/miniconda3/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 52, in fetch
    return self.collate_fn(data)
  File "/home/matthias/github/pytorch_geometric/alpha/PL.py", line 32, in sample_fn
    assert isinstance(self.data, torch.Tensor)
AssertionError

Notably, the code runs through for pytorch-lightning==1.5.10 and when I let MyDataLoader inherit directly from torch.utils.data.DataLoader.

I personally think this issue is related to overriding both __init__ from MyBaseDataLoader and MyDataLoader here.

  • PyTorch Lightning Version (e.g., 1.5.0): 1.6.0
  • PyTorch Version (e.g., 1.10): 1.10
  • Python version (e.g., 3.9): 3.9
  • OS (e.g., Linux): Linux
  • CUDA/cuDNN version: 11.3

cc @justusschock @awaelchli @ninginthecloud

@rusty1s rusty1s added the needs triage Waiting to be triaged by maintainers label Apr 1, 2022
@rusty1s rusty1s changed the title Nested Inheritence of Data Loaders Nested Inheritance of Data Loaders Apr 2, 2022
@rusty1s rusty1s changed the title Nested Inheritance of Data Loaders Inheritance of Data Loaders Apr 2, 2022
@akihironitta akihironitta added data handling Generic data-related topic and removed needs triage Waiting to be triaged by maintainers labels Apr 4, 2022
@carmocca carmocca added the bug Something isn't working label Apr 6, 2022
@carmocca carmocca self-assigned this Apr 6, 2022
@carmocca carmocca added this to the 1.6.x milestone Apr 6, 2022
@carmocca carmocca assigned otaj and unassigned carmocca Apr 11, 2022
@carmocca carmocca moved this to In Review in Frameworks Planning Apr 12, 2022
Repository owner moved this from In Review to Done in Frameworks Planning Apr 15, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working data handling Generic data-related topic
Projects
No open projects
Status: Done
Development

Successfully merging a pull request may close this issue.

4 participants