Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Add output argument to Trainer.predict and remove DataPipelineState #1157

Merged
merged 42 commits into from
Feb 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
f792dd5
Initial commit
ethanwharris Feb 7, 2022
ff297ab
Updates
ethanwharris Feb 8, 2022
99cc5a1
Fix video
ethanwharris Feb 8, 2022
4f701a9
Fix tabular
ethanwharris Feb 8, 2022
0b8c24a
Fix graph
ethanwharris Feb 8, 2022
e42dbbc
Fix audio classification
ethanwharris Feb 8, 2022
416bbb9
Fix object detection
ethanwharris Feb 8, 2022
301e1c5
Fix pointcloud
ethanwharris Feb 8, 2022
42cbc69
Fixes
ethanwharris Feb 8, 2022
8ffe5f6
Fixes
ethanwharris Feb 8, 2022
2aaf7c0
CLI fixes
ethanwharris Feb 8, 2022
bbf2e83
Merge branch 'master' into feature/predict_output
ethanwharris Feb 8, 2022
673e8c2
Fixes
ethanwharris Feb 8, 2022
db06f3e
Fixes
ethanwharris Feb 8, 2022
f60aca9
Remove data pipeline state
ethanwharris Feb 8, 2022
9347acb
Fixes
ethanwharris Feb 8, 2022
68dd199
Fixes
ethanwharris Feb 8, 2022
922863a
Fixes
ethanwharris Feb 8, 2022
87a94cd
Fixes
ethanwharris Feb 8, 2022
acb99fa
Fixes
ethanwharris Feb 8, 2022
aa1930a
Fixes
ethanwharris Feb 8, 2022
e46add9
Fixes
ethanwharris Feb 8, 2022
0fe3aef
Fixes
ethanwharris Feb 9, 2022
a701714
Fixes
ethanwharris Feb 9, 2022
5c3bf98
Try to reduce memory footprint
ethanwharris Feb 9, 2022
599b763
Fix tabular serving
ethanwharris Feb 9, 2022
ab3480b
Batch tokenize
ethanwharris Feb 9, 2022
c2b2500
Try to reduce memory footprint
ethanwharris Feb 9, 2022
dda397b
Try to reduce memory footprint
ethanwharris Feb 9, 2022
e11c5a4
Fixes
ethanwharris Feb 9, 2022
b82d130
Try to reduce memory footprint
ethanwharris Feb 9, 2022
1748f11
Drop broken test
ethanwharris Feb 9, 2022
8b0c11f
Drop broken test
ethanwharris Feb 9, 2022
77081f9
Fixes
ethanwharris Feb 9, 2022
550b246
Docs fixes
ethanwharris Feb 9, 2022
9f3a301
Merge branch 'master' into feature/predict_output
ethanwharris Feb 10, 2022
4767d43
Merge branch 'master' into feature/predict_output
ethanwharris Feb 14, 2022
29e358a
Fixes
ethanwharris Feb 14, 2022
aba7042
Trigger CI
ethanwharris Feb 14, 2022
81e9f6e
Trigger CI
ethanwharris Feb 14, 2022
055cf7d
Trigger CI
ethanwharris Feb 14, 2022
71773c7
Update CHANGELOG.md
ethanwharris Feb 14, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added `FlashRegistry` of Available Heads for `flash.image.ImageClassifier` ([#1152](https://github.com/PyTorchLightning/lightning-flash/pull/1152))

- Added support for `objectDetectionData.from_files` ([#1154](https://github.com/PyTorchLightning/lightning-flash/pull/1154))
- Added support for `ObjectDetectionData.from_files` ([#1154](https://github.com/PyTorchLightning/lightning-flash/pull/1154))

- Added support for passing the `Output` object (or a string e.g. `"labels"`) to the `flash.Trainer.predict` method ([#1157](https://github.com/PyTorchLightning/lightning-flash/pull/1157))

### Changed

Expand Down Expand Up @@ -86,6 +88,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Removed the `Seq2SeqData` base class (use `TranslationData` or `SummarizationData` directly) ([#1128](https://github.com/PyTorchLightning/lightning-flash/pull/1128))

- Removed the ability to attach the `Output` object directly to the model ([#1157](https://github.com/PyTorchLightning/lightning-flash/pull/1157))

## [0.6.0] - 2021-13-12

### Added
Expand Down
1 change: 0 additions & 1 deletion docs/source/api/audio.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,3 @@ __________________
speech_recognition.input.SpeechRecognitionDatasetInput
speech_recognition.input.SpeechRecognitionDeserializer
speech_recognition.output_transform.SpeechRecognitionOutputTransform
speech_recognition.output_transform.SpeechRecognitionBackboneState
4 changes: 0 additions & 4 deletions docs/source/api/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ _____________________________
:template: classtemplate.rst

~flash.core.data.data_pipeline.DataPipeline
~flash.core.data.data_pipeline.DataPipelineState

flash.core.data.io.input
___________________________
Expand All @@ -80,7 +79,6 @@ ___________________________
~flash.core.data.io.input.Input
~flash.core.data.io.input.DataKeys
~flash.core.data.io.input.InputFormat
~flash.core.data.io.input.ImageLabelsMap

flash.core.data.io.classification_input
_______________________________________
Expand All @@ -90,7 +88,6 @@ _______________________________________
:nosignatures:
:template: classtemplate.rst

~flash.core.data.io.classification_input.ClassificationState
~flash.core.data.io.classification_input.ClassificationInputMixin

flash.core.data.utilities.classification
Expand Down Expand Up @@ -133,7 +130,6 @@ __________________________
:nosignatures:
:template: classtemplate.rst

~flash.core.data.properties.ProcessState
~flash.core.data.properties.Properties

flash.core.data.splits
Expand Down
1 change: 0 additions & 1 deletion docs/source/api/tabular.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ ___________
~forecasting.data.TabularForecastingData

forecasting.input.TabularForecastingDataFrameInput
forecasting.input.TimeSeriesDataSetParametersState

flash.tabular.data
__________________
Expand Down
2 changes: 0 additions & 2 deletions docs/source/api/text.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ __________________
question_answering.input.QuestionAnsweringJSONInput
question_answering.input.QuestionAnsweringSQuADInput
question_answering.input.QuestionAnsweringDictionaryInput
question_answering.input_transform.QuestionAnsweringInputTransform
question_answering.output_transform.QuestionAnsweringOutputTransform

Summarization
Expand Down Expand Up @@ -92,7 +91,6 @@ _______________
seq2seq.core.input.Seq2SeqCSVInput
seq2seq.core.input.Seq2SeqJSONInput
seq2seq.core.input.Seq2SeqListInput
seq2seq.core.output_transform.Seq2SeqOutputTransform

flash.text.input
________________
Expand Down
7 changes: 2 additions & 5 deletions docs/source/common/finetuning_example.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ Here's an example of finetuning.
)

# 2. Build the model using desired Task
model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes)
model = ImageClassifier(backbone="resnet18", labels=datamodule.labels)

# 3. Create the trainer (run one epoch for demo)
trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count())
Expand All @@ -56,17 +56,14 @@ Once you've finetuned, use the model to predict:

.. testcode:: finetune

# Output predictions as labels, automatically inferred from the training data in part 2.
model.output = LabelsOutput()

predict_datamodule = ImageClassificationData.from_files(
predict_files=[
"data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg",
"data/hymenoptera_data/val/ants/2255445811_dabcdf7258.jpg",
],
batch_size=1,
)
predictions = trainer.predict(model, datamodule=predict_datamodule)
predictions = trainer.predict(model, datamodule=predict_datamodule, output="labels")
print(predictions)

We get the following output:
Expand Down
2 changes: 1 addition & 1 deletion docs/source/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ Here's an example of inference:
],
batch_size=4,
)
predictions = trainer.predict(model, datamodule=datamodule)
predictions = trainer.predict(model, datamodule=datamodule, output="labels")
print(predictions)

We get the following output:
Expand Down
6 changes: 3 additions & 3 deletions docs/source/template/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -172,11 +172,11 @@ OutputTransform

:class:`~flash.core.data.io.output_transform.OutputTransform` contains any transforms that need to be applied *after* the model.
You may want to use it for: converting tokens back into text, applying an inverse normalization to an output image, resizing a generated image back to the size of the input, etc.
As an example, here's the :class:`~text.seq2seq.core.output_transform.Seq2SeqOutputTransform` which decodes tokenized model outputs:
As an example, here's the :class:`~image.segmentation.model.SemanticSegmentationOutputTransform` which decodes tokenized model outputs:

.. literalinclude:: ../../../flash/text/seq2seq/core/output_transform.py
.. literalinclude:: ../../../flash/image/segmentation/model.py
:language: python
:pyobject: Seq2SeqOutputTransform
:pyobject: SemanticSegmentationOutputTransform

In your :class:`~flash.core.data.io.input.Input` or :class:`~flash.core.data.io.input_transform.InputTransform`, you can add metadata to the batch using the :attr:`~flash.core.data.io.input.DataKeys.METADATA` key.
Your :class:`~flash.core.data.io.output_transform.OutputTransform` can then use this metadata in its transforms.
Expand Down
1 change: 1 addition & 0 deletions flash/audio/classification/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def audio_classification():
default_arguments={
"trainer.max_epochs": 3,
},
datamodule_attributes={"num_classes", "labels", "multi_label"},
)

cli.trainer.save_checkpoint("audio_classification_model.pt")
Expand Down
37 changes: 24 additions & 13 deletions flash/audio/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from flash.audio.classification.input_transform import AudioClassificationInputTransform
from flash.core.data.callback import BaseDataFetcher
from flash.core.data.data_module import DataModule
from flash.core.data.data_pipeline import DataPipelineState
from flash.core.data.io.input import Input
from flash.core.data.io.input_transform import INPUT_TRANSFORM_TYPE
from flash.core.data.utilities.paths import PATH_TYPE
Expand Down Expand Up @@ -143,13 +142,15 @@ def from_files(
"""

ds_kw = dict(
data_pipeline_state=DataPipelineState(),
transform_kwargs=transform_kwargs,
input_transforms_registry=cls.input_transforms_registry,
)

train_input = input_cls(RunningStage.TRAINING, train_files, train_targets, transform=train_transform, **ds_kw)
ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None)

return cls(
input_cls(RunningStage.TRAINING, train_files, train_targets, transform=train_transform, **ds_kw),
train_input,
input_cls(RunningStage.VALIDATING, val_files, val_targets, transform=val_transform, **ds_kw),
input_cls(RunningStage.TESTING, test_files, test_targets, transform=test_transform, **ds_kw),
input_cls(RunningStage.PREDICTING, predict_files, transform=predict_transform, **ds_kw),
Expand Down Expand Up @@ -269,13 +270,15 @@ def from_folders(
"""

ds_kw = dict(
data_pipeline_state=DataPipelineState(),
transform_kwargs=transform_kwargs,
input_transforms_registry=cls.input_transforms_registry,
)

train_input = input_cls(RunningStage.TRAINING, train_folder, transform=train_transform, **ds_kw)
ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None)

return cls(
input_cls(RunningStage.TRAINING, train_folder, transform=train_transform, **ds_kw),
train_input,
input_cls(RunningStage.VALIDATING, val_folder, transform=val_transform, **ds_kw),
input_cls(RunningStage.TESTING, test_folder, transform=test_transform, **ds_kw),
input_cls(RunningStage.PREDICTING, predict_folder, transform=predict_transform, **ds_kw),
Expand Down Expand Up @@ -358,13 +361,15 @@ def from_numpy(
"""

ds_kw = dict(
data_pipeline_state=DataPipelineState(),
transform_kwargs=transform_kwargs,
input_transforms_registry=cls.input_transforms_registry,
)

train_input = input_cls(RunningStage.TRAINING, train_data, train_targets, transform=train_transform, **ds_kw)
ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None)

return cls(
input_cls(RunningStage.TRAINING, train_data, train_targets, transform=train_transform, **ds_kw),
train_input,
input_cls(RunningStage.VALIDATING, val_data, val_targets, transform=val_transform, **ds_kw),
input_cls(RunningStage.TESTING, test_data, test_targets, transform=test_transform, **ds_kw),
input_cls(RunningStage.PREDICTING, predict_data, transform=predict_transform, **ds_kw),
Expand Down Expand Up @@ -447,13 +452,15 @@ def from_tensors(
"""

ds_kw = dict(
data_pipeline_state=DataPipelineState(),
transform_kwargs=transform_kwargs,
input_transforms_registry=cls.input_transforms_registry,
)

train_input = input_cls(RunningStage.TRAINING, train_data, train_targets, transform=train_transform, **ds_kw)
ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None)

return cls(
input_cls(RunningStage.TRAINING, train_data, train_targets, transform=train_transform, **ds_kw),
train_input,
input_cls(RunningStage.VALIDATING, val_data, val_targets, transform=val_transform, **ds_kw),
input_cls(RunningStage.TESTING, test_data, test_targets, transform=test_transform, **ds_kw),
input_cls(RunningStage.PREDICTING, predict_data, transform=predict_transform, **ds_kw),
Expand Down Expand Up @@ -595,7 +602,6 @@ def from_data_frame(
"""

ds_kw = dict(
data_pipeline_state=DataPipelineState(),
transform_kwargs=transform_kwargs,
input_transforms_registry=cls.input_transforms_registry,
)
Expand All @@ -605,8 +611,11 @@ def from_data_frame(
test_data = (test_data_frame, input_field, target_fields, test_images_root, test_resolver)
predict_data = (predict_data_frame, input_field, None, predict_images_root, predict_resolver)

train_input = input_cls(RunningStage.TRAINING, *train_data, transform=train_transform, **ds_kw)
ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None)

return cls(
input_cls(RunningStage.TRAINING, *train_data, transform=train_transform, **ds_kw),
train_input,
input_cls(RunningStage.VALIDATING, *val_data, transform=val_transform, **ds_kw),
input_cls(RunningStage.TESTING, *test_data, transform=test_transform, **ds_kw),
input_cls(RunningStage.PREDICTING, *predict_data, transform=predict_transform, **ds_kw),
Expand Down Expand Up @@ -758,7 +767,6 @@ def from_csv(
"""

ds_kw = dict(
data_pipeline_state=DataPipelineState(),
transform_kwargs=transform_kwargs,
input_transforms_registry=cls.input_transforms_registry,
)
Expand All @@ -768,8 +776,11 @@ def from_csv(
test_data = (test_file, input_field, target_fields, test_images_root, test_resolver)
predict_data = (predict_file, input_field, None, predict_images_root, predict_resolver)

train_input = input_cls(RunningStage.TRAINING, *train_data, transform=train_transform, **ds_kw)
ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None)

return cls(
input_cls(RunningStage.TRAINING, *train_data, transform=train_transform, **ds_kw),
train_input,
input_cls(RunningStage.VALIDATING, *val_data, transform=val_transform, **ds_kw),
input_cls(RunningStage.TESTING, *test_data, transform=test_transform, **ds_kw),
input_cls(RunningStage.PREDICTING, *predict_data, transform=predict_transform, **ds_kw),
Expand Down
32 changes: 19 additions & 13 deletions flash/audio/classification/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
import numpy as np
import pandas as pd

from flash.core.data.io.classification_input import ClassificationInputMixin, ClassificationState
from flash.core.data.io.classification_input import ClassificationInputMixin
from flash.core.data.io.input import DataKeys, Input
from flash.core.data.utilities.classification import MultiBinaryTargetFormatter
from flash.core.data.utilities.classification import MultiBinaryTargetFormatter, TargetFormatter
from flash.core.data.utilities.data_frame import read_csv, resolve_files, resolve_targets
from flash.core.data.utilities.paths import filter_valid_files, has_file_allowed_extension, make_dataset, PATH_TYPE
from flash.core.data.utilities.samples import to_samples
Expand Down Expand Up @@ -54,12 +54,13 @@ def load_data(
self,
files: List[PATH_TYPE],
targets: Optional[List[Any]] = None,
target_formatter: Optional[TargetFormatter] = None,
) -> List[Dict[str, Any]]:
if targets is None:
files = filter_valid_files(files, valid_extensions=IMG_EXTENSIONS + NP_EXTENSIONS)
return to_samples(files)
files, targets = filter_valid_files(files, targets, valid_extensions=IMG_EXTENSIONS + NP_EXTENSIONS)
self.load_target_metadata(targets)
self.load_target_metadata(targets, target_formatter=target_formatter)
return to_samples(files, targets)

def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
Expand All @@ -71,15 +72,17 @@ def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:


class AudioClassificationFolderInput(AudioClassificationFilesInput):
def load_data(self, folder: PATH_TYPE) -> List[Dict[str, Any]]:
def load_data(self, folder: PATH_TYPE, target_formatter: Optional[TargetFormatter] = None) -> List[Dict[str, Any]]:
files, targets = make_dataset(folder, extensions=IMG_EXTENSIONS + NP_EXTENSIONS)
return super().load_data(files, targets)
return super().load_data(files, targets, target_formatter=target_formatter)


class AudioClassificationNumpyInput(AudioClassificationInput):
def load_data(self, array: Any, targets: Optional[List[Any]] = None) -> List[Dict[str, Any]]:
def load_data(
self, array: Any, targets: Optional[List[Any]] = None, target_formatter: Optional[TargetFormatter] = None
) -> List[Dict[str, Any]]:
if targets is not None:
self.load_target_metadata(targets)
self.load_target_metadata(targets, target_formatter=target_formatter)
return to_samples(array, targets)

def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
Expand All @@ -88,9 +91,11 @@ def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:


class AudioClassificationTensorInput(AudioClassificationNumpyInput):
def load_data(self, tensor: Any, targets: Optional[List[Any]] = None) -> List[Dict[str, Any]]:
def load_data(
self, tensor: Any, targets: Optional[List[Any]] = None, target_formatter: Optional[TargetFormatter] = None
) -> List[Dict[str, Any]]:
if targets is not None:
self.load_target_metadata(targets)
self.load_target_metadata(targets, target_formatter=target_formatter)
return to_samples(tensor, targets)

def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
Expand All @@ -106,22 +111,22 @@ def load_data(
target_keys: Optional[Union[str, List[str]]] = None,
root: Optional[PATH_TYPE] = None,
resolver: Optional[Callable[[Optional[PATH_TYPE], Any], PATH_TYPE]] = None,
target_formatter: Optional[TargetFormatter] = None,
) -> List[Dict[str, Any]]:
files = resolve_files(data_frame, input_key, root, resolver)
if target_keys is not None:
targets = resolve_targets(data_frame, target_keys)
else:
targets = None
result = super().load_data(files, targets)
result = super().load_data(files, targets, target_formatter=target_formatter)

# If we had binary multi-class targets then we also know the labels (column names)
if (
self.training
and isinstance(self.target_formatter, MultiBinaryTargetFormatter)
and isinstance(target_keys, List)
):
classification_state = self.get_state(ClassificationState)
self.set_state(ClassificationState(target_keys, classification_state.num_classes))
self.labels = target_keys

return result

Expand All @@ -134,8 +139,9 @@ def load_data(
target_keys: Optional[Union[str, List[str]]] = None,
root: Optional[PATH_TYPE] = None,
resolver: Optional[Callable[[Optional[PATH_TYPE], Any], PATH_TYPE]] = None,
target_formatter: Optional[TargetFormatter] = None,
) -> List[Dict[str, Any]]:
data_frame = read_csv(csv_file)
if root is None:
root = os.path.dirname(csv_file)
return super().load_data(data_frame, input_key, target_keys, root, resolver)
return super().load_data(data_frame, input_key, target_keys, root, resolver, target_formatter=target_formatter)
Loading