diff --git a/CHANGELOG.md b/CHANGELOG.md index f213a32b41..7c573622c2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,7 +22,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `FlashRegistry` of Available Heads for `flash.image.ImageClassifier` ([#1152](https://github.com/PyTorchLightning/lightning-flash/pull/1152)) -- Added support for `objectDetectionData.from_files` ([#1154](https://github.com/PyTorchLightning/lightning-flash/pull/1154)) +- Added support for `ObjectDetectionData.from_files` ([#1154](https://github.com/PyTorchLightning/lightning-flash/pull/1154)) + +- Added support for passing the `Output` object (or a string e.g. `"labels"`) to the `flash.Trainer.predict` method ([#1157](https://github.com/PyTorchLightning/lightning-flash/pull/1157)) ### Changed @@ -86,6 +88,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed the `Seq2SeqData` base class (use `TranslationData` or `SummarizationData` directly) ([#1128](https://github.com/PyTorchLightning/lightning-flash/pull/1128)) +- Removed the ability to attach the `Output` object directly to the model ([#1157](https://github.com/PyTorchLightning/lightning-flash/pull/1157)) + ## [0.6.0] - 2021-13-12 ### Added diff --git a/docs/source/api/audio.rst b/docs/source/api/audio.rst index 83351ca604..9b4cffe20b 100644 --- a/docs/source/api/audio.rst +++ b/docs/source/api/audio.rst @@ -39,4 +39,3 @@ __________________ speech_recognition.input.SpeechRecognitionDatasetInput speech_recognition.input.SpeechRecognitionDeserializer speech_recognition.output_transform.SpeechRecognitionOutputTransform - speech_recognition.output_transform.SpeechRecognitionBackboneState diff --git a/docs/source/api/data.rst b/docs/source/api/data.rst index 8948ed8346..c6f9c9976c 100644 --- a/docs/source/api/data.rst +++ b/docs/source/api/data.rst @@ -67,7 +67,6 @@ _____________________________ :template: classtemplate.rst ~flash.core.data.data_pipeline.DataPipeline - ~flash.core.data.data_pipeline.DataPipelineState flash.core.data.io.input ___________________________ @@ -80,7 +79,6 @@ ___________________________ ~flash.core.data.io.input.Input ~flash.core.data.io.input.DataKeys ~flash.core.data.io.input.InputFormat - ~flash.core.data.io.input.ImageLabelsMap flash.core.data.io.classification_input _______________________________________ @@ -90,7 +88,6 @@ _______________________________________ :nosignatures: :template: classtemplate.rst - ~flash.core.data.io.classification_input.ClassificationState ~flash.core.data.io.classification_input.ClassificationInputMixin flash.core.data.utilities.classification @@ -133,7 +130,6 @@ __________________________ :nosignatures: :template: classtemplate.rst - ~flash.core.data.properties.ProcessState ~flash.core.data.properties.Properties flash.core.data.splits diff --git a/docs/source/api/tabular.rst b/docs/source/api/tabular.rst index c838c7d9c3..adade46846 100644 --- a/docs/source/api/tabular.rst +++ b/docs/source/api/tabular.rst @@ -49,7 +49,6 @@ ___________ ~forecasting.data.TabularForecastingData forecasting.input.TabularForecastingDataFrameInput - forecasting.input.TimeSeriesDataSetParametersState flash.tabular.data __________________ diff --git a/docs/source/api/text.rst b/docs/source/api/text.rst index 3144697036..312c743c63 100644 --- a/docs/source/api/text.rst +++ b/docs/source/api/text.rst @@ -53,7 +53,6 @@ __________________ question_answering.input.QuestionAnsweringJSONInput question_answering.input.QuestionAnsweringSQuADInput question_answering.input.QuestionAnsweringDictionaryInput - question_answering.input_transform.QuestionAnsweringInputTransform question_answering.output_transform.QuestionAnsweringOutputTransform Summarization @@ -92,7 +91,6 @@ _______________ seq2seq.core.input.Seq2SeqCSVInput seq2seq.core.input.Seq2SeqJSONInput seq2seq.core.input.Seq2SeqListInput - seq2seq.core.output_transform.Seq2SeqOutputTransform flash.text.input ________________ diff --git a/docs/source/common/finetuning_example.rst b/docs/source/common/finetuning_example.rst index 118f4a7749..cdb475a2b1 100644 --- a/docs/source/common/finetuning_example.rst +++ b/docs/source/common/finetuning_example.rst @@ -33,7 +33,7 @@ Here's an example of finetuning. ) # 2. Build the model using desired Task - model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes) + model = ImageClassifier(backbone="resnet18", labels=datamodule.labels) # 3. Create the trainer (run one epoch for demo) trainer = flash.Trainer(max_epochs=1, gpus=torch.cuda.device_count()) @@ -56,9 +56,6 @@ Once you've finetuned, use the model to predict: .. testcode:: finetune - # Output predictions as labels, automatically inferred from the training data in part 2. - model.output = LabelsOutput() - predict_datamodule = ImageClassificationData.from_files( predict_files=[ "data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg", @@ -66,7 +63,7 @@ Once you've finetuned, use the model to predict: ], batch_size=1, ) - predictions = trainer.predict(model, datamodule=predict_datamodule) + predictions = trainer.predict(model, datamodule=predict_datamodule, output="labels") print(predictions) We get the following output: diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst index 1cf81e52da..596f6cf147 100644 --- a/docs/source/quickstart.rst +++ b/docs/source/quickstart.rst @@ -104,7 +104,7 @@ Here's an example of inference: ], batch_size=4, ) - predictions = trainer.predict(model, datamodule=datamodule) + predictions = trainer.predict(model, datamodule=datamodule, output="labels") print(predictions) We get the following output: diff --git a/docs/source/template/data.rst b/docs/source/template/data.rst index b39bc5786c..c73165a419 100644 --- a/docs/source/template/data.rst +++ b/docs/source/template/data.rst @@ -172,11 +172,11 @@ OutputTransform :class:`~flash.core.data.io.output_transform.OutputTransform` contains any transforms that need to be applied *after* the model. You may want to use it for: converting tokens back into text, applying an inverse normalization to an output image, resizing a generated image back to the size of the input, etc. -As an example, here's the :class:`~text.seq2seq.core.output_transform.Seq2SeqOutputTransform` which decodes tokenized model outputs: +As an example, here's the :class:`~image.segmentation.model.SemanticSegmentationOutputTransform` which decodes tokenized model outputs: -.. literalinclude:: ../../../flash/text/seq2seq/core/output_transform.py +.. literalinclude:: ../../../flash/image/segmentation/model.py :language: python - :pyobject: Seq2SeqOutputTransform + :pyobject: SemanticSegmentationOutputTransform In your :class:`~flash.core.data.io.input.Input` or :class:`~flash.core.data.io.input_transform.InputTransform`, you can add metadata to the batch using the :attr:`~flash.core.data.io.input.DataKeys.METADATA` key. Your :class:`~flash.core.data.io.output_transform.OutputTransform` can then use this metadata in its transforms. diff --git a/flash/audio/classification/cli.py b/flash/audio/classification/cli.py index b1250b826c..1181187d30 100644 --- a/flash/audio/classification/cli.py +++ b/flash/audio/classification/cli.py @@ -43,6 +43,7 @@ def audio_classification(): default_arguments={ "trainer.max_epochs": 3, }, + datamodule_attributes={"num_classes", "labels", "multi_label"}, ) cli.trainer.save_checkpoint("audio_classification_model.pt") diff --git a/flash/audio/classification/data.py b/flash/audio/classification/data.py index a29372d96f..c9920b4f2a 100644 --- a/flash/audio/classification/data.py +++ b/flash/audio/classification/data.py @@ -28,7 +28,6 @@ from flash.audio.classification.input_transform import AudioClassificationInputTransform from flash.core.data.callback import BaseDataFetcher from flash.core.data.data_module import DataModule -from flash.core.data.data_pipeline import DataPipelineState from flash.core.data.io.input import Input from flash.core.data.io.input_transform import INPUT_TRANSFORM_TYPE from flash.core.data.utilities.paths import PATH_TYPE @@ -143,13 +142,15 @@ def from_files( """ ds_kw = dict( - data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, ) + train_input = input_cls(RunningStage.TRAINING, train_files, train_targets, transform=train_transform, **ds_kw) + ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) + return cls( - input_cls(RunningStage.TRAINING, train_files, train_targets, transform=train_transform, **ds_kw), + train_input, input_cls(RunningStage.VALIDATING, val_files, val_targets, transform=val_transform, **ds_kw), input_cls(RunningStage.TESTING, test_files, test_targets, transform=test_transform, **ds_kw), input_cls(RunningStage.PREDICTING, predict_files, transform=predict_transform, **ds_kw), @@ -269,13 +270,15 @@ def from_folders( """ ds_kw = dict( - data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, ) + train_input = input_cls(RunningStage.TRAINING, train_folder, transform=train_transform, **ds_kw) + ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) + return cls( - input_cls(RunningStage.TRAINING, train_folder, transform=train_transform, **ds_kw), + train_input, input_cls(RunningStage.VALIDATING, val_folder, transform=val_transform, **ds_kw), input_cls(RunningStage.TESTING, test_folder, transform=test_transform, **ds_kw), input_cls(RunningStage.PREDICTING, predict_folder, transform=predict_transform, **ds_kw), @@ -358,13 +361,15 @@ def from_numpy( """ ds_kw = dict( - data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, ) + train_input = input_cls(RunningStage.TRAINING, train_data, train_targets, transform=train_transform, **ds_kw) + ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) + return cls( - input_cls(RunningStage.TRAINING, train_data, train_targets, transform=train_transform, **ds_kw), + train_input, input_cls(RunningStage.VALIDATING, val_data, val_targets, transform=val_transform, **ds_kw), input_cls(RunningStage.TESTING, test_data, test_targets, transform=test_transform, **ds_kw), input_cls(RunningStage.PREDICTING, predict_data, transform=predict_transform, **ds_kw), @@ -447,13 +452,15 @@ def from_tensors( """ ds_kw = dict( - data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, ) + train_input = input_cls(RunningStage.TRAINING, train_data, train_targets, transform=train_transform, **ds_kw) + ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) + return cls( - input_cls(RunningStage.TRAINING, train_data, train_targets, transform=train_transform, **ds_kw), + train_input, input_cls(RunningStage.VALIDATING, val_data, val_targets, transform=val_transform, **ds_kw), input_cls(RunningStage.TESTING, test_data, test_targets, transform=test_transform, **ds_kw), input_cls(RunningStage.PREDICTING, predict_data, transform=predict_transform, **ds_kw), @@ -595,7 +602,6 @@ def from_data_frame( """ ds_kw = dict( - data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, ) @@ -605,8 +611,11 @@ def from_data_frame( test_data = (test_data_frame, input_field, target_fields, test_images_root, test_resolver) predict_data = (predict_data_frame, input_field, None, predict_images_root, predict_resolver) + train_input = input_cls(RunningStage.TRAINING, *train_data, transform=train_transform, **ds_kw) + ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) + return cls( - input_cls(RunningStage.TRAINING, *train_data, transform=train_transform, **ds_kw), + train_input, input_cls(RunningStage.VALIDATING, *val_data, transform=val_transform, **ds_kw), input_cls(RunningStage.TESTING, *test_data, transform=test_transform, **ds_kw), input_cls(RunningStage.PREDICTING, *predict_data, transform=predict_transform, **ds_kw), @@ -758,7 +767,6 @@ def from_csv( """ ds_kw = dict( - data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, ) @@ -768,8 +776,11 @@ def from_csv( test_data = (test_file, input_field, target_fields, test_images_root, test_resolver) predict_data = (predict_file, input_field, None, predict_images_root, predict_resolver) + train_input = input_cls(RunningStage.TRAINING, *train_data, transform=train_transform, **ds_kw) + ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) + return cls( - input_cls(RunningStage.TRAINING, *train_data, transform=train_transform, **ds_kw), + train_input, input_cls(RunningStage.VALIDATING, *val_data, transform=val_transform, **ds_kw), input_cls(RunningStage.TESTING, *test_data, transform=test_transform, **ds_kw), input_cls(RunningStage.PREDICTING, *predict_data, transform=predict_transform, **ds_kw), diff --git a/flash/audio/classification/input.py b/flash/audio/classification/input.py index a7cf78dcba..fcb5c6f756 100644 --- a/flash/audio/classification/input.py +++ b/flash/audio/classification/input.py @@ -17,9 +17,9 @@ import numpy as np import pandas as pd -from flash.core.data.io.classification_input import ClassificationInputMixin, ClassificationState +from flash.core.data.io.classification_input import ClassificationInputMixin from flash.core.data.io.input import DataKeys, Input -from flash.core.data.utilities.classification import MultiBinaryTargetFormatter +from flash.core.data.utilities.classification import MultiBinaryTargetFormatter, TargetFormatter from flash.core.data.utilities.data_frame import read_csv, resolve_files, resolve_targets from flash.core.data.utilities.paths import filter_valid_files, has_file_allowed_extension, make_dataset, PATH_TYPE from flash.core.data.utilities.samples import to_samples @@ -54,12 +54,13 @@ def load_data( self, files: List[PATH_TYPE], targets: Optional[List[Any]] = None, + target_formatter: Optional[TargetFormatter] = None, ) -> List[Dict[str, Any]]: if targets is None: files = filter_valid_files(files, valid_extensions=IMG_EXTENSIONS + NP_EXTENSIONS) return to_samples(files) files, targets = filter_valid_files(files, targets, valid_extensions=IMG_EXTENSIONS + NP_EXTENSIONS) - self.load_target_metadata(targets) + self.load_target_metadata(targets, target_formatter=target_formatter) return to_samples(files, targets) def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: @@ -71,15 +72,17 @@ def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: class AudioClassificationFolderInput(AudioClassificationFilesInput): - def load_data(self, folder: PATH_TYPE) -> List[Dict[str, Any]]: + def load_data(self, folder: PATH_TYPE, target_formatter: Optional[TargetFormatter] = None) -> List[Dict[str, Any]]: files, targets = make_dataset(folder, extensions=IMG_EXTENSIONS + NP_EXTENSIONS) - return super().load_data(files, targets) + return super().load_data(files, targets, target_formatter=target_formatter) class AudioClassificationNumpyInput(AudioClassificationInput): - def load_data(self, array: Any, targets: Optional[List[Any]] = None) -> List[Dict[str, Any]]: + def load_data( + self, array: Any, targets: Optional[List[Any]] = None, target_formatter: Optional[TargetFormatter] = None + ) -> List[Dict[str, Any]]: if targets is not None: - self.load_target_metadata(targets) + self.load_target_metadata(targets, target_formatter=target_formatter) return to_samples(array, targets) def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: @@ -88,9 +91,11 @@ def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: class AudioClassificationTensorInput(AudioClassificationNumpyInput): - def load_data(self, tensor: Any, targets: Optional[List[Any]] = None) -> List[Dict[str, Any]]: + def load_data( + self, tensor: Any, targets: Optional[List[Any]] = None, target_formatter: Optional[TargetFormatter] = None + ) -> List[Dict[str, Any]]: if targets is not None: - self.load_target_metadata(targets) + self.load_target_metadata(targets, target_formatter=target_formatter) return to_samples(tensor, targets) def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: @@ -106,13 +111,14 @@ def load_data( target_keys: Optional[Union[str, List[str]]] = None, root: Optional[PATH_TYPE] = None, resolver: Optional[Callable[[Optional[PATH_TYPE], Any], PATH_TYPE]] = None, + target_formatter: Optional[TargetFormatter] = None, ) -> List[Dict[str, Any]]: files = resolve_files(data_frame, input_key, root, resolver) if target_keys is not None: targets = resolve_targets(data_frame, target_keys) else: targets = None - result = super().load_data(files, targets) + result = super().load_data(files, targets, target_formatter=target_formatter) # If we had binary multi-class targets then we also know the labels (column names) if ( @@ -120,8 +126,7 @@ def load_data( and isinstance(self.target_formatter, MultiBinaryTargetFormatter) and isinstance(target_keys, List) ): - classification_state = self.get_state(ClassificationState) - self.set_state(ClassificationState(target_keys, classification_state.num_classes)) + self.labels = target_keys return result @@ -134,8 +139,9 @@ def load_data( target_keys: Optional[Union[str, List[str]]] = None, root: Optional[PATH_TYPE] = None, resolver: Optional[Callable[[Optional[PATH_TYPE], Any], PATH_TYPE]] = None, + target_formatter: Optional[TargetFormatter] = None, ) -> List[Dict[str, Any]]: data_frame = read_csv(csv_file) if root is None: root = os.path.dirname(csv_file) - return super().load_data(data_frame, input_key, target_keys, root, resolver) + return super().load_data(data_frame, input_key, target_keys, root, resolver, target_formatter=target_formatter) diff --git a/flash/audio/speech_recognition/data.py b/flash/audio/speech_recognition/data.py index f99e21d808..0b44c4ac26 100644 --- a/flash/audio/speech_recognition/data.py +++ b/flash/audio/speech_recognition/data.py @@ -23,7 +23,6 @@ ) from flash.audio.speech_recognition.output_transform import SpeechRecognitionOutputTransform from flash.core.data.data_module import DataModule -from flash.core.data.data_pipeline import DataPipelineState from flash.core.data.io.input import Input from flash.core.data.io.input_transform import INPUT_TRANSFORM_TYPE, InputTransform from flash.core.registry import FlashRegistry @@ -128,7 +127,6 @@ def from_files( """ ds_kw = dict( - data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, sampling_rate=sampling_rate, @@ -257,7 +255,6 @@ def from_csv( """ ds_kw = dict( - data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, input_key=input_field, @@ -387,7 +384,6 @@ def from_json( """ ds_kw = dict( - data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, input_key=input_field, @@ -544,7 +540,6 @@ def from_datasets( """ ds_kw = dict( - data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, sampling_rate=sampling_rate, diff --git a/flash/audio/speech_recognition/model.py b/flash/audio/speech_recognition/model.py index bcff080c2c..49cfec4527 100644 --- a/flash/audio/speech_recognition/model.py +++ b/flash/audio/speech_recognition/model.py @@ -13,7 +13,7 @@ # limitations under the License. import os import warnings -from typing import Any, Dict, Optional, Type +from typing import Any, Dict, Optional, Type, Union import torch import torch.nn as nn @@ -21,24 +21,15 @@ from flash.audio.speech_recognition.backbone import SPEECH_RECOGNITION_BACKBONES from flash.audio.speech_recognition.collate import DataCollatorCTCWithPadding from flash.audio.speech_recognition.input import SpeechRecognitionDeserializer -from flash.audio.speech_recognition.output_transform import ( - SpeechRecognitionBackboneState, - SpeechRecognitionOutputTransform, -) +from flash.audio.speech_recognition.output_transform import SpeechRecognitionOutputTransform from flash.core.data.io.input import ServeInput from flash.core.data.io.input_transform import InputTransform -from flash.core.data.states import CollateFn +from flash.core.data.io.output import Output from flash.core.model import Task from flash.core.registry import FlashRegistry from flash.core.serve import Composition from flash.core.utilities.imports import _AUDIO_AVAILABLE, requires -from flash.core.utilities.types import ( - INPUT_TRANSFORM_TYPE, - LR_SCHEDULER_TYPE, - OPTIMIZER_TYPE, - OUTPUT_TRANSFORM_TYPE, - OUTPUT_TYPE, -) +from flash.core.utilities.types import INPUT_TRANSFORM_TYPE, LR_SCHEDULER_TYPE, OPTIMIZER_TYPE if _AUDIO_AVAILABLE: from transformers import AutoProcessor @@ -54,7 +45,6 @@ class SpeechRecognition(Task): learning_rate: Learning rate to use for training, defaults to ``1e-5``. optimizer: Optimizer to use for training. lr_scheduler: The LR scheduler to use during training. - output: The :class:`~flash.core.data.io.output.Output` to use when formatting prediction outputs. """ backbones: FlashRegistry = SPEECH_RECOGNITION_BACKBONES @@ -68,8 +58,6 @@ def __init__( optimizer: OPTIMIZER_TYPE = "Adam", lr_scheduler: LR_SCHEDULER_TYPE = None, learning_rate: float = 1e-5, - output: OUTPUT_TYPE = None, - output_transform: OUTPUT_TRANSFORM_TYPE = SpeechRecognitionOutputTransform(), ): os.environ["TOKENIZERS_PARALLELISM"] = "TRUE" # disable HF thousand warnings @@ -83,21 +71,15 @@ def __init__( optimizer=optimizer, lr_scheduler=lr_scheduler, learning_rate=learning_rate, - output=output, - output_transform=output_transform, + output_transform=SpeechRecognitionOutputTransform(backbone), ) self.save_hyperparameters() - self.set_state(SpeechRecognitionBackboneState(backbone)) - self.set_state( - CollateFn( - DataCollatorCTCWithPadding( - AutoProcessor.from_pretrained(backbone) - if processor_backbone is None - else AutoProcessor.from_pretrained(processor_backbone) - ) - ) + self.collate_fn = DataCollatorCTCWithPadding( + AutoProcessor.from_pretrained(backbone) + if processor_backbone is None + else AutoProcessor.from_pretrained(processor_backbone) ) def forward(self, batch: Dict[str, torch.Tensor]): @@ -120,5 +102,6 @@ def serve( input_cls: Optional[Type[ServeInput]] = SpeechRecognitionDeserializer, transform: INPUT_TRANSFORM_TYPE = InputTransform, transform_kwargs: Optional[Dict] = None, + output: Optional[Union[str, Output]] = None, ) -> Composition: - return super().serve(host, port, sanity_check, input_cls, transform, transform_kwargs) + return super().serve(host, port, sanity_check, input_cls, transform, transform_kwargs, output) diff --git a/flash/audio/speech_recognition/output_transform.py b/flash/audio/speech_recognition/output_transform.py index 56f282bed5..49f6a9defd 100644 --- a/flash/audio/speech_recognition/output_transform.py +++ b/flash/audio/speech_recognition/output_transform.py @@ -11,53 +11,29 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass from typing import Any import torch from flash.core.data.io.output_transform import OutputTransform -from flash.core.data.properties import ProcessState from flash.core.utilities.imports import _AUDIO_AVAILABLE, requires if _AUDIO_AVAILABLE: from transformers import Wav2Vec2CTCTokenizer -@dataclass(unsafe_hash=True, frozen=True) -class SpeechRecognitionBackboneState(ProcessState): - """The ``SpeechRecognitionBackboneState`` stores the backbone in use by the - :class:`~flash.audio.speech_recognition.data.SpeechRecognitionOutputTransform`. - """ - - backbone: str - - class SpeechRecognitionOutputTransform(OutputTransform): - def __init__(self): + def __init__(self, backbone: str): super().__init__() - self._backbone = None - self._tokenizer = None - - @property - def backbone(self): - backbone_state = self.get_state(SpeechRecognitionBackboneState) - if backbone_state is not None: - return backbone_state.backbone - - @property - def tokenizer(self): - if self.backbone is not None and self.backbone != self._backbone: - self._tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(self.backbone) - self._backbone = self.backbone - return self._tokenizer + self.backbone = backbone + self._tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(self.backbone) @requires("audio") def per_batch_transform(self, batch: Any) -> Any: # converts logits into greedy transcription pred_ids = torch.argmax(batch.logits, dim=-1) - transcriptions = self.tokenizer.batch_decode(pred_ids) + transcriptions = self._tokenizer.batch_decode(pred_ids) return transcriptions def __getstate__(self): # TODO: Find out why this is being pickled diff --git a/flash/core/adapter.py b/flash/core/adapter.py index 1167df2f90..1e3ca81e92 100644 --- a/flash/core/adapter.py +++ b/flash/core/adapter.py @@ -21,6 +21,7 @@ import flash from flash.core.data.io.input import InputBase from flash.core.model import DatasetProcessor, ModuleWrapperBase, Task +from flash.core.utilities.types import INPUT_TRANSFORM_TYPE class Adapter(DatasetProcessor, ModuleWrapperBase, nn.Module): @@ -78,6 +79,15 @@ def __init__(self, adapter: Adapter, **kwargs): self.adapter = adapter + @torch.jit.unused + @property + def input_transform(self) -> Optional[INPUT_TRANSFORM_TYPE]: + return self.adapter.input_transform + + @input_transform.setter + def input_transform(self, input_transform: INPUT_TRANSFORM_TYPE) -> None: + self.adapter.input_transform = input_transform + @torch.jit.unused @property def backbone(self) -> nn.Module: @@ -118,17 +128,19 @@ def process_train_dataset( shuffle: bool = False, drop_last: bool = True, sampler: Optional[Sampler] = None, + persistent_workers: bool = False, ) -> DataLoader: return self.adapter.process_train_dataset( dataset, trainer, - batch_size, - num_workers, - pin_memory, - collate_fn, - shuffle, - drop_last, - sampler, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + collate_fn=collate_fn, + shuffle=shuffle, + drop_last=drop_last, + sampler=sampler, + persistent_workers=persistent_workers, ) def process_val_dataset( @@ -142,17 +154,19 @@ def process_val_dataset( shuffle: bool = False, drop_last: bool = False, sampler: Optional[Sampler] = None, + persistent_workers: bool = False, ) -> DataLoader: return self.adapter.process_val_dataset( dataset, trainer, - batch_size, - num_workers, - pin_memory, - collate_fn, - shuffle, - drop_last, - sampler, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + collate_fn=collate_fn, + shuffle=shuffle, + drop_last=drop_last, + sampler=sampler, + persistent_workers=persistent_workers, ) def process_test_dataset( @@ -166,17 +180,19 @@ def process_test_dataset( shuffle: bool = False, drop_last: bool = False, sampler: Optional[Sampler] = None, + persistent_workers: bool = False, ) -> DataLoader: return self.adapter.process_test_dataset( dataset, trainer, - batch_size, - num_workers, - pin_memory, - collate_fn, - shuffle, - drop_last, - sampler, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + collate_fn=collate_fn, + shuffle=shuffle, + drop_last=drop_last, + sampler=sampler, + persistent_workers=persistent_workers, ) def process_predict_dataset( @@ -189,6 +205,7 @@ def process_predict_dataset( shuffle: bool = False, drop_last: bool = True, sampler: Optional[Sampler] = None, + persistent_workers: bool = False, ) -> DataLoader: return self.adapter.process_predict_dataset( dataset, @@ -199,4 +216,5 @@ def process_predict_dataset( shuffle=shuffle, drop_last=drop_last, sampler=sampler, + persistent_workers=persistent_workers, ) diff --git a/flash/core/classification.py b/flash/core/classification.py index bc9074b0d3..9ed743e50a 100644 --- a/flash/core/classification.py +++ b/flash/core/classification.py @@ -19,11 +19,12 @@ from torchmetrics import Accuracy, Metric from flash.core.adapter import AdapterTask -from flash.core.data.io.classification_input import ClassificationState from flash.core.data.io.input import DataKeys from flash.core.data.io.output import Output from flash.core.model import Task +from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _TM_GREATER_EQUAL_0_7_0, lazy_import, requires +from flash.core.utilities.providers import _FIFTYONE if _FIFTYONE_AVAILABLE: fol = lazy_import("fiftyone.core.labels") @@ -40,19 +41,27 @@ from torchmetrics import F1 as F1Score +CLASSIFICATION_OUTPUTS = FlashRegistry("outputs") + + def binary_cross_entropy_with_logits(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """Calls BCE with logits and cast the target one_hot (y) encoding to floating point precision.""" return F.binary_cross_entropy_with_logits(x, y.float()) class ClassificationMixin: - @staticmethod def _build( + self, num_classes: Optional[int] = None, + labels: Optional[List[str]] = None, loss_fn: Optional[Callable] = None, metrics: Union[Metric, Mapping, Sequence, None] = None, multi_label: bool = False, ): + self.num_classes = num_classes + self.multi_label = multi_label + self.labels = labels + if metrics is None: metrics = F1Score(num_classes) if (multi_label and num_classes) else Accuracy() @@ -68,6 +77,9 @@ def to_metrics_format(self, x: torch.Tensor) -> torch.Tensor: class ClassificationTask(Task, ClassificationMixin): + + outputs: FlashRegistry = Task.outputs + CLASSIFICATION_OUTPUTS + def __init__( self, *args, @@ -75,22 +87,24 @@ def __init__( loss_fn: Optional[Callable] = None, metrics: Union[Metric, Mapping, Sequence, None] = None, multi_label: bool = False, - output: Optional[Union[Output, Mapping[str, Output]]] = None, + labels: Optional[List[str]] = None, **kwargs, ) -> None: - metrics, loss_fn = ClassificationMixin._build(num_classes, loss_fn, metrics, multi_label) + metrics, loss_fn = self._build(num_classes, labels, loss_fn, metrics, multi_label) super().__init__( *args, loss_fn=loss_fn, metrics=metrics, - output=output or ClassesOutput(multi_label=multi_label), **kwargs, ) class ClassificationAdapterTask(AdapterTask, ClassificationMixin): + + outputs: FlashRegistry = Task.outputs + CLASSIFICATION_OUTPUTS + def __init__( self, *args, @@ -98,17 +112,16 @@ def __init__( loss_fn: Optional[Callable] = None, metrics: Union[Metric, Mapping, Sequence, None] = None, multi_label: bool = False, - output: Optional[Union[Output, Mapping[str, Output]]] = None, + labels: Optional[List[str]] = None, **kwargs, ) -> None: - metrics, loss_fn = ClassificationMixin._build(num_classes, loss_fn, metrics, multi_label) + metrics, loss_fn = self._build(num_classes, labels, loss_fn, metrics, multi_label) super().__init__( *args, loss_fn=loss_fn, metrics=metrics, - output=output or ClassesOutput(multi_label=multi_label), **kwargs, ) @@ -125,11 +138,16 @@ def __init__(self, multi_label: bool = False): self._mutli_label = multi_label + @classmethod + def from_task(cls, task: Task, **kwargs) -> Output: + return cls(multi_label=getattr(task, "multi_label", False)) + @property def multi_label(self) -> bool: return self._mutli_label +@CLASSIFICATION_OUTPUTS(name="preds") class PredsClassificationOutput(ClassificationOutput): """A :class:`~flash.core.classification.ClassificationOutput` which gets the :attr:`~flash.core.data.io.input.InputFormat.PREDS` from the sample. @@ -143,6 +161,7 @@ def transform(self, sample: Any) -> Any: return sample +@CLASSIFICATION_OUTPUTS(name="logits") class LogitsOutput(PredsClassificationOutput): """A :class:`.Output` which simply converts the model outputs (assumed to be logits) to a list.""" @@ -150,6 +169,7 @@ def transform(self, sample: Any) -> Any: return super().transform(sample).tolist() +@CLASSIFICATION_OUTPUTS(name="probabilities") class ProbabilitiesOutput(PredsClassificationOutput): """A :class:`.Output` which applies a softmax to the model outputs (assumed to be logits) and converts to a list.""" @@ -161,6 +181,7 @@ def transform(self, sample: Any) -> Any: return torch.softmax(sample, -1).tolist() +@CLASSIFICATION_OUTPUTS(name="classes") class ClassesOutput(PredsClassificationOutput): """A :class:`.Output` which applies an argmax to the model outputs (either logits or probabilities) and converts to a list. @@ -187,13 +208,13 @@ def transform(self, sample: Any) -> Union[int, List[int]]: return torch.argmax(sample, -1).tolist() +@CLASSIFICATION_OUTPUTS(name="labels") class LabelsOutput(ClassesOutput): """A :class:`.Output` which converts the model outputs (either logits or probabilities) to the label of the argmax classification. Args: - labels: A list of labels, assumed to map the class index to the label for that class. If ``labels`` is not - provided, will attempt to get them from the :class:`.ClassificationState`. + labels: A list of labels, assumed to map the class index to the label for that class. multi_label: If true, treats outputs as multi label logits. threshold: The threshold to use for multi_label classification. """ @@ -202,35 +223,27 @@ def __init__(self, labels: Optional[List[str]] = None, multi_label: bool = False super().__init__(multi_label=multi_label, threshold=threshold) self._labels = labels - if labels is not None: - self.set_state(ClassificationState(labels)) + @classmethod + def from_task(cls, task: Task, **kwargs) -> Output: + return cls(labels=getattr(task, "labels", None), multi_label=getattr(task, "multi_label", False)) def transform(self, sample: Any) -> Union[int, List[int], str, List[str]]: - labels = None - - if self._labels is not None: - labels = self._labels - else: - state = self.get_state(ClassificationState) - if state is not None: - labels = state.labels - classes = super().transform(sample) - if labels is not None: + if self._labels is not None: if self.multi_label: - return [labels[cls] for cls in classes] - return labels[classes] - rank_zero_warn("No ClassificationState was found, this output will act as a Classes output.", UserWarning) + return [self._labels[cls] for cls in classes] + return self._labels[classes] + rank_zero_warn("No labels were provided, this output will act as a Classes output.", UserWarning) return classes +@CLASSIFICATION_OUTPUTS(name="fiftyone", providers=_FIFTYONE) class FiftyOneLabelsOutput(ClassificationOutput): """A :class:`.Output` which converts the model outputs to FiftyOne classification format. Args: - labels: A list of labels, assumed to map the class index to the label for that class. If ``labels`` is not - provided, will attempt to get them from the :class:`.ClassificationState`. + labels: A list of labels, assumed to map the class index to the label for that class. multi_label: If true, treats outputs as multi label logits. threshold: A threshold to use to filter candidate labels. In the single label case, predictions below this threshold will be replaced with None @@ -247,7 +260,7 @@ def __init__( multi_label: bool = False, threshold: Optional[float] = None, store_logits: bool = False, - return_filepath: bool = False, + return_filepath: bool = True, ): if multi_label and threshold is None: threshold = 0.5 @@ -258,8 +271,9 @@ def __init__( self.store_logits = store_logits self.return_filepath = return_filepath - if labels is not None: - self.set_state(ClassificationState(labels)) + @classmethod + def from_task(cls, task: Task, **kwargs) -> Output: + return cls(labels=getattr(task, "labels", None), multi_label=getattr(task, "multi_label", False)) def transform( self, @@ -268,15 +282,6 @@ def transform( pred = sample[DataKeys.PREDS] if isinstance(sample, Dict) else sample pred = torch.tensor(pred) - labels = None - - if self._labels is not None: - labels = self._labels - else: - state = self.get_state(ClassificationState) - if state is not None: - labels = state.labels - logits = None if self.store_logits: logits = pred.tolist() @@ -292,12 +297,12 @@ def transform( classes = torch.argmax(pred, -1).tolist() probabilities = torch.softmax(pred, -1).tolist() - if labels is not None: + if self._labels is not None: if self.multi_label: classifications = [] for idx in classes: fo_cls = fol.Classification( - label=labels[idx], + label=self._labels[idx], confidence=probabilities[idx], ) classifications.append(fo_cls) @@ -311,12 +316,12 @@ def transform( fo_predictions = None else: fo_predictions = fol.Classification( - label=labels[classes], + label=self._labels[classes], confidence=confidence, logits=logits, ) else: - rank_zero_warn("No ClassificationState was found, int targets will be used as label strings", UserWarning) + rank_zero_warn("No labels were provided, int targets will be used as label strings", UserWarning) if self.multi_label: classifications = [] diff --git a/flash/core/data/batch.py b/flash/core/data/batch.py index d9b6849eca..d1ab5837d5 100644 --- a/flash/core/data/batch.py +++ b/flash/core/data/batch.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, TYPE_CHECKING +from typing import Any, Callable, List, TYPE_CHECKING import torch @@ -25,14 +25,17 @@ class _ServeInputProcessor(torch.nn.Module): def __init__( self, serve_input: "ServeInput", + collate_fn: Callable, ): super().__init__() self.serve_input = serve_input - self.dataloader_collate_fn = self.serve_input._create_dataloader_collate_fn([]) + self.collate_fn = collate_fn def forward(self, sample: str): sample = self.serve_input._call_load_sample(sample) - sample = self.dataloader_collate_fn(sample) + if not isinstance(sample, list): + sample = [sample] + sample = self.collate_fn(sample) return sample diff --git a/flash/core/data/data_module.py b/flash/core/data/data_module.py index 9056805542..2c0dbeb400 100644 --- a/flash/core/data/data_module.py +++ b/flash/core/data/data_module.py @@ -26,9 +26,13 @@ import flash from flash.core.data.base_viz import BaseVisualization from flash.core.data.callback import BaseDataFetcher -from flash.core.data.data_pipeline import DataPipeline, DataPipelineState from flash.core.data.io.input import DataKeys, Input, InputBase, IterableInput -from flash.core.data.io.input_transform import _InputTransformProcessorV2, InputTransform +from flash.core.data.io.input_transform import ( + _create_collate_input_transform_processors, + _InputTransformProcessorV2, + create_transform, + InputTransform, +) from flash.core.data.io.output_transform import OutputTransform from flash.core.data.splits import SplitDataset from flash.core.data.utils import _STAGES_PREFIX @@ -70,7 +74,6 @@ class DataModule(pl.LightningDataModule): """ input_transform_cls = InputTransform - output_transform_cls = OutputTransform input_transforms_registry: Optional[FlashRegistry] = None def __init__( @@ -85,8 +88,7 @@ def __init__( num_workers: int = 0, sampler: Optional[Type[Sampler]] = None, pin_memory: bool = True, - persistent_workers: bool = True, - output_transform: Optional[OutputTransform] = None, + persistent_workers: bool = False, ) -> None: if not batch_size: @@ -96,7 +98,6 @@ def __init__( batch_size = 16 self._input_transform: Optional[OutputTransform] = None - self._output_transform: Optional[OutputTransform] = output_transform self._viz: Optional[BaseVisualization] = None self._train_input = train_input @@ -119,10 +120,14 @@ def __init__( self._test_dataloader_collate_fn = self._resolve_dataloader_collate_fn(self._test_input) self._predict_dataloader_collate_fn = self._resolve_dataloader_collate_fn(self._predict_input) - self._train_on_after_batch_transfer_fn = self._resolve_on_after_batch_transfer_fn(self._train_input) - self._val_on_after_batch_transfer_fn = self._resolve_on_after_batch_transfer_fn(self._val_input) - self._test_on_after_batch_transfer_fn = self._resolve_on_after_batch_transfer_fn(self._test_input) - self._predict_on_after_batch_transfer_fn = self._resolve_on_after_batch_transfer_fn(self._predict_input) + self._on_after_batch_transfer_fns = { + RunningStage.TRAINING: self._resolve_on_after_batch_transfer_fn(self._train_input), + RunningStage.VALIDATING: self._resolve_on_after_batch_transfer_fn(self._val_input), + RunningStage.SANITY_CHECKING: self._resolve_on_after_batch_transfer_fn(self._val_input), + RunningStage.TESTING: self._resolve_on_after_batch_transfer_fn(self._test_input), + RunningStage.PREDICTING: self._resolve_on_after_batch_transfer_fn(self._predict_input), + } + self._model_on_after_batch_transfer_fns = None if self._train_input: self.train_dataloader = self._train_dataloader @@ -138,8 +143,6 @@ def __init__( self.batch_size = batch_size - if num_workers is None: - num_workers = 0 self.num_workers = num_workers self.persistent_workers = persistent_workers and num_workers > 0 self.pin_memory = pin_memory @@ -168,11 +171,6 @@ def predict_dataset(self) -> Optional[Input]: """This property returns the predict dataset.""" return self._predict_input - def _resolve_transform(self, ds: Optional[Input]) -> Optional[InputTransform]: - if not isinstance(ds, Input): - return None - return ds.transform - def _resolve_dataloader_collate_fn(self, ds: Optional[Input]) -> Optional[Callable]: if not ds: return None @@ -187,12 +185,14 @@ def _resolve_on_after_batch_transfer_fn(self, ds: Optional[Input]) -> Optional[C return ds._create_on_after_batch_transfer_fn([self.data_fetcher]) def _train_dataloader(self) -> DataLoader: - if isinstance(getattr(self, "trainer", None), pl.Trainer): - if isinstance(self.trainer.lightning_module, flash.Task): - self.connect(self.trainer.lightning_module) - train_ds: Input = self._train_input + collate_fn = self._train_dataloader_collate_fn + if isinstance(getattr(self, "trainer", None), pl.Trainer): + input_transform = getattr(self.trainer.lightning_module, "input_transform", None) + if input_transform is not None: + input_transform = create_transform(input_transform, RunningStage.TRAINING) + collate_fn = _create_collate_input_transform_processors(input_transform, [self.data_fetcher])[0] transform_processor = None if isinstance(collate_fn, _InputTransformProcessorV2): @@ -222,6 +222,7 @@ def _train_dataloader(self) -> DataLoader: drop_last=drop_last, collate_fn=collate_fn, sampler=sampler, + persistent_workers=self.persistent_workers, ) else: dataloader = DataLoader( @@ -240,15 +241,18 @@ def _train_dataloader(self) -> DataLoader: transform_processor.collate_fn = dataloader.collate_fn dataloader.collate_fn = transform_processor + self._model_on_after_batch_transfer_fns = None return dataloader def _val_dataloader(self) -> DataLoader: - if isinstance(getattr(self, "trainer", None), pl.Trainer): - if isinstance(self.trainer.lightning_module, flash.Task): - self.connect(self.trainer.lightning_module) - val_ds: Input = self._val_input + collate_fn = self._val_dataloader_collate_fn + if isinstance(getattr(self, "trainer", None), pl.Trainer): + input_transform = getattr(self.trainer.lightning_module, "input_transform", None) + if input_transform is not None: + input_transform = create_transform(input_transform, RunningStage.VALIDATING) + collate_fn = _create_collate_input_transform_processors(input_transform, [self.data_fetcher])[0] transform_processor = None if isinstance(collate_fn, _InputTransformProcessorV2): @@ -263,6 +267,7 @@ def _val_dataloader(self) -> DataLoader: num_workers=self.num_workers, pin_memory=self.pin_memory, collate_fn=collate_fn, + persistent_workers=self.persistent_workers, ) else: dataloader = DataLoader( @@ -278,15 +283,18 @@ def _val_dataloader(self) -> DataLoader: transform_processor.collate_fn = dataloader.collate_fn dataloader.collate_fn = transform_processor + self._model_on_after_batch_transfer_fns = None return dataloader def _test_dataloader(self) -> DataLoader: - if isinstance(getattr(self, "trainer", None), pl.Trainer): - if isinstance(self.trainer.lightning_module, flash.Task): - self.connect(self.trainer.lightning_module) - test_ds: Input = self._test_input + collate_fn = self._test_dataloader_collate_fn + if isinstance(getattr(self, "trainer", None), pl.Trainer): + input_transform = getattr(self.trainer.lightning_module, "input_transform", None) + if input_transform is not None: + input_transform = create_transform(input_transform, RunningStage.TESTING) + collate_fn = _create_collate_input_transform_processors(input_transform, [self.data_fetcher])[0] transform_processor = None if isinstance(collate_fn, _InputTransformProcessorV2): @@ -301,6 +309,7 @@ def _test_dataloader(self) -> DataLoader: num_workers=self.num_workers, pin_memory=self.pin_memory, collate_fn=collate_fn, + persistent_workers=self.persistent_workers, ) else: dataloader = DataLoader( @@ -316,15 +325,18 @@ def _test_dataloader(self) -> DataLoader: transform_processor.collate_fn = dataloader.collate_fn dataloader.collate_fn = transform_processor + self._model_on_after_batch_transfer_fns = None return dataloader def _predict_dataloader(self) -> DataLoader: - if isinstance(getattr(self, "trainer", None), pl.Trainer): - if isinstance(self.trainer.lightning_module, flash.Task): - self.connect(self.trainer.lightning_module) - predict_ds: Input = self._predict_input + collate_fn = self._predict_dataloader_collate_fn + if isinstance(getattr(self, "trainer", None), pl.Trainer): + input_transform = getattr(self.trainer.lightning_module, "input_transform", None) + if input_transform is not None: + input_transform = create_transform(input_transform, RunningStage.PREDICTING) + collate_fn = _create_collate_input_transform_processors(input_transform, [self.data_fetcher])[0] transform_processor = None if isinstance(collate_fn, _InputTransformProcessorV2): @@ -343,6 +355,7 @@ def _predict_dataloader(self) -> DataLoader: num_workers=self.num_workers, pin_memory=self.pin_memory, collate_fn=collate_fn, + persistent_workers=self.persistent_workers, ) else: dataloader = DataLoader( @@ -358,43 +371,44 @@ def _predict_dataloader(self) -> DataLoader: transform_processor.collate_fn = dataloader.collate_fn dataloader.collate_fn = transform_processor + self._model_on_after_batch_transfer_fns = None return dataloader - def connect(self, task: "flash.Task"): - data_pipeline_state = DataPipelineState() - for properties in [ - self._train_input, - self._val_input, - self._test_input, - self._predict_input, - getattr(self._train_input, "transform", None), - getattr(self._val_input, "transform", None), - getattr(self._test_input, "transform", None), - getattr(self._predict_input, "transform", None), - task._deserializer, - task._output_transform, - task._output, - task, + def _load_model_on_after_batch_transfer_fns(self) -> None: + self._model_on_after_batch_transfer_fns = {} + + for stage in [ + RunningStage.TRAINING, + RunningStage.VALIDATING, + RunningStage.SANITY_CHECKING, + RunningStage.TESTING, + RunningStage.PREDICTING, ]: - if properties is not None and hasattr(properties, "attach_data_pipeline_state"): - properties.attach_data_pipeline_state(data_pipeline_state) + transform = None + if isinstance(getattr(self, "trainer", None), pl.Trainer): + input_transform = getattr(self.trainer.lightning_module, "input_transform", None) + if input_transform is not None: + input_transform = create_transform( + input_transform, stage if stage != RunningStage.SANITY_CHECKING else RunningStage.VALIDATING + ) + transform = _create_collate_input_transform_processors(input_transform, [self.data_fetcher])[1] + self._model_on_after_batch_transfer_fns[stage] = transform def on_after_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any: if getattr(self, "trainer", None) is None: return batch - transform = None - if self.trainer.training: - transform = self._train_on_after_batch_transfer_fn - elif self.trainer.validating or self.trainer.sanity_checking: - transform = self._val_on_after_batch_transfer_fn - elif self.trainer.testing: - transform = self._test_on_after_batch_transfer_fn - elif self.trainer.predicting: - transform = self._predict_on_after_batch_transfer_fn + + if self._model_on_after_batch_transfer_fns is None: + self._load_model_on_after_batch_transfer_fns() + + stage = self.trainer.state.stage + + transform = self._model_on_after_batch_transfer_fns[stage] + if transform is None: + transform = self._on_after_batch_transfer_fns[stage] if transform: batch = transform(batch) - return batch @property @@ -509,24 +523,6 @@ def inputs(self) -> Optional[Union[Input, List[InputBase]]]: inputs = [self.train_dataset, self.val_dataset, self.test_dataset, self.predict_dataset] return [input for input in inputs if input] - @property - def input_transform(self) -> InputTransform: - """Property that returns the input transform class used on input data.""" - # Find a better way to resolve this. - return getattr(self.train_dataset, "transform", None) or self.input_transform_cls(RunningStage.TRAINING) - - @property - def output_transform(self) -> OutputTransform: - """Property that returns the :class:`~flash.core.data.io.output_transform.OutputTransform` used to - output_transform the model outputs.""" - return self._output_transform or self.output_transform_cls() - - @property - def data_pipeline(self) -> DataPipeline: - """Property that returns the full data pipeline including the data source, input transform and - postprocessing.""" - return DataPipeline(self.inputs, self.input_transform, self.output_transform) - @staticmethod def _split_train_val( train_dataset: Dataset, diff --git a/flash/core/data/data_pipeline.py b/flash/core/data/data_pipeline.py index b310ef8882..9f653cb606 100644 --- a/flash/core/data/data_pipeline.py +++ b/flash/core/data/data_pipeline.py @@ -12,40 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect -from typing import Any, Dict, List, Optional, Set, Type, Union +from typing import Any, Optional, Type from pytorch_lightning.utilities.exceptions import MisconfigurationException -from flash.core.data.io.input import Input, InputBase from flash.core.data.io.input_transform import InputTransform -from flash.core.data.io.output import _OutputProcessor, Output -from flash.core.data.io.output_transform import _OutputTransformProcessor, OutputTransform -from flash.core.data.process import Deserializer -from flash.core.data.properties import ProcessState -from flash.core.data.utils import _INPUT_TRANSFORM_FUNCS, _OUTPUT_TRANSFORM_FUNCS from flash.core.utilities.stages import RunningStage -class DataPipelineState: - """A class to store and share all process states once a :class:`.DataPipeline` has been initialized.""" - - def __init__(self): - self._state: Dict[Type[ProcessState], ProcessState] = {} - - def set_state(self, state: ProcessState): - """Add the given :class:`.ProcessState` to the :class:`.DataPipelineState`.""" - - self._state[type(state)] = state - - def get_state(self, state_type: Type[ProcessState]) -> Optional[ProcessState]: - """Get the :class:`.ProcessState` of the given type from the :class:`.DataPipelineState`.""" - - return self._state.get(state_type, None) - - def __str__(self) -> str: - return f"{self.__class__.__name__}(state={self._state})" - - class DataPipeline: """ DataPipeline holds the engineering logic to connect @@ -54,43 +28,6 @@ class DataPipeline: objects to the ``DataModule``, Flash ``Task`` and ``Trainer``. """ - INPUT_TRANSFORM_FUNCS: Set[str] = _INPUT_TRANSFORM_FUNCS - OUTPUT_TRANSFORM_FUNCS: Set[str] = _OUTPUT_TRANSFORM_FUNCS - - def __init__( - self, - input: Optional[Union[Input, List[InputBase]]] = None, - input_transform: Optional[InputTransform] = None, - output_transform: Optional[OutputTransform] = None, - deserializer: Optional[Deserializer] = None, - output: Optional[Output] = None, - ) -> None: - self.input = input - - self._input_transform_pipeline = input_transform or InputTransform(RunningStage.TRAINING) - self._output_transform = output_transform or OutputTransform() - self._output = output or Output() - self._deserializer = deserializer or Deserializer() - self._running_stage = None - - def initialize(self, data_pipeline_state: Optional[DataPipelineState] = None) -> DataPipelineState: - """Creates the :class:`.DataPipelineState` and gives the reference to the: :class:`.InputTransform`, - :class:`.OutputTransform`, and :class:`.Output`. Once this has been called, any attempt to add new state will - give a warning.""" - data_pipeline_state = data_pipeline_state or DataPipelineState() - if self.input is not None: - if isinstance(self.input, list): - for input in self.input: - if hasattr(input, "attach_data_pipeline_state"): - input.attach_data_pipeline_state(data_pipeline_state) - else: - self.input.attach_data_pipeline_state(data_pipeline_state) - self._deserializer.attach_data_pipeline_state(data_pipeline_state) - self._input_transform_pipeline.attach_data_pipeline_state(data_pipeline_state) - self._output_transform.attach_data_pipeline_state(data_pipeline_state) - self._output.attach_data_pipeline_state(data_pipeline_state) - return data_pipeline_state - @staticmethod def _is_overridden(method_name: str, process_obj, super_obj: Any, prefix: Optional[str] = None) -> bool: """Cropped Version of https://github.com/PyTorchLightning/pytorch- @@ -129,12 +66,6 @@ def _is_overridden_recursive( return has_different_code return has_different_code or cls._is_overridden_recursive(method_name, process_obj, super_obj) - def output_transform_processor(self, running_stage: RunningStage, is_serving=False) -> _OutputTransformProcessor: - return self._create_output_transform_processor(running_stage, is_serving=is_serving) - - def output_processor(self) -> _OutputProcessor: - return _OutputProcessor(self._output) - @classmethod def _resolve_function_hierarchy( cls, function_name, process_obj, stage: RunningStage, object_type: Optional[Type] = None @@ -162,38 +93,3 @@ def _resolve_function_hierarchy( return function_name if prefix is None else f"{prefix}_{function_name}" return function_name - - def _create_output_transform_processor( - self, - stage: RunningStage, - is_serving: bool = False, - ) -> _OutputTransformProcessor: - output_transform: OutputTransform = self._output_transform - - func_names: Dict[str, str] = { - k: self._resolve_function_hierarchy(k, output_transform, stage, object_type=OutputTransform) - for k in self.OUTPUT_TRANSFORM_FUNCS - } - - return _OutputTransformProcessor( - getattr(output_transform, func_names["uncollate"]), - getattr(output_transform, func_names["per_batch_transform"]), - getattr(output_transform, func_names["per_sample_transform"]), - output=None if is_serving else self._output, - is_serving=is_serving, - ) - - def __str__(self) -> str: - input: Input = self.input - input_transform: InputTransform = self._input_transform_pipeline - output_transform: OutputTransform = self._output_transform - output: Output = self._output - deserializer: Deserializer = self._deserializer - return ( - f"{self.__class__.__name__}(" - f"input={str(input)}, " - f"deserializer={deserializer}, " - f"input_transform={input_transform}, " - f"output_transform={output_transform}, " - f"output={output})" - ) diff --git a/flash/core/data/io/classification_input.py b/flash/core/data/io/classification_input.py index 3cea62c05f..bac6925cdf 100644 --- a/flash/core/data/io/classification_input.py +++ b/flash/core/data/io/classification_input.py @@ -11,22 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass -from typing import Any, List, Optional, Sequence +from typing import Any, List, Optional -from flash.core.data.properties import ProcessState, Properties +from flash.core.data.properties import Properties from flash.core.data.utilities.classification import get_target_formatter, TargetFormatter -@dataclass(unsafe_hash=True, frozen=True) -class ClassificationState(ProcessState): - """A :class:`~flash.core.data.properties.ProcessState` containing ``labels`` (a mapping from class index to - label) and ``num_classes``.""" - - labels: Optional[Sequence[str]] - num_classes: Optional[int] = None - - class ClassificationInputMixin(Properties): """The ``ClassificationInputMixin`` class provides utility methods for handling classification targets. :class:`~flash.core.data.io.input.Input` objects that extend ``ClassificationInputMixin`` should do the following: @@ -52,22 +42,14 @@ def load_target_metadata( add_background: If ``True``, a background class will be inserted as class zero if ``labels`` and ``num_classes`` are being inferred. """ + self.target_formatter = target_formatter if target_formatter is None and targets is not None: - classification_state = self.get_state(ClassificationState) - if classification_state is not None: - labels, num_classes = classification_state.labels, classification_state.num_classes - else: - labels, num_classes = None, None + self.target_formatter = get_target_formatter(targets, add_background=add_background) - self.target_formatter = get_target_formatter(targets, labels, num_classes, add_background=add_background) - else: - self.target_formatter = target_formatter - - if getattr(self, "target_formatter", None) is not None: + if self.target_formatter is not None: self.multi_label = self.target_formatter.multi_label self.labels = self.target_formatter.labels self.num_classes = self.target_formatter.num_classes - self.set_state(ClassificationState(self.labels, self.num_classes)) def format_target(self, target: Any) -> Any: """Format a single target according to the previously computed target format and metadata. diff --git a/flash/core/data/io/input.py b/flash/core/data/io/input.py index 22443b853f..04dba6e7f8 100644 --- a/flash/core/data/io/input.py +++ b/flash/core/data/io/input.py @@ -15,9 +15,8 @@ import os import sys from copy import deepcopy -from dataclasses import dataclass from functools import partial -from typing import Any, Callable, cast, Dict, Iterable, List, MutableMapping, Optional, Sequence, Tuple, Type, Union +from typing import Any, Callable, cast, Dict, Iterable, List, Optional, Sequence, Tuple, Type, Union from pytorch_lightning.utilities.enums import LightningEnum from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -25,7 +24,8 @@ import flash from flash.core.data.callback import FlashCallback -from flash.core.data.properties import ProcessState, Properties +from flash.core.data.properties import Properties +from flash.core.data.utils import _STAGES_PREFIX from flash.core.registry import FlashRegistry from flash.core.utilities.stages import RunningStage from flash.core.utilities.types import INPUT_TRANSFORM_TYPE @@ -44,12 +44,6 @@ IterableDataset = object -@dataclass(unsafe_hash=True, frozen=True) -class ImageLabelsMap(ProcessState): - - labels_map: Optional[Dict[int, Tuple[int, int, int]]] - - class InputFormat(LightningEnum): """The ``InputFormat`` enum contains the data source names used by all of the default ``from_*`` methods in :class:`~flash.core.data.data_module.DataModule`.""" @@ -181,7 +175,6 @@ def __init__( transform: INPUT_TRANSFORM_TYPE = None, transform_kwargs: Optional[Dict] = None, input_transforms_registry: Optional[FlashRegistry] = None, - data_pipeline_state: Optional["flash.core.data.data_pipeline.DataPipelineState"] = None, **kwargs: Any, ) -> None: from flash.core.data.io.input_transform import create_transform @@ -189,15 +182,14 @@ def __init__( self.transform = create_transform( transform, running_stage, - data_pipeline_state, input_transforms_registry or self.input_transforms_registry, transform_kwargs, ) - super().__init__(running_stage=running_stage, data_pipeline_state=data_pipeline_state) + super().__init__(running_stage=running_stage) self.data = None if len(args) >= 1 and args[0] is not None: - self.data = self._call_load_data(*args, **kwargs) + self.data = getattr(self, f"{_STAGES_PREFIX[running_stage]}_load_data")(*args, **kwargs) def _create_dataloader_collate_fn(self, callbacks: List[FlashCallback]) -> Optional[Callable]: from flash.core.data.io.input_transform import _create_collate_input_transform_processors @@ -213,29 +205,9 @@ def _create_on_after_batch_transfer_fn(self, callbacks: List[FlashCallback]) -> return return _create_collate_input_transform_processors(self.transform, callbacks)[1] - def _call_load_data(self, *args: Any, **kwargs: Any) -> Union[Sequence, Iterable]: - from flash.core.data.data_pipeline import DataPipeline - - load_data = getattr( - self, DataPipeline._resolve_function_hierarchy("load_data", self, self.running_stage, InputBase) - ) - return load_data(*args, **kwargs) - def _call_load_sample(self, sample: Any) -> Any: - from flash.core.data.data_pipeline import DataPipeline - - load_sample = getattr( - self, - DataPipeline._resolve_function_hierarchy( - "load_sample", - self, - self.running_stage, - InputBase, - ), - ) - # Deepcopy the sample to avoid leaks with complex data structures - return load_sample(deepcopy(sample)) + return getattr(self, f"{_STAGES_PREFIX[self.running_stage]}_load_sample")(deepcopy(sample)) @staticmethod def load_data(*args: Any, **kwargs: Any) -> Union[Sequence, Iterable]: @@ -249,8 +221,44 @@ def load_data(*args: Any, **kwargs: Any) -> Union[Sequence, Iterable]: """ return args[0] + def train_load_data(self, *args: Any, **kwargs: Any) -> Union[Sequence, Iterable]: + """Override the ``train_load_data`` hook with data loading logic that is only required during training. + + Args: + *args: Any arguments that the input requires. + **kwargs: Any additional keyword arguments that the input requires. + """ + return self.load_data(*args, **kwargs) + + def val_load_data(self, *args: Any, **kwargs: Any) -> Union[Sequence, Iterable]: + """Override the ``val_load_data`` hook with data loading logic that is only required during validating. + + Args: + *args: Any arguments that the input requires. + **kwargs: Any additional keyword arguments that the input requires. + """ + return self.load_data(*args, **kwargs) + + def test_load_data(self, *args: Any, **kwargs: Any) -> Union[Sequence, Iterable]: + """Override the ``test_load_data`` hook with data loading logic that is only required during testing. + + Args: + *args: Any arguments that the input requires. + **kwargs: Any additional keyword arguments that the input requires. + """ + return self.load_data(*args, **kwargs) + + def predict_load_data(self, *args: Any, **kwargs: Any) -> Union[Sequence, Iterable]: + """Override the ``predict_load_data`` hook with data loading logic that is only required during predicting. + + Args: + *args: Any arguments that the input requires. + **kwargs: Any additional keyword arguments that the input requires. + """ + return self.load_data(*args, **kwargs) + @staticmethod - def load_sample(sample: MutableMapping[str, Any]) -> Any: + def load_sample(sample: Dict[str, Any]) -> Any: """The ``load_sample`` hook is called for each ``__getitem__`` or ``__next__`` call to the dataset with a single sample from the output of the ``load_data`` hook as input. @@ -259,6 +267,39 @@ def load_sample(sample: MutableMapping[str, Any]) -> Any: """ return sample + def train_load_sample(self, sample: Dict[str, Any]) -> Any: + """Override the ``train_load_sample`` hook with data loading logic that is only required during training. + + Args: + sample: A single sample from the output of the ``load_data`` hook. + """ + return self.load_sample(sample) + + def val_load_sample(self, sample: Dict[str, Any]) -> Any: + """Override the ``val_load_sample`` hook with data loading logic that is only required during validating. + + Args: + sample: A single sample from the output of the ``load_data`` hook. + """ + return self.load_sample(sample) + + def test_load_sample(self, sample: Dict[str, Any]) -> Any: + """Override the ``test_load_sample`` hook with data loading logic that is only required during testing. + + Args: + sample: A single sample from the output of the ``load_data`` hook. + """ + return self.load_sample(sample) + + def predict_load_sample(self, sample: Dict[str, Any]) -> Any: + """Override the ``predict_load_sample`` hook with data loading logic that is only required during + predicting. + + Args: + sample: A single sample from the output of the ``load_data`` hook. + """ + return self.load_sample(sample) + def __getstate__(self): """Temporarily override pickle behaviour. @@ -336,7 +377,6 @@ def __init__( self, transform: INPUT_TRANSFORM_TYPE = None, transform_kwargs: Optional[Dict] = None, - data_pipeline_state: Optional["flash.core.data.data_pipeline.DataPipelineState"] = None, ) -> None: if hasattr(self, "serve_load_data"): raise MisconfigurationException("`serve_load_data` shouldn't be implemented.") @@ -345,7 +385,6 @@ def __init__( RunningStage.SERVING, transform=transform, transform_kwargs=transform_kwargs, - data_pipeline_state=data_pipeline_state, ) def serve_load_sample(self, sample: Any) -> List[Any]: diff --git a/flash/core/data/io/input_transform.py b/flash/core/data/io/input_transform.py index 46be001a1a..c4aa3e7588 100644 --- a/flash/core/data/io/input_transform.py +++ b/flash/core/data/io/input_transform.py @@ -20,17 +20,9 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch.utils.data._utils.collate import default_collate -import flash from flash.core.data.callback import ControlFlow, FlashCallback from flash.core.data.io.input import DataKeys -from flash.core.data.properties import ProcessState, Properties -from flash.core.data.states import ( - CollateFn, - PerBatchTransform, - PerBatchTransformOnDevice, - PerSampleTransform, - PerSampleTransformOnDevice, -) +from flash.core.data.properties import Properties from flash.core.data.transforms import ApplyToKeys from flash.core.data.utils import _INPUT_TRANSFORM_FUNCS, _STAGES_PREFIX from flash.core.registry import FlashRegistry @@ -89,28 +81,19 @@ def __repr__(self): return format_string -class InputTransformState(dict): - pass - - @dataclass class InputTransform(Properties): running_stage: RunningStage - data_pipeline_state: Optional["flash.core.data.data_pipeline.DataPipelineState"] = None def __post_init__(self): - transform_kwargs = { - k: v for k, v in self.__dict__.items() if k not in ("_running_stage", "data_pipeline_state") - } # used to keep track of provided transforms self._collate_in_worker_from_transform: Optional[bool] = None self._transform = None self._transform = self._check_transforms(self._resolve_transforms(self.running_stage), self.running_stage) # Hack - Properties.__init__(self, data_pipeline_state=self.data_pipeline_state, running_stage=self.running_stage) - self.set_state(InputTransformState(**transform_kwargs)) + Properties.__init__(self, running_stage=self.running_stage) @property def current_transform(self) -> Callable: @@ -850,7 +833,7 @@ def collate(self) -> Callable: @partial(transform_context, current_fn="per_sample_transform") def _per_sample_transform(self, sample: Any) -> Any: - fn = self._get_current_transform(PerSampleTransform) + fn = self.current_transform if isinstance(sample, list): return [fn(s) for s in sample] return fn(sample) @@ -859,37 +842,20 @@ def _per_sample_transform(self, sample: Any) -> Any: def _per_batch_transform(self, batch: Any) -> Any: """Transforms to apply to a whole batch (if possible use this for efficiency). - .. note:: This option is mutually exclusive with :meth:`per_sample_transform_on_device`, since if both - are specified, uncollation has to be applied. + .. note:: This option is mutually exclusive with :meth:`per_sample_transform_on_device`, since if both are + specified, uncollation has to be applied. """ - return self._get_current_transform(PerBatchTransform)(batch) + return self.current_transform(batch) @partial(transform_context, current_fn="collate") def _collate(self, samples: Sequence, metadata=None) -> Any: """Transform to convert a sequence of samples to a collated batch.""" - current_transform = self.current_transform - - # the model can provide a custom ``collate_fn``. - collate_fn = self.get_state(CollateFn) - if collate_fn is not None: - collate_fn = collate_fn.collate_fn - else: - collate_fn = current_transform - # return collate_fn.collate_fn(samples) - + collate_fn = self.current_transform parameters = inspect.signature(collate_fn).parameters if len(parameters) > 1 and DataKeys.METADATA in parameters: return collate_fn(samples, metadata) return collate_fn(samples) - def _get_current_transform(self, process_state: ProcessState): - fn = self.get_state(process_state) - if fn is not None: - if fn.transform is not None: - return fn.transform - return self._identity - return self.current_transform - @partial(transform_context, current_fn="per_sample_transform_on_device") def _per_sample_transform_on_device(self, sample: Any) -> Any: """Transforms to apply to the data before the collation (per-sample basis). @@ -899,7 +865,7 @@ def _per_sample_transform_on_device(self, sample: Any) -> Any: workers, since to make that happen each of the workers would have to create it's own CUDA-context which would pollute GPU memory (if on GPU). """ - fn = self._get_current_transform(PerSampleTransformOnDevice) + fn = self.current_transform if isinstance(sample, list): return [fn(s) for s in sample] return fn(sample) @@ -911,7 +877,7 @@ def _per_batch_transform_on_device(self, batch: Any) -> Any: .. note:: This function won't be called within the dataloader workers, since to make that happen each of the workers would have to create it's own CUDA-context which would pollute GPU memory (if on GPU). """ - return self._get_current_transform(PerBatchTransformOnDevice)(batch) + return self.current_transform(batch) ############# # UTILITIES # @@ -1033,11 +999,7 @@ def _get_transform(self, transform: Dict[str, Callable]) -> Callable: return self._identity def __str__(self) -> str: - state = self.get_state(InputTransformState) - return ( - f"{self.__class__.__name__}(" - + f"running_stage={self.running_stage}, state: {state}, transform={self._transform})" - ) + return f"{self.__class__.__name__}(" + f"running_stage={self.running_stage}, transform={self._transform})" def __getitem__(self, placement: InputTransformPlacement) -> Callable: return self._transform[placement] @@ -1047,7 +1009,6 @@ def __getitem__(self, placement: InputTransformPlacement) -> Callable: class LambdaInputTransform(InputTransform): transform: Callable = InputTransform._identity - data_pipeline_state: Optional["flash.core.data.data_pipeline.DataPipelineState"] = None def per_sample_transform(self) -> Callable: return self.transform @@ -1077,7 +1038,6 @@ def _sanitize_registry_transform( def create_transform( transform: INPUT_TRANSFORM_TYPE, running_stage: RunningStage, - data_pipeline_state: Optional["flash.core.data.data_pipeline.DataPipelineState"] = None, input_transforms_registry: Optional[FlashRegistry] = None, transform_kwargs: Optional[Dict] = None, ) -> Optional["InputTransform"]: @@ -1086,24 +1046,22 @@ def create_transform( transform_kwargs = {} if isinstance(transform, InputTransform): - transform._data_pipeline_state = data_pipeline_state return transform if inspect.isclass(transform) and issubclass(transform, InputTransform): - return transform(running_stage=running_stage, data_pipeline_state=data_pipeline_state, **transform_kwargs) + return transform(running_stage=running_stage, **transform_kwargs) if isinstance(transform, Callable): return LambdaInputTransform( running_stage=running_stage, transform=transform, - data_pipeline_state=data_pipeline_state, **transform_kwargs, ) if isinstance(transform, tuple) or isinstance(transform, (LightningEnum, str)): enum, transform_kwargs = _sanitize_registry_transform(transform, input_transforms_registry) transform_cls = input_transforms_registry.get(enum) - return transform_cls(running_stage, data_pipeline_state=data_pipeline_state, **transform_kwargs) + return transform_cls(running_stage, **transform_kwargs) if not transform: return None @@ -1164,43 +1122,37 @@ def __call__(self, samples: Sequence[Any]) -> Any: for sample in samples: self.callback.on_load_sample(sample, self.stage) - # we create a new dict to prevent from potential memory leaks - # assuming that the dictionary samples are stored in between and - # potentially modified before the transforms are applied. - if isinstance(samples, dict): - samples = dict(samples.items()) - if self.apply_per_sample_transform: - _samples = [] + if not isinstance(samples, list): + list_samples = [samples] + else: + list_samples = samples - if isinstance(samples, Mapping): - samples = [samples] + transformed_samples = [self.per_sample_transform(sample) for sample in list_samples] - for sample in samples: - sample = self.per_sample_transform(sample) + for sample in transformed_samples: if self.on_device: self.callback.on_per_sample_transform_on_device(sample, self.stage) else: self.callback.on_per_sample_transform(sample, self.stage) - _samples.append(sample) - - samples = type(_samples)(_samples) - samples, metadata = self._extract_metadata(samples) + extracted_samples, metadata = self._extract_metadata(transformed_samples) try: - samples = self.collate_fn(samples, metadata) + collated_samples = self.collate_fn(extracted_samples, metadata) except TypeError: - samples = self.collate_fn(samples) - if metadata and isinstance(samples, dict): - samples[DataKeys.METADATA] = metadata - self.callback.on_collate(samples, self.stage) + collated_samples = self.collate_fn(extracted_samples) + if metadata and isinstance(collated_samples, dict): + collated_samples[DataKeys.METADATA] = metadata + self.callback.on_collate(collated_samples, self.stage) + else: + collated_samples = samples - samples = self.per_batch_transform(samples) + transformed_collated_samples = self.per_batch_transform(collated_samples) if self.on_device: - self.callback.on_per_batch_transform_on_device(samples, self.stage) + self.callback.on_per_batch_transform_on_device(transformed_collated_samples, self.stage) else: - self.callback.on_per_batch_transform(samples, self.stage) - return samples + self.callback.on_per_batch_transform(transformed_collated_samples, self.stage) + return transformed_collated_samples def __str__(self) -> str: # todo: define repr function which would take object and string attributes to be shown diff --git a/flash/core/data/io/output.py b/flash/core/data/io/output.py index 816cb01213..8802d125f8 100644 --- a/flash/core/data/io/output.py +++ b/flash/core/data/io/output.py @@ -11,18 +11,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from abc import abstractmethod from typing import Any -import torch - +import flash from flash.core.data.properties import Properties -from flash.core.data.utils import convert_to_modules class Output(Properties): """An :class:`.Output` encapsulates a single :meth:`~flash.core.data.io.output.Output.transform` method which is used to convert the model output into the desired output format when predicting.""" + @classmethod + @abstractmethod + def from_task(cls, task: "flash.Task", **kwargs) -> "Output": + """Instantiate the output from the given :class:`~flash.core.model.Task`.""" + return cls() + @staticmethod def transform(sample: Any) -> Any: """Convert the given sample into the desired output format. @@ -37,15 +42,3 @@ def transform(sample: Any) -> Any: def __call__(self, sample: Any) -> Any: return self.transform(sample) - - -class _OutputProcessor(torch.nn.Module): - def __init__( - self, - output: "Output", - ): - super().__init__() - self.output = convert_to_modules(output) - - def forward(self, sample): - return self.output(sample) diff --git a/flash/core/data/io/output_transform.py b/flash/core/data/io/output_transform.py index 44239fc24e..76bc5edf2f 100644 --- a/flash/core/data/io/output_transform.py +++ b/flash/core/data/io/output_transform.py @@ -11,14 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Optional, Sequence - -import torch -from torch import Tensor +from typing import Any, Sequence from flash.core.data.batch import default_uncollate from flash.core.data.properties import Properties -from flash.core.data.utils import convert_to_modules class OutputTransform(Properties): @@ -45,57 +41,14 @@ def per_sample_transform(sample: Any) -> Any: def uncollate(batch: Any) -> Any: """Uncollates a batch into single samples. - Tries to preserve the type whereever possible. + Tries to preserve the type wherever possible. """ return default_uncollate(batch) - -class _OutputTransformProcessor(torch.nn.Module): - """This class is used to encapsultate the following functions of a OutputTransform Object: - - Inside main process: - per_batch_transform: Function to transform a batch - per_sample_transform: Function to transform an individual sample - uncollate_fn: Function to split a batch into samples - per_sample_transform: Function to transform an individual sample - is_serving: Whether the Postprocessor is used in serving mode. - """ - - def __init__( - self, - uncollate_fn: Callable, - per_batch_transform: Callable, - per_sample_transform: Callable, - output: Optional[Callable], - is_serving: bool = False, - ): - super().__init__() - self.uncollate_fn = convert_to_modules(uncollate_fn) - self.per_batch_transform = convert_to_modules(per_batch_transform) - self.per_sample_transform = convert_to_modules(per_sample_transform) - self.output = convert_to_modules(output) - self.is_serving = is_serving - - def forward(self, batch: Sequence[Any]): + def __call__(self, batch: Sequence[Any]): if batch is None: return batch - uncollated = self.uncollate_fn(self.per_batch_transform(batch)) - - final_preds = [self.per_sample_transform(sample) for sample in uncollated] - - if self.output is not None: - final_preds = [self.output(sample) for sample in final_preds] - - if isinstance(uncollated, Tensor) and isinstance(final_preds[0], Tensor): - return torch.stack(final_preds) - return type(final_preds)(final_preds) + uncollated = self.uncollate(self.per_batch_transform(batch)) - def __str__(self) -> str: - return ( - "_OutputTransformProcessor:\n" - f"\t(per_batch_transform): {str(self.per_batch_transform)}\n" - f"\t(uncollate_fn): {str(self.uncollate_fn)}\n" - f"\t(per_sample_transform): {str(self.per_sample_transform)}\n" - f"\t(output): {str(self.output)}" - ) + return [self.per_sample_transform(sample) for sample in uncollated] diff --git a/flash/core/data/output.py b/flash/core/data/output.py index 337386b14f..61ba300ca5 100644 --- a/flash/core/data/output.py +++ b/flash/core/data/output.py @@ -15,8 +15,13 @@ from flash.core.data.io.input import DataKeys from flash.core.data.io.output import Output +from flash.core.registry import FlashRegistry +BASE_OUTPUTS = FlashRegistry("outputs") +BASE_OUTPUTS(name="raw")(Output) + +@BASE_OUTPUTS(name="preds") class PredsOutput(Output): """A :class:`~flash.core.data.io.output.Output` which returns the "preds" from the model outputs.""" diff --git a/flash/core/data/process.py b/flash/core/data/process.py index 83b6199d76..c83f4a7277 100644 --- a/flash/core/data/process.py +++ b/flash/core/data/process.py @@ -17,7 +17,6 @@ from deprecate import deprecated -import flash from flash.core.data.io.input import ServeInput as Deserializer from flash.core.data.io.output import Output @@ -36,10 +35,6 @@ def deserialize(self, sample: Any) -> Any: return {key: deserializer.deserialize(sample[key]) for key, deserializer in self._deserializers.items()} raise ValueError("The model output must be a mapping when using a DeserializerMapping.") - def attach_data_pipeline_state(self, data_pipeline_state: "flash.core.data.data_pipeline.DataPipelineState"): - for deserializer in self._deserializers.values(): - deserializer.attach_data_pipeline_state(data_pipeline_state) - class Serializer(Output): """Deprecated. diff --git a/flash/core/data/properties.py b/flash/core/data/properties.py index b30c7ba02c..a31de446e4 100644 --- a/flash/core/data/properties.py +++ b/flash/core/data/properties.py @@ -11,54 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass -from typing import Dict, Optional, Type, TypeVar +from typing import Optional -import flash from flash.core.utilities.stages import RunningStage -@dataclass(unsafe_hash=True, frozen=True) -class ProcessState: - """Base class for all process states.""" - - -STATE_TYPE = TypeVar("STATE_TYPE", bound=ProcessState) - - class Properties: def __init__( self, running_stage: Optional[RunningStage] = None, - data_pipeline_state: Optional["flash.core.data.data_pipeline.DataPipelineState"] = None, - state: Dict[Type[ProcessState], ProcessState] = None, ): super().__init__() self._running_stage = running_stage self._current_fn: Optional[str] = None - self._data_pipeline_state = data_pipeline_state - self._state: Dict[Type[ProcessState], ProcessState] = {} if state is None else state - - def get_state(self, state_type: Type[STATE_TYPE]) -> Optional[STATE_TYPE]: - if state_type in self._state: - return self._state[state_type] - if self._data_pipeline_state is not None: - return self._data_pipeline_state.get_state(state_type) - return None - - def set_state(self, state: ProcessState): - self._state[type(state)] = state - if self._data_pipeline_state is not None: - self._data_pipeline_state.set_state(state) - - def attach_data_pipeline_state(self, data_pipeline_state: "flash.core.data.data_pipeline.DataPipelineState"): - for state in self._state.values(): - data_pipeline_state.set_state(state) - if self._data_pipeline_state: - for state in self._data_pipeline_state._state.values(): - data_pipeline_state.set_state(state) - self._data_pipeline_state = data_pipeline_state @property def current_fn(self) -> Optional[str]: diff --git a/flash/core/data/splits.py b/flash/core/data/splits.py index 60bdafb16f..e26e7c4679 100644 --- a/flash/core/data/splits.py +++ b/flash/core/data/splits.py @@ -4,7 +4,6 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch.utils.data import Dataset -import flash from flash.core.data.properties import Properties @@ -29,8 +28,6 @@ def __init__(self, dataset: Any, indices: List[int] = None, use_duplicated_indic if isinstance(dataset, Properties): kwargs = dict( running_stage=dataset._running_stage, - data_pipeline_state=dataset._data_pipeline_state, - state=dataset._state, ) super().__init__(**kwargs) @@ -50,11 +47,6 @@ def __init__(self, dataset: Any, indices: List[int] = None, use_duplicated_indic self.dataset = dataset self.indices = indices - def attach_data_pipeline_state(self, data_pipeline_state: "flash.core.data.data_pipeline.DataPipelineState"): - super().attach_data_pipeline_state(data_pipeline_state) - if isinstance(self.dataset, Properties): - self.dataset.attach_data_pipeline_state(data_pipeline_state) - def __getattr__(self, key: str): if key != "dataset": return getattr(self.dataset, key) diff --git a/flash/core/data/states.py b/flash/core/data/states.py deleted file mode 100644 index 8dbb738b4a..0000000000 --- a/flash/core/data/states.py +++ /dev/null @@ -1,34 +0,0 @@ -from dataclasses import dataclass -from typing import Callable, Optional - -from flash.core.data.properties import ProcessState - - -@dataclass(unsafe_hash=True, frozen=True) -class PerSampleTransform(ProcessState): - - transform: Optional[Callable] = None - - -@dataclass(unsafe_hash=True, frozen=True) -class PerSampleTransformOnDevice(ProcessState): - - transform: Optional[Callable] = None - - -@dataclass(unsafe_hash=True, frozen=True) -class PerBatchTransform(ProcessState): - - transform: Optional[Callable] = None - - -@dataclass(unsafe_hash=True, frozen=True) -class PerBatchTransformOnDevice(ProcessState): - - transform: Optional[Callable] = None - - -@dataclass(unsafe_hash=True, frozen=True) -class CollateFn(ProcessState): - - collate_fn: Optional[Callable] = None diff --git a/flash/core/integrations/icevision/adapter.py b/flash/core/integrations/icevision/adapter.py index 9f4ac416e3..925e79a150 100644 --- a/flash/core/integrations/icevision/adapter.py +++ b/flash/core/integrations/icevision/adapter.py @@ -85,7 +85,7 @@ def from_task( return cls(model_type, model, icevision_adapter, backbone, predict_kwargs) @staticmethod - def _collate_fn(collate_fn, samples, metadata: Optional[List[Dict[str, Any]]] = None): + def _wrap_collate_fn(collate_fn, samples, metadata: Optional[List[Dict[str, Any]]] = None): metadata = metadata or [None] * len(samples) return { DataKeys.INPUT: collate_fn( @@ -105,6 +105,7 @@ def process_train_dataset( shuffle: bool = False, drop_last: bool = False, sampler: Optional[Sampler] = None, + persistent_workers: bool = False, ) -> DataLoader: data_loader = self.model_type.train_dl( dataset, @@ -114,8 +115,9 @@ def process_train_dataset( shuffle=shuffle, drop_last=drop_last, sampler=sampler, + persistent_workers=persistent_workers, ) - data_loader.collate_fn = functools.partial(self._collate_fn, data_loader.collate_fn) + data_loader.collate_fn = functools.partial(self._wrap_collate_fn, data_loader.collate_fn) return data_loader def process_val_dataset( @@ -129,6 +131,7 @@ def process_val_dataset( shuffle: bool = False, drop_last: bool = False, sampler: Optional[Sampler] = None, + persistent_workers: bool = False, ) -> DataLoader: data_loader = self.model_type.valid_dl( dataset, @@ -138,8 +141,9 @@ def process_val_dataset( shuffle=shuffle, drop_last=drop_last, sampler=sampler, + persistent_workers=persistent_workers, ) - data_loader.collate_fn = functools.partial(self._collate_fn, data_loader.collate_fn) + data_loader.collate_fn = functools.partial(self._wrap_collate_fn, data_loader.collate_fn) return data_loader def process_test_dataset( @@ -153,6 +157,7 @@ def process_test_dataset( shuffle: bool = False, drop_last: bool = False, sampler: Optional[Sampler] = None, + persistent_workers: bool = False, ) -> DataLoader: data_loader = self.model_type.valid_dl( dataset, @@ -162,8 +167,9 @@ def process_test_dataset( shuffle=shuffle, drop_last=drop_last, sampler=sampler, + persistent_workers=persistent_workers, ) - data_loader.collate_fn = functools.partial(self._collate_fn, data_loader.collate_fn) + data_loader.collate_fn = functools.partial(self._wrap_collate_fn, data_loader.collate_fn) return data_loader def process_predict_dataset( @@ -176,6 +182,7 @@ def process_predict_dataset( shuffle: bool = False, drop_last: bool = True, sampler: Optional[Sampler] = None, + persistent_workers: bool = False, ) -> DataLoader: data_loader = self.model_type.infer_dl( dataset, @@ -185,8 +192,9 @@ def process_predict_dataset( shuffle=shuffle, drop_last=drop_last, sampler=sampler, + persistent_workers=persistent_workers, ) - data_loader.collate_fn = functools.partial(self._collate_fn, data_loader.collate_fn) + data_loader.collate_fn = functools.partial(self._wrap_collate_fn, data_loader.collate_fn) return data_loader def training_step(self, batch, batch_idx) -> Any: diff --git a/flash/core/integrations/icevision/data.py b/flash/core/integrations/icevision/data.py index 376b747a6e..af52394c3b 100644 --- a/flash/core/integrations/icevision/data.py +++ b/flash/core/integrations/icevision/data.py @@ -16,7 +16,6 @@ import numpy as np -from flash.core.data.io.classification_input import ClassificationState from flash.core.data.io.input import DataKeys, Input from flash.core.data.utilities.paths import list_valid_files from flash.core.integrations.icevision.transforms import from_icevision_record @@ -49,7 +48,6 @@ def load_data( if class_map is not None: self.num_classes = class_map.num_classes self.labels = [class_map.get_by_id(i) for i in range(self.num_classes)] - self.set_state(ClassificationState(self.labels)) records = parser.parse(data_splitter=SingleSplitSplitter()) return [{DataKeys.INPUT: record} for record in records[0]] diff --git a/flash/core/integrations/labelstudio/input.py b/flash/core/integrations/labelstudio/input.py index e2164444c3..a252e11fcc 100644 --- a/flash/core/integrations/labelstudio/input.py +++ b/flash/core/integrations/labelstudio/input.py @@ -11,11 +11,9 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch.utils.data import Sampler -import flash from flash.core.data.io.input import DataKeys, Input, IterableInput -from flash.core.data.properties import ProcessState, Properties +from flash.core.data.properties import Properties from flash.core.data.utils import image_default_loader -from flash.core.integrations.transformers.states import TransformersBackboneState from flash.core.utilities.imports import _PYTORCHVIDEO_AVAILABLE from flash.core.utilities.stages import RunningStage @@ -24,8 +22,8 @@ @dataclass(unsafe_hash=True, frozen=True) -class LabelStudioState(ProcessState): - """The ``LabelStudioState`` stores the metadata loaded from the data.""" +class LabelStudioParameters: + """The ``LabelStudioParameters`` stores the metadata loaded from the data.""" multi_label: bool num_classes: Optional[int] @@ -106,7 +104,9 @@ def _load_json_data(data, data_folder, multi_label=False): class BaseLabelStudioInput(Properties): - def load_data(self, data: Optional[Any]) -> Sequence[Mapping[str, Any]]: + def load_data( + self, data: Optional[Any], parameters: Optional[LabelStudioParameters] = None + ) -> Sequence[Mapping[str, Any]]: """Iterate through all tasks in exported data and construct train\test\val results.""" if data and isinstance(data, dict): data_folder = data.get("data_folder") @@ -119,26 +119,23 @@ def load_data(self, data: Optional[Any]) -> Sequence[Mapping[str, Any]]: _raw_data, data_folder=data_folder, multi_label=multi_label ) if self.training: - self.set_state( - LabelStudioState( - classes=classes, - data_types=data_types, - tag_types=tag_types, - multi_label=multi_label, - num_classes=len(classes), - ) + self.parameters = LabelStudioParameters( + classes=classes, + data_types=data_types, + tag_types=tag_types, + multi_label=multi_label, + num_classes=len(classes), ) + else: + self.parameters = parameters return test_results if self.testing else results return [] def load_sample(self, sample: Mapping[str, Any] = None) -> Any: """Load 1 sample from dataset.""" - if not self.state: - self.state = self.get_state(LabelStudioState) - assert self.state # all other data types # separate label from data - label = _get_labels_from_sample(sample["label"], self.state.classes) + label = _get_labels_from_sample(sample["label"], self.parameters.classes) # delete label from input data del sample["label"] result = { @@ -225,31 +222,11 @@ class LabelStudioInput(BaseLabelStudioInput, Input): """The ``LabelStudioInput`` expects the input to :meth:`~flash.core.data.io.input.Input.load_data` to be a json export from label studio.""" - def __init__( - self, - running_stage: RunningStage, - *args: Any, - data_pipeline_state: Optional["flash.core.data.data_pipeline.DataPipelineState"] = None, - **kwargs: Any, - ): - self.state = None - super().__init__(running_stage, *args, data_pipeline_state=data_pipeline_state, **kwargs) - class LabelStudioIterableInput(BaseLabelStudioInput, IterableInput): """The ``LabelStudioInput`` expects the input to :meth:`~flash.core.data.io.input.Input.load_data` to be a json export from label studio.""" - def __init__( - self, - running_stage: RunningStage, - *args: Any, - data_pipeline_state: Optional["flash.core.data.data_pipeline.DataPipelineState"] = None, - **kwargs: Any, - ): - self.state = None - super().__init__(running_stage, *args, data_pipeline_state=data_pipeline_state, **kwargs) - class LabelStudioImageClassificationInput(LabelStudioInput): """The ``LabelStudioImageInput`` expects the input to @@ -258,13 +235,13 @@ class LabelStudioImageClassificationInput(LabelStudioInput): def load_sample(self, sample: Mapping[str, Any] = None) -> Any: """Load 1 sample from dataset.""" - if not self.state: - self.state = self.get_state(LabelStudioState) - assert self.state p = sample["file_upload"] # loading image image = image_default_loader(p) - result = {DataKeys.INPUT: image, DataKeys.TARGET: _get_labels_from_sample(sample["label"], self.state.classes)} + result = { + DataKeys.INPUT: image, + DataKeys.TARGET: _get_labels_from_sample(sample["label"], self.parameters.classes), + } return result @@ -274,28 +251,15 @@ class LabelStudioTextClassificationInput(LabelStudioInput): Export data should point to text data """ - def __init__(self, *args, max_length=128, **kwargs): - self.max_length = max_length - super().__init__(*args, **kwargs) - def load_sample(self, sample: Mapping[str, Any] = None) -> Any: """Load 1 sample from dataset.""" - if not self.state: - self.state = self.get_state(LabelStudioState) - - assert self.state - data = "" for key in sample.get("data"): data += sample.get("data").get(key) - tokenized_data = self.get_state(TransformersBackboneState).tokenizer( - data, max_length=self.max_length, truncation=True, padding="max_length" - ) - for key in tokenized_data: - tokenized_data[key] = torch.tensor(tokenized_data[key]) - tokenized_data["labels"] = _get_labels_from_sample(sample["label"], self.state.classes) - # separate text data type block - return tokenized_data + return { + DataKeys.INPUT: data, + DataKeys.TARGET: _get_labels_from_sample(sample["label"], self.parameters.classes), + } class LabelStudioVideoClassificationInput(LabelStudioIterableInput): @@ -332,27 +296,24 @@ def load_sample(self, sample: Mapping[str, Any] = None) -> Any: """Load 1 sample from dataset.""" return sample - def load_data(self, data: Optional[Any] = None) -> Sequence[Mapping[str, Any]]: + def load_data( + self, data: Optional[Any] = None, parameters: Optional[LabelStudioParameters] = None + ) -> Sequence[Mapping[str, Any]]: """load_data produces a sequence or iterable of samples.""" - res = super().load_data(data) + res = super().load_data(data, parameters=parameters) return self.convert_to_encodedvideo(res) def convert_to_encodedvideo(self, dataset): """Converting dataset to EncodedVideoDataset.""" if len(dataset) > 0: - if not self.state: - self.state = self.get_state(LabelStudioState) - - assert self.state - from pytorchvideo.data import LabeledVideoDataset dataset = LabeledVideoDataset( [ ( os.path.join(self._data_folder, sample["file_upload"]), - {"label": _get_labels_from_sample(sample["label"], self.state.classes)}, + {"label": _get_labels_from_sample(sample["label"], self.parameters.classes)}, ) for sample in dataset ], diff --git a/flash/core/integrations/labelstudio/visualizer.py b/flash/core/integrations/labelstudio/visualizer.py index d902d43145..180a6d5d12 100644 --- a/flash/core/integrations/labelstudio/visualizer.py +++ b/flash/core/integrations/labelstudio/visualizer.py @@ -5,19 +5,14 @@ from pytorch_lightning.utilities.cloud_io import get_filesystem from flash.core.data.data_module import DataModule -from flash.core.data.data_pipeline import DataPipelineState -from flash.core.integrations.labelstudio.input import LabelStudioState +from flash.core.integrations.labelstudio.input import LabelStudioParameters class App: """App for visualizing predictions in Label Studio results format.""" def __init__(self, datamodule: DataModule): - ds = datamodule.inputs - data_pipeline_state: DataPipelineState = ( - ds[0]._data_pipeline_state if isinstance(ds, list) else ds._data_pipeline_state - ) - self.state: LabelStudioState = data_pipeline_state.get_state(LabelStudioState) + self.parameters: LabelStudioParameters = datamodule.train_dataset.parameters def show_predictions(self, predictions): """Converts predictions to Label Studio results.""" @@ -30,7 +25,7 @@ def show_predictions(self, predictions): def show_tasks(self, predictions, export_json=None): """Converts predictions to tasks format.""" results = self.show_predictions(predictions) - data_type = list(self.state.data_types)[0] + data_type = list(self.parameters.data_types)[0] meta = {"ids": [], "data": [], "meta": [], "max_predictions_id": 0, "project": None} if export_json: fs = get_filesystem(export_json) @@ -77,13 +72,13 @@ def _construct_result(self, pred): """Construction Label Studio result from data source and prediction values.""" # get label if isinstance(pred, list): - label = [list(self.state.classes)[p] for p in pred] + label = [list(self.parameters.classes)[p] for p in pred] else: - label = list(self.state.classes)[pred] + label = list(self.parameters.classes)[pred] # get data type, if len(data_types) > 1 take first data type - data_type = list(self.state.data_types)[0] + data_type = list(self.parameters.data_types)[0] # get tag type, if len(tag_types) > 1 take first tag - tag_type = list(self.state.tag_types)[0] + tag_type = list(self.parameters.tag_types)[0] js = { "result": [ { diff --git a/flash/core/integrations/pytorch_forecasting/adapter.py b/flash/core/integrations/pytorch_forecasting/adapter.py index 57068b3716..46da2d0bdf 100644 --- a/flash/core/integrations/pytorch_forecasting/adapter.py +++ b/flash/core/integrations/pytorch_forecasting/adapter.py @@ -21,7 +21,6 @@ from flash.core.adapter import Adapter from flash.core.data.batch import default_uncollate from flash.core.data.io.input import DataKeys -from flash.core.data.states import CollateFn from flash.core.model import Task from flash.core.utilities.imports import _FORECASTING_AVAILABLE, _PANDAS_AVAILABLE @@ -90,7 +89,7 @@ def from_task( adapter = cls(task.backbones.get(backbone)(time_series_dataset=time_series_dataset, **backbone_kwargs)) # Attach the required collate function - adapter.set_state(CollateFn(partial(PyTorchForecastingAdapter._collate_fn, time_series_dataset._collate_fn))) + adapter.collate_fn = partial(PyTorchForecastingAdapter._collate_fn, time_series_dataset._collate_fn) return adapter diff --git a/flash/core/integrations/transformers/states.py b/flash/core/integrations/transformers/collate.py similarity index 51% rename from flash/core/integrations/transformers/states.py rename to flash/core/integrations/transformers/collate.py index 0669daed3f..fc7b7a6682 100644 --- a/flash/core/integrations/transformers/states.py +++ b/flash/core/integrations/transformers/collate.py @@ -12,28 +12,39 @@ # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass, field -from functools import lru_cache from typing import Any, Dict, Optional -from flash.core.data.properties import ProcessState +import torch + +from flash.core.data.io.input import DataKeys from flash.core.utilities.imports import _TRANSFORMERS_AVAILABLE if _TRANSFORMERS_AVAILABLE: from transformers import AutoTokenizer -@dataclass(unsafe_hash=True, frozen=True) -class TransformersBackboneState(ProcessState): - """The ``TransformersBackboneState`` records the ``backbone`` in use by tasks which rely on Hugging Face - transformers.""" +@dataclass(unsafe_hash=True) +class TransformersCollate: backbone: str tokenizer_kwargs: Optional[Dict[str, Any]] = field(default_factory=dict, hash=False) - @property - @lru_cache(maxsize=None) - def tokenizer(self): - tokenizer_kwargs = {} - if self.tokenizer_kwargs is not None: - tokenizer_kwargs = self.tokenizer_kwargs - return AutoTokenizer.from_pretrained(self.backbone, use_fast=True, **tokenizer_kwargs) + def __post_init__(self): + tokenizer_kwargs = self.tokenizer_kwargs or {} + self.tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True, **tokenizer_kwargs) + + @staticmethod + def to_tensor(sample: Dict[str, Any]) -> Dict[str, Any]: + tensor_sample = {} + for key in sample: + if key is DataKeys.METADATA: + tensor_sample[key] = sample[key] + else: + tensor_sample[key] = torch.tensor(sample[key]) + return tensor_sample + + def tokenize(self, sample): + raise NotImplementedError + + def __call__(self, samples): + return self.to_tensor(self.tokenize({key: [sample[key] for sample in samples] for key in samples[0].keys()})) diff --git a/flash/core/model.py b/flash/core/model.py index 405512bf49..5d783b88ee 100644 --- a/flash/core/model.py +++ b/flash/core/model.py @@ -11,24 +11,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import functools import inspect -import pickle import re from abc import ABCMeta from copy import deepcopy from importlib import import_module -from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Type, Union -from warnings import warn +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union import pytorch_lightning as pl import torch import torchmetrics -from deprecate import deprecated from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.callbacks import Callback from pytorch_lightning.callbacks.finetuning import BaseFinetuning -from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.enums import LightningEnum from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch import nn @@ -37,13 +32,11 @@ from torch.utils.data import DataLoader, Sampler import flash -from flash.core.data.data_pipeline import DataPipeline, DataPipelineState from flash.core.data.io.input import InputBase, ServeInput from flash.core.data.io.input_transform import InputTransform from flash.core.data.io.output import Output from flash.core.data.io.output_transform import OutputTransform -from flash.core.data.process import Deserializer, DeserializerMapping -from flash.core.data.properties import ProcessState +from flash.core.data.output import BASE_OUTPUTS from flash.core.finetuning import _DEFAULTS_FINETUNE_STRATEGIES, _FINETUNING_STRATEGIES_REGISTRY from flash.core.hooks import FineTuningHooks from flash.core.optimizers.optimizers import _OPTIMIZERS_REGISTRY @@ -53,9 +46,7 @@ from flash.core.utilities.apply_func import get_callable_dict from flash.core.utilities.imports import _PL_GREATER_EQUAL_1_5_0, requires from flash.core.utilities.providers import _HUGGINGFACE -from flash.core.utilities.stages import RunningStage from flash.core.utilities.types import ( - DESERIALIZER_TYPE, INPUT_TRANSFORM_TYPE, LOSS_FN_TYPE, LR_SCHEDULER_TYPE, @@ -63,7 +54,6 @@ MODEL_TYPE, OPTIMIZER_TYPE, OUTPUT_TRANSFORM_TYPE, - OUTPUT_TYPE, ) @@ -82,12 +72,6 @@ def __init__(self): self._children = [] - # TODO: create enum values to define what are the exact states - self._data_pipeline_state: DataPipelineState = DataPipelineState() - - # model own internal state shared with the data pipeline. - self._state: Dict[Type[ProcessState], ProcessState] = {} - def __setattr__(self, key, value): if isinstance(value, (LightningModule, ModuleWrapperBase)): self._children.append(key) @@ -98,35 +82,35 @@ def __setattr__(self, key, value): setattr(getattr(self, child), key, value) super().__setattr__(key, value) - def get_state(self, state_type): - if state_type in self._state: - return self._state[state_type] - if self._data_pipeline_state is not None: - return self._data_pipeline_state.get_state(state_type) - return None - - def set_state(self, state: ProcessState): - self._state[type(state)] = state - if self._data_pipeline_state is not None: - self._data_pipeline_state.set_state(state) - - def attach_data_pipeline_state(self, data_pipeline_state: "DataPipelineState"): - for state in self._state.values(): - data_pipeline_state.set_state(state) - if self._data_pipeline_state: - for state in self._data_pipeline_state._state.values(): - data_pipeline_state.set_state(state) - self._data_pipeline_state = data_pipeline_state - for child in self._children: - child = getattr(self, child) - if hasattr(child, "attach_data_pipeline_state"): - child.attach_data_pipeline_state(data_pipeline_state) - class DatasetProcessor: """The ``DatasetProcessor`` mixin provides hooks for classes which need custom logic for producing the data loaders for each running stage given the corresponding dataset.""" + def __init__(self): + super().__init__() + + self._collate_fn = None + self._input_transform = None + + @torch.jit.unused + @property + def collate_fn(self) -> Optional[Callable]: + return self._collate_fn + + @collate_fn.setter + def collate_fn(self, collate_fn: Callable) -> None: + self._collate_fn = collate_fn + + @torch.jit.unused + @property + def input_transform(self) -> Optional[INPUT_TRANSFORM_TYPE]: + return self._input_transform + + @input_transform.setter + def input_transform(self, input_transform: INPUT_TRANSFORM_TYPE) -> None: + self._input_transform = input_transform + def _process_dataset( self, dataset: InputBase, @@ -137,7 +121,7 @@ def _process_dataset( shuffle: bool = False, drop_last: bool = True, sampler: Optional[Sampler] = None, - persistent_workers: bool = True, + persistent_workers: bool = False, ) -> DataLoader: return DataLoader( dataset, @@ -147,7 +131,7 @@ def _process_dataset( shuffle=shuffle, drop_last=drop_last, sampler=sampler, - collate_fn=collate_fn, + collate_fn=self.collate_fn if self.collate_fn is not None else collate_fn, persistent_workers=persistent_workers, ) @@ -162,17 +146,18 @@ def process_train_dataset( shuffle: bool = False, drop_last: bool = True, sampler: Optional[Sampler] = None, + persistent_workers: bool = False, ) -> DataLoader: return self._process_dataset( dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, - collate_fn=collate_fn, + collate_fn=self.collate_fn if self.collate_fn is not None else collate_fn, shuffle=shuffle, drop_last=drop_last, sampler=sampler, - persistent_workers=num_workers > 0, + persistent_workers=persistent_workers, ) def process_val_dataset( @@ -186,17 +171,18 @@ def process_val_dataset( shuffle: bool = False, drop_last: bool = False, sampler: Optional[Sampler] = None, + persistent_workers: bool = False, ) -> DataLoader: return self._process_dataset( dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, - collate_fn=collate_fn, + collate_fn=self.collate_fn if self.collate_fn is not None else collate_fn, shuffle=shuffle, drop_last=drop_last, sampler=sampler, - persistent_workers=num_workers > 0, + persistent_workers=persistent_workers, ) def process_test_dataset( @@ -210,17 +196,18 @@ def process_test_dataset( shuffle: bool = False, drop_last: bool = False, sampler: Optional[Sampler] = None, + persistent_workers: bool = False, ) -> DataLoader: return self._process_dataset( dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, - collate_fn=collate_fn, + collate_fn=self.collate_fn if self.collate_fn is not None else collate_fn, shuffle=shuffle, drop_last=drop_last, sampler=sampler, - persistent_workers=num_workers > 0, + persistent_workers=persistent_workers, ) def process_predict_dataset( @@ -233,17 +220,18 @@ def process_predict_dataset( shuffle: bool = False, drop_last: bool = False, sampler: Optional[Sampler] = None, + persistent_workers: bool = False, ) -> DataLoader: return self._process_dataset( dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_memory, - collate_fn=collate_fn, + collate_fn=self.collate_fn if self.collate_fn is not None else collate_fn, shuffle=shuffle, drop_last=drop_last, sampler=sampler, - persistent_workers=False, + persistent_workers=persistent_workers, ) @@ -302,18 +290,14 @@ class Task(DatasetProcessor, ModuleWrapperBase, LightningModule, FineTuningHooks package, a custom metric inheriting from `torchmetrics.Metric`, a callable function or a list/dict containing a combination of the aforementioned. In all cases, each metric needs to have the signature `metric(preds,target)` and return a single scalar tensor. - deserializer: Either a single :class:`~flash.core.data.process.Deserializer` or a mapping of these to - deserialize the input - input_transform: :class:`~flash.core.data.io.input_transform.InputTransform` to use as the default - for this task. output_transform: :class:`~flash.core.data.io.output_transform.OutputTransform` to use as the default for this task. - output: The :class:`~flash.core.data.io.output.Output` to use when formatting prediction outputs. """ optimizers: FlashRegistry = _OPTIMIZERS_REGISTRY lr_schedulers: FlashRegistry = _SCHEDULERS_REGISTRY finetuning_strategies: FlashRegistry = _FINETUNING_STRATEGIES_REGISTRY + outputs: FlashRegistry = BASE_OUTPUTS required_extras: Optional[Union[str, List[str]]] = None @@ -325,10 +309,7 @@ def __init__( optimizer: OPTIMIZER_TYPE = "Adam", lr_scheduler: LR_SCHEDULER_TYPE = None, metrics: METRICS_TYPE = None, - deserializer: DESERIALIZER_TYPE = None, - input_transform: INPUT_TRANSFORM_TYPE = None, output_transform: OUTPUT_TRANSFORM_TYPE = None, - output: OUTPUT_TYPE = None, ): super().__init__() if model is not None: @@ -344,37 +325,7 @@ def __init__( # TODO: should we save more? Bug on some regarding yaml if we save metrics self.save_hyperparameters("learning_rate", "optimizer") - self._deserializer: Optional[Deserializer] = None - self._input_transform: Optional[InputTransform] = input_transform self._output_transform: Optional[OutputTransform] = output_transform - self._output: Optional[Output] = None - - # Explicitly set the output to call the setter - self.deserializer = deserializer - self.output = output - self._wrapped_predict_step = False - - def _wrap_predict_step(self) -> None: - if not self._wrapped_predict_step: - process_fn = self.build_data_pipeline().output_transform_processor(RunningStage.PREDICTING) - - predict_step = self.predict_step - - @functools.wraps(predict_step) - def wrapper(*args, **kwargs): - predictions = predict_step(*args, **kwargs) - return process_fn(predictions) - - self._original_predict_step = self.predict_step - self.predict_step = wrapper - - self._wrapped_predict_step = True - - def _unwrap_predict_step(self) -> None: - if self._wrapped_predict_step: - self.predict_step = self._original_predict_step - del self._original_predict_step - self._wrapped_predict_step = False def step(self, batch: Any, batch_idx: int, metrics: nn.ModuleDict) -> Any: """Implement the core logic for the training/validation/test step. By default this includes: @@ -580,258 +531,6 @@ def configure_finetune_callback( return [finetuning_strategy_fn(**finetuning_strategy_metadata)] - @staticmethod - def _resolve( - old_deserializer: Optional[Deserializer], - old_input_transform: Optional[InputTransform], - old_output_transform: Optional[OutputTransform], - old_output: Optional[Output], - new_deserializer: Optional[Deserializer], - new_input_transform: Optional[InputTransform], - new_output_transform: Optional[OutputTransform], - new_output: Optional[Output], - ) -> Tuple[Optional[Deserializer], Optional[InputTransform], Optional[OutputTransform], Optional[Output]]: - """Resolves the correct :class:`~flash.core.data.io.input_transform.InputTransform`, - :class:`~flash.core.data.io.output_transform.OutputTransform`, and :class:`~flash.core.data.io.output.Output` - to use, choosing ``new_*`` if it is not None or a base class - (:class:`~flash.core.data.io.input_transform.InputTransform`, - :class:`~flash.core.data.io.output_transform.OutputTransform`, or :class:`~flash.core.data.io.output.Output`) - and ``old_*`` otherwise. - - Args: - old_input_transform: :class:`~flash.core.data.io.input_transform.InputTransform` to be overridden. - old_output_transform: :class:`~flash.core.data.io.output_transform.OutputTransform` to be overridden. - old_output: :class:`~flash.core.data.io.output.Output` to be overridden. - new_input_transform: :class:`~flash.core.data.io.input_transform.InputTransform` to override with. - new_output_transform: :class:`~flash.core.data.io.output_transform.OutputTransform` to override with. - new_output: :class:`~flash.core.data.io.output.Output` to override with. - - Returns: - The resolved :class:`~flash.core.data.io.input_transform.InputTransform`, - :class:`~flash.core.data.io.output_transform.OutputTransform`, and - :class:`~flash.core.data.io.output.Output`. - """ - deserializer = old_deserializer - if new_deserializer is not None and type(new_deserializer) != Deserializer: - deserializer = new_deserializer - - input_transform = old_input_transform - if new_input_transform is not None and type(new_input_transform) != InputTransform: - input_transform = new_input_transform - - output_transform = old_output_transform - if new_output_transform is not None and type(new_output_transform) != OutputTransform: - output_transform = new_output_transform - - output = old_output - if new_output is not None and type(new_output) != Output: - output = new_output - - return deserializer, input_transform, output_transform, output - - @torch.jit.unused - @property - def deserializer(self) -> Optional[Deserializer]: - return self._deserializer - - @deserializer.setter - def deserializer(self, deserializer: Union[Deserializer, Mapping[str, Deserializer]]): - if isinstance(deserializer, Mapping): - deserializer = DeserializerMapping(deserializer) - self._deserializer = deserializer - - @torch.jit.unused - @property - def output(self) -> Optional[Output]: - """The current :class:`.Output` associated with this model.""" - return self._output - - @torch.jit.unused - @output.setter - def output(self, output: Output): - self._output = output - - @torch.jit.unused - @property - @deprecated( - None, - "0.6.0", - "0.7.0", - template_mgs="`Task.serializer` was deprecated in v%(deprecated_in)s in favor of `Task.output`. " - "It will be removed in v%(remove_in)s.", - stream=functools.partial(warn, category=FutureWarning), - ) - def serializer(self) -> Optional[Output]: - """Deprecated. - - Use ``Task.output`` instead. - """ - return self.output - - @torch.jit.unused - @serializer.setter - @deprecated( - None, - "0.6.0", - "0.7.0", - template_mgs="`Task.serializer` was deprecated in v%(deprecated_in)s in favor of `Task.output`. " - "It will be removed in v%(remove_in)s.", - stream=functools.partial(warn, category=FutureWarning), - ) - def serializer(self, serializer: Output): - self.output = serializer - - def build_data_pipeline( - self, - input: Optional[str] = None, - deserializer: Optional[Deserializer] = None, - data_pipeline: Optional[DataPipeline] = None, - ) -> Optional[DataPipeline]: - """Build a :class:`.DataPipeline` incorporating available - :class:`~flash.core.data.io.input_transform.InputTransform` and - :class:`~flash.core.data.io.output_transform.OutputTransform` - objects. These will be overridden in the following resolution order (lowest priority first): - - - Lightning ``Datamodule``, either attached to the :class:`.Trainer` or to the :class:`.Task`. - - :class:`.Task` defaults given to :meth:`.Task.__init__`. - - :class:`.Task` manual overrides by setting :py:attr:`~data_pipeline`. - - :class:`.DataPipeline` passed to this method. - - Args: - input: A string that indicates the format of the data source to use which will override - the current data source format used. - deserializer: deserializer to use - data_pipeline: Optional highest priority source of - :class:`~flash.core.data.io.input_transform.InputTransform` and - :class:`~flash.core.data.io.output_transform.OutputTransform`. - - Returns: - The fully resolved :class:`.DataPipeline`. - """ - deserializer, old_input, input_transform, output_transform, output = None, None, None, None, None - - # Datamodule - datamodule = None - if self.trainer is not None and hasattr(self.trainer, "datamodule"): - datamodule = self.trainer.datamodule - elif getattr(self, "datamodule", None) is not None: - datamodule = self.datamodule - - if getattr(datamodule, "data_pipeline", None) is not None: - old_input = getattr(datamodule.data_pipeline, "input", None) - input_transform = getattr(datamodule.data_pipeline, "_input_transform_pipeline", None) - output_transform = getattr(datamodule.data_pipeline, "_output_transform", None) - output = getattr(datamodule.data_pipeline, "_output", None) - deserializer = getattr(datamodule.data_pipeline, "_deserializer", None) - - # Defaults / task attributes - deserializer, input_transform, output_transform, output = Task._resolve( - deserializer, - input_transform, - output_transform, - output, - self._deserializer, - self._input_transform, - self._output_transform, - self._output, - ) - - # Datapipeline - if data_pipeline is not None: - deserializer, input_transform, output_transform, output = Task._resolve( - deserializer, - input_transform, - output_transform, - output, - getattr(data_pipeline, "_deserializer", None), - getattr(data_pipeline, "_input_transform_pipeline", None), - getattr(data_pipeline, "_output_transform", None), - getattr(data_pipeline, "_output", None), - ) - - input = input or old_input - - if deserializer is None or type(deserializer) is Deserializer: - deserializer = getattr(input_transform, "deserializer", deserializer) - - data_pipeline = DataPipeline( - input=input, - input_transform=input_transform, - output_transform=output_transform, - deserializer=deserializer, - output=output, - ) - - self._data_pipeline_state = self._data_pipeline_state or DataPipelineState() - - self.attach_data_pipeline_state(self._data_pipeline_state) - self._data_pipeline_state = data_pipeline.initialize(self._data_pipeline_state) - return data_pipeline - - @torch.jit.unused - @property - def data_pipeline(self) -> DataPipeline: - """The current :class:`.DataPipeline`. - - If set, the new value will override the :class:`.Task` defaults. See - :py:meth:`~build_data_pipeline` for more details on the resolution order. - """ - return self.build_data_pipeline() - - @torch.jit.unused - @data_pipeline.setter - def data_pipeline(self, data_pipeline: Optional[DataPipeline]) -> None: - self._deserializer, self._input_transform, self._output_transform, self.output = Task._resolve( - self._deserializer, - self._input_transform, - self._output_transform, - self._output, - getattr(data_pipeline, "_deserializer", None), - getattr(data_pipeline, "_input_transform_pipeline", None), - getattr(data_pipeline, "_output_transform", None), - getattr(data_pipeline, "_output", None), - ) - - # self._input_transform.state_dict() - if getattr(self._input_transform, "_ddp_params_and_buffers_to_ignore", None): - self._ddp_params_and_buffers_to_ignore = self._input_transform._ddp_params_and_buffers_to_ignore - - @torch.jit.unused - @property - def input_transform(self) -> InputTransform: - return getattr(self.data_pipeline, "_input_transform_pipeline", None) - - @torch.jit.unused - @property - def output_transform(self) -> OutputTransform: - return getattr(self.data_pipeline, "_output_transform", None) - - def on_predict_start(self) -> None: - self._wrap_predict_step() - - def on_predict_end(self) -> None: - self._unwrap_predict_step() - - def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - # This may be an issue since here we create the same problems with pickle as in - # https://pytorch.org/docs/stable/notes/serialization.html - if self.data_pipeline is not None and "data_pipeline" not in checkpoint: - try: - pickle.dumps(self.data_pipeline) # TODO: DataPipeline not always pickleable - checkpoint["data_pipeline"] = self.data_pipeline - except AttributeError: - rank_zero_warn("DataPipeline couldn't be added to the checkpoint.") - if self._data_pipeline_state is not None and "_data_pipeline_state" not in checkpoint: - checkpoint["_data_pipeline_state"] = self._data_pipeline_state - super().on_save_checkpoint(checkpoint) - - def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - super().on_load_checkpoint(checkpoint) - if "data_pipeline" in checkpoint: - self.data_pipeline = checkpoint["data_pipeline"] - if "_data_pipeline_state" in checkpoint: - self._data_pipeline_state = checkpoint["_data_pipeline_state"] - @classmethod def available_backbones( cls, head: Optional[str] = None @@ -1086,13 +785,13 @@ def configure_callbacks(self): return [BenchmarkConvergenceCI()] @requires("serve") - def run_serve_sanity_check(self, serve_input: ServeInput): + def run_serve_sanity_check(self, serve_input: ServeInput, output: Output): from fastapi.testclient import TestClient from flash.core.serve.flash_components import build_flash_serve_model_component print("Running serve sanity check") - comp = build_flash_serve_model_component(self, serve_input) + comp = build_flash_serve_model_component(self, serve_input, output) composition = Composition(predict=comp, TESTING=True, DEBUG=True) app = composition.serve(host="0.0.0.0", port=8000) @@ -1111,6 +810,7 @@ def serve( input_cls: Optional[Type[ServeInput]] = None, transform: INPUT_TRANSFORM_TYPE = InputTransform, transform_kwargs: Optional[Dict] = None, + output: Optional[Union[str, Output]] = None, ) -> "Composition": """Serve the ``Task``. Override this method to provide a default ``input_cls``, ``transform``, and ``transform_kwargs``. @@ -1130,10 +830,14 @@ def serve( serve_input = input_cls(transform=transform, transform_kwargs=transform_kwargs) + output = output or Output() + if isinstance(output, str): + output = self.outputs.get(output).from_task(self) + if sanity_check: - self.run_serve_sanity_check(serve_input) + self.run_serve_sanity_check(serve_input, output) - comp = build_flash_serve_model_component(self, serve_input) + comp = build_flash_serve_model_component(self, serve_input, output) composition = Composition(predict=comp, TESTING=flash._IS_TESTING) composition.serve(host=host, port=port) return composition diff --git a/flash/core/registry.py b/flash/core/registry.py index 640deed427..58098fc740 100644 --- a/flash/core/registry.py +++ b/flash/core/registry.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import functools +import inspect import itertools from typing import Any, Callable, Dict, List, Optional, Union @@ -32,10 +33,20 @@ def print_provider_info(name, providers, func): providers = providers[:-1] message = f"Using '{name}' provided by {', '.join(str(provider) for provider in providers)}." - @functools.wraps(func) - def wrapper(*args, **kwargs): - rank_zero_info(message) - return func(*args, **kwargs) + def build_wrapper(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + rank_zero_info(message) + return func(*args, **kwargs) + + return wrapper + + wrapper = build_wrapper(func) + + if inspect.isclass(func): + callables = [f for f in dir(func) if callable(getattr(func, f)) and not f.startswith("_")] + for c in callables: + setattr(wrapper, c, build_wrapper(getattr(func, c))) return wrapper @@ -56,9 +67,9 @@ def __add__(self, other): registries += [self] if isinstance(other, ConcatRegistry): - registries += other.registries + registries = other.registries + tuple(registries) else: - registries += [other] + registries = [other] + registries return ConcatRegistry(*registries) diff --git a/flash/core/regression.py b/flash/core/regression.py index a9a516be0a..846e782d80 100644 --- a/flash/core/regression.py +++ b/flash/core/regression.py @@ -19,7 +19,6 @@ from flash.core.adapter import AdapterTask from flash.core.model import Task -from flash.core.utilities.types import OUTPUT_TYPE class RegressionMixin: @@ -43,7 +42,6 @@ def __init__( *args, loss_fn: Optional[Callable] = None, metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None, - output: OUTPUT_TYPE = None, **kwargs, ) -> None: @@ -53,7 +51,6 @@ def __init__( *args, loss_fn=loss_fn, metrics=metrics, - output=output, **kwargs, ) @@ -64,7 +61,6 @@ def __init__( *args, loss_fn: Optional[Callable] = None, metrics: Union[torchmetrics.Metric, Mapping, Sequence, None] = None, - output: OUTPUT_TYPE = None, **kwargs, ) -> None: metrics, loss_fn = RegressionMixin._build(loss_fn, metrics) @@ -73,6 +69,5 @@ def __init__( *args, loss_fn=loss_fn, metrics=metrics, - output=output, **kwargs, ) diff --git a/flash/core/serve/flash_components.py b/flash/core/serve/flash_components.py index 6a10919e5c..c02b182039 100644 --- a/flash/core/serve/flash_components.py +++ b/flash/core/serve/flash_components.py @@ -4,10 +4,12 @@ import torch from flash.core.data.batch import _ServeInputProcessor -from flash.core.data.data_pipeline import DataPipelineState +from flash.core.data.data_module import DataModule from flash.core.data.io.input import DataKeys +from flash.core.data.io.output_transform import OutputTransform from flash.core.serve import expose, ModelComponent from flash.core.serve.types.base import BaseType +from flash.core.trainer import Trainer from flash.core.utilities.stages import RunningStage @@ -48,39 +50,38 @@ def deserialize(self, data: str) -> Any: # pragma: no cover return None -def build_flash_serve_model_component(model, serve_input): +def build_flash_serve_model_component(model, serve_input, output): + # TODO: Resolve this hack + data_module = DataModule(predict_input=serve_input, batch_size=1) - data_pipeline_state = DataPipelineState() - for properties in [ - serve_input, - getattr(serve_input, "transform", None), - model._output_transform, - model._output, - model, - ]: - if properties is not None and hasattr(properties, "attach_data_pipeline_state"): - properties.attach_data_pipeline_state(data_pipeline_state) + class MockTrainer(Trainer): + def __init__(self): + super().__init__() + self.state.stage = RunningStage.PREDICTING - data_pipeline = model.build_data_pipeline() + @property + def lightning_module(self): + return model + + data_module.trainer = MockTrainer() + dataloader = data_module.predict_dataloader() + + collate_fn = dataloader.collate_fn class FlashServeModelComponent(ModelComponent): def __init__(self, model): self.model = model self.model.eval() - self.data_pipeline = model.build_data_pipeline() self.serve_input = serve_input - self.dataloader_collate_fn = self.serve_input._create_dataloader_collate_fn([]) - self.on_after_batch_transfer_fn = self.serve_input._create_on_after_batch_transfer_fn([]) - self.output_transform_processor = self.data_pipeline.output_transform_processor( - RunningStage.SERVING, is_serving=True - ) + self.on_after_batch_transfer = data_module.on_after_batch_transfer + self.output_transform = getattr(model, "_output_transform", None) or OutputTransform() # todo (tchaton) Remove this hack self.extra_arguments = len(inspect.signature(self.model.transfer_batch_to_device).parameters) == 3 self.device = self.model.device @expose( - inputs={"inputs": FlashInputs(_ServeInputProcessor(serve_input))}, - outputs={"outputs": FlashOutputs(data_pipeline.output_processor())}, + inputs={"inputs": FlashInputs(_ServeInputProcessor(serve_input, collate_fn))}, + outputs={"outputs": FlashOutputs(output)}, ) def predict(self, inputs): with torch.no_grad(): @@ -88,9 +89,9 @@ def predict(self, inputs): inputs = self.model.transfer_batch_to_device(inputs, self.device, 0) else: inputs = self.model.transfer_batch_to_device(inputs, self.device) - inputs = self.on_after_batch_transfer_fn(inputs) + inputs = self.on_after_batch_transfer(inputs, 0) preds = self.model.predict_step(inputs, 0) - preds = self.output_transform_processor(preds) + preds = self.output_transform(preds) return preds return FlashServeModelComponent(model) diff --git a/flash/core/trainer.py b/flash/core/trainer.py index 77e78f30ca..3245856f4a 100644 --- a/flash/core/trainer.py +++ b/flash/core/trainer.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import contextlib +import functools import inspect import warnings from argparse import ArgumentParser, Namespace @@ -26,7 +28,10 @@ from torch.utils.data import DataLoader import flash +from flash.core.data.io.output import Output +from flash.core.data.io.output_transform import OutputTransform from flash.core.model import Task +from flash.core.registry import FlashRegistry def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs): @@ -157,6 +162,57 @@ def finetune( self._resolve_callbacks(model, strategy, train_bn=train_bn) return super().fit(model, train_dataloader, val_dataloaders, datamodule) + @contextlib.contextmanager + def _wrap_predict_step(self, model, output_transform, output) -> None: + predict_step = model.predict_step + + @functools.wraps(predict_step) + def wrapper(*args, **kwargs): + predictions = predict_step(*args, **kwargs) + if predictions is not None: + predictions = output_transform(predictions) + predictions = [output(prediction) for prediction in predictions] + return predictions + + model.predict_step = wrapper + try: + yield + finally: + model.predict_step = predict_step + + def predict( + self, + model: Optional[LightningModule] = None, + dataloaders: Optional[Union[DataLoader, LightningDataModule]] = None, + output: Union[Output, str] = None, + **kwargs, + ): + r""" + Run inference on your data. + This will call the model forward function to compute predictions. Useful to perform distributed + and batched predictions. Logging is disabled in the predict hooks. + + Args: + model: The model to predict with. + dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them, + or a :class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying prediction samples. + output: The :class:`~flash.core.data.io.output.Output` to use to transform predict outputs. + kwargs: Additional keyword arguments to pass to ``pytorch_lightning.Trainer.predict``. + + + Returns: + Returns a list of dictionaries, one for each provided dataloader containing their respective predictions. + """ + model = model or self.lightning_module + output_transform = getattr(model, "_output_transform", None) or OutputTransform() + if output is None: + output = Output() + if isinstance(output, str) and isinstance(model, Task): + output = getattr(model, "outputs", FlashRegistry("outputs")).get(output).from_task(model) + + with self._wrap_predict_step(model, output_transform, output): + return super().predict(model, dataloaders, **kwargs) + def _resolve_callbacks( self, model: Task, diff --git a/flash/core/utilities/providers.py b/flash/core/utilities/providers.py index 4573faceff..942e86e429 100644 --- a/flash/core/utilities/providers.py +++ b/flash/core/utilities/providers.py @@ -50,3 +50,4 @@ def __str__(self): _PYTORCH_FORECASTING = Provider("jdb78/PyTorch-Forecasting", "https://github.com/jdb78/pytorch-forecasting") _PYTORCH_GEOMETRIC = Provider("PyG/PyTorch Geometric", "https://github.com/pyg-team/pytorch_geometric") _PYTORCH_TABULAR = Provider("manujosephv/PyTorch Tabular", "https://github.com/manujosephv/pytorch_tabular") +_FIFTYONE = Provider("voxel51/fiftyone", "https://github.com/voxel51/fiftyone") diff --git a/flash/core/utilities/stages.py b/flash/core/utilities/stages.py index edfbbe0aaa..15728064bf 100644 --- a/flash/core/utilities/stages.py +++ b/flash/core/utilities/stages.py @@ -49,24 +49,3 @@ def dataloader_prefix(self) -> Optional[str]: if self == self.VALIDATING: return "val" return self.value - - -_STAGES_PREFIX = { - RunningStage.TRAINING: "train", - RunningStage.TESTING: "test", - RunningStage.VALIDATING: "val", - RunningStage.PREDICTING: "predict", - RunningStage.SERVING: "serve", -} - -_STAGES_PREFIX_VALUES = {"train", "test", "val", "predict", "serve"} - -_RUNNING_STAGE_MAPPING = { - RunningStage.TRAINING: RunningStage.TRAINING, - RunningStage.SANITY_CHECKING: RunningStage.VALIDATING, - RunningStage.VALIDATING: RunningStage.VALIDATING, - RunningStage.TESTING: RunningStage.TESTING, - RunningStage.PREDICTING: RunningStage.PREDICTING, - RunningStage.SERVING: RunningStage.SERVING, - RunningStage.TUNING: RunningStage.TUNING, -} diff --git a/flash/graph/classification/cli.py b/flash/graph/classification/cli.py index ca2f5368c1..acad250e3f 100644 --- a/flash/graph/classification/cli.py +++ b/flash/graph/classification/cli.py @@ -51,7 +51,7 @@ def graph_classification(): "trainer.max_epochs": 3, }, finetune=False, - datamodule_attributes={"num_classes", "num_features"}, + datamodule_attributes={"num_classes", "labels", "num_features"}, ) cli.trainer.save_checkpoint("graph_classification.pt") diff --git a/flash/graph/classification/data.py b/flash/graph/classification/data.py index 9a877ad378..e340535961 100644 --- a/flash/graph/classification/data.py +++ b/flash/graph/classification/data.py @@ -16,7 +16,6 @@ from torch.utils.data import Dataset from flash.core.data.data_module import DataModule -from flash.core.data.data_pipeline import DataPipelineState from flash.core.data.io.input import Input from flash.core.data.utilities.classification import TargetFormatter from flash.core.utilities.imports import _GRAPH_TESTING @@ -176,13 +175,15 @@ def from_datasets( ds_kw = dict( target_formatter=target_formatter, - data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, ) + train_input = input_cls(RunningStage.TRAINING, train_dataset, transform=train_transform, **ds_kw) + ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) + return cls( - input_cls(RunningStage.TRAINING, train_dataset, transform=train_transform, **ds_kw), + train_input, input_cls(RunningStage.VALIDATING, val_dataset, transform=val_transform, **ds_kw), input_cls(RunningStage.TESTING, test_dataset, transform=test_transform, **ds_kw), input_cls(RunningStage.PREDICTING, predict_dataset, transform=predict_transform, **ds_kw), diff --git a/flash/graph/classification/model.py b/flash/graph/classification/model.py index d82d55ab9d..92a5e25708 100644 --- a/flash/graph/classification/model.py +++ b/flash/graph/classification/model.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch from torch import nn @@ -59,7 +59,8 @@ class GraphClassifier(ClassificationTask): def __init__( self, num_features: int, - num_classes: int, + num_classes: Optional[int] = None, + labels: Optional[List[str]] = None, backbone: Union[str, Tuple[nn.Module, int]] = "GCN", backbone_kwargs: Optional[Dict] = {}, pooling_fn: Optional[Union[str, Callable]] = "mean", @@ -70,15 +71,19 @@ def __init__( lr_scheduler: LR_SCHEDULER_TYPE = None, metrics: METRICS_TYPE = None, ): - self.save_hyperparameters() + if labels is not None and num_classes is None: + num_classes = len(labels) + super().__init__( loss_fn=loss_fn, optimizer=optimizer, lr_scheduler=lr_scheduler, metrics=metrics, learning_rate=learning_rate, + num_classes=num_classes, + labels=labels, ) self.save_hyperparameters() diff --git a/flash/image/classification/adapters.py b/flash/image/classification/adapters.py index 423bd267e8..89376753ae 100644 --- a/flash/image/classification/adapters.py +++ b/flash/image/classification/adapters.py @@ -332,6 +332,7 @@ def process_train_dataset( shuffle: bool = False, drop_last: bool = False, sampler: Optional[Sampler] = None, + persistent_workers: bool = False, ) -> DataLoader: dataset = self._convert_dataset( trainer=trainer, @@ -356,6 +357,7 @@ def process_train_dataset( shuffle=shuffle, drop_last=drop_last, sampler=sampler, + persistent_workers=persistent_workers, ) def process_val_dataset( @@ -369,6 +371,7 @@ def process_val_dataset( shuffle: bool = False, drop_last: bool = False, sampler: Optional[Sampler] = None, + persistent_workers: bool = False, ) -> DataLoader: dataset = self._convert_dataset( trainer=trainer, @@ -393,6 +396,7 @@ def process_val_dataset( shuffle=shuffle, drop_last=drop_last, sampler=sampler, + persistent_workers=persistent_workers, ) def process_test_dataset( @@ -406,6 +410,7 @@ def process_test_dataset( shuffle: bool = False, drop_last: bool = False, sampler: Optional[Sampler] = None, + persistent_workers: bool = False, ) -> DataLoader: dataset = self._convert_dataset( trainer=trainer, @@ -430,6 +435,7 @@ def process_test_dataset( shuffle=shuffle, drop_last=drop_last, sampler=sampler, + persistent_workers=persistent_workers, ) def process_predict_dataset( @@ -442,6 +448,7 @@ def process_predict_dataset( shuffle: bool = False, drop_last: bool = True, sampler: Optional[Sampler] = None, + persistent_workers: bool = False, ) -> DataLoader: if not self._algorithm_has_validated: @@ -458,6 +465,7 @@ def process_predict_dataset( shuffle=shuffle, drop_last=drop_last, sampler=sampler, + persistent_workers=persistent_workers, ) diff --git a/flash/image/classification/cli.py b/flash/image/classification/cli.py index 1ffc4052d6..38779e1109 100644 --- a/flash/image/classification/cli.py +++ b/flash/image/classification/cli.py @@ -63,7 +63,7 @@ def image_classification(): default_arguments={ "trainer.max_epochs": 3, }, - datamodule_attributes={"num_classes", "multi_label"}, + datamodule_attributes={"num_classes", "labels", "multi_label"}, ) cli.trainer.save_checkpoint("image_classification_model.pt") diff --git a/flash/image/classification/data.py b/flash/image/classification/data.py index a7f311698b..59d33bcfa4 100644 --- a/flash/image/classification/data.py +++ b/flash/image/classification/data.py @@ -21,7 +21,6 @@ from flash.core.data.base_viz import BaseVisualization from flash.core.data.callback import BaseDataFetcher from flash.core.data.data_module import DataModule, DatasetInput -from flash.core.data.data_pipeline import DataPipelineState from flash.core.data.io.input import DataKeys, Input from flash.core.data.io.input_transform import INPUT_TRANSFORM_TYPE from flash.core.data.utilities.paths import PATH_TYPE @@ -169,13 +168,15 @@ def from_files( """ ds_kw = dict( - data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, ) + train_input = input_cls(RunningStage.TRAINING, train_files, train_targets, transform=train_transform, **ds_kw) + ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) + return cls( - input_cls(RunningStage.TRAINING, train_files, train_targets, transform=train_transform, **ds_kw), + train_input, input_cls(RunningStage.VALIDATING, val_files, val_targets, transform=val_transform, **ds_kw), input_cls(RunningStage.TESTING, test_files, test_targets, transform=test_transform, **ds_kw), input_cls(RunningStage.PREDICTING, predict_files, transform=predict_transform, **ds_kw), @@ -292,13 +293,15 @@ def from_folders( """ ds_kw = dict( - data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, ) + train_input = input_cls(RunningStage.TRAINING, train_folder, transform=train_transform, **ds_kw) + ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) + return cls( - input_cls(RunningStage.TRAINING, train_folder, transform=train_transform, **ds_kw), + train_input, input_cls(RunningStage.VALIDATING, val_folder, transform=val_transform, **ds_kw), input_cls(RunningStage.TESTING, test_folder, transform=test_transform, **ds_kw), input_cls(RunningStage.PREDICTING, predict_folder, transform=predict_transform, **ds_kw), @@ -380,13 +383,15 @@ def from_numpy( """ ds_kw = dict( - data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, ) + train_input = input_cls(RunningStage.TRAINING, train_data, train_targets, transform=train_transform, **ds_kw) + ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) + return cls( - input_cls(RunningStage.TRAINING, train_data, train_targets, transform=train_transform, **ds_kw), + train_input, input_cls(RunningStage.VALIDATING, val_data, val_targets, transform=val_transform, **ds_kw), input_cls(RunningStage.TESTING, test_data, test_targets, transform=test_transform, **ds_kw), input_cls(RunningStage.PREDICTING, predict_data, transform=predict_transform, **ds_kw), @@ -468,13 +473,15 @@ def from_tensors( """ ds_kw = dict( - data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, ) + train_input = input_cls(RunningStage.TRAINING, train_data, train_targets, transform=train_transform, **ds_kw) + ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) + return cls( - input_cls(RunningStage.TRAINING, train_data, train_targets, transform=train_transform, **ds_kw), + train_input, input_cls(RunningStage.VALIDATING, val_data, val_targets, transform=val_transform, **ds_kw), input_cls(RunningStage.TESTING, test_data, test_targets, transform=test_transform, **ds_kw), input_cls(RunningStage.PREDICTING, predict_data, transform=predict_transform, **ds_kw), @@ -609,7 +616,6 @@ def from_data_frame( """ ds_kw = dict( - data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, ) @@ -619,8 +625,11 @@ def from_data_frame( test_data = (test_data_frame, input_field, target_fields, test_images_root, test_resolver) predict_data = (predict_data_frame, input_field, None, predict_images_root, predict_resolver) + train_input = input_cls(RunningStage.TRAINING, *train_data, transform=train_transform, **ds_kw) + ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) + return cls( - input_cls(RunningStage.TRAINING, *train_data, transform=train_transform, **ds_kw), + train_input, input_cls(RunningStage.VALIDATING, *val_data, transform=val_transform, **ds_kw), input_cls(RunningStage.TESTING, *test_data, transform=test_transform, **ds_kw), input_cls(RunningStage.PREDICTING, *predict_data, transform=predict_transform, **ds_kw), @@ -769,7 +778,6 @@ def from_csv( """ ds_kw = dict( - data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, ) @@ -779,8 +787,11 @@ def from_csv( test_data = (test_file, input_field, target_fields, test_images_root, test_resolver) predict_data = (predict_file, input_field, None, predict_images_root, predict_resolver) + train_input = input_cls(RunningStage.TRAINING, *train_data, transform=train_transform, **ds_kw) + ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) + return cls( - input_cls(RunningStage.TRAINING, *train_data, transform=train_transform, **ds_kw), + train_input, input_cls(RunningStage.VALIDATING, *val_data, transform=val_transform, **ds_kw), input_cls(RunningStage.TESTING, *test_data, transform=test_transform, **ds_kw), input_cls(RunningStage.PREDICTING, *predict_data, transform=predict_transform, **ds_kw), @@ -890,15 +901,17 @@ def from_fiftyone( """ ds_kw = dict( - data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, ) + train_input = input_cls( + RunningStage.TRAINING, train_dataset, transform=train_transform, label_field=label_field, **ds_kw + ) + ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) + return cls( - input_cls( - RunningStage.TRAINING, train_dataset, transform=train_transform, label_field=label_field, **ds_kw - ), + train_input, input_cls(RunningStage.VALIDATING, val_dataset, transform=val_transform, label_field=label_field, **ds_kw), input_cls(RunningStage.TESTING, test_dataset, transform=test_transform, label_field=label_field, **ds_kw), input_cls(RunningStage.PREDICTING, predict_dataset, transform=predict_transform, **ds_kw), @@ -993,13 +1006,15 @@ def from_labelstudio( ) ds_kw = dict( - data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, ) + train_input = input_cls(RunningStage.TRAINING, train_data, transform=train_transform, **ds_kw) + ds_kw["parameters"] = getattr(train_input, "parameters", None) + return cls( - input_cls(RunningStage.TRAINING, train_data, transform=train_transform, **ds_kw), + train_input, input_cls(RunningStage.VALIDATING, val_data, transform=val_transform, **ds_kw), input_cls(RunningStage.TESTING, val_data, transform=test_transform, **ds_kw), input_cls(RunningStage.PREDICTING, predict_data, transform=predict_transform, **ds_kw), @@ -1053,7 +1068,6 @@ def from_datasets( ) """ ds_kw = dict( - data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, ) diff --git a/flash/image/classification/input.py b/flash/image/classification/input.py index 35ef329fe6..8e2ca2f920 100644 --- a/flash/image/classification/input.py +++ b/flash/image/classification/input.py @@ -16,9 +16,9 @@ import pandas as pd -from flash.core.data.io.classification_input import ClassificationInputMixin, ClassificationState +from flash.core.data.io.classification_input import ClassificationInputMixin from flash.core.data.io.input import DataKeys -from flash.core.data.utilities.classification import MultiBinaryTargetFormatter +from flash.core.data.utilities.classification import MultiBinaryTargetFormatter, TargetFormatter from flash.core.data.utilities.data_frame import read_csv, resolve_files, resolve_targets from flash.core.data.utilities.paths import filter_valid_files, make_dataset, PATH_TYPE from flash.core.data.utilities.samples import to_samples @@ -35,11 +35,16 @@ class ImageClassificationFilesInput(ClassificationInputMixin, ImageFilesInput): - def load_data(self, files: List[PATH_TYPE], targets: Optional[List[Any]] = None) -> List[Dict[str, Any]]: + def load_data( + self, + files: List[PATH_TYPE], + targets: Optional[List[Any]] = None, + target_formatter: Optional[TargetFormatter] = None, + ) -> List[Dict[str, Any]]: if targets is None: return super().load_data(files) files, targets = filter_valid_files(files, targets, valid_extensions=IMG_EXTENSIONS + NP_EXTENSIONS) - self.load_target_metadata(targets) + self.load_target_metadata(targets, target_formatter=target_formatter) return to_samples(files, targets) def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: @@ -50,14 +55,19 @@ def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: class ImageClassificationFolderInput(ImageClassificationFilesInput): - def load_data(self, folder: PATH_TYPE) -> List[Dict[str, Any]]: + def load_data(self, folder: PATH_TYPE, target_formatter: Optional[TargetFormatter] = None) -> List[Dict[str, Any]]: files, targets = make_dataset(folder, extensions=IMG_EXTENSIONS + NP_EXTENSIONS) - return super().load_data(files, targets) + return super().load_data(files, targets, target_formatter=target_formatter) class ImageClassificationFiftyOneInput(ImageClassificationFilesInput): @requires("fiftyone") - def load_data(self, sample_collection: SampleCollection, label_field: str = "ground_truth") -> List[Dict[str, Any]]: + def load_data( + self, + sample_collection: SampleCollection, + label_field: str = "ground_truth", + target_formatter: Optional[TargetFormatter] = None, + ) -> List[Dict[str, Any]]: label_utilities = FiftyOneLabelUtilities(label_field, fol.Label) label_utilities.validate(sample_collection) @@ -66,17 +76,21 @@ def load_data(self, sample_collection: SampleCollection, label_field: str = "gro filepaths = sample_collection.values("filepath") targets = sample_collection.values(label_path) - return super().load_data(filepaths, targets) + return super().load_data(filepaths, targets, target_formatter=target_formatter) @requires("fiftyone") - def predict_load_data(self, data: SampleCollection) -> List[Dict[str, Any]]: - return super().load_data(data.values("filepath")) + def predict_load_data( + self, data: SampleCollection, target_formatter: Optional[TargetFormatter] = None + ) -> List[Dict[str, Any]]: + return super().load_data(data.values("filepath"), target_formatter=target_formatter) class ImageClassificationTensorInput(ClassificationInputMixin, ImageTensorInput): - def load_data(self, tensor: Any, targets: Optional[List[Any]] = None) -> List[Dict[str, Any]]: + def load_data( + self, tensor: Any, targets: Optional[List[Any]] = None, target_formatter: Optional[TargetFormatter] = None + ) -> List[Dict[str, Any]]: if targets is not None: - self.load_target_metadata(targets) + self.load_target_metadata(targets, target_formatter=target_formatter) return to_samples(tensor, targets) def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: @@ -87,9 +101,11 @@ def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: class ImageClassificationNumpyInput(ClassificationInputMixin, ImageNumpyInput): - def load_data(self, array: Any, targets: Optional[List[Any]] = None) -> List[Dict[str, Any]]: + def load_data( + self, array: Any, targets: Optional[List[Any]] = None, target_formatter: Optional[TargetFormatter] = None + ) -> List[Dict[str, Any]]: if targets is not None: - self.load_target_metadata(targets) + self.load_target_metadata(targets, target_formatter=target_formatter) return to_samples(array, targets) def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: @@ -107,13 +123,14 @@ def load_data( target_keys: Optional[Union[str, List[str]]] = None, root: Optional[PATH_TYPE] = None, resolver: Optional[Callable[[Optional[PATH_TYPE], Any], PATH_TYPE]] = None, + target_formatter: Optional[TargetFormatter] = None, ) -> List[Dict[str, Any]]: files = resolve_files(data_frame, input_key, root, resolver) if target_keys is not None: targets = resolve_targets(data_frame, target_keys) else: targets = None - result = super().load_data(files, targets) + result = super().load_data(files, targets, target_formatter=target_formatter) # If we had binary multi-class targets then we also know the labels (column names) if ( @@ -121,8 +138,7 @@ def load_data( and isinstance(self.target_formatter, MultiBinaryTargetFormatter) and isinstance(target_keys, List) ): - classification_state = self.get_state(ClassificationState) - self.set_state(ClassificationState(target_keys, classification_state.num_classes)) + self.labels = target_keys return result @@ -135,8 +151,9 @@ def load_data( target_keys: Optional[Union[str, List[str]]] = None, root: Optional[PATH_TYPE] = None, resolver: Optional[Callable[[Optional[PATH_TYPE], Any], PATH_TYPE]] = None, + target_formatter: Optional[TargetFormatter] = None, ) -> List[Dict[str, Any]]: data_frame = read_csv(csv_file) if root is None: root = os.path.dirname(csv_file) - return super().load_data(data_frame, input_key, target_keys, root, resolver) + return super().load_data(data_frame, input_key, target_keys, root, resolver, target_formatter=target_formatter) diff --git a/flash/image/classification/integrations/baal/data.py b/flash/image/classification/integrations/baal/data.py index 2145cf62c9..12524bb170 100644 --- a/flash/image/classification/integrations/baal/data.py +++ b/flash/image/classification/integrations/baal/data.py @@ -20,9 +20,9 @@ from torch.utils.data import DataLoader, Dataset, random_split from flash.core.data.data_module import DataModule -from flash.core.data.data_pipeline import DataPipeline from flash.core.data.io.input import InputBase from flash.core.utilities.imports import _BAAL_AVAILABLE, requires +from flash.core.utilities.stages import RunningStage if _BAAL_AVAILABLE: from baal.active.dataset import ActiveLearningDataset @@ -132,10 +132,6 @@ def has_unlabelled_data(self) -> bool: def num_classes(self) -> Optional[int]: return getattr(self.labelled, "num_classes", None) or getattr(self.unlabelled, "num_classes", None) - @property - def data_pipeline(self) -> "DataPipeline": - return self.labelled.data_pipeline - def train_dataloader(self) -> "DataLoader": if self.val_split: self.labelled._train_input = train_val_split(self._dataset, self.val_split)[0] @@ -150,7 +146,9 @@ def train_dataloader(self) -> "DataLoader": def _val_dataloader(self) -> "DataLoader": self.labelled._val_input = train_val_split(self._dataset, self.val_split)[1] self.labelled._val_dataloader_collate_fn = self.labelled._train_dataloader_collate_fn - self.labelled._val_on_after_batch_transfer_fn = self.labelled._train_on_after_batch_transfer_fn + self.labelled._on_after_batch_transfer_fns[ + RunningStage.VALIDATING + ] = self.labelled._on_after_batch_transfer_fns[RunningStage.TRAINING] return self.labelled._val_dataloader() def _test_dataloader(self) -> "DataLoader": @@ -159,7 +157,9 @@ def _test_dataloader(self) -> "DataLoader": def predict_dataloader(self) -> "DataLoader": self.labelled._predict_input = self.filter_unlabelled_data(self._dataset.pool) self.labelled._predict_dataloader_collate_fn = self.labelled._train_dataloader_collate_fn - self.labelled._predict_on_after_batch_transfer_fn = self.labelled._train_on_after_batch_transfer_fn + self.labelled._on_after_batch_transfer_fns[ + RunningStage.PREDICTING + ] = self.labelled._on_after_batch_transfer_fns[RunningStage.TRAINING] return self.labelled._predict_dataloader() def label(self, probabilities: List[torch.Tensor] = None, indices=None): diff --git a/flash/image/classification/model.py b/flash/image/classification/model.py index a8d1c134e5..ea89d5eb9e 100644 --- a/flash/image/classification/model.py +++ b/flash/image/classification/model.py @@ -17,8 +17,9 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch import nn -from flash.core.classification import ClassificationAdapterTask, LabelsOutput +from flash.core.classification import ClassificationAdapterTask from flash.core.data.io.input import ServeInput +from flash.core.data.io.output import Output from flash.core.registry import FlashRegistry from flash.core.serve import Composition from flash.core.utilities.imports import requires @@ -28,7 +29,6 @@ LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE, - OUTPUT_TYPE, ) from flash.image.classification.adapters import TRAINING_STRATEGIES from flash.image.classification.backbones import IMAGE_CLASSIFIER_BACKBONES @@ -75,7 +75,6 @@ def fn_resnet(pretrained: bool = True): `metric(preds,target)` and return a single scalar tensor. Defaults to :class:`torchmetrics.Accuracy`. learning_rate: Learning rate to use for training, defaults to ``1e-3``. multi_label: Whether the targets are multi-label or not. - output: The :class:`~flash.core.data.io.output.Output` to use when formatting prediction outputs. training_strategy: string indicating the training strategy. Adjust if you want to use `learn2learn` for doing meta-learning research training_strategy_kwargs: Additional kwargs for setting the training strategy @@ -89,6 +88,7 @@ def fn_resnet(pretrained: bool = True): def __init__( self, num_classes: Optional[int] = None, + labels: Optional[List[str]] = None, backbone: Union[str, Tuple[nn.Module, int]] = "resnet18", backbone_kwargs: Optional[Dict] = None, head: Union[str, FunctionType, nn.Module] = "linear", @@ -99,13 +99,14 @@ def __init__( metrics: METRICS_TYPE = None, learning_rate: float = 1e-3, multi_label: bool = False, - output: OUTPUT_TYPE = None, training_strategy: Optional[str] = "default", training_strategy_kwargs: Optional[Dict[str, Any]] = None, ): - self.save_hyperparameters() + if labels is not None and num_classes is None: + num_classes = len(labels) + if not backbone_kwargs: backbone_kwargs = {} @@ -113,8 +114,8 @@ def __init__( training_strategy_kwargs = {} if training_strategy == "default": - if not num_classes: - raise MisconfigurationException("`num_classes` should be provided.") + if num_classes is None and labels is None: + raise MisconfigurationException("`num_classes` or `labels` should be provided.") else: num_classes = training_strategy_kwargs.get("ways", None) if not num_classes: @@ -151,7 +152,7 @@ def __init__( optimizer=optimizer, lr_scheduler=lr_scheduler, multi_label=multi_label, - output=output or LabelsOutput(multi_label=multi_label), + labels=labels, ) @classmethod @@ -173,8 +174,9 @@ def serve( input_cls: Optional[Type[ServeInput]] = ImageDeserializer, transform: INPUT_TRANSFORM_TYPE = ImageClassificationInputTransform, transform_kwargs: Optional[Dict] = None, + output: Optional[Union[str, Output]] = None, ) -> Composition: - return super().serve(host, port, sanity_check, input_cls, transform, transform_kwargs) + return super().serve(host, port, sanity_check, input_cls, transform, transform_kwargs, output) def _ci_benchmark_fn(self, history: List[Dict[str, Any]]): """This function is used only for debugging usage with CI.""" diff --git a/flash/image/detection/data.py b/flash/image/detection/data.py index 968fe891d3..14472c1493 100644 --- a/flash/image/detection/data.py +++ b/flash/image/detection/data.py @@ -15,7 +15,6 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union from flash.core.data.data_module import DataModule -from flash.core.data.data_pipeline import DataPipelineState from flash.core.data.io.input import Input from flash.core.data.utilities.sort import sorted_alphanumeric from flash.core.integrations.icevision.data import IceVisionInput @@ -138,7 +137,7 @@ def from_files( 3 >>> datamodule.labels ['background', 'cat', 'dog'] - >>> model = ObjectDetector(num_classes=datamodule.num_classes) + >>> model = ObjectDetector(labels=datamodule.labels) >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE Training... @@ -152,17 +151,20 @@ def from_files( >>> _ = [os.remove(f"predict_image_{i}.png") for i in range(1, 4)] """ - ds_kw = dict(data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs) + ds_kw = dict(transform_kwargs=transform_kwargs) + + train_input = input_cls( + RunningStage.TRAINING, + train_files, + train_targets, + train_bboxes, + transform=train_transform, + **ds_kw, + ) + ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) return cls( - input_cls( - RunningStage.TRAINING, - train_files, - train_targets, - train_bboxes, - transform=train_transform, - **ds_kw, - ), + train_input, input_cls( RunningStage.VALIDATING, val_files, @@ -206,7 +208,7 @@ def from_icedata( **data_module_kwargs, ) -> "ObjectDetectionData": - ds_kw = dict(parser=parser, data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs) + ds_kw = dict(parser=parser, transform_kwargs=transform_kwargs) return cls( input_cls( @@ -832,7 +834,7 @@ def from_fiftyone( >>> _ = [os.remove(f"predict_image_{i}.png") for i in range(1, 4)] """ - ds_kw = dict(data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs) + ds_kw = dict(transform_kwargs=transform_kwargs) return cls( input_cls(RunningStage.TRAINING, train_dataset, label_field, iscrowd, transform=train_transform, **ds_kw), diff --git a/flash/image/detection/input.py b/flash/image/detection/input.py index 9ff1a6c527..7aabf07a4c 100644 --- a/flash/image/detection/input.py +++ b/flash/image/detection/input.py @@ -13,8 +13,9 @@ # limitations under the License. from typing import Any, Dict, Hashable, List, Optional, Sequence -from flash.core.data.io.classification_input import ClassificationInputMixin, ClassificationState +from flash.core.data.io.classification_input import ClassificationInputMixin from flash.core.data.io.input import DataKeys +from flash.core.data.utilities.classification import TargetFormatter from flash.core.data.utilities.paths import filter_valid_files, PATH_TYPE from flash.core.integrations.fiftyone.utils import FiftyOneLabelUtilities from flash.core.integrations.icevision.data import IceVisionInput @@ -43,13 +44,16 @@ def load_data( files: List[PATH_TYPE], targets: Optional[List[List[Any]]] = None, bboxes: Optional[List[List[Dict[str, int]]]] = None, + target_formatter: Optional[TargetFormatter] = None, ) -> List[Dict[str, Any]]: if targets is None: return super().load_data(files) files, targets, bboxes = filter_valid_files( files, targets, bboxes, valid_extensions=IMG_EXTENSIONS + NP_EXTENSIONS ) - self.load_target_metadata([t for target in targets for t in target], add_background=True) + self.load_target_metadata( + [t for target in targets for t in target], add_background=True, target_formatter=target_formatter + ) return [ {DataKeys.INPUT: file, DataKeys.TARGET: {"bboxes": bbox, "labels": label}} @@ -142,7 +146,6 @@ def load_data( class_map = ClassMap(classes) self.num_classes = len(class_map) self.labels = [class_map.get_by_id(i) for i in range(self.num_classes)] - self.set_state(ClassificationState(self.labels)) parser = FiftyOneParser(sample_collection, class_map, label_field, iscrowd) records = parser.parse(data_splitter=SingleSplitSplitter()) diff --git a/flash/image/detection/model.py b/flash/image/detection/model.py index f6b009891b..576ff0c3cd 100644 --- a/flash/image/detection/model.py +++ b/flash/image/detection/model.py @@ -14,10 +14,11 @@ from typing import Any, Dict, List, Optional from flash.core.adapter import AdapterTask -from flash.core.data.output import PredsOutput +from flash.core.model import Task from flash.core.registry import FlashRegistry -from flash.core.utilities.types import LR_SCHEDULER_TYPE, OPTIMIZER_TYPE, OUTPUT_TYPE +from flash.core.utilities.types import LR_SCHEDULER_TYPE, OPTIMIZER_TYPE from flash.image.detection.backbones import OBJECT_DETECTION_HEADS +from flash.image.detection.output import OBJECT_DETECTION_OUTPUTS class ObjectDetector(AdapterTask): @@ -38,24 +39,31 @@ class ObjectDetector(AdapterTask): """ heads: FlashRegistry = OBJECT_DETECTION_HEADS + outputs = Task.outputs + OBJECT_DETECTION_OUTPUTS required_extras: List[str] = ["image", "icevision", "effdet"] def __init__( self, - num_classes: int, + num_classes: Optional[int] = None, + labels: Optional[List[str]] = None, backbone: Optional[str] = "resnet18_fpn", head: Optional[str] = "retinanet", pretrained: bool = True, optimizer: OPTIMIZER_TYPE = "Adam", lr_scheduler: LR_SCHEDULER_TYPE = None, learning_rate: float = 1e-2, - output: OUTPUT_TYPE = None, predict_kwargs: Dict = None, **kwargs: Any, ): self.save_hyperparameters() + if labels is not None and num_classes is None: + num_classes = len(labels) + + self.labels = labels + self.num_classes = num_classes + predict_kwargs = predict_kwargs if predict_kwargs else {} metadata = self.heads.get(head, with_metadata=True) adapter = metadata["metadata"]["adapter"].from_task( @@ -73,7 +81,6 @@ def __init__( learning_rate=learning_rate, optimizer=optimizer, lr_scheduler=lr_scheduler, - output=output or PredsOutput(), ) def _ci_benchmark_fn(self, history: List[Dict[str, Any]]) -> None: diff --git a/flash/image/detection/output.py b/flash/image/detection/output.py index 240674be25..acb7e8b5b6 100644 --- a/flash/image/detection/output.py +++ b/flash/image/detection/output.py @@ -13,12 +13,12 @@ # limitations under the License. from typing import Any, Dict, List, Optional, Union -from pytorch_lightning.utilities import rank_zero_warn - -from flash.core.data.io.classification_input import ClassificationState from flash.core.data.io.input import DataKeys from flash.core.data.io.output import Output +from flash.core.model import Task +from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, lazy_import, requires +from flash.core.utilities.providers import _FIFTYONE if _FIFTYONE_AVAILABLE: fo = lazy_import("fiftyone") @@ -28,12 +28,15 @@ Detections = None +OBJECT_DETECTION_OUTPUTS = FlashRegistry("outputs") + + +@OBJECT_DETECTION_OUTPUTS(name="fiftyone", providers=_FIFTYONE) class FiftyOneDetectionLabelsOutput(Output): """A :class:`.Output` which converts model outputs to FiftyOne detection format. Args: - labels: A list of labels, assumed to map the class index to the label for that class. If ``labels`` is not - provided, will attempt to get them from the :class:`.ClassificationState`. + labels: A list of labels, assumed to map the class index to the label for that class. threshold: a score threshold to apply to candidate detections. return_filepath: Boolean determining whether to return a dict containing filepath and FiftyOne labels (True) or only a @@ -45,32 +48,21 @@ def __init__( self, labels: Optional[List[str]] = None, threshold: Optional[float] = None, - return_filepath: bool = False, + return_filepath: bool = True, ): super().__init__() self._labels = labels self.threshold = threshold self.return_filepath = return_filepath - if labels is not None: - self.set_state(ClassificationState(labels)) + @classmethod + def from_task(cls, task: Task, **kwargs) -> Output: + return cls(labels=getattr(task, "labels", None)) def transform(self, sample: Dict[str, Any]) -> Union[Detections, Dict[str, Any]]: if DataKeys.METADATA not in sample: raise ValueError("sample requires DefaultDataKeys.METADATA to use a FiftyOneDetectionLabelsOutput output.") - labels = None - if self._labels is not None: - labels = self._labels - else: - state = self.get_state(ClassificationState) - if state is not None: - labels = state.labels - else: - rank_zero_warn( - "No ClassificationState was found, int targets will be used as label strings", UserWarning - ) - height, width = sample[DataKeys.METADATA]["size"] detections = [] @@ -92,8 +84,8 @@ def transform(self, sample: Dict[str, Any]) -> Union[Detections, Dict[str, Any]] ] label = label.item() - if labels is not None: - label = labels[label] + if self._labels is not None: + label = self._labels[label] else: label = str(int(label)) diff --git a/flash/image/embedding/model.py b/flash/image/embedding/model.py index 54c6510d9c..b0ad82b0d5 100644 --- a/flash/image/embedding/model.py +++ b/flash/image/embedding/model.py @@ -16,13 +16,6 @@ from flash.core.adapter import AdapterTask from flash.core.data.io.input import DataKeys -from flash.core.data.states import ( - CollateFn, - PerBatchTransform, - PerBatchTransformOnDevice, - PerSampleTransform, - PerSampleTransformOnDevice, -) from flash.core.data.transforms import ApplyToKeys from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _VISSL_AVAILABLE, requires @@ -117,20 +110,8 @@ def __init__( learning_rate=learning_rate, ) - transform, collate_fn = self.transforms.get(pretraining_transform)(**pretraining_transform_kwargs) - - self.adapter.set_state(CollateFn(collate_fn)) - self.adapter.set_state( - PerSampleTransform( - ApplyToKeys( - DataKeys.INPUT, - transform, - ) - ) - ) - self.adapter.set_state(PerSampleTransformOnDevice(None)) - self.adapter.set_state(PerBatchTransform(None)) - self.adapter.set_state(PerBatchTransformOnDevice(None)) + input_transform, self.collate_fn = self.transforms.get(pretraining_transform)(**pretraining_transform_kwargs) + self.input_transform = ApplyToKeys(DataKeys.INPUT, input_transform) warnings.warn( "Warning: VISSL ImageEmbedder overrides any user provided transforms" diff --git a/flash/image/face_detection/data.py b/flash/image/face_detection/data.py index dcf9b4702d..422e2e3028 100644 --- a/flash/image/face_detection/data.py +++ b/flash/image/face_detection/data.py @@ -16,19 +16,16 @@ from torch.utils.data import Dataset from flash.core.data.data_module import DataModule -from flash.core.data.data_pipeline import DataPipelineState from flash.core.data.io.input import Input from flash.core.utilities.stages import RunningStage from flash.core.utilities.types import INPUT_TRANSFORM_TYPE from flash.image.classification.data import ImageClassificationFilesInput, ImageClassificationFolderInput from flash.image.face_detection.input import FaceDetectionInput from flash.image.face_detection.input_transform import FaceDetectionInputTransform -from flash.image.face_detection.output_transform import FaceDetectionOutputTransform class FaceDetectionData(DataModule): input_transform_cls = FaceDetectionInputTransform - output_transform_cls = FaceDetectionOutputTransform @classmethod def from_datasets( @@ -46,14 +43,13 @@ def from_datasets( **data_module_kwargs, ) -> "FaceDetectionData": - ds_kw = dict(data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs) + ds_kw = dict(transform_kwargs=transform_kwargs) return cls( input_cls(RunningStage.TRAINING, train_dataset, transform=train_transform, **ds_kw), input_cls(RunningStage.VALIDATING, val_dataset, transform=val_transform, **ds_kw), input_cls(RunningStage.TESTING, test_dataset, transform=test_transform, **ds_kw), input_cls(RunningStage.PREDICTING, predict_dataset, transform=predict_transform, **ds_kw), - output_transform=cls.output_transform_cls(), **data_module_kwargs, ) @@ -71,7 +67,6 @@ def from_files( return cls( predict_input=input_cls(RunningStage.PREDICTING, predict_files, **ds_kw), - output_transform=cls.output_transform_cls(), **data_module_kwargs, ) @@ -89,6 +84,5 @@ def from_folders( return cls( predict_input=input_cls(RunningStage.PREDICTING, predict_folder, **ds_kw), - output_transform=cls.output_transform_cls(), **data_module_kwargs, ) diff --git a/flash/image/face_detection/model.py b/flash/image/face_detection/model.py index 12295a2f61..9be3f51dcc 100644 --- a/flash/image/face_detection/model.py +++ b/flash/image/face_detection/model.py @@ -11,36 +11,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Iterable, List, Union +from typing import Any, Iterable, List, Union import torch from torch.nn import Module from flash.core.data.io.input import DataKeys -from flash.core.data.io.output import Output from flash.core.model import Task from flash.core.utilities.imports import _FASTFACE_AVAILABLE -from flash.core.utilities.types import ( - INPUT_TRANSFORM_TYPE, - LOSS_FN_TYPE, - LR_SCHEDULER_TYPE, - METRICS_TYPE, - OPTIMIZER_TYPE, - OUTPUT_TYPE, -) +from flash.core.utilities.types import LOSS_FN_TYPE, LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE from flash.image.face_detection.backbones import FACE_DETECTION_BACKBONES +from flash.image.face_detection.output_transform import FaceDetectionOutputTransform if _FASTFACE_AVAILABLE: import fastface as ff -class DetectionLabelsOutput(Output): - """A :class:`.Output` which extracts predictions from sample dict.""" - - def transform(self, sample: Any) -> Dict[str, Any]: - return sample[DataKeys.PREDS] if isinstance(sample, Dict) else sample - - class FaceDetector(Task): """The ``FaceDetector`` is a :class:`~flash.Task` for detecting faces in images. @@ -55,7 +41,6 @@ class FaceDetector(Task): optimizer: Optimizer to use for training. lr_scheduler: The LR scheduler to use during training. learning_rate: The learning rate to use for training. - output: The :class:`~flash.core.data.io.output.Output` to use when formatting prediction outputs. kwargs: additional kwargs nessesary for initializing face detector backbone """ @@ -70,8 +55,6 @@ def __init__( optimizer: OPTIMIZER_TYPE = "Adam", lr_scheduler: LR_SCHEDULER_TYPE = None, learning_rate: float = 1e-4, - output: OUTPUT_TYPE = DetectionLabelsOutput(), - input_transform: INPUT_TRANSFORM_TYPE = None, **kwargs: Any, ): self.save_hyperparameters() @@ -88,8 +71,7 @@ def __init__( learning_rate=learning_rate, optimizer=optimizer, lr_scheduler=lr_scheduler, - output=output, - input_transform=input_transform, + output_transform=FaceDetectionOutputTransform, ) @staticmethod diff --git a/flash/image/instance_segmentation/data.py b/flash/image/instance_segmentation/data.py index 3eb4e29d42..2d855ba107 100644 --- a/flash/image/instance_segmentation/data.py +++ b/flash/image/instance_segmentation/data.py @@ -15,7 +15,6 @@ from typing import Any, Callable, Dict, List, Optional, Type, Union from flash.core.data.data_module import DataModule -from flash.core.data.data_pipeline import DataPipelineState from flash.core.data.io.input import DataKeys, Input from flash.core.data.io.output_transform import OutputTransform from flash.core.data.utilities.sort import sorted_alphanumeric @@ -73,7 +72,7 @@ def from_icedata( **data_module_kwargs, ) -> "InstanceSegmentationData": - ds_kw = dict(parser=parser, data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs) + ds_kw = dict(parser=parser, transform_kwargs=transform_kwargs) return cls( input_cls( diff --git a/flash/image/instance_segmentation/model.py b/flash/image/instance_segmentation/model.py index d47119301b..082055cbe8 100644 --- a/flash/image/instance_segmentation/model.py +++ b/flash/image/instance_segmentation/model.py @@ -13,13 +13,9 @@ # limitations under the License. from typing import Any, Dict, List, Optional -from pytorch_lightning.utilities import rank_zero_info - from flash.core.adapter import AdapterTask -from flash.core.data.data_pipeline import DataPipeline -from flash.core.data.output import PredsOutput from flash.core.registry import FlashRegistry -from flash.core.utilities.types import LR_SCHEDULER_TYPE, OPTIMIZER_TYPE, OUTPUT_TRANSFORM_TYPE, OUTPUT_TYPE +from flash.core.utilities.types import LR_SCHEDULER_TYPE, OPTIMIZER_TYPE, OUTPUT_TRANSFORM_TYPE from flash.image.instance_segmentation.backbones import INSTANCE_SEGMENTATION_HEADS from flash.image.instance_segmentation.data import InstanceSegmentationOutputTransform @@ -36,7 +32,6 @@ class InstanceSegmentation(AdapterTask): optimizer: Optimizer to use for training. lr_scheduler: The LR scheduler to use during training. learning_rate: The learning rate to use for training. - output: The :class:`~flash.core.data.io.output.Output` to use when formatting prediction outputs. predict_kwargs: dictionary containing parameters that will be used during the prediction phase. **kwargs: additional kwargs used for initializing the task """ @@ -55,7 +50,6 @@ def __init__( lr_scheduler: LR_SCHEDULER_TYPE = None, learning_rate: float = 5e-4, output_transform: OUTPUT_TRANSFORM_TYPE = InstanceSegmentationOutputTransform(), - output: OUTPUT_TYPE = PredsOutput(), predict_kwargs: Dict = None, **kwargs: Any, ): @@ -79,27 +73,12 @@ def __init__( optimizer=optimizer, lr_scheduler=lr_scheduler, output_transform=output_transform, - output=output, ) def _ci_benchmark_fn(self, history: List[Dict[str, Any]]) -> None: """This function is used only for debugging usage with CI.""" # todo - def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - super().on_load_checkpoint(checkpoint) - # todo: currently the data pipeline for icevision is not serializable, so we re-create the pipeline. - if "data_pipeline" not in checkpoint: - rank_zero_info( - "Assigned Segmentation Data Pipeline for data processing. This is because a data-pipeline stored in " - "the model due to pickling issues. " - "If you'd like to change this, extend the InstanceSegmentation Task and override `on_load_checkpoint`." - ) - self.data_pipeline = DataPipeline( - input_transform=None, - output_transform=InstanceSegmentationOutputTransform(), - ) - @property def predict_kwargs(self) -> Dict[str, Any]: """The kwargs used for the prediction step.""" diff --git a/flash/image/keypoint_detection/data.py b/flash/image/keypoint_detection/data.py index ec2a79d61c..870e8e839e 100644 --- a/flash/image/keypoint_detection/data.py +++ b/flash/image/keypoint_detection/data.py @@ -15,7 +15,6 @@ from typing import Any, Callable, Dict, List, Optional, Type, Union from flash.core.data.data_module import DataModule -from flash.core.data.data_pipeline import DataPipelineState from flash.core.data.io.input import Input from flash.core.integrations.icevision.data import IceVisionInput from flash.core.utilities.imports import _ICEVISION_AVAILABLE @@ -90,7 +89,7 @@ def from_icedata( **data_module_kwargs: Any, ) -> "KeypointDetectionData": - ds_kw = dict(parser=parser, data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs) + ds_kw = dict(parser=parser, transform_kwargs=transform_kwargs) return cls( input_cls( diff --git a/flash/image/keypoint_detection/model.py b/flash/image/keypoint_detection/model.py index 74353a27f2..6f5a475191 100644 --- a/flash/image/keypoint_detection/model.py +++ b/flash/image/keypoint_detection/model.py @@ -14,9 +14,8 @@ from typing import Any, Dict, List, Optional from flash.core.adapter import AdapterTask -from flash.core.data.output import PredsOutput from flash.core.registry import FlashRegistry -from flash.core.utilities.types import LR_SCHEDULER_TYPE, OPTIMIZER_TYPE, OUTPUT_TYPE +from flash.core.utilities.types import LR_SCHEDULER_TYPE, OPTIMIZER_TYPE from flash.image.keypoint_detection.backbones import KEYPOINT_DETECTION_HEADS @@ -33,7 +32,6 @@ class KeypointDetector(AdapterTask): optimizer: Optimizer to use for training. lr_scheduler: The LR scheduler to use during training. learning_rate: The learning rate to use for training. - output: The :class:`~flash.core.data.io.output.Output` to use when formatting prediction outputs. predict_kwargs: dictionary containing parameters that will be used during the prediction phase. **kwargs: additional kwargs used for initializing the task """ @@ -52,7 +50,6 @@ def __init__( optimizer: OPTIMIZER_TYPE = "Adam", lr_scheduler: LR_SCHEDULER_TYPE = None, learning_rate: float = 5e-4, - output: OUTPUT_TYPE = None, predict_kwargs: Dict = None, **kwargs: Any, ): @@ -76,7 +73,6 @@ def __init__( learning_rate=learning_rate, optimizer=optimizer, lr_scheduler=lr_scheduler, - output=output or PredsOutput(), ) def _ci_benchmark_fn(self, history: List[Dict[str, Any]]) -> None: diff --git a/flash/image/segmentation/data.py b/flash/image/segmentation/data.py index 8ad97f1b4b..e6ae4bf9a4 100644 --- a/flash/image/segmentation/data.py +++ b/flash/image/segmentation/data.py @@ -18,7 +18,6 @@ from flash.core.data.callback import BaseDataFetcher from flash.core.data.data_module import DataModule -from flash.core.data.data_pipeline import DataPipelineState from flash.core.data.io.input import Input from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _FIFTYONE_AVAILABLE, _IMAGE_EXTRAS_TESTING, _IMAGE_TESTING, lazy_import @@ -160,7 +159,6 @@ def from_files( """ ds_kw = dict( - data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, num_classes=num_classes, @@ -309,7 +307,6 @@ def from_folders( """ ds_kw = dict( - data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, num_classes=num_classes, @@ -405,7 +402,6 @@ def from_numpy( """ ds_kw = dict( - data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, num_classes=num_classes, @@ -501,7 +497,6 @@ def from_tensors( """ ds_kw = dict( - data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, num_classes=num_classes, @@ -620,7 +615,6 @@ def from_fiftyone( """ ds_kw = dict( - data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, ) diff --git a/flash/image/segmentation/input.py b/flash/image/segmentation/input.py index cc36137a0d..1a160fea2e 100644 --- a/flash/image/segmentation/input.py +++ b/flash/image/segmentation/input.py @@ -16,7 +16,7 @@ import torch -from flash.core.data.io.input import DataKeys, ImageLabelsMap, Input +from flash.core.data.io.input import DataKeys, Input from flash.core.data.utilities.paths import filter_valid_files, PATH_TYPE from flash.core.data.utilities.samples import to_samples from flash.core.integrations.fiftyone.utils import FiftyOneLabelUtilities @@ -44,7 +44,6 @@ def load_labels_map( labels_map = labels_map or SegmentationLabelsOutput.create_random_labels_map(num_classes) if labels_map is not None: - self.set_state(ImageLabelsMap(labels_map)) self.labels_map = labels_map def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: diff --git a/flash/image/segmentation/model.py b/flash/image/segmentation/model.py index e1f3981160..80721b30b7 100644 --- a/flash/image/segmentation/model.py +++ b/flash/image/segmentation/model.py @@ -19,7 +19,9 @@ from flash.core.classification import ClassificationTask from flash.core.data.io.input import DataKeys, ServeInput +from flash.core.data.io.output import Output from flash.core.data.io.output_transform import OutputTransform +from flash.core.model import Task from flash.core.registry import FlashRegistry from flash.core.serve import Composition from flash.core.utilities.imports import _KORNIA_AVAILABLE, _TM_GREATER_EQUAL_0_7_0, requires @@ -31,13 +33,12 @@ METRICS_TYPE, OPTIMIZER_TYPE, OUTPUT_TRANSFORM_TYPE, - OUTPUT_TYPE, ) from flash.image.segmentation.backbones import SEMANTIC_SEGMENTATION_BACKBONES from flash.image.segmentation.heads import SEMANTIC_SEGMENTATION_HEADS from flash.image.segmentation.input import SemanticSegmentationDeserializer from flash.image.segmentation.input_transform import SemanticSegmentationInputTransform -from flash.image.segmentation.output import SegmentationLabelsOutput +from flash.image.segmentation.output import SEMANTIC_SEGMENTATION_OUTPUTS if _KORNIA_AVAILABLE: import kornia as K @@ -83,8 +84,8 @@ class SemanticSegmentation(ClassificationTask): output_transform_cls = SemanticSegmentationOutputTransform backbones: FlashRegistry = SEMANTIC_SEGMENTATION_BACKBONES - heads: FlashRegistry = SEMANTIC_SEGMENTATION_HEADS + outputs: FlashRegistry = Task.outputs + SEMANTIC_SEGMENTATION_OUTPUTS required_extras: str = "image" @@ -102,7 +103,6 @@ def __init__( metrics: METRICS_TYPE = None, learning_rate: float = 1e-3, multi_label: bool = False, - output: OUTPUT_TYPE = None, output_transform: OUTPUT_TRANSFORM_TYPE = None, ) -> None: if metrics is None: @@ -122,7 +122,6 @@ def __init__( lr_scheduler=lr_scheduler, metrics=metrics, learning_rate=learning_rate, - output=output or SegmentationLabelsOutput(), output_transform=output_transform or self.output_transform_cls(), ) @@ -191,8 +190,9 @@ def serve( input_cls: Optional[Type[ServeInput]] = SemanticSegmentationDeserializer, transform: INPUT_TRANSFORM_TYPE = SemanticSegmentationInputTransform, transform_kwargs: Optional[Dict] = None, + output: Optional[Union[str, Output]] = None, ) -> Composition: - return super().serve(host, port, sanity_check, input_cls, transform, transform_kwargs) + return super().serve(host, port, sanity_check, input_cls, transform, transform_kwargs, output) @staticmethod def _ci_benchmark_fn(history: List[Dict[str, Any]]): diff --git a/flash/image/segmentation/output.py b/flash/image/segmentation/output.py index e69c923792..412cf943c8 100644 --- a/flash/image/segmentation/output.py +++ b/flash/image/segmentation/output.py @@ -18,8 +18,9 @@ import torch import flash -from flash.core.data.io.input import DataKeys, ImageLabelsMap +from flash.core.data.io.input import DataKeys from flash.core.data.io.output import Output +from flash.core.registry import FlashRegistry from flash.core.utilities.imports import ( _FIFTYONE_AVAILABLE, _KORNIA_AVAILABLE, @@ -27,6 +28,7 @@ lazy_import, requires, ) +from flash.core.utilities.providers import _FIFTYONE if _FIFTYONE_AVAILABLE: fol = lazy_import("fiftyone.core.labels") @@ -46,13 +48,17 @@ K = None +SEMANTIC_SEGMENTATION_OUTPUTS = FlashRegistry("outputs") + + +@SEMANTIC_SEGMENTATION_OUTPUTS(name="labels") class SegmentationLabelsOutput(Output): """A :class:`.Output` which converts the model outputs to the label of the argmax classification per pixel in the image for semantic segmentation tasks. Args: labels_map: A dictionary that map the labels ids to pixel intensities. - visualize: Wether to visualize the image labels. + visualize: Whether to visualize the image labels. """ @requires("image") @@ -83,8 +89,6 @@ def create_random_labels_map(num_classes: int) -> Dict[int, Tuple[int, int, int] @requires("matplotlib") def _visualize(self, labels): - if self.labels_map is None: - self.labels_map = self.get_state(ImageLabelsMap).labels_map labels_vis = self.labels_to_image(labels, self.labels_map) labels_vis = K.utils.tensor_to_image(labels_vis) plt.imshow(labels_vis) @@ -100,6 +104,7 @@ def transform(self, sample: Dict[str, torch.Tensor]) -> torch.Tensor: return labels.tolist() +@SEMANTIC_SEGMENTATION_OUTPUTS(name="fiftyone", providers=_FIFTYONE) class FiftyOneSegmentationLabelsOutput(SegmentationLabelsOutput): """A :class:`.Output` which converts the model outputs to FiftyOne segmentation format. @@ -116,7 +121,7 @@ def __init__( self, labels_map: Optional[Dict[int, Tuple[int, int, int]]] = None, visualize: bool = False, - return_filepath: bool = False, + return_filepath: bool = True, ): super().__init__(labels_map=labels_map, visualize=visualize) diff --git a/flash/image/style_transfer/data.py b/flash/image/style_transfer/data.py index 90753a27fd..9400f39e99 100644 --- a/flash/image/style_transfer/data.py +++ b/flash/image/style_transfer/data.py @@ -17,7 +17,6 @@ import torch from flash.core.data.data_module import DataModule -from flash.core.data.data_pipeline import DataPipelineState from flash.core.data.io.input import Input from flash.core.utilities.imports import _IMAGE_TESTING from flash.core.utilities.stages import RunningStage @@ -104,7 +103,6 @@ def from_files( """ ds_kw = dict( - data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, ) @@ -195,7 +193,6 @@ def from_folders( """ ds_kw = dict( - data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, ) @@ -260,7 +257,6 @@ def from_numpy( """ ds_kw = dict( - data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, ) @@ -325,7 +321,6 @@ def from_tensors( """ ds_kw = dict( - data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, ) diff --git a/flash/image/style_transfer/model.py b/flash/image/style_transfer/model.py index cebc3aec93..732d559a2d 100644 --- a/flash/image/style_transfer/model.py +++ b/flash/image/style_transfer/model.py @@ -20,7 +20,7 @@ from flash.core.model import Task from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _IMAGE_AVAILABLE -from flash.core.utilities.types import LR_SCHEDULER_TYPE, OPTIMIZER_TYPE, OUTPUT_TYPE +from flash.core.utilities.types import LR_SCHEDULER_TYPE, OPTIMIZER_TYPE from flash.image.style_transfer import STYLE_TRANSFER_BACKBONES if _IMAGE_AVAILABLE: @@ -61,7 +61,6 @@ class StyleTransfer(Task): optimizer: Optimizer to use for training. lr_scheduler: The LR scheduler to use during training. learning_rate: Learning rate to use for training, defaults to ``1e-3``. - output: The :class:`~flash.core.data.io.output.Output` to use when serializing prediction outputs. """ backbones: FlashRegistry = STYLE_TRANSFER_BACKBONES @@ -80,7 +79,6 @@ def __init__( optimizer: OPTIMIZER_TYPE = "Adam", lr_scheduler: LR_SCHEDULER_TYPE = None, learning_rate: float = 1e-3, - output: OUTPUT_TYPE = None, ): self.save_hyperparameters(ignore="style_image") @@ -110,7 +108,6 @@ def __init__( optimizer=optimizer, lr_scheduler=lr_scheduler, learning_rate=learning_rate, - output=output, ) self.perceptual_loss = perceptual_loss diff --git a/flash/pointcloud/detection/data.py b/flash/pointcloud/detection/data.py index b91317aa61..3a87369f97 100644 --- a/flash/pointcloud/detection/data.py +++ b/flash/pointcloud/detection/data.py @@ -16,7 +16,6 @@ from torch.utils.data import Dataset from flash.core.data.data_module import DataModule -from flash.core.data.data_pipeline import DataPipelineState from flash.core.data.io.input import BaseDataFormat, Input from flash.core.data.io.input_transform import InputTransform from flash.core.utilities.stages import RunningStage @@ -57,7 +56,6 @@ def from_folders( labels_folder_name=labels_folder_name, calibrations_folder_name=calibrations_folder_name, data_format=data_format, - data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, ) @@ -89,7 +87,6 @@ def from_files( labels_folder_name=labels_folder_name, calibrations_folder_name=calibrations_folder_name, data_format=data_format, - data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, ) @@ -116,7 +113,6 @@ def from_datasets( ) -> "PointCloudObjectDetectorData": ds_kw = dict( - data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, ) diff --git a/flash/pointcloud/detection/model.py b/flash/pointcloud/detection/model.py index a1621ff1d9..c796c9b7ac 100644 --- a/flash/pointcloud/detection/model.py +++ b/flash/pointcloud/detection/model.py @@ -19,21 +19,15 @@ from torch.utils.data import DataLoader, Sampler from flash.core.data.io.input import DataKeys, Input -from flash.core.data.io.output import Output -from flash.core.data.states import CollateFn from flash.core.model import Task from flash.core.registry import FlashRegistry from flash.core.utilities.apply_func import get_callable_dict -from flash.core.utilities.types import LOSS_FN_TYPE, LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE, OUTPUT_TYPE +from flash.core.utilities.types import LOSS_FN_TYPE, LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE from flash.pointcloud.detection.backbones import POINTCLOUD_OBJECT_DETECTION_BACKBONES __FILE_EXAMPLE__ = "pointcloud_detection" -class PointCloudObjectDetectorOutput(Output): - pass - - class PointCloudObjectDetector(Task): """The ``PointCloudObjectDetector`` is a :class:`~flash.core.classification.ClassificationTask` that classifies pointcloud data. @@ -49,7 +43,6 @@ class PointCloudObjectDetector(Task): metrics: Any metrics to use with this :class:`~flash.core.model.Task`. If ``None``, a default will be selected by the :class:`~flash.core.classification.ClassificationTask` depending on the ``multi_label`` argument. learning_rate: The learning rate for the optimizer. - output: The :class:`~flash.core.data.io.output.Output` to use when formatting prediction outputs. lambda_loss_cls: The value to scale the loss classification. lambda_loss_bbox: The value to scale the bounding boxes loss. lambda_loss_dir: The value to scale the bounding boxes direction loss. @@ -68,7 +61,6 @@ def __init__( lr_scheduler: LR_SCHEDULER_TYPE = None, metrics: METRICS_TYPE = None, learning_rate: float = 1e-2, - output: OUTPUT_TYPE = PointCloudObjectDetectorOutput(), lambda_loss_cls: float = 1.0, lambda_loss_bbox: float = 1.0, lambda_loss_dir: float = 1.0, @@ -81,7 +73,6 @@ def __init__( lr_scheduler=lr_scheduler, metrics=metrics, learning_rate=learning_rate, - output=output, ) self.save_hyperparameters() @@ -92,12 +83,9 @@ def __init__( if isinstance(backbone, tuple): self.backbone, out_features = backbone else: - self.model, out_features, collate_fn = self.backbones.get(backbone)(**backbone_kwargs) + self.model, out_features, self.collate_fn = self.backbones.get(backbone)(**backbone_kwargs) self.backbone = self.model.backbone self.neck = self.model.neck - self.set_state(CollateFn(collate_fn)) - self.set_state(CollateFn(collate_fn)) - self.set_state(CollateFn(collate_fn)) self.loss_fn = get_callable_dict(self.model.loss) if __FILE_EXAMPLE__ not in sys.argv[0]: diff --git a/flash/pointcloud/segmentation/data.py b/flash/pointcloud/segmentation/data.py index 8b15a2e6dc..3afb8aa000 100644 --- a/flash/pointcloud/segmentation/data.py +++ b/flash/pointcloud/segmentation/data.py @@ -16,7 +16,6 @@ from torch.utils.data import Dataset from flash.core.data.data_module import DataModule -from flash.core.data.data_pipeline import DataPipelineState from flash.core.data.io.input import Input from flash.core.data.io.input_transform import InputTransform from flash.core.utilities.stages import RunningStage @@ -45,7 +44,6 @@ def from_folders( ) -> "PointCloudSegmentationData": ds_kw = dict( - data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, ) @@ -69,7 +67,6 @@ def from_files( ) -> "PointCloudSegmentationData": ds_kw = dict( - data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, ) @@ -96,7 +93,6 @@ def from_datasets( ) -> "PointCloudSegmentationData": ds_kw = dict( - data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, ) diff --git a/flash/pointcloud/segmentation/model.py b/flash/pointcloud/segmentation/model.py index eb9525c072..f551a8dbe5 100644 --- a/flash/pointcloud/segmentation/model.py +++ b/flash/pointcloud/segmentation/model.py @@ -20,11 +20,9 @@ from flash.core.classification import ClassificationTask from flash.core.data.io.input import DataKeys, Input -from flash.core.data.io.output import Output -from flash.core.data.states import CollateFn from flash.core.registry import FlashRegistry from flash.core.utilities.imports import _POINTCLOUD_AVAILABLE, _TM_GREATER_EQUAL_0_7_0 -from flash.core.utilities.types import LOSS_FN_TYPE, LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE, OUTPUT_TYPE +from flash.core.utilities.types import LOSS_FN_TYPE, LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE from flash.pointcloud.segmentation.backbones import POINTCLOUD_SEGMENTATION_BACKBONES if _POINTCLOUD_AVAILABLE: @@ -37,10 +35,6 @@ from torchmetrics import IoU as JaccardIndex -class PointCloudSegmentationOutput(Output): - pass - - class PointCloudSegmentation(ClassificationTask): """The ``PointCloudClassifier`` is a :class:`~flash.core.classification.ClassificationTask` that classifies pointcloud data. @@ -59,7 +53,6 @@ class PointCloudSegmentation(ClassificationTask): by the :class:`~flash.core.classification.ClassificationTask` depending on the ``multi_label`` argument. learning_rate: The learning rate for the optimizer. multi_label: If ``True``, this will be treated as a multi-label classification problem. - output: The :class:`~flash.core.data.io.output.Output` to use when formatting prediction outputs. """ backbones: FlashRegistry = POINTCLOUD_SEGMENTATION_BACKBONES @@ -78,7 +71,6 @@ def __init__( metrics: METRICS_TYPE = None, learning_rate: float = 1e-2, multi_label: bool = False, - output: OUTPUT_TYPE = PointCloudSegmentationOutput(), ): import flash @@ -93,7 +85,6 @@ def __init__( metrics=metrics, learning_rate=learning_rate, multi_label=multi_label, - output=output, ) self.save_hyperparameters() @@ -104,11 +95,10 @@ def __init__( if isinstance(backbone, tuple): self.backbone, out_features = backbone else: - self.backbone, out_features, collate_fn = self.backbones.get(backbone)(**backbone_kwargs) + self.backbone, out_features, self.collate_fn = self.backbones.get(backbone)(**backbone_kwargs) # replace latest layer if not flash._IS_TESTING: self.backbone.fc = nn.Identity() - self.set_state(CollateFn(collate_fn)) self.head = nn.Identity() if flash._IS_TESTING else (head or nn.Linear(out_features, num_classes)) diff --git a/flash/tabular/classification/cli.py b/flash/tabular/classification/cli.py index dd29d299ca..186a21738f 100644 --- a/flash/tabular/classification/cli.py +++ b/flash/tabular/classification/cli.py @@ -50,11 +50,12 @@ def tabular_classification(): }, finetune=False, datamodule_attributes={ + "parameters", "embedding_sizes", - "categorical_fields", - "num_features", "cat_dims", + "num_features", "num_classes", + "labels", }, ) diff --git a/flash/tabular/classification/data.py b/flash/tabular/classification/data.py index 2980654a25..518b45f35b 100644 --- a/flash/tabular/classification/data.py +++ b/flash/tabular/classification/data.py @@ -13,7 +13,6 @@ # limitations under the License. from typing import Any, Dict, List, Optional, Type, Union -from flash.core.data.data_pipeline import DataPipelineState from flash.core.data.io.input import Input from flash.core.data.io.input_transform import INPUT_TRANSFORM_TYPE, InputTransform from flash.core.utilities.imports import _PANDAS_AVAILABLE, _TABULAR_TESTING @@ -157,7 +156,6 @@ def from_data_frame( >>> del predict_data """ ds_kw = dict( - data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, categorical_fields=categorical_fields, @@ -167,8 +165,8 @@ def from_data_frame( ) train_input = input_cls(RunningStage.TRAINING, train_data_frame, transform=train_transform, **ds_kw) - ds_kw["parameters"] = train_input.parameters if train_input else parameters + ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) return cls( train_input, @@ -299,7 +297,6 @@ def from_csv( >>> os.remove("predict_data.csv") """ ds_kw = dict( - data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, categorical_fields=categorical_fields, @@ -309,8 +306,8 @@ def from_csv( ) train_input = input_cls(RunningStage.TRAINING, train_file, transform=train_transform, **ds_kw) - ds_kw["parameters"] = train_input.parameters if train_input else parameters + ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) return cls( train_input, diff --git a/flash/tabular/classification/input.py b/flash/tabular/classification/input.py index 3484b45920..a73490b366 100644 --- a/flash/tabular/classification/input.py +++ b/flash/tabular/classification/input.py @@ -15,6 +15,7 @@ from flash.core.data.io.classification_input import ClassificationInputMixin from flash.core.data.io.input import DataKeys +from flash.core.data.utilities.classification import TargetFormatter from flash.core.data.utilities.data_frame import read_csv, resolve_targets from flash.core.utilities.imports import _PANDAS_AVAILABLE from flash.tabular.input import TabularDataFrameInput @@ -33,12 +34,13 @@ def load_data( numerical_fields: Optional[Union[str, List[str]]] = None, target_fields: Optional[Union[str, List[str]]] = None, parameters: Dict[str, Any] = None, + target_formatter: Optional[TargetFormatter] = None, ): cat_vars, num_vars = self.preprocess(data_frame, categorical_fields, numerical_fields, parameters) if not self.predicting: targets = resolve_targets(data_frame, target_fields) - self.load_target_metadata(targets) + self.load_target_metadata(targets, target_formatter=target_formatter) return [{DataKeys.INPUT: (c, n), DataKeys.TARGET: t} for c, n, t in zip(cat_vars, num_vars, targets)] else: return [{DataKeys.INPUT: (c, n)} for c, n in zip(cat_vars, num_vars)] @@ -57,6 +59,9 @@ def load_data( numerical_fields: Optional[Union[str, List[str]]] = None, target_fields: Optional[Union[str, List[str]]] = None, parameters: Dict[str, Any] = None, + target_formatter: Optional[TargetFormatter] = None, ): if file is not None: - return super().load_data(read_csv(file), categorical_fields, numerical_fields, target_fields, parameters) + return super().load_data( + read_csv(file), categorical_fields, numerical_fields, target_fields, parameters, target_formatter + ) diff --git a/flash/tabular/classification/model.py b/flash/tabular/classification/model.py index 5a2015744c..0d1c0a2c77 100644 --- a/flash/tabular/classification/model.py +++ b/flash/tabular/classification/model.py @@ -12,13 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import partial -from typing import Any, Callable, Dict, List, Optional, Type +from typing import Any, Callable, Dict, List, Optional, Type, Union from torch.nn import functional as F from flash.core.classification import ClassificationAdapterTask from flash.core.data.io.input import ServeInput from flash.core.data.io.input_transform import InputTransform +from flash.core.data.io.output import Output from flash.core.integrations.pytorch_tabular.backbones import PYTORCH_TABULAR_BACKBONES from flash.core.registry import FlashRegistry from flash.core.serve import Composition @@ -32,8 +33,8 @@ class TabularClassifier(ClassificationAdapterTask): :ref:`tabular_classification`. Args: - embedding_sizes: Number of columns in table (not including target column). - categorical_fields: Number of classes to classify. + parameters: The parameters computed from the training data (can be obtained from the ``parameters`` attribute of + the ``TabularClassificationData`` object containing your training data). embedding_sizes: List of (num_classes, emb_dim) to form categorical embeddings. cat_dims: Number of distinct values for each categorical column num_features: Number of columns in table @@ -55,11 +56,12 @@ class TabularClassifier(ClassificationAdapterTask): def __init__( self, + parameters: Dict[str, Any], embedding_sizes: list, - categorical_fields: list, cat_dims: list, num_features: int, num_classes: int, + labels: Optional[List[str]] = None, backbone: str = "tabnet", loss_fn: Callable = F.cross_entropy, optimizer: OPTIMIZER_TYPE = "Adam", @@ -69,12 +71,15 @@ def __init__( **backbone_kwargs, ): self.save_hyperparameters() + + self._parameters = parameters + metadata = self.backbones.get(backbone, with_metadata=True) adapter = metadata["metadata"]["adapter"].from_task( self, task_type="classification", embedding_sizes=embedding_sizes, - categorical_fields=categorical_fields, + categorical_fields=parameters["categorical_fields"], cat_dims=cat_dims, num_features=num_features, output_dim=num_classes, @@ -88,6 +93,7 @@ def __init__( optimizer=optimizer, lr_scheduler=lr_scheduler, learning_rate=learning_rate, + labels=labels, ) @staticmethod @@ -98,8 +104,8 @@ def _ci_benchmark_fn(history: List[Dict[str, Any]]): @classmethod def from_data(cls, datamodule, **kwargs) -> "TabularClassifier": model = cls( + parameters=datamodule.parameters, embedding_sizes=datamodule.embedding_sizes, - categorical_fields=datamodule.categorical_fields, cat_dims=datamodule.cat_dims, num_features=datamodule.num_features, num_classes=datamodule.num_classes, @@ -116,8 +122,10 @@ def serve( input_cls: Optional[Type[ServeInput]] = TabularDeserializer, transform: INPUT_TRANSFORM_TYPE = InputTransform, transform_kwargs: Optional[Dict] = None, + output: Optional[Union[str, Output]] = None, parameters: Optional[Dict[str, Any]] = None, ) -> Composition: + parameters = parameters or self._parameters return super().serve( - host, port, sanity_check, partial(input_cls, parameters=parameters), transform, transform_kwargs + host, port, sanity_check, partial(input_cls, parameters=parameters), transform, transform_kwargs, output ) diff --git a/flash/tabular/classification/utils.py b/flash/tabular/classification/utils.py index 6ca4ba281e..578b34036c 100644 --- a/flash/tabular/classification/utils.py +++ b/flash/tabular/classification/utils.py @@ -19,6 +19,7 @@ if _PANDAS_AVAILABLE: import pandas as pd + from pandas import Series from pandas.core.frame import DataFrame else: DataFrame = None @@ -31,10 +32,12 @@ def _impute(df: DataFrame, num_cols: List) -> DataFrame: def _compute_normalization(df: DataFrame, num_cols: List) -> Tuple: - return df[num_cols].mean(), df[num_cols].std() + return df[num_cols].mean().to_dict(), df[num_cols].std().to_dict() -def _normalize(df: DataFrame, num_cols: List, mean: DataFrame, std: DataFrame) -> DataFrame: +def _normalize(df: DataFrame, num_cols: List, mean: Dict, std: Dict) -> DataFrame: + mean = Series(mean) + std = Series(std) df[num_cols] = (df[num_cols] - mean) / std return df @@ -66,8 +69,8 @@ def _pre_transform( num_cols: List[str], cat_cols: List[str], codes: Dict, - mean: DataFrame, - std: DataFrame, + mean: Dict, + std: Dict, ) -> DataFrame: df = _impute(df, num_cols) df = _normalize(df, num_cols, mean=mean, std=std) diff --git a/flash/tabular/forecasting/data.py b/flash/tabular/forecasting/data.py index 91b065549e..f5fc8417bb 100644 --- a/flash/tabular/forecasting/data.py +++ b/flash/tabular/forecasting/data.py @@ -17,10 +17,8 @@ from flash.core.data.callback import BaseDataFetcher from flash.core.data.data_module import DataModule -from flash.core.data.data_pipeline import DataPipelineState from flash.core.data.io.input import Input from flash.core.data.io.input_transform import INPUT_TRANSFORM_TYPE, InputTransform -from flash.core.data.io.output_transform import OutputTransform from flash.core.utilities.imports import _PANDAS_AVAILABLE, _TABULAR_TESTING from flash.core.utilities.stages import RunningStage from flash.tabular.forecasting.input import TabularForecastingDataFrameInput @@ -72,7 +70,6 @@ def from_data_frame( sampler: Optional[Type[Sampler]] = None, pin_memory: bool = True, persistent_workers: bool = True, - output_transform: Optional[OutputTransform] = None, **input_kwargs: Any, ) -> "TabularForecastingData": """Creates a :class:`~flash.tabular.forecasting.data.TabularForecastingData` object from the given data @@ -169,7 +166,6 @@ def from_data_frame( """ ds_kw = dict( - data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, time_idx=time_idx, @@ -180,7 +176,6 @@ def from_data_frame( ) train_input = input_cls(RunningStage.TRAINING, train_data_frame, transform=train_transform, **ds_kw) - ds_kw["parameters"] = train_input.parameters if train_input else parameters return cls( @@ -195,5 +190,4 @@ def from_data_frame( sampler=sampler, pin_memory=pin_memory, persistent_workers=persistent_workers, - output_transform=output_transform, ) diff --git a/flash/tabular/forecasting/input.py b/flash/tabular/forecasting/input.py index 33f54a962b..7d16497e79 100644 --- a/flash/tabular/forecasting/input.py +++ b/flash/tabular/forecasting/input.py @@ -12,13 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. from copy import copy -from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union from pytorch_lightning.utilities.exceptions import MisconfigurationException from flash.core.data.io.input import DataKeys, Input -from flash.core.data.properties import ProcessState from flash.core.utilities.imports import _FORECASTING_AVAILABLE, _PANDAS_AVAILABLE, requires if _PANDAS_AVAILABLE: @@ -30,14 +28,6 @@ from pytorch_forecasting import TimeSeriesDataSet -@dataclass(unsafe_hash=True, frozen=True) -class TimeSeriesDataSetParametersState(ProcessState): - """A :class:`~flash.core.data.properties.ProcessState` containing ``labels``, a mapping from class index to - label.""" - - parameters: Optional[Dict[str, Any]] - - class TabularForecastingDataFrameInput(Input): @requires("tabular") def load_data( @@ -58,18 +48,15 @@ def load_data( # Add some sample data so that we can recreate the `TimeSeriesDataSet` later on parameters["data_sample"] = data.iloc[[0]].to_dict() - self.set_state(TimeSeriesDataSetParametersState(parameters)) self.parameters = parameters else: - parameters_state = self.get_state(TimeSeriesDataSetParametersState) - parameters = parameters or (parameters_state.parameters if parameters_state is not None else None) - parameters = copy(parameters) if parameters is None: raise MisconfigurationException( "Loading data for evaluation or inference requires parameters from the train data. Either " "construct the train data at the same time as evaluation and inference or provide the train " "`datamodule.parameters` to `from_data_frame` in the `parameters` argument." ) + parameters = copy(parameters) parameters.pop("data_sample") time_series_dataset = TimeSeriesDataSet.from_parameters( parameters, diff --git a/flash/tabular/input.py b/flash/tabular/input.py index 7ca846c9f3..7415be9350 100644 --- a/flash/tabular/input.py +++ b/flash/tabular/input.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass from io import StringIO from typing import Any, Dict, List, Optional, Union @@ -20,7 +19,6 @@ from flash.core.data.io.input import DataKeys, Input from flash.core.data.process import Deserializer -from flash.core.data.properties import ProcessState from flash.core.data.utilities.data_frame import read_csv from flash.core.utilities.imports import _PANDAS_AVAILABLE from flash.tabular.classification.utils import ( @@ -37,13 +35,6 @@ DataFrame = object -@dataclass(unsafe_hash=True, frozen=True) -class TabularParametersState(ProcessState): - """A :class:`~flash.core.data.properties.ProcessState` containing tabular data ``parameters``.""" - - parameters: Optional[Dict[str, Any]] - - class TabularDataFrameInput(Input): @staticmethod def _sanetize_fields( @@ -92,17 +83,12 @@ def preprocess( if self.training: categorical_fields, numerical_fields = self._sanetize_fields(categorical_fields, numerical_fields) parameters = self.compute_parameters(df, numerical_fields, categorical_fields) - - self.set_state(TabularParametersState(parameters)) - else: - parameters_state = self.get_state(TabularParametersState) - parameters = parameters or (parameters_state.parameters if parameters_state is not None else None) - if parameters is None: - raise MisconfigurationException( - "Loading tabular data for evaluation or inference requires parameters from the train data. Either " - "construct the train data at the same time as evaluation and inference or provide the train " - "`datamodule.parameters` in the `parameters` argument." - ) + elif parameters is None: + raise MisconfigurationException( + "Loading tabular data for evaluation or inference requires parameters from the train data. Either " + "construct the train data at the same time as evaluation and inference or provide the train " + "`datamodule.parameters` in the `parameters` argument." + ) self.parameters = parameters @@ -131,20 +117,8 @@ def __init__(self, *args, parameters: Optional[Dict[str, Any]] = None, **kwargs) self._parameters = parameters super().__init__(*args, **kwargs) - @property - def parameters(self) -> Dict[str, Any]: - if self._parameters is not None: - return self._parameters - parameters_state = self.get_state(TabularParametersState) - if parameters_state is not None and parameters_state.parameters is not None: - return parameters_state.parameters - raise MisconfigurationException( - "Tabular tasks must previously have been trained in order to support serving or the `parameters` argument " - "must be provided to the `serve` method." - ) - def serve_load_sample(self, data: str) -> Any: - parameters = self.parameters + parameters = self._parameters df = read_csv(StringIO(data)) df = _pre_transform( @@ -166,7 +140,7 @@ def serve_load_sample(self, data: str) -> Any: @property def example_input(self) -> str: - parameters = self.parameters + parameters = self._parameters row = {} for cat_col in parameters["categorical_fields"]: diff --git a/flash/tabular/regression/data.py b/flash/tabular/regression/data.py index 61af99afe0..065f8fd16f 100644 --- a/flash/tabular/regression/data.py +++ b/flash/tabular/regression/data.py @@ -13,7 +13,6 @@ # limitations under the License. from typing import Any, Dict, List, Optional, Type, Union -from flash.core.data.data_pipeline import DataPipelineState from flash.core.data.io.input import Input from flash.core.data.io.input_transform import INPUT_TRANSFORM_TYPE, InputTransform from flash.core.utilities.imports import _PANDAS_AVAILABLE, _TABULAR_TESTING @@ -152,7 +151,6 @@ def from_data_frame( >>> del predict_data """ ds_kw = dict( - data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, categorical_fields=categorical_fields, @@ -162,7 +160,6 @@ def from_data_frame( ) train_input = input_cls(RunningStage.TRAINING, train_data_frame, transform=train_transform, **ds_kw) - ds_kw["parameters"] = train_input.parameters if train_input else parameters return cls( @@ -288,7 +285,6 @@ def from_csv( >>> os.remove("predict_data.csv") """ ds_kw = dict( - data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, categorical_fields=categorical_fields, @@ -298,7 +294,6 @@ def from_csv( ) train_input = input_cls(RunningStage.TRAINING, train_file, transform=train_transform, **ds_kw) - ds_kw["parameters"] = train_input.parameters if train_input else parameters return cls( diff --git a/flash/tabular/regression/model.py b/flash/tabular/regression/model.py index 07e08aa482..ae66dc30d7 100644 --- a/flash/tabular/regression/model.py +++ b/flash/tabular/regression/model.py @@ -12,12 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import partial -from typing import Any, Callable, Dict, Optional, Type +from typing import Any, Callable, Dict, Optional, Type, Union from torch.nn import functional as F from flash.core.data.io.input import ServeInput from flash.core.data.io.input_transform import InputTransform +from flash.core.data.io.output import Output from flash.core.integrations.pytorch_tabular.backbones import PYTORCH_TABULAR_BACKBONES from flash.core.registry import FlashRegistry from flash.core.regression import RegressionAdapterTask @@ -32,8 +33,8 @@ class TabularRegressor(RegressionAdapterTask): :ref:`tabular_classification`. Args: - embedding_sizes: Number of columns in table (not including target column). - categorical_fields: Number of classes to classify. + parameters: The parameters computed from the training data (can be obtained from the ``parameters`` attribute of + the ``TabularRegressionData`` object containing your training data). embedding_sizes: List of (num_classes, emb_dim) to form categorical embeddings. cat_dims: Number of distinct values for each categorical column num_features: Number of columns in table @@ -54,8 +55,8 @@ class TabularRegressor(RegressionAdapterTask): def __init__( self, + parameters: Dict[str, Any], embedding_sizes: list, - categorical_fields: list, cat_dims: list, num_features: int, backbone: str = "tabnet", @@ -67,12 +68,15 @@ def __init__( **backbone_kwargs ): self.save_hyperparameters() + + self._parameters = parameters + metadata = self.backbones.get(backbone, with_metadata=True) adapter = metadata["metadata"]["adapter"].from_task( self, task_type="regression", embedding_sizes=embedding_sizes, - categorical_fields=categorical_fields, + categorical_fields=parameters["categorical_fields"], cat_dims=cat_dims, num_features=num_features, output_dim=1, @@ -91,8 +95,8 @@ def __init__( @classmethod def from_data(cls, datamodule, **kwargs) -> "TabularRegressor": model = cls( + parameters=datamodule.parameters, embedding_sizes=datamodule.embedding_sizes, - categorical_fields=datamodule.categorical_fields, cat_dims=datamodule.cat_dims, num_features=datamodule.num_features, **kwargs @@ -108,8 +112,10 @@ def serve( input_cls: Optional[Type[ServeInput]] = TabularDeserializer, transform: INPUT_TRANSFORM_TYPE = InputTransform, transform_kwargs: Optional[Dict] = None, + output: Optional[Union[str, Output]] = None, parameters: Optional[Dict[str, Any]] = None, ) -> Composition: + parameters = parameters or self._parameters return super().serve( - host, port, sanity_check, partial(input_cls, parameters=parameters), transform, transform_kwargs + host, port, sanity_check, partial(input_cls, parameters=parameters), transform, transform_kwargs, output ) diff --git a/flash/template/classification/data.py b/flash/template/classification/data.py index 809bd753df..f03f3898be 100644 --- a/flash/template/classification/data.py +++ b/flash/template/classification/data.py @@ -19,10 +19,10 @@ from flash.core.data.base_viz import BaseVisualization from flash.core.data.callback import BaseDataFetcher from flash.core.data.data_module import DataModule -from flash.core.data.data_pipeline import DataPipelineState from flash.core.data.io.classification_input import ClassificationInputMixin from flash.core.data.io.input import DataKeys, Input from flash.core.data.io.input_transform import InputTransform +from flash.core.data.utilities.classification import TargetFormatter from flash.core.data.utilities.samples import to_samples from flash.core.utilities.imports import _SKLEARN_AVAILABLE from flash.core.utilities.stages import RunningStage @@ -38,13 +38,17 @@ class TemplateNumpyClassificationInput(Input, ClassificationInputMixin): """An example data source that records ``num_features`` on the dataset.""" def load_data( - self, examples: Collection[np.ndarray], targets: Optional[Sequence[Any]] = None + self, + examples: Collection[np.ndarray], + targets: Optional[Sequence[Any]] = None, + target_formatter: Optional[TargetFormatter] = None, ) -> Sequence[Dict[str, Any]]: """Sets the ``num_features`` attribute and calls ``super().load_data``. Args: examples: The ``np.ndarray`` (num_examples x num_features). targets: Associated targets. + target_formatter: Optionally provide a ``TargetFormatter`` to control how targets are formatted. Returns: A sequence of samples / sample metadata. @@ -52,7 +56,7 @@ def load_data( if not self.predicting and isinstance(examples, np.ndarray): self.num_features = examples.shape[1] if targets is not None: - self.load_target_metadata(targets) + self.load_target_metadata(targets, target_formatter=target_formatter) return to_samples(examples, targets) def load_sample(self, sample: Dict[str, Any]) -> Any: @@ -64,16 +68,17 @@ def load_sample(self, sample: Dict[str, Any]) -> Any: class TemplateSKLearnClassificationInput(TemplateNumpyClassificationInput): """An example data source that loads data from an sklearn data ``Bunch``.""" - def load_data(self, data: Bunch) -> Sequence[Dict[str, Any]]: + def load_data(self, data: Bunch, target_formatter: Optional[TargetFormatter] = None) -> Sequence[Dict[str, Any]]: """Gets the ``data`` and ``target`` attributes from the ``Bunch`` and passes them to ``super().load_data``. Args: data: The scikit-learn data ``Bunch``. + target_formatter: Optionally provide a ``TargetFormatter`` to control how targets are formatted. Returns: A sequence of samples / sample metadata. """ - return super().load_data(data.data, data.target) + return super().load_data(data.data, data.target, target_formatter=target_formatter) def predict_load_data(self, data: Bunch) -> Sequence[Dict[str, Any]]: """Avoid including targets when predicting. @@ -160,15 +165,31 @@ def from_numpy( """ ds_kw = dict( - data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, ) + train_input = input_cls(RunningStage.TRAINING, train_data, train_targets, transform=train_transform, **ds_kw) + target_formatter = getattr(train_input, "target_formatter", None) + return cls( - input_cls(RunningStage.TRAINING, train_data, train_targets, transform=train_transform, **ds_kw), - input_cls(RunningStage.VALIDATING, val_data, val_targets, transform=val_transform, **ds_kw), - input_cls(RunningStage.TESTING, test_data, test_targets, transform=test_transform, **ds_kw), + train_input, + input_cls( + RunningStage.VALIDATING, + val_data, + val_targets, + transform=val_transform, + target_formatter=target_formatter, + **ds_kw, + ), + input_cls( + RunningStage.TESTING, + test_data, + test_targets, + transform=test_transform, + target_formatter=target_formatter, + **ds_kw, + ), input_cls(RunningStage.PREDICTING, predict_data, transform=predict_transform, **ds_kw), **data_module_kwargs, ) @@ -208,17 +229,30 @@ def from_sklearn( Returns: The constructed data module. """ - ds_kw = dict( - data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, ) + train_input = input_cls(RunningStage.TRAINING, train_bunch, transform=train_transform, **ds_kw) + target_formatter = getattr(train_input, "target_formatter", None) + return cls( - input_cls(RunningStage.TRAINING, train_bunch, transform=train_transform, **ds_kw), - input_cls(RunningStage.VALIDATING, val_bunch, transform=val_transform, **ds_kw), - input_cls(RunningStage.TESTING, test_bunch, transform=test_transform, **ds_kw), + train_input, + input_cls( + RunningStage.VALIDATING, + val_bunch, + transform=val_transform, + target_formatter=target_formatter, + **ds_kw, + ), + input_cls( + RunningStage.TESTING, + test_bunch, + transform=test_transform, + target_formatter=target_formatter, + **ds_kw, + ), input_cls(RunningStage.PREDICTING, predict_bunch, transform=predict_transform, **ds_kw), **data_module_kwargs, ) diff --git a/flash/template/classification/model.py b/flash/template/classification/model.py index cc99c840dd..74f07d97c3 100644 --- a/flash/template/classification/model.py +++ b/flash/template/classification/model.py @@ -11,15 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch from torch import nn -from flash.core.classification import ClassificationTask, LabelsOutput +from flash.core.classification import ClassificationTask from flash.core.data.io.input import DataKeys from flash.core.registry import FlashRegistry -from flash.core.utilities.types import LOSS_FN_TYPE, LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE, OUTPUT_TYPE +from flash.core.utilities.types import LOSS_FN_TYPE, LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE from flash.template.classification.backbones import TEMPLATE_BACKBONES @@ -40,7 +40,6 @@ class TemplateSKLearnClassifier(ClassificationTask): by the :class:`~flash.core.classification.ClassificationTask` depending on the ``multi_label`` argument. learning_rate: The learning rate for the optimizer. multi_label: If ``True``, this will be treated as a multi-label classification problem. - output: The :class:`~flash.core.data.io.output.Output` to use when formatting prediction outputs. """ backbones: FlashRegistry = TEMPLATE_BACKBONES @@ -48,7 +47,8 @@ class TemplateSKLearnClassifier(ClassificationTask): def __init__( self, num_features: int, - num_classes: int, + num_classes: Optional[int] = None, + labels: Optional[List[str]] = None, backbone: Union[str, Tuple[nn.Module, int]] = "mlp-128", backbone_kwargs: Optional[Dict] = None, loss_fn: LOSS_FN_TYPE = None, @@ -57,8 +57,12 @@ def __init__( metrics: METRICS_TYPE = None, learning_rate: float = 1e-2, multi_label: bool = False, - output: OUTPUT_TYPE = LabelsOutput(), ): + self.save_hyperparameters() + + if labels is not None and num_classes is None: + num_classes = len(labels) + super().__init__( model=None, loss_fn=loss_fn, @@ -67,11 +71,10 @@ def __init__( metrics=metrics, learning_rate=learning_rate, multi_label=multi_label, - output=output, + num_classes=num_classes, + labels=labels, ) - self.save_hyperparameters() - if not backbone_kwargs: backbone_kwargs = {} diff --git a/flash/text/classification/cli.py b/flash/text/classification/cli.py index 511f195e97..f00ed1fa9d 100644 --- a/flash/text/classification/cli.py +++ b/flash/text/classification/cli.py @@ -62,7 +62,7 @@ def text_classification(): default_arguments={ "trainer.max_epochs": 3, }, - datamodule_attributes={"num_classes", "multi_label"}, + datamodule_attributes={"num_classes", "labels", "multi_label"}, ) cli.trainer.save_checkpoint("text_classification_model.pt") diff --git a/flash/core/integrations/transformers/input_transform.py b/flash/text/classification/collate.py similarity index 54% rename from flash/core/integrations/transformers/input_transform.py rename to flash/text/classification/collate.py index e7b590e938..bb4ec6be17 100644 --- a/flash/core/integrations/transformers/input_transform.py +++ b/flash/text/classification/collate.py @@ -12,23 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass -from typing import Any, Callable, Dict - -import torch from flash.core.data.io.input import DataKeys -from flash.core.data.io.input_transform import InputTransform +from flash.core.integrations.transformers.collate import TransformersCollate + +@dataclass(unsafe_hash=True) +class TextClassificationCollate(TransformersCollate): -@dataclass -class TransformersInputTransform(InputTransform): - @staticmethod - def to_tensor(sample: Dict[str, Any]) -> Dict[str, Any]: - for key in sample: - if key is DataKeys.METADATA: - continue - sample[key] = torch.as_tensor(sample[key]) - return sample + max_length: int = 128 - def per_sample_transform(self) -> Callable: - return self.to_tensor + def tokenize(self, sample): + tokenized_sample = self.tokenizer( + sample[DataKeys.INPUT], max_length=self.max_length, truncation=True, padding="max_length" + ) + tokenized_sample = tokenized_sample.data + if DataKeys.TARGET in sample: + tokenized_sample[DataKeys.TARGET] = sample[DataKeys.TARGET] + return tokenized_sample diff --git a/flash/text/classification/data.py b/flash/text/classification/data.py index a917a4eaef..ec842edf83 100644 --- a/flash/text/classification/data.py +++ b/flash/text/classification/data.py @@ -16,11 +16,10 @@ from pandas.core.frame import DataFrame from flash.core.data.data_module import DataModule -from flash.core.data.data_pipeline import DataPipelineState from flash.core.data.io.input import Input +from flash.core.data.io.input_transform import InputTransform from flash.core.data.utilities.paths import PATH_TYPE from flash.core.integrations.labelstudio.input import _parse_labelstudio_arguments, LabelStudioTextClassificationInput -from flash.core.integrations.transformers.input_transform import TransformersInputTransform from flash.core.utilities.imports import _TEXT_AVAILABLE, _TEXT_TESTING from flash.core.utilities.stages import RunningStage from flash.text.classification.input import ( @@ -46,7 +45,7 @@ class TextClassificationData(DataModule): """The ``TextClassificationData`` class is a :class:`~flash.core.data.data_module.DataModule` with a set of classmethods for loading data for text classification.""" - input_transform_cls = TransformersInputTransform + input_transform_cls = InputTransform @classmethod def from_csv( @@ -57,13 +56,12 @@ def from_csv( val_file: Optional[PATH_TYPE] = None, test_file: Optional[PATH_TYPE] = None, predict_file: Optional[PATH_TYPE] = None, - train_transform: Optional[Dict[str, Callable]] = TransformersInputTransform, - val_transform: Optional[Dict[str, Callable]] = TransformersInputTransform, - test_transform: Optional[Dict[str, Callable]] = TransformersInputTransform, - predict_transform: Optional[Dict[str, Callable]] = TransformersInputTransform, + train_transform: Optional[Dict[str, Callable]] = InputTransform, + val_transform: Optional[Dict[str, Callable]] = InputTransform, + test_transform: Optional[Dict[str, Callable]] = InputTransform, + predict_transform: Optional[Dict[str, Callable]] = InputTransform, input_cls: Type[Input] = TextClassificationCSVInput, transform_kwargs: Optional[Dict] = None, - max_length: int = 128, **data_module_kwargs: Any, ) -> "TextClassificationData": """Load the :class:`~flash.text.classification.data.TextClassificationData` from CSV files containing text @@ -89,7 +87,6 @@ def from_csv( predicting. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. - max_length: The maximum length to pad / truncate sequences to. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. @@ -145,7 +142,7 @@ def from_csv( 3 >>> datamodule.labels ['negative', 'neutral', 'positive'] - >>> model = TextClassifier(num_classes=datamodule.num_classes) + >>> model = TextClassifier(num_classes=datamodule.num_classes, backbone="prajjwal1/bert-tiny") >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE Training... @@ -161,14 +158,15 @@ def from_csv( ds_kw = dict( input_key=input_field, target_keys=target_fields, - max_length=max_length, - data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, ) + train_input = input_cls(RunningStage.TRAINING, train_file, transform=train_transform, **ds_kw) + ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) + return cls( - input_cls(RunningStage.TRAINING, train_file, transform=train_transform, **ds_kw), + train_input, input_cls(RunningStage.VALIDATING, val_file, transform=val_transform, **ds_kw), input_cls(RunningStage.TESTING, test_file, transform=test_transform, **ds_kw), input_cls(RunningStage.PREDICTING, predict_file, transform=predict_transform, **ds_kw), @@ -184,14 +182,13 @@ def from_json( val_file: Optional[PATH_TYPE] = None, test_file: Optional[PATH_TYPE] = None, predict_file: Optional[PATH_TYPE] = None, - train_transform: Optional[Dict[str, Callable]] = TransformersInputTransform, - val_transform: Optional[Dict[str, Callable]] = TransformersInputTransform, - test_transform: Optional[Dict[str, Callable]] = TransformersInputTransform, - predict_transform: Optional[Dict[str, Callable]] = TransformersInputTransform, + train_transform: Optional[Dict[str, Callable]] = InputTransform, + val_transform: Optional[Dict[str, Callable]] = InputTransform, + test_transform: Optional[Dict[str, Callable]] = InputTransform, + predict_transform: Optional[Dict[str, Callable]] = InputTransform, input_cls: Type[Input] = TextClassificationJSONInput, transform_kwargs: Optional[Dict] = None, field: Optional[str] = None, - max_length: int = 128, **data_module_kwargs: Any, ) -> "TextClassificationData": """Load the :class:`~flash.text.classification.data.TextClassificationData` from JSON files containing text @@ -218,7 +215,6 @@ def from_json( input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. field: To specify the field that holds the data in the JSON file. - max_length: The maximum length to pad / truncate sequences to. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. @@ -272,7 +268,7 @@ def from_json( 3 >>> datamodule.labels ['negative', 'neutral', 'positive'] - >>> model = TextClassifier(num_classes=datamodule.num_classes) + >>> model = TextClassifier(num_classes=datamodule.num_classes, backbone="prajjwal1/bert-tiny") >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE Training... @@ -289,14 +285,15 @@ def from_json( input_key=input_field, target_keys=target_fields, field=field, - max_length=max_length, - data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, ) + train_input = input_cls(RunningStage.TRAINING, train_file, transform=train_transform, **ds_kw) + ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) + return cls( - input_cls(RunningStage.TRAINING, train_file, transform=train_transform, **ds_kw), + train_input, input_cls(RunningStage.VALIDATING, val_file, transform=val_transform, **ds_kw), input_cls(RunningStage.TESTING, test_file, transform=test_transform, **ds_kw), input_cls(RunningStage.PREDICTING, predict_file, transform=predict_transform, **ds_kw), @@ -312,13 +309,12 @@ def from_parquet( val_file: Optional[PATH_TYPE] = None, test_file: Optional[PATH_TYPE] = None, predict_file: Optional[PATH_TYPE] = None, - train_transform: Optional[Dict[str, Callable]] = TransformersInputTransform, - val_transform: Optional[Dict[str, Callable]] = TransformersInputTransform, - test_transform: Optional[Dict[str, Callable]] = TransformersInputTransform, - predict_transform: Optional[Dict[str, Callable]] = TransformersInputTransform, + train_transform: Optional[Dict[str, Callable]] = InputTransform, + val_transform: Optional[Dict[str, Callable]] = InputTransform, + test_transform: Optional[Dict[str, Callable]] = InputTransform, + predict_transform: Optional[Dict[str, Callable]] = InputTransform, input_cls: Type[Input] = TextClassificationParquetInput, transform_kwargs: Optional[Dict] = None, - max_length: int = 128, **data_module_kwargs: Any, ) -> "TextClassificationData": """Load the :class:`~flash.text.classification.data.TextClassificationData` from PARQUET files containing @@ -344,7 +340,6 @@ def from_parquet( predicting. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. - max_length: The maximum length to pad / truncate sequences to. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. @@ -400,7 +395,7 @@ def from_parquet( 3 >>> datamodule.labels ['negative', 'neutral', 'positive'] - >>> model = TextClassifier(num_classes=datamodule.num_classes) + >>> model = TextClassifier(num_classes=datamodule.num_classes, backbone="prajjwal1/bert-tiny") >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE Training... @@ -416,14 +411,15 @@ def from_parquet( ds_kw = dict( input_key=input_field, target_keys=target_fields, - max_length=max_length, - data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, ) + train_input = input_cls(RunningStage.TRAINING, train_file, transform=train_transform, **ds_kw) + ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) + return cls( - input_cls(RunningStage.TRAINING, train_file, transform=train_transform, **ds_kw), + train_input, input_cls(RunningStage.VALIDATING, val_file, transform=val_transform, **ds_kw), input_cls(RunningStage.TESTING, test_file, transform=test_transform, **ds_kw), input_cls(RunningStage.PREDICTING, predict_file, transform=predict_transform, **ds_kw), @@ -439,13 +435,12 @@ def from_hf_datasets( val_hf_dataset: Optional[Dataset] = None, test_hf_dataset: Optional[Dataset] = None, predict_hf_dataset: Optional[Dataset] = None, - train_transform: Optional[Dict[str, Callable]] = TransformersInputTransform, - val_transform: Optional[Dict[str, Callable]] = TransformersInputTransform, - test_transform: Optional[Dict[str, Callable]] = TransformersInputTransform, - predict_transform: Optional[Dict[str, Callable]] = TransformersInputTransform, + train_transform: Optional[Dict[str, Callable]] = InputTransform, + val_transform: Optional[Dict[str, Callable]] = InputTransform, + test_transform: Optional[Dict[str, Callable]] = InputTransform, + predict_transform: Optional[Dict[str, Callable]] = InputTransform, input_cls: Type[Input] = TextClassificationInput, transform_kwargs: Optional[Dict] = None, - max_length: int = 128, **data_module_kwargs: Any, ) -> "TextClassificationData": """Load the :class:`~flash.text.classification.data.TextClassificationData` from Hugging Face ``Dataset`` @@ -471,7 +466,6 @@ def from_hf_datasets( predicting. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. - max_length: The maximum length to pad / truncate sequences to. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. @@ -508,7 +502,7 @@ def from_hf_datasets( 3 >>> datamodule.labels ['negative', 'neutral', 'positive'] - >>> model = TextClassifier(num_classes=datamodule.num_classes) + >>> model = TextClassifier(num_classes=datamodule.num_classes, backbone="prajjwal1/bert-tiny") >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE Training... @@ -524,14 +518,15 @@ def from_hf_datasets( ds_kw = dict( input_key=input_field, target_keys=target_fields, - max_length=max_length, - data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, ) + train_input = input_cls(RunningStage.TRAINING, train_hf_dataset, transform=train_transform, **ds_kw) + ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) + return cls( - input_cls(RunningStage.TRAINING, train_hf_dataset, transform=train_transform, **ds_kw), + train_input, input_cls(RunningStage.VALIDATING, val_hf_dataset, transform=val_transform, **ds_kw), input_cls(RunningStage.TESTING, test_hf_dataset, transform=test_transform, **ds_kw), input_cls(RunningStage.PREDICTING, predict_hf_dataset, transform=predict_transform, **ds_kw), @@ -547,13 +542,12 @@ def from_data_frame( val_data_frame: Optional[DataFrame] = None, test_data_frame: Optional[DataFrame] = None, predict_data_frame: Optional[DataFrame] = None, - train_transform: Optional[Dict[str, Callable]] = TransformersInputTransform, - val_transform: Optional[Dict[str, Callable]] = TransformersInputTransform, - test_transform: Optional[Dict[str, Callable]] = TransformersInputTransform, - predict_transform: Optional[Dict[str, Callable]] = TransformersInputTransform, + train_transform: Optional[Dict[str, Callable]] = InputTransform, + val_transform: Optional[Dict[str, Callable]] = InputTransform, + test_transform: Optional[Dict[str, Callable]] = InputTransform, + predict_transform: Optional[Dict[str, Callable]] = InputTransform, input_cls: Type[Input] = TextClassificationDataFrameInput, transform_kwargs: Optional[Dict] = None, - max_length: int = 128, **data_module_kwargs: Any, ) -> "TextClassificationData": """Load the :class:`~flash.text.classification.data.TextClassificationData` from Pandas ``DataFrame`` @@ -580,7 +574,6 @@ def from_data_frame( predicting. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. - max_length: The maximum length to pad / truncate sequences to. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. @@ -617,7 +610,7 @@ def from_data_frame( 3 >>> datamodule.labels ['negative', 'neutral', 'positive'] - >>> model = TextClassifier(num_classes=datamodule.num_classes) + >>> model = TextClassifier(num_classes=datamodule.num_classes, backbone="prajjwal1/bert-tiny") >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE Training... @@ -633,14 +626,15 @@ def from_data_frame( ds_kw = dict( input_key=input_field, target_keys=target_fields, - max_length=max_length, - data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, ) + train_input = input_cls(RunningStage.TRAINING, train_data_frame, transform=train_transform, **ds_kw) + ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) + return cls( - input_cls(RunningStage.TRAINING, train_data_frame, transform=train_transform, **ds_kw), + train_input, input_cls(RunningStage.VALIDATING, val_data_frame, transform=val_transform, **ds_kw), input_cls(RunningStage.TESTING, test_data_frame, transform=test_transform, **ds_kw), input_cls(RunningStage.PREDICTING, predict_data_frame, transform=predict_transform, **ds_kw), @@ -657,13 +651,12 @@ def from_lists( test_data: Optional[List[str]] = None, test_targets: Optional[Union[List[Any], List[List[Any]]]] = None, predict_data: Optional[List[str]] = None, - train_transform: Optional[Dict[str, Callable]] = TransformersInputTransform, - val_transform: Optional[Dict[str, Callable]] = TransformersInputTransform, - test_transform: Optional[Dict[str, Callable]] = TransformersInputTransform, - predict_transform: Optional[Dict[str, Callable]] = TransformersInputTransform, + train_transform: Optional[Dict[str, Callable]] = InputTransform, + val_transform: Optional[Dict[str, Callable]] = InputTransform, + test_transform: Optional[Dict[str, Callable]] = InputTransform, + predict_transform: Optional[Dict[str, Callable]] = InputTransform, input_cls: Type[Input] = TextClassificationListInput, transform_kwargs: Optional[Dict] = None, - max_length: int = 128, **data_module_kwargs: Any, ) -> "TextClassificationData": """Load the :class:`~flash.text.classification.data.TextClassificationData` from lists of text snippets and @@ -689,7 +682,6 @@ def from_lists( predicting. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. - max_length: The maximum length to pad / truncate sequences to. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. @@ -713,7 +705,7 @@ def from_lists( 3 >>> datamodule.labels ['negative', 'neutral', 'positive'] - >>> model = TextClassifier(num_classes=datamodule.num_classes) + >>> model = TextClassifier(num_classes=datamodule.num_classes, backbone="prajjwal1/bert-tiny") >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE Training... @@ -722,14 +714,15 @@ def from_lists( """ ds_kw = dict( - max_length=max_length, - data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, ) + train_input = input_cls(RunningStage.TRAINING, train_data, train_targets, transform=train_transform, **ds_kw) + ds_kw["target_formatter"] = getattr(train_input, "target_formatter", None) + return cls( - input_cls(RunningStage.TRAINING, train_data, train_targets, transform=train_transform, **ds_kw), + train_input, input_cls(RunningStage.VALIDATING, val_data, val_targets, transform=val_transform, **ds_kw), input_cls(RunningStage.TESTING, test_data, test_targets, transform=test_transform, **ds_kw), input_cls(RunningStage.PREDICTING, predict_data, transform=predict_transform, **ds_kw), @@ -749,15 +742,14 @@ def from_labelstudio( val_data_folder: str = None, test_data_folder: str = None, predict_data_folder: str = None, - train_transform: Optional[Dict[str, Callable]] = TransformersInputTransform, - val_transform: Optional[Dict[str, Callable]] = TransformersInputTransform, - test_transform: Optional[Dict[str, Callable]] = TransformersInputTransform, - predict_transform: Optional[Dict[str, Callable]] = TransformersInputTransform, + train_transform: Optional[Dict[str, Callable]] = InputTransform, + val_transform: Optional[Dict[str, Callable]] = InputTransform, + test_transform: Optional[Dict[str, Callable]] = InputTransform, + predict_transform: Optional[Dict[str, Callable]] = InputTransform, input_cls: Type[Input] = LabelStudioTextClassificationInput, transform_kwargs: Optional[Dict] = None, val_split: Optional[float] = None, multi_label: Optional[bool] = False, - max_length: int = 128, **data_module_kwargs: Any, ) -> "TextClassificationData": """Creates a :class:`~flash.core.data.data_module.DataModule` object @@ -789,7 +781,6 @@ def from_labelstudio( :class:`~flash.core.data.io.input_transform.InputTransform` hook names to callable transforms. val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. multi_label: Whether the labels are multi encoded. - max_length: The maximum sequence length. data_module_kwargs: Additional keyword arguments to use when constructing the datamodule. Returns: @@ -812,14 +803,15 @@ def from_labelstudio( ) ds_kw = dict( - max_length=max_length, - data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, ) + train_input = input_cls(RunningStage.TRAINING, train_data, transform=train_transform, **ds_kw) + ds_kw["parameters"] = getattr(train_input, "parameters", None) + return cls( - input_cls(RunningStage.TRAINING, train_data, transform=train_transform, **ds_kw), + train_input, input_cls(RunningStage.VALIDATING, val_data, transform=val_transform, **ds_kw), input_cls(RunningStage.TESTING, test_data, transform=test_transform, **ds_kw), input_cls(RunningStage.PREDICTING, predict_data, transform=predict_transform, **ds_kw), diff --git a/flash/text/classification/input.py b/flash/text/classification/input.py index 43b5a57d69..c4c96c8ffa 100644 --- a/flash/text/classification/input.py +++ b/flash/text/classification/input.py @@ -16,11 +16,10 @@ import pandas as pd -from flash.core.data.io.classification_input import ClassificationInputMixin, ClassificationState +from flash.core.data.io.classification_input import ClassificationInputMixin from flash.core.data.io.input import DataKeys, Input -from flash.core.data.utilities.classification import MultiBinaryTargetFormatter +from flash.core.data.utilities.classification import MultiBinaryTargetFormatter, TargetFormatter from flash.core.data.utilities.paths import PATH_TYPE -from flash.core.integrations.transformers.states import TransformersBackboneState from flash.core.utilities.imports import _TEXT_AVAILABLE, requires if _TEXT_AVAILABLE: @@ -44,20 +43,17 @@ def load_data( hf_dataset: Dataset, input_key: str, target_keys: Optional[Union[str, List[str]]] = None, - max_length: int = 128, + target_formatter: Optional[TargetFormatter] = None, ) -> Dataset: """Loads data into HuggingFace datasets.Dataset.""" - self.max_length = max_length - if not self.predicting: hf_dataset = hf_dataset.map(partial(self._resolve_target, target_keys)) targets = hf_dataset.to_dict()[DataKeys.TARGET] - self.load_target_metadata(targets) + self.load_target_metadata(targets, target_formatter=target_formatter) # If we had binary multi-class targets then we also know the labels (column names) if isinstance(self.target_formatter, MultiBinaryTargetFormatter) and isinstance(target_keys, List): - classification_state = self.get_state(ClassificationState) - self.set_state(ClassificationState(target_keys, classification_state.num_classes)) + self.labels = target_keys # remove extra columns extra_columns = set(hf_dataset.column_names) - {input_key, DataKeys.TARGET} @@ -69,13 +65,9 @@ def load_data( return hf_dataset def load_sample(self, sample: Dict[str, Any]) -> Any: - tokenized_sample = self.get_state(TransformersBackboneState).tokenizer( - sample[DataKeys.INPUT], max_length=self.max_length, truncation=True, padding="max_length" - ) - tokenized_sample = tokenized_sample.data if DataKeys.TARGET in sample: - tokenized_sample[DataKeys.TARGET] = self.format_target(sample[DataKeys.TARGET]) - return tokenized_sample + sample[DataKeys.TARGET] = self.format_target(sample[DataKeys.TARGET]) + return sample class TextClassificationCSVInput(TextClassificationInput): @@ -85,10 +77,10 @@ def load_data( csv_file: PATH_TYPE, input_key: str, target_keys: Optional[Union[str, List[str]]] = None, - max_length: int = 128, + target_formatter: Optional[TargetFormatter] = None, ) -> Dataset: dataset_dict = load_dataset("csv", data_files={"data": str(csv_file)}) - return super().load_data(dataset_dict["data"], input_key, target_keys, max_length) + return super().load_data(dataset_dict["data"], input_key, target_keys, target_formatter=target_formatter) class TextClassificationJSONInput(TextClassificationInput): @@ -99,10 +91,10 @@ def load_data( field: str, input_key: str, target_keys: Optional[Union[str, List[str]]] = None, - max_length: int = 128, + target_formatter: Optional[TargetFormatter] = None, ) -> Dataset: dataset_dict = load_dataset("json", data_files={"data": str(json_file)}, field=field) - return super().load_data(dataset_dict["data"], input_key, target_keys, max_length) + return super().load_data(dataset_dict["data"], input_key, target_keys, target_formatter=target_formatter) class TextClassificationDataFrameInput(TextClassificationInput): @@ -112,9 +104,11 @@ def load_data( data_frame: pd.DataFrame, input_key: str, target_keys: Optional[Union[str, List[str]]] = None, - max_length: int = 128, + target_formatter: Optional[TargetFormatter] = None, ) -> Dataset: - return super().load_data(Dataset.from_pandas(data_frame), input_key, target_keys, max_length) + return super().load_data( + Dataset.from_pandas(data_frame), input_key, target_keys, target_formatter=target_formatter + ) class TextClassificationParquetInput(TextClassificationInput): @@ -124,9 +118,11 @@ def load_data( parquet_file: PATH_TYPE, input_key: str, target_keys: Optional[Union[str, List[str]]] = None, - max_length: int = 128, + target_formatter: Optional[TargetFormatter] = None, ) -> Dataset: - return super().load_data(Dataset.from_parquet(str(parquet_file)), input_key, target_keys, max_length) + return super().load_data( + Dataset.from_parquet(str(parquet_file)), input_key, target_keys, target_formatter=target_formatter + ) class TextClassificationListInput(TextClassificationInput): @@ -135,10 +131,10 @@ def load_data( self, inputs: List[str], targets: Optional[List[Any]] = None, - max_length: int = 128, + target_formatter: Optional[TargetFormatter] = None, ) -> Dataset: if targets is not None: hf_dataset = Dataset.from_dict({DataKeys.INPUT: inputs, DataKeys.TARGET: targets}) else: hf_dataset = Dataset.from_dict({DataKeys.INPUT: inputs}) - return super().load_data(hf_dataset, DataKeys.INPUT, DataKeys.TARGET, max_length) + return super().load_data(hf_dataset, DataKeys.INPUT, DataKeys.TARGET, target_formatter=target_formatter) diff --git a/flash/text/classification/model.py b/flash/text/classification/model.py index 7dc4d6e220..fcc02283a8 100644 --- a/flash/text/classification/model.py +++ b/flash/text/classification/model.py @@ -13,15 +13,15 @@ # limitations under the License. import os import warnings -from typing import Any, Dict, List, Optional, Type +from typing import Any, Dict, List, Optional, Type, Union import torch from pytorch_lightning import Callback -from flash.core.classification import ClassificationTask, LabelsOutput +from flash.core.classification import ClassificationTask from flash.core.data.io.input import DataKeys, ServeInput -from flash.core.integrations.transformers.input_transform import TransformersInputTransform -from flash.core.integrations.transformers.states import TransformersBackboneState +from flash.core.data.io.input_transform import InputTransform +from flash.core.data.io.output import Output from flash.core.registry import FlashRegistry from flash.core.serve import Composition from flash.core.utilities.imports import _TRANSFORMERS_AVAILABLE, requires @@ -31,9 +31,9 @@ LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE, - OUTPUT_TYPE, ) from flash.text.classification.backbones import TEXT_CLASSIFIER_BACKBONES +from flash.text.classification.collate import TextClassificationCollate from flash.text.input import TextDeserializer from flash.text.ort_callback import ORTCallback @@ -48,7 +48,8 @@ class TextClassifier(ClassificationTask): Args: num_classes: Number of classes to classify. - backbone: A model to use to compute text features can be any BERT model from HuggingFace/transformersimage . + backbone: A model to use to compute text features can be any BERT model from HuggingFace/transformersimage. + max_length: The maximum length to pad / truncate sequences to. optimizer: Optimizer to use for training. lr_scheduler: The LR scheduler to use during training. metrics: Metrics to compute for training and evaluation. Can either be an metric from the `torchmetrics` @@ -57,7 +58,6 @@ class TextClassifier(ClassificationTask): `metric(preds,target)` and return a single scalar tensor. Defaults to :class:`torchmetrics.Accuracy`. learning_rate: Learning rate to use for training, defaults to `1e-3` multi_label: Whether the targets are multi-label or not. - output: The :class:`~flash.core.data.io.output.Output` to use when formatting prediction outputs. enable_ort: Enable Torch ONNX Runtime Optimization: https://onnxruntime.ai/docs/#onnx-runtime-for-training """ @@ -67,19 +67,23 @@ class TextClassifier(ClassificationTask): def __init__( self, - num_classes: int, + num_classes: Optional[int] = None, + labels: Optional[List[str]] = None, backbone: str = "prajjwal1/bert-medium", + max_length: int = 128, loss_fn: LOSS_FN_TYPE = None, optimizer: OPTIMIZER_TYPE = "Adam", lr_scheduler: LR_SCHEDULER_TYPE = None, metrics: METRICS_TYPE = None, learning_rate: float = 1e-2, multi_label: bool = False, - output: OUTPUT_TYPE = None, enable_ort: bool = False, ): self.save_hyperparameters() + if labels is not None and num_classes is None: + num_classes = len(labels) + os.environ["TOKENIZERS_PARALLELISM"] = "TRUE" # disable HF thousand warnings warnings.simplefilter("ignore") @@ -95,12 +99,12 @@ def __init__( metrics=metrics, learning_rate=learning_rate, multi_label=multi_label, - output=output or LabelsOutput(multi_label=multi_label), + labels=labels, ) self.enable_ort = enable_ort - self.set_state(TransformersBackboneState(backbone)) + self.max_length = max_length + self.collate_fn = TextClassificationCollate(backbone=backbone, max_length=max_length) self.model = self.backbones.get(backbone)(num_labels=num_classes) - self.save_hyperparameters() @property def backbone(self): @@ -140,7 +144,8 @@ def serve( port: int = 8000, sanity_check: bool = True, input_cls: Optional[Type[ServeInput]] = TextDeserializer, - transform: INPUT_TRANSFORM_TYPE = TransformersInputTransform, + transform: INPUT_TRANSFORM_TYPE = InputTransform, transform_kwargs: Optional[Dict] = None, + output: Optional[Union[str, Output]] = None, ) -> Composition: - return super().serve(host, port, sanity_check, input_cls, transform, transform_kwargs) + return super().serve(host, port, sanity_check, input_cls, transform, transform_kwargs, output) diff --git a/flash/text/embedding/model.py b/flash/text/embedding/model.py index 2fae923403..dcf5fa911f 100644 --- a/flash/text/embedding/model.py +++ b/flash/text/embedding/model.py @@ -19,11 +19,11 @@ import torch from pytorch_lightning import Callback -from flash.core.integrations.transformers.states import TransformersBackboneState from flash.core.model import Task from flash.core.registry import FlashRegistry, print_provider_info from flash.core.utilities.imports import _TEXT_AVAILABLE from flash.core.utilities.providers import _SENTENCE_TRANSFORMERS +from flash.text.classification.collate import TextClassificationCollate from flash.text.embedding.backbones import HUGGINGFACE_BACKBONES from flash.text.ort_callback import ORTCallback @@ -55,6 +55,7 @@ class TextEmbedder(Task): def __init__( self, backbone: str = "sentence-transformers/all-MiniLM-L6-v2", + max_length: int = 128, tokenizer_backbone: Optional[str] = None, tokenizer_kwargs: Optional[Dict[str, Any]] = None, enable_ort: bool = False, @@ -68,7 +69,10 @@ def __init__( if tokenizer_backbone is None: tokenizer_backbone = backbone - self.set_state(TransformersBackboneState(tokenizer_backbone, tokenizer_kwargs=tokenizer_kwargs)) + self.max_length = max_length + self.collate_fn = TextClassificationCollate( + backbone=tokenizer_backbone, max_length=max_length, tokenizer_kwargs=tokenizer_kwargs + ) self.model = self.backbones.get(backbone)() self.pooling = Pooling(self.model.config.hidden_size) self.enable_ort = enable_ort diff --git a/flash/text/input.py b/flash/text/input.py index 96a3171627..bbf62a48a1 100644 --- a/flash/text/input.py +++ b/flash/text/input.py @@ -11,23 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from torch import Tensor +from typing import Any, Dict +from flash.core.data.io.input import DataKeys from flash.core.data.process import Deserializer -from flash.core.integrations.transformers.states import TransformersBackboneState -from flash.core.utilities.imports import requires class TextDeserializer(Deserializer): - @requires("text") - def __init__(self, *args, max_length: int = 128, **kwargs): - super().__init__(*args, **kwargs) - self.max_length = max_length - - def serve_load_sample(self, text: str) -> Tensor: - return self.get_state(TransformersBackboneState).tokenizer( - text, max_length=self.max_length, truncation=True, padding="max_length" - ) + def serve_load_sample(self, text: str) -> Dict[str, Any]: + return {DataKeys.INPUT: text} @property def example_input(self) -> str: diff --git a/flash/text/question_answering/collate.py b/flash/text/question_answering/collate.py new file mode 100644 index 0000000000..9265c75a56 --- /dev/null +++ b/flash/text/question_answering/collate.py @@ -0,0 +1,173 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Adapted from: +# https://github.com/huggingface/transformers/blob/master/examples/pytorch/question-answering/run_qa_no_trainer.py +# https://github.com/huggingface/transformers/blob/master/examples/pytorch/question-answering/utils_qa.py + +from dataclasses import dataclass +from typing import Any, Optional, Union + +from flash.core.data.io.input import DataKeys +from flash.core.integrations.transformers.collate import TransformersCollate +from flash.core.model import Task + + +@dataclass(unsafe_hash=True) +class TextQuestionAnsweringCollate(TransformersCollate): + max_source_length: int = 384 + max_target_length: int = 30 + padding: Union[str, bool] = "max_length" + doc_stride: int = 128 + model: Optional[Task] = None + + def _prepare_train_features(self, tokenizer, samples: Any, tokenized_samples: Any, pad_on_right: bool): + # Since one example might give us several features if it has a long context, we need a map from a feature to + # its corresponding example. This key gives us just that. + sample_mapping = tokenized_samples.pop("overflow_to_sample_mapping") + # The offset mappings will give us a map from token to character position in the original context. This will + # help us compute the start_positions and end_positions. + offset_mapping = tokenized_samples.pop("offset_mapping") + + # Let's label those examples! + tokenized_samples["start_positions"] = [] + tokenized_samples["end_positions"] = [] + + for i, offsets in enumerate(offset_mapping): + # We will label impossible answers with the index of the CLS token. + input_ids = tokenized_samples["input_ids"][i] + cls_index = input_ids.index(tokenizer.cls_token_id) + + # Grab the sequence corresponding to that example (to know what is the context and what is the question). + sequence_ids = tokenized_samples.sequence_ids(i) + + # One example can give several spans, this is the index of the example containing this span of text. + sample_index = sample_mapping[i] + answers = samples["answer"][sample_index] + # If no answers are given, set the cls_index as answer. + if len(answers["answer_start"]) == 0: + tokenized_samples["start_positions"].append(cls_index) + tokenized_samples["end_positions"].append(cls_index) + else: + # Start/end character index of the answer in the text. + start_char = answers["answer_start"][0] + end_char = start_char + len(answers["text"][0]) + + # Start token index of the current span in the text. + token_start_index = 0 + while sequence_ids[token_start_index] != (1 if pad_on_right else 0): + token_start_index += 1 + + # End token index of the current span in the text. + token_end_index = len(input_ids) - 1 + while sequence_ids[token_end_index] != (1 if pad_on_right else 0): + token_end_index -= 1 + + # Detect if the answer is out of the span (in which case this feature is labeled with the CLS index). + if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char): + tokenized_samples["start_positions"].append(cls_index) + tokenized_samples["end_positions"].append(cls_index) + else: + # Otherwise move the token_start_index and token_end_index to the two ends of the answer. + # Note: we could go after the last offset if the answer is the last word (edge case). + while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char: + token_start_index += 1 + tokenized_samples["start_positions"].append(token_start_index - 1) + while offsets[token_end_index][1] >= end_char: + token_end_index -= 1 + tokenized_samples["end_positions"].append(token_end_index + 1) + + return tokenized_samples, sample_mapping, offset_mapping + + def _prepare_val_features(self, samples: Any, tokenized_samples: Any, pad_on_right: bool): + # Since one example might give us several features if it has a long context, we need a map from a feature to + # its corresponding example. This key gives us just that. + sample_mapping = tokenized_samples.pop("overflow_to_sample_mapping") + + # For evaluation, we will need to convert our predictions to substrings of the context, so we keep the + # corresponding example_id and we will store the offset mappings. + tokenized_samples["example_id"] = [] + tokenized_samples["context"] = [] + tokenized_samples["answer"] = [] + + for i in range(len(tokenized_samples["input_ids"])): + # Grab the sequence corresponding to that example (to know what is the context and what is the question). + sequence_ids = tokenized_samples.sequence_ids(i) + context_index = 1 if pad_on_right else 0 + + # One example can give several spans, this is the index of the example containing this span of text. + sample_index = sample_mapping[i] + tokenized_samples["example_id"].append(samples["id"][sample_index]) + tokenized_samples["context"].append(samples["context"][sample_index]) + if "answer" in samples: + tokenized_samples["answer"].append(samples["answer"][sample_index]) + + # Set to None the offset_mapping that are not part of the context so it's easy to determine if a token + # position is part of the context or not. + tokenized_samples["offset_mapping"][i] = [ + (o if sequence_ids[k] == context_index else None) + for k, o in enumerate(tokenized_samples["offset_mapping"][i]) + ] + + return tokenized_samples + + def tokenize(self, samples: Any): + pad_on_right = self.tokenizer.padding_side == "right" + + samples["question"] = [q.lstrip() for q in samples["question"]] + + tokenized_samples = self.tokenizer( + samples["question" if pad_on_right else "context"], + samples["context" if pad_on_right else "question"], + truncation="only_second" if pad_on_right else "only_first", + max_length=self.max_source_length, + stride=self.doc_stride, + return_overflowing_tokens=True, + return_offsets_mapping=True, + padding=self.padding, + ) + + if "answer" in samples: + tokenized_samples, _sample_mapping, _offset_mapping = self._prepare_train_features( + self.tokenizer, samples, tokenized_samples, pad_on_right + ) + + if not self.model.training: + if "answer" in samples: + tokenized_samples["overflow_to_sample_mapping"] = _sample_mapping + tokenized_samples["offset_mapping"] = _offset_mapping + + # InputTransform function for eval or predict + tokenized_samples = self._prepare_val_features(samples, tokenized_samples, pad_on_right) + + offset_mappings = tokenized_samples.pop("offset_mapping") + example_ids = tokenized_samples.pop("example_id") + contexts = tokenized_samples.pop("context") + + tokenized_samples[DataKeys.METADATA] = [] + for offset_mapping, example_id, context in zip(offset_mappings, example_ids, contexts): + tokenized_samples[DataKeys.METADATA].append( + {"context": context, "offset_mapping": offset_mapping, "example_id": example_id} + ) + if "answer" in tokenized_samples: + answers = tokenized_samples.pop("answer") + for index, answer in enumerate(answers): + tokenized_samples[DataKeys.METADATA][index]["answer"] = answer + + del offset_mappings + del example_ids + del contexts + del answers + + return tokenized_samples.data diff --git a/flash/text/question_answering/data.py b/flash/text/question_answering/data.py index 8918eb767f..80b717dfad 100644 --- a/flash/text/question_answering/data.py +++ b/flash/text/question_answering/data.py @@ -11,11 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional, Type, Union +from typing import Any, Dict, Optional, Type from flash.core.data.data_module import DataModule -from flash.core.data.data_pipeline import DataPipelineState from flash.core.data.io.input import Input +from flash.core.data.io.input_transform import InputTransform from flash.core.data.utilities.paths import PATH_TYPE from flash.core.utilities.imports import _TEXT_AVAILABLE from flash.core.utilities.stages import RunningStage @@ -26,8 +26,6 @@ QuestionAnsweringJSONInput, QuestionAnsweringSQuADInput, ) -from flash.text.question_answering.input_transform import QuestionAnsweringInputTransform -from flash.text.question_answering.output_transform import QuestionAnsweringOutputTransform # Skip doctests if requirements aren't available if not _TEXT_AVAILABLE: @@ -38,8 +36,7 @@ class QuestionAnsweringData(DataModule): """The ``QuestionAnsweringData`` class is a :class:`~flash.core.data.data_module.DataModule` with a set of classmethods for loading data for extractive question answering.""" - input_transform_cls = QuestionAnsweringInputTransform - output_transform_cls = QuestionAnsweringOutputTransform + input_transform_cls = InputTransform @classmethod def from_csv( @@ -48,19 +45,15 @@ def from_csv( val_file: Optional[PATH_TYPE] = None, test_file: Optional[PATH_TYPE] = None, predict_file: Optional[PATH_TYPE] = None, - train_transform: INPUT_TRANSFORM_TYPE = QuestionAnsweringInputTransform, - val_transform: INPUT_TRANSFORM_TYPE = QuestionAnsweringInputTransform, - test_transform: INPUT_TRANSFORM_TYPE = QuestionAnsweringInputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = QuestionAnsweringInputTransform, + train_transform: INPUT_TRANSFORM_TYPE = InputTransform, + val_transform: INPUT_TRANSFORM_TYPE = InputTransform, + test_transform: INPUT_TRANSFORM_TYPE = InputTransform, + predict_transform: INPUT_TRANSFORM_TYPE = InputTransform, input_cls: Type[Input] = QuestionAnsweringCSVInput, transform_kwargs: Optional[Dict] = None, - max_source_length: int = 384, - max_target_length: int = 30, - padding: Union[str, bool] = "max_length", question_column_name: str = "question", context_column_name: str = "context", answer_column_name: str = "answer", - doc_stride: int = 128, **data_module_kwargs: Any, ) -> "QuestionAnsweringData": """Load the :class:`~flash.text.question_answering.data.QuestionAnsweringData` from CSV files containing @@ -84,13 +77,9 @@ def from_csv( predicting. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. - max_source_length: Max length of the sequence to be considered during tokenization. - max_target_length: Max length of each answer to be produced. - padding: Padding type during tokenization. question_column_name: The key in the JSON file to recognize the question field. context_column_name: The key in the JSON file to recognize the context field. answer_column_name: The key in the JSON file to recognize the answer field. - doc_stride: The stride amount to be taken when splitting up a long document into chunks. Returns: The constructed :class:`~flash.text.question_answering.data.QuestionAnsweringData`. @@ -165,7 +154,7 @@ def from_csv( ... batch_size=2, ... ) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE Downloading... - >>> model = QuestionAnsweringTask() + >>> model = QuestionAnsweringTask(max_source_length=32, max_target_length=32) >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE Training... @@ -179,14 +168,9 @@ def from_csv( """ ds_kw = dict( - max_source_length=max_source_length, - max_target_length=max_target_length, - padding=padding, question_column_name=question_column_name, context_column_name=context_column_name, answer_column_name=answer_column_name, - doc_stride=doc_stride, - data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, ) @@ -206,20 +190,16 @@ def from_json( val_file: Optional[PATH_TYPE] = None, test_file: Optional[PATH_TYPE] = None, predict_file: Optional[PATH_TYPE] = None, - train_transform: INPUT_TRANSFORM_TYPE = QuestionAnsweringInputTransform, - val_transform: INPUT_TRANSFORM_TYPE = QuestionAnsweringInputTransform, - test_transform: INPUT_TRANSFORM_TYPE = QuestionAnsweringInputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = QuestionAnsweringInputTransform, + train_transform: INPUT_TRANSFORM_TYPE = InputTransform, + val_transform: INPUT_TRANSFORM_TYPE = InputTransform, + test_transform: INPUT_TRANSFORM_TYPE = InputTransform, + predict_transform: INPUT_TRANSFORM_TYPE = InputTransform, input_cls: Type[Input] = QuestionAnsweringJSONInput, transform_kwargs: Optional[Dict] = None, field: Optional[str] = None, - max_source_length: int = 384, - max_target_length: int = 30, - padding: Union[str, bool] = "max_length", question_column_name: str = "question", context_column_name: str = "context", answer_column_name: str = "answer", - doc_stride: int = 128, **data_module_kwargs: Any, ) -> "QuestionAnsweringData": """Load the :class:`~flash.text.question_answering.data.QuestionAnsweringData` from JSON files containing @@ -244,13 +224,9 @@ def from_json( input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. field: The field that holds the data in the JSON file. - max_source_length: Max length of the sequence to be considered during tokenization. - max_target_length: Max length of each answer to be produced. - padding: Padding type during tokenization. question_column_name: The key in the JSON file to recognize the question field. context_column_name: The key in the JSON file to recognize the context field. answer_column_name: The key in the JSON file to recognize the answer field. - doc_stride: The stride amount to be taken when splitting up a long document into chunks. Returns: The constructed :class:`~flash.text.question_answering.data.QuestionAnsweringData`. @@ -329,7 +305,7 @@ def from_json( ... batch_size=2, ... ) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE Downloading... - >>> model = QuestionAnsweringTask() + >>> model = QuestionAnsweringTask(max_source_length=32, max_target_length=32) >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE Training... @@ -344,14 +320,9 @@ def from_json( ds_kw = dict( field=field, - max_source_length=max_source_length, - max_target_length=max_target_length, - padding=padding, question_column_name=question_column_name, context_column_name=context_column_name, answer_column_name=answer_column_name, - doc_stride=doc_stride, - data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, ) @@ -371,19 +342,15 @@ def from_squad_v2( val_file: Optional[str] = None, test_file: Optional[str] = None, predict_file: Optional[str] = None, - train_transform: INPUT_TRANSFORM_TYPE = QuestionAnsweringInputTransform, - val_transform: INPUT_TRANSFORM_TYPE = QuestionAnsweringInputTransform, - test_transform: INPUT_TRANSFORM_TYPE = QuestionAnsweringInputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = QuestionAnsweringInputTransform, + train_transform: INPUT_TRANSFORM_TYPE = InputTransform, + val_transform: INPUT_TRANSFORM_TYPE = InputTransform, + test_transform: INPUT_TRANSFORM_TYPE = InputTransform, + predict_transform: INPUT_TRANSFORM_TYPE = InputTransform, input_cls: Type[Input] = QuestionAnsweringSQuADInput, transform_kwargs: Optional[Dict] = None, - max_source_length: int = 384, - max_target_length: int = 30, - padding: Union[str, bool] = "max_length", question_column_name: str = "question", context_column_name: str = "context", answer_column_name: str = "answer", - doc_stride: int = 128, **data_module_kwargs: Any, ) -> "QuestionAnsweringData": """Load the :class:`~flash.text.question_answering.data.QuestionAnsweringData` from JSON files containing @@ -407,13 +374,9 @@ def from_squad_v2( predicting. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. - max_source_length: Max length of the sequence to be considered during tokenization. - max_target_length: Max length of each answer to be produced. - padding: Padding type during tokenization. question_column_name: The key in the JSON file to recognize the question field. context_column_name: The key in the JSON file to recognize the context field. answer_column_name: The key in the JSON file to recognize the answer field. - doc_stride: The stride amount to be taken when splitting up a long document into chunks. Returns: The constructed data module. @@ -630,7 +593,7 @@ def from_squad_v2( ... predict_file="predict_data.json", ... batch_size=2, ... ) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE - >>> model = QuestionAnsweringTask() + >>> model = QuestionAnsweringTask(max_source_length=32, max_target_length=32) >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE Training... @@ -645,14 +608,9 @@ def from_squad_v2( """ ds_kw = dict( - max_source_length=max_source_length, - max_target_length=max_target_length, - padding=padding, question_column_name=question_column_name, context_column_name=context_column_name, answer_column_name=answer_column_name, - doc_stride=doc_stride, - data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, ) @@ -672,19 +630,15 @@ def from_dicts( val_data: Optional[Dict[str, Any]] = None, test_data: Optional[Dict[str, Any]] = None, predict_data: Optional[Dict[str, Any]] = None, - train_transform: INPUT_TRANSFORM_TYPE = QuestionAnsweringInputTransform, - val_transform: INPUT_TRANSFORM_TYPE = QuestionAnsweringInputTransform, - test_transform: INPUT_TRANSFORM_TYPE = QuestionAnsweringInputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = QuestionAnsweringInputTransform, + train_transform: INPUT_TRANSFORM_TYPE = InputTransform, + val_transform: INPUT_TRANSFORM_TYPE = InputTransform, + test_transform: INPUT_TRANSFORM_TYPE = InputTransform, + predict_transform: INPUT_TRANSFORM_TYPE = InputTransform, input_cls: Type[Input] = QuestionAnsweringDictionaryInput, transform_kwargs: Optional[Dict] = None, - max_source_length: int = 384, - max_target_length: int = 30, - padding: Union[str, bool] = "max_length", question_column_name: str = "question", context_column_name: str = "context", answer_column_name: str = "answer", - doc_stride: int = 128, **data_module_kwargs: Any, ) -> "QuestionAnsweringData": """Load the :class:`~flash.text.question_answering.data.QuestionAnsweringData` from Python dictionary @@ -708,13 +662,9 @@ def from_dicts( predicting. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. - max_source_length: Max length of the sequence to be considered during tokenization. - max_target_length: Max length of each answer to be produced. - padding: Padding type during tokenization. question_column_name: The key in the JSON file to recognize the question field. context_column_name: The key in the JSON file to recognize the context field. answer_column_name: The key in the JSON file to recognize the answer field. - doc_stride: The stride amount to be taken when splitting up a long document into chunks. Returns: The constructed :class:`~flash.text.question_answering.data.QuestionAnsweringData`. @@ -764,7 +714,7 @@ def from_dicts( ... predict_data=predict_data, ... batch_size=2, ... ) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE - >>> model = QuestionAnsweringTask() + >>> model = QuestionAnsweringTask(max_source_length=32, max_target_length=32) >>> trainer = Trainer(fast_dev_run=True) >>> trainer.fit(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE Training... @@ -778,14 +728,9 @@ def from_dicts( """ ds_kw = dict( - max_source_length=max_source_length, - max_target_length=max_target_length, - padding=padding, question_column_name=question_column_name, context_column_name=context_column_name, answer_column_name=answer_column_name, - doc_stride=doc_stride, - data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, ) diff --git a/flash/text/question_answering/input.py b/flash/text/question_answering/input.py index cbc0b918e9..70c1dd7239 100644 --- a/flash/text/question_answering/input.py +++ b/flash/text/question_answering/input.py @@ -17,13 +17,11 @@ # https://github.com/huggingface/transformers/blob/master/examples/pytorch/question-answering/utils_qa.py import json from pathlib import Path -from typing import Any, Callable, Dict, Union +from typing import Any, Dict import flash -from flash.core.data.batch import default_uncollate -from flash.core.data.io.input import DataKeys, Input +from flash.core.data.io.input import Input from flash.core.data.utilities.paths import PATH_TYPE -from flash.core.integrations.transformers.states import TransformersBackboneState from flash.core.utilities.imports import _TEXT_AVAILABLE, requires if _TEXT_AVAILABLE: @@ -33,149 +31,6 @@ class QuestionAnsweringInputBase(Input): - def _tokenize_fn(self, samples: Any) -> Callable: - tokenizer = self.get_state(TransformersBackboneState).tokenizer - pad_on_right = tokenizer.padding_side == "right" - - samples[self.question_column_name] = [q.lstrip() for q in samples[self.question_column_name]] - - tokenized_samples = tokenizer( - samples[self.question_column_name if pad_on_right else self.context_column_name], - samples[self.context_column_name if pad_on_right else self.question_column_name], - truncation="only_second" if pad_on_right else "only_first", - max_length=self.max_source_length, - stride=self.doc_stride, - return_overflowing_tokens=True, - return_offsets_mapping=True, - padding=self.padding, - ) - - if self.training: - # InputTransform function for training - tokenized_samples, _, _ = self._prepare_train_features(tokenizer, samples, tokenized_samples, pad_on_right) - else: - if self.validating or self.testing: - tokenized_samples, _sample_mapping, _offset_mapping = self._prepare_train_features( - tokenizer, samples, tokenized_samples, pad_on_right - ) - - tokenized_samples["overflow_to_sample_mapping"] = _sample_mapping - tokenized_samples["offset_mapping"] = _offset_mapping - - # InputTransform function for eval or predict - tokenized_samples = self._prepare_val_features(samples, tokenized_samples, pad_on_right) - - offset_mappings = tokenized_samples.pop("offset_mapping") - example_ids = tokenized_samples.pop("example_id") - contexts = tokenized_samples.pop("context") - answers = tokenized_samples.pop("answer") - - tokenized_samples[DataKeys.METADATA] = [] - for offset_mapping, example_id, context in zip(offset_mappings, example_ids, contexts): - tokenized_samples[DataKeys.METADATA].append( - {"context": context, "offset_mapping": offset_mapping, "example_id": example_id} - ) - if self.validating or self.testing: - for index, answer in enumerate(answers): - tokenized_samples[DataKeys.METADATA][index]["answer"] = answer - - del offset_mappings - del example_ids - del contexts - del answers - - return tokenized_samples - - def _prepare_train_features(self, tokenizer, samples: Any, tokenized_samples: Any, pad_on_right: bool): - # Since one example might give us several features if it has a long context, we need a map from a feature to - # its corresponding example. This key gives us just that. - sample_mapping = tokenized_samples.pop("overflow_to_sample_mapping") - # The offset mappings will give us a map from token to character position in the original context. This will - # help us compute the start_positions and end_positions. - offset_mapping = tokenized_samples.pop("offset_mapping") - - # Let's label those examples! - tokenized_samples["start_positions"] = [] - tokenized_samples["end_positions"] = [] - - for i, offsets in enumerate(offset_mapping): - # We will label impossible answers with the index of the CLS token. - input_ids = tokenized_samples["input_ids"][i] - cls_index = input_ids.index(tokenizer.cls_token_id) - - # Grab the sequence corresponding to that example (to know what is the context and what is the question). - sequence_ids = tokenized_samples.sequence_ids(i) - - # One example can give several spans, this is the index of the example containing this span of text. - sample_index = sample_mapping[i] - answers = samples[self.answer_column_name][sample_index] - # If no answers are given, set the cls_index as answer. - if len(answers["answer_start"]) == 0: - tokenized_samples["start_positions"].append(cls_index) - tokenized_samples["end_positions"].append(cls_index) - else: - # Start/end character index of the answer in the text. - start_char = answers["answer_start"][0] - end_char = start_char + len(answers["text"][0]) - - # Start token index of the current span in the text. - token_start_index = 0 - while sequence_ids[token_start_index] != (1 if pad_on_right else 0): - token_start_index += 1 - - # End token index of the current span in the text. - token_end_index = len(input_ids) - 1 - while sequence_ids[token_end_index] != (1 if pad_on_right else 0): - token_end_index -= 1 - - # Detect if the answer is out of the span (in which case this feature is labeled with the CLS index). - if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char): - tokenized_samples["start_positions"].append(cls_index) - tokenized_samples["end_positions"].append(cls_index) - else: - # Otherwise move the token_start_index and token_end_index to the two ends of the answer. - # Note: we could go after the last offset if the answer is the last word (edge case). - while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char: - token_start_index += 1 - tokenized_samples["start_positions"].append(token_start_index - 1) - while offsets[token_end_index][1] >= end_char: - token_end_index -= 1 - tokenized_samples["end_positions"].append(token_end_index + 1) - - return tokenized_samples, sample_mapping, offset_mapping - - def _prepare_val_features(self, samples: Any, tokenized_samples: Any, pad_on_right: bool): - # Since one example might give us several features if it has a long context, we need a map from a feature to - # its corresponding example. This key gives us just that. - sample_mapping = tokenized_samples.pop("overflow_to_sample_mapping") - - # For evaluation, we will need to convert our predictions to substrings of the context, so we keep the - # corresponding example_id and we will store the offset mappings. - tokenized_samples["example_id"] = [] - tokenized_samples["context"] = [] - tokenized_samples["answer"] = [] - - for i in range(len(tokenized_samples["input_ids"])): - # Grab the sequence corresponding to that example (to know what is the context and what is the question). - sequence_ids = tokenized_samples.sequence_ids(i) - context_index = 1 if pad_on_right else 0 - - # One example can give several spans, this is the index of the example containing this span of text. - sample_index = sample_mapping[i] - tokenized_samples["example_id"].append(samples["id"][sample_index]) - tokenized_samples["context"].append(samples["context"][sample_index]) - if self.validating or self.testing: - tokenized_samples["answer"].append(samples["answer"][sample_index]) - - # Set to None the offset_mapping that are not part of the context so it's easy to determine if a token - # position is part of the context or not. - tokenized_samples["offset_mapping"][i] = [ - (o if sequence_ids[k] == context_index else None) - for k, o in enumerate(tokenized_samples["offset_mapping"][i]) - ] - - return tokenized_samples - def _reshape_answer_column(self, sample: Any): text = sample.pop("answer_text") start = sample.pop("answer_start") @@ -190,25 +45,14 @@ def _reshape_answer_column(self, sample: Any): def load_data( self, hf_dataset: Dataset, - max_source_length: int = 384, - max_target_length: int = 30, - padding: Union[str, bool] = "max_length", question_column_name: str = "question", context_column_name: str = "context", answer_column_name: str = "answer", - doc_stride: int = 128, ) -> Dataset: - self.max_source_length = max_source_length - self.max_target_length = max_target_length - self.padding = padding - self.question_column_name = question_column_name - self.context_column_name = context_column_name - self.answer_column_name = answer_column_name - self.doc_stride = doc_stride + column_names = hf_dataset.column_names if self.training or self.validating or self.testing: - if self.answer_column_name == "answer": - column_names = hf_dataset.column_names + if answer_column_name == "answer": if "answer" not in column_names: if "answer_text" in column_names and "answer_start" in column_names: hf_dataset = hf_dataset.map(self._reshape_answer_column, batched=False) @@ -217,48 +61,42 @@ def load_data( """Dataset must contain either \"answer\" key as dict type or "answer_text" and "answer_start" as string and integer types.""" ) - if not isinstance(hf_dataset[self.answer_column_name][0], Dict): + if not isinstance(hf_dataset[answer_column_name][0], Dict): raise TypeError( - f'{self.answer_column_name} column should be of type dict with keys "text" and "answer_start"' + f'{answer_column_name} column should be of type dict with keys "text" and "answer_start"' ) + if answer_column_name in column_names and answer_column_name != "answer": + hf_dataset = hf_dataset.rename_column(answer_column_name, "answer") + + if question_column_name in column_names and question_column_name != "question": + hf_dataset = hf_dataset.rename_column(question_column_name, "question") + + if context_column_name in column_names and context_column_name != "context": + hf_dataset = hf_dataset.rename_column(context_column_name, "context") + if flash._IS_TESTING: # NOTE: must subset in this way to return a Dataset hf_dataset = hf_dataset.select(range(20)) return hf_dataset - def load_sample(self, sample: Dict[str, Any]) -> Any: - sample = {key: [value] for key, value in sample.items()} - tokenized_sample = self._tokenize_fn(sample).data - - # The tokenize function can return multiple outputs for each input. So we uncollate them here - return default_uncollate(tokenized_sample) - class QuestionAnsweringCSVInput(QuestionAnsweringInputBase): @requires("text") def load_data( self, csv_file: PATH_TYPE, - max_source_length: int = 384, - max_target_length: int = 30, - padding: Union[str, bool] = "max_length", question_column_name: str = "question", context_column_name: str = "context", answer_column_name: str = "answer", - doc_stride: int = 128, ) -> Dataset: dataset_dict = load_dataset("csv", data_files={"data": str(csv_file)}) return super().load_data( dataset_dict["data"], - max_source_length=max_source_length, - max_target_length=max_target_length, - padding=padding, question_column_name=question_column_name, context_column_name=context_column_name, answer_column_name=answer_column_name, - doc_stride=doc_stride, ) @@ -268,24 +106,16 @@ def load_data( self, json_file: PATH_TYPE, field: str, - max_source_length: int = 384, - max_target_length: int = 30, - padding: Union[str, bool] = "max_length", question_column_name: str = "question", context_column_name: str = "context", answer_column_name: str = "answer", - doc_stride: int = 128, ) -> Dataset: dataset_dict = load_dataset("json", data_files={"data": str(json_file)}, field=field) return super().load_data( dataset_dict["data"], - max_source_length=max_source_length, - max_target_length=max_target_length, - padding=padding, question_column_name=question_column_name, context_column_name=context_column_name, answer_column_name=answer_column_name, - doc_stride=doc_stride, ) @@ -293,23 +123,15 @@ class QuestionAnsweringDictionaryInput(QuestionAnsweringInputBase): def load_data( self, data: Dict[str, Any], - max_source_length: int = 384, - max_target_length: int = 30, - padding: Union[str, bool] = "max_length", question_column_name: str = "question", context_column_name: str = "context", answer_column_name: str = "answer", - doc_stride: int = 128, ) -> Dataset: return super().load_data( Dataset.from_dict(data), - max_source_length=max_source_length, - max_target_length=max_target_length, - padding=padding, question_column_name=question_column_name, context_column_name=context_column_name, answer_column_name=answer_column_name, - doc_stride=doc_stride, ) @@ -317,13 +139,9 @@ class QuestionAnsweringSQuADInput(QuestionAnsweringDictionaryInput): def load_data( self, json_file: PATH_TYPE, - max_source_length: int = 384, - max_target_length: int = 30, - padding: Union[str, bool] = "max_length", question_column_name: str = "question", context_column_name: str = "context", answer_column_name: str = "answer", - doc_stride: int = 128, ) -> Dataset: path = Path(json_file) with open(path, "rb") as f: @@ -358,11 +176,7 @@ def load_data( return super().load_data( data, - max_source_length=max_source_length, - max_target_length=max_target_length, - padding=padding, question_column_name=question_column_name, context_column_name=context_column_name, answer_column_name=answer_column_name, - doc_stride=doc_stride, ) diff --git a/flash/text/question_answering/input_transform.py b/flash/text/question_answering/input_transform.py deleted file mode 100644 index 604a0fe131..0000000000 --- a/flash/text/question_answering/input_transform.py +++ /dev/null @@ -1,41 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from dataclasses import dataclass -from typing import Any, Callable, Dict, List - -from torch.utils.data.dataloader import default_collate - -from flash.core.data.io.input import DataKeys -from flash.core.integrations.transformers.input_transform import TransformersInputTransform - - -@dataclass -class QuestionAnsweringInputTransform(TransformersInputTransform): - @staticmethod - def default_collate(samples: List[List[Dict[str, Any]]]) -> Dict[str, Any]: - # TODO: This should be handled by the InputTransformProcessor - chained_samples = [] - chained_metadata = [] - for s in samples: - for sample in s: - chained_metadata.append(sample.pop(DataKeys.METADATA, None)) - chained_samples.append(sample) - - result = default_collate(chained_samples) - if any(m is not None for m in chained_metadata): - result[DataKeys.METADATA] = chained_metadata - return result - - def collate(self) -> Callable: - return self.default_collate diff --git a/flash/text/question_answering/model.py b/flash/text/question_answering/model.py index d4c5d07a45..635cd1d200 100644 --- a/flash/text/question_answering/model.py +++ b/flash/text/question_answering/model.py @@ -30,14 +30,15 @@ from torchmetrics.text.rouge import ROUGEScore from flash.core.data.io.input import DataKeys -from flash.core.integrations.transformers.states import TransformersBackboneState from flash.core.model import Task from flash.core.registry import ExternalRegistry, FlashRegistry from flash.core.utilities.imports import _TEXT_AVAILABLE, _TM_GREATER_EQUAL_0_7_0 from flash.core.utilities.providers import _HUGGINGFACE from flash.core.utilities.types import LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE from flash.text.ort_callback import ORTCallback +from flash.text.question_answering.collate import TextQuestionAnsweringCollate from flash.text.question_answering.finetuning import _get_question_answering_bacbones_for_freezing +from flash.text.question_answering.output_transform import QuestionAnsweringOutputTransform if _TEXT_AVAILABLE: from transformers import AutoModelForQuestionAnswering @@ -67,6 +68,10 @@ class QuestionAnsweringTask(Task): Args: backbone: backbone model to use for the task. + max_source_length: Max length of the sequence to be considered during tokenization. + max_target_length: Max length of each answer to be produced. + padding: Padding type during tokenization. + doc_stride: The stride amount to be taken when splitting up a long document into chunks. loss_fn: Loss function for training. optimizer: Optimizer to use for training. lr_scheduler: The LR scheduler to use during training. @@ -90,7 +95,11 @@ class QuestionAnsweringTask(Task): def __init__( self, - backbone: str = "distilbert-base-uncased", + backbone: str = "sshleifer/tiny-distilbert-base-cased-distilled-squad", + max_source_length: int = 384, + max_target_length: int = 30, + padding: Union[str, bool] = "max_length", + doc_stride: int = 128, loss_fn: Optional[Union[Callable, Mapping, Sequence]] = None, optimizer: OPTIMIZER_TYPE = "Adam", lr_scheduler: LR_SCHEDULER_TYPE = None, @@ -99,28 +108,40 @@ def __init__( enable_ort: bool = False, n_best_size: int = 20, version_2_with_negative: bool = True, - max_answer_length: int = 30, null_score_diff_threshold: float = 0.0, use_stemmer: bool = True, ): + self.save_hyperparameters() + os.environ["TOKENIZERS_PARALLELISM"] = "TRUE" # disable HF thousand warnings warnings.simplefilter("ignore") # set os environ variable for multiprocesses os.environ["PYTHONWARNINGS"] = "ignore" + super().__init__( loss_fn=loss_fn, optimizer=optimizer, lr_scheduler=lr_scheduler, metrics=metrics, learning_rate=learning_rate, + output_transform=QuestionAnsweringOutputTransform(), ) - self.set_state(TransformersBackboneState(backbone)) + + self.collate_fn = TextQuestionAnsweringCollate( + backbone=backbone, + max_source_length=max_source_length, + max_target_length=max_target_length, + padding=padding, + doc_stride=doc_stride, + model=self, + ) + self.model = self.backbones.get(backbone)() self.enable_ort = enable_ort self.n_best_size = n_best_size self.version_2_with_negative = version_2_with_negative - self.max_answer_length = max_answer_length + self.max_target_length = max_target_length self.null_score_diff_threshold = null_score_diff_threshold self._initialize_model_specific_parameters() @@ -167,7 +188,7 @@ def _generate_answers(self, pred_start_logits, pred_end_logits, examples): -1 : -self.n_best_size - 1 : -1 ].tolist() - max_answer_length: int = 30 + max_answer_length = self.max_target_length for start_index in start_indexes: for end_index in end_indexes: # Don't consider out-of-scope answers, either because the indices are out of bounds or correspond diff --git a/flash/text/seq2seq/core/collate.py b/flash/text/seq2seq/core/collate.py new file mode 100644 index 0000000000..f8532b1982 --- /dev/null +++ b/flash/text/seq2seq/core/collate.py @@ -0,0 +1,45 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Union + +from flash.core.data.io.input import DataKeys +from flash.core.integrations.transformers.collate import TransformersCollate + + +@dataclass(unsafe_hash=True) +class TextSeq2SeqCollate(TransformersCollate): + max_source_length: int = 128 + max_target_length: int = 128 + padding: Union[str, bool] = "max_length" + + def tokenize(self, sample): + tokenized_sample = self.tokenizer( + sample[DataKeys.INPUT], + max_length=self.max_source_length, + padding=self.padding, + add_special_tokens=True, + truncation=True, + ) + tokenized_sample = tokenized_sample.data + if DataKeys.TARGET in sample: + with self.tokenizer.as_target_tokenizer(): + tokenized_sample[DataKeys.TARGET] = self.tokenizer( + sample[DataKeys.TARGET], + max_length=self.max_target_length, + padding=self.padding, + add_special_tokens=True, + truncation=True, + )["input_ids"] + return tokenized_sample diff --git a/flash/text/seq2seq/core/input.py b/flash/text/seq2seq/core/input.py index e655a3a35b..5bdc775558 100644 --- a/flash/text/seq2seq/core/input.py +++ b/flash/text/seq2seq/core/input.py @@ -11,12 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Union +from typing import List, Optional import flash from flash.core.data.io.input import DataKeys, Input from flash.core.data.utilities.paths import PATH_TYPE -from flash.core.integrations.transformers.states import TransformersBackboneState from flash.core.utilities.imports import _TEXT_AVAILABLE, requires if _TEXT_AVAILABLE: @@ -32,14 +31,7 @@ def load_data( hf_dataset: Dataset, input_key: str, target_key: Optional[str] = None, - max_source_length: int = 128, - max_target_length: int = 128, - padding: Union[str, bool] = "max_length", ) -> Dataset: - self.max_source_length = max_source_length - self.max_target_length = max_target_length - self.padding = padding - # remove extra columns extra_columns = set(hf_dataset.column_names) - {input_key, target_key} hf_dataset = hf_dataset.remove_columns(extra_columns) @@ -56,27 +48,6 @@ def load_data( return hf_dataset - def load_sample(self, sample: Dict[str, Any]) -> Any: - tokenizer = self.get_state(TransformersBackboneState).tokenizer - tokenized_sample = tokenizer( - sample[DataKeys.INPUT], - max_length=self.max_source_length, - padding=self.padding, - add_special_tokens=True, - truncation=True, - ) - tokenized_sample = tokenized_sample.data - if DataKeys.TARGET in sample: - with tokenizer.as_target_tokenizer(): - tokenized_sample[DataKeys.TARGET] = tokenizer( - sample[DataKeys.TARGET], - max_length=self.max_target_length, - padding=self.padding, - add_special_tokens=True, - truncation=True, - )["input_ids"] - return tokenized_sample - class Seq2SeqCSVInput(Seq2SeqInputBase): @requires("text") @@ -85,18 +56,12 @@ def load_data( csv_file: PATH_TYPE, input_key: str, target_key: Optional[str] = None, - max_source_length: int = 128, - max_target_length: int = 128, - padding: Union[str, bool] = "max_length", ) -> Dataset: dataset_dict = load_dataset("csv", data_files={"data": str(csv_file)}) return super().load_data( dataset_dict["data"], input_key, target_key, - max_source_length, - max_target_length, - padding, ) @@ -108,18 +73,12 @@ def load_data( field: str, input_key: str, target_key: Optional[str] = None, - max_source_length: int = 128, - max_target_length: int = 128, - padding: Union[str, bool] = "max_length", ) -> Dataset: dataset_dict = load_dataset("json", data_files={"data": str(json_file)}, field=field) return super().load_data( dataset_dict["data"], input_key, target_key, - max_source_length, - max_target_length, - padding, ) @@ -129,9 +88,6 @@ def load_data( self, inputs: List[str], targets: Optional[List[str]] = None, - max_source_length: int = 128, - max_target_length: int = 128, - padding: Union[str, bool] = "max_length", ) -> Dataset: if targets is not None: hf_dataset = Dataset.from_dict({DataKeys.INPUT: inputs, DataKeys.TARGET: targets}) @@ -141,7 +97,4 @@ def load_data( hf_dataset, DataKeys.INPUT, DataKeys.TARGET, - max_source_length, - max_target_length, - padding, ) diff --git a/flash/text/seq2seq/core/model.py b/flash/text/seq2seq/core/model.py index 0a8ba55c80..66f08e7839 100644 --- a/flash/text/seq2seq/core/model.py +++ b/flash/text/seq2seq/core/model.py @@ -22,9 +22,9 @@ from torch.nn import Module from flash.core.data.io.input import DataKeys, ServeInput +from flash.core.data.io.input_transform import InputTransform +from flash.core.data.io.output import Output from flash.core.data.io.output_transform import OutputTransform -from flash.core.integrations.transformers.input_transform import TransformersInputTransform -from flash.core.integrations.transformers.states import TransformersBackboneState from flash.core.model import Task from flash.core.registry import ExternalRegistry, FlashRegistry from flash.core.serve import Composition @@ -39,7 +39,7 @@ ) from flash.text.input import TextDeserializer from flash.text.ort_callback import ORTCallback -from flash.text.seq2seq.core.output_transform import Seq2SeqOutputTransform +from flash.text.seq2seq.core.collate import TextSeq2SeqCollate if _TEXT_AVAILABLE: from transformers import AutoModelForSeq2SeqLM @@ -72,12 +72,15 @@ class Seq2SeqTask(Task): """General Task for Sequence2Sequence. Args: + max_source_length: The maximum length to pad / truncate input sequences to. + max_target_length: The maximum length to pad / truncate target sequences to. + padding: The type of padding to apply. One of: "longest" or ``True``, "max_length", "do_not_pad" or + ``False``. loss_fn: Loss function for training optimizer: Optimizer to use for training. lr_scheduler: The LR scheduler to use during training. metrics: Metrics to compute for training and evaluation. Changing this argument currently has no effect learning_rate: Learning rate to use for training, defaults to `3e-4` - val_target_max_length: Maximum length of targets in validation. Defaults to `128` num_beams: Number of beams to use in validation when generating predictions. Defaults to `4` enable_ort: Enable Torch ONNX Runtime Optimization: https://onnxruntime.ai/docs/#onnx-runtime-for-training """ @@ -90,15 +93,17 @@ def __init__( self, backbone: str = "t5-small", tokenizer_kwargs: Optional[Dict[str, Any]] = None, + max_source_length: int = 128, + max_target_length: int = 128, + padding: Union[str, bool] = "max_length", loss_fn: LOSS_FN_TYPE = None, optimizer: OPTIMIZER_TYPE = "Adam", lr_scheduler: LR_SCHEDULER_TYPE = None, metrics: METRICS_TYPE = None, learning_rate: float = 5e-5, - val_target_max_length: Optional[int] = None, num_beams: Optional[int] = None, enable_ort: bool = False, - output_transform: Optional[OutputTransform] = Seq2SeqOutputTransform(), + output_transform: Optional[OutputTransform] = None, ): os.environ["TOKENIZERS_PARALLELISM"] = "TRUE" # disable HF thousand warnings @@ -113,15 +118,23 @@ def __init__( learning_rate=learning_rate, output_transform=output_transform, ) - self.set_state(TransformersBackboneState(backbone, tokenizer_kwargs=tokenizer_kwargs)) + + self.collate_fn = TextSeq2SeqCollate( + backbone=backbone, + tokenizer_kwargs=tokenizer_kwargs, + max_source_length=max_source_length, + max_target_length=max_target_length, + padding=padding, + ) self.model = self.backbones.get(backbone)() self.enable_ort = enable_ort - self.val_target_max_length = val_target_max_length + self.max_source_length = max_source_length + self.max_target_length = max_target_length self.num_beams = num_beams self._initialize_model_specific_parameters() def forward(self, x: Any) -> Any: - max_length = self.val_target_max_length if self.val_target_max_length else self.model.config.max_length + max_length = self.max_target_length num_beams = self.num_beams if self.num_beams else self.model.config.num_beams generated_tokens = self.model.generate( input_ids=x["input_ids"], attention_mask=x["attention_mask"], max_length=max_length, num_beams=num_beams @@ -151,6 +164,10 @@ def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0): def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0): self.common_step("test", batch) + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: + output = super().predict_step(batch, batch_idx, dataloader_idx) + return self.decode(output) + def compute_metrics(self, generated_tokens, batch, prefix): pass @@ -167,9 +184,9 @@ def _initialize_model_specific_parameters(self): rank_zero_info(f"Overriding model paramameters for {self.task} as defined within the model:\n {pars}") self.model.config.update(pars) - def tokenize_labels(self, labels: Tensor) -> List[str]: - label_str = self.get_state(TransformersBackboneState).tokenizer.batch_decode(labels, skip_special_tokens=True) - return [str.strip(s) for s in label_str] + def decode(self, tokens: Tensor) -> List[str]: + decoded_str = self.collate_fn.tokenizer.batch_decode(tokens, skip_special_tokens=True) + return [str.strip(s) for s in decoded_str] def modules_to_freeze(self) -> Union[Module, Iterable[Union[Module, Iterable]]]: """Return the module attributes of the model to be frozen.""" @@ -199,7 +216,8 @@ def serve( port: int = 8000, sanity_check: bool = True, input_cls: Optional[Type[ServeInput]] = TextDeserializer, - transform: INPUT_TRANSFORM_TYPE = TransformersInputTransform, + transform: INPUT_TRANSFORM_TYPE = InputTransform, transform_kwargs: Optional[Dict] = None, + output: Optional[Union[str, Output]] = None, ) -> Composition: - return super().serve(host, port, sanity_check, input_cls, transform, transform_kwargs) + return super().serve(host, port, sanity_check, input_cls, transform, transform_kwargs, output) diff --git a/flash/text/seq2seq/core/output_transform.py b/flash/text/seq2seq/core/output_transform.py deleted file mode 100644 index 3bb165b69b..0000000000 --- a/flash/text/seq2seq/core/output_transform.py +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import Any - -from flash.core.data.io.output_transform import OutputTransform -from flash.core.integrations.transformers.states import TransformersBackboneState -from flash.core.utilities.imports import requires - - -class Seq2SeqOutputTransform(OutputTransform): - def __init__(self): - super().__init__() - - self._backbone = None - self._tokenizer = None - - @requires("text") - def uncollate(self, generated_tokens: Any) -> Any: - tokenizer = self.get_state(TransformersBackboneState).tokenizer - pred_str = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) - pred_str = [str.strip(s) for s in pred_str] - return pred_str diff --git a/flash/text/seq2seq/summarization/data.py b/flash/text/seq2seq/summarization/data.py index 6be75b6d08..7ad9c5f8a9 100644 --- a/flash/text/seq2seq/summarization/data.py +++ b/flash/text/seq2seq/summarization/data.py @@ -11,18 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Type, Union +from typing import Any, Dict, List, Optional, Type from flash.core.data.data_module import DataModule -from flash.core.data.data_pipeline import DataPipelineState from flash.core.data.io.input import Input +from flash.core.data.io.input_transform import InputTransform from flash.core.data.utilities.paths import PATH_TYPE -from flash.core.integrations.transformers.input_transform import TransformersInputTransform from flash.core.utilities.imports import _TEXT_AVAILABLE, _TEXT_TESTING from flash.core.utilities.stages import RunningStage from flash.core.utilities.types import INPUT_TRANSFORM_TYPE from flash.text.seq2seq.core.input import Seq2SeqCSVInput, Seq2SeqInputBase, Seq2SeqJSONInput, Seq2SeqListInput -from flash.text.seq2seq.core.output_transform import Seq2SeqOutputTransform if _TEXT_AVAILABLE: from datasets import Dataset @@ -38,8 +36,7 @@ class SummarizationData(DataModule): """The ``SummarizationData`` class is a :class:`~flash.core.data.data_module.DataModule` with a set of classmethods for loading data for text summarization.""" - input_transform_cls = TransformersInputTransform - output_transform_cls = Seq2SeqOutputTransform + input_transform_cls = InputTransform @classmethod def from_csv( @@ -50,15 +47,12 @@ def from_csv( val_file: Optional[PATH_TYPE] = None, test_file: Optional[PATH_TYPE] = None, predict_file: Optional[PATH_TYPE] = None, - train_transform: INPUT_TRANSFORM_TYPE = TransformersInputTransform, - val_transform: INPUT_TRANSFORM_TYPE = TransformersInputTransform, - test_transform: INPUT_TRANSFORM_TYPE = TransformersInputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = TransformersInputTransform, + train_transform: INPUT_TRANSFORM_TYPE = InputTransform, + val_transform: INPUT_TRANSFORM_TYPE = InputTransform, + test_transform: INPUT_TRANSFORM_TYPE = InputTransform, + predict_transform: INPUT_TRANSFORM_TYPE = InputTransform, input_cls: Type[Input] = Seq2SeqCSVInput, transform_kwargs: Optional[Dict] = None, - max_source_length: int = 128, - max_target_length: int = 128, - padding: Union[str, bool] = "max_length", **data_module_kwargs: Any, ) -> "SummarizationData": """Load the :class:`~flash.text.seq2seq.summarization.data.SummarizationData` from CSV files containing @@ -83,10 +77,6 @@ def from_csv( predicting. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. - max_source_length: The maximum length to pad / truncate input sequences to. - max_target_length: The maximum length to pad / truncate target sequences to. - padding: The type of padding to apply. One of: "longest" or ``True``, "max_length", "do_not_pad" or - ``False``. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. @@ -152,10 +142,6 @@ def from_csv( ds_kw = dict( input_key=input_field, target_key=target_field, - max_source_length=max_source_length, - max_target_length=max_target_length, - padding=padding, - data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, ) @@ -177,16 +163,13 @@ def from_json( val_file: Optional[PATH_TYPE] = None, test_file: Optional[PATH_TYPE] = None, predict_file: Optional[PATH_TYPE] = None, - train_transform: INPUT_TRANSFORM_TYPE = TransformersInputTransform, - val_transform: INPUT_TRANSFORM_TYPE = TransformersInputTransform, - test_transform: INPUT_TRANSFORM_TYPE = TransformersInputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = TransformersInputTransform, + train_transform: INPUT_TRANSFORM_TYPE = InputTransform, + val_transform: INPUT_TRANSFORM_TYPE = InputTransform, + test_transform: INPUT_TRANSFORM_TYPE = InputTransform, + predict_transform: INPUT_TRANSFORM_TYPE = InputTransform, input_cls: Type[Input] = Seq2SeqJSONInput, transform_kwargs: Optional[Dict] = None, field: Optional[str] = None, - max_source_length: int = 128, - max_target_length: int = 128, - padding: Union[str, bool] = "max_length", **data_module_kwargs: Any, ) -> "SummarizationData": """Load the :class:`~flash.text.seq2seq.summarization.data.SummarizationData` from JSON files containing @@ -212,10 +195,6 @@ def from_json( input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. field: The field that holds the data in the JSON file. - max_source_length: The maximum length to pad / truncate input sequences to. - max_target_length: The maximum length to pad / truncate target sequences to. - padding: The type of padding to apply. One of: "longest" or ``True``, "max_length", "do_not_pad" or - ``False``. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. @@ -280,10 +259,6 @@ def from_json( input_key=input_field, target_key=target_field, field=field, - max_source_length=max_source_length, - max_target_length=max_target_length, - padding=padding, - data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, ) @@ -305,15 +280,12 @@ def from_hf_datasets( val_hf_dataset: Optional[Dataset] = None, test_hf_dataset: Optional[Dataset] = None, predict_hf_dataset: Optional[Dataset] = None, - train_transform: INPUT_TRANSFORM_TYPE = TransformersInputTransform, - val_transform: INPUT_TRANSFORM_TYPE = TransformersInputTransform, - test_transform: INPUT_TRANSFORM_TYPE = TransformersInputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = TransformersInputTransform, + train_transform: INPUT_TRANSFORM_TYPE = InputTransform, + val_transform: INPUT_TRANSFORM_TYPE = InputTransform, + test_transform: INPUT_TRANSFORM_TYPE = InputTransform, + predict_transform: INPUT_TRANSFORM_TYPE = InputTransform, input_cls: Type[Input] = Seq2SeqInputBase, transform_kwargs: Optional[Dict] = None, - max_source_length: int = 128, - max_target_length: int = 128, - padding: Union[str, bool] = "max_length", **data_module_kwargs: Any, ) -> "SummarizationData": """Load the :class:`~flash.text.seq2seq.summarization.data.SummarizationData` from Hugging Face ``Dataset`` @@ -338,10 +310,6 @@ def from_hf_datasets( predicting. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. - max_source_length: The maximum length to pad / truncate input sequences to. - max_target_length: The maximum length to pad / truncate target sequences to. - padding: The type of padding to apply. One of: "longest" or ``True``, "max_length", "do_not_pad" or - ``False``. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. @@ -390,10 +358,6 @@ def from_hf_datasets( ds_kw = dict( input_key=input_field, target_key=target_field, - max_source_length=max_source_length, - max_target_length=max_target_length, - padding=padding, - data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, ) @@ -416,15 +380,12 @@ def from_lists( test_data: Optional[List[str]] = None, test_targets: Optional[List[str]] = None, predict_data: Optional[List[str]] = None, - train_transform: INPUT_TRANSFORM_TYPE = TransformersInputTransform, - val_transform: INPUT_TRANSFORM_TYPE = TransformersInputTransform, - test_transform: INPUT_TRANSFORM_TYPE = TransformersInputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = TransformersInputTransform, + train_transform: INPUT_TRANSFORM_TYPE = InputTransform, + val_transform: INPUT_TRANSFORM_TYPE = InputTransform, + test_transform: INPUT_TRANSFORM_TYPE = InputTransform, + predict_transform: INPUT_TRANSFORM_TYPE = InputTransform, input_cls: Type[Input] = Seq2SeqListInput, transform_kwargs: Optional[Dict] = None, - max_source_length: int = 128, - max_target_length: int = 128, - padding: Union[str, bool] = "max_length", **data_module_kwargs: Any, ) -> "SummarizationData": """Load the :class:`~flash.text.seq2seq.summarization.data.SummarizationData` from lists of input text @@ -448,10 +409,6 @@ def from_lists( predicting. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. - max_source_length: The maximum length to pad / truncate input sequences to. - max_target_length: The maximum length to pad / truncate target sequences to. - padding: The type of padding to apply. One of: "longest" or ``True``, "max_length", "do_not_pad" or - ``False``. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. @@ -480,10 +437,6 @@ def from_lists( """ ds_kw = dict( - max_source_length=max_source_length, - max_target_length=max_target_length, - padding=padding, - data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, ) diff --git a/flash/text/seq2seq/summarization/model.py b/flash/text/seq2seq/summarization/model.py index 78ee4301bc..ed4e0cf252 100644 --- a/flash/text/seq2seq/summarization/model.py +++ b/flash/text/seq2seq/summarization/model.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union import torch from torchmetrics.text.rouge import ROUGEScore @@ -30,13 +30,16 @@ class SummarizationTask(Seq2SeqTask): Args: backbone: backbone model to use for the task. + max_source_length: The maximum length to pad / truncate input sequences to. + max_target_length: The maximum length to pad / truncate target sequences to. + padding: The type of padding to apply. One of: "longest" or ``True``, "max_length", "do_not_pad" or + ``False``. loss_fn: Loss function for training. optimizer: Optimizer to use for training. lr_scheduler: The LR scheduler to use during training. metrics: Metrics to compute for training and evaluation. Defauls to calculating the ROUGE metric. Changing this argument currently has no effect. learning_rate: Learning rate to use for training, defaults to `3e-4` - val_target_max_length: Maximum length of targets in validation. Defaults to `128` num_beams: Number of beams to use in validation when generating predictions. Defaults to `4` use_stemmer: Whether Porter stemmer should be used to strip word suffixes to improve matching. enable_ort: Enable Torch ONNX Runtime Optimization: https://onnxruntime.ai/docs/#onnx-runtime-for-training @@ -46,12 +49,14 @@ def __init__( self, backbone: str = "sshleifer/distilbart-xsum-1-1", tokenizer_kwargs: Optional[Dict[str, Any]] = None, + max_source_length: int = 128, + max_target_length: int = 128, + padding: Union[str, bool] = "max_length", loss_fn: LOSS_FN_TYPE = None, optimizer: OPTIMIZER_TYPE = "Adam", lr_scheduler: LR_SCHEDULER_TYPE = None, metrics: METRICS_TYPE = None, learning_rate: float = 1e-5, - val_target_max_length: Optional[int] = None, num_beams: Optional[int] = 4, use_stemmer: bool = True, enable_ort: bool = False, @@ -60,12 +65,14 @@ def __init__( super().__init__( backbone=backbone, tokenizer_kwargs=tokenizer_kwargs, + max_source_length=max_source_length, + max_target_length=max_target_length, + padding=padding, loss_fn=loss_fn, optimizer=optimizer, lr_scheduler=lr_scheduler, metrics=metrics, learning_rate=learning_rate, - val_target_max_length=val_target_max_length, num_beams=num_beams, enable_ort=enable_ort, ) @@ -84,8 +91,8 @@ def task(self) -> str: return "summarization" def compute_metrics(self, generated_tokens: torch.Tensor, batch: Dict, prefix: str) -> None: - tgt_lns = self.tokenize_labels(batch["labels"]) - result = self.rouge(self._output_transform.uncollate(generated_tokens), tgt_lns) + tgt_lns = self.decode(batch["labels"]) + result = self.rouge(self.decode(generated_tokens), tgt_lns) self.log_dict(result, on_step=False, on_epoch=True, prog_bar=True) @staticmethod diff --git a/flash/text/seq2seq/translation/data.py b/flash/text/seq2seq/translation/data.py index c160077c43..14b6d62b69 100644 --- a/flash/text/seq2seq/translation/data.py +++ b/flash/text/seq2seq/translation/data.py @@ -11,18 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Type, Union +from typing import Any, Dict, List, Optional, Type from flash.core.data.data_module import DataModule -from flash.core.data.data_pipeline import DataPipelineState from flash.core.data.io.input import Input +from flash.core.data.io.input_transform import InputTransform from flash.core.data.utilities.paths import PATH_TYPE -from flash.core.integrations.transformers.input_transform import TransformersInputTransform from flash.core.utilities.imports import _TEXT_AVAILABLE, _TEXT_TESTING from flash.core.utilities.stages import RunningStage from flash.core.utilities.types import INPUT_TRANSFORM_TYPE from flash.text.seq2seq.core.input import Seq2SeqCSVInput, Seq2SeqInputBase, Seq2SeqJSONInput, Seq2SeqListInput -from flash.text.seq2seq.core.output_transform import Seq2SeqOutputTransform if _TEXT_AVAILABLE: from datasets import Dataset @@ -38,8 +36,7 @@ class TranslationData(DataModule): """The ``TranslationData`` class is a :class:`~flash.core.data.data_module.DataModule` with a set of classmethods for loading data for text translation.""" - input_transform_cls = TransformersInputTransform - output_transform_cls = Seq2SeqOutputTransform + input_transform_cls = InputTransform @classmethod def from_csv( @@ -50,15 +47,12 @@ def from_csv( val_file: Optional[PATH_TYPE] = None, test_file: Optional[PATH_TYPE] = None, predict_file: Optional[PATH_TYPE] = None, - train_transform: INPUT_TRANSFORM_TYPE = TransformersInputTransform, - val_transform: INPUT_TRANSFORM_TYPE = TransformersInputTransform, - test_transform: INPUT_TRANSFORM_TYPE = TransformersInputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = TransformersInputTransform, + train_transform: INPUT_TRANSFORM_TYPE = InputTransform, + val_transform: INPUT_TRANSFORM_TYPE = InputTransform, + test_transform: INPUT_TRANSFORM_TYPE = InputTransform, + predict_transform: INPUT_TRANSFORM_TYPE = InputTransform, input_cls: Type[Input] = Seq2SeqCSVInput, transform_kwargs: Optional[Dict] = None, - max_source_length: int = 128, - max_target_length: int = 128, - padding: Union[str, bool] = "max_length", **data_module_kwargs: Any, ) -> "TranslationData": """Load the :class:`~flash.text.seq2seq.translation.data.TranslationData` from CSV files containing input @@ -83,10 +77,6 @@ def from_csv( predicting. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. - max_source_length: The maximum length to pad / truncate input sequences to. - max_target_length: The maximum length to pad / truncate target sequences to. - padding: The type of padding to apply. One of: "longest" or ``True``, "max_length", "do_not_pad" or - ``False``. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. @@ -151,10 +141,6 @@ def from_csv( ds_kw = dict( input_key=input_field, target_key=target_field, - max_source_length=max_source_length, - max_target_length=max_target_length, - padding=padding, - data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, ) @@ -176,16 +162,13 @@ def from_json( val_file: Optional[PATH_TYPE] = None, test_file: Optional[PATH_TYPE] = None, predict_file: Optional[PATH_TYPE] = None, - train_transform: INPUT_TRANSFORM_TYPE = TransformersInputTransform, - val_transform: INPUT_TRANSFORM_TYPE = TransformersInputTransform, - test_transform: INPUT_TRANSFORM_TYPE = TransformersInputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = TransformersInputTransform, + train_transform: INPUT_TRANSFORM_TYPE = InputTransform, + val_transform: INPUT_TRANSFORM_TYPE = InputTransform, + test_transform: INPUT_TRANSFORM_TYPE = InputTransform, + predict_transform: INPUT_TRANSFORM_TYPE = InputTransform, input_cls: Type[Input] = Seq2SeqJSONInput, transform_kwargs: Optional[Dict] = None, field: Optional[str] = None, - max_source_length: int = 128, - max_target_length: int = 128, - padding: Union[str, bool] = "max_length", **data_module_kwargs: Any, ) -> "TranslationData": """Load the :class:`~flash.text.seq2seq.translation.data.TranslationData` from JSON files containing input @@ -211,10 +194,6 @@ def from_json( input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. field: The field that holds the data in the JSON file. - max_source_length: The maximum length to pad / truncate input sequences to. - max_target_length: The maximum length to pad / truncate target sequences to. - padding: The type of padding to apply. One of: "longest" or ``True``, "max_length", "do_not_pad" or - ``False``. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. @@ -278,10 +257,6 @@ def from_json( input_key=input_field, target_key=target_field, field=field, - max_source_length=max_source_length, - max_target_length=max_target_length, - padding=padding, - data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, ) @@ -303,15 +278,12 @@ def from_hf_datasets( val_hf_dataset: Optional[Dataset] = None, test_hf_dataset: Optional[Dataset] = None, predict_hf_dataset: Optional[Dataset] = None, - train_transform: INPUT_TRANSFORM_TYPE = TransformersInputTransform, - val_transform: INPUT_TRANSFORM_TYPE = TransformersInputTransform, - test_transform: INPUT_TRANSFORM_TYPE = TransformersInputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = TransformersInputTransform, + train_transform: INPUT_TRANSFORM_TYPE = InputTransform, + val_transform: INPUT_TRANSFORM_TYPE = InputTransform, + test_transform: INPUT_TRANSFORM_TYPE = InputTransform, + predict_transform: INPUT_TRANSFORM_TYPE = InputTransform, input_cls: Type[Input] = Seq2SeqInputBase, transform_kwargs: Optional[Dict] = None, - max_source_length: int = 128, - max_target_length: int = 128, - padding: Union[str, bool] = "max_length", **data_module_kwargs: Any, ) -> "TranslationData": """Load the :class:`~flash.text.seq2seq.translation.data.TranslationData` from Hugging Face ``Dataset`` @@ -336,10 +308,6 @@ def from_hf_datasets( predicting. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. - max_source_length: The maximum length to pad / truncate input sequences to. - max_target_length: The maximum length to pad / truncate target sequences to. - padding: The type of padding to apply. One of: "longest" or ``True``, "max_length", "do_not_pad" or - ``False``. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. @@ -388,10 +356,6 @@ def from_hf_datasets( ds_kw = dict( input_key=input_field, target_key=target_field, - max_source_length=max_source_length, - max_target_length=max_target_length, - padding=padding, - data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, ) @@ -414,15 +378,12 @@ def from_lists( test_data: Optional[List[str]] = None, test_targets: Optional[List[str]] = None, predict_data: Optional[List[str]] = None, - train_transform: INPUT_TRANSFORM_TYPE = TransformersInputTransform, - val_transform: INPUT_TRANSFORM_TYPE = TransformersInputTransform, - test_transform: INPUT_TRANSFORM_TYPE = TransformersInputTransform, - predict_transform: INPUT_TRANSFORM_TYPE = TransformersInputTransform, + train_transform: INPUT_TRANSFORM_TYPE = InputTransform, + val_transform: INPUT_TRANSFORM_TYPE = InputTransform, + test_transform: INPUT_TRANSFORM_TYPE = InputTransform, + predict_transform: INPUT_TRANSFORM_TYPE = InputTransform, input_cls: Type[Input] = Seq2SeqListInput, transform_kwargs: Optional[Dict] = None, - max_source_length: int = 128, - max_target_length: int = 128, - padding: Union[str, bool] = "max_length", **data_module_kwargs: Any, ) -> "TranslationData": """Load the :class:`~flash.text.seq2seq.translation.data.TranslationData` from lists of input text snippets @@ -446,10 +407,6 @@ def from_lists( predicting. input_cls: The :class:`~flash.core.data.io.input.Input` type to use for loading the data. transform_kwargs: Dict of keyword arguments to be provided when instantiating the transforms. - max_source_length: The maximum length to pad / truncate input sequences to. - max_target_length: The maximum length to pad / truncate target sequences to. - padding: The type of padding to apply. One of: "longest" or ``True``, "max_length", "do_not_pad" or - ``False``. data_module_kwargs: Additional keyword arguments to provide to the :class:`~flash.core.data.data_module.DataModule` constructor. @@ -478,10 +435,6 @@ def from_lists( """ ds_kw = dict( - max_source_length=max_source_length, - max_target_length=max_target_length, - padding=padding, - data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, ) diff --git a/flash/text/seq2seq/translation/model.py b/flash/text/seq2seq/translation/model.py index b656748a74..5c08f8a485 100644 --- a/flash/text/seq2seq/translation/model.py +++ b/flash/text/seq2seq/translation/model.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from torchmetrics import BLEUScore @@ -29,13 +29,16 @@ class TranslationTask(Seq2SeqTask): Args: backbone: backbone model to use for the task. + max_source_length: The maximum length to pad / truncate input sequences to. + max_target_length: The maximum length to pad / truncate target sequences to. + padding: The type of padding to apply. One of: "longest" or ``True``, "max_length", "do_not_pad" or + ``False``. loss_fn: Loss function for training. optimizer: Optimizer to use for training. lr_scheduler: The LR scheduler to use during training. metrics: Metrics to compute for training and evaluation. Defauls to calculating the BLEU metric. Changing this argument currently has no effect. learning_rate: Learning rate to use for training, defaults to `1e-5` - val_target_max_length: Maximum length of targets in validation. Defaults to `128` num_beams: Number of beams to use in validation when generating predictions. Defaults to `4` n_gram: Maximum n_grams to use in metric calculation. Defaults to `4` smooth: Apply smoothing in BLEU calculation. Defaults to `True` @@ -46,12 +49,14 @@ def __init__( self, backbone: str = "t5-small", tokenizer_kwargs: Optional[Dict[str, Any]] = None, + max_source_length: int = 128, + max_target_length: int = 128, + padding: Union[str, bool] = "max_length", loss_fn: LOSS_FN_TYPE = None, optimizer: OPTIMIZER_TYPE = "Adam", lr_scheduler: LR_SCHEDULER_TYPE = None, metrics: METRICS_TYPE = None, learning_rate: float = 1e-5, - val_target_max_length: Optional[int] = 128, num_beams: Optional[int] = 4, n_gram: bool = 4, smooth: bool = True, @@ -61,12 +66,14 @@ def __init__( super().__init__( backbone=backbone, tokenizer_kwargs=tokenizer_kwargs, + max_source_length=max_source_length, + max_target_length=max_target_length, + padding=padding, loss_fn=loss_fn, optimizer=optimizer, lr_scheduler=lr_scheduler, metrics=metrics, learning_rate=learning_rate, - val_target_max_length=val_target_max_length, num_beams=num_beams, enable_ort=enable_ort, ) @@ -80,11 +87,11 @@ def task(self) -> str: return "translation" def compute_metrics(self, generated_tokens, batch, prefix): - reference_corpus = self.tokenize_labels(batch["labels"]) + reference_corpus = self.decode(batch["labels"]) # wrap targets in list as score expects a list of potential references reference_corpus = [[reference] for reference in reference_corpus] - translate_corpus = self._output_transform.uncollate(generated_tokens) + translate_corpus = self.decode(generated_tokens) translate_corpus = [line for line in translate_corpus] if _TM_GREATER_EQUAL_0_7_0: diff --git a/flash/video/classification/cli.py b/flash/video/classification/cli.py index b4f0fdd709..9754e27b23 100644 --- a/flash/video/classification/cli.py +++ b/flash/video/classification/cli.py @@ -49,6 +49,7 @@ def video_classification(): default_arguments={ "trainer.max_epochs": 1, }, + datamodule_attributes={"num_classes", "labels"}, ) cli.trainer.save_checkpoint("video_classification.pt") diff --git a/flash/video/classification/data.py b/flash/video/classification/data.py index c5ce872240..d15c3be800 100644 --- a/flash/video/classification/data.py +++ b/flash/video/classification/data.py @@ -18,7 +18,6 @@ from torch.utils.data import Sampler from flash.core.data.data_module import DataModule -from flash.core.data.data_pipeline import DataPipelineState from flash.core.data.io.input import Input from flash.core.data.io.input_transform import INPUT_TRANSFORM_TYPE from flash.core.data.utilities.paths import PATH_TYPE @@ -179,7 +178,6 @@ def from_files( """ ds_kw = dict( - data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, clip_sampler=clip_sampler, @@ -189,21 +187,25 @@ def from_files( decoder=decoder, ) + train_input = input_cls( + RunningStage.TRAINING, + train_files, + train_targets, + transform=train_transform, + video_sampler=video_sampler, + **ds_kw, + ) + target_formatter = getattr(train_input, "target_formatter", None) + return cls( - input_cls( - RunningStage.TRAINING, - train_files, - train_targets, - transform=train_transform, - video_sampler=video_sampler, - **ds_kw, - ), + train_input, input_cls( RunningStage.VALIDATING, val_files, val_targets, transform=val_transform, video_sampler=video_sampler, + target_formatter=target_formatter, **ds_kw, ), input_cls( @@ -212,6 +214,7 @@ def from_files( test_targets, transform=test_transform, video_sampler=video_sampler, + target_formatter=target_formatter, **ds_kw, ), predict_input_cls(RunningStage.PREDICTING, predict_files, transform=predict_transform, **ds_kw), @@ -348,7 +351,6 @@ def from_folders( """ ds_kw = dict( - data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, clip_sampler=clip_sampler, @@ -358,15 +360,28 @@ def from_folders( decoder=decoder, ) + train_input = input_cls( + RunningStage.TRAINING, train_folder, transform=train_transform, video_sampler=video_sampler, **ds_kw + ) + target_formatter = getattr(train_input, "target_formatter", None) + return cls( + train_input, input_cls( - RunningStage.TRAINING, train_folder, transform=train_transform, video_sampler=video_sampler, **ds_kw - ), - input_cls( - RunningStage.VALIDATING, val_folder, transform=val_transform, video_sampler=video_sampler, **ds_kw + RunningStage.VALIDATING, + val_folder, + transform=val_transform, + video_sampler=video_sampler, + target_formatter=target_formatter, + **ds_kw, ), input_cls( - RunningStage.TESTING, test_folder, transform=test_transform, video_sampler=video_sampler, **ds_kw + RunningStage.TESTING, + test_folder, + transform=test_transform, + video_sampler=video_sampler, + target_formatter=target_formatter, + **ds_kw, ), predict_input_cls(RunningStage.PREDICTING, predict_folder, transform=predict_transform, **ds_kw), **data_module_kwargs, @@ -522,7 +537,6 @@ def from_data_frame( """ ds_kw = dict( - data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, clip_sampler=clip_sampler, @@ -537,14 +551,29 @@ def from_data_frame( test_data = (test_data_frame, input_field, target_fields, test_videos_root, test_resolver) predict_data = (predict_data_frame, input_field, predict_videos_root, predict_resolver) + train_input = input_cls( + RunningStage.TRAINING, *train_data, transform=train_transform, video_sampler=video_sampler, **ds_kw + ) + target_formatter = getattr(train_input, "target_formatter", None) + return cls( + train_input, input_cls( - RunningStage.TRAINING, *train_data, transform=train_transform, video_sampler=video_sampler, **ds_kw + RunningStage.VALIDATING, + *val_data, + transform=val_transform, + video_sampler=video_sampler, + target_formatter=target_formatter, + **ds_kw, ), input_cls( - RunningStage.VALIDATING, *val_data, transform=val_transform, video_sampler=video_sampler, **ds_kw + RunningStage.TESTING, + *test_data, + transform=test_transform, + video_sampler=video_sampler, + target_formatter=target_formatter, + **ds_kw, ), - input_cls(RunningStage.TESTING, *test_data, transform=test_transform, video_sampler=video_sampler, **ds_kw), predict_input_cls(RunningStage.PREDICTING, *predict_data, transform=predict_transform, **ds_kw), **data_module_kwargs, ) @@ -713,7 +742,6 @@ def from_csv( """ ds_kw = dict( - data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, clip_sampler=clip_sampler, @@ -728,14 +756,29 @@ def from_csv( test_data = (test_file, input_field, target_fields, test_videos_root, test_resolver) predict_data = (predict_file, input_field, predict_videos_root, predict_resolver) + train_input = input_cls( + RunningStage.TRAINING, *train_data, transform=train_transform, video_sampler=video_sampler, **ds_kw + ) + target_formatter = getattr(train_input, "target_formatter", None) + return cls( + train_input, input_cls( - RunningStage.TRAINING, *train_data, transform=train_transform, video_sampler=video_sampler, **ds_kw + RunningStage.VALIDATING, + *val_data, + transform=val_transform, + video_sampler=video_sampler, + target_formatter=target_formatter, + **ds_kw, ), input_cls( - RunningStage.VALIDATING, *val_data, transform=val_transform, video_sampler=video_sampler, **ds_kw + RunningStage.TESTING, + *test_data, + transform=test_transform, + video_sampler=video_sampler, + target_formatter=target_formatter, + **ds_kw, ), - input_cls(RunningStage.TESTING, *test_data, transform=test_transform, video_sampler=video_sampler, **ds_kw), predict_input_cls(RunningStage.PREDICTING, *predict_data, transform=predict_transform, **ds_kw), **data_module_kwargs, ) @@ -859,7 +902,6 @@ def from_fiftyone( """ ds_kw = dict( - data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, clip_sampler=clip_sampler, @@ -869,21 +911,25 @@ def from_fiftyone( decoder=decoder, ) + train_input = input_cls( + RunningStage.TRAINING, + train_dataset, + transform=train_transform, + video_sampler=video_sampler, + label_field=label_field, + **ds_kw, + ) + target_formatter = getattr(train_input, "target_formatter", None) + return cls( - input_cls( - RunningStage.TRAINING, - train_dataset, - transform=train_transform, - video_sampler=video_sampler, - label_field=label_field, - **ds_kw, - ), + train_input, input_cls( RunningStage.VALIDATING, val_dataset, transform=val_transform, video_sampler=video_sampler, label_field=label_field, + target_formatter=target_formatter, **ds_kw, ), input_cls( @@ -892,6 +938,7 @@ def from_fiftyone( transform=test_transform, video_sampler=video_sampler, label_field=label_field, + target_formatter=target_formatter, **ds_kw, ), predict_input_cls(RunningStage.PREDICTING, predict_dataset, transform=predict_transform, **ds_kw), @@ -998,7 +1045,6 @@ def from_labelstudio( ) ds_kw = dict( - data_pipeline_state=DataPipelineState(), transform_kwargs=transform_kwargs, input_transforms_registry=cls.input_transforms_registry, clip_sampler=clip_sampler, @@ -1009,8 +1055,11 @@ def from_labelstudio( decoder=decoder, ) + train_input = input_cls(RunningStage.TRAINING, train_data, transform=train_transform, **ds_kw) + ds_kw["parameters"] = getattr(train_input, "parameters", None) + return cls( - input_cls(RunningStage.TRAINING, train_data, transform=train_transform, **ds_kw), + train_input, input_cls(RunningStage.VALIDATING, val_data, transform=val_transform, **ds_kw), input_cls(RunningStage.TESTING, test_data, transform=test_transform, **ds_kw), input_cls(RunningStage.PREDICTING, predict_data, transform=predict_transform, **ds_kw), diff --git a/flash/video/classification/input.py b/flash/video/classification/input.py index b3320c4583..906af4ef8d 100644 --- a/flash/video/classification/input.py +++ b/flash/video/classification/input.py @@ -19,9 +19,9 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from torch.utils.data import Sampler -from flash.core.data.io.classification_input import ClassificationInputMixin, ClassificationState +from flash.core.data.io.classification_input import ClassificationInputMixin from flash.core.data.io.input import DataKeys, Input, IterableInput -from flash.core.data.utilities.classification import MultiBinaryTargetFormatter +from flash.core.data.utilities.classification import MultiBinaryTargetFormatter, TargetFormatter from flash.core.data.utilities.data_frame import read_csv, resolve_files, resolve_targets from flash.core.data.utilities.paths import list_valid_files, make_dataset, PATH_TYPE from flash.core.integrations.fiftyone.utils import FiftyOneLabelUtilities @@ -64,6 +64,7 @@ def load_data( video_sampler: Type[Sampler] = torch.utils.data.RandomSampler, decode_audio: bool = False, decoder: str = "pyav", + target_formatter: Optional[TargetFormatter] = None, ) -> "LabeledVideoDataset": dataset = LabeledVideoDataset( LabeledVideoPaths(list(zip(files, targets))), @@ -73,7 +74,9 @@ def load_data( decoder=decoder, ) if not self.predicting: - self.load_target_metadata([sample[1] for sample in dataset._labeled_videos._paths_and_labels]) + self.load_target_metadata( + [sample[1] for sample in dataset._labeled_videos._paths_and_labels], target_formatter=target_formatter + ) return dataset def load_sample(self, sample): @@ -91,6 +94,7 @@ def load_data( video_sampler: Type[Sampler] = torch.utils.data.RandomSampler, decode_audio: bool = False, decoder: str = "pyav", + target_formatter: Optional[TargetFormatter] = None, ) -> "LabeledVideoDataset": return super().load_data( *make_dataset(path, extensions=("mp4", "avi")), @@ -100,6 +104,7 @@ def load_data( video_sampler=video_sampler, decode_audio=decode_audio, decoder=decoder, + target_formatter=target_formatter, ) @@ -114,6 +119,7 @@ def load_data( video_sampler: Type[Sampler] = torch.utils.data.RandomSampler, decode_audio: bool = False, decoder: str = "pyav", + target_formatter: Optional[TargetFormatter] = None, ) -> "LabeledVideoDataset": return super().load_data( paths, @@ -124,6 +130,7 @@ def load_data( video_sampler=video_sampler, decode_audio=decode_audio, decoder=decoder, + target_formatter=target_formatter, ) @@ -141,6 +148,7 @@ def load_data( video_sampler: Type[Sampler] = torch.utils.data.RandomSampler, decode_audio: bool = False, decoder: str = "pyav", + target_formatter: Optional[TargetFormatter] = None, ) -> "LabeledVideoDataset": result = super().load_data( resolve_files(data_frame, input_key, root, resolver), @@ -151,6 +159,7 @@ def load_data( video_sampler=video_sampler, decode_audio=decode_audio, decoder=decoder, + target_formatter=target_formatter, ) # If we had binary multi-class targets then we also know the labels (column names) @@ -159,8 +168,7 @@ def load_data( and isinstance(self.target_formatter, MultiBinaryTargetFormatter) and isinstance(target_keys, List) ): - classification_state = self.get_state(ClassificationState) - self.set_state(ClassificationState(target_keys, classification_state.num_classes)) + self.labels = target_keys return result @@ -179,6 +187,7 @@ def load_data( video_sampler: Type[Sampler] = torch.utils.data.RandomSampler, decode_audio: bool = False, decoder: str = "pyav", + target_formatter: Optional[TargetFormatter] = None, ) -> "LabeledVideoDataset": data_frame = read_csv(csv_file) if root is None: @@ -195,6 +204,7 @@ def load_data( video_sampler=video_sampler, decode_audio=decode_audio, decoder=decoder, + target_formatter=target_formatter, ) @@ -210,6 +220,7 @@ def load_data( decode_audio: bool = False, decoder: str = "pyav", label_field: str = "ground_truth", + target_formatter: Optional[TargetFormatter] = None, ) -> "LabeledVideoDataset": label_utilities = FiftyOneLabelUtilities(label_field, fol.Classification) label_utilities.validate(sample_collection) @@ -223,6 +234,7 @@ def load_data( video_sampler=video_sampler, decode_audio=decode_audio, decoder=decoder, + target_formatter=target_formatter, ) diff --git a/flash/video/classification/model.py b/flash/video/classification/model.py index 3c049c7030..8228c14178 100644 --- a/flash/video/classification/model.py +++ b/flash/video/classification/model.py @@ -22,13 +22,13 @@ from torchmetrics import Accuracy import flash -from flash.core.classification import ClassificationTask, LabelsOutput +from flash.core.classification import ClassificationTask from flash.core.data.io.input import DataKeys from flash.core.registry import FlashRegistry from flash.core.utilities.compatibility import accelerator_connector from flash.core.utilities.imports import _PYTORCHVIDEO_AVAILABLE from flash.core.utilities.providers import _PYTORCHVIDEO -from flash.core.utilities.types import LOSS_FN_TYPE, LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE, OUTPUT_TYPE +from flash.core.utilities.types import LOSS_FN_TYPE, LR_SCHEDULER_TYPE, METRICS_TYPE, OPTIMIZER_TYPE _VIDEO_CLASSIFIER_BACKBONES = FlashRegistry("backbones") @@ -63,7 +63,6 @@ class VideoClassifier(ClassificationTask): head: either a `nn.Module` or a callable function that converts the features extrated from the backbone into class log probabilities (assuming default loss function). If `None`, will default to using a single linear layer. - output: The :class:`~flash.core.data.io.output.Output` to use when formatting prediction outputs. """ backbones: FlashRegistry = _VIDEO_CLASSIFIER_BACKBONES @@ -72,7 +71,8 @@ class VideoClassifier(ClassificationTask): def __init__( self, - num_classes: int, + num_classes: Optional[int] = None, + labels: Optional[List[str]] = None, backbone: Union[str, nn.Module] = "x3d_xs", backbone_kwargs: Optional[Dict] = None, pretrained: bool = True, @@ -82,8 +82,12 @@ def __init__( metrics: METRICS_TYPE = Accuracy(), learning_rate: float = 1e-3, head: Optional[Union[FunctionType, nn.Module]] = None, - output: OUTPUT_TYPE = None, ): + self.save_hyperparameters() + + if labels is not None and num_classes is None: + num_classes = len(labels) + super().__init__( model=None, loss_fn=loss_fn, @@ -91,11 +95,10 @@ def __init__( lr_scheduler=lr_scheduler, metrics=metrics, learning_rate=learning_rate, - output=output or LabelsOutput(), + num_classes=num_classes, + labels=labels, ) - self.save_hyperparameters() - if not backbone_kwargs: backbone_kwargs = {} diff --git a/flash_examples/audio_classification.py b/flash_examples/audio_classification.py index 3302b1b9a2..51415b9e94 100644 --- a/flash_examples/audio_classification.py +++ b/flash_examples/audio_classification.py @@ -29,13 +29,13 @@ ) # 2. Build the model. -model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes) +model = ImageClassifier(backbone="resnet18", labels=datamodule.labels) # 3. Create the trainer and finetune the model trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count()) trainer.finetune(model, datamodule=datamodule, strategy=("freeze_unfreeze", 1)) -# 4. Predict what's on few images! air_conditioner, children_playing, siren e.t.c +# 4. Predict what's on few images! air_conditioner, children_playing, siren etc. datamodule = AudioClassificationData.from_files( predict_files=[ "data/urban8k_images/test/air_conditioner/13230-0-0-5.wav.jpg", @@ -44,7 +44,7 @@ ], batch_size=3, ) -predictions = trainer.predict(model, datamodule=datamodule) +predictions = trainer.predict(model, datamodule=datamodule, output="labels") print(predictions) # 5. Save the model! diff --git a/flash_examples/flash_components/custom_data_loading.py b/flash_examples/flash_components/custom_data_loading.py index bf3f66ddef..5f299a6457 100644 --- a/flash_examples/flash_components/custom_data_loading.py +++ b/flash_examples/flash_components/custom_data_loading.py @@ -24,7 +24,6 @@ from flash import _PACKAGE_ROOT, RunningStage from flash.core.data.data_module import DataModule -from flash.core.data.data_pipeline import DataPipelineState from flash.core.data.io.input import DataKeys, Input from flash.core.data.io.input_transform import INPUT_TRANSFORM_TYPE, InputTransform from flash.core.data.utils import download_data @@ -289,13 +288,11 @@ def from_multiple_folders( **data_module_kwargs: Any, ) -> "ImageClassificationDataModule": - kw = dict(data_pipeline_state=DataPipelineState()) - return cls( - MultipleFoldersImageInput(RunningStage.TRAINING, train_folders, transform=train_transform, **kw), - MultipleFoldersImageInput(RunningStage.VALIDATING, val_folders, transform=val_transform, **kw), - MultipleFoldersImageInput(RunningStage.VALIDATING, test_folders, transform=test_transform, **kw), - MultipleFoldersImageInput(RunningStage.PREDICTING, predict_folder, transform=predict_transform, **kw), + MultipleFoldersImageInput(RunningStage.TRAINING, train_folders, transform=train_transform), + MultipleFoldersImageInput(RunningStage.VALIDATING, val_folders, transform=val_transform), + MultipleFoldersImageInput(RunningStage.VALIDATING, test_folders, transform=test_transform), + MultipleFoldersImageInput(RunningStage.PREDICTING, predict_folder, transform=predict_transform), **data_module_kwargs, ) diff --git a/flash_examples/graph_classification.py b/flash_examples/graph_classification.py index 50a0b8cfe4..18212ba546 100644 --- a/flash_examples/graph_classification.py +++ b/flash_examples/graph_classification.py @@ -44,7 +44,7 @@ predict_dataset=dataset[:3], batch_size=4, ) -predictions = trainer.predict(model, datamodule=datamodule) +predictions = trainer.predict(model, datamodule=datamodule, output="classes") print(predictions) # 5. Save the model! diff --git a/flash_examples/image_classification.py b/flash_examples/image_classification.py index 8157ca1b57..aa096b80e4 100644 --- a/flash_examples/image_classification.py +++ b/flash_examples/image_classification.py @@ -28,7 +28,7 @@ ) # 2. Build the task -model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes) +model = ImageClassifier(backbone="resnet18", labels=datamodule.labels) # 3. Create the trainer and finetune the model trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count()) @@ -43,7 +43,7 @@ ], batch_size=3, ) -predictions = trainer.predict(model, datamodule=datamodule) +predictions = trainer.predict(model, datamodule=datamodule, output="labels") print(predictions) # 5. Save the model! diff --git a/flash_examples/image_classification_multi_label.py b/flash_examples/image_classification_multi_label.py index 7744c3c3d6..d3b7371c4a 100644 --- a/flash_examples/image_classification_multi_label.py +++ b/flash_examples/image_classification_multi_label.py @@ -41,7 +41,7 @@ def resolver(root, file_id): ) # 2. Build the task -model = ImageClassifier(backbone="resnet18", num_classes=datamodule.num_classes, multi_label=datamodule.multi_label) +model = ImageClassifier(backbone="resnet18", labels=datamodule.labels, multi_label=datamodule.multi_label) # 3. Create the trainer and finetune the model trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count()) @@ -56,7 +56,7 @@ def resolver(root, file_id): ], batch_size=3, ) -predictions = trainer.predict(model, datamodule=datamodule) +predictions = trainer.predict(model, datamodule=datamodule, output="labels") print(predictions) # 5. Save the model! diff --git a/flash_examples/integrations/baal/image_classification_active_learning.py b/flash_examples/integrations/baal/image_classification_active_learning.py index f56cd975a6..a3a97b7467 100644 --- a/flash_examples/integrations/baal/image_classification_active_learning.py +++ b/flash_examples/integrations/baal/image_classification_active_learning.py @@ -14,7 +14,6 @@ import torch import flash -from flash.core.classification import ProbabilitiesOutput from flash.core.data.utils import download_data from flash.image import ImageClassificationData, ImageClassifier from flash.image.classification.integrations.baal import ActiveLearningDataModule, ActiveLearningLoop @@ -33,9 +32,7 @@ torch.nn.Dropout(p=0.1), torch.nn.Linear(512, datamodule.num_classes), ) -model = ImageClassifier( - backbone="resnet18", head=head, num_classes=datamodule.num_classes, output=ProbabilitiesOutput() -) +model = ImageClassifier(backbone="resnet18", head=head, num_classes=datamodule.num_classes) # 3.1 Create the trainer trainer = flash.Trainer(max_epochs=3) @@ -53,7 +50,7 @@ predict_files=["data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg"], batch_size=1, ) -predictions = trainer.predict(model, datamodule=datamodule) +predictions = trainer.predict(model, datamodule=datamodule, output="probabilities") print(predictions) # 5. Save the model! diff --git a/flash_examples/integrations/fiftyone/image_classification.py b/flash_examples/integrations/fiftyone/image_classification.py index f5415565a2..458f59220d 100644 --- a/flash_examples/integrations/fiftyone/image_classification.py +++ b/flash_examples/integrations/fiftyone/image_classification.py @@ -11,12 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from itertools import chain - import torch import flash -from flash.core.classification import FiftyOneLabelsOutput, LabelsOutput from flash.core.data.utils import download_data from flash.core.integrations.fiftyone import visualize from flash.image import ImageClassificationData, ImageClassifier @@ -30,20 +27,18 @@ val_folder="data/hymenoptera_data/val/", test_folder="data/hymenoptera_data/test/", predict_folder="data/hymenoptera_data/predict/", - batch_size=1, + batch_size=16, ) # 3 Fine tune a model model = ImageClassifier( backbone="resnet18", - num_classes=datamodule.num_classes, - output=LabelsOutput(), + labels=datamodule.labels, ) trainer = flash.Trainer( max_epochs=1, gpus=torch.cuda.device_count(), - limit_train_batches=1, - limit_val_batches=1, + fast_dev_run=True, ) trainer.finetune( model, @@ -53,13 +48,9 @@ trainer.save_checkpoint("image_classification_model.pt") # 4 Predict from checkpoint -model = ImageClassifier.load_from_checkpoint( - "https://flash-weights.s3.amazonaws.com/0.7.0/image_classification_model.pt" -) -model.output = FiftyOneLabelsOutput(return_filepath=True) # output FiftyOne format -predictions = trainer.predict(model, datamodule=datamodule) -predictions = list(chain.from_iterable(predictions)) # flatten batches +model = ImageClassifier.load_from_checkpoint("image_classification_model.pt") +predictions = trainer.predict(model, datamodule=datamodule, output="fiftyone") # output FiftyOne format # 5 Visualize predictions in FiftyOne App # Optional: pass `wait=True` to block execution until App is closed -session = visualize(predictions) +session = visualize(predictions, wait=True) diff --git a/flash_examples/integrations/fiftyone/image_classification_fiftyone_datasets.py b/flash_examples/integrations/fiftyone/image_classification_fiftyone_datasets.py index 8e073fed96..b4a4f474f4 100644 --- a/flash_examples/integrations/fiftyone/image_classification_fiftyone_datasets.py +++ b/flash_examples/integrations/fiftyone/image_classification_fiftyone_datasets.py @@ -17,7 +17,7 @@ import torch import flash -from flash.core.classification import FiftyOneLabelsOutput, LabelsOutput +from flash.core.classification import FiftyOneLabelsOutput from flash.core.data.utils import download_data from flash.image import ImageClassificationData, ImageClassifier @@ -43,13 +43,13 @@ train_dataset=train_dataset, val_dataset=val_dataset, test_dataset=test_dataset, + batch_size=4, ) # 4 Fine tune model model = ImageClassifier( backbone="resnet18", - num_classes=datamodule.num_classes, - output=LabelsOutput(), + labels=datamodule.labels, ) trainer = flash.Trainer( max_epochs=1, @@ -65,13 +65,12 @@ trainer.save_checkpoint("image_classification_model.pt") # 5 Predict from checkpoint on data with ground truth -model = ImageClassifier.load_from_checkpoint( - "https://flash-weights.s3.amazonaws.com/0.7.0/image_classification_model.pt" -) -model.output = FiftyOneLabelsOutput(return_filepath=False) # output FiftyOne format -datamodule = ImageClassificationData.from_fiftyone(predict_dataset=test_dataset) -predictions = trainer.predict(model, datamodule=datamodule) -predictions = list(chain.from_iterable(predictions)) # flatten batches +model = ImageClassifier.load_from_checkpoint("image_classification_model.pt") +datamodule = ImageClassificationData.from_fiftyone(predict_dataset=test_dataset, batch_size=4) +predictions = trainer.predict( + model, datamodule=datamodule, output=FiftyOneLabelsOutput(model.labels, return_filepath=False) +) # output FiftyOne format +predictions = list(chain.from_iterable(predictions)) # 6 Add predictions to dataset test_dataset.set_values("predictions", predictions) diff --git a/flash_examples/integrations/fiftyone/object_detection.py b/flash_examples/integrations/fiftyone/object_detection.py index 755cc91be6..cff203e94f 100644 --- a/flash_examples/integrations/fiftyone/object_detection.py +++ b/flash_examples/integrations/fiftyone/object_detection.py @@ -12,13 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from itertools import chain import flash from flash.core.integrations.fiftyone import visualize from flash.core.utilities.imports import example_requires from flash.image import ObjectDetectionData, ObjectDetector -from flash.image.detection.output import FiftyOneDetectionLabelsOutput example_requires("image") @@ -40,7 +38,7 @@ model = ObjectDetector( head="efficientdet", backbone="d0", - num_classes=datamodule.num_classes, + labels=datamodule.labels, image_size=128, lr_scheduler=("multisteplr", {"milestones": [20]}), ) @@ -50,9 +48,7 @@ trainer.finetune(model, datamodule=datamodule, strategy="freeze") # 4. Set the output and get some predictions -model.output = FiftyOneDetectionLabelsOutput(return_filepath=True) # output FiftyOne format -predictions = trainer.predict(model, datamodule=datamodule) -predictions = list(chain.from_iterable(predictions)) # flatten batches +predictions = trainer.predict(model, datamodule=datamodule, output="fiftyone") # output FiftyOne format # 5. Visualize predictions in FiftyOne app # Optional: pass `wait=True` to block execution until App is closed diff --git a/flash_examples/question_answering.py b/flash_examples/question_answering.py index dd777433b2..0306efe92f 100644 --- a/flash_examples/question_answering.py +++ b/flash_examples/question_answering.py @@ -25,7 +25,7 @@ ) # 2. Build the task -model = QuestionAnsweringTask() +model = QuestionAnsweringTask(backbone="distilbert-base-uncased") # 3. Create the trainer and finetune the model trainer = Trainer(max_epochs=3) diff --git a/flash_examples/tabular_classification.py b/flash_examples/tabular_classification.py index ae6d80e052..f587ee147e 100644 --- a/flash_examples/tabular_classification.py +++ b/flash_examples/tabular_classification.py @@ -42,7 +42,7 @@ parameters=datamodule.parameters, batch_size=8, ) -predictions = trainer.predict(model, datamodule=datamodule) +predictions = trainer.predict(model, datamodule=datamodule, output="classes") print(predictions) # 5. Save the model! diff --git a/flash_examples/template.py b/flash_examples/template.py index e9bfc8cb04..2af11ecf60 100644 --- a/flash_examples/template.py +++ b/flash_examples/template.py @@ -26,7 +26,7 @@ ) # 2. Build the task -model = TemplateSKLearnClassifier(num_features=datamodule.num_features, num_classes=datamodule.num_classes, output=None) +model = TemplateSKLearnClassifier(num_features=datamodule.num_features, num_classes=datamodule.num_classes) # 3. Create the trainer and train the model trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count()) @@ -41,7 +41,7 @@ ], batch_size=4, ) -predictions = trainer.predict(model, datamodule=datamodule) +predictions = trainer.predict(model, datamodule=datamodule, output="classes") print(predictions) # 5. Save the model! diff --git a/flash_examples/text_classification.py b/flash_examples/text_classification.py index b870f47b4e..303be0e3b0 100644 --- a/flash_examples/text_classification.py +++ b/flash_examples/text_classification.py @@ -29,7 +29,7 @@ ) # 2. Build the task -model = TextClassifier(backbone="prajjwal1/bert-medium", num_classes=datamodule.num_classes) +model = TextClassifier(backbone="prajjwal1/bert-medium", labels=datamodule.labels) # 3. Create the trainer and finetune the model trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count()) @@ -44,7 +44,7 @@ ], batch_size=4, ) -predictions = trainer.predict(model, datamodule=datamodule) +predictions = trainer.predict(model, datamodule=datamodule, output="labels") print(predictions) # 5. Save the model! diff --git a/flash_examples/text_classification_multi_label.py b/flash_examples/text_classification_multi_label.py index d9cc50eb98..d5dce1e4f9 100644 --- a/flash_examples/text_classification_multi_label.py +++ b/flash_examples/text_classification_multi_label.py @@ -27,12 +27,13 @@ ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"], train_file="data/jigsaw_toxic_comments/train.csv", val_split=0.1, + batch_size=4, ) # 2. Build the task model = TextClassifier( backbone="unitary/toxic-bert", - num_classes=datamodule.num_classes, + labels=datamodule.labels, multi_label=datamodule.multi_label, ) @@ -46,9 +47,10 @@ "No, he is an arrogant, self serving, immature idiot. Get it right.", "U SUCK HANNAH MONTANA", "Would you care to vote? Thx.", - ] + ], + batch_size=4, ) -predictions = trainer.predict(model, datamodule=datamodule) +predictions = trainer.predict(model, datamodule=datamodule, output="labels") print(predictions) # 5. Save the model! diff --git a/flash_examples/video_classification.py b/flash_examples/video_classification.py index b458f23e6d..ace5a6615d 100644 --- a/flash_examples/video_classification.py +++ b/flash_examples/video_classification.py @@ -31,15 +31,15 @@ ) # 2. Build the task -model = VideoClassifier(backbone="x3d_xs", num_classes=datamodule.num_classes, pretrained=False) +model = VideoClassifier(backbone="x3d_xs", labels=datamodule.labels, pretrained=False) # 3. Create the trainer and finetune the model -trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count()) +trainer = flash.Trainer(max_epochs=3, gpus=torch.cuda.device_count(), fast_dev_run=True) trainer.finetune(model, datamodule=datamodule, strategy="freeze") # 4. Make a prediction datamodule = VideoClassificationData.from_folders(predict_folder="data/kinetics/predict", batch_size=1) -predictions = trainer.predict(model, datamodule=datamodule) +predictions = trainer.predict(model, datamodule=datamodule, output="labels") print(predictions) # 5. Save the model! diff --git a/tests/core/data/io/test_output.py b/tests/core/data/io/test_output.py index 391570bb3b..580f08134f 100644 --- a/tests/core/data/io/test_output.py +++ b/tests/core/data/io/test_output.py @@ -11,20 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os from unittest.mock import Mock -import torch -from torch.utils.data import DataLoader - -from flash import RunningStage -from flash.core.classification import LabelsOutput -from flash.core.data.data_pipeline import DataPipeline, DataPipelineState -from flash.core.data.io.classification_input import ClassificationState -from flash.core.data.io.input_transform import InputTransform from flash.core.data.io.output import Output -from flash.core.model import Task -from flash.core.trainer import Trainer def test_output(): @@ -36,25 +25,3 @@ def test_output(): my_output.transform = Mock() my_output("test") my_output.transform.assert_called_once() - - -def test_saving_with_output(tmpdir): - checkpoint_file = os.path.join(tmpdir, "tmp.ckpt") - - class CustomModel(Task): - def __init__(self): - super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss()) - - output = LabelsOutput(["a", "b"]) - model = CustomModel() - trainer = Trainer(fast_dev_run=True) - data_pipeline = DataPipeline(input_transform=InputTransform(RunningStage.TRAINING), output=output) - data_pipeline.initialize() - model.data_pipeline = data_pipeline - assert isinstance(model.input_transform, InputTransform) - dummy_data = DataLoader(list(zip(torch.arange(10, dtype=torch.float), torch.arange(10, dtype=torch.float)))) - trainer.fit(model, train_dataloader=dummy_data) - trainer.save_checkpoint(checkpoint_file) - model = CustomModel.load_from_checkpoint(checkpoint_file) - assert isinstance(model._data_pipeline_state, DataPipelineState) - assert model._data_pipeline_state._state[ClassificationState] == ClassificationState(["a", "b"]) diff --git a/tests/core/data/io/test_output_transform.py b/tests/core/data/io/test_output_transform.py index 5fb2b17d69..eab5054526 100644 --- a/tests/core/data/io/test_output_transform.py +++ b/tests/core/data/io/test_output_transform.py @@ -13,21 +13,19 @@ # limitations under the License. import torch -from flash.core.data.batch import default_uncollate -from flash.core.data.io.output_transform import _OutputTransformProcessor +from flash.core.data.io.output_transform import OutputTransform -def test_output_transform_processor_str(): - output_transform_processor = _OutputTransformProcessor( - default_uncollate, - torch.relu, - torch.softmax, - None, - ) - assert str(output_transform_processor) == ( - "_OutputTransformProcessor:\n" - "\t(per_batch_transform): FuncModule(relu)\n" - "\t(uncollate_fn): FuncModule(default_uncollate)\n" - "\t(per_sample_transform): FuncModule(softmax)\n" - "\t(output): None" - ) +def test_output_transform(): + class CustomOutputTransform(OutputTransform): + @staticmethod + def per_batch_transform(batch): + return batch * 2 + + @staticmethod + def per_sample_transform(sample): + return sample + 1 + + output_transform = CustomOutputTransform() + transformed = output_transform(torch.ones(10)) + assert all(torch.isclose(t, torch.tensor(3.0)) for t in transformed) diff --git a/tests/core/data/test_data_module.py b/tests/core/data/test_data_module.py index be113a6fdc..6f00fa8ab6 100644 --- a/tests/core/data/test_data_module.py +++ b/tests/core/data/test_data_module.py @@ -25,7 +25,6 @@ from flash.core.data.data_module import DataModule, DatasetInput from flash.core.data.io.input import Input from flash.core.data.io.input_transform import InputTransform -from flash.core.data.states import PerBatchTransformOnDevice, PerSampleTransform from flash.core.utilities.imports import _IMAGE_TESTING, _TORCHVISION_AVAILABLE from flash.core.utilities.stages import RunningStage @@ -380,16 +379,17 @@ def per_sample_transform(self) -> Callable: def per_batch_transform_on_device(self) -> Callable: return T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + class OverrideInputTransform(InputTransform): + def per_sample_transform(self) -> Callable: + return T.Compose([T.ToTensor(), T.Resize(128)]) + # define task which overrides transforms using set_state class CustomModel(Task): def __init__(self): super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss()) # override default transform to resize images - self.set_state(PerSampleTransform(T.Compose([T.ToTensor(), T.Resize(128)]))) - - # remove normalization, => image still in [0, 1] range - self.set_state(PerBatchTransformOnDevice(None)) + self.input_transform = OverrideInputTransform def training_step(self, batch, batch_idx): assert batch.shape == torch.Size([2, 3, 128, 128]) diff --git a/tests/core/data/test_data_pipeline.py b/tests/core/data/test_data_pipeline.py index 09baaeea33..d1d720783e 100644 --- a/tests/core/data/test_data_pipeline.py +++ b/tests/core/data/test_data_pipeline.py @@ -11,61 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import cast, Tuple - import pytest -import torch from pytorch_lightning.utilities.exceptions import MisconfigurationException -from torch import Tensor -from flash.core.data.data_pipeline import DataPipeline, DataPipelineState -from flash.core.data.io.input import Input +from flash.core.data.data_pipeline import DataPipeline from flash.core.data.io.input_transform import InputTransform -from flash.core.data.io.output import Output -from flash.core.data.io.output_transform import OutputTransform -from flash.core.data.process import Deserializer -from flash.core.data.properties import ProcessState from flash.core.utilities.stages import RunningStage -class DummyDataset(torch.utils.data.Dataset): - def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]: - return torch.rand(1), torch.rand(1) - - def __len__(self) -> int: - return 5 - - -class TestDataPipelineState: - @staticmethod - def test_str(): - state = DataPipelineState() - state.set_state(ProcessState()) - - assert str(state) == ( - "DataPipelineState(state={: ProcessState()})" - ) - - @staticmethod - def test_get_state(): - state = DataPipelineState() - assert state.get_state(ProcessState) is None - - -def test_data_pipeline_str(): - data_pipeline = DataPipeline( - input=cast(Input, "input"), - input_transform=cast(InputTransform, "input_transform"), - output_transform=cast(OutputTransform, "output_transform"), - output=cast(Output, "output"), - deserializer=cast(Deserializer, "deserializer"), - ) - - expected = "input=input, deserializer=deserializer, " - expected += "input_transform=input_transform, output_transform=output_transform, output=output" - assert str(data_pipeline) == (f"DataPipeline({expected})") - - def test_is_overridden_recursive(tmpdir): class TestInputTransform(InputTransform): @staticmethod diff --git a/tests/core/data/test_properties.py b/tests/core/data/test_properties.py index 06918624ba..f2adf2b08e 100644 --- a/tests/core/data/test_properties.py +++ b/tests/core/data/test_properties.py @@ -13,38 +13,10 @@ # limitations under the License. import pytest -from flash.core.data.data_pipeline import DataPipelineState -from flash.core.data.properties import ProcessState, Properties +from flash.core.data.properties import Properties from flash.core.utilities.stages import RunningStage -def test_properties_data_pipeline_state(): - """Tests that ``get_state`` and ``set_state`` work for properties and that ``DataPipelineState`` is attached - correctly.""" - - class MyProcessState1(ProcessState): - pass - - class MyProcessState2(ProcessState): - pass - - class OtherProcessState(ProcessState): - pass - - my_properties = Properties() - my_properties.set_state(MyProcessState1()) - assert my_properties._state == {MyProcessState1: MyProcessState1()} - assert my_properties.get_state(OtherProcessState) is None - - data_pipeline_state = DataPipelineState() - data_pipeline_state.set_state(OtherProcessState()) - my_properties.attach_data_pipeline_state(data_pipeline_state) - assert my_properties.get_state(OtherProcessState) == OtherProcessState() - - my_properties.set_state(MyProcessState2()) - assert data_pipeline_state.get_state(MyProcessState2) == MyProcessState2() - - @pytest.mark.parametrize( "running_stage", [RunningStage.TRAINING, RunningStage.VALIDATING, RunningStage.TESTING, RunningStage.PREDICTING] ) diff --git a/tests/core/data/test_serialization.py b/tests/core/data/test_serialization.py deleted file mode 100644 index 58d7c72f3b..0000000000 --- a/tests/core/data/test_serialization.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os - -import torch -from pytorch_lightning import Trainer -from pytorch_lightning.callbacks import ModelCheckpoint -from torch.utils.data.dataloader import DataLoader - -from flash.core.model import Task - - -class CustomModel(Task): - def __init__(self): - super().__init__(model=torch.nn.Linear(1, 1), loss_fn=torch.nn.MSELoss()) - - -def test_serialization_data_pipeline(tmpdir): - model = CustomModel() - - checkpoint_file = os.path.join(tmpdir, "tmp.ckpt") - checkpoint = ModelCheckpoint(tmpdir, "test.ckpt") - trainer = Trainer(callbacks=[checkpoint], max_epochs=1) - dummy_data = DataLoader(list(zip(torch.arange(10, dtype=torch.float), torch.arange(10, dtype=torch.float)))) - trainer.fit(model, dummy_data) - - assert model.data_pipeline - trainer.save_checkpoint(checkpoint_file) - - loaded_model = CustomModel.load_from_checkpoint(checkpoint_file) - assert loaded_model.data_pipeline - - trainer.fit(model, dummy_data) - assert model.data_pipeline - trainer.save_checkpoint(checkpoint_file) - - loaded_model = CustomModel.load_from_checkpoint(checkpoint_file) - assert loaded_model.data_pipeline - for file in os.listdir(tmpdir): - if file.endswith(".ckpt"): - os.remove(os.path.join(tmpdir, file)) diff --git a/tests/core/integrations/labelstudio/test_labelstudio.py b/tests/core/integrations/labelstudio/test_labelstudio.py index 0814906093..2a66674e7f 100644 --- a/tests/core/integrations/labelstudio/test_labelstudio.py +++ b/tests/core/integrations/labelstudio/test_labelstudio.py @@ -1,6 +1,5 @@ import pytest -from flash.core.data.data_pipeline import DataPipelineState from flash.core.data.utils import download_data from flash.core.integrations.labelstudio.input import ( _load_json_data, @@ -9,7 +8,6 @@ LabelStudioTextClassificationInput, ) from flash.core.integrations.labelstudio.visualizer import launch_app -from flash.core.integrations.transformers.states import TransformersBackboneState from flash.core.utilities.imports import _IMAGE_TESTING, _TEXT_TESTING, _VIDEO_TESTING from flash.core.utilities.stages import RunningStage from flash.image.classification.data import ImageClassificationData @@ -149,10 +147,9 @@ def test_input_labelstudio(): "multi_label": False, } - data_pipeline_state = DataPipelineState() train_data, val_data = LabelStudioInput._split_train_val_data(data, split=0.1) - train = LabelStudioInput(RunningStage.TRAINING, train_data, data_pipeline_state=data_pipeline_state) - val = LabelStudioInput(RunningStage.VALIDATING, val_data, data_pipeline_state=data_pipeline_state) + train = LabelStudioInput(RunningStage.TRAINING, train_data) + val = LabelStudioInput(RunningStage.VALIDATING, val_data, parameters=train.parameters) train_sample = train[0] val_sample = val[0] @@ -171,15 +168,9 @@ def test_input_labelstudio_image(): "multi_label": True, } - data_pipeline_state = DataPipelineState() train_data, val_data = LabelStudioInput._split_train_val_data(data, split=0.2) - train = LabelStudioImageClassificationInput( - RunningStage.TRAINING, train_data, data_pipeline_state=data_pipeline_state - ) - val = LabelStudioImageClassificationInput( - RunningStage.VALIDATING, val_data, data_pipeline_state=data_pipeline_state - ) - assert train._data_pipeline_state == val._data_pipeline_state == data_pipeline_state + train = LabelStudioImageClassificationInput(RunningStage.TRAINING, train_data) + val = LabelStudioImageClassificationInput(RunningStage.VALIDATING, val_data, parameters=train.parameters) train_sample = train[0] val_sample = val[0] @@ -239,20 +230,12 @@ def test_input_labelstudio_text(): "multi_label": False, } - data_pipeline_state = DataPipelineState() train_data, test_data = LabelStudioInput._split_train_test_data(data) train_data, val_data = LabelStudioInput._split_train_val_data(train_data, split=0.2) - train = LabelStudioTextClassificationInput( - RunningStage.TRAINING, train_data, data_pipeline_state=data_pipeline_state - ) - val = LabelStudioTextClassificationInput(RunningStage.VALIDATING, val_data, data_pipeline_state=data_pipeline_state) - test = LabelStudioTextClassificationInput(RunningStage.TESTING, test_data, data_pipeline_state=data_pipeline_state) - - backbone = "prajjwal1/bert-tiny" - train.set_state(TransformersBackboneState(backbone)) + train = LabelStudioTextClassificationInput(RunningStage.TRAINING, train_data) + val = LabelStudioTextClassificationInput(RunningStage.VALIDATING, val_data, parameters=train.parameters) + test = LabelStudioTextClassificationInput(RunningStage.TESTING, test_data, parameters=train.parameters) - assert train._data_pipeline_state == val._data_pipeline_state - assert train._data_pipeline_state == test._data_pipeline_state train_sample = train[0] val_sample = val[0] assert train_sample diff --git a/tests/core/test_classification.py b/tests/core/test_classification.py index 64214b549a..e0e8582453 100644 --- a/tests/core/test_classification.py +++ b/tests/core/test_classification.py @@ -65,15 +65,17 @@ def test_classification_outputs_fiftyone(): assert predictions["predictions"].label == "class_3" assert predictions["filepath"] == "something" - predictions = FiftyOneLabelsOutput(store_logits=True).transform(example_output) + predictions = FiftyOneLabelsOutput(store_logits=True, return_filepath=False).transform(example_output) assert torch.allclose(torch.tensor(predictions.logits), logits) assert torch.allclose(torch.tensor(predictions.confidence), torch.softmax(logits, -1)[-1]) assert predictions.label == "2" - predictions = FiftyOneLabelsOutput(labels, store_logits=True).transform(example_output) + predictions = FiftyOneLabelsOutput(labels, store_logits=True, return_filepath=False).transform(example_output) assert predictions.label == "class_3" - predictions = FiftyOneLabelsOutput(store_logits=True, multi_label=True).transform(example_output) + predictions = FiftyOneLabelsOutput(store_logits=True, multi_label=True, return_filepath=False).transform( + example_output + ) assert torch.allclose(torch.tensor(predictions.logits), logits) assert [c.label for c in predictions.classifications] == ["1", "2"] - predictions = FiftyOneLabelsOutput(labels, multi_label=True).transform(example_output) + predictions = FiftyOneLabelsOutput(labels, multi_label=True, return_filepath=False).transform(example_output) assert [c.label for c in predictions.classifications] == ["class_2", "class_3"] diff --git a/tests/core/test_data.py b/tests/core/test_data.py index 0b4b37584c..cdfc01ff33 100644 --- a/tests/core/test_data.py +++ b/tests/core/test_data.py @@ -34,10 +34,15 @@ def test_init(): train_input = DatasetInput(RunningStage.TRAINING, DummyDataset()) val_input = DatasetInput(RunningStage.VALIDATING, DummyDataset()) test_input = DatasetInput(RunningStage.TESTING, DummyDataset()) - DataModule(train_input, batch_size=1) - DataModule(train_input, val_input, batch_size=1) - DataModule(train_input, val_input, test_input, batch_size=1) - assert DataModule(batch_size=1).data_pipeline + + data_module = DataModule(train_input, batch_size=1) + assert data_module.train_dataset and not data_module.val_dataset and not data_module.test_dataset + + data_module = DataModule(train_input, val_input, batch_size=1) + assert data_module.train_dataset and data_module.val_dataset and not data_module.test_dataset + + data_module = DataModule(train_input, val_input, test_input, batch_size=1) + assert data_module.train_dataset and data_module.val_dataset and data_module.test_dataset def test_dataloaders(): @@ -52,9 +57,3 @@ def test_dataloaders(): ]: x = next(iter(dl))[DataKeys.INPUT] assert x.shape == (1, 1, 28, 28) - - -def test_cpu_count_none(): - train_input = DatasetInput(RunningStage.TRAINING, DummyDataset()) - dm = DataModule(train_input, num_workers=None, batch_size=1) - assert dm.num_workers == 0 diff --git a/tests/core/test_model.py b/tests/core/test_model.py index a69200313b..47ac9cfdc1 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -197,32 +197,6 @@ def test_classification_task_trainer_predict(tmpdir): assert len(list(chain.from_iterable(predictions))) == 10 -def test_task_datapipeline_save(tmpdir): - model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.Softmax()) - train_dl = torch.utils.data.DataLoader(DummyDataset()) - task = ClassificationTask(model, loss_fn=F.nll_loss, output_transform=DummyOutputTransform()) - - # to check later - task.output_transform.test = True - - # generate a checkpoint - trainer = pl.Trainer( - default_root_dir=tmpdir, - limit_train_batches=1, - max_epochs=1, - progress_bar_refresh_rate=0, - weights_summary=None, - logger=False, - ) - trainer.fit(task, train_dl) - path = str(tmpdir / "model.ckpt") - trainer.save_checkpoint(path) - - # load from file - task = ClassificationTask.load_from_checkpoint(path, model=model) - assert task.output_transform.test - - @pytest.mark.parametrize( ["cls", "filename"], [ diff --git a/tests/graph/classification/test_model.py b/tests/graph/classification/test_model.py index c429d0d65a..6c0bde8b6d 100644 --- a/tests/graph/classification/test_model.py +++ b/tests/graph/classification/test_model.py @@ -89,7 +89,7 @@ def test_predict_dataset(tmpdir): batch_size=4, ) trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) - out = trainer.predict(model, datamodule=datamodule) + out = trainer.predict(model, datamodule=datamodule, output="classes") assert isinstance(out[0][0], int) diff --git a/tests/image/classification/test_active_learning.py b/tests/image/classification/test_active_learning.py index effedd0470..53349af8a3 100644 --- a/tests/image/classification/test_active_learning.py +++ b/tests/image/classification/test_active_learning.py @@ -22,7 +22,6 @@ from torch.utils.data import SequentialSampler import flash -from flash.core.classification import ProbabilitiesOutput from flash.core.utilities.imports import _BAAL_AVAILABLE, _IMAGE_TESTING from flash.image import ImageClassificationData, ImageClassifier from flash.image.classification.integrations.baal import ActiveLearningDataModule, ActiveLearningLoop @@ -68,15 +67,13 @@ def test_active_learning_training(simple_datamodule, initial_num_labels, query_s seed_everything(42) if initial_num_labels == 0: - with pytest.warns(UserWarning) as record: + with pytest.warns(UserWarning, match="No labels provided for the initial step"): active_learning_dm = ActiveLearningDataModule( simple_datamodule, initial_num_labels=initial_num_labels, query_size=query_size, val_split=0.5, ) - assert len(record) == 1 - assert "No labels provided for the initial step" in record[0].message.args[0] else: active_learning_dm = ActiveLearningDataModule( simple_datamodule, @@ -90,9 +87,7 @@ def test_active_learning_training(simple_datamodule, initial_num_labels, query_s nn.Linear(512, active_learning_dm.num_classes), ) - model = ImageClassifier( - backbone="resnet18", head=head, num_classes=active_learning_dm.num_classes, output=ProbabilitiesOutput() - ) + model = ImageClassifier(backbone="resnet18", head=head, num_classes=active_learning_dm.num_classes) trainer = flash.Trainer(max_epochs=3, num_sanity_val_steps=0) active_learning_loop = ActiveLearningLoop(label_epoch_frequency=1, inference_iteration=3) active_learning_loop.connect(trainer.fit_loop) @@ -142,9 +137,7 @@ def test_no_validation_loop(simple_datamodule): nn.Linear(512, active_learning_dm.num_classes), ) - model = ImageClassifier( - backbone="resnet18", head=head, num_classes=active_learning_dm.num_classes, output=ProbabilitiesOutput() - ) + model = ImageClassifier(backbone="resnet18", head=head, num_classes=active_learning_dm.num_classes) trainer = flash.Trainer(max_epochs=3) active_learning_loop = ActiveLearningLoop(label_epoch_frequency=1, inference_iteration=3) active_learning_loop.connect(trainer.fit_loop) diff --git a/tests/image/classification/test_model.py b/tests/image/classification/test_model.py index 0722feae35..c174d99ec2 100644 --- a/tests/image/classification/test_model.py +++ b/tests/image/classification/test_model.py @@ -20,7 +20,6 @@ from flash import Trainer from flash.__main__ import main -from flash.core.classification import ProbabilitiesOutput from flash.core.data.io.input import DataKeys from flash.core.utilities.imports import _IMAGE_AVAILABLE, _IMAGE_TESTING, _SERVE_TESTING from flash.image import ImageClassifier @@ -87,7 +86,7 @@ def test_init_train_head(tmpdir, head): @pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") def test_non_existent_backbone(): with pytest.raises(KeyError): - ImageClassifier(2, "i am never going to implement this lol") + ImageClassifier(2, backbone="i am never going to implement this lol") @pytest.mark.skipif(not _IMAGE_TESTING, reason="image libraries aren't installed.") @@ -111,11 +110,11 @@ def test_multilabel(tmpdir): num_classes = 4 ds = DummyMultiLabelDataset(num_classes) - model = ImageClassifier(num_classes, multi_label=True, output=ProbabilitiesOutput(multi_label=True)) + model = ImageClassifier(num_classes, multi_label=True) train_dl = torch.utils.data.DataLoader(ds, batch_size=2) trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, limit_train_batches=5) trainer.finetune(model, train_dl, strategy=("freeze_unfreeze", 1)) - predictions = trainer.predict(model, train_dl)[0] + predictions = trainer.predict(model, train_dl, output="probabilities")[0] assert (torch.tensor(predictions) > 1).sum() == 0 assert (torch.tensor(predictions) < 0).sum() == 0 assert len(predictions[0]) == num_classes diff --git a/tests/image/detection/test_model.py b/tests/image/detection/test_model.py index 903948a6de..db46114b0c 100644 --- a/tests/image/detection/test_model.py +++ b/tests/image/detection/test_model.py @@ -18,11 +18,11 @@ import numpy as np import pytest import torch -from pytorch_lightning import Trainer from torch.utils.data import Dataset from flash.__main__ import main from flash.core.data.io.input import DataKeys +from flash.core.trainer import Trainer from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _IMAGE_AVAILABLE from flash.image import ObjectDetector @@ -146,8 +146,8 @@ def test_predict(tmpdir, head): dl = model.process_train_dataset(ds, trainer, 2, 0, False, None) trainer.fit(model, dl) dl = model.process_predict_dataset(ds, batch_size=2) - predictions = trainer.predict(model, dl) + predictions = trainer.predict(model, dl, output="preds") assert len(predictions[0][0]["bboxes"]) > 0 model.predict_kwargs = {"detection_threshold": 2} - predictions = trainer.predict(model, dl) + predictions = trainer.predict(model, dl, output="preds") assert len(predictions[0][0]["bboxes"]) == 0 diff --git a/tests/image/detection/test_output.py b/tests/image/detection/test_output.py index 4fe0395290..136568821c 100644 --- a/tests/image/detection/test_output.py +++ b/tests/image/detection/test_output.py @@ -18,10 +18,10 @@ def test_smoke(): @staticmethod def test_serialize_fiftyone(): labels = ["class_1", "class_2", "class_3"] - serial = FiftyOneDetectionLabelsOutput() + serial = FiftyOneDetectionLabelsOutput(return_filepath=False) filepath_serial = FiftyOneDetectionLabelsOutput(return_filepath=True) - threshold_serial = FiftyOneDetectionLabelsOutput(threshold=0.9) - labels_serial = FiftyOneDetectionLabelsOutput(labels=labels) + threshold_serial = FiftyOneDetectionLabelsOutput(threshold=0.9, return_filepath=False) + labels_serial = FiftyOneDetectionLabelsOutput(labels=labels, return_filepath=False) sample = { DataKeys.PREDS: { diff --git a/tests/image/instance_segmentation/test_model.py b/tests/image/instance_segmentation/test_model.py index 54f47f8e7d..ab703cc049 100644 --- a/tests/image/instance_segmentation/test_model.py +++ b/tests/image/instance_segmentation/test_model.py @@ -63,7 +63,6 @@ def test_instance_segmentation_inference(tmpdir): head="mask_rcnn", backbone="resnet18_fpn", num_classes=datamodule.num_classes, - output=None, ) # 3. Create the trainer and finetune the model diff --git a/tests/image/segmentation/test_model.py b/tests/image/segmentation/test_model.py index e4517fb6b2..85e7a3a918 100644 --- a/tests/image/segmentation/test_model.py +++ b/tests/image/segmentation/test_model.py @@ -106,7 +106,7 @@ def test_predict_tensor(): model = SemanticSegmentation(2, backbone="mobilenetv3_large_100") datamodule = SemanticSegmentationData.from_tensors(predict_data=img, batch_size=1) trainer = Trainer() - out = trainer.predict(model, datamodule=datamodule) + out = trainer.predict(model, datamodule=datamodule, output="labels") assert isinstance(out[0][0], list) assert len(out[0][0]) == 64 assert len(out[0][0][0]) == 64 @@ -118,7 +118,7 @@ def test_predict_numpy(): model = SemanticSegmentation(2, backbone="mobilenetv3_large_100") datamodule = SemanticSegmentationData.from_numpy(predict_data=img, batch_size=1) trainer = Trainer() - out = trainer.predict(model, datamodule=datamodule) + out = trainer.predict(model, datamodule=datamodule, output="labels") assert isinstance(out[0][0], list) assert len(out[0][0]) == 64 assert len(out[0][0][0]) == 64 diff --git a/tests/tabular/classification/test_model.py b/tests/tabular/classification/test_model.py index 0fed8a560e..d66654a5a2 100644 --- a/tests/tabular/classification/test_model.py +++ b/tests/tabular/classification/test_model.py @@ -54,8 +54,8 @@ def __len__(self) -> int: def test_init_train(backbone, tmpdir): train_dl = torch.utils.data.DataLoader(DummyDataset(), batch_size=16) data_properties = { + "parameters": {"categorical_fields": list(range(16))}, "embedding_sizes": [(10, 32) for _ in range(16)], - "categorical_fields": list(range(16)), "cat_dims": [10 for _ in range(16)], "num_features": 32, "num_classes": 10, @@ -74,8 +74,8 @@ def test_init_train(backbone, tmpdir): def test_init_train_no_num(backbone, tmpdir): train_dl = torch.utils.data.DataLoader(DummyDataset(num_num=0), batch_size=16) data_properties = { + "parameters": {"categorical_fields": list(range(16))}, "embedding_sizes": [(10, 32) for _ in range(16)], - "categorical_fields": list(range(16)), "cat_dims": [10 for _ in range(16)], "num_features": 16, "num_classes": 10, @@ -92,8 +92,8 @@ def test_init_train_no_num(backbone, tmpdir): def test_init_train_no_cat(backbone, tmpdir): train_dl = torch.utils.data.DataLoader(DummyDataset(num_cat=0), batch_size=16) data_properties = { + "parameters": {"categorical_fields": []}, "embedding_sizes": [], - "categorical_fields": [], "cat_dims": [], "num_features": 16, "num_classes": 10, @@ -108,8 +108,8 @@ def test_init_train_no_cat(backbone, tmpdir): @pytest.mark.skipif(_TABULAR_AVAILABLE, reason="tabular libraries are installed.") def test_module_import_error(tmpdir): data_properties = { + "parameters": {"categorical_fields": list(range(16))}, "embedding_sizes": [(10, 32) for _ in range(16)], - "categorical_fields": list(range(16)), "cat_dims": [10 for _ in range(16)], "num_features": 32, "num_classes": 10, @@ -125,8 +125,8 @@ def test_module_import_error(tmpdir): ) def test_jit(backbone, tmpdir): data_properties = { + "parameters": {"categorical_fields": list(range(4))}, "embedding_sizes": [(10, 32) for _ in range(4)], - "categorical_fields": list(range(4)), "cat_dims": [10 for _ in range(4)], "num_features": 8, "num_classes": 10, @@ -167,8 +167,6 @@ def test_serve(backbone): batch_size=1, ) model = TabularClassifier.from_data(datamodule=datamodule, backbone=backbone) - # TODO: Currently only servable once a input_transform has been attached - model._input_transform = datamodule.input_transform model.eval() model.serve(parameters=datamodule.parameters) diff --git a/tests/tabular/regression/test_model.py b/tests/tabular/regression/test_model.py index b9054d4799..a82f5e2bf1 100644 --- a/tests/tabular/regression/test_model.py +++ b/tests/tabular/regression/test_model.py @@ -52,9 +52,10 @@ def __len__(self) -> int: ) def test_init_train(backbone, tmpdir): train_dl = torch.utils.data.DataLoader(DummyDataset(), batch_size=16) + data_properties = { + "parameters": {"categorical_fields": list(range(16))}, "embedding_sizes": [(10, 32) for _ in range(16)], - "categorical_fields": list(range(16)), "cat_dims": [10 for _ in range(16)], "num_features": 32, "backbone": backbone, @@ -72,8 +73,8 @@ def test_init_train(backbone, tmpdir): def test_init_train_no_num(backbone, tmpdir): train_dl = torch.utils.data.DataLoader(DummyDataset(num_num=0), batch_size=16) data_properties = { + "parameters": {"categorical_fields": list(range(16))}, "embedding_sizes": [(10, 32) for _ in range(16)], - "categorical_fields": list(range(16)), "cat_dims": [10 for _ in range(16)], "num_features": 16, "backbone": backbone, @@ -89,8 +90,8 @@ def test_init_train_no_num(backbone, tmpdir): def test_init_train_no_cat(backbone, tmpdir): train_dl = torch.utils.data.DataLoader(DummyDataset(num_cat=0), batch_size=16) data_properties = { + "parameters": {"categorical_fields": []}, "embedding_sizes": [], - "categorical_fields": [], "cat_dims": [], "num_features": 16, "backbone": backbone, @@ -104,8 +105,8 @@ def test_init_train_no_cat(backbone, tmpdir): @pytest.mark.skipif(_TABULAR_AVAILABLE, reason="tabular libraries are installed.") def test_module_import_error(tmpdir): data_properties = { + "parameters": {"categorical_fields": list(range(16))}, "embedding_sizes": [(10, 32) for _ in range(16)], - "categorical_fields": list(range(16)), "cat_dims": [10 for _ in range(16)], "num_features": 32, "backbone": "tabnet", @@ -120,8 +121,8 @@ def test_module_import_error(tmpdir): ) def test_jit(backbone, tmpdir): data_properties = { + "parameters": {"categorical_fields": list(range(4))}, "embedding_sizes": [(10, 32) for _ in range(4)], - "categorical_fields": list(range(4)), "cat_dims": [10 for _ in range(4)], "num_features": 8, "backbone": backbone, @@ -161,8 +162,6 @@ def test_serve(backbone): batch_size=1, ) model = TabularRegressor.from_data(datamodule=datamodule, backbone=backbone) - # TODO: Currently only servable once a input_transform has been attached - model._input_transform = datamodule.input_transform model.eval() model.serve(parameters=datamodule.parameters) diff --git a/tests/template/classification/test_model.py b/tests/template/classification/test_model.py index c1be89c969..ba76cab455 100644 --- a/tests/template/classification/test_model.py +++ b/tests/template/classification/test_model.py @@ -106,7 +106,7 @@ def test_predict_numpy(): model = TemplateSKLearnClassifier(num_features=DummyDataset.num_features, num_classes=DummyDataset.num_classes) datamodule = TemplateData.from_numpy(predict_data=row, batch_size=1) trainer = Trainer() - out = trainer.predict(model, datamodule=datamodule) + out = trainer.predict(model, datamodule=datamodule, output="classes") assert isinstance(out[0][0], int) @@ -117,7 +117,7 @@ def test_predict_sklearn(): model = TemplateSKLearnClassifier(num_features=DummyDataset.num_features, num_classes=DummyDataset.num_classes) datamodule = TemplateData.from_sklearn(predict_bunch=bunch, batch_size=1) trainer = Trainer() - out = trainer.predict(model, datamodule=datamodule) + out = trainer.predict(model, datamodule=datamodule, output="classes") assert isinstance(out[0][0], int) diff --git a/tests/text/classification/test_data.py b/tests/text/classification/test_data.py index 73f06b018e..1ba45499c9 100644 --- a/tests/text/classification/test_data.py +++ b/tests/text/classification/test_data.py @@ -18,16 +18,12 @@ import pytest from flash.core.data.io.input import DataKeys -from flash.core.integrations.transformers.states import TransformersBackboneState from flash.core.utilities.imports import _TEXT_AVAILABLE, _TEXT_TESTING from flash.text import TextClassificationData if _TEXT_AVAILABLE: from datasets import Dataset -TEST_BACKBONE = "prajjwal1/bert-tiny" # super small model for testing -TEST_BACKBONE_STATE = TransformersBackboneState(TEST_BACKBONE) - TEST_CSV_DATA = """sentence,label this is a sentence one,0 this is a sentence two,1 @@ -134,22 +130,20 @@ def test_from_csv(tmpdir): batch_size=1, ) - dm.train_dataset.set_state(TEST_BACKBONE_STATE) - batch = next(iter(dm.train_dataloader())) assert batch[DataKeys.TARGET].item() in [0, 1] - assert "input_ids" in batch + assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.val_dataloader())) assert batch[DataKeys.TARGET].item() in [0, 1] - assert "input_ids" in batch + assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.test_dataloader())) assert batch[DataKeys.TARGET].item() in [0, 1] - assert "input_ids" in batch + assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.predict_dataloader())) - assert "input_ids" in batch + assert isinstance(batch[DataKeys.INPUT][0], str) @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") @@ -166,24 +160,22 @@ def test_from_csv_multilabel(tmpdir): batch_size=1, ) - dm.train_dataset.set_state(TEST_BACKBONE_STATE) - assert dm.multi_label batch = next(iter(dm.train_dataloader())) assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) - assert "input_ids" in batch + assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.val_dataloader())) assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) - assert "input_ids" in batch + assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.test_dataloader())) assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) - assert "input_ids" in batch + assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.predict_dataloader())) - assert "input_ids" in batch + assert isinstance(batch[DataKeys.INPUT][0], str) @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") @@ -200,22 +192,20 @@ def test_from_json(tmpdir): batch_size=1, ) - dm.train_dataset.set_state(TEST_BACKBONE_STATE) - batch = next(iter(dm.train_dataloader())) assert batch[DataKeys.TARGET].item() in [0, 1] - assert "input_ids" in batch + assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.val_dataloader())) assert batch[DataKeys.TARGET].item() in [0, 1] - assert "input_ids" in batch + assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.test_dataloader())) assert batch[DataKeys.TARGET].item() in [0, 1] - assert "input_ids" in batch + assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.predict_dataloader())) - assert "input_ids" in batch + assert isinstance(batch[DataKeys.INPUT][0], str) @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") @@ -232,24 +222,22 @@ def test_from_json_multilabel(tmpdir): batch_size=1, ) - dm.train_dataset.set_state(TEST_BACKBONE_STATE) - assert dm.multi_label batch = next(iter(dm.train_dataloader())) assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) - assert "input_ids" in batch + assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.val_dataloader())) assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) - assert "input_ids" in batch + assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.test_dataloader())) assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) - assert "input_ids" in batch + assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.predict_dataloader())) - assert "input_ids" in batch + assert isinstance(batch[DataKeys.INPUT][0], str) @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") @@ -267,22 +255,20 @@ def test_from_json_with_field(tmpdir): field="data", ) - dm.train_dataset.set_state(TEST_BACKBONE_STATE) - batch = next(iter(dm.train_dataloader())) assert batch[DataKeys.TARGET].item() in [0, 1] - assert "input_ids" in batch + assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.val_dataloader())) assert batch[DataKeys.TARGET].item() in [0, 1] - assert "input_ids" in batch + assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.test_dataloader())) assert batch[DataKeys.TARGET].item() in [0, 1] - assert "input_ids" in batch + assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.predict_dataloader())) - assert "input_ids" in batch + assert isinstance(batch[DataKeys.INPUT][0], str) @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") @@ -300,24 +286,22 @@ def test_from_json_with_field_multilabel(tmpdir): field="data", ) - dm.train_dataset.set_state(TEST_BACKBONE_STATE) - assert dm.multi_label batch = next(iter(dm.train_dataloader())) assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) - assert "input_ids" in batch + assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.val_dataloader())) assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) - assert "input_ids" in batch + assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.test_dataloader())) assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) - assert "input_ids" in batch + assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.predict_dataloader())) - assert "input_ids" in batch + assert isinstance(batch[DataKeys.INPUT][0], str) @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") @@ -334,22 +318,20 @@ def test_from_parquet(tmpdir): batch_size=1, ) - dm.train_dataset.set_state(TEST_BACKBONE_STATE) - batch = next(iter(dm.train_dataloader())) assert batch[DataKeys.TARGET].item() in [0, 1] - assert "input_ids" in batch + assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.val_dataloader())) assert batch[DataKeys.TARGET].item() in [0, 1] - assert "input_ids" in batch + assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.test_dataloader())) assert batch[DataKeys.TARGET].item() in [0, 1] - assert "input_ids" in batch + assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.predict_dataloader())) - assert "input_ids" in batch + assert isinstance(batch[DataKeys.INPUT][0], str) @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") @@ -366,24 +348,22 @@ def test_from_parquet_multilabel(tmpdir): batch_size=1, ) - dm.train_dataset.set_state(TEST_BACKBONE_STATE) - assert dm.multi_label batch = next(iter(dm.train_dataloader())) assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) - assert "input_ids" in batch + assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.val_dataloader())) assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) - assert "input_ids" in batch + assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.test_dataloader())) assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) - assert "input_ids" in batch + assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.predict_dataloader())) - assert "input_ids" in batch + assert isinstance(batch[DataKeys.INPUT][0], str) @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") @@ -399,22 +379,20 @@ def test_from_data_frame(): batch_size=1, ) - dm.train_dataset.set_state(TEST_BACKBONE_STATE) - batch = next(iter(dm.train_dataloader())) assert batch[DataKeys.TARGET].item() in [0, 1] - assert "input_ids" in batch + assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.val_dataloader())) assert batch[DataKeys.TARGET].item() in [0, 1] - assert "input_ids" in batch + assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.test_dataloader())) assert batch[DataKeys.TARGET].item() in [0, 1] - assert "input_ids" in batch + assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.predict_dataloader())) - assert "input_ids" in batch + assert isinstance(batch[DataKeys.INPUT][0], str) @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") @@ -430,24 +408,22 @@ def test_from_data_frame_multilabel(): batch_size=1, ) - dm.train_dataset.set_state(TEST_BACKBONE_STATE) - assert dm.multi_label batch = next(iter(dm.train_dataloader())) assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) - assert "input_ids" in batch + assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.val_dataloader())) assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) - assert "input_ids" in batch + assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.test_dataloader())) assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) - assert "input_ids" in batch + assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.predict_dataloader())) - assert "input_ids" in batch + assert isinstance(batch[DataKeys.INPUT][0], str) @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") @@ -464,22 +440,20 @@ def test_from_hf_datasets(): batch_size=1, ) - dm.train_dataset.set_state(TEST_BACKBONE_STATE) - batch = next(iter(dm.train_dataloader())) assert batch[DataKeys.TARGET].item() in [0, 1] - assert "input_ids" in batch + assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.val_dataloader())) assert batch[DataKeys.TARGET].item() in [0, 1] - assert "input_ids" in batch + assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.test_dataloader())) assert batch[DataKeys.TARGET].item() in [0, 1] - assert "input_ids" in batch + assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.predict_dataloader())) - assert "input_ids" in batch + assert isinstance(batch[DataKeys.INPUT][0], str) @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") @@ -496,24 +470,22 @@ def test_from_hf_datasets_multilabel(): batch_size=1, ) - dm.train_dataset.set_state(TEST_BACKBONE_STATE) - assert dm.multi_label batch = next(iter(dm.train_dataloader())) assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) - assert "input_ids" in batch + assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.val_dataloader())) assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) - assert "input_ids" in batch + assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.test_dataloader())) assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) - assert "input_ids" in batch + assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.predict_dataloader())) - assert "input_ids" in batch + assert isinstance(batch[DataKeys.INPUT][0], str) @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") @@ -530,22 +502,20 @@ def test_from_lists(): batch_size=1, ) - dm.train_dataset.set_state(TEST_BACKBONE_STATE) - batch = next(iter(dm.train_dataloader())) assert batch[DataKeys.TARGET].item() in [0, 1] - assert "input_ids" in batch + assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.val_dataloader())) assert batch[DataKeys.TARGET].item() in [0, 1] - assert "input_ids" in batch + assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.test_dataloader())) assert batch[DataKeys.TARGET].item() in [0, 1] - assert "input_ids" in batch + assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.predict_dataloader())) - assert "input_ids" in batch + assert isinstance(batch[DataKeys.INPUT][0], str) @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") @@ -562,24 +532,22 @@ def test_from_lists_multilabel(): batch_size=1, ) - dm.train_dataset.set_state(TEST_BACKBONE_STATE) - assert dm.multi_label batch = next(iter(dm.train_dataloader())) assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) - assert "input_ids" in batch + assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.val_dataloader())) assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) - assert "input_ids" in batch + assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.test_dataloader())) assert all([label in [0, 1] for label in batch[DataKeys.TARGET][0]]) - assert "input_ids" in batch + assert isinstance(batch[DataKeys.INPUT][0], str) batch = next(iter(dm.predict_dataloader())) - assert "input_ids" in batch + assert isinstance(batch[DataKeys.INPUT][0], str) @pytest.mark.skipif(_TEXT_AVAILABLE, reason="text libraries are installed.") diff --git a/tests/text/classification/test_model.py b/tests/text/classification/test_model.py index 1f535f6102..9f745cbe0d 100644 --- a/tests/text/classification/test_model.py +++ b/tests/text/classification/test_model.py @@ -46,7 +46,7 @@ def __len__(self) -> int: @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") @pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") def test_init_train(tmpdir): - model = TextClassifier(2, TEST_BACKBONE) + model = TextClassifier(2, backbone=TEST_BACKBONE) train_dl = torch.utils.data.DataLoader(DummyDataset()) trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) trainer.fit(model, train_dl) @@ -57,7 +57,7 @@ def test_jit(tmpdir): sample_input = {"input_ids": torch.randint(1000, size=(1, 100))} path = os.path.join(tmpdir, "test.pt") - model = TextClassifier(2, TEST_BACKBONE) + model = TextClassifier(2, backbone=TEST_BACKBONE) model.eval() # Huggingface bert model only supports `torch.jit.trace` with `strict=False` @@ -74,7 +74,7 @@ def test_jit(tmpdir): @pytest.mark.skipif(not _SERVE_TESTING, reason="serve libraries aren't installed.") @mock.patch("flash._IS_TESTING", True) def test_serve(): - model = TextClassifier(2, TEST_BACKBONE) + model = TextClassifier(2, backbone=TEST_BACKBONE) model.eval() model.serve() diff --git a/tests/text/question_answering/test_data.py b/tests/text/question_answering/test_data.py index c28bd271bc..afe80479d4 100644 --- a/tests/text/question_answering/test_data.py +++ b/tests/text/question_answering/test_data.py @@ -18,14 +18,9 @@ import pandas as pd import pytest -from flash.core.data.io.input import DataKeys -from flash.core.integrations.transformers.states import TransformersBackboneState from flash.core.utilities.imports import _TEXT_TESTING from flash.text import QuestionAnsweringData -TEST_BACKBONE = "distilbert-base-uncased" -TEST_BACKBONE_STATE = TransformersBackboneState(TEST_BACKBONE) - TEST_CSV_DATA = { "id": ["12345", "12346", "12347", "12348"], "context": [ @@ -115,12 +110,10 @@ def test_from_csv(tmpdir): train_file=csv_path, batch_size=2, ) - dm.train_dataset.set_state(TEST_BACKBONE_STATE) batch = next(iter(dm.train_dataloader())) - assert "input_ids" in batch - assert "attention_mask" in batch - assert "start_positions" in batch - assert "end_positions" in batch + assert isinstance(batch["question"][0], str) + assert isinstance(batch["context"][0], str) + assert isinstance(batch["answer"], dict) @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") @@ -136,28 +129,15 @@ def test_from_files(tmpdir): test_file=csv_path, batch_size=2, ) - dm.train_dataset.set_state(TEST_BACKBONE_STATE) batch = next(iter(dm.val_dataloader())) - assert "input_ids" in batch - assert "attention_mask" in batch - assert "start_positions" in batch - assert "end_positions" in batch - assert DataKeys.METADATA in batch - assert "context" in batch[DataKeys.METADATA][0] - assert "answer" in batch[DataKeys.METADATA][0] - assert "example_id" in batch[DataKeys.METADATA][0] - assert "offset_mapping" in batch[DataKeys.METADATA][0] + assert isinstance(batch["question"][0], str) + assert isinstance(batch["context"][0], str) + assert isinstance(batch["answer"], dict) batch = next(iter(dm.test_dataloader())) - assert "input_ids" in batch - assert "attention_mask" in batch - assert "start_positions" in batch - assert "end_positions" in batch - assert DataKeys.METADATA in batch - assert "context" in batch[DataKeys.METADATA][0] - assert "answer" in batch[DataKeys.METADATA][0] - assert "example_id" in batch[DataKeys.METADATA][0] - assert "offset_mapping" in batch[DataKeys.METADATA][0] + assert isinstance(batch["question"][0], str) + assert isinstance(batch["context"][0], str) + assert isinstance(batch["answer"], dict) @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") @@ -171,12 +151,10 @@ def test_from_json(tmpdir): train_file=json_path, batch_size=2, ) - dm.train_dataset.set_state(TEST_BACKBONE_STATE) batch = next(iter(dm.train_dataloader())) - assert "input_ids" in batch - assert "attention_mask" in batch - assert "start_positions" in batch - assert "end_positions" in batch + assert isinstance(batch["question"][0], str) + assert isinstance(batch["context"][0], str) + assert isinstance(batch["answer"], dict) @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") @@ -191,12 +169,10 @@ def test_from_json_with_field(tmpdir): field="data", batch_size=2, ) - dm.train_dataset.set_state(TEST_BACKBONE_STATE) batch = next(iter(dm.train_dataloader())) - assert "input_ids" in batch - assert "attention_mask" in batch - assert "start_positions" in batch - assert "end_positions" in batch + assert isinstance(batch["question"][0], str) + assert isinstance(batch["context"][0], str) + assert isinstance(batch["answer"], dict) @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") diff --git a/tests/text/seq2seq/summarization/test_data.py b/tests/text/seq2seq/summarization/test_data.py index 218239f0d4..4644453778 100644 --- a/tests/text/seq2seq/summarization/test_data.py +++ b/tests/text/seq2seq/summarization/test_data.py @@ -17,15 +17,9 @@ import pytest from flash import DataKeys -from flash.core.integrations.transformers.states import TransformersBackboneState from flash.core.utilities.imports import _TEXT_TESTING from flash.text import SummarizationData -TEST_BACKBONE = "sshleifer/tiny-mbart" # super small model for testing -TEST_BACKBONE_STATE = TransformersBackboneState( - TEST_BACKBONE, tokenizer_kwargs=dict(src_lang="en_XX", tgt_lang="en_XX") -) - TEST_CSV_DATA = """input,target this is a sentence one,this is a summarized sentence one this is a sentence two,this is a summarized sentence two @@ -68,10 +62,9 @@ def json_data_with_field(tmpdir): def test_from_csv(tmpdir): csv_path = csv_data(tmpdir) dm = SummarizationData.from_csv("input", "target", train_file=csv_path, batch_size=1) - dm.train_dataset.set_state(TEST_BACKBONE_STATE) batch = next(iter(dm.train_dataloader())) - assert DataKeys.TARGET in batch - assert "input_ids" in batch + assert isinstance(batch[DataKeys.INPUT][0], str) + assert isinstance(batch[DataKeys.TARGET][0], str) @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") @@ -86,14 +79,13 @@ def test_from_files(tmpdir): test_file=csv_path, batch_size=1, ) - dm.train_dataset.set_state(TEST_BACKBONE_STATE) batch = next(iter(dm.val_dataloader())) - assert DataKeys.TARGET in batch - assert "input_ids" in batch + assert isinstance(batch[DataKeys.INPUT][0], str) + assert isinstance(batch[DataKeys.TARGET][0], str) batch = next(iter(dm.test_dataloader())) - assert DataKeys.TARGET in batch - assert "input_ids" in batch + assert isinstance(batch[DataKeys.INPUT][0], str) + assert isinstance(batch[DataKeys.TARGET][0], str) @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") @@ -106,10 +98,9 @@ def test_from_json(tmpdir): train_file=json_path, batch_size=1, ) - dm.train_dataset.set_state(TEST_BACKBONE_STATE) batch = next(iter(dm.train_dataloader())) - assert DataKeys.TARGET in batch - assert "input_ids" in batch + assert isinstance(batch[DataKeys.INPUT][0], str) + assert isinstance(batch[DataKeys.TARGET][0], str) @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") @@ -123,7 +114,6 @@ def test_from_json_with_field(tmpdir): batch_size=1, field="data", ) - dm.train_dataset.set_state(TEST_BACKBONE_STATE) batch = next(iter(dm.train_dataloader())) - assert DataKeys.TARGET in batch - assert "input_ids" in batch + assert isinstance(batch[DataKeys.INPUT][0], str) + assert isinstance(batch[DataKeys.TARGET][0], str) diff --git a/tests/text/seq2seq/summarization/test_model.py b/tests/text/seq2seq/summarization/test_model.py index 7f1dbbd584..bba5a67a0a 100644 --- a/tests/text/seq2seq/summarization/test_model.py +++ b/tests/text/seq2seq/summarization/test_model.py @@ -53,7 +53,7 @@ def test_init_train(tmpdir): @pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") def test_jit(tmpdir): sample_input = { - "input_ids": torch.randint(1000, size=(1, 32)), + "input_ids": torch.randint(128, size=(1, 32)), "attention_mask": torch.randint(1, size=(1, 32)), } path = os.path.join(tmpdir, "test.pt") @@ -62,7 +62,7 @@ def test_jit(tmpdir): model.eval() # Huggingface only supports `torch.jit.trace` - model = torch.jit.trace(model, [sample_input]) + model = torch.jit.trace(model, [sample_input], check_trace=False) torch.jit.save(model, path) model = torch.jit.load(path) diff --git a/tests/text/seq2seq/translation/test_data.py b/tests/text/seq2seq/translation/test_data.py index f175681dfb..1a361c509b 100644 --- a/tests/text/seq2seq/translation/test_data.py +++ b/tests/text/seq2seq/translation/test_data.py @@ -17,15 +17,9 @@ import pytest from flash import DataKeys -from flash.core.integrations.transformers.states import TransformersBackboneState from flash.core.utilities.imports import _TEXT_TESTING from flash.text import TranslationData -TEST_BACKBONE = "sshleifer/tiny-mbart" # super small model for testing -TEST_BACKBONE_STATE = TransformersBackboneState( - TEST_BACKBONE, tokenizer_kwargs=dict(src_lang="en_XX", tgt_lang="ro_RO") -) - TEST_CSV_DATA = """input,target this is a sentence one,this is a translated sentence one this is a sentence two,this is a translated sentence two @@ -73,10 +67,9 @@ def test_from_csv(tmpdir): train_file=csv_path, batch_size=1, ) - dm.train_dataset.set_state(TEST_BACKBONE_STATE) batch = next(iter(dm.train_dataloader())) - assert DataKeys.TARGET in batch - assert "input_ids" in batch + assert isinstance(batch[DataKeys.INPUT][0], str) + assert isinstance(batch[DataKeys.TARGET][0], str) @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") @@ -91,14 +84,13 @@ def test_from_files(tmpdir): test_file=csv_path, batch_size=1, ) - dm.train_dataset.set_state(TEST_BACKBONE_STATE) batch = next(iter(dm.val_dataloader())) - assert DataKeys.TARGET in batch - assert "input_ids" in batch + assert isinstance(batch[DataKeys.INPUT][0], str) + assert isinstance(batch[DataKeys.TARGET][0], str) batch = next(iter(dm.test_dataloader())) - assert DataKeys.TARGET in batch - assert "input_ids" in batch + assert isinstance(batch[DataKeys.INPUT][0], str) + assert isinstance(batch[DataKeys.TARGET][0], str) @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") @@ -111,10 +103,9 @@ def test_from_json(tmpdir): train_file=json_path, batch_size=1, ) - dm.train_dataset.set_state(TEST_BACKBONE_STATE) batch = next(iter(dm.train_dataloader())) - assert DataKeys.TARGET in batch - assert "input_ids" in batch + assert isinstance(batch[DataKeys.INPUT][0], str) + assert isinstance(batch[DataKeys.TARGET][0], str) @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") @@ -128,7 +119,6 @@ def test_from_json_with_field(tmpdir): batch_size=1, field="data", ) - dm.train_dataset.set_state(TEST_BACKBONE_STATE) batch = next(iter(dm.train_dataloader())) - assert DataKeys.TARGET in batch - assert "input_ids" in batch + assert isinstance(batch[DataKeys.INPUT][0], str) + assert isinstance(batch[DataKeys.TARGET][0], str) diff --git a/tests/text/seq2seq/translation/test_model.py b/tests/text/seq2seq/translation/test_model.py index 172d3d55cb..506e4ae1fb 100644 --- a/tests/text/seq2seq/translation/test_model.py +++ b/tests/text/seq2seq/translation/test_model.py @@ -58,11 +58,11 @@ def test_jit(tmpdir): } path = os.path.join(tmpdir, "test.pt") - model = TranslationTask(TEST_BACKBONE, val_target_max_length=None) + model = TranslationTask(TEST_BACKBONE) model.eval() # Huggingface only supports `torch.jit.trace` - model = torch.jit.trace(model, [sample_input]) + model = torch.jit.trace(model, [sample_input], check_trace=False) torch.jit.save(model, path) model = torch.jit.load(path)