Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Refactor InputTransform and DataModule (#1233)
Browse files Browse the repository at this point in the history
Co-authored-by: Kushashwa Ravi Shrimali <[email protected]>
Co-authored-by: Ethan Harris <[email protected]>
  • Loading branch information
3 people authored Mar 25, 2022
1 parent 9001449 commit 6da53fe
Show file tree
Hide file tree
Showing 57 changed files with 1,386 additions and 1,856 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ class MixUpInputTransform(InputTransform):

datamodule = ImageClassificationData.from_folders(
train_folder="data/train",
train_transform=MixUpInputTransform,
transform=MixUpInputTransform,
batch_size=2,
)

Expand Down
4 changes: 2 additions & 2 deletions docs/source/integrations/icevision.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ Here's an example:
from flash.core.integrations.icevision.transforms import IceVisionTransformAdapter
from flash.image import ObjectDetectionData
train_transform = {
transform = {
"per_sample_transform": IceVisionTransformAdapter([A.HorizontalFlip(), A.Normalize()]),
}
datamodule = ObjectDetectionData.from_coco(
...,
train_transform=train_transform,
transform=transform,
)
2 changes: 1 addition & 1 deletion docs/source/reference/image_classification.rst
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ Here's an example:
datamodule = ImageClassificationData.from_folders(
train_folder="data/hymenoptera_data/train/",
val_folder="data/hymenoptera_data/val/",
train_transform=ImageClassificationInputTransform,
transform=ImageClassificationInputTransform,
transform_kwargs=dict(image_size=(128, 128)),
batch_size=1,
)
Expand Down
2 changes: 1 addition & 1 deletion docs/source/reference/object_detection.rst
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,6 @@ creating a subclass of :class:`~flash.core.data.io.input_transform.InputTransfor
train_folder="data/coco128/images/train2017/",
train_ann_file="data/coco128/annotations/instances_train2017.json",
val_split=0.1,
train_transform=BrightnessContrastTransform,
transform=BrightnessContrastTransform,
batch_size=4,
)
152 changes: 54 additions & 98 deletions flash/audio/classification/data.py

Large diffs are not rendered by default.

94 changes: 32 additions & 62 deletions flash/audio/speech_recognition/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from flash.core.data.data_module import DataModule
from flash.core.data.io.input import Input
from flash.core.data.io.input_transform import INPUT_TRANSFORM_TYPE, InputTransform
from flash.core.registry import FlashRegistry
from flash.core.utilities.imports import _AUDIO_TESTING
from flash.core.utilities.stages import RunningStage

Expand All @@ -40,7 +39,6 @@ class SpeechRecognitionData(DataModule):

input_transform_cls = InputTransform
output_transform_cls = SpeechRecognitionOutputTransform
input_transforms_registry = FlashRegistry("input_transforms")

@classmethod
def from_files(
Expand All @@ -53,11 +51,8 @@ def from_files(
test_targets: Optional[Sequence[str]] = None,
predict_files: Optional[Sequence[str]] = None,
sampling_rate: int = 16000,
train_transform: INPUT_TRANSFORM_TYPE = InputTransform,
val_transform: INPUT_TRANSFORM_TYPE = InputTransform,
test_transform: INPUT_TRANSFORM_TYPE = InputTransform,
predict_transform: INPUT_TRANSFORM_TYPE = InputTransform,
input_cls: Type[Input] = SpeechRecognitionPathsInput,
transform: INPUT_TRANSFORM_TYPE = InputTransform,
transform_kwargs: Optional[Dict] = None,
**data_module_kwargs: Any,
) -> "SpeechRecognitionData":
Expand All @@ -77,12 +72,8 @@ def from_files(
test_targets: The list of targets (ground truth speech transcripts) to use when testing.
predict_files: The list of audio files to use when predicting.
sampling_rate: Sampling rate to use when loading the audio files.
train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training.
val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating.
test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing.
predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when
predicting.
input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data.
transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use.
transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms.
data_module_kwargs: Additional keyword arguments to provide to the
:class:`~flash.core.data.data_module.DataModule` constructor.
Expand Down Expand Up @@ -127,16 +118,16 @@ def from_files(
"""

ds_kw = dict(
transform_kwargs=transform_kwargs,
input_transforms_registry=cls.input_transforms_registry,
sampling_rate=sampling_rate,
)

return cls(
input_cls(RunningStage.TRAINING, train_files, train_targets, transform=train_transform, **ds_kw),
input_cls(RunningStage.VALIDATING, val_files, val_targets, transform=val_transform, **ds_kw),
input_cls(RunningStage.TESTING, test_files, test_targets, transform=test_transform, **ds_kw),
input_cls(RunningStage.PREDICTING, predict_files, transform=predict_transform, **ds_kw),
input_cls(RunningStage.TRAINING, train_files, train_targets, **ds_kw),
input_cls(RunningStage.VALIDATING, val_files, val_targets, **ds_kw),
input_cls(RunningStage.TESTING, test_files, test_targets, **ds_kw),
input_cls(RunningStage.PREDICTING, predict_files, **ds_kw),
transform=transform,
transform_kwargs=transform_kwargs,
**data_module_kwargs,
)

Expand All @@ -150,11 +141,8 @@ def from_csv(
test_file: Optional[str] = None,
predict_file: Optional[str] = None,
sampling_rate: int = 16000,
train_transform: INPUT_TRANSFORM_TYPE = InputTransform,
val_transform: INPUT_TRANSFORM_TYPE = InputTransform,
test_transform: INPUT_TRANSFORM_TYPE = InputTransform,
predict_transform: INPUT_TRANSFORM_TYPE = InputTransform,
input_cls: Type[Input] = SpeechRecognitionCSVInput,
transform: INPUT_TRANSFORM_TYPE = InputTransform,
transform_kwargs: Optional[Dict] = None,
**data_module_kwargs: Any,
) -> "SpeechRecognitionData":
Expand All @@ -175,12 +163,8 @@ def from_csv(
test_file: The CSV file to use when testing.
predict_file: The CSV file to use when predicting.
sampling_rate: Sampling rate to use when loading the audio files.
train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training.
val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating.
test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing.
predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when
predicting.
input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data.
transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use.
transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms.
data_module_kwargs: Additional keyword arguments to provide to the
:class:`~flash.core.data.data_module.DataModule` constructor.
Expand Down Expand Up @@ -255,17 +239,17 @@ def from_csv(
"""

ds_kw = dict(
transform_kwargs=transform_kwargs,
input_transforms_registry=cls.input_transforms_registry,
input_key=input_field,
sampling_rate=sampling_rate,
)

return cls(
input_cls(RunningStage.TRAINING, train_file, transform=train_transform, target_key=target_field, **ds_kw),
input_cls(RunningStage.VALIDATING, val_file, transform=val_transform, target_key=target_field, **ds_kw),
input_cls(RunningStage.TESTING, test_file, transform=test_transform, target_key=target_field, **ds_kw),
input_cls(RunningStage.PREDICTING, predict_file, transform=predict_transform, **ds_kw),
input_cls(RunningStage.TRAINING, train_file, target_key=target_field, **ds_kw),
input_cls(RunningStage.VALIDATING, val_file, target_key=target_field, **ds_kw),
input_cls(RunningStage.TESTING, test_file, target_key=target_field, **ds_kw),
input_cls(RunningStage.PREDICTING, predict_file, **ds_kw),
transform=transform,
transform_kwargs=transform_kwargs,
**data_module_kwargs,
)

Expand All @@ -280,11 +264,8 @@ def from_json(
predict_file: Optional[str] = None,
sampling_rate: int = 16000,
field: Optional[str] = None,
train_transform: INPUT_TRANSFORM_TYPE = InputTransform,
val_transform: INPUT_TRANSFORM_TYPE = InputTransform,
test_transform: INPUT_TRANSFORM_TYPE = InputTransform,
predict_transform: INPUT_TRANSFORM_TYPE = InputTransform,
input_cls: Type[Input] = SpeechRecognitionJSONInput,
transform: INPUT_TRANSFORM_TYPE = InputTransform,
transform_kwargs: Optional[Dict] = None,
**data_module_kwargs: Any,
) -> "SpeechRecognitionData":
Expand All @@ -306,12 +287,8 @@ def from_json(
predict_file: The JSON file to use when predicting.
sampling_rate: Sampling rate to use when loading the audio files.
field: The field that holds the data in the JSON file.
train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training.
val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating.
test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing.
predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when
predicting.
input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data.
transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use.
transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms.
data_module_kwargs: Additional keyword arguments to provide to the
:class:`~flash.core.data.data_module.DataModule` constructor.
Expand Down Expand Up @@ -384,18 +361,18 @@ def from_json(
"""

ds_kw = dict(
transform_kwargs=transform_kwargs,
input_transforms_registry=cls.input_transforms_registry,
input_key=input_field,
sampling_rate=sampling_rate,
field=field,
)

return cls(
input_cls(RunningStage.TRAINING, train_file, transform=train_transform, target_key=target_field, **ds_kw),
input_cls(RunningStage.VALIDATING, val_file, transform=val_transform, target_key=target_field, **ds_kw),
input_cls(RunningStage.TESTING, test_file, transform=test_transform, target_key=target_field, **ds_kw),
input_cls(RunningStage.PREDICTING, predict_file, transform=predict_transform, **ds_kw),
input_cls(RunningStage.TRAINING, train_file, target_key=target_field, **ds_kw),
input_cls(RunningStage.VALIDATING, val_file, target_key=target_field, **ds_kw),
input_cls(RunningStage.TESTING, test_file, target_key=target_field, **ds_kw),
input_cls(RunningStage.PREDICTING, predict_file, **ds_kw),
transform=transform,
transform_kwargs=transform_kwargs,
**data_module_kwargs,
)

Expand All @@ -406,12 +383,9 @@ def from_datasets(
val_dataset: Optional[Dataset] = None,
test_dataset: Optional[Dataset] = None,
predict_dataset: Optional[Dataset] = None,
train_transform: INPUT_TRANSFORM_TYPE = InputTransform,
val_transform: INPUT_TRANSFORM_TYPE = InputTransform,
test_transform: INPUT_TRANSFORM_TYPE = InputTransform,
predict_transform: INPUT_TRANSFORM_TYPE = InputTransform,
sampling_rate: int = 16000,
input_cls: Type[Input] = SpeechRecognitionDatasetInput,
transform: INPUT_TRANSFORM_TYPE = InputTransform,
transform_kwargs: Optional[Dict] = None,
**data_module_kwargs: Any,
) -> "SpeechRecognitionData":
Expand All @@ -433,12 +407,8 @@ def from_datasets(
test_dataset: The Dataset to use when testing.
predict_dataset: The Dataset to use when predicting.
sampling_rate: Sampling rate to use when loading the audio files.
train_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when training.
val_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when validating.
test_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when testing.
predict_transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use when
predicting.
input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data.
transform: The :class:`~flash.core.data.io.input_transform.InputTransform` type to use.
transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms.
data_module_kwargs: Additional keyword arguments to provide to the
:class:`~flash.core.data.data_module.DataModule` constructor.
Expand Down Expand Up @@ -540,15 +510,15 @@ def from_datasets(
"""

ds_kw = dict(
transform_kwargs=transform_kwargs,
input_transforms_registry=cls.input_transforms_registry,
sampling_rate=sampling_rate,
)

return cls(
input_cls(RunningStage.TRAINING, train_dataset, transform=train_transform, **ds_kw),
input_cls(RunningStage.VALIDATING, val_dataset, transform=val_transform, **ds_kw),
input_cls(RunningStage.TESTING, test_dataset, transform=test_transform, **ds_kw),
input_cls(RunningStage.PREDICTING, predict_dataset, transform=predict_transform, **ds_kw),
input_cls(RunningStage.TRAINING, train_dataset, **ds_kw),
input_cls(RunningStage.VALIDATING, val_dataset, **ds_kw),
input_cls(RunningStage.TESTING, test_dataset, **ds_kw),
input_cls(RunningStage.PREDICTING, predict_dataset, **ds_kw),
transform=transform,
transform_kwargs=transform_kwargs,
**data_module_kwargs,
)
23 changes: 16 additions & 7 deletions flash/core/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import flash
from flash.core.data.io.input import InputBase
from flash.core.data.io.input_transform import InputTransform
from flash.core.model import DatasetProcessor, ModuleWrapperBase, Task
from flash.core.utilities.types import INPUT_TRANSFORM_TYPE

Expand Down Expand Up @@ -121,6 +122,7 @@ def process_train_dataset(
self,
dataset: InputBase,
trainer: "flash.Trainer",
input_transform: InputTransform,
batch_size: int,
num_workers: int,
pin_memory: bool,
Expand All @@ -131,8 +133,9 @@ def process_train_dataset(
persistent_workers: bool = False,
) -> DataLoader:
return self.adapter.process_train_dataset(
dataset,
trainer,
dataset=dataset,
trainer=trainer,
input_transform=input_transform,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=pin_memory,
Expand All @@ -147,6 +150,7 @@ def process_val_dataset(
self,
dataset: InputBase,
trainer: "flash.Trainer",
input_transform: InputTransform,
batch_size: int,
num_workers: int,
pin_memory: bool,
Expand All @@ -157,8 +161,9 @@ def process_val_dataset(
persistent_workers: bool = False,
) -> DataLoader:
return self.adapter.process_val_dataset(
dataset,
trainer,
dataset=dataset,
trainer=trainer,
input_transform=input_transform,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=pin_memory,
Expand All @@ -173,6 +178,7 @@ def process_test_dataset(
self,
dataset: InputBase,
trainer: "flash.Trainer",
input_transform: InputTransform,
batch_size: int,
num_workers: int,
pin_memory: bool,
Expand All @@ -183,8 +189,9 @@ def process_test_dataset(
persistent_workers: bool = False,
) -> DataLoader:
return self.adapter.process_test_dataset(
dataset,
trainer,
dataset=dataset,
trainer=trainer,
input_transform=input_transform,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=pin_memory,
Expand All @@ -198,6 +205,7 @@ def process_test_dataset(
def process_predict_dataset(
self,
dataset: InputBase,
input_transform: InputTransform,
batch_size: int = 1,
num_workers: int = 0,
pin_memory: bool = False,
Expand All @@ -208,7 +216,8 @@ def process_predict_dataset(
persistent_workers: bool = False,
) -> DataLoader:
return self.adapter.process_predict_dataset(
dataset,
dataset=dataset,
input_transform=input_transform,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=pin_memory,
Expand Down
Loading

0 comments on commit 6da53fe

Please sign in to comment.