From ddd85f0f7a7c58de62fcaa9ae4e2ea85ad002840 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Sat, 10 Jul 2021 12:31:20 +0100 Subject: [PATCH] Image classification csv data source (#556) * Initial commit * Add support for from_csv and from_data_frame to ImageClassificationData * Update CHANGELOG.md * Fixes * Clean --- CHANGELOG.md | 1 + flash/image/classification/data.py | 296 +++++++++++++++++- flash/image/classification/model.py | 2 +- .../image_classification_multi_label.py | 26 +- flash_examples/object_detection.py | 4 +- tests/image/classification/test_data.py | 111 +++++++ 6 files changed, 415 insertions(+), 25 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 14cd73c12b..117c68ebb0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added - Added support for (input, target) style datasets (e.g. torchvision) to the from_datasets method ([#552](https://github.com/PyTorchLightning/lightning-flash/pull/552)) +- Added support for `from_csv` and `from_data_frame` to `ImageClassificationData` ([#556](https://github.com/PyTorchLightning/lightning-flash/pull/556)) ### Changed diff --git a/flash/image/classification/data.py b/flash/image/classification/data.py index deb84f82a4..2da17645ae 100644 --- a/flash/image/classification/data.py +++ b/flash/image/classification/data.py @@ -11,18 +11,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +import glob +import os +from functools import partial +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union import numpy as np +import pandas as pd import torch from pytorch_lightning.trainer.states import RunningStage +from torch.utils.data.sampler import Sampler from flash.core.data.base_viz import BaseVisualization # for viz from flash.core.data.callback import BaseDataFetcher from flash.core.data.data_module import DataModule -from flash.core.data.data_source import DefaultDataKeys, DefaultDataSources +from flash.core.data.data_source import DataSource, DefaultDataKeys, DefaultDataSources, LabelsState from flash.core.data.process import Deserializer, Preprocess -from flash.core.utilities.imports import _MATPLOTLIB_AVAILABLE, _PIL_AVAILABLE, _requires_extras +from flash.core.utilities.imports import _MATPLOTLIB_AVAILABLE, _PIL_AVAILABLE, _requires_extras, _TORCHVISION_AVAILABLE from flash.image.classification.transforms import default_transforms, train_default_transforms from flash.image.data import ( ImageDeserializer, @@ -37,6 +42,9 @@ else: plt = None +if _TORCHVISION_AVAILABLE: + from torchvision.datasets.folder import default_loader + if _PIL_AVAILABLE: from PIL import Image else: @@ -45,6 +53,96 @@ class Image: Image = None +class ImageClassificationDataFrameDataSource( + DataSource[Tuple[pd.DataFrame, str, Union[str, List[str]], Optional[str]]] +): + + @staticmethod + def _resolve_file(root: str, file_id: str) -> str: + if os.path.isabs(file_id): + pattern = f"{file_id}*" + else: + pattern = os.path.join(root, f"*{file_id}*") + files = glob.glob(pattern) + if len(files) > 1: + raise ValueError( + f"Found multiple matches for pattern: {pattern}. File IDs should uniquely identify the file to load." + ) + elif len(files) == 0: + raise ValueError( + f"Found no matches for pattern: {pattern}. File IDs should uniquely identify the file to load." + ) + return files[0] + + @staticmethod + def _resolve_target(label_to_class: Dict[str, int], target_key: str, row: pd.Series) -> pd.Series: + row[target_key] = label_to_class[row[target_key]] + return row + + @staticmethod + def _resolve_multi_target(target_keys: List[str], row: pd.Series) -> pd.Series: + row[target_keys[0]] = [row[target_key] for target_key in target_keys] + return row + + def load_data( + self, + data: Tuple[pd.DataFrame, str, Union[str, List[str]], Optional[str]], + dataset: Optional[Any] = None, + ) -> Sequence[Mapping[str, Any]]: + data_frame, input_key, target_keys, root = data + if root is None: + root = "" + + if not self.predicting: + if isinstance(target_keys, List): + dataset.num_classes = len(target_keys) + self.set_state(LabelsState(target_keys)) + data_frame = data_frame.apply(partial(self._resolve_multi_target, target_keys), axis=1) + target_keys = target_keys[0] + else: + if self.training: + labels = list(sorted(data_frame[target_keys].unique())) + dataset.num_classes = len(labels) + self.set_state(LabelsState(labels)) + + labels = self.get_state(LabelsState) + + if labels is not None: + labels = labels.labels + label_to_class = {v: k for k, v in enumerate(labels)} + data_frame = data_frame.apply(partial(self._resolve_target, label_to_class, target_keys), axis=1) + + return [{ + DefaultDataKeys.INPUT: row[input_key], + DefaultDataKeys.TARGET: row[target_keys], + DefaultDataKeys.METADATA: dict(root=root), + } for _, row in data_frame.iterrows()] + else: + return [{ + DefaultDataKeys.INPUT: row[input_key], + DefaultDataKeys.METADATA: dict(root=root), + } for _, row in data_frame.iterrows()] + + def load_sample(self, sample: Dict[str, Any], dataset: Optional[Any] = None) -> Dict[str, Any]: + file = self._resolve_file(sample[DefaultDataKeys.METADATA]['root'], sample[DefaultDataKeys.INPUT]) + sample[DefaultDataKeys.INPUT] = default_loader(file) + return sample + + +class ImageClassificationCSVDataSource(ImageClassificationDataFrameDataSource): + + def load_data( + self, + data: Tuple[str, str, Union[str, List[str]], Optional[str]], + dataset: Optional[Any] = None, + ) -> Sequence[Mapping[str, Any]]: + csv_file, input_key, target_keys, root = data + data_frame = pd.read_csv(csv_file) + if root is None: + root = os.path.dirname(csv_file) + return super().load_data((data_frame, input_key, target_keys, root), dataset) + + class ImageClassificationPreprocess(Preprocess): def __init__( @@ -70,6 +168,8 @@ def __init__( DefaultDataSources.FOLDERS: ImagePathsDataSource(), DefaultDataSources.NUMPY: ImageNumpyDataSource(), DefaultDataSources.TENSORS: ImageTensorDataSource(), + "data_frame": ImageClassificationDataFrameDataSource(), + DefaultDataSources.CSV: ImageClassificationCSVDataSource(), }, deserializer=deserializer or ImageDeserializer(), default_data_source=DefaultDataSources.FILES, @@ -94,6 +194,196 @@ class ImageClassificationData(DataModule): preprocess_cls = ImageClassificationPreprocess + @classmethod + def from_data_frame( + cls, + input_field: str, + target_fields: Optional[Union[str, Sequence[str]]] = None, + train_data_frame: Optional[pd.DataFrame] = None, + train_images_root: Optional[str] = None, + val_data_frame: Optional[pd.DataFrame] = None, + val_images_root: Optional[str] = None, + test_data_frame: Optional[pd.DataFrame] = None, + test_images_root: Optional[str] = None, + predict_data_frame: Optional[pd.DataFrame] = None, + predict_images_root: Optional[str] = None, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + data_fetcher: Optional[BaseDataFetcher] = None, + preprocess: Optional[Preprocess] = None, + val_split: Optional[float] = None, + batch_size: int = 4, + num_workers: Optional[int] = None, + sampler: Optional[Sampler] = None, + **preprocess_kwargs: Any, + ) -> 'DataModule': + """Creates a :class:`~flash.image.classification.data.ImageClassificationData` object from the given pandas + ``DataFrame`` objects. + + Args: + input_field: The field (column) in the pandas ``DataFrame`` to use for the input. + target_fields: The field or fields (columns) in the pandas ``DataFrame`` to use for the target. + train_data_frame: The pandas ``DataFrame`` containing the training data. + train_images_root: The directory containing the train images. If ``None``, values in the ``input_field`` + will be assumed to be the full file paths. + val_data_frame: The pandas ``DataFrame`` containing the validation data. + val_images_root: The directory containing the validation images. If ``None``, the directory containing the + ``val_file`` will be used. + test_data_frame: The pandas ``DataFrame`` containing the testing data. + test_images_root: The directory containing the test images. If ``None``, the directory containing the + ``test_file`` will be used. + predict_data_frame: The pandas ``DataFrame`` containing the data to use when predicting. + predict_images_root: The directory containing the predict images. If ``None``, the directory containing the + ``predict_file`` will be used. + train_transform: The dictionary of transforms to use during training which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + val_transform: The dictionary of transforms to use during validation which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + test_transform: The dictionary of transforms to use during testing which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + predict_transform: The dictionary of transforms to use during predicting which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the + :class:`~flash.core.data.data_module.DataModule`. + preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the + :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` + will be constructed and used. + val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + sampler: The ``sampler`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used + if ``preprocess = None``. + + Returns: + The constructed data module. + + Examples:: + + data_module = ImageClassificationData.from_data_frame( + "image_id", + "target", + train_data_frame=train_data, + train_images_root="data/train_images", + ) + """ + return cls.from_data_source( + "data_frame", + (train_data_frame, input_field, target_fields, train_images_root), + (val_data_frame, input_field, target_fields, val_images_root), + (test_data_frame, input_field, target_fields, test_images_root), + (predict_data_frame, input_field, target_fields, predict_images_root), + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_fetcher=data_fetcher, + preprocess=preprocess, + val_split=val_split, + batch_size=batch_size, + num_workers=num_workers, + sampler=sampler, + **preprocess_kwargs, + ) + + @classmethod + def from_csv( + cls, + input_field: str, + target_fields: Optional[Union[str, Sequence[str]]] = None, + train_file: Optional[str] = None, + train_images_root: Optional[str] = None, + val_file: Optional[str] = None, + val_images_root: Optional[str] = None, + test_file: Optional[str] = None, + test_images_root: Optional[str] = None, + predict_file: Optional[str] = None, + predict_images_root: Optional[str] = None, + train_transform: Optional[Dict[str, Callable]] = None, + val_transform: Optional[Dict[str, Callable]] = None, + test_transform: Optional[Dict[str, Callable]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + data_fetcher: Optional[BaseDataFetcher] = None, + preprocess: Optional[Preprocess] = None, + val_split: Optional[float] = None, + batch_size: int = 4, + num_workers: Optional[int] = None, + sampler: Optional[Sampler] = None, + **preprocess_kwargs: Any, + ) -> 'DataModule': + """Creates a :class:`~flash.image.classification.data.ImageClassificationData` object from the given CSV files + using the :class:`~flash.core.data.data_source.DataSource` + of name :attr:`~flash.core.data.data_source.DefaultDataSources.CSV` + from the passed or constructed :class:`~flash.core.data.process.Preprocess`. + + Args: + input_field: The field (column) in the CSV file to use for the input. + target_fields: The field or fields (columns) in the CSV file to use for the target. + train_file: The CSV file containing the training data. + train_images_root: The directory containing the train images. If ``None``, the directory containing the + ``train_file`` will be used. + val_file: The CSV file containing the validation data. + val_images_root: The directory containing the validation images. If ``None``, the directory containing the + ``val_file`` will be used. + test_file: The CSV file containing the testing data. + test_images_root: The directory containing the test images. If ``None``, the directory containing the + ``test_file`` will be used. + predict_file: The CSV file containing the data to use when predicting. + predict_images_root: The directory containing the predict images. If ``None``, the directory containing the + ``predict_file`` will be used. + train_transform: The dictionary of transforms to use during training which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + val_transform: The dictionary of transforms to use during validation which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + test_transform: The dictionary of transforms to use during testing which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + predict_transform: The dictionary of transforms to use during predicting which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the + :class:`~flash.core.data.data_module.DataModule`. + preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the + :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` + will be constructed and used. + val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + sampler: The ``sampler`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used + if ``preprocess = None``. + + Returns: + The constructed data module. + + Examples:: + + data_module = ImageClassificationData.from_csv( + "image_id", + "target", + train_file="train_data.csv", + train_images_root="data/train_images", + ) + """ + return cls.from_data_source( + DefaultDataSources.CSV, + (train_file, input_field, target_fields, train_images_root), + (val_file, input_field, target_fields, val_images_root), + (test_file, input_field, target_fields, test_images_root), + (predict_file, input_field, target_fields, predict_images_root), + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_fetcher=data_fetcher, + preprocess=preprocess, + val_split=val_split, + batch_size=batch_size, + num_workers=num_workers, + sampler=sampler, + **preprocess_kwargs, + ) + def set_block_viz_window(self, value: bool) -> None: """Setter method to switch on/off matplotlib to pop up windows.""" self.data_fetcher.block_viz_window = value diff --git a/flash/image/classification/model.py b/flash/image/classification/model.py index 46c1f6cbd2..abd366c2a8 100644 --- a/flash/image/classification/model.py +++ b/flash/image/classification/model.py @@ -94,7 +94,7 @@ def __init__( metrics=metrics or F1(num_classes) if multi_label else Accuracy(), learning_rate=learning_rate, multi_label=multi_label, - serializer=serializer or Labels(), + serializer=serializer or Labels(multi_label=multi_label), ) self.save_hyperparameters() diff --git a/flash_examples/image_classification_multi_label.py b/flash_examples/image_classification_multi_label.py index 00e86d7f0b..9f2ef46457 100644 --- a/flash_examples/image_classification_multi_label.py +++ b/flash_examples/image_classification_multi_label.py @@ -11,11 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os.path as osp -from typing import List, Tuple - -import pandas as pd - import flash from flash.core.data.utils import download_data from flash.image import ImageClassificationData, ImageClassifier @@ -24,25 +19,18 @@ # Data set from the paper “Movie Genre Classification based on Poster Images with Deep Neural Networks”. # More info here: https://www.cs.ccu.edu.tw/~wtchu/projects/MoviePoster/ download_data("https://pl-flash-data.s3.amazonaws.com/movie_posters.zip") -genres = ["Action", "Romance", "Crime", "Thriller", "Adventure"] - - -def load_data(data: str, root: str = 'data/movie_posters') -> Tuple[List[str], List[List[int]]]: - metadata = pd.read_csv(osp.join(root, data, "metadata.csv")) - return ([osp.join(root, data, row['Id'] + ".jpg") for _, row in metadata.iterrows()], - [[int(row[genre]) for genre in genres] for _, row in metadata.iterrows()]) - -train_files, train_targets = load_data('train') -datamodule = ImageClassificationData.from_files( - train_files=train_files, - train_targets=train_targets, +datamodule = ImageClassificationData.from_csv( + 'Id', + ["Action", "Romance", "Crime", "Thriller", "Adventure"], + train_file="data/movie_posters/train/metadata.csv", + val_file="data/movie_posters/val/metadata.csv", val_split=0.1, image_size=(128, 128), ) # 2. Build the task -model = ImageClassifier(backbone="resnet18", num_classes=len(genres), multi_label=True) +model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes, multi_label=True) # 3. Create the trainer and finetune the model trainer = flash.Trainer(max_epochs=3) @@ -56,5 +44,5 @@ def load_data(data: str, root: str = 'data/movie_posters') -> Tuple[List[str], L ]) print(predictions) -# 7. Save the model! +# 5. Save the model! trainer.save_checkpoint("image_classification_multi_label_model.pt") diff --git a/flash_examples/object_detection.py b/flash_examples/object_detection.py index 4f488e1e11..118bdc5c67 100644 --- a/flash_examples/object_detection.py +++ b/flash_examples/object_detection.py @@ -17,11 +17,11 @@ # 1. Create the DataModule # Dataset Credit: https://www.kaggle.com/ultralytics/coco128 -download_data("https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.3.0/coco128.zip", "finetuning/data/") +download_data("https://github.com/zhiqwang/yolov5-rt-stack/releases/download/v0.3.0/coco128.zip", "data/") datamodule = ObjectDetectionData.from_coco( train_folder="data/coco128/images/train2017/", - train_ann_file="finetuning/data/coco128/annotations/instances_train2017.json", + train_ann_file="data/coco128/annotations/instances_train2017.json", val_split=0.1, ) diff --git a/tests/image/classification/test_data.py b/tests/image/classification/test_data.py index 183f3427a4..232998522e 100644 --- a/tests/image/classification/test_data.py +++ b/tests/image/classification/test_data.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import csv from pathlib import Path from typing import Any, List, Tuple @@ -473,3 +474,113 @@ def test_from_datasets(): imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] assert imgs.shape == (2, 3, 196, 196) assert labels.shape == (2, ) + + +@pytest.fixture +def image_tmpdir(tmpdir): + (tmpdir / "train").mkdir() + Image.new("RGB", (128, 128)).save(str(tmpdir / "train" / "image_1.png")) + Image.new("RGB", (128, 128)).save(str(tmpdir / "train" / "image_2.png")) + return tmpdir / "train" + + +@pytest.fixture +def single_target_csv(image_tmpdir): + with open(image_tmpdir / "metadata.csv", "w") as csvfile: + fieldnames = ["image", "target"] + writer = csv.DictWriter(csvfile, fieldnames) + writer.writeheader() + writer.writerow({"image": "image_1", "target": "Ants"}) + writer.writerow({"image": "image_2", "target": "Bees"}) + return str(image_tmpdir / "metadata.csv") + + +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +def test_from_csv_single_target(single_target_csv): + img_data = ImageClassificationData.from_csv( + "image", + "target", + train_file=single_target_csv, + batch_size=2, + num_workers=0, + ) + + # check training data + data = next(iter(img_data.train_dataloader())) + imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] + assert imgs.shape == (2, 3, 196, 196) + assert labels.shape == (2, ) + + +@pytest.fixture +def multi_target_csv(image_tmpdir): + with open(image_tmpdir / "metadata.csv", "w") as csvfile: + fieldnames = ["image", "target_1", "target_2"] + writer = csv.DictWriter(csvfile, fieldnames) + writer.writeheader() + writer.writerow({"image": "image_1", "target_1": 1, "target_2": 0}) + writer.writerow({"image": "image_2", "target_1": 1, "target_2": 1}) + return str(image_tmpdir / "metadata.csv") + + +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +def test_from_csv_multi_target(multi_target_csv): + img_data = ImageClassificationData.from_csv( + "image", + ["target_1", "target_2"], + train_file=multi_target_csv, + batch_size=2, + num_workers=0, + ) + + # check training data + data = next(iter(img_data.train_dataloader())) + imgs, labels = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET] + assert imgs.shape == (2, 3, 196, 196) + assert labels.shape == (2, 2) + + +@pytest.fixture +def bad_csv_multi_image(image_tmpdir): + with open(image_tmpdir / "metadata.csv", "w") as csvfile: + fieldnames = ["image", "target"] + writer = csv.DictWriter(csvfile, fieldnames) + writer.writeheader() + writer.writerow({"image": "image", "target": "Ants"}) + return str(image_tmpdir / "metadata.csv") + + +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +def test_from_bad_csv_multi_image(bad_csv_multi_image): + with pytest.raises(ValueError, match="Found multiple matches"): + img_data = ImageClassificationData.from_csv( + "image", + ["target"], + train_file=bad_csv_multi_image, + batch_size=1, + num_workers=0, + ) + _ = next(iter(img_data.train_dataloader())) + + +@pytest.fixture +def bad_csv_no_image(image_tmpdir): + with open(image_tmpdir / "metadata.csv", "w") as csvfile: + fieldnames = ["image", "target"] + writer = csv.DictWriter(csvfile, fieldnames) + writer.writeheader() + writer.writerow({"image": "image_3", "target": "Ants"}) + return str(image_tmpdir / "metadata.csv") + + +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +def test_from_bad_csv_no_image(bad_csv_no_image): + with pytest.raises(ValueError, match="Found no matches"): + img_data = ImageClassificationData.from_csv( + "image", + ["target"], + train_file=bad_csv_no_image, + batch_size=1, + num_workers=0, + ) + _ = next(iter(img_data.train_dataloader()))