-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MisconfigurationException: Trying to inject DistributedSampler
into the AnnLoader
instance
#12917
Comments
@mbuttner Thanks for reporting this! Here's a smaller reproduction code that copies how the import os
from typing import Any
import torch
from torch.utils.data import DataLoader
from pytorch_lightning import LightningModule, Trainer
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 2)
def forward(self, x):
return self.layer(x)
class AnnLoader(DataLoader):
def __init__(self, adatas: Any, **kwargs):
super().__init__(adatas, batch_size=batch_size, shuffle=shuffle, **kwargs)
def run():
data = AnnLoader([1, 2, 3], batch_size=2)
model = BoringModel()
trainer = Trainer(
default_root_dir=os.getcwd(),
limit_predict_batches=1,
replace_sampler_ddp=False
)
trainer.predict(model, data)
if __name__ == "__main__":
run() The error you are experiencing could be solved 2 ways:
def __init__(self, adatas: Any, **kwargs):
+ self.adatas = adatas
super().__init__(adatas, batch_size=batch_size, shuffle=shuffle, **kwargs)
+ def predict_dataloader(self):
+ return AnnLoader([1, 2, 3], batch_size=2) Still, a different error will now appear:
which would require the following change in the implementation: def __init__(self, adatas: Any, **kwargs):
+ kwargs.pop("dataset", None) # `dataset` is `adatas` already
super().__init__(adatas, **kwargs) However, I can see that you are setting But the changes described above should unblock you for the moment being |
Hi again!
After more investigation, turns out we need to inject more components other than the I suggest your propose the above changes to the DataLoader implementation to its authors so that it can work as expected. Thank you for your report! |
Hi, |
Sorry for the back and forth! We just thought about an improvement to remove this requirement as long as the |
Hi, I'm seeing a similar issue even when I define class MyLightningModule(LightningModule)
...
def predict_dataloader(self, samples, **kwargs):
"""generate dataloader for inference"""
return self.inference_dataloader_cls(
samples,
self.preprocessor,
collate_fn=identity,
shuffle=False,
pin_memory=False if self.device == torch.device("cpu") else True,
**kwargs,
) where
I get quite an informative error message:
But, adding the predict_dataloader method to MyLightningModule doesn't resolve the error so I'm confused. Thanks for any help |
The issue seems to be resolved by adding the so called "missing arguments" as attributes when instantiating the custom dataloader class, i.e.
in the |
🐛 Bug
This a copy of the issue 757 posted at the anndata github repository.
I have been trying to implement an MLP to predict cell type labels using pyTorch Lightning and the AnnLoader function from the anndata Python package. For the implementation, I followed the AnnLoader tutorial to interface with pyTorch models and the PyTorch Lightning tutorial.
I aim to implement the training, test and prediction methods, and run it on a GPU. I tested my code on a Google Colabs instance. The error message is the same for GPU and CPU runtime.
When I try to predict a cell type label using the predict function, pyTorch lightning wants to use the DistributedSampler as sampler, which is not implemented in the AnnLoader and I could not figure out how to disable the sampler.
My intuition is that this is an issue with the setup of the prediction step as the test step works without error.
To Reproduce
Here is the error message from the prediction step:
Expected behavior
I expected a return of the predicted label per cell (i.e. per input).
Environment
conda
,pip
, source): piptorch.__config__.show()
:Output of all versions:
Additional context
cc @justusschock @awaelchli @ninginthecloud @rohitgr7
The text was updated successfully, but these errors were encountered: