diff --git a/CHANGELOG.md b/CHANGELOG.md index c2feab80eb..41b3c3c94d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added the normalization parameters of ```torchvision.transforms.Normalize``` as ```transform_kwargs``` in the ```ImageClassificationInputTransform``` ([#1178](https://github.com/PyTorchLightning/lightning-flash/pull/1178)) + ### Changed ### Deprecated diff --git a/docs/source/reference/image_classification.rst b/docs/source/reference/image_classification.rst index fa129662cf..6906cf2b8f 100644 --- a/docs/source/reference/image_classification.rst +++ b/docs/source/reference/image_classification.rst @@ -101,7 +101,7 @@ Here's an example: from torchvision import transforms as T - from typing import Tuple, Callable + from typing import Callable, Tuple, Union import flash from flash.image import ImageClassificationData, ImageClassifier from flash.core.data.io.input_transform import InputTransform @@ -112,18 +112,18 @@ Here's an example: class ImageClassificationInputTransform(InputTransform): image_size: Tuple[int, int] = (196, 196) + mean: Union[float, Tuple[float, float, float]] = (0.485, 0.456, 0.406) + std: Union[float, Tuple[float, float, float]] = (0.229, 0.224, 0.225) def input_per_sample_transform(self): - return T.Compose( - [T.ToTensor(), T.Resize(self.image_size), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])] - ) + return T.Compose([T.ToTensor(), T.Resize(self.image_size), T.Normalize(self.mean, self.std)]) def train_input_per_sample_transform(self): return T.Compose( [ T.ToTensor(), T.Resize(self.image_size), - T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), + T.Normalize(self.mean, self.std), T.RandomHorizontalFlip(), T.ColorJitter(), T.RandomAutocontrast(), diff --git a/flash/image/classification/input_transform.py b/flash/image/classification/input_transform.py index 60f3f05901..598db6c46e 100644 --- a/flash/image/classification/input_transform.py +++ b/flash/image/classification/input_transform.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass -from typing import Callable, Tuple +from typing import Callable, Tuple, Union import torch @@ -43,20 +43,15 @@ def forward(self, x): class ImageClassificationInputTransform(InputTransform): image_size: Tuple[int, int] = (196, 196) + mean: Union[float, Tuple[float, float, float]] = (0.485, 0.456, 0.406) + std: Union[float, Tuple[float, float, float]] = (0.229, 0.224, 0.225) def input_per_sample_transform(self): - return T.Compose( - [T.ToTensor(), T.Resize(self.image_size), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])] - ) + return T.Compose([T.ToTensor(), T.Resize(self.image_size), T.Normalize(self.mean, self.std)]) def train_input_per_sample_transform(self): return T.Compose( - [ - T.ToTensor(), - T.Resize(self.image_size), - T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), - T.RandomHorizontalFlip(), - ] + [T.ToTensor(), T.Resize(self.image_size), T.Normalize(self.mean, self.std), T.RandomHorizontalFlip()] ) def target_per_sample_transform(self) -> Callable: diff --git a/flash_examples/image_classification.py b/flash_examples/image_classification.py index aa096b80e4..9c4dfb37d0 100644 --- a/flash_examples/image_classification.py +++ b/flash_examples/image_classification.py @@ -24,7 +24,7 @@ train_folder="data/hymenoptera_data/train/", val_folder="data/hymenoptera_data/val/", batch_size=4, - transform_kwargs={"image_size": (196, 196)}, + transform_kwargs={"image_size": (196, 196), "mean": (0.485, 0.456, 0.406), "std": (0.229, 0.224, 0.225)}, ) # 2. Build the task