Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MisconfigurationException: Trying to inject DistributedSampler into the AnnLoader instance #12917

Closed
mbuttner opened this issue Apr 28, 2022 · 6 comments · Fixed by #12981
Closed
Assignees
Labels
bug Something isn't working data handling Generic data-related topic trainer: predict
Milestone

Comments

@mbuttner
Copy link

mbuttner commented Apr 28, 2022

🐛 Bug

This a copy of the issue 757 posted at the anndata github repository.

I have been trying to implement an MLP to predict cell type labels using pyTorch Lightning and the AnnLoader function from the anndata Python package. For the implementation, I followed the AnnLoader tutorial to interface with pyTorch models and the PyTorch Lightning tutorial.
I aim to implement the training, test and prediction methods, and run it on a GPU. I tested my code on a Google Colabs instance. The error message is the same for GPU and CPU runtime.
When I try to predict a cell type label using the predict function, pyTorch lightning wants to use the DistributedSampler as sampler, which is not implemented in the AnnLoader and I could not figure out how to disable the sampler.
My intuition is that this is an issue with the setup of the prediction step as the test step works without error.

To Reproduce

import gdown
import pytorch_lightning as pl
import torch
import torch.nn as nn

import numpy as np
import scanpy as sc
from sklearn.preprocessing import OneHotEncoder, LabelEncoder
from torchmetrics.functional import accuracy
from anndata.experimental.pytorch import AnnLoader

#define model class 
class MLP(pl.LightningModule):
  
  def __init__(self, input_dim, hidden_dims, out_dim):
    super().__init__()
    modules = []
    for in_size, out_size in zip([input_dim]+hidden_dims, hidden_dims):
        modules.append(nn.Linear(in_size, out_size))
        modules.append(nn.LayerNorm(out_size))
        modules.append(nn.ReLU())
        modules.append(nn.Dropout(p=0.05))
    modules.append(nn.Linear(hidden_dims[-1], out_dim))
    self.layers = nn.Sequential(*modules)
    
    self.ce = nn.CrossEntropyLoss()
    
  def forward(self, x):
    return self.layers(x)
  
  def training_step(self, batch, batch_idx):
    # here, a batch has data (x) and labels (y). What is returned by
    # batch depends on the __get_item__() implementation in your Dataset
    x = batch.X
    y = batch.obs['cell_type'] #hard coded, please adapt
    x = x.view(x.size(0), -1)
    y_hat = self.layers(x)
    loss = self.ce(y_hat, y)
    self.log('train_loss', loss)
    return loss
  
  def test_step(self, batch, batch_idx):
    # here, a batch has data (x) and labels (y). What is returned by
    # batch depends on the __get_item__() implementation in your Dataset
    x = batch.X
    y = batch.obs['cell_type'] #hard coded, please adapt
    x = x.view(x.size(0), -1)
    y_hat = self.layers(x)
    loss = self.ce(y_hat, y)
    y_hat = torch.argmax(y_hat, dim=1)
    acc = accuracy(y_hat, y)
    metrics = dict({
        'test_loss': loss.clone().detach(),
        'test_acc': acc.clone().detach(),
    })
    self.log_dict(metrics, batch_size=len(y))
    return metrics

  def predict_step(self, batch, batch_idx, dataloader_idx=0):
    # here, a batch has data (x) and labels (y). What is returned by
    # batch depends on the __get_item__() implementation in your Dataset
    x = batch.X
    x = x.view(x.size(0), -1)
    y_hat = self.model(x)
    return y_hat

  def configure_optimizers(self):
    optimizer = torch.optim.Adam(self.parameters(), lr=1e-4)
    return optimizer

#load normalized pancreas data from AnnLoader tutorial 
from google.colab import drive
drive.mount('/content/drive')
file_path = '/content/drive/My Drive/pancreas_normalized.h5ad'

adata = sc.read(file_path)
adata.X = adata.raw.X # put raw counts to .X

#prepare AnnLoader 
encoder_study = OneHotEncoder(sparse=False, dtype=np.float32)
encoder_study.fit(adata.obs['study'].to_numpy()[:, None])

encoder_celltype = LabelEncoder()
encoder_celltype.fit(adata.obs['cell_type'])

use_cuda = torch.cuda.is_available()

encoders = {
    'obs': {
        'study': lambda s: encoder_study.transform(s.to_numpy()[:, None]),
        'cell_type': encoder_celltype.transform
    }
}

# Load data as dataLoader, split in train and test data  
dataloader = AnnLoader(adata[adata.obs['study']!='Pancreas Fluidigm C1'], batch_size=128, shuffle=True, convert=encoders, use_cuda=use_cuda)
dataloader_test = AnnLoader(adata[adata.obs['study']=='Pancreas Fluidigm C1'], batch_size=128, #sampler = sampler,  
                            shuffle=False, convert=encoders, use_cuda=use_cuda)

#create MLP model, configure pytorch lightning trainer 
mlp = MLP(input_dim = adata.n_vars, hidden_dims = [128,128], out_dim=8)
trainer = pl.Trainer(auto_scale_batch_size='power', gpus=1, deterministic=True, 
                     max_epochs=5, replace_sampler_ddp=False) 
