Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

enforce random sampling at first call in kornia parallell transforms #351

Merged
merged 7 commits into from
Jun 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Changed

- Changed the installation command for extra features ([#346](https://github.com/PyTorchLightning/lightning-flash/pull/346))


- Fixed a bug where the translation task wasn't decoding tokens properly ([#332](https://github.com/PyTorchLightning/lightning-flash/pull/332))


- Fixed a bug where huggingface tokenizers were sometimes being pickled ([#332](https://github.com/PyTorchLightning/lightning-flash/pull/332))
- Fixed issue with `KorniaParallelTransforms` to assure to share the random state between transforms ([#351](https://github.com/PyTorchLightning/lightning-flash/pull/351))
- Change resize interpolation default mode to nearest ([#352](https://github.com/PyTorchLightning/lightning-flash/pull/352))


## [0.3.0] - 2021-05-20
Expand Down
11 changes: 8 additions & 3 deletions flash/core/data/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,16 +78,21 @@ def forward(self, inputs: Any):
result = list(inputs) if isinstance(inputs, Sequence) else [inputs]
for transform in self.children():
inputs = result

# we enforce the first time to sample random params
result[0] = transform(inputs[0])

if hasattr(transform, "_params") and bool(transform._params):
params = transform._params
else:
params = None

for i, input in enumerate(inputs):
# apply transforms from (1, n)
for i, input in enumerate(inputs[1:]):
if params is not None:
result[i] = transform(input, params)
result[i + 1] = transform(input, params)
else: # case for non random transforms
result[i] = transform(input)
result[i + 1] = transform(input)
if hasattr(transform, "_params") and bool(transform._params):
transform._params = None
return result
Expand Down
2 changes: 1 addition & 1 deletion flash/image/segmentation/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def default_transforms(image_size: Tuple[int, int]) -> Dict[str, Callable]:
"post_tensor_transform": nn.Sequential(
ApplyToKeys(
[DefaultDataKeys.INPUT, DefaultDataKeys.TARGET],
KorniaParallelTransforms(K.geometry.Resize(image_size, interpolation='bilinear')),
KorniaParallelTransforms(K.geometry.Resize(image_size, interpolation='nearest')),
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved
),
),
"collate": Compose([kornia_collate, ApplyToKeys(DefaultDataKeys.TARGET, prepare_target)]),
Expand Down
7 changes: 4 additions & 3 deletions tests/core/data/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,17 +97,18 @@ def test_kornia_parallel_transforms(with_params):
transform_b = Mock(spec=torch.nn.Module)

if with_params:
transform_a._params = "test"
transform_a._params = "test" # initialize params with some value

parallel_transforms = KorniaParallelTransforms(transform_a, transform_b)

parallel_transforms(samples)

assert transform_a.call_count == 2
assert transform_b.call_count == 2

if with_params:
assert transform_a.call_args_list[0][0][1] == transform_a.call_args_list[1][0][1] == "test"
assert transform_a.call_args_list[1][0][1] == "test"
# check that after the forward `_params` is set to None
assert transform_a._params == transform_a._params is None

assert torch.allclose(transform_a.call_args_list[0][0][0], samples[0])
assert torch.allclose(transform_a.call_args_list[1][0][0], samples[1])
Expand Down