Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Expose normalization parameters (#1178)
Browse files Browse the repository at this point in the history
Co-authored-by: Ethan Harris <[email protected]>
  • Loading branch information
AndresAlgaba and ethanwharris committed Mar 1, 2022
1 parent cf4609e commit b2d7f5c
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 17 deletions.
16 changes: 15 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,21 @@ All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [0.7.0] - 2022-15-02
## [0.7.1] - 2022-03-01

### Added

- Added the normalization parameters of ```torchvision.transforms.Normalize``` as ```transform_kwargs``` in the ```ImageClassificationInputTransform``` ([#1178](https://github.com/PyTorchLightning/lightning-flash/pull/1178))

### Changed

### Deprecated

### Removed

### Fixed

## [0.7.0] - 2022-02-15

### Added

Expand Down
10 changes: 5 additions & 5 deletions docs/source/reference/image_classification.rst
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ Here's an example:

from torchvision import transforms as T

from typing import Tuple, Callable
from typing import Callable, Tuple, Union
import flash
from flash.image import ImageClassificationData, ImageClassifier
from flash.core.data.io.input_transform import InputTransform
Expand All @@ -112,18 +112,18 @@ Here's an example:
class ImageClassificationInputTransform(InputTransform):

image_size: Tuple[int, int] = (196, 196)
mean: Union[float, Tuple[float, float, float]] = (0.485, 0.456, 0.406)
std: Union[float, Tuple[float, float, float]] = (0.229, 0.224, 0.225)

def input_per_sample_transform(self):
return T.Compose(
[T.ToTensor(), T.Resize(self.image_size), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]
)
return T.Compose([T.ToTensor(), T.Resize(self.image_size), T.Normalize(self.mean, self.std)])

def train_input_per_sample_transform(self):
return T.Compose(
[
T.ToTensor(),
T.Resize(self.image_size),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
T.Normalize(self.mean, self.std),
T.RandomHorizontalFlip(),
T.ColorJitter(),
T.RandomAutocontrast(),
Expand Down
15 changes: 5 additions & 10 deletions flash/image/classification/input_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Callable, Tuple
from typing import Callable, Tuple, Union

import torch

Expand Down Expand Up @@ -43,20 +43,15 @@ def forward(self, x):
class ImageClassificationInputTransform(InputTransform):

image_size: Tuple[int, int] = (196, 196)
mean: Union[float, Tuple[float, float, float]] = (0.485, 0.456, 0.406)
std: Union[float, Tuple[float, float, float]] = (0.229, 0.224, 0.225)

def input_per_sample_transform(self):
return T.Compose(
[T.ToTensor(), T.Resize(self.image_size), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]
)
return T.Compose([T.ToTensor(), T.Resize(self.image_size), T.Normalize(self.mean, self.std)])

def train_input_per_sample_transform(self):
return T.Compose(
[
T.ToTensor(),
T.Resize(self.image_size),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
T.RandomHorizontalFlip(),
]
[T.ToTensor(), T.Resize(self.image_size), T.Normalize(self.mean, self.std), T.RandomHorizontalFlip()]
)

def target_per_sample_transform(self) -> Callable:
Expand Down
2 changes: 1 addition & 1 deletion flash_examples/image_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
train_folder="data/hymenoptera_data/train/",
val_folder="data/hymenoptera_data/val/",
batch_size=4,
transform_kwargs={"image_size": (196, 196)},
transform_kwargs={"image_size": (196, 196), "mean": (0.485, 0.456, 0.406), "std": (0.229, 0.224, 0.225)},
)

# 2. Build the task
Expand Down

0 comments on commit b2d7f5c

Please sign in to comment.