Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
TextClassificationData from_dataframe (#785)
Browse files Browse the repository at this point in the history
Co-authored-by: Ethan Harris <[email protected]>
  • Loading branch information
kingyiusuen and ethanwharris authored Sep 22, 2021
1 parent 4f0ad73 commit 0c8c24d
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 2 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added `vissl` training_strategies for `ImageEmbedder` ([#682](https://github.com/PyTorchLightning/lightning-flash/pull/682))

- Added support for `from_data_frame` to `TextClassificationData` ([#785](https://github.com/PyTorchLightning/lightning-flash/pull/785))

### Changed

- Changed the default `num_workers` on linux to `0` (matching the default for other OS) ([#759](https://github.com/PyTorchLightning/lightning-flash/pull/759))
Expand Down
141 changes: 139 additions & 2 deletions flash/text/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union

import pandas as pd
import torch
from torch import Tensor
from torch.utils.data.sampler import Sampler

import flash
from flash.core.data.auto_dataset import AutoDataset
from flash.core.data.callback import BaseDataFetcher
from flash.core.data.data_module import DataModule
from flash.core.data.data_source import DataSource, DefaultDataSources, LabelsState
from flash.core.data.process import Deserializer, Postprocess, Preprocess
from flash.core.utilities.imports import _TEXT_AVAILABLE, requires

if _TEXT_AVAILABLE:
from datasets import DatasetDict, load_dataset
from datasets import Dataset, DatasetDict, load_dataset
from transformers import AutoTokenizer, default_data_collator
from transformers.modeling_outputs import SequenceClassifierOutput

Expand Down Expand Up @@ -215,6 +218,66 @@ def __setstate__(self, state):
self.tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True)


class TextDataFrameDataSource(TextDataSource):
@staticmethod
def _multilabel_target(targets, element):
targets = [element.pop(target) for target in targets]
element["labels"] = targets
return element

def load_data(
self,
data: Union[Tuple[pd.DataFrame, Union[str, List[str]], Union[str, List[str]]], Tuple[List[str], List[str]]],
dataset: Optional[Any] = None,
columns: Union[List[str], Tuple[str]] = ("input_ids", "attention_mask", "labels"),
) -> Union[Sequence[Mapping[str, Any]]]:
df, input, target = data
hf_dataset = Dataset.from_pandas(df)

if not self.predicting:
if isinstance(target, List):
# multi-target
dataset.multi_label = True
hf_dataset = hf_dataset.map(partial(self._multilabel_target, target))
dataset.num_classes = len(target)
self.set_state(LabelsState(target))
else:
dataset.multi_label = False
if self.training:
labels = list(sorted(list(set(hf_dataset[target]))))
dataset.num_classes = len(labels)
self.set_state(LabelsState(labels))

labels = self.get_state(LabelsState)

# convert labels to ids
if labels is not None:
labels = labels.labels
label_to_class_mapping = {v: k for k, v in enumerate(labels)}
hf_dataset = hf_dataset.map(partial(self._transform_label, label_to_class_mapping, target))

# Hugging Face models expect target to be named ``labels``.
if target != "labels":
hf_dataset.rename_column_(target, "labels")

hf_dataset = hf_dataset.map(partial(self._tokenize_fn, input=input), batched=True)
hf_dataset.set_format("torch", columns=columns)

return hf_dataset

def predict_load_data(self, data: Any, dataset: AutoDataset):
return self.load_data(data, dataset, columns=["input_ids", "attention_mask"])

def __getstate__(self): # TODO: Find out why this is being pickled
state = self.__dict__.copy()
state.pop("tokenizer")
return state

def __setstate__(self, state):
self.__dict__.update(state)
self.tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True)


class TextSentencesDataSource(TextDataSource):
def __init__(self, backbone: str, max_length: int = 128):
super().__init__(backbone, max_length=max_length)
Expand Down Expand Up @@ -266,6 +329,7 @@ def __init__(
data_sources={
DefaultDataSources.CSV: TextCSVDataSource(self.backbone, max_length=max_length),
DefaultDataSources.JSON: TextJSONDataSource(self.backbone, max_length=max_length),
"data_frame": TextDataFrameDataSource(self.backbone, max_length=max_length),
"sentences": TextSentencesDataSource(self.backbone, max_length=max_length),
},
default_data_source="sentences",
Expand Down Expand Up @@ -313,3 +377,76 @@ class TextClassificationData(DataModule):
@property
def backbone(self) -> Optional[str]:
return getattr(self.preprocess, "backbone", None)

@classmethod
def from_data_frame(
cls,
input_field: str,
target_fields: Union[str, Sequence[str]],
train_data_frame: Optional[pd.DataFrame] = None,
val_data_frame: Optional[pd.DataFrame] = None,
test_data_frame: Optional[pd.DataFrame] = None,
predict_data_frame: Optional[pd.DataFrame] = None,
train_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None,
val_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None,
test_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None,
predict_transform: Optional[Dict[str, Callable]] = None,
data_fetcher: Optional[BaseDataFetcher] = None,
preprocess: Optional[Preprocess] = None,
val_split: Optional[float] = None,
batch_size: int = 4,
num_workers: int = 0,
sampler: Optional[Type[Sampler]] = None,
**preprocess_kwargs: Any,
) -> "DataModule":
"""Creates a :class:`~flash.text.classification.data.TextClassificationData` object from the given pandas
``DataFrame`` objects.
Args:
input_field: The field (column) in the pandas ``DataFrame`` to use for the input.
target_fields: The field or fields (columns) in the pandas ``DataFrame`` to use for the target.
train_data_frame: The pandas ``DataFrame`` containing the training data.
val_data_frame: The pandas ``DataFrame`` containing the validation data.
test_data_frame: The pandas ``DataFrame`` containing the testing data.
predict_data_frame: The pandas ``DataFrame`` containing the data to use when predicting.
train_transform: The dictionary of transforms to use during training which maps
:class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
val_transform: The dictionary of transforms to use during validation which maps
:class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
test_transform: The dictionary of transforms to use during testing which maps
:class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
predict_transform: The dictionary of transforms to use during predicting which maps
:class:`~flash.core.data.process.Preprocess` hook names to callable transforms.
data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the
:class:`~flash.core.data.data_module.DataModule`.
preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the
:class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls``
will be constructed and used.
val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`.
sampler: The ``sampler`` to use for the ``train_dataloader``.
preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used
if ``preprocess = None``.
Returns:
The constructed data module.
"""
return cls.from_data_source(
"data_frame",
(train_data_frame, input_field, target_fields),
(val_data_frame, input_field, target_fields),
(test_data_frame, input_field, target_fields),
(predict_data_frame, input_field, target_fields),
train_transform=train_transform,
val_transform=val_transform,
test_transform=test_transform,
predict_transform=predict_transform,
data_fetcher=data_fetcher,
preprocess=preprocess,
val_split=val_split,
batch_size=batch_size,
num_workers=num_workers,
sampler=sampler,
**preprocess_kwargs,
)
19 changes: 19 additions & 0 deletions tests/text/classification/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@
import os
from pathlib import Path

import pandas as pd
import pytest

from flash.core.utilities.imports import _TEXT_AVAILABLE
from flash.text import TextClassificationData
from flash.text.classification.data import (
TextCSVDataSource,
TextDataFrameDataSource,
TextDataSource,
TextFileDataSource,
TextJSONDataSource,
Expand Down Expand Up @@ -51,6 +53,11 @@
"""


TEST_DATA_FRAME_DATA = pd.DataFrame(
{"sentence": ["this is a sentence one", "this is a sentence two", "this is a sentence three"], "lab": [0, 1, 0]},
)


def csv_data(tmpdir):
path = Path(tmpdir) / "data.csv"
path.write_text(TEST_CSV_DATA)
Expand Down Expand Up @@ -123,6 +130,17 @@ def test_from_json_with_field(tmpdir):
assert "input_ids" in batch


@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows")
@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.")
def test_from_data_frame():
dm = TextClassificationData.from_data_frame(
"sentence", "lab", backbone=TEST_BACKBONE, train_data_frame=TEST_DATA_FRAME_DATA, batch_size=1
)
batch = next(iter(dm.train_dataloader()))
assert batch["labels"].item() in [0, 1]
assert "input_ids" in batch


@pytest.mark.skipif(_TEXT_AVAILABLE, reason="text libraries are installed.")
def test_text_module_not_found_error():
with pytest.raises(ModuleNotFoundError, match="[text]"):
Expand All @@ -138,6 +156,7 @@ def test_text_module_not_found_error():
(TextFileDataSource, {"filetype": "csv"}),
(TextCSVDataSource, {}),
(TextJSONDataSource, {}),
(TextDataFrameDataSource, {}),
(TextSentencesDataSource, {}),
],
)
Expand Down

0 comments on commit 0c8c24d

Please sign in to comment.