From ccf25c37efdcb43287038e3d68bc2b8a69cb2bf2 Mon Sep 17 00:00:00 2001 From: Michael Boesl Date: Sun, 26 Dec 2021 23:18:40 +0100 Subject: [PATCH 1/3] fix from_data_frame factory method with prediction df --- flash/image/classification/data.py | 2 +- tests/image/classification/test_data.py | 50 +++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 1 deletion(-) diff --git a/flash/image/classification/data.py b/flash/image/classification/data.py index 3bcf3611e1..0ab5395f71 100644 --- a/flash/image/classification/data.py +++ b/flash/image/classification/data.py @@ -217,7 +217,7 @@ def from_data_frame( train_data = (train_data_frame, input_field, target_fields, train_images_root, train_resolver) val_data = (val_data_frame, input_field, target_fields, val_images_root, val_resolver) test_data = (test_data_frame, input_field, target_fields, test_images_root, test_resolver) - predict_data = (predict_data_frame, input_field, predict_images_root, predict_resolver) + predict_data = (predict_data_frame, input_field, None, predict_images_root, predict_resolver) return cls( input_cls(RunningStage.TRAINING, *train_data, transform=train_transform, **ds_kw), diff --git a/tests/image/classification/test_data.py b/tests/image/classification/test_data.py index 59bf13ccb5..54d2fb4f33 100644 --- a/tests/image/classification/test_data.py +++ b/tests/image/classification/test_data.py @@ -17,6 +17,7 @@ from typing import Any, List, Tuple import numpy as np +import pandas as pd import pytest import torch import torch.nn as nn @@ -84,6 +85,55 @@ def test_from_filepaths_smoke(tmpdir): assert sorted(list(labels.numpy())) == [1, 2] +@pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") +def test_from_data_frame_smoke(tmpdir): + tmpdir = Path(tmpdir) + + df = pd.DataFrame({"file": ["train.png", "valid.png", "test.png"], "split": ["train", "valid", "test"], + "target": [0, 1, 1]}) + + [_rand_image().save(tmpdir / row.file) for i, row in df.iterrows()] + + img_data = ImageClassificationData.from_data_frame( + "file", "target", + train_images_root=str(tmpdir), + val_images_root=str(tmpdir), + test_images_root=str(tmpdir), + train_data_frame=df[df.split == "train"], + val_data_frame=df[df.split == "valid"], + test_data_frame=df[df.split == "test"], + predict_images_root=str(tmpdir), + batch_size=1, + predict_data_frame=df) + + assert img_data.train_dataloader() is not None + assert img_data.val_dataloader() is not None + assert img_data.test_dataloader() is not None + assert img_data.predict_dataloader() is not None + + data = next(iter(img_data.train_dataloader())) + imgs, labels = data["input"], data["target"] + assert imgs.shape == (1, 3, 196, 196) + assert labels.shape == (1,) + assert sorted(list(labels.numpy())) == [0] + + data = next(iter(img_data.val_dataloader())) + imgs, labels = data["input"], data["target"] + assert imgs.shape == (1, 3, 196, 196) + assert labels.shape == (1,) + assert sorted(list(labels.numpy())) == [1] + + data = next(iter(img_data.test_dataloader())) + imgs, labels = data["input"], data["target"] + assert imgs.shape == (1, 3, 196, 196) + assert labels.shape == (1,) + assert sorted(list(labels.numpy())) == [1] + + data = next(iter(img_data.predict_dataloader())) + imgs = data["input"] + assert imgs.shape == (1, 3, 196, 196) + + @pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") def test_from_filepaths_list_image_paths(tmpdir): tmpdir = Path(tmpdir) From 0f7df42933225a51ca22785342fe561ce876aa14 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 29 Dec 2021 18:35:21 +0000 Subject: [PATCH 2/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/image/classification/test_data.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/image/classification/test_data.py b/tests/image/classification/test_data.py index 54d2fb4f33..6213f26708 100644 --- a/tests/image/classification/test_data.py +++ b/tests/image/classification/test_data.py @@ -89,13 +89,15 @@ def test_from_filepaths_smoke(tmpdir): def test_from_data_frame_smoke(tmpdir): tmpdir = Path(tmpdir) - df = pd.DataFrame({"file": ["train.png", "valid.png", "test.png"], "split": ["train", "valid", "test"], - "target": [0, 1, 1]}) + df = pd.DataFrame( + {"file": ["train.png", "valid.png", "test.png"], "split": ["train", "valid", "test"], "target": [0, 1, 1]} + ) [_rand_image().save(tmpdir / row.file) for i, row in df.iterrows()] img_data = ImageClassificationData.from_data_frame( - "file", "target", + "file", + "target", train_images_root=str(tmpdir), val_images_root=str(tmpdir), test_images_root=str(tmpdir), @@ -104,7 +106,8 @@ def test_from_data_frame_smoke(tmpdir): test_data_frame=df[df.split == "test"], predict_images_root=str(tmpdir), batch_size=1, - predict_data_frame=df) + predict_data_frame=df, + ) assert img_data.train_dataloader() is not None assert img_data.val_dataloader() is not None From b64b62119b49f7cfa45069522467551cfe0f49f9 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Thu, 30 Dec 2021 10:35:34 +0000 Subject: [PATCH 3/3] Update CHANGELOG.md --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5cb78ee80b..48639ade80 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed a bug where passing the `val_split` to the `DataModule` would not have the desired effect ([#1079](https://github.com/PyTorchLightning/lightning-flash/pull/1079)) +- Fixed a bug where passing `predict_data_frame` to `ImageClassificationData.from_data_frame` raised an error ([#1088](https://github.com/PyTorchLightning/lightning-flash/pull/1088)) + ### Removed ## [0.6.0] - 2021-13-12