Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Fix / simplify default_uncollate (#1077)
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris authored Dec 16, 2021
1 parent 7bcf4f1 commit f37e50d
Show file tree
Hide file tree
Showing 7 changed files with 93 additions and 91 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed a bug where under some circumstances transforms would not get called ([#1072](https://github.com/PyTorchLightning/lightning-flash/pull/1072))

- Fixed a bug where prediction would sometimes give the wrong number of outputs ([#1077](https://github.com/PyTorchLightning/lightning-flash/pull/1077))

### Removed

## [0.6.0] - 2021-13-12
Expand Down
53 changes: 33 additions & 20 deletions flash/core/data/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Sequence, TYPE_CHECKING
from typing import Any, List, TYPE_CHECKING

import torch
from torch import Tensor

from flash.core.data.utilities.classification import _is_list_like

if TYPE_CHECKING:
from flash.core.data.io.input import ServeInput
Expand All @@ -35,27 +36,39 @@ def forward(self, sample: str):
return sample


def default_uncollate(batch: Any):
"""
This function is used to uncollate a batch into samples.
Examples:
>>> a, b = default_uncollate(torch.rand((2,1)))
"""
def _is_list_like_excluding_str(x):
return _is_list_like(x) and str(x) != x

batch_type = type(batch)

if isinstance(batch, Tensor):
if len(batch.shape) == 0: # 0 shape tensors
return batch
return list(torch.unbind(batch, 0))
def default_uncollate(batch: Any) -> List[Any]:
"""This function is used to uncollate a batch into samples. The following conditions are used:
if isinstance(batch, dict):
return [batch_type(dict(zip(batch, default_uncollate(t)))) for t in zip(*batch.values())]
- if the ``batch`` is a ``dict``, the result will be a list of dicts
- if the ``batch`` is list-like, the result is guaranteed to be a list
if isinstance(batch, tuple) and hasattr(batch, "_fields"): # namedtuple
return [batch_type(*sample) for sample in zip(*batch)]
Args:
batch: The batch of outputs to be uncollated.
if isinstance(batch, Sequence) and not isinstance(batch, str):
return [sample for sample in batch]
Returns:
The uncollated list of predictions.
Raises:
ValueError: If the input is a ``dict`` whose values are not all list-like.
ValueError: If the input is a ``dict`` whose values are not all the same length.
ValueError: If the input is not a ``dict`` or list-like.
"""

if isinstance(batch, dict):
if any(not _is_list_like_excluding_str(sub_batch) for sub_batch in batch.values()):
raise ValueError("When uncollating a dict, all sub-batches (values) are expected to be list-like.")
if len({len(sub_batch) for sub_batch in batch.values()}) > 1:
raise ValueError("When uncollating a dict, all sub-batches (values) are expected to have the same length.")
elements = list(default_uncollate(element) for element in zip(*batch.values()))
return [dict(zip(batch.keys(), element)) for element in elements]

return batch
if _is_list_like_excluding_str(batch):
return list(batch)
raise ValueError(
"The batch of outputs to be uncollated is expected to be a `dict` or list-like "
"(e.g. `torch.Tensor`, `list`, `tuple`, etc.)."
)
14 changes: 1 addition & 13 deletions flash/core/data/io/output_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Mapping, Optional, Sequence, Tuple
from typing import Any, Callable, Optional, Sequence

import torch
from torch import Tensor

from flash.core.data.batch import default_uncollate
from flash.core.data.io.input import DataKeys
from flash.core.data.properties import Properties
from flash.core.data.utils import convert_to_modules

Expand Down Expand Up @@ -77,22 +76,11 @@ def __init__(
self.output = convert_to_modules(output)
self.is_serving = is_serving

@staticmethod
def _extract_metadata(batch: Any) -> Tuple[Any, Optional[Any]]:
metadata = None
if isinstance(batch, Mapping) and DataKeys.METADATA in batch:
metadata = batch.pop(DataKeys.METADATA, None)
return batch, metadata

def forward(self, batch: Sequence[Any]):
if batch is None:
return batch

batch, metadata = self._extract_metadata(batch)
uncollated = self.uncollate_fn(self.per_batch_transform(batch))
if metadata:
for sample, sample_metadata in zip(uncollated, metadata):
sample[DataKeys.METADATA] = sample_metadata

final_preds = [self.per_sample_transform(sample) for sample in uncollated]

Expand Down
14 changes: 11 additions & 3 deletions flash/core/integrations/icevision/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@
import flash
from flash.core.adapter import Adapter
from flash.core.data.io.input import DataKeys, InputBase
from flash.core.integrations.icevision.transforms import from_icevision_predictions, to_icevision_record
from flash.core.integrations.icevision.transforms import (
from_icevision_predictions,
from_icevision_record,
to_icevision_record,
)
from flash.core.model import Task
from flash.core.utilities.imports import _ICEVISION_AVAILABLE
from flash.core.utilities.url_error import catch_url_error
Expand Down Expand Up @@ -195,8 +199,12 @@ def test_step(self, batch, batch_idx):
return self.icevision_adapter.validation_step(batch[DataKeys.INPUT], batch_idx)

def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
batch[DataKeys.PREDS] = self(batch[DataKeys.INPUT])
return batch
records = batch[DataKeys.INPUT][1]
return {
DataKeys.INPUT: [from_icevision_record(record) for record in records],
DataKeys.PREDS: self(batch[DataKeys.INPUT]),
DataKeys.METADATA: batch[DataKeys.METADATA],
}

def forward(self, batch: Any) -> Any:
return from_icevision_predictions(
Expand Down
2 changes: 1 addition & 1 deletion flash/core/integrations/icevision/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def from_icevision_record(record: "BaseRecord"):
filepath = getattr(record, "filepath", None)
if filepath is not None:
sample[DataKeys.METADATA]["filepath"] = filepath
elif record.filepath is not None:
elif getattr(record, "filepath", None) is not None:
sample[DataKeys.INPUT] = record.filepath

sample[DataKeys.TARGET] = from_icevision_detection(record)
Expand Down
89 changes: 39 additions & 50 deletions tests/core/data/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,65 +13,54 @@
# limitations under the License.
from collections import namedtuple

import pytest
import torch
from torch.testing import assert_allclose

from flash.core.data.batch import default_uncollate

Case = namedtuple("Case", ["collated_batch", "uncollated_batch"])

class TestDefaultUncollate:
cases = [
# Primitives
Case({"preds": [1, 2, 3]}, [{"preds": 1}, {"preds": 2}, {"preds": 3}]),
Case(
{"preds": [1, 2, 3], "metadata": [4, 5, 6]},
[{"preds": 1, "metadata": 4}, {"preds": 2, "metadata": 5}, {"preds": 3, "metadata": 6}],
),
Case(([1, 2, 3], [4, 5, 6]), [[1, 2, 3], [4, 5, 6]]),
Case([[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]),
Case([[1, 2], [4, 5, 6]], [[1, 2], [4, 5, 6]]),
Case([["a", "b"], ["a", "c", "d"]], [["a", "b"], ["a", "c", "d"]]),
# Tensors
Case({"preds": torch.tensor([1, 2, 3])}, [{"preds": 1}, {"preds": 2}, {"preds": 3}]),
Case(
{"preds": torch.tensor([1, 2, 3]), "metadata": torch.tensor([4, 5, 6])},
[{"preds": 1, "metadata": 4}, {"preds": 2, "metadata": 5}, {"preds": 3, "metadata": 6}],
),
Case(torch.tensor([1, 2, 3]), [torch.tensor(1), torch.tensor(2), torch.tensor(3)]),
# Mixed
Case(
{"preds": torch.tensor([1, 2, 3]), "metadata": [4, 5, 6]},
[{"preds": 1, "metadata": 4}, {"preds": 2, "metadata": 5}, {"preds": 3, "metadata": 6}],
),
]

BATCH_SIZE = 3

@staticmethod
def test_smoke():
batch = torch.rand(2, 1)
assert default_uncollate(batch) is not None
@pytest.mark.parametrize("case", cases)
def test_default_uncollate(case):
assert default_uncollate(case.collated_batch) == case.uncollated_batch

@staticmethod
def test_tensor_zero():
batch = torch.tensor(1)
output = default_uncollate(batch)
assert_allclose(batch, output)

@staticmethod
def test_tensor_batch():
batch = torch.rand(2, 1)
output = default_uncollate(batch)
assert isinstance(output, list)
assert all(isinstance(x, torch.Tensor) for x in output)
ErrorCase = namedtuple("ErrorCase", ["collated_batch", "match"])

def test_sequence(self):
batch = {
"a": torch.rand(self.BATCH_SIZE, 4),
"b": torch.rand(self.BATCH_SIZE, 2),
"c": torch.rand(self.BATCH_SIZE),
}
error_cases = [
ErrorCase({"preds": [1, 2, 3], "metadata": [4, 5, 6, 7]}, "expected to have the same length."),
ErrorCase({"preds": [1, 2, 3], "metadata": "test"}, "expected to be list-like."),
ErrorCase("test", "expected to be a `dict` or list-like"),
]

output = default_uncollate(batch)
assert isinstance(output, list)
assert len(batch) == self.BATCH_SIZE

for sample in output:
assert list(sample.keys()) == ["a", "b", "c"]
assert isinstance(sample["a"], torch.Tensor)
assert len(sample["a"]) == 4
assert isinstance(sample["b"], torch.Tensor)
assert len(sample["b"]) == 2
assert isinstance(sample["c"], torch.Tensor)
assert len(sample["c"].shape) == 0

def test_named_tuple(self):
Batch = namedtuple("Batch", ["x", "y"])
batch = Batch(x=torch.rand(self.BATCH_SIZE, 4), y=torch.rand(self.BATCH_SIZE))

output = default_uncollate(batch)
assert isinstance(output, list)
assert len(output) == self.BATCH_SIZE

for sample in output:
assert isinstance(sample, Batch)
assert isinstance(sample.x, torch.Tensor)
assert len(sample.x) == 4
assert isinstance(sample.y, torch.Tensor)
assert len(sample.y.shape) == 0
@pytest.mark.parametrize("error_case", error_cases)
def test_default_uncollate_raises(error_case):
with pytest.raises(ValueError, match=error_case.match):
default_uncollate(error_case.collated_batch)
10 changes: 6 additions & 4 deletions tests/image/instance_segmentation/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,11 @@ def test_instance_segmentation_inference(tmpdir):
str(data_dir / "images/yorkshire_terrier_12.jpg"),
str(data_dir / "images/yorkshire_terrier_13.jpg"),
],
batch_size=2,
batch_size=4,
)
predictions = trainer.predict(model, datamodule=datamodule)
assert len(predictions[0][0]) == 6
assert len(predictions[0]) == 3
assert len(list(predictions[0][0].keys())) == 5

model_path = os.path.join(tmpdir, "model.pt")
trainer.save_checkpoint(model_path)
Expand All @@ -91,7 +92,8 @@ def test_instance_segmentation_inference(tmpdir):
str(data_dir / "images/yorkshire_terrier_12.jpg"),
str(data_dir / "images/yorkshire_terrier_13.jpg"),
],
batch_size=2,
batch_size=4,
)
predictions = trainer.predict(model, datamodule=datamodule)
assert len(predictions[0][0]) == 6
assert len(predictions[0]) == 3
assert len(list(predictions[0][0].keys())) == 5

0 comments on commit f37e50d

Please sign in to comment.