Skip to content

Commit

Permalink
Set shuffle=False in case a sampler is given (#5457)
Browse files Browse the repository at this point in the history
* sampler support

* update

* update

* follow-up

* changelog
  • Loading branch information
rusty1s authored Sep 16, 2022
1 parent ec7400e commit 8ac0a48
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 4 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [2.2.0] - 2022-MM-DD
### Added
- Added `sampler` support to `LightningDataModule` ([#5456](https://github.com/pyg-team/pytorch_geometric/pull/5456))
- Added `sampler` support to `LightningDataModule` ([#5456](https://github.com/pyg-team/pytorch_geometric/pull/5456), [#5457](https://github.com/pyg-team/pytorch_geometric/pull/5457))
- Added official splits to `MalNetTiny` dataset ([#5078](https://github.com/pyg-team/pytorch_geometric/pull/5078))
- Added `IndexToMask` and `MaskToIndex` transforms ([#5375](https://github.com/pyg-team/pytorch_geometric/pull/5375), [#5455](https://github.com/pyg-team/pytorch_geometric/pull/5455))
- Added `FeaturePropagation` transform ([#5387](https://github.com/pyg-team/pytorch_geometric/pull/5387))
Expand Down
16 changes: 13 additions & 3 deletions torch_geometric/data/lightning_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,9 @@ def dataloader(self, dataset: Dataset, **kwargs) -> DataLoader:
def train_dataloader(self) -> DataLoader:
""""""
from torch.utils.data import IterableDataset
shuffle = not isinstance(self.train_dataset, IterableDataset)
shuffle = (not isinstance(self.train_dataset, IterableDataset)
and 'sampler' not in self.kwargs
and 'batch_sampler' not in self.kwargs)

return self.dataloader(self.train_dataset, shuffle=shuffle,
**self.kwargs)
Expand Down Expand Up @@ -344,6 +346,7 @@ def dataloader(
warnings.filterwarnings('ignore', '.*does not have many workers.*')
warnings.filterwarnings('ignore', '.*data loading bottlenecks.*')

kwargs['shuffle'] = True
kwargs.pop('sampler', None)
kwargs.pop('batch_sampler', None)

Expand All @@ -365,7 +368,10 @@ def dataloader(

def train_dataloader(self) -> DataLoader:
""""""
return self.dataloader(self.input_train_nodes, shuffle=True,
shuffle = ('sampler' not in self.kwargs
and 'batch_sampler' not in self.kwargs)

return self.dataloader(self.input_train_nodes, shuffle=shuffle,
**self.kwargs)

def val_dataloader(self) -> DataLoader:
Expand Down Expand Up @@ -564,6 +570,7 @@ def dataloader(
warnings.filterwarnings('ignore', '.*does not have many workers.*')
warnings.filterwarnings('ignore', '.*data loading bottlenecks.*')

kwargs['shuffle'] = True
kwargs.pop('sampler', None)
kwargs.pop('batch_sampler', None)

Expand All @@ -588,8 +595,11 @@ def dataloader(

def train_dataloader(self) -> DataLoader:
""""""
shuffle = ('sampler' not in self.kwargs
and 'batch_sampler' not in self.kwargs)

return self.dataloader(self.input_train_edges, self.input_train_labels,
self.input_train_time, shuffle=True,
self.input_train_time, shuffle=shuffle,
**self.kwargs)

def val_dataloader(self) -> DataLoader:
Expand Down

0 comments on commit 8ac0a48

Please sign in to comment.