# Train the model
trainer.fit(mlp, dataloader)

# Perform evaluation
trainer.test(mlp, dataloader_test)
## output [{'test_acc': 0.9620253443717957, 'test_loss': 0.12309074401855469}]

# Return predictions
trainer.predict(mlp, dataloader_test)

Here is the error message from the prediction step:

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:489: PossibleUserWarning: Your `predict_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test/predict dataloaders.
  category=PossibleUserWarning,
---------------------------------------------------------------------------
MisconfigurationException                 Traceback (most recent call last)
[<ipython-input-18-0ce04d8b9a10>](https://localhost:8080/#) in <module>()
      1 # Return predictions
----> 2 trainer.predict(mlp, dataloader_test)

11 frames
[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/utilities/data.py](https://localhost:8080/#) in _get_dataloader_init_kwargs(dataloader, sampler, mode)
    240         dataloader_cls_name = dataloader.__class__.__name__
    241         raise MisconfigurationException(
--> 242             f"Trying to inject `DistributedSampler` into the `{dataloader_cls_name}` instance. "
    243             "This would fail as some of the `__init__` arguments are not available as instance attributes. "
    244             f"The missing attributes are {required_args}. "

MisconfigurationException: Trying to inject `DistributedSampler` into the `AnnLoader` instance. This would fail as some of the `__init__` arguments are not available as instance attributes. The missing attributes are ['adatas']. HINT: If you wrote the `AnnLoader` class, define `self.missing_arg_name` or manually add the `DistributedSampler` as: `AnnLoader(dataset, sampler=DistributedSampler(dataset))`.

Expected behavior

I expected a return of the predicted label per cell (i.e. per input).

Environment

  • PyTorch Lightning Version (e.g., 1.5.0): 1.6.0
  • PyTorch Version (e.g., 1.10): 1.10.0+cu111
  • Python version (e.g., 3.9): 3.7.13
  • OS (e.g., Linux): Linux-5.4.144+-x86_64-with-Ubuntu-18.04-bionic
  • CUDA/cuDNN version: cu111 (from a Google Colab instance)
  • GPU models and configuration:
  • How you installed PyTorch (conda, pip, source): pip
  • If compiling from source, the output of torch.__config__.show():
  • Any other relevant information:

Output of all versions:

pyasn1                                      0.4.8
pyasn1_modules                              0.2.8
pydev_ipython                               NA
pydevconsole                                NA
pydevd                                      2.0.0
pydevd_concurrency_analyser                 NA
pydevd_file_utils                           NA
pydevd_plugins                              NA
pydevd_tracing                              NA
pydot_ng                                    2.0.0
pygments                                    2.6.1
pyparsing                                   3.0.7
pytorch_lightning                           1.6.0
pytz                                        2018.9
regex                                       2.5.72
requests                                    2.23.0
rsa                                         4.8
scipy                                       1.4.1
session_info                                1.0.0
setuptools                                  57.4.0
simplegeneric                               NA
sitecustomize                               NA
six                                         1.15.0
sklearn                                     1.0.2
socks                                       1.7.1
sphinxcontrib                               NA
storemagic                                  NA
tblib                                       1.7.0
tensorboard                                 2.8.0
tensorflow                                  2.8.0
termcolor                                   1.1.0
threadpoolctl                               3.1.0
toolz                                       0.11.2
torch                                       1.10.0+cu111
torchmetrics                                0.7.3
torchtext                                   0.11.0
torchvision                                 0.11.1+cu111
tornado                                     5.1.1
tqdm                                        4.63.0
traitlets                                   5.1.1
typing_extensions                           NA
uritemplate                                 3.0.1
urllib3                                     1.24.3
wcwidth                                     0.2.5
webencodings                                0.5.1
wrapt                                       1.14.0
yaml                                        6.0
zipp                                        NA
zmq                                         22.3.0
-----
IPython             5.5.0
jupyter_client      5.3.5
jupyter_core        4.9.2
notebook            5.3.1
-----
Python 3.7.13 (default, Mar 16 2022, 17:37:17) [GCC 7.5.0]
Linux-5.4.144+-x86_64-with-Ubuntu-18.04-bionic
-----
Session information updated at 2022-04-13 15:04

Additional context

cc @justusschock @awaelchli @ninginthecloud @rohitgr7

@mbuttner mbuttner added the needs triage Waiting to be triaged by maintainers label Apr 28, 2022
@carmocca
Copy link
Contributor

carmocca commented Apr 29, 2022

@mbuttner Thanks for reporting this!

Here's a smaller reproduction code that copies how the AnnLoader is implemented:

import os
from typing import Any

import torch
from torch.utils.data import DataLoader

from pytorch_lightning import LightningModule, Trainer


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)


class AnnLoader(DataLoader):
    def __init__(self, adatas: Any, **kwargs):
        super().__init__(adatas, batch_size=batch_size, shuffle=shuffle, **kwargs)


def run():
    data = AnnLoader([1, 2, 3], batch_size=2)
    model = BoringModel()
    trainer = Trainer(
        default_root_dir=os.getcwd(),
        limit_predict_batches=1,
        replace_sampler_ddp=False
    )
    trainer.predict(model, data)


if __name__ == "__main__":
    run()

The error you are experiencing could be solved 2 ways:

  1. Before 1.6.0, modify the AnnLoader implementation:
    def __init__(self, adatas: Any, **kwargs):
+       self.adatas = adatas
        super().__init__(adatas, batch_size=batch_size, shuffle=shuffle, **kwargs)
  1. After 1.6.0, we do this automatically as long as you create the dataloader in a hook:
+    def predict_dataloader(self):
+        return AnnLoader([1, 2, 3], batch_size=2)

Still, a different error will now appear:

The AnnLoader DataLoader implementation has an error where more than one __init__ argument can be passed to its parent's dataset=... __init__ argument. This is likely caused by allowing passing both a custom argument that will map to the dataset argument as well as **kwargs. kwargs should be filtered to make sure they don't contain the dataset key.

which would require the following change in the implementation:

    def __init__(self, adatas: Any, **kwargs):
+       kwargs.pop("dataset", None)  # `dataset` is `adatas` already 
        super().__init__(adatas, **kwargs)

However, I can see that you are setting Trainer(replace_sampler_ddp=False) and this is only a problem with trainer.predict, meaning we have a bug in our data handling logic. So we should fix that.

But the changes described above should unblock you for the moment being

@carmocca carmocca added bug Something isn't working data handling Generic data-related topic trainer: predict and removed needs triage Waiting to be triaged by maintainers labels Apr 29, 2022
@carmocca carmocca added this to the 1.6.x milestone Apr 29, 2022
@carmocca
Copy link
Contributor

carmocca commented May 2, 2022

Hi again!

However, I can see that you are setting Trainer(replace_sampler_ddp=False) and this is only a problem with trainer.predict, meaning we have a bug in our data handling logic. So we should fix that

After more investigation, turns out we need to inject more components other than the DistributedSampler in the case of predict, namely a custom batch sampler: https://github.com/PyTorchLightning/pytorch-lightning/blob/0e4c4424fd2aa9528a20db7873c9adb9b9ba7465/pytorch_lightning/utilities/data.py#L296-L297 which means that the requirements in the DataLoader implementation are still necessary. The message is just misleading in that case.

I suggest your propose the above changes to the DataLoader implementation to its authors so that it can work as expected.

Thank you for your report!

@carmocca carmocca closed this as completed May 2, 2022
@mbuttner
Copy link
Author

mbuttner commented May 3, 2022

Hi,
thank you for your quick response and looking into this! I'll get in touch with the anndata developers.

@carmocca
Copy link
Contributor

carmocca commented May 3, 2022

Sorry for the back and forth! We just thought about an improvement to remove this requirement as long as the DataLoader is instantiated inside a *_dataloader hook

@carmocca carmocca moved this to In Progress in Frameworks Planning May 3, 2022
@carmocca carmocca moved this from In Progress to In Review in Frameworks Planning May 6, 2022
Repository owner moved this from In Review to Done in Frameworks Planning Jun 21, 2022
@sammlapp
Copy link

sammlapp commented May 24, 2024

Hi, I'm seeing a similar issue even when I define MyLightningModule.predict_dataloader():

class MyLightningModule(LightningModule)
    ...

    def predict_dataloader(self, samples, **kwargs):
        """generate dataloader for inference"""

        return self.inference_dataloader_cls(
            samples,
            self.preprocessor,
            collate_fn=identity,
            shuffle=False,
            pin_memory=False if self.device == torch.device("cpu") else True,
            **kwargs,
        )

where self.inference_dataloader_cls=SafeAudioDataloader is a custom DataLoader class.

self.preprocessor is an object that defines the preprocessing operations of the Dataset initialized within SafeAudioDataloader

I get quite an informative error message:

MisconfigurationException: Trying to inject custom Sampler into the SafeAudioDataloader instance. This would fail as some of the __init__ arguments are not available as instance attributes. The missing attributes are ['preprocessor', 'samples']. If you instantiate your SafeAudioDataloader inside a *_dataloader hook of your module, we will do this for you. Otherwise, define self.preprocessor, self.samples inside your __init__.

But, adding the predict_dataloader method to MyLightningModule doesn't resolve the error so I'm confused.

Thanks for any help

@sammlapp
Copy link

The issue seems to be resolved by adding the so called "missing arguments" as attributes when instantiating the custom dataloader class, i.e.

self.samples = samples
self.preprocessor = preprocessor

in the __init__ method of my custom class SafeAudioDataloader. However, I would prefer not to set these arguments as attributes - the arguments are used to initialize SafeAudioDataloader.dataset but modifying them would be nonsensical and would not update the .dataset. Is there an alternative to this workaround? @carmocca any advice?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working data handling Generic data-related topic trainer: predict
Projects
No open projects
Status: Done
4 participants