Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Add output argument to Trainer.predict and remove `DataPipelineSt…
Browse files Browse the repository at this point in the history
…ate` (#1157)
  • Loading branch information
ethanwharris authored Feb 14, 2022
1 parent defbace commit 796c9c8
Show file tree
Hide file tree
Showing 150 changed files with 1,729 additions and 2,779 deletions.
6 changes: 5 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added `FlashRegistry` of Available Heads for `flash.image.ImageClassifier` ([#1152](https://github.com/PyTorchLightning/lightning-flash/pull/1152))

- Added support for `objectDetectionData.from_files` ([#1154](https://github.com/PyTorchLightning/lightning-flash/pull/1154))
- Added support for `ObjectDetectionData.from_files` ([#1154](https://github.com/PyTorchLightning/lightning-flash/pull/1154))

- Added support for passing the `Output` object (or a string e.g. `"labels"`) to the `flash.Trainer.predict` method ([#1157](https://github.com/PyTorchLightning/lightning-flash/pull/1157))

### Changed

Expand Down Expand Up @@ -86,6 +88,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Removed the `Seq2SeqData` base class (use `TranslationData` or `SummarizationData` directly) ([#1128](https://github.com/PyTorchLightning/lightning-flash/pull/1128))

- Removed the ability to attach the `Output` object directly to the model ([#1157](https://github.com/PyTorchLightning/lightning-flash/pull/1157))

## [0.6.0] - 2021-13-12

### Added
Expand Down
1 change: 0 additions & 1 deletion docs/source/api/audio.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,3 @@ __________________
speech_recognition.input.SpeechRecognitionDatasetInput
speech_recognition.input.SpeechRecognitionDeserializer
speech_recognition.output_transform.SpeechRecognitionOutputTransform
speech_recognition.output_transform.SpeechRecognitionBackboneState
4 changes: 0 additions & 4 deletions docs/source/api/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ _____________________________
:template: classtemplate.rst

~flash.core.data.data_pipeline.DataPipeline
~flash.core.data.data_pipeline.DataPipelineState

flash.core.data.io.input
___________________________
Expand All @@ -80,7 +79,6 @@ ___________________________
~flash.core.data.io.input.Input
~flash.core.data.io.input.DataKeys
~flash.core.data.io.input.InputFormat
~flash.core.data.io.input.ImageLabelsMap

flash.core.data.io.classification_input
_______________________________________
Expand All @@ -90,7 +88,6 @@ _______________________________________
:nosignatures:
:template: classtemplate.rst

~flash.core.data.io.classification_input.ClassificationState
~flash.core.data.io.classification_input.ClassificationInputMixin

flash.core.data.utilities.classification
Expand Down Expand Up @@ -133,7 +130,6 @@ __________________________
:nosignatures:
:template: classtemplate.rst

~flash.core.data.properties.ProcessState
~flash.core.data.properties.Properties

flash.core.data.splits
Expand Down
1 change: 0 additions & 1 deletion docs/source/api/tabular.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ ___________
~forecasting.data.TabularForecastingData

forecasting.input.TabularForecastingDataFrameInput
forecasting.input.TimeSeriesDataSetParametersState

flash.tabular.data
__________________
Expand Down
2 changes: 0 additions & 2 deletions docs/source/api/text.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ __________________
question_answering.input.QuestionAnsweringJSONInput
question_answering.input.QuestionAnsweringSQuADInput
question_answering.input.QuestionAnsweringDictionaryInput
question_answering.input_transform.QuestionAnsweringInputTransform
question_answering.output_transform.QuestionAnsweringOutputTransform

Summarization
Expand Down Expand Up @@ -92,7 +91,6 @@ _______________
seq2seq.core.input.Seq2SeqCSVInput
seq2seq.core.input.Seq2SeqJSONInput
seq2seq.core.input.Seq2SeqListInput
seq2seq.core.output_transform.Seq2SeqOutputTransform

flash.text.input
________________
Expand Down
7 changes: 2 additions & 5 deletions docs/source/common/finetuning_example.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ Here's an example of finetuning.
)

# 2. Build the model using desired Task
model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes)
model = ImageClassifier(backbone="resnet18", labels=datamodule.labels)

# 3. Create the trainer (run one epoch for demo)
trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count())
Expand All @@ -56,17 +56,14 @@ Once you've finetuned, use the model to predict:

.. testcode:: finetune

# Output predictions as labels, automatically inferred from the training data in part 2.
model.output = LabelsOutput()

predict_datamodule = ImageClassificationData.from_files(
predict_files=[
"data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg",
"data/hymenoptera_data/val/ants/2255445811_dabcdf7258.jpg",
],
batch_size=1,
)
predictions = trainer.predict(model, datamodule=predict_datamodule)
predictions = trainer.predict(model, datamodule=predict_datamodule, output="labels")
print(predictions)

We get the following output:
Expand Down
2 changes: 1 addition & 1 deletion docs/source/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ Here's an example of inference:
],
batch_size=4,
)
predictions = trainer.predict(model, datamodule=datamodule)
predictions = trainer.predict(model, datamodule=datamodule, output="labels")
print(predictions)

We get the following output:
Expand Down
6 changes: 3 additions & 3 deletions docs/source/template/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -172,11 +172,11 @@ OutputTransform

:class:`~flash.core.data.io.output_transform.OutputTransform` contains any transforms that need to be applied *after* the model.
You may want to use it for: converting tokens back into text, applying an inverse normalization to an output image, resizing a generated image back to the size of the input, etc.
As an example, here's the :class:`~text.seq2seq.core.output_transform.Seq2SeqOutputTransform` which decodes tokenized model outputs:
As an example, here's the :class:`~image.segmentation.model.SemanticSegmentationOutputTransform` which decodes tokenized model outputs:

.. literalinclude:: ../../../flash/text/seq2seq/core/output_transform.py
.. literalinclude:: ../../../flash/image/segmentation/model.py
:language: python
:pyobject: Seq2SeqOutputTransform
:pyobject: SemanticSegmentationOutputTransform

In your :class:`~flash.core.data.io.input.Input` or :class:`~flash.core.data.io.input_transform.InputTransform`, you can add metadata to the batch using the :attr:`~flash.core.data.io.input.DataKeys.METADATA` key.
Your :class:`~flash.core.data.io.output_transform.OutputTransform` can then use this metadata in its transforms.
Expand Down
1 change: 1 addition & 0 deletions flash/audio/classification/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def audio_classification():
default_arguments={
"trainer.max_epochs": 3,
},
datamodule_attributes={"num_classes", "labels", "multi_label"},
)

cli.trainer.save_checkpoint("audio_classification_model.pt")
Expand Down
37 changes: 24 additions & 13 deletions flash/audio/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from flash.audio.classification.input_transform import AudioClassificationInputTransform
from flash.core.data.callback import BaseDataFetcher
from flash.core.data.data_module import DataModule
from flash.core.data.data_pipeline import DataPipelineState
from flash.core.data.io.input import Input
from flash.core.data.io.input_transform import INPUT_TRANSFORM_TYPE
from flash.core.data.utilities.paths import PATH_TYPE
Expand Down Expand Up @@ -143,13 +142,15 @@ def from_files(
"""

ds_kw = dict(
data_pipeline_state=DataPipelineState(),
transform_kwargs=transform_kwargs,
input_transforms_registry=cls.input_transforms_registry,
)

train_input = input_cls(RunningStage.TRAINING, train_files, train_targets, transform=train_transform, **ds_kw)
ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None)

return cls(
input_cls(RunningStage.TRAINING, train_files, train_targets, transform=train_transform, **ds_kw),
train_input,
input_cls(RunningStage.VALIDATING, val_files, val_targets, transform=val_transform, **ds_kw),
input_cls(RunningStage.TESTING, test_files, test_targets, transform=test_transform, **ds_kw),
input_cls(RunningStage.PREDICTING, predict_files, transform=predict_transform, **ds_kw),
Expand Down Expand Up @@ -269,13 +270,15 @@ def from_folders(
"""

ds_kw = dict(
data_pipeline_state=DataPipelineState(),
transform_kwargs=transform_kwargs,
input_transforms_registry=cls.input_transforms_registry,
)

train_input = input_cls(RunningStage.TRAINING, train_folder, transform=train_transform, **ds_kw)
ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None)

return cls(
input_cls(RunningStage.TRAINING, train_folder, transform=train_transform, **ds_kw),
train_input,
input_cls(RunningStage.VALIDATING, val_folder, transform=val_transform, **ds_kw),
input_cls(RunningStage.TESTING, test_folder, transform=test_transform, **ds_kw),
input_cls(RunningStage.PREDICTING, predict_folder, transform=predict_transform, **ds_kw),
Expand Down Expand Up @@ -358,13 +361,15 @@ def from_numpy(
"""

ds_kw = dict(
data_pipeline_state=DataPipelineState(),
transform_kwargs=transform_kwargs,
input_transforms_registry=cls.input_transforms_registry,
)

train_input = input_cls(RunningStage.TRAINING, train_data, train_targets, transform=train_transform, **ds_kw)
ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None)

return cls(
input_cls(RunningStage.TRAINING, train_data, train_targets, transform=train_transform, **ds_kw),
train_input,
input_cls(RunningStage.VALIDATING, val_data, val_targets, transform=val_transform, **ds_kw),
input_cls(RunningStage.TESTING, test_data, test_targets, transform=test_transform, **ds_kw),
input_cls(RunningStage.PREDICTING, predict_data, transform=predict_transform, **ds_kw),
Expand Down Expand Up @@ -447,13 +452,15 @@ def from_tensors(
"""

ds_kw = dict(
data_pipeline_state=DataPipelineState(),
transform_kwargs=transform_kwargs,
input_transforms_registry=cls.input_transforms_registry,
)

train_input = input_cls(RunningStage.TRAINING, train_data, train_targets, transform=train_transform, **ds_kw)
ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None)

return cls(
input_cls(RunningStage.TRAINING, train_data, train_targets, transform=train_transform, **ds_kw),
train_input,
input_cls(RunningStage.VALIDATING, val_data, val_targets, transform=val_transform, **ds_kw),
input_cls(RunningStage.TESTING, test_data, test_targets, transform=test_transform, **ds_kw),
input_cls(RunningStage.PREDICTING, predict_data, transform=predict_transform, **ds_kw),
Expand Down Expand Up @@ -595,7 +602,6 @@ def from_data_frame(
"""

ds_kw = dict(
data_pipeline_state=DataPipelineState(),
transform_kwargs=transform_kwargs,
input_transforms_registry=cls.input_transforms_registry,
)
Expand All @@ -605,8 +611,11 @@ def from_data_frame(
test_data = (test_data_frame, input_field, target_fields, test_images_root, test_resolver)
predict_data = (predict_data_frame, input_field, None, predict_images_root, predict_resolver)

train_input = input_cls(RunningStage.TRAINING, *train_data, transform=train_transform, **ds_kw)
ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None)

return cls(
input_cls(RunningStage.TRAINING, *train_data, transform=train_transform, **ds_kw),
train_input,
input_cls(RunningStage.VALIDATING, *val_data, transform=val_transform, **ds_kw),
input_cls(RunningStage.TESTING, *test_data, transform=test_transform, **ds_kw),
input_cls(RunningStage.PREDICTING, *predict_data, transform=predict_transform, **ds_kw),
Expand Down Expand Up @@ -758,7 +767,6 @@ def from_csv(
"""

ds_kw = dict(
data_pipeline_state=DataPipelineState(),
transform_kwargs=transform_kwargs,
input_transforms_registry=cls.input_transforms_registry,
)
Expand All @@ -768,8 +776,11 @@ def from_csv(
test_data = (test_file, input_field, target_fields, test_images_root, test_resolver)
predict_data = (predict_file, input_field, None, predict_images_root, predict_resolver)

train_input = input_cls(RunningStage.TRAINING, *train_data, transform=train_transform, **ds_kw)
ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None)

return cls(
input_cls(RunningStage.TRAINING, *train_data, transform=train_transform, **ds_kw),
train_input,
input_cls(RunningStage.VALIDATING, *val_data, transform=val_transform, **ds_kw),
input_cls(RunningStage.TESTING, *test_data, transform=test_transform, **ds_kw),
input_cls(RunningStage.PREDICTING, *predict_data, transform=predict_transform, **ds_kw),
Expand Down
32 changes: 19 additions & 13 deletions flash/audio/classification/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
import numpy as np
import pandas as pd

from flash.core.data.io.classification_input import ClassificationInputMixin, ClassificationState
from flash.core.data.io.classification_input import ClassificationInputMixin
from flash.core.data.io.input import DataKeys, Input
from flash.core.data.utilities.classification import MultiBinaryTargetFormatter
from flash.core.data.utilities.classification import MultiBinaryTargetFormatter, TargetFormatter
from flash.core.data.utilities.data_frame import read_csv, resolve_files, resolve_targets
from flash.core.data.utilities.paths import filter_valid_files, has_file_allowed_extension, make_dataset, PATH_TYPE
from flash.core.data.utilities.samples import to_samples
Expand Down Expand Up @@ -54,12 +54,13 @@ def load_data(
self,
files: List[PATH_TYPE],
targets: Optional[List[Any]] = None,
target_formatter: Optional[TargetFormatter] = None,
) -> List[Dict[str, Any]]:
if targets is None:
files = filter_valid_files(files, valid_extensions=IMG_EXTENSIONS + NP_EXTENSIONS)
return to_samples(files)
files, targets = filter_valid_files(files, targets, valid_extensions=IMG_EXTENSIONS + NP_EXTENSIONS)
self.load_target_metadata(targets)
self.load_target_metadata(targets, target_formatter=target_formatter)
return to_samples(files, targets)

def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
Expand All @@ -71,15 +72,17 @@ def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:


class AudioClassificationFolderInput(AudioClassificationFilesInput):
def load_data(self, folder: PATH_TYPE) -> List[Dict[str, Any]]:
def load_data(self, folder: PATH_TYPE, target_formatter: Optional[TargetFormatter] = None) -> List[Dict[str, Any]]:
files, targets = make_dataset(folder, extensions=IMG_EXTENSIONS + NP_EXTENSIONS)
return super().load_data(files, targets)
return super().load_data(files, targets, target_formatter=target_formatter)


class AudioClassificationNumpyInput(AudioClassificationInput):
def load_data(self, array: Any, targets: Optional[List[Any]] = None) -> List[Dict[str, Any]]:
def load_data(
self, array: Any, targets: Optional[List[Any]] = None, target_formatter: Optional[TargetFormatter] = None
) -> List[Dict[str, Any]]:
if targets is not None:
self.load_target_metadata(targets)
self.load_target_metadata(targets, target_formatter=target_formatter)
return to_samples(array, targets)

def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
Expand All @@ -88,9 +91,11 @@ def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:


class AudioClassificationTensorInput(AudioClassificationNumpyInput):
def load_data(self, tensor: Any, targets: Optional[List[Any]] = None) -> List[Dict[str, Any]]:
def load_data(
self, tensor: Any, targets: Optional[List[Any]] = None, target_formatter: Optional[TargetFormatter] = None
) -> List[Dict[str, Any]]:
if targets is not None:
self.load_target_metadata(targets)
self.load_target_metadata(targets, target_formatter=target_formatter)
return to_samples(tensor, targets)

def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
Expand All @@ -106,22 +111,22 @@ def load_data(
target_keys: Optional[Union[str, List[str]]] = None,
root: Optional[PATH_TYPE] = None,
resolver: Optional[Callable[[Optional[PATH_TYPE], Any], PATH_TYPE]] = None,
target_formatter: Optional[TargetFormatter] = None,
) -> List[Dict[str, Any]]:
files = resolve_files(data_frame, input_key, root, resolver)
if target_keys is not None:
targets = resolve_targets(data_frame, target_keys)
else:
targets = None
result = super().load_data(files, targets)
result = super().load_data(files, targets, target_formatter=target_formatter)

# If we had binary multi-class targets then we also know the labels (column names)
if (
self.training
and isinstance(self.target_formatter, MultiBinaryTargetFormatter)
and isinstance(target_keys, List)
):
classification_state = self.get_state(ClassificationState)
self.set_state(ClassificationState(target_keys, classification_state.num_classes))
self.labels = target_keys

return result

Expand All @@ -134,8 +139,9 @@ def load_data(
target_keys: Optional[Union[str, List[str]]] = None,
root: Optional[PATH_TYPE] = None,
resolver: Optional[Callable[[Optional[PATH_TYPE], Any], PATH_TYPE]] = None,
target_formatter: Optional[TargetFormatter] = None,
) -> List[Dict[str, Any]]:
data_frame = read_csv(csv_file)
if root is None:
root = os.path.dirname(csv_file)
return super().load_data(data_frame, input_key, target_keys, root, resolver)
return super().load_data(data_frame, input_key, target_keys, root, resolver, target_formatter=target_formatter)
Loading

0 comments on commit 796c9c8

Please sign in to comment.