From b41e33db8459e654aace82eef7ae4fe1a1a0e070 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 14 Feb 2021 07:11:51 -0500 Subject: [PATCH] added .csv image loading utils (#118) * added .csv image loading utils * added .csv image loading utils * added .csv image loading utils * added .csv image loading utils * added .csv image loading utils --- flash/core/model.py | 9 ++++++++- flash/vision/classification/data.py | 12 ++++++++---- tests/core/test_model.py | 19 +++++++++++++++++++ 3 files changed, 35 insertions(+), 5 deletions(-) diff --git a/flash/core/model.py b/flash/core/model.py index 9c042cf590..8d45939abb 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import functools +import os from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Type, Union import pytorch_lightning as pl @@ -126,7 +127,7 @@ def predict( Args: - x: Input to predict. Can be raw data or processed data. + x: Input to predict. Can be raw data or processed data. If str, assumed to be a folder of data. batch_idx: Batch index @@ -142,6 +143,12 @@ def predict( The post-processed model predictions """ + # enable x to be a path to a folder + if isinstance(x, str): + files = os.listdir(x) + files = [os.path.join(x, y) for y in files] + x = files + data_pipeline = data_pipeline or self.data_pipeline batch = x if skip_collate_fn else data_pipeline.collate_fn(x) batch_x, batch_y = batch if len(batch) == 2 and isinstance(batch, (list, tuple)) else (batch, None) diff --git a/flash/vision/classification/data.py b/flash/vision/classification/data.py index 6f90f2571d..fcbfb5e5a1 100644 --- a/flash/vision/classification/data.py +++ b/flash/vision/classification/data.py @@ -17,7 +17,7 @@ import pandas as pd import torch -from PIL import Image +from PIL import Image, UnidentifiedImageError from pytorch_lightning.utilities.exceptions import MisconfigurationException from torchvision import transforms as T from torchvision.datasets import VisionDataset @@ -241,9 +241,13 @@ def before_collate(self, samples: Any) -> Any: if isinstance(samples, (list, tuple)) and all(isinstance(p, str) for p in samples): outputs = [] for sample in samples: - output = self._loader(sample) - transform = self._valid_transform if self._use_valid_transform else self._train_transform - outputs.append(transform(output)) + try: + output = self._loader(sample) + transform = self._valid_transform if self._use_valid_transform else self._train_transform + outputs.append(transform(output)) + except UnidentifiedImageError: + print(f'Skipping: could not read file {sample}') + return outputs raise MisconfigurationException("The samples should either be a tensor or a list of paths.") diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 8e67ceb2e1..efd2009a67 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -11,11 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from pathlib import Path from typing import Any +import numpy as np import pytest import pytorch_lightning as pl import torch +from PIL import Image from torch import nn from torch.nn import functional as F @@ -68,6 +71,18 @@ def test_classificationtask_task_predict(): assert pred0[0] == pred1[0] +def test_classification_task_predict_folder_path(tmpdir): + train_dir = Path(tmpdir / "train") + train_dir.mkdir() + + _rand_image().save(train_dir / "1.png") + _rand_image().save(train_dir / "2.png") + + task = ImageClassifier(num_classes=10) + predictions = task.predict(str(train_dir)) + assert len(predictions) == 2 + + def test_classificationtask_trainer_predict(tmpdir): model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10)) task = ClassificationTask(model) @@ -127,3 +142,7 @@ def test_model_download(tmpdir, cls, filename): with tmpdir.as_cwd(): task = cls.load_from_checkpoint(url + filename) assert isinstance(task, cls) + + +def _rand_image(): + return Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype="uint8"))