Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Fix drop last for predicting and testing #671

Merged
merged 4 commits into from
Aug 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed a bug where it was not possible to pass no metrics to the `ImageClassifier` or `TestClassifier` ([#660](https://github.com/PyTorchLightning/lightning-flash/pull/660))

- Fixed a bug where `drop_last` would be set to True during prediction and testing ([#671](https://github.com/PyTorchLightning/lightning-flash/pull/671))

## [0.4.0] - 2021-06-22

### Added
Expand Down
4 changes: 2 additions & 2 deletions flash/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def process_test_dataset(
pin_memory: bool,
collate_fn: Callable,
shuffle: bool = False,
drop_last: bool = True,
drop_last: bool = False,
sampler: Optional[Sampler] = None,
) -> DataLoader:
return self._process_dataset(
Expand All @@ -204,7 +204,7 @@ def process_predict_dataset(
pin_memory: bool = False,
collate_fn: Callable = None,
shuffle: bool = False,
drop_last: bool = True,
drop_last: bool = False,
sampler: Optional[Sampler] = None,
) -> DataLoader:
return self._process_dataset(
Expand Down
20 changes: 12 additions & 8 deletions tests/core/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from itertools import chain
from numbers import Number
from pathlib import Path
from typing import Any, Tuple
Expand Down Expand Up @@ -52,14 +53,20 @@ class Image:


class DummyDataset(torch.utils.data.Dataset):
def __init__(self, num_samples: int = 9):
self.num_samples = num_samples

def __getitem__(self, index: int) -> Tuple[Tensor, Number]:
return torch.rand(1, 28, 28), torch.randint(10, size=(1,)).item()

def __len__(self) -> int:
return 9
return self.num_samples


class PredictDummyDataset(DummyDataset):
def __init__(self, num_samples: int):
super().__init__(num_samples)

def __getitem__(self, index: int) -> Tensor:
return torch.rand(1, 28, 28)

Expand Down Expand Up @@ -211,15 +218,12 @@ def _rand_image():
def test_classification_task_trainer_predict(tmpdir):
model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10))
task = ClassificationTask(model)
ds = PredictDummyDataset()
batch_size = 3
predict_dl = torch.utils.data.DataLoader(ds, batch_size=batch_size)
ds = PredictDummyDataset(10)
batch_size = 6
predict_dl = task.process_predict_dataset(ds, batch_size=batch_size)
trainer = pl.Trainer(default_root_dir=tmpdir)
predictions = trainer.predict(task, predict_dl)
assert len(predictions) == len(ds) // batch_size
for batch_pred in predictions:
assert len(batch_pred) == batch_size
assert all(y < 10 for y in batch_pred)
assert len(list(chain.from_iterable(predictions))) == 10


def test_task_datapipeline_save(tmpdir):
Expand Down