From 63d1f329e367e448f9a307a9b5e6fa7adf7ac987 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 12 May 2021 12:46:51 +0100 Subject: [PATCH 01/13] cleanup segmentation --- flash/__init__.py | 1 + flash/data/data_pipeline.py | 3 + flash/data/data_source.py | 7 ++ flash/data/splits.py | 5 ++ flash/vision/segmentation/data.py | 87 ++++++++++++++++--- flash/vision/segmentation/serialization.py | 20 ++--- .../finetuning/semantic_segmentation.py | 9 +- .../predict/semantic_segmentation.py | 4 +- tests/data/test_split_dataset.py | 15 ++++ 9 files changed, 114 insertions(+), 37 deletions(-) diff --git a/flash/__init__.py b/flash/__init__.py index 3eef508374..1caafa591f 100644 --- a/flash/__init__.py +++ b/flash/__init__.py @@ -18,6 +18,7 @@ _PACKAGE_ROOT = os.path.dirname(__file__) PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT) +_IS_TESTING = os.getenv("FLASH_TESTING", "0") == "1" from flash.core.model import Task # noqa: E402 from flash.core.trainer import Trainer # noqa: E402 diff --git a/flash/data/data_pipeline.py b/flash/data/data_pipeline.py index 8a0b739ace..80eb0ecbad 100644 --- a/flash/data/data_pipeline.py +++ b/flash/data/data_pipeline.py @@ -61,6 +61,9 @@ def get_state(self, state_type: Type[ProcessState]) -> Optional[ProcessState]: else: return None + def __repr__(self) -> str: + return f"{self.__class__.__name__}(initialized={self._initialized}, state={self._state})" + class DataPipeline: """ diff --git a/flash/data/data_source.py b/flash/data/data_source.py index 4238dbb514..f7003c45de 100644 --- a/flash/data/data_source.py +++ b/flash/data/data_source.py @@ -44,6 +44,12 @@ class LabelsState(ProcessState): labels: Optional[Sequence[str]] +@dataclass(unsafe_hash=True, frozen=True) +class ImageLabelsMap(ProcessState): + + labels_map: Optional[Dict[int, Tuple[int, int, int]]] + + class DefaultDataSources(LightningEnum): """The ``DefaultDataSources`` enum contains the data source names used by all of the default ``from_*`` methods in :class:`~flash.data.data_module.DataModule`.""" @@ -65,6 +71,7 @@ class DefaultDataKeys(LightningEnum): INPUT = "input" TARGET = "target" + METADATA = "metadata" # TODO: Create a FlashEnum class??? def __hash__(self) -> int: diff --git a/flash/data/splits.py b/flash/data/splits.py index d8f4e2aa7e..054ab116a1 100644 --- a/flash/data/splits.py +++ b/flash/data/splits.py @@ -38,6 +38,11 @@ def __init__(self, dataset: Any, indices: List[int] = [], use_duplicated_indices self.dataset = dataset self.indices = indices + def __getattr__(self, key: str): + if key in ("dataset", "indices", "data"): + return getattr(self, key) + return getattr(self.dataset, key) + def __getitem__(self, index: int) -> Any: return self.dataset[self.indices[index]] diff --git a/flash/vision/segmentation/data.py b/flash/vision/segmentation/data.py index 882f31ef3a..989da6196c 100644 --- a/flash/vision/segmentation/data.py +++ b/flash/vision/segmentation/data.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +from dataclasses import dataclass from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union import numpy as np @@ -23,14 +24,18 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from torchvision.datasets.folder import has_file_allowed_extension, IMG_EXTENSIONS +import flash +from flash.data.auto_dataset import BaseAutoDataset from flash.data.base_viz import BaseVisualization # for viz from flash.data.callback import BaseDataFetcher from flash.data.data_module import DataModule from flash.data.data_source import ( DefaultDataKeys, DefaultDataSources, + ImageLabelsMap, NumpyDataSource, PathsDataSource, + SEQUENCE_DATA_TYPE, TensorDataSource, ) from flash.data.process import Preprocess @@ -44,8 +49,25 @@ plt = None +def create_random_labels_map(num_classes: int) -> Dict[int, Tuple[int, int, int]]: + labels_map: Dict[int, Tuple[int, int, int]] = {} + for i in range(num_classes): + labels_map[i] = torch.randint(0, 255, (3, )) + return labels_map + + class SemanticSegmentationNumpyDataSource(NumpyDataSource): + def __init__(self, num_classes: int): + self.num_classes = num_classes + + def load_data(self, data: Tuple[Sequence[SEQUENCE_DATA_TYPE], Optional[Sequence]], + dataset: Optional[Any]) -> Sequence[Mapping[str, Any]]: + data = super().load_data(data, dataset=dataset) + if self.training: + dataset.num_classes = self.num_classes + return data + def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]: sample[DefaultDataKeys.INPUT] = torch.from_numpy(sample[DefaultDataKeys.INPUT]).float() return sample @@ -53,10 +75,12 @@ def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> class SemanticSegmentationPathsDataSource(PathsDataSource): - def __init__(self): + def __init__(self, num_classes: int): super().__init__(IMG_EXTENSIONS) + self.num_classes = num_classes - def load_data(self, data: Union[Tuple[str, str], Tuple[List[str], List[str]]]) -> Sequence[Mapping[str, Any]]: + def load_data(self, data: Union[Tuple[str, str], Tuple[List[str], List[str]]], + dataset: BaseAutoDataset) -> Sequence[Mapping[str, Any]]: input_data, target_data = data if self.isdir(input_data) and self.isdir(target_data): @@ -93,6 +117,9 @@ def load_data(self, data: Union[Tuple[str, str], Tuple[List[str], List[str]]]) - zip(input_data, target_data), ) + if self.training: + dataset.num_classes = self.num_classes + return [{DefaultDataKeys.INPUT: input, DefaultDataKeys.TARGET: target} for input, target in data] def predict_load_data(self, data: Union[str, List[str]]): @@ -108,7 +135,11 @@ def load_sample(self, sample: Mapping[str, Any]) -> Mapping[str, torch.Tensor]: img_labels: torch.Tensor = torchvision.io.read_image(img_labels_path) # CxHxW img_labels = img_labels[0] # HxW - return {DefaultDataKeys.INPUT: img.float(), DefaultDataKeys.TARGET: img_labels.float()} + return { + DefaultDataKeys.INPUT: img.float(), + DefaultDataKeys.TARGET: img_labels.float(), + DefaultDataKeys.TARGET: img_labels.float() + } def predict_load_sample(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: return {DefaultDataKeys.INPUT: torchvision.io.read_image(sample[DefaultDataKeys.INPUT]).float()} @@ -123,6 +154,8 @@ def __init__( test_transform: Optional[Dict[str, Callable]] = None, predict_transform: Optional[Dict[str, Callable]] = None, image_size: Tuple[int, int] = (196, 196), + num_classes: int = None, + labels_map: Dict[int, Tuple[int, int, int]] = None, ) -> None: """Preprocess pipeline for semantic segmentation tasks. @@ -133,7 +166,15 @@ def __init__( predict_transform: Dictionary with the set of transforms to apply during prediction. image_size: A tuple with the expected output image size. """ + if not num_classes or not isinstance(num_classes, int): + raise MisconfigurationException("`num_classes` should be provided for instantiation.") + + if not labels_map: + raise MisconfigurationException("`labels_map` should be provided for instantiation.") + self.image_size = image_size + self.num_classes = num_classes + self.labels_map = labels_map super().__init__( train_transform=train_transform, @@ -141,17 +182,20 @@ def __init__( test_transform=test_transform, predict_transform=predict_transform, data_sources={ - DefaultDataSources.PATHS: SemanticSegmentationPathsDataSource(), + DefaultDataSources.PATHS: SemanticSegmentationPathsDataSource(num_classes), DefaultDataSources.TENSOR: TensorDataSource(), - DefaultDataSources.NUMPY: SemanticSegmentationNumpyDataSource(), + DefaultDataSources.NUMPY: SemanticSegmentationNumpyDataSource(num_classes), }, default_data_source=DefaultDataSources.PATHS, ) + self.set_state(ImageLabelsMap(labels_map)) + def get_state_dict(self) -> Dict[str, Any]: return { - **self.transforms, - "image_size": self.image_size, + **self.transforms, "image_size": self.image_size, + "num_classes": self.num_classes, + "labels_map": self.labels_map } @classmethod @@ -184,9 +228,6 @@ class SemanticSegmentationData(DataModule): def configure_data_fetcher(*args, **kwargs) -> BaseDataFetcher: return SegmentationMatplotlibVisualization(*args, **kwargs) - def set_labels_map(self, labels_map: Dict[int, Tuple[int, int, int]]): - self.data_fetcher.labels_map = labels_map - def set_block_viz_window(self, value: bool) -> None: """Setter method to switch on/off matplotlib to pop up windows.""" self.data_fetcher.block_viz_window = value @@ -210,7 +251,9 @@ def from_folders( val_split: Optional[float] = None, batch_size: int = 4, num_workers: Optional[int] = None, - **preprocess_kwargs: Any, + num_classes: Optional[int] = None, + labels_map: Dict[int, Tuple[int, int, int]] = None, + **preprocess_kwargs, ) -> 'DataModule': """Creates a :class:`~flash.vision.segmentation.data.SemanticSegmentationData` object from the given data folders and corresponding target folders. @@ -242,6 +285,8 @@ def from_folders( val_split: The ``val_split`` argument to pass to the :class:`~flash.data.data_module.DataModule`. batch_size: The ``batch_size`` argument to pass to the :class:`~flash.data.data_module.DataModule`. num_workers: The ``num_workers`` argument to pass to the :class:`~flash.data.data_module.DataModule`. + num_classes: Number of classes within the segmentation mask. + labels_map: Mapping between a class_id and its corresponding color. preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used if ``preprocess = None``. @@ -255,6 +300,17 @@ def from_folders( train_target_folder="train_masks", ) """ + + if not num_classes or not isinstance(num_classes, int): + raise MisconfigurationException("`num_classes` should be provided during instantiation.") + + labels_map = labels_map or create_random_labels_map(num_classes) + + data_fetcher = data_fetcher or cls.configure_data_fetcher(labels_map) + + if flash._IS_TESTING: + data_fetcher.block_viz_window = True + return cls.from_data_source( DefaultDataSources.PATHS, (train_folder, train_target_folder), @@ -270,6 +326,8 @@ def from_folders( val_split=val_split, batch_size=batch_size, num_workers=num_workers, + num_classes=num_classes, + labels_map=labels_map, **preprocess_kwargs, ) @@ -278,11 +336,12 @@ class SegmentationMatplotlibVisualization(BaseVisualization): """Process and show the image batch and its associated label using matplotlib. """ - def __init__(self): - super().__init__(self) + def __init__(self, labels_map: Dict[int, Tuple[int, int, int]]): + super().__init__() + self.max_cols: int = 4 # maximum number of columns we accept self.block_viz_window: bool = True # parameter to allow user to block visualisation windows - self.labels_map: Dict[int, Tuple[int, int, int]] = {} + self.labels_map: Dict[int, Tuple[int, int, int]] = labels_map @staticmethod def _to_numpy(img: Union[torch.Tensor, Image.Image]) -> np.ndarray: diff --git a/flash/vision/segmentation/serialization.py b/flash/vision/segmentation/serialization.py index 50ba5be9a9..47e237b6a2 100644 --- a/flash/vision/segmentation/serialization.py +++ b/flash/vision/segmentation/serialization.py @@ -16,6 +16,8 @@ import torch +import flash +from flash.data.data_source import ImageLabelsMap from flash.data.process import Serializer from flash.utils.imports import _KORNIA_AVAILABLE, _MATPLOTLIB_AVAILABLE @@ -58,24 +60,14 @@ def labels_to_image(img_labels: torch.Tensor, labels_map: Dict[int, Tuple[int, i out[i].masked_fill_(mask, label_val[i]) return out - @staticmethod - def create_random_labels_map(num_classes: int) -> Dict[int, Tuple[int, int, int]]: - labels_map: Dict[int, Tuple[int, int, int]] = {} - for i in range(num_classes): - labels_map[i] = torch.randint(0, 255, (3, )) - return labels_map - def serialize(self, sample: torch.Tensor) -> torch.Tensor: assert len(sample.shape) == 3, sample.shape labels = torch.argmax(sample, dim=-3) # HxW - if self.visualize and os.getenv("FLASH_TESTING", "0") == "0": + + if self.visualize and not flash._IS_TESTING: if self.labels_map is None: - # create random colors map - num_classes = sample.shape[-3] - labels_map = self.create_random_labels_map(num_classes) - else: - labels_map = self.labels_map - labels_vis = self.labels_to_image(labels, labels_map) + self.labels_map = self.get_state(ImageLabelsMap).labels_map + labels_vis = self.labels_to_image(labels, self.labels_map) labels_vis = K.utils.tensor_to_image(labels_vis) plt.imshow(labels_vis) plt.show() diff --git a/flash_examples/finetuning/semantic_segmentation.py b/flash_examples/finetuning/semantic_segmentation.py index 3676353ec8..2d5bfaaee2 100644 --- a/flash_examples/finetuning/semantic_segmentation.py +++ b/flash_examples/finetuning/semantic_segmentation.py @@ -31,17 +31,17 @@ batch_size=4, val_split=0.3, image_size=(200, 200), # (600, 800) + num_classes=21, ) # 2.2 Visualise the samples -labels_map = SegmentationLabels.create_random_labels_map(num_classes=21) -datamodule.set_labels_map(labels_map) datamodule.show_train_batch(["load_sample", "post_tensor_transform"]) # 3. Build the model model = SemanticSegmentation( backbone="torchvision/fcn_resnet50", - num_classes=21, + num_classes=datamodule.num_classes, + serializer=SegmentationLabels(visualize=True) ) # 4. Create the trainer. @@ -53,9 +53,6 @@ # 5. Train the model trainer.finetune(model, datamodule=datamodule, strategy="freeze") -# 6. Predict what's on a few images! -model.serializer = SegmentationLabels(labels_map, visualize=True) - predictions = model.predict([ "data/CameraRGB/F61-1.png", "data/CameraRGB/F62-1.png", diff --git a/flash_examples/predict/semantic_segmentation.py b/flash_examples/predict/semantic_segmentation.py index f507f2a6a6..9209923be7 100644 --- a/flash_examples/predict/semantic_segmentation.py +++ b/flash_examples/predict/semantic_segmentation.py @@ -24,9 +24,7 @@ ) # 2. Load the model from a checkpoint -model = SemanticSegmentation.load_from_checkpoint( - "https://flash-weights.s3.amazonaws.com/semantic_segmentation_model.pt" -) +model = SemanticSegmentation.load_from_checkpoint("semantic_segmentation_model.pt") model.serializer = SegmentationLabels(visualize=True) # 3. Predict what's on a few images and visualize! diff --git a/tests/data/test_split_dataset.py b/tests/data/test_split_dataset.py index e92a44e1b6..382d4eb05e 100644 --- a/tests/data/test_split_dataset.py +++ b/tests/data/test_split_dataset.py @@ -37,3 +37,18 @@ def test_split_dataset(tmpdir): with pytest.raises(MisconfigurationException, match="[0, 99]"): SplitDataset(list(range(50)) + list(range(50)), indices=[-1], use_duplicated_indices=True) + + class Dataset: + + def __init__(self): + self.data = [0, 1, 2] + self.name = "something" + + def __getitem__(self, index): + return self.data[index] + + def __len__(self): + return len(self.data) + + split_dataset = SplitDataset(Dataset(), indices=[0]) + assert split_dataset.name == "something" From 9ef1d2f791712c6481ed4f6e9e2c7cc9606d587e Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 12 May 2021 12:57:00 +0100 Subject: [PATCH 02/13] update --- flash/vision/segmentation/data.py | 25 +++++++------------------ 1 file changed, 7 insertions(+), 18 deletions(-) diff --git a/flash/vision/segmentation/data.py b/flash/vision/segmentation/data.py index 989da6196c..dc94d92d57 100644 --- a/flash/vision/segmentation/data.py +++ b/flash/vision/segmentation/data.py @@ -58,16 +58,6 @@ def create_random_labels_map(num_classes: int) -> Dict[int, Tuple[int, int, int] class SemanticSegmentationNumpyDataSource(NumpyDataSource): - def __init__(self, num_classes: int): - self.num_classes = num_classes - - def load_data(self, data: Tuple[Sequence[SEQUENCE_DATA_TYPE], Optional[Sequence]], - dataset: Optional[Any]) -> Sequence[Mapping[str, Any]]: - data = super().load_data(data, dataset=dataset) - if self.training: - dataset.num_classes = self.num_classes - return data - def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]: sample[DefaultDataKeys.INPUT] = torch.from_numpy(sample[DefaultDataKeys.INPUT]).float() return sample @@ -75,9 +65,8 @@ def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> class SemanticSegmentationPathsDataSource(PathsDataSource): - def __init__(self, num_classes: int): + def __init__(self): super().__init__(IMG_EXTENSIONS) - self.num_classes = num_classes def load_data(self, data: Union[Tuple[str, str], Tuple[List[str], List[str]]], dataset: BaseAutoDataset) -> Sequence[Mapping[str, Any]]: @@ -117,9 +106,6 @@ def load_data(self, data: Union[Tuple[str, str], Tuple[List[str], List[str]]], zip(input_data, target_data), ) - if self.training: - dataset.num_classes = self.num_classes - return [{DefaultDataKeys.INPUT: input, DefaultDataKeys.TARGET: target} for input, target in data] def predict_load_data(self, data: Union[str, List[str]]): @@ -182,9 +168,9 @@ def __init__( test_transform=test_transform, predict_transform=predict_transform, data_sources={ - DefaultDataSources.PATHS: SemanticSegmentationPathsDataSource(num_classes), + DefaultDataSources.PATHS: SemanticSegmentationPathsDataSource(), DefaultDataSources.TENSOR: TensorDataSource(), - DefaultDataSources.NUMPY: SemanticSegmentationNumpyDataSource(num_classes), + DefaultDataSources.NUMPY: SemanticSegmentationNumpyDataSource(), }, default_data_source=DefaultDataSources.PATHS, ) @@ -311,7 +297,7 @@ def from_folders( if flash._IS_TESTING: data_fetcher.block_viz_window = True - return cls.from_data_source( + dm = cls.from_data_source( DefaultDataSources.PATHS, (train_folder, train_target_folder), (val_folder, val_target_folder), @@ -331,6 +317,9 @@ def from_folders( **preprocess_kwargs, ) + dm.train_dataset.num_classes = num_classes + return dm + class SegmentationMatplotlibVisualization(BaseVisualization): """Process and show the image batch and its associated label using matplotlib. From b3870f3c1bf6d91967f07c1451222956fd98ba56 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 12 May 2021 13:05:03 +0100 Subject: [PATCH 03/13] update --- flash/data/splits.py | 10 +++++++++- tests/data/test_split_dataset.py | 5 +++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/flash/data/splits.py b/flash/data/splits.py index 054ab116a1..8c09ad2290 100644 --- a/flash/data/splits.py +++ b/flash/data/splits.py @@ -23,6 +23,8 @@ class SplitDataset(Dataset): """ + _INTERNAL_KEYS = ("dataset", "indices", "data") + def __init__(self, dataset: Any, indices: List[int] = [], use_duplicated_indices: bool = False) -> None: if not isinstance(indices, list): raise MisconfigurationException("indices should be a list") @@ -39,10 +41,16 @@ def __init__(self, dataset: Any, indices: List[int] = [], use_duplicated_indices self.indices = indices def __getattr__(self, key: str): - if key in ("dataset", "indices", "data"): + if key in self._INTERNAL_KEYS: return getattr(self, key) return getattr(self.dataset, key) + def __setattr__(self, name: str, value: Any) -> None: + if name in self._INTERNAL_KEYS: + self.__dict__[name] = value + else: + setattr(self.dataset, name, value) + def __getitem__(self, index: int) -> Any: return self.dataset[self.indices[index]] diff --git a/tests/data/test_split_dataset.py b/tests/data/test_split_dataset.py index 382d4eb05e..cc087cd167 100644 --- a/tests/data/test_split_dataset.py +++ b/tests/data/test_split_dataset.py @@ -52,3 +52,8 @@ def __len__(self): split_dataset = SplitDataset(Dataset(), indices=[0]) assert split_dataset.name == "something" + + assert split_dataset._INTERNAL_KEYS == ("dataset", "indices", "data") + + split_dataset.is_passed_down = True + assert split_dataset.dataset.is_passed_down From 3c58b0f057fb6960e1b0afa6b2e38d2cad146bfc Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 12 May 2021 13:39:53 +0100 Subject: [PATCH 04/13] update --- flash/vision/segmentation/data.py | 6 ++---- tests/vision/segmentation/test_data.py | 2 +- tests/vision/segmentation/test_model.py | 4 ++-- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/flash/vision/segmentation/data.py b/flash/vision/segmentation/data.py index dc94d92d57..1f3689a3a2 100644 --- a/flash/vision/segmentation/data.py +++ b/flash/vision/segmentation/data.py @@ -155,9 +155,6 @@ def __init__( if not num_classes or not isinstance(num_classes, int): raise MisconfigurationException("`num_classes` should be provided for instantiation.") - if not labels_map: - raise MisconfigurationException("`labels_map` should be provided for instantiation.") - self.image_size = image_size self.num_classes = num_classes self.labels_map = labels_map @@ -175,7 +172,8 @@ def __init__( default_data_source=DefaultDataSources.PATHS, ) - self.set_state(ImageLabelsMap(labels_map)) + if labels_map: + self.set_state(ImageLabelsMap(labels_map)) def get_state_dict(self) -> Dict[str, Any]: return { diff --git a/tests/vision/segmentation/test_data.py b/tests/vision/segmentation/test_data.py index bd51f09d21..a828bc13a6 100644 --- a/tests/vision/segmentation/test_data.py +++ b/tests/vision/segmentation/test_data.py @@ -44,7 +44,7 @@ class TestSemanticSegmentationPreprocess: @pytest.mark.xfail(reaspn="parameters are marked as optional but it returns Misconficg error.") def test_smoke(self): - prep = SemanticSegmentationPreprocess() + prep = SemanticSegmentationPreprocess(num_classes=1) assert prep is not None diff --git a/tests/vision/segmentation/test_model.py b/tests/vision/segmentation/test_model.py index d3f30129ff..dc927ee257 100644 --- a/tests/vision/segmentation/test_model.py +++ b/tests/vision/segmentation/test_model.py @@ -86,7 +86,7 @@ def test_unfreeze(): def test_predict_tensor(): img = torch.rand(1, 3, 10, 20) model = SemanticSegmentation(2) - data_pipe = DataPipeline(preprocess=SemanticSegmentationPreprocess()) + data_pipe = DataPipeline(preprocess=SemanticSegmentationPreprocess(num_classes=1)) out = model.predict(img, data_source="tensor", data_pipeline=data_pipe) assert isinstance(out[0], torch.Tensor) assert out[0].shape == (196, 196) @@ -95,7 +95,7 @@ def test_predict_tensor(): def test_predict_numpy(): img = np.ones((1, 3, 10, 20)) model = SemanticSegmentation(2) - data_pipe = DataPipeline(preprocess=SemanticSegmentationPreprocess()) + data_pipe = DataPipeline(preprocess=SemanticSegmentationPreprocess(num_classes=1)) out = model.predict(img, data_source="numpy", data_pipeline=data_pipe) assert isinstance(out[0], torch.Tensor) assert out[0].shape == (196, 196) From 7f933037c6a8a664e6406dd78d37db2c83379a4b Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 12 May 2021 16:29:49 +0100 Subject: [PATCH 05/13] update --- flash/vision/segmentation/data.py | 95 +++++++++++++++------- flash/vision/segmentation/serialization.py | 7 ++ tests/vision/segmentation/test_data.py | 6 +- 3 files changed, 77 insertions(+), 31 deletions(-) diff --git a/flash/vision/segmentation/data.py b/flash/vision/segmentation/data.py index 1f3689a3a2..dd770c65ec 100644 --- a/flash/vision/segmentation/data.py +++ b/flash/vision/segmentation/data.py @@ -49,13 +49,6 @@ plt = None -def create_random_labels_map(num_classes: int) -> Dict[int, Tuple[int, int, int]]: - labels_map: Dict[int, Tuple[int, int, int]] = {} - for i in range(num_classes): - labels_map[i] = torch.randint(0, 255, (3, )) - return labels_map - - class SemanticSegmentationNumpyDataSource(NumpyDataSource): def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]: @@ -111,7 +104,7 @@ def load_data(self, data: Union[Tuple[str, str], Tuple[List[str], List[str]]], def predict_load_data(self, data: Union[str, List[str]]): return super().predict_load_data(data) - def load_sample(self, sample: Mapping[str, Any]) -> Mapping[str, torch.Tensor]: + def load_sample(self, sample: Mapping[str, Any]) -> Mapping[str, Union[torch.Tensor, torch.Size]]: # unpack data paths img_path = sample[DefaultDataKeys.INPUT] img_labels_path = sample[DefaultDataKeys.TARGET] @@ -124,7 +117,7 @@ def load_sample(self, sample: Mapping[str, Any]) -> Mapping[str, torch.Tensor]: return { DefaultDataKeys.INPUT: img.float(), DefaultDataKeys.TARGET: img_labels.float(), - DefaultDataKeys.TARGET: img_labels.float() + DefaultDataKeys.METADATA: img.shape, } def predict_load_sample(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: @@ -152,12 +145,10 @@ def __init__( predict_transform: Dictionary with the set of transforms to apply during prediction. image_size: A tuple with the expected output image size. """ - if not num_classes or not isinstance(num_classes, int): - raise MisconfigurationException("`num_classes` should be provided for instantiation.") - self.image_size = image_size self.num_classes = num_classes - self.labels_map = labels_map + if num_classes: + labels_map = labels_map or SegmentationLabels.create_random_labels_map(num_classes) super().__init__( train_transform=train_transform, @@ -175,6 +166,8 @@ def __init__( if labels_map: self.set_state(ImageLabelsMap(labels_map)) + self.labels_map = labels_map + def get_state_dict(self) -> Dict[str, Any]: return { **self.transforms, "image_size": self.image_size, @@ -209,13 +202,69 @@ class SemanticSegmentationData(DataModule): preprocess_cls = SemanticSegmentationPreprocess @staticmethod - def configure_data_fetcher(*args, **kwargs) -> BaseDataFetcher: - return SegmentationMatplotlibVisualization(*args, **kwargs) + def configure_data_fetcher( + labels_map: Optional[Dict[int, Tuple[int, int, int]]] = None + ) -> 'SegmentationMatplotlibVisualization': + return SegmentationMatplotlibVisualization(labels_map=labels_map) def set_block_viz_window(self, value: bool) -> None: """Setter method to switch on/off matplotlib to pop up windows.""" self.data_fetcher.block_viz_window = value + @classmethod + def from_data_source( + cls, + data_source: str, + train_data: Any = None, + val_data: Any = None, + test_data: Any = None, + predict_data: Any = None, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + data_fetcher: Optional[BaseDataFetcher] = None, + preprocess: Optional[Preprocess] = None, + val_split: Optional[float] = None, + batch_size: int = 4, + num_workers: Optional[int] = None, + **preprocess_kwargs: Any, + ) -> 'DataModule': + + if 'num_classes' not in preprocess_kwargs: + raise MisconfigurationException("`num_classes` should be provided during instantiation.") + + num_classes = preprocess_kwargs["num_classes"] + + labels_map = getattr(preprocess_kwargs, "labels_map", + None) or SegmentationLabels.create_random_labels_map(num_classes) + + data_fetcher = data_fetcher or cls.configure_data_fetcher(labels_map) + + if flash._IS_TESTING: + data_fetcher.block_viz_window = True + + dm = super(SemanticSegmentationData, cls).from_data_source( + data_source=data_source, + train_data=train_data, + val_data=val_data, + test_data=test_data, + predict_data=predict_data, + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_fetcher=data_fetcher, + preprocess=preprocess, + val_split=val_split, + batch_size=batch_size, + num_workers=num_workers, + **preprocess_kwargs + ) + + dm.train_dataset.num_classes = num_classes + return dm + @classmethod def from_folders( cls, @@ -284,18 +333,7 @@ def from_folders( train_target_folder="train_masks", ) """ - - if not num_classes or not isinstance(num_classes, int): - raise MisconfigurationException("`num_classes` should be provided during instantiation.") - - labels_map = labels_map or create_random_labels_map(num_classes) - - data_fetcher = data_fetcher or cls.configure_data_fetcher(labels_map) - - if flash._IS_TESTING: - data_fetcher.block_viz_window = True - - dm = cls.from_data_source( + return cls.from_data_source( DefaultDataSources.PATHS, (train_folder, train_target_folder), (val_folder, val_target_folder), @@ -315,9 +353,6 @@ def from_folders( **preprocess_kwargs, ) - dm.train_dataset.num_classes = num_classes - return dm - class SegmentationMatplotlibVisualization(BaseVisualization): """Process and show the image batch and its associated label using matplotlib. diff --git a/flash/vision/segmentation/serialization.py b/flash/vision/segmentation/serialization.py index 47e237b6a2..5a8cb40f69 100644 --- a/flash/vision/segmentation/serialization.py +++ b/flash/vision/segmentation/serialization.py @@ -60,6 +60,13 @@ def labels_to_image(img_labels: torch.Tensor, labels_map: Dict[int, Tuple[int, i out[i].masked_fill_(mask, label_val[i]) return out + @staticmethod + def create_random_labels_map(num_classes: int) -> Dict[int, Tuple[int, int, int]]: + labels_map: Dict[int, Tuple[int, int, int]] = {} + for i in range(num_classes): + labels_map[i] = torch.randint(0, 255, (3, )) + return labels_map + def serialize(self, sample: torch.Tensor) -> torch.Tensor: assert len(sample.shape) == 3, sample.shape labels = torch.argmax(sample, dim=-3) # HxW diff --git a/tests/vision/segmentation/test_data.py b/tests/vision/segmentation/test_data.py index a828bc13a6..4d68bbc1d1 100644 --- a/tests/vision/segmentation/test_data.py +++ b/tests/vision/segmentation/test_data.py @@ -89,6 +89,7 @@ def test_from_folders(self, tmpdir): test_target_folder=str(tmp_dir / "targets"), batch_size=2, num_workers=0, + num_classes=num_classes, ) assert dm is not None assert dm.train_dataloader() is not None @@ -143,6 +144,7 @@ def test_from_folders_warning(self, tmpdir): train_target_folder=str(tmp_dir / "targets"), batch_size=1, num_workers=0, + num_classes=num_classes, ) assert dm is not None assert dm.train_dataloader() is not None @@ -185,6 +187,7 @@ def test_from_files(self, tmpdir): test_targets=targets, batch_size=2, num_workers=0, + num_classes=num_classes ) assert dm is not None assert dm.train_dataloader() is not None @@ -238,6 +241,7 @@ def test_from_files_warning(self, tmpdir): train_targets=targets + [str(tmp_dir / "labels_img4.png")], batch_size=2, num_workers=0, + num_classes=num_classes ) def test_map_labels(self, tmpdir): @@ -275,6 +279,7 @@ def test_map_labels(self, tmpdir): val_targets=targets, batch_size=2, num_workers=0, + num_classes=num_classes ) assert dm is not None assert dm.train_dataloader() is not None @@ -284,7 +289,6 @@ def test_map_labels(self, tmpdir): dm.set_block_viz_window(False) assert dm.data_fetcher.block_viz_window is False - dm.set_labels_map(labels_map) dm.show_train_batch("load_sample") dm.show_train_batch("to_tensor_transform") From a810ed8ece8622728b3ab3522cca3dca3cb2eb3e Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 12 May 2021 17:03:35 +0100 Subject: [PATCH 06/13] update --- flash/core/model.py | 3 ++- flash/data/batch.py | 11 ++++++++++- flash/data/data_source.py | 1 + flash/vision/segmentation/data.py | 6 +++++- flash/vision/segmentation/model.py | 20 +++++++++++++++++--- 5 files changed, 35 insertions(+), 6 deletions(-) diff --git a/flash/core/model.py b/flash/core/model.py index b4d4a8b709..667ceb231f 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -30,7 +30,7 @@ from flash.core.schedulers import _SCHEDULERS_REGISTRY from flash.core.utils import get_callable_dict from flash.data.data_pipeline import DataPipeline, DataPipelineState -from flash.data.data_source import DataSource, DefaultDataSources +from flash.data.data_source import DataSource, DefaultDataKeys, DefaultDataSources from flash.data.process import Postprocess, Preprocess, Serializer, SerializerMapping @@ -191,6 +191,7 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A elif isinstance(batch, list): # Todo: Understand why stack is needed batch = torch.stack(batch) + self(batch) return self(batch) def configure_optimizers(self) -> Union[Optimizer, Tuple[List[Optimizer], List[_LRScheduler]]]: diff --git a/flash/data/batch.py b/flash/data/batch.py index f08be37d02..3b0d9653ff 100644 --- a/flash/data/batch.py +++ b/flash/data/batch.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Mapping, Optional, Sequence, TYPE_CHECKING, Union +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, TYPE_CHECKING, Union import torch from pytorch_lightning.trainer.states import RunningStage @@ -19,6 +19,7 @@ from torch import Tensor from flash.data.callback import ControlFlow +from flash.data.data_source import DefaultDataKeys from flash.data.utils import _contains_any_tensor, convert_to_modules, CurrentFuncContext, CurrentRunningStageContext if TYPE_CHECKING: @@ -137,6 +138,11 @@ def __init__( self._collate_context = CurrentFuncContext("collate", preprocess) self._per_batch_transform_context = CurrentFuncContext(f"per_batch_transform{extension}", preprocess) + def _extract_metadata(self, + samples: List[Dict[str, Any]]) -> Tuple[List[Dict[str, Any]], Optional[List[Dict[str, Any]]]]: + metadata = [s.pop(DefaultDataKeys.METADATA, None) for s in samples] + return samples, metadata if any(m is not None for m in metadata) else None + def forward(self, samples: Sequence[Any]) -> Any: # we create a new dict to prevent from potential memory leaks # assuming that the dictionary samples are stored in between and @@ -158,7 +164,10 @@ def forward(self, samples: Sequence[Any]) -> Any: samples = type(_samples)(_samples) with self._collate_context: + samples, metada = self._extract_metadata(samples) samples = self.collate_fn(samples) + if metada: + samples[DefaultDataKeys.METADATA] = metada self.callback.on_collate(samples, self.stage) with self._per_batch_transform_context: diff --git a/flash/data/data_source.py b/flash/data/data_source.py index 4ec980a9f9..6d5dfebfb1 100644 --- a/flash/data/data_source.py +++ b/flash/data/data_source.py @@ -70,6 +70,7 @@ class DefaultDataKeys(LightningEnum): targets.""" INPUT = "input" + PREDS = "preds" TARGET = "target" METADATA = "metadata" diff --git a/flash/vision/segmentation/data.py b/flash/vision/segmentation/data.py index dd770c65ec..83a5147207 100644 --- a/flash/vision/segmentation/data.py +++ b/flash/vision/segmentation/data.py @@ -121,7 +121,11 @@ def load_sample(self, sample: Mapping[str, Any]) -> Mapping[str, Union[torch.Ten } def predict_load_sample(self, sample: Mapping[str, Any]) -> Mapping[str, Any]: - return {DefaultDataKeys.INPUT: torchvision.io.read_image(sample[DefaultDataKeys.INPUT]).float()} + img = torchvision.io.read_image(sample[DefaultDataKeys.INPUT]).float() + return { + DefaultDataKeys.INPUT: img, + DefaultDataKeys.METADATA: img.shape, + } class SemanticSegmentationPreprocess(Preprocess): diff --git a/flash/vision/segmentation/model.py b/flash/vision/segmentation/model.py index e543b341ed..b62c71f248 100644 --- a/flash/vision/segmentation/model.py +++ b/flash/vision/segmentation/model.py @@ -21,11 +21,19 @@ from flash.core.classification import ClassificationTask from flash.core.registry import FlashRegistry from flash.data.data_source import DefaultDataKeys -from flash.data.process import Serializer +from flash.data.process import Postprocess, Serializer from flash.vision.segmentation.backbones import SEMANTIC_SEGMENTATION_BACKBONES from flash.vision.segmentation.serialization import SegmentationLabels +class SemanticSegmentationPostprocess(Postprocess): + + def per_batch_transform(self, batch: Any) -> Any: + import pdb + pdb.set_trace() + return super().per_batch_transform(batch) + + class SemanticSegmentation(ClassificationTask): """Task that performs semantic segmentation on images. @@ -53,6 +61,8 @@ class SemanticSegmentation(ClassificationTask): serializer: The :class:`~flash.data.process.Serializer` to use when serializing prediction outputs. """ + postprocess_cls = SemanticSegmentationPostprocess + backbones: FlashRegistry = SEMANTIC_SEGMENTATION_BACKBONES def __init__( @@ -67,6 +77,7 @@ def __init__( learning_rate: float = 1e-3, multi_label: bool = False, serializer: Optional[Union[Serializer, Mapping[str, Serializer]]] = None, + postprocess: Optional[Postprocess] = None, ) -> None: if metrics is None: @@ -86,6 +97,7 @@ def __init__( metrics=metrics, learning_rate=learning_rate, serializer=serializer or SegmentationLabels(), + postprocess=postprocess or self.postprocess_cls() ) self.save_hyperparameters() @@ -109,8 +121,10 @@ def test_step(self, batch: Any, batch_idx: int) -> Any: return super().test_step(batch, batch_idx) def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: - batch = (batch[DefaultDataKeys.INPUT]) - return super().predict_step(batch, batch_idx, dataloader_idx=dataloader_idx) + batch_input = (batch[DefaultDataKeys.INPUT]) + preds = super().predict_step(batch_input, batch_idx, dataloader_idx=dataloader_idx) + batch[DefaultDataKeys.PREDS] = preds + return batch def forward(self, x) -> torch.Tensor: # infer the image to the model From 808b3c3bd6eb74429d42c57c03df1e227a6938db Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 12 May 2021 17:26:36 +0100 Subject: [PATCH 07/13] update --- flash_examples/predict/semantic_segmentation.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/flash_examples/predict/semantic_segmentation.py b/flash_examples/predict/semantic_segmentation.py index 9209923be7..f507f2a6a6 100644 --- a/flash_examples/predict/semantic_segmentation.py +++ b/flash_examples/predict/semantic_segmentation.py @@ -24,7 +24,9 @@ ) # 2. Load the model from a checkpoint -model = SemanticSegmentation.load_from_checkpoint("semantic_segmentation_model.pt") +model = SemanticSegmentation.load_from_checkpoint( + "https://flash-weights.s3.amazonaws.com/semantic_segmentation_model.pt" +) model.serializer = SegmentationLabels(visualize=True) # 3. Predict what's on a few images and visualize! From a669bac2e4e1ad69fc14bbfc10bc69c823df0930 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 12 May 2021 18:57:57 +0100 Subject: [PATCH 08/13] update --- flash/vision/segmentation/model.py | 13 +++++++++---- flash/vision/segmentation/serialization.py | 9 +++++---- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/flash/vision/segmentation/model.py b/flash/vision/segmentation/model.py index b62c71f248..7d99949e49 100644 --- a/flash/vision/segmentation/model.py +++ b/flash/vision/segmentation/model.py @@ -22,16 +22,21 @@ from flash.core.registry import FlashRegistry from flash.data.data_source import DefaultDataKeys from flash.data.process import Postprocess, Serializer +from flash.utils.imports import _KORNIA_AVAILABLE from flash.vision.segmentation.backbones import SEMANTIC_SEGMENTATION_BACKBONES from flash.vision.segmentation.serialization import SegmentationLabels +if _KORNIA_AVAILABLE: + import kornia as K + class SemanticSegmentationPostprocess(Postprocess): - def per_batch_transform(self, batch: Any) -> Any: - import pdb - pdb.set_trace() - return super().per_batch_transform(batch) + def per_sample_transform(self, sample: Any) -> Any: + resize = K.geometry.Resize(sample[DefaultDataKeys.METADATA][-2:], interpolation='bilinear') + sample[DefaultDataKeys.PREDS] = resize(torch.stack(sample[DefaultDataKeys.PREDS])) + sample[DefaultDataKeys.INPUT] = resize(torch.stack(sample[DefaultDataKeys.INPUT])) + return super().per_sample_transform(sample) class SemanticSegmentation(ClassificationTask): diff --git a/flash/vision/segmentation/serialization.py b/flash/vision/segmentation/serialization.py index 5a8cb40f69..6a63a0bc7f 100644 --- a/flash/vision/segmentation/serialization.py +++ b/flash/vision/segmentation/serialization.py @@ -17,7 +17,7 @@ import torch import flash -from flash.data.data_source import ImageLabelsMap +from flash.data.data_source import DefaultDataKeys, ImageLabelsMap from flash.data.process import Serializer from flash.utils.imports import _KORNIA_AVAILABLE, _MATPLOTLIB_AVAILABLE @@ -67,9 +67,10 @@ def create_random_labels_map(num_classes: int) -> Dict[int, Tuple[int, int, int] labels_map[i] = torch.randint(0, 255, (3, )) return labels_map - def serialize(self, sample: torch.Tensor) -> torch.Tensor: - assert len(sample.shape) == 3, sample.shape - labels = torch.argmax(sample, dim=-3) # HxW + def serialize(self, sample: Dict[str, torch.Tensor]) -> torch.Tensor: + preds = sample[DefaultDataKeys.PREDS] + assert len(preds.shape) == 3, preds.shape + labels = torch.argmax(preds, dim=-3) # HxW if self.visualize and not flash._IS_TESTING: if self.labels_map is None: From 302014bbe8b4cba1a038632d865d408188fae754 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 12 May 2021 19:04:48 +0100 Subject: [PATCH 09/13] update --- flash/core/model.py | 1 - flash/data/batch.py | 6 ++++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/flash/core/model.py b/flash/core/model.py index 667ceb231f..e01170fbab 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -191,7 +191,6 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A elif isinstance(batch, list): # Todo: Understand why stack is needed batch = torch.stack(batch) - self(batch) return self(batch) def configure_optimizers(self) -> Union[Optimizer, Tuple[List[Optimizer], List[_LRScheduler]]]: diff --git a/flash/data/batch.py b/flash/data/batch.py index 580adeaaa1..c0df4deb64 100644 --- a/flash/data/batch.py +++ b/flash/data/batch.py @@ -138,8 +138,10 @@ def __init__( self._collate_context = CurrentFuncContext("collate", preprocess) self._per_batch_transform_context = CurrentFuncContext(f"per_batch_transform{extension}", preprocess) - def _extract_metadata(self, - samples: List[Dict[str, Any]]) -> Tuple[List[Dict[str, Any]], Optional[List[Dict[str, Any]]]]: + def _extract_metadata( + self, + samples: List[Dict[str, Any]], + ) -> Tuple[List[Dict[str, Any]], Optional[List[Dict[str, Any]]]]: metadata = [s.pop(DefaultDataKeys.METADATA, None) for s in samples] return samples, metadata if any(m is not None for m in metadata) else None From 4e5d044c7809b62340ad7321088384fc2957c41c Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 12 May 2021 19:12:43 +0100 Subject: [PATCH 10/13] update --- flash/data/batch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flash/data/batch.py b/flash/data/batch.py index c0df4deb64..9199e6c45f 100644 --- a/flash/data/batch.py +++ b/flash/data/batch.py @@ -142,7 +142,7 @@ def _extract_metadata( self, samples: List[Dict[str, Any]], ) -> Tuple[List[Dict[str, Any]], Optional[List[Dict[str, Any]]]]: - metadata = [s.pop(DefaultDataKeys.METADATA, None) for s in samples] + metadata = [s.pop(DefaultDataKeys.METADATA, None) if isinstance(s, Mapping) else None for s in samples] return samples, metadata if any(m is not None for m in metadata) else None def forward(self, samples: Sequence[Any]) -> Any: From c512f75f103601cac9c5858f9f8446c852f693a7 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 12 May 2021 19:32:43 +0100 Subject: [PATCH 11/13] update --- flash/vision/segmentation/data.py | 15 +++++++++++++-- tests/vision/segmentation/test_model.py | 4 ++-- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/flash/vision/segmentation/data.py b/flash/vision/segmentation/data.py index 67729713bc..df48cfcc80 100644 --- a/flash/vision/segmentation/data.py +++ b/flash/vision/segmentation/data.py @@ -52,7 +52,18 @@ class SemanticSegmentationNumpyDataSource(NumpyDataSource): def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]: - sample[DefaultDataKeys.INPUT] = torch.from_numpy(sample[DefaultDataKeys.INPUT]).float() + img = torch.from_numpy(sample[DefaultDataKeys.INPUT]).float() + sample[DefaultDataKeys.INPUT] = img + sample[DefaultDataKeys.METADATA] = img.shape + return sample + + +class SemanticSegmentationTensorDataSource(TensorDataSource): + + def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]: + img = sample[DefaultDataKeys.INPUT].float() + sample[DefaultDataKeys.INPUT] = img + sample[DefaultDataKeys.METADATA] = img.shape return sample @@ -162,7 +173,7 @@ def __init__( data_sources={ DefaultDataSources.FILES: SemanticSegmentationPathsDataSource(), DefaultDataSources.FOLDERS: SemanticSegmentationPathsDataSource(), - DefaultDataSources.TENSORS: TensorDataSource(), + DefaultDataSources.TENSORS: SemanticSegmentationTensorDataSource(), DefaultDataSources.NUMPY: SemanticSegmentationNumpyDataSource(), }, default_data_source=DefaultDataSources.FILES, diff --git a/tests/vision/segmentation/test_model.py b/tests/vision/segmentation/test_model.py index 5ccc86d68f..d436ffa982 100644 --- a/tests/vision/segmentation/test_model.py +++ b/tests/vision/segmentation/test_model.py @@ -89,7 +89,7 @@ def test_predict_tensor(): data_pipe = DataPipeline(preprocess=SemanticSegmentationPreprocess(num_classes=1)) out = model.predict(img, data_source="tensors", data_pipeline=data_pipe) assert isinstance(out[0], torch.Tensor) - assert out[0].shape == (196, 196) + assert out[0].shape == (10, 20) def test_predict_numpy(): @@ -98,4 +98,4 @@ def test_predict_numpy(): data_pipe = DataPipeline(preprocess=SemanticSegmentationPreprocess(num_classes=1)) out = model.predict(img, data_source="numpy", data_pipeline=data_pipe) assert isinstance(out[0], torch.Tensor) - assert out[0].shape == (196, 196) + assert out[0].shape == (10, 20) From 1da7d89b18833f03d800c909f5fb5983cd54c54f Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 12 May 2021 19:33:26 +0100 Subject: [PATCH 12/13] update --- flash/data/batch.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/flash/data/batch.py b/flash/data/batch.py index 9199e6c45f..61aa6c0e26 100644 --- a/flash/data/batch.py +++ b/flash/data/batch.py @@ -166,10 +166,10 @@ def forward(self, samples: Sequence[Any]) -> Any: samples = type(_samples)(_samples) with self._collate_context: - samples, metada = self._extract_metadata(samples) + samples, metadata = self._extract_metadata(samples) samples = self.collate_fn(samples) - if metada: - samples[DefaultDataKeys.METADATA] = metada + if metadata: + samples[DefaultDataKeys.METADATA] = metadata self.callback.on_collate(samples, self.stage) with self._per_batch_transform_context: From 8349d40441146311b170b47248e596810824e85e Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 12 May 2021 19:46:32 +0100 Subject: [PATCH 13/13] update --- tests/vision/segmentation/test_serialization.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/vision/segmentation/test_serialization.py b/tests/vision/segmentation/test_serialization.py index a971c91fbf..872fcc2420 100644 --- a/tests/vision/segmentation/test_serialization.py +++ b/tests/vision/segmentation/test_serialization.py @@ -1,6 +1,7 @@ import pytest import torch +from flash.data.data_source import DefaultDataKeys from flash.vision.segmentation.serialization import SegmentationLabels @@ -30,7 +31,7 @@ def test_serialize(self): sample[1, 1, 2] = 1 # add peak in class 2 sample[3, 0, 1] = 1 # add peak in class 4 - classes = serial.serialize(sample) + classes = serial.serialize({DefaultDataKeys.PREDS: sample}) assert classes[1, 2] == 1 assert classes[0, 1] == 3