Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Onboard text classification inputs to new object (#1022)
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris authored Dec 6, 2021
1 parent 5dd695f commit 159cd98
Show file tree
Hide file tree
Showing 19 changed files with 567 additions and 596 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -173,4 +173,5 @@ mini-imagenet*

docs/source/_static/images/course_UvA-DL
docs/source/_static/images/lightning_examples
docs/source/_static/images/flash_tutorials
docs/source/notebooks
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Removed `Task.predict`, use `Trainer.predict` instead ([#1030](https://github.com/PyTorchLightning/lightning-flash/pull/1030))

- Removed the `backbone` argument from `TextClassificationData`, it is now sufficient to only provide a `backbone` argument to the `TextClassifier` ([#1022](https://github.com/PyTorchLightning/lightning-flash/pull/1022))

## [0.5.2] - 2021-11-05

### Added
Expand Down
14 changes: 6 additions & 8 deletions docs/source/api/text.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,14 @@ ______________
~classification.model.TextClassifier
~classification.data.TextClassificationData

classification.data.TextClassificationOutputTransform
classification.data.TextClassificationInputTransform
classification.data.TextDeserializer
classification.data.TextInput
classification.data.TextCSVInput
classification.data.TextJSONInput
classification.data.TextDataFrameInput
classification.data.TextParquetInput
classification.data.TextHuggingFaceDatasetInput
classification.data.TextListInput
classification.data.TextClassificationInput
classification.data.TextClassificationCSVInput
classification.data.TextClassificationJSONInput
classification.data.TextClassificationDataFrameInput
classification.data.TextClassificationParquetInput
classification.data.TextClassificationListInput

Question Answering
__________________
Expand Down
6 changes: 3 additions & 3 deletions docs/source/template/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -201,11 +201,11 @@ OutputTransform

:class:`~flash.core.data.io.output_transform.OutputTransform` contains any transforms that need to be applied *after* the model.
You may want to use it for: converting tokens back into text, applying an inverse normalization to an output image, resizing a generated image back to the size of the input, etc.
As an example, here's the :class:`~text.classification.data.TextClassificationOutputTransform` which gets the logits from a ``SequenceClassifierOutput``:
As an example, here's the :class:`~text.seq2seq.core.data.Seq2SeqOutputTransform` which decodes tokenized model outputs:

.. literalinclude:: ../../../flash/text/classification/data.py
.. literalinclude:: ../../../flash/text/seq2seq/core/data.py
:language: python
:pyobject: TextClassificationOutputTransform
:pyobject: Seq2SeqOutputTransform

In your :class:`~flash.core.data.io.input.Input` or :class:`~flash.core.data.io.input_transform.InputTransform`, you can add metadata to the batch using the :attr:`~flash.core.data.io.input.DataKeys.METADATA` key.
Your :class:`~flash.core.data.io.output_transform.OutputTransform` can then use this metadata in its transforms.
Expand Down
2 changes: 1 addition & 1 deletion flash/core/data/io/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ class InputFormat(LightningEnum):
JSON = "json"
PARQUET = "parquet"
DATASETS = "datasets"
HUGGINGFACE_DATASET = "hf_dataset"
HUGGINGFACE_DATASET = "hf_datasets"
FIFTYONE = "fiftyone"
DATAFRAME = "data_frame"
LISTS = "lists"
Expand Down
12 changes: 11 additions & 1 deletion flash/core/data/splits.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from torch.utils.data import Dataset

import flash
from flash.core.data.properties import Properties


Expand All @@ -26,7 +27,11 @@ class SplitDataset(Properties, Dataset):
def __init__(self, dataset: Any, indices: List[int] = None, use_duplicated_indices: bool = False) -> None:
kwargs = {}
if isinstance(dataset, Properties):
kwargs = {"running_stage": dataset._running_stage, "state": dataset._state}
kwargs = dict(
running_stage=dataset._running_stage,
data_pipeline_state=dataset._data_pipeline_state,
state=dataset._state,
)
super().__init__(**kwargs)

if indices is None:
Expand All @@ -45,6 +50,11 @@ def __init__(self, dataset: Any, indices: List[int] = None, use_duplicated_indic
self.dataset = dataset
self.indices = indices

def attach_data_pipeline_state(self, data_pipeline_state: "flash.core.data.data_pipeline.DataPipelineState"):
super().attach_data_pipeline_state(data_pipeline_state)
if isinstance(self.dataset, Properties):
self.dataset.attach_data_pipeline_state(data_pipeline_state)

def __getattr__(self, key: str):
if key != "dataset":
return getattr(self.dataset, key)
Expand Down
34 changes: 14 additions & 20 deletions flash/core/integrations/labelstudio/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,12 @@
from flash.core.data.io.input_base import Input, IterableInput
from flash.core.data.properties import ProcessState, Properties
from flash.core.data.utils import image_default_loader
from flash.core.utilities.imports import _PYTORCHVIDEO_AVAILABLE, _TEXT_AVAILABLE
from flash.core.utilities.imports import _PYTORCHVIDEO_AVAILABLE
from flash.core.utilities.stages import RunningStage

if _PYTORCHVIDEO_AVAILABLE:
from pytorchvideo.data.clip_sampling import make_clip_sampler

if _TEXT_AVAILABLE:
from transformers import AutoTokenizer


@dataclass(unsafe_hash=True, frozen=True)
class LabelStudioState(ProcessState):
Expand Down Expand Up @@ -277,11 +274,8 @@ class LabelStudioTextClassificationInput(LabelStudioInput):
Export data should point to text data
"""

def __init__(self, *args, backbone=None, max_length=128, **kwargs):
if backbone:
self.backbone = backbone
self.tokenizer = AutoTokenizer.from_pretrained(backbone, use_fast=True)
self.max_length = max_length
def __init__(self, *args, max_length=128, **kwargs):
self.max_length = max_length
super().__init__(*args, **kwargs)

def load_sample(self, sample: Mapping[str, Any] = None) -> Any:
Expand All @@ -291,17 +285,17 @@ def load_sample(self, sample: Mapping[str, Any] = None) -> Any:

assert self.state

if self.backbone:
data = ""
for key in sample.get("data"):
data += sample.get("data").get(key)
tokenized_data = self.tokenizer(data, max_length=self.max_length, truncation=True, padding="max_length")
for key in tokenized_data:
tokenized_data[key] = torch.tensor(tokenized_data[key])
tokenized_data["labels"] = _get_labels_from_sample(sample["label"], self.state.classes)
# separate text data type block
result = tokenized_data
return result
data = ""
for key in sample.get("data"):
data += sample.get("data").get(key)
tokenized_data = self.get_state(flash.text.classification.model.TextClassificationBackboneState).tokenizer(
data, max_length=self.max_length, truncation=True, padding="max_length"
)
for key in tokenized_data:
tokenized_data[key] = torch.tensor(tokenized_data[key])
tokenized_data["labels"] = _get_labels_from_sample(sample["label"], self.state.classes)
# separate text data type block
return tokenized_data


class LabelStudioVideoClassificationInput(LabelStudioIterableInput):
Expand Down
8 changes: 4 additions & 4 deletions flash/image/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,10 +383,10 @@ def from_data_frame(
predict_data = (predict_data_frame, input_field, predict_images_root, predict_resolver)

return cls(
ImageClassificationCSVInput(RunningStage.TRAINING, *train_data, **dataset_kwargs),
ImageClassificationCSVInput(RunningStage.VALIDATING, *val_data, **dataset_kwargs),
ImageClassificationCSVInput(RunningStage.TESTING, *test_data, **dataset_kwargs),
ImageClassificationCSVInput(RunningStage.PREDICTING, *predict_data, **dataset_kwargs),
ImageClassificationDataFrameInput(RunningStage.TRAINING, *train_data, **dataset_kwargs),
ImageClassificationDataFrameInput(RunningStage.VALIDATING, *val_data, **dataset_kwargs),
ImageClassificationDataFrameInput(RunningStage.TESTING, *test_data, **dataset_kwargs),
ImageClassificationDataFrameInput(RunningStage.PREDICTING, *predict_data, **dataset_kwargs),
input_transform=cls.input_transform_cls(
train_transform,
val_transform,
Expand Down
23 changes: 5 additions & 18 deletions flash/text/classification/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,7 @@


def from_imdb(
backbone: str = "prajjwal1/bert-medium",
batch_size: int = 4,
num_workers: int = 0,
**input_transform_kwargs,
**data_module_kwargs,
) -> TextClassificationData:
"""Downloads and loads the IMDB sentiment classification data set."""
download_data("https://pl-flash-data.s3.amazonaws.com/imdb.zip", "./data/")
Expand All @@ -32,31 +29,22 @@ def from_imdb(
"sentiment",
train_file="data/imdb/train.csv",
val_file="data/imdb/valid.csv",
backbone=backbone,
batch_size=batch_size,
num_workers=num_workers,
**input_transform_kwargs,
**data_module_kwargs,
)


def from_toxic(
backbone: str = "unitary/toxic-bert",
val_split: float = 0.1,
batch_size: int = 4,
num_workers: int = 0,
**input_transform_kwargs,
**data_module_kwargs,
) -> TextClassificationData:
"""Downloads and loads the Jigsaw toxic comments data set."""
download_data("https://pl-flash-data.s3.amazonaws.com/jigsaw_toxic_comments.zip", "./data")
return TextClassificationData.from_csv(
"comment_text",
["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"],
train_file="data/jigsaw_toxic_comments/train.csv",
backbone=backbone,
val_split=val_split,
batch_size=batch_size,
num_workers=num_workers,
**input_transform_kwargs,
**data_module_kwargs,
)


Expand All @@ -70,8 +58,7 @@ def text_classification():
default_arguments={
"trainer.max_epochs": 3,
},
datamodule_attributes={"num_classes", "multi_label", "backbone"},
legacy=True,
datamodule_attributes={"num_classes", "multi_label"},
)

cli.trainer.save_checkpoint("text_classification_model.pt")
Expand Down
Loading

0 comments on commit 159cd98

Please sign in to comment.