From 43b062ff99ba2c976a781ebb41f188d3a1da3a78 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sat, 13 Feb 2021 23:53:01 -0500 Subject: [PATCH 1/5] added .csv image loading utils --- flash/core/model.py | 7 +++++++ tests/core/test_model.py | 19 +++++++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/flash/core/model.py b/flash/core/model.py index 9c042cf590..c1710a67ea 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import functools +import os from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Type, Union import pytorch_lightning as pl @@ -142,6 +143,12 @@ def predict( The post-processed model predictions """ + # enable x to be a path to a folder + if isinstance(x, str): + files = os.listdir(x) + files = [os.path.join(x, y) for y in files] + x = files + data_pipeline = data_pipeline or self.data_pipeline batch = x if skip_collate_fn else data_pipeline.collate_fn(x) batch_x, batch_y = batch if len(batch) == 2 and isinstance(batch, (list, tuple)) else (batch, None) diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 8e67ceb2e1..a507266bc5 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -11,8 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from pathlib import Path from typing import Any +import numpy as np +from PIL import Image import pytest import pytorch_lightning as pl import torch @@ -68,6 +71,18 @@ def test_classificationtask_task_predict(): assert pred0[0] == pred1[0] +def test_classification_task_predict_folder_path(tmpdir): + train_dir = Path(tmpdir / "train") + train_dir.mkdir() + + _rand_image().save(train_dir / "1.png") + _rand_image().save(train_dir / "2.png") + + task = ImageClassifier(num_classes=10) + predictions = task.predict(str(train_dir)) + assert len(predictions) == 2 + + def test_classificationtask_trainer_predict(tmpdir): model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10)) task = ClassificationTask(model) @@ -127,3 +142,7 @@ def test_model_download(tmpdir, cls, filename): with tmpdir.as_cwd(): task = cls.load_from_checkpoint(url + filename) assert isinstance(task, cls) + + +def _rand_image(): + return Image.fromarray(np.random.randint(0, 255, (64, 64, 3), dtype="uint8")) From 2c15fe7c29fd80605a854cb32e3ddb4c1afdc52d Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sat, 13 Feb 2021 23:54:11 -0500 Subject: [PATCH 2/5] added .csv image loading utils --- flash/core/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flash/core/model.py b/flash/core/model.py index c1710a67ea..8d45939abb 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -127,7 +127,7 @@ def predict( Args: - x: Input to predict. Can be raw data or processed data. + x: Input to predict. Can be raw data or processed data. If str, assumed to be a folder of data. batch_idx: Batch index From f1549438ca5f5126ace1941847934d7b48181258 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sat, 13 Feb 2021 23:57:37 -0500 Subject: [PATCH 3/5] added .csv image loading utils --- tests/core/test_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/core/test_model.py b/tests/core/test_model.py index a507266bc5..efd2009a67 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -15,10 +15,10 @@ from typing import Any import numpy as np -from PIL import Image import pytest import pytorch_lightning as pl import torch +from PIL import Image from torch import nn from torch.nn import functional as F From 2b213fbf66a96033dcb95dfb422a8268625d2337 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 14 Feb 2021 00:05:45 -0500 Subject: [PATCH 4/5] added .csv image loading utils --- flash/vision/classification/data.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/flash/vision/classification/data.py b/flash/vision/classification/data.py index 6f90f2571d..766da2460b 100644 --- a/flash/vision/classification/data.py +++ b/flash/vision/classification/data.py @@ -17,7 +17,7 @@ import pandas as pd import torch -from PIL import Image +from PIL import Image, UnidentifiedImageError from pytorch_lightning.utilities.exceptions import MisconfigurationException from torchvision import transforms as T from torchvision.datasets import VisionDataset @@ -241,9 +241,13 @@ def before_collate(self, samples: Any) -> Any: if isinstance(samples, (list, tuple)) and all(isinstance(p, str) for p in samples): outputs = [] for sample in samples: - output = self._loader(sample) - transform = self._valid_transform if self._use_valid_transform else self._train_transform - outputs.append(transform(output)) + try: + output = self._loader(sample) + transform = self._valid_transform if self._use_valid_transform else self._train_transform + outputs.append(transform(output)) + except UnidentifiedImageError as e: + print(f'Skipping: could not read file {sample}') + return outputs raise MisconfigurationException("The samples should either be a tensor or a list of paths.") From dda28349310862b24ab3fe6aa3b651a4145beb6f Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 14 Feb 2021 00:10:43 -0500 Subject: [PATCH 5/5] added .csv image loading utils --- flash/vision/classification/data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flash/vision/classification/data.py b/flash/vision/classification/data.py index 766da2460b..fcbfb5e5a1 100644 --- a/flash/vision/classification/data.py +++ b/flash/vision/classification/data.py @@ -245,7 +245,7 @@ def before_collate(self, samples: Any) -> Any: output = self._loader(sample) transform = self._valid_transform if self._use_valid_transform else self._train_transform outputs.append(transform(output)) - except UnidentifiedImageError as e: + except UnidentifiedImageError: print(f'Skipping: could not read file {sample}') return outputs