Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

anndata DataLoader for pyTorch without DistributedSampler #757

Open
mbuttner opened this issue Apr 13, 2022 · 5 comments
Open

anndata DataLoader for pyTorch without DistributedSampler #757

mbuttner opened this issue Apr 13, 2022 · 5 comments
Assignees

Comments

@mbuttner
Copy link

mbuttner commented Apr 13, 2022

Hi there,
I have been trying to implement an MLP to predict cell type labels using pyTorch Lightning and the AnnLoader function.
For the implementation, I followed the AnnLoader tutorial to interface with pyTorch models and the PyTorch Lightning tutorial.
I aim to implement the training, test and prediction methods, and run it on a GPU. I tested my code on a Google Colabs instance. The error message is the same for GPU and CPU runtime.
When I try to predict a cell type label using the predict function, pyTorch lightning wants to use the DistributedSampler as sampler, which is not implemented in the AnnLoader and I could not figure out how to disable the sampler.

Here's my code:

import gdown
import pytorch_lightning as pl
import torch
import torch.nn as nn

import numpy as np
import scanpy as sc
from sklearn.preprocessing import OneHotEncoder, LabelEncoder
from torchmetrics.functional import accuracy
from anndata.experimental.pytorch import AnnLoader

#define model class 
class MLP(pl.LightningModule):
  
  def __init__(self, input_dim, hidden_dims, out_dim):
    super().__init__()
    modules = []
    for in_size, out_size in zip([input_dim]+hidden_dims, hidden_dims):
        modules.append(nn.Linear(in_size, out_size))
        modules.append(nn.LayerNorm(out_size))
        modules.append(nn.ReLU())
        modules.append(nn.Dropout(p=0.05))
    modules.append(nn.Linear(hidden_dims[-1], out_dim))
    self.layers = nn.Sequential(*modules)
    
    self.ce = nn.CrossEntropyLoss()
    
  def forward(self, x):
    return self.layers(x)
  
  def training_step(self, batch, batch_idx):
    # here, a batch has data (x) and labels (y). What is returned by
    # batch depends on the __get_item__() implementation in your Dataset
    x = batch.X
    y = batch.obs['cell_type'] #hard coded, please adapt
    x = x.view(x.size(0), -1)
    y_hat = self.layers(x)
    loss = self.ce(y_hat, y)
    self.log('train_loss', loss)
    return loss
  
  def test_step(self, batch, batch_idx):
    # here, a batch has data (x) and labels (y). What is returned by
    # batch depends on the __get_item__() implementation in your Dataset
    x = batch.X
    y = batch.obs['cell_type'] #hard coded, please adapt
    x = x.view(x.size(0), -1)
    y_hat = self.layers(x)
    loss = self.ce(y_hat, y)
    y_hat = torch.argmax(y_hat, dim=1)
    acc = accuracy(y_hat, y)
    metrics = dict({
        'test_loss': loss.clone().detach(),
        'test_acc': acc.clone().detach(),
    })
    self.log_dict(metrics, batch_size=len(y))
    return metrics

  def predict_step(self, batch, batch_idx, dataloader_idx=0):
    # here, a batch has data (x) and labels (y). What is returned by
    x = batch.X
    x = x.view(x.size(0), -1)
    y_hat = self.model(x)
    return y_hat

  def configure_optimizers(self):
    optimizer = torch.optim.Adam(self.parameters(), lr=1e-4)
    return optimizer

#load normalized pancreas data from AnnLoader tutorial 
from google.colab import drive
drive.mount('/content/drive')
file_path = '/content/drive/My Drive/pancreas_normalized.h5ad'

adata = sc.read(file_path)
adata.X = adata.raw.X # put raw counts to .X

#prepare AnnLoader 
encoder_study = OneHotEncoder(sparse=False, dtype=np.float32)
encoder_study.fit(adata.obs['study'].to_numpy()[:, None])

encoder_celltype = LabelEncoder()
encoder_celltype.fit(adata.obs['cell_type'])

use_cuda = torch.cuda.is_available()

encoders = {
    'obs': {
        'study': lambda s: encoder_study.transform(s.to_numpy()[:, None]),
        'cell_type': encoder_celltype.transform
    }
}

# Load data as dataLoader, split in train and test data  
dataloader = AnnLoader(adata[adata.obs['study']!='Pancreas Fluidigm C1'], batch_size=128, shuffle=True, convert=encoders, use_cuda=use_cuda)
dataloader_test = AnnLoader(adata[adata.obs['study']=='Pancreas Fluidigm C1'], batch_size=128, #sampler = sampler,  
                            shuffle=False, convert=encoders, use_cuda=use_cuda)

#create MLP model, configure pytorch lightning trainer 
mlp = MLP(input_dim = adata.n_vars, hidden_dims = [128,128], out_dim=8)
trainer = pl.Trainer(auto_scale_batch_size='power', gpus=1, deterministic=True, 
                     max_epochs=5, replace_sampler_ddp=False) 
# Train the model
trainer.fit(mlp, dataloader)

# Perform evaluation
trainer.test(mlp, dataloader_test)
## output [{'test_acc': 0.9620253443717957, 'test_loss': 0.12309074401855469}]

# Return predictions
trainer.predict(mlp, dataloader_test)

Here is the error code from the prediction step:

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:489: PossibleUserWarning: Your `predict_dataloader`'s sampler has shuffling enabled, it is strongly recommended that you turn shuffling off for val/test/predict dataloaders.
  category=PossibleUserWarning,
---------------------------------------------------------------------------
MisconfigurationException                 Traceback (most recent call last)
[<ipython-input-18-0ce04d8b9a10>](https://localhost:8080/#) in <module>()
      1 # Return predictions
----> 2 trainer.predict(mlp, dataloader_test)

11 frames
[/usr/local/lib/python3.7/dist-packages/pytorch_lightning/utilities/data.py](https://localhost:8080/#) in _get_dataloader_init_kwargs(dataloader, sampler, mode)
    240         dataloader_cls_name = dataloader.__class__.__name__
    241         raise MisconfigurationException(
--> 242             f"Trying to inject `DistributedSampler` into the `{dataloader_cls_name}` instance. "
    243             "This would fail as some of the `__init__` arguments are not available as instance attributes. "
    244             f"The missing attributes are {required_args}. "

MisconfigurationException: Trying to inject `DistributedSampler` into the `AnnLoader` instance. This would fail as some of the `__init__` arguments are not available as instance attributes. The missing attributes are ['adatas']. HINT: If you wrote the `AnnLoader` class, define `self.missing_arg_name` or manually add the `DistributedSampler` as: `AnnLoader(dataset, sampler=DistributedSampler(dataset))`.

Versions:

pyasn1                                      0.4.8
pyasn1_modules                              0.2.8
pydev_ipython                               NA
pydevconsole                                NA
pydevd                                      2.0.0
pydevd_concurrency_analyser                 NA
pydevd_file_utils                           NA
pydevd_plugins                              NA
pydevd_tracing                              NA
pydot_ng                                    2.0.0
pygments                                    2.6.1
pyparsing                                   3.0.7
pytorch_lightning                           1.6.0
pytz                                        2018.9
regex                                       2.5.72
requests                                    2.23.0
rsa                                         4.8
scipy                                       1.4.1
session_info                                1.0.0
setuptools                                  57.4.0
simplegeneric                               NA
sitecustomize                               NA
six                                         1.15.0
sklearn                                     1.0.2
socks                                       1.7.1
sphinxcontrib                               NA
storemagic                                  NA
tblib                                       1.7.0
tensorboard                                 2.8.0
tensorflow                                  2.8.0
termcolor                                   1.1.0
threadpoolctl                               3.1.0
toolz                                       0.11.2
torch                                       1.10.0+cu111
torchmetrics                                0.7.3
torchtext                                   0.11.0
torchvision                                 0.11.1+cu111
tornado                                     5.1.1
tqdm                                        4.63.0
traitlets                                   5.1.1
typing_extensions                           NA
uritemplate                                 3.0.1
urllib3                                     1.24.3
wcwidth                                     0.2.5
webencodings                                0.5.1
wrapt                                       1.14.0
yaml                                        6.0
zipp                                        NA
zmq                                         22.3.0
-----
IPython             5.5.0
jupyter_client      5.3.5
jupyter_core        4.9.2
notebook            5.3.1
-----
Python 3.7.13 (default, Mar 16 2022, 17:37:17) [GCC 7.5.0]
Linux-5.4.144+-x86_64-with-Ubuntu-18.04-bionic
-----
Session information updated at 2022-04-13 15:04
@Koncopd
Copy link
Member

Koncopd commented Apr 13, 2022

Yes, the way pytorch lightning tries to inject the sampler is not supported, probably you need to disable this somehow. But i will check what can be done to fix this.

@adamgayoso
Copy link
Member

it honestly sounds like a bug in pytorch lightning as the code works for fit. it should remember that you didn't want to replace samplers. I would make an issue there

@mbuttner
Copy link
Author

mbuttner commented Apr 15, 2022 via email

@mbuttner
Copy link
Author

mbuttner commented May 3, 2022

Hi there,
I have got a reply from the PyTorch lightning developers, who suggested some change in the DataLoader, see issue posted there. Is this something to implement without much hassle? Thank you!

@Koncopd
Copy link
Member

Koncopd commented May 3, 2022

@mbuttner thank you, i will check their proposed changes.

@Koncopd Koncopd self-assigned this May 4, 2022
@github-actions github-actions bot added stale and removed stale labels Jun 17, 2023
@github-actions github-actions bot added the stale label Aug 21, 2023
@scverse scverse deleted a comment from github-actions bot Aug 21, 2023
@scverse scverse deleted a comment from github-actions bot Aug 21, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants