Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Fix object detection / keypoint detection / instance segmentation (#1072
Browse files Browse the repository at this point in the history
)
  • Loading branch information
ethanwharris authored Dec 14, 2021
1 parent 924aee8 commit ba72af6
Show file tree
Hide file tree
Showing 13 changed files with 193 additions and 82 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed a bug when not explicitly passing `embedding_sizes` to the `TabularClassifier` and `TabularRegressor` tasks ([#1067](https://github.com/PyTorchLightning/lightning-flash/pull/1067))

- Fixed a bug where under some circumstances transforms would not get called ([#1072](https://github.com/PyTorchLightning/lightning-flash/pull/1072))

### Removed

## [0.6.0] - 2021-13-12
Expand Down
149 changes: 101 additions & 48 deletions flash/core/data/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from flash.core.data.callback import BaseDataFetcher
from flash.core.data.data_pipeline import DataPipeline, DataPipelineState
from flash.core.data.io.input import DataKeys, Input, InputBase, IterableInput
from flash.core.data.io.input_transform import InputTransform
from flash.core.data.io.input_transform import _InputTransformProcessorV2, InputTransform
from flash.core.data.io.output_transform import OutputTransform
from flash.core.data.splits import SplitDataset
from flash.core.data.utils import _STAGES_PREFIX
Expand Down Expand Up @@ -193,8 +193,18 @@ def _resolve_on_after_batch_transfer_fn(self, ds: Optional[Input]) -> Optional[C
return ds._create_on_after_batch_transfer_fn([self.data_fetcher])

def _train_dataloader(self) -> DataLoader:
if isinstance(getattr(self, "trainer", None), pl.Trainer):
if isinstance(self.trainer.lightning_module, flash.Task):
self.connect(self.trainer.lightning_module)

train_ds: Input = self._train_input
collate_fn = self._train_dataloader_collate_fn

transform_processor = None
if isinstance(collate_fn, _InputTransformProcessorV2):
transform_processor = collate_fn
collate_fn = transform_processor.collate_fn

shuffle: bool = False
if isinstance(train_ds, IterableDataset):
drop_last = False
Expand All @@ -208,9 +218,7 @@ def _train_dataloader(self) -> DataLoader:
sampler = self.sampler(train_ds)

if isinstance(getattr(self, "trainer", None), pl.Trainer):
if isinstance(self.trainer.lightning_module, flash.Task):
self.connect(self.trainer.lightning_module)
return self.trainer.lightning_module.process_train_dataset(
dataloader = self.trainer.lightning_module.process_train_dataset(
train_ds,
trainer=self.trainer,
batch_size=self.batch_size,
Expand All @@ -221,97 +229,142 @@ def _train_dataloader(self) -> DataLoader:
collate_fn=collate_fn,
sampler=sampler,
)
else:
dataloader = DataLoader(
train_ds,
batch_size=self.batch_size,
shuffle=shuffle,
sampler=sampler,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
drop_last=drop_last,
collate_fn=collate_fn,
persistent_workers=self.persistent_workers,
)

return DataLoader(
train_ds,
batch_size=self.batch_size,
shuffle=shuffle,
sampler=sampler,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
drop_last=drop_last,
collate_fn=collate_fn,
persistent_workers=self.persistent_workers,
)
if transform_processor is not None:
transform_processor.collate_fn = dataloader.collate_fn
dataloader.collate_fn = transform_processor

return dataloader

def _val_dataloader(self) -> DataLoader:
if isinstance(getattr(self, "trainer", None), pl.Trainer):
if isinstance(self.trainer.lightning_module, flash.Task):
self.connect(self.trainer.lightning_module)

val_ds: Input = self._val_input
collate_fn = self._val_dataloader_collate_fn

transform_processor = None
if isinstance(collate_fn, _InputTransformProcessorV2):
transform_processor = collate_fn
collate_fn = transform_processor.collate_fn

if isinstance(getattr(self, "trainer", None), pl.Trainer):
if isinstance(self.trainer.lightning_module, flash.Task):
self.connect(self.trainer.lightning_module)
return self.trainer.lightning_module.process_val_dataset(
dataloader = self.trainer.lightning_module.process_val_dataset(
val_ds,
trainer=self.trainer,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
collate_fn=collate_fn,
)
else:
dataloader = DataLoader(
val_ds,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
collate_fn=collate_fn,
persistent_workers=self.persistent_workers,
)

return DataLoader(
val_ds,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
collate_fn=collate_fn,
persistent_workers=self.persistent_workers,
)
if transform_processor is not None:
transform_processor.collate_fn = dataloader.collate_fn
dataloader.collate_fn = transform_processor

return dataloader

def _test_dataloader(self) -> DataLoader:
if isinstance(getattr(self, "trainer", None), pl.Trainer):
if isinstance(self.trainer.lightning_module, flash.Task):
self.connect(self.trainer.lightning_module)

test_ds: Input = self._test_input
collate_fn = self._test_dataloader_collate_fn

transform_processor = None
if isinstance(collate_fn, _InputTransformProcessorV2):
transform_processor = collate_fn
collate_fn = transform_processor.collate_fn

if isinstance(getattr(self, "trainer", None), pl.Trainer):
if isinstance(self.trainer.lightning_module, flash.Task):
self.connect(self.trainer.lightning_module)
return self.trainer.lightning_module.process_test_dataset(
dataloader = self.trainer.lightning_module.process_test_dataset(
test_ds,
trainer=self.trainer,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
collate_fn=collate_fn,
)
else:
dataloader = DataLoader(
test_ds,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
collate_fn=collate_fn,
persistent_workers=self.persistent_workers,
)

return DataLoader(
test_ds,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
collate_fn=collate_fn,
persistent_workers=self.persistent_workers,
)
if transform_processor is not None:
transform_processor.collate_fn = dataloader.collate_fn
dataloader.collate_fn = transform_processor

return dataloader

def _predict_dataloader(self) -> DataLoader:
if isinstance(getattr(self, "trainer", None), pl.Trainer):
if isinstance(self.trainer.lightning_module, flash.Task):
self.connect(self.trainer.lightning_module)

predict_ds: Input = self._predict_input
collate_fn = self._predict_dataloader_collate_fn

transform_processor = None
if isinstance(collate_fn, _InputTransformProcessorV2):
transform_processor = collate_fn
collate_fn = transform_processor.collate_fn

if isinstance(predict_ds, IterableDataset):
batch_size = self.batch_size
else:
batch_size = min(self.batch_size, len(predict_ds) if len(predict_ds) > 0 else 1)

if isinstance(getattr(self, "trainer", None), pl.Trainer):
if isinstance(self.trainer.lightning_module, flash.Task):
self.connect(self.trainer.lightning_module)
return self.trainer.lightning_module.process_predict_dataset(
dataloader = self.trainer.lightning_module.process_predict_dataset(
predict_ds,
batch_size=batch_size,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
collate_fn=collate_fn,
)
else:
dataloader = DataLoader(
predict_ds,
batch_size=batch_size,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
collate_fn=collate_fn,
persistent_workers=self.persistent_workers,
)

return DataLoader(
predict_ds,
batch_size=batch_size,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
collate_fn=collate_fn,
persistent_workers=self.persistent_workers,
)
if transform_processor is not None:
transform_processor.collate_fn = dataloader.collate_fn
dataloader.collate_fn = transform_processor

return dataloader

def connect(self, task: "flash.Task"):
data_pipeline_state = DataPipelineState()
Expand Down
13 changes: 6 additions & 7 deletions flash/core/data/io/input_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from functools import partial, wraps
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union

import torch
from pytorch_lightning.utilities.enums import LightningEnum
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch.utils.data._utils.collate import default_collate
Expand All @@ -33,7 +32,7 @@
PerSampleTransformOnDevice,
)
from flash.core.data.transforms import ApplyToKeys
from flash.core.data.utils import _INPUT_TRANSFORM_FUNCS, _STAGES_PREFIX, convert_to_modules
from flash.core.data.utils import _INPUT_TRANSFORM_FUNCS, _STAGES_PREFIX
from flash.core.registry import FlashRegistry
from flash.core.utilities.stages import RunningStage
from flash.core.utilities.types import INPUT_TRANSFORM_TYPE
Expand Down Expand Up @@ -1118,7 +1117,7 @@ def _make_collates(input_transform: "InputTransform", on_device: bool, collate:
return collate, input_transform._identity


class _InputTransformProcessorV2(torch.nn.Module):
class _InputTransformProcessorV2:
"""
This class is used to encapsulate the following functions of a InputTransformInputTransform Object:
Inside a worker:
Expand Down Expand Up @@ -1146,9 +1145,9 @@ def __init__(
super().__init__()
self.input_transform = input_transform
self.callback = ControlFlow(callbacks or [])
self.collate_fn = convert_to_modules(collate_fn)
self.per_sample_transform = convert_to_modules(per_sample_transform)
self.per_batch_transform = convert_to_modules(per_batch_transform)
self.collate_fn = collate_fn
self.per_sample_transform = per_sample_transform
self.per_batch_transform = per_batch_transform
self.apply_per_sample_transform = apply_per_sample_transform
self.stage = stage
self.on_device = on_device
Expand All @@ -1160,7 +1159,7 @@ def _extract_metadata(
metadata = [s.pop(DataKeys.METADATA, None) if isinstance(s, Mapping) else None for s in samples]
return samples, metadata if any(m is not None for m in metadata) else None

def forward(self, samples: Sequence[Any]) -> Any:
def __call__(self, samples: Sequence[Any]) -> Any:
if not self.on_device:
for sample in samples:
self.callback.on_load_sample(sample, self.stage)
Expand Down
4 changes: 2 additions & 2 deletions flash/image/detection/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional, Type, TYPE_CHECKING
from typing import Any, Callable, Dict, List, Optional, Type, TYPE_CHECKING, Union

from flash.core.data.data_module import DataModule
from flash.core.data.data_pipeline import DataPipelineState
Expand Down Expand Up @@ -58,7 +58,7 @@ def from_icedata(
val_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform,
test_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform,
predict_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform,
parser: Type[Parser] = Parser,
parser: Optional[Union[Callable, Type[Parser]]] = None,
input_cls: Type[Input] = IceVisionInput,
transform_kwargs: Optional[Dict] = None,
**data_module_kwargs,
Expand Down
4 changes: 2 additions & 2 deletions flash/image/instance_segmentation/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional, Type
from typing import Any, Callable, Dict, List, Optional, Type, Union

from flash.core.data.data_module import DataModule
from flash.core.data.data_pipeline import DataPipelineState
Expand Down Expand Up @@ -56,7 +56,7 @@ def from_icedata(
val_transform: INPUT_TRANSFORM_TYPE = InstanceSegmentationInputTransform,
test_transform: INPUT_TRANSFORM_TYPE = InstanceSegmentationInputTransform,
predict_transform: INPUT_TRANSFORM_TYPE = InstanceSegmentationInputTransform,
parser: Optional[Type[Parser]] = Parser,
parser: Optional[Union[Callable, Type[Parser]]] = None,
input_cls: Type[Input] = IceVisionInput,
transform_kwargs: Optional[Dict] = None,
**data_module_kwargs,
Expand Down
28 changes: 14 additions & 14 deletions flash/image/keypoint_detection/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional, Type
from typing import Any, Callable, Dict, List, Optional, Type, Union

from flash.core.data.data_module import DataModule
from flash.core.data.data_pipeline import DataPipelineState
from flash.core.data.io.input import Input
from flash.core.integrations.icevision.data import IceVisionInput
from flash.core.integrations.icevision.transforms import IceVisionInputTransform
from flash.core.utilities.imports import _ICEVISION_AVAILABLE
from flash.core.utilities.stages import RunningStage
from flash.core.utilities.types import INPUT_TRANSFORM_TYPE
from flash.image.keypoint_detection.input_transform import KeypointDetectionInputTransform

if _ICEVISION_AVAILABLE:
from icevision.parsers import COCOKeyPointsParser, Parser
Expand All @@ -31,7 +31,7 @@

class KeypointDetectionData(DataModule):

input_transform_cls = IceVisionInputTransform
input_transform_cls = KeypointDetectionInputTransform

@classmethod
def from_icedata(
Expand All @@ -43,11 +43,11 @@ def from_icedata(
test_folder: Optional[str] = None,
test_ann_file: Optional[str] = None,
predict_folder: Optional[str] = None,
train_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform,
val_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform,
test_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform,
predict_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform,
parser: Optional[Type[Parser]] = Parser,
train_transform: INPUT_TRANSFORM_TYPE = KeypointDetectionInputTransform,
val_transform: INPUT_TRANSFORM_TYPE = KeypointDetectionInputTransform,
test_transform: INPUT_TRANSFORM_TYPE = KeypointDetectionInputTransform,
predict_transform: INPUT_TRANSFORM_TYPE = KeypointDetectionInputTransform,
parser: Optional[Union[Callable, Type[Parser]]] = None,
input_cls: Type[Input] = IceVisionInput,
transform_kwargs: Optional[Dict] = None,
**data_module_kwargs: Any,
Expand All @@ -73,10 +73,10 @@ def from_coco(
test_folder: Optional[str] = None,
test_ann_file: Optional[str] = None,
predict_folder: Optional[str] = None,
train_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform,
val_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform,
test_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform,
predict_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform,
train_transform: INPUT_TRANSFORM_TYPE = KeypointDetectionInputTransform,
val_transform: INPUT_TRANSFORM_TYPE = KeypointDetectionInputTransform,
test_transform: INPUT_TRANSFORM_TYPE = KeypointDetectionInputTransform,
predict_transform: INPUT_TRANSFORM_TYPE = KeypointDetectionInputTransform,
parser: Optional[Type[Parser]] = COCOKeyPointsParser,
input_cls: Type[Input] = IceVisionInput,
transform_kwargs: Optional[Dict] = None,
Expand Down Expand Up @@ -119,7 +119,7 @@ def from_coco(
def from_folders(
cls,
predict_folder: Optional[str] = None,
predict_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform,
predict_transform: INPUT_TRANSFORM_TYPE = KeypointDetectionInputTransform,
input_cls: Type[Input] = IceVisionInput,
transform_kwargs: Optional[Dict] = None,
**data_module_kwargs: Any,
Expand Down Expand Up @@ -149,7 +149,7 @@ def from_folders(
def from_files(
cls,
predict_files: Optional[List[str]] = None,
predict_transform: INPUT_TRANSFORM_TYPE = IceVisionInputTransform,
predict_transform: INPUT_TRANSFORM_TYPE = KeypointDetectionInputTransform,
input_cls: Type[Input] = IceVisionInput,
transform_kwargs: Optional[Dict] = None,
**data_module_kwargs: Any,
Expand Down
Loading

0 comments on commit ba72af6

Please sign in to comment.