From b2d7f5cf184c674e64f91b60437ab604ea479234 Mon Sep 17 00:00:00 2001 From: Andres Algaba Date: Sat, 19 Feb 2022 14:03:04 +0100 Subject: [PATCH] Expose normalization parameters (#1178) Co-authored-by: Ethan Harris --- CHANGELOG.md | 16 +++++++++++++++- docs/source/reference/image_classification.rst | 10 +++++----- flash/image/classification/input_transform.py | 15 +++++---------- flash_examples/image_classification.py | 2 +- 4 files changed, 26 insertions(+), 17 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 65a1e6c3b5..3c529d56a3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,7 +4,21 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). -## [0.7.0] - 2022-15-02 +## [0.7.1] - 2022-03-01 + +### Added + +- Added the normalization parameters of ```torchvision.transforms.Normalize``` as ```transform_kwargs``` in the ```ImageClassificationInputTransform``` ([#1178](https://github.com/PyTorchLightning/lightning-flash/pull/1178)) + +### Changed + +### Deprecated + +### Removed + +### Fixed + +## [0.7.0] - 2022-02-15 ### Added diff --git a/docs/source/reference/image_classification.rst b/docs/source/reference/image_classification.rst index fa129662cf..6906cf2b8f 100644 --- a/docs/source/reference/image_classification.rst +++ b/docs/source/reference/image_classification.rst @@ -101,7 +101,7 @@ Here's an example: from torchvision import transforms as T - from typing import Tuple, Callable + from typing import Callable, Tuple, Union import flash from flash.image import ImageClassificationData, ImageClassifier from flash.core.data.io.input_transform import InputTransform @@ -112,18 +112,18 @@ Here's an example: class ImageClassificationInputTransform(InputTransform): image_size: Tuple[int, int] = (196, 196) + mean: Union[float, Tuple[float, float, float]] = (0.485, 0.456, 0.406) + std: Union[float, Tuple[float, float, float]] = (0.229, 0.224, 0.225) def input_per_sample_transform(self): - return T.Compose( - [T.ToTensor(), T.Resize(self.image_size), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])] - ) + return T.Compose([T.ToTensor(), T.Resize(self.image_size), T.Normalize(self.mean, self.std)]) def train_input_per_sample_transform(self): return T.Compose( [ T.ToTensor(), T.Resize(self.image_size), - T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), + T.Normalize(self.mean, self.std), T.RandomHorizontalFlip(), T.ColorJitter(), T.RandomAutocontrast(), diff --git a/flash/image/classification/input_transform.py b/flash/image/classification/input_transform.py index 60f3f05901..598db6c46e 100644 --- a/flash/image/classification/input_transform.py +++ b/flash/image/classification/input_transform.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass -from typing import Callable, Tuple +from typing import Callable, Tuple, Union import torch @@ -43,20 +43,15 @@ def forward(self, x): class ImageClassificationInputTransform(InputTransform): image_size: Tuple[int, int] = (196, 196) + mean: Union[float, Tuple[float, float, float]] = (0.485, 0.456, 0.406) + std: Union[float, Tuple[float, float, float]] = (0.229, 0.224, 0.225) def input_per_sample_transform(self): - return T.Compose( - [T.ToTensor(), T.Resize(self.image_size), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])] - ) + return T.Compose([T.ToTensor(), T.Resize(self.image_size), T.Normalize(self.mean, self.std)]) def train_input_per_sample_transform(self): return T.Compose( - [ - T.ToTensor(), - T.Resize(self.image_size), - T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), - T.RandomHorizontalFlip(), - ] + [T.ToTensor(), T.Resize(self.image_size), T.Normalize(self.mean, self.std), T.RandomHorizontalFlip()] ) def target_per_sample_transform(self) -> Callable: diff --git a/flash_examples/image_classification.py b/flash_examples/image_classification.py index aa096b80e4..9c4dfb37d0 100644 --- a/flash_examples/image_classification.py +++ b/flash_examples/image_classification.py @@ -24,7 +24,7 @@ train_folder="data/hymenoptera_data/train/", val_folder="data/hymenoptera_data/val/", batch_size=4, - transform_kwargs={"image_size": (196, 196)}, + transform_kwargs={"image_size": (196, 196), "mean": (0.485, 0.456, 0.406), "std": (0.229, 0.224, 0.225)}, ) # 2. Build the task