diff --git a/flash/core/model.py b/flash/core/model.py index b4d4a8b709..e01170fbab 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -30,7 +30,7 @@ from flash.core.schedulers import _SCHEDULERS_REGISTRY from flash.core.utils import get_callable_dict from flash.data.data_pipeline import DataPipeline, DataPipelineState -from flash.data.data_source import DataSource, DefaultDataSources +from flash.data.data_source import DataSource, DefaultDataKeys, DefaultDataSources from flash.data.process import Postprocess, Preprocess, Serializer, SerializerMapping diff --git a/flash/data/batch.py b/flash/data/batch.py index 38929f6079..61aa6c0e26 100644 --- a/flash/data/batch.py +++ b/flash/data/batch.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Mapping, Optional, Sequence, TYPE_CHECKING, Union +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, TYPE_CHECKING, Union import torch from pytorch_lightning.trainer.states import RunningStage @@ -19,6 +19,7 @@ from torch import Tensor from flash.data.callback import ControlFlow +from flash.data.data_source import DefaultDataKeys from flash.data.utils import _contains_any_tensor, convert_to_modules, CurrentFuncContext, CurrentRunningStageContext if TYPE_CHECKING: @@ -137,6 +138,13 @@ def __init__( self._collate_context = CurrentFuncContext("collate", preprocess) self._per_batch_transform_context = CurrentFuncContext(f"per_batch_transform{extension}", preprocess) + def _extract_metadata( + self, + samples: List[Dict[str, Any]], + ) -> Tuple[List[Dict[str, Any]], Optional[List[Dict[str, Any]]]]: + metadata = [s.pop(DefaultDataKeys.METADATA, None) if isinstance(s, Mapping) else None for s in samples] + return samples, metadata if any(m is not None for m in metadata) else None + def forward(self, samples: Sequence[Any]) -> Any: # we create a new dict to prevent from potential memory leaks # assuming that the dictionary samples are stored in between and @@ -158,7 +166,10 @@ def forward(self, samples: Sequence[Any]) -> Any: samples = type(_samples)(_samples) with self._collate_context: + samples, metadata = self._extract_metadata(samples) samples = self.collate_fn(samples) + if metadata: + samples[DefaultDataKeys.METADATA] = metadata self.callback.on_collate(samples, self.stage) with self._per_batch_transform_context: diff --git a/flash/data/data_source.py b/flash/data/data_source.py index 23955413ab..1360885b90 100644 --- a/flash/data/data_source.py +++ b/flash/data/data_source.py @@ -71,6 +71,7 @@ class DefaultDataKeys(LightningEnum): targets.""" INPUT = "input" + PREDS = "preds" TARGET = "target" METADATA = "metadata" diff --git a/flash/vision/segmentation/data.py b/flash/vision/segmentation/data.py index c4f9d0ebf9..df48cfcc80 100644 --- a/flash/vision/segmentation/data.py +++ b/flash/vision/segmentation/data.py @@ -35,6 +35,7 @@ ImageLabelsMap, NumpyDataSource, PathsDataSource, + SEQUENCE_DATA_TYPE, TensorDataSource, ) from flash.data.process import Preprocess @@ -51,7 +52,18 @@ class SemanticSegmentationNumpyDataSource(NumpyDataSource): def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]: - sample[DefaultDataKeys.INPUT] = torch.from_numpy(sample[DefaultDataKeys.INPUT]).float() + img = torch.from_numpy(sample[DefaultDataKeys.INPUT]).float() + sample[DefaultDataKeys.INPUT] = img + sample[DefaultDataKeys.METADATA] = img.shape + return sample + + +class SemanticSegmentationTensorDataSource(TensorDataSource): + + def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]: + img = sample[DefaultDataKeys.INPUT].float() + sample[DefaultDataKeys.INPUT] = img + sample[DefaultDataKeys.METADATA] = img.shape return sample @@ -120,7 +132,11 @@ def load_sample(self, sample: Mapping[str, Any]) -> Mapping[str, Union[torch.Ten } def predict_load_sample(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: - return {DefaultDataKeys.INPUT: torchvision.io.read_image(sample[DefaultDataKeys.INPUT]).float()} + img = torchvision.io.read_image(sample[DefaultDataKeys.INPUT]).float() + return { + DefaultDataKeys.INPUT: img, + DefaultDataKeys.METADATA: img.shape, + } class SemanticSegmentationPreprocess(Preprocess): @@ -157,7 +173,7 @@ def __init__( data_sources={ DefaultDataSources.FILES: SemanticSegmentationPathsDataSource(), DefaultDataSources.FOLDERS: SemanticSegmentationPathsDataSource(), - DefaultDataSources.TENSORS: TensorDataSource(), + DefaultDataSources.TENSORS: SemanticSegmentationTensorDataSource(), DefaultDataSources.NUMPY: SemanticSegmentationNumpyDataSource(), }, default_data_source=DefaultDataSources.FILES, diff --git a/flash/vision/segmentation/model.py b/flash/vision/segmentation/model.py index e543b341ed..7d99949e49 100644 --- a/flash/vision/segmentation/model.py +++ b/flash/vision/segmentation/model.py @@ -21,10 +21,23 @@ from flash.core.classification import ClassificationTask from flash.core.registry import FlashRegistry from flash.data.data_source import DefaultDataKeys -from flash.data.process import Serializer +from flash.data.process import Postprocess, Serializer +from flash.utils.imports import _KORNIA_AVAILABLE from flash.vision.segmentation.backbones import SEMANTIC_SEGMENTATION_BACKBONES from flash.vision.segmentation.serialization import SegmentationLabels +if _KORNIA_AVAILABLE: + import kornia as K + + +class SemanticSegmentationPostprocess(Postprocess): + + def per_sample_transform(self, sample: Any) -> Any: + resize = K.geometry.Resize(sample[DefaultDataKeys.METADATA][-2:], interpolation='bilinear') + sample[DefaultDataKeys.PREDS] = resize(torch.stack(sample[DefaultDataKeys.PREDS])) + sample[DefaultDataKeys.INPUT] = resize(torch.stack(sample[DefaultDataKeys.INPUT])) + return super().per_sample_transform(sample) + class SemanticSegmentation(ClassificationTask): """Task that performs semantic segmentation on images. @@ -53,6 +66,8 @@ class SemanticSegmentation(ClassificationTask): serializer: The :class:`~flash.data.process.Serializer` to use when serializing prediction outputs. """ + postprocess_cls = SemanticSegmentationPostprocess + backbones: FlashRegistry = SEMANTIC_SEGMENTATION_BACKBONES def __init__( @@ -67,6 +82,7 @@ def __init__( learning_rate: float = 1e-3, multi_label: bool = False, serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, + postprocess: Optional[Postprocess] = None, ) -> None: if metrics is None: @@ -86,6 +102,7 @@ def __init__( metrics=metrics, learning_rate=learning_rate, serializer=serializer or SegmentationLabels(), + postprocess=postprocess or self.postprocess_cls() ) self.save_hyperparameters() @@ -109,8 +126,10 @@ def test_step(self, batch: Any, batch_idx: int) -> Any: return super().test_step(batch, batch_idx) def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: - batch = (batch[DefaultDataKeys.INPUT]) - return super().predict_step(batch, batch_idx, dataloader_idx=dataloader_idx) + batch_input = (batch[DefaultDataKeys.INPUT]) + preds = super().predict_step(batch_input, batch_idx, dataloader_idx=dataloader_idx) + batch[DefaultDataKeys.PREDS] = preds + return batch def forward(self, x) -> torch.Tensor: # infer the image to the model diff --git a/flash/vision/segmentation/serialization.py b/flash/vision/segmentation/serialization.py index 5a8cb40f69..6a63a0bc7f 100644 --- a/flash/vision/segmentation/serialization.py +++ b/flash/vision/segmentation/serialization.py @@ -17,7 +17,7 @@ import torch import flash -from flash.data.data_source import ImageLabelsMap +from flash.data.data_source import DefaultDataKeys, ImageLabelsMap from flash.data.process import Serializer from flash.utils.imports import _KORNIA_AVAILABLE, _MATPLOTLIB_AVAILABLE @@ -67,9 +67,10 @@ def create_random_labels_map(num_classes: int) -> Dict[int, Tuple[int, int, int] labels_map[i] = torch.randint(0, 255, (3, )) return labels_map - def serialize(self, sample: torch.Tensor) -> torch.Tensor: - assert len(sample.shape) == 3, sample.shape - labels = torch.argmax(sample, dim=-3) # HxW + def serialize(self, sample: Dict[str, torch.Tensor]) -> torch.Tensor: + preds = sample[DefaultDataKeys.PREDS] + assert len(preds.shape) == 3, preds.shape + labels = torch.argmax(preds, dim=-3) # HxW if self.visualize and not flash._IS_TESTING: if self.labels_map is None: diff --git a/tests/vision/segmentation/test_model.py b/tests/vision/segmentation/test_model.py index 5ccc86d68f..d436ffa982 100644 --- a/tests/vision/segmentation/test_model.py +++ b/tests/vision/segmentation/test_model.py @@ -89,7 +89,7 @@ def test_predict_tensor(): data_pipe = DataPipeline(preprocess=SemanticSegmentationPreprocess(num_classes=1)) out = model.predict(img, data_source="tensors", data_pipeline=data_pipe) assert isinstance(out[0], torch.Tensor) - assert out[0].shape == (196, 196) + assert out[0].shape == (10, 20) def test_predict_numpy(): @@ -98,4 +98,4 @@ def test_predict_numpy(): data_pipe = DataPipeline(preprocess=SemanticSegmentationPreprocess(num_classes=1)) out = model.predict(img, data_source="numpy", data_pipeline=data_pipe) assert isinstance(out[0], torch.Tensor) - assert out[0].shape == (196, 196) + assert out[0].shape == (10, 20) diff --git a/tests/vision/segmentation/test_serialization.py b/tests/vision/segmentation/test_serialization.py index a971c91fbf..872fcc2420 100644 --- a/tests/vision/segmentation/test_serialization.py +++ b/tests/vision/segmentation/test_serialization.py @@ -1,6 +1,7 @@ import pytest import torch +from flash.data.data_source import DefaultDataKeys from flash.vision.segmentation.serialization import SegmentationLabels @@ -30,7 +31,7 @@ def test_serialize(self): sample[1, 1, 2] = 1 # add peak in class 2 sample[3, 0, 1] = 1 # add peak in class 4 - classes = serial.serialize(sample) + classes = serial.serialize({DefaultDataKeys.PREDS: sample}) assert classes[1, 2] == 1 assert classes[0, 1] == 3