diff --git a/CHANGELOG.md b/CHANGELOG.md index ad47448567..66d0ab8c1a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `vissl` training_strategies for `ImageEmbedder` ([#682](https://github.com/PyTorchLightning/lightning-flash/pull/682)) +- Added support for `from_data_frame` to `TextClassificationData` ([#785](https://github.com/PyTorchLightning/lightning-flash/pull/785)) + ### Changed - Changed the default `num_workers` on linux to `0` (matching the default for other OS) ([#759](https://github.com/PyTorchLightning/lightning-flash/pull/759)) diff --git a/flash/text/classification/data.py b/flash/text/classification/data.py index 3e1c7e39cb..c71538c0b9 100644 --- a/flash/text/classification/data.py +++ b/flash/text/classification/data.py @@ -12,20 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import partial -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union +import pandas as pd import torch from torch import Tensor +from torch.utils.data.sampler import Sampler import flash from flash.core.data.auto_dataset import AutoDataset +from flash.core.data.callback import BaseDataFetcher from flash.core.data.data_module import DataModule from flash.core.data.data_source import DataSource, DefaultDataSources, LabelsState from flash.core.data.process import Deserializer, Postprocess, Preprocess from flash.core.utilities.imports import _TEXT_AVAILABLE, requires if _TEXT_AVAILABLE: - from datasets import DatasetDict, load_dataset + from datasets import Dataset, DatasetDict, load_dataset from transformers import AutoTokenizer, default_data_collator from transformers.modeling_outputs import SequenceClassifierOutput @@ -215,6 +218,66 @@ def __setstate__(self, state): self.tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True) +class TextDataFrameDataSource(TextDataSource): + @staticmethod + def _multilabel_target(targets, element): + targets = [element.pop(target) for target in targets] + element["labels"] = targets + return element + + def load_data( + self, + data: Union[Tuple[pd.DataFrame, Union[str, List[str]], Union[str, List[str]]], Tuple[List[str], List[str]]], + dataset: Optional[Any] = None, + columns: Union[List[str], Tuple[str]] = ("input_ids", "attention_mask", "labels"), + ) -> Union[Sequence[Mapping[str, Any]]]: + df, input, target = data + hf_dataset = Dataset.from_pandas(df) + + if not self.predicting: + if isinstance(target, List): + # multi-target + dataset.multi_label = True + hf_dataset = hf_dataset.map(partial(self._multilabel_target, target)) + dataset.num_classes = len(target) + self.set_state(LabelsState(target)) + else: + dataset.multi_label = False + if self.training: + labels = list(sorted(list(set(hf_dataset[target])))) + dataset.num_classes = len(labels) + self.set_state(LabelsState(labels)) + + labels = self.get_state(LabelsState) + + # convert labels to ids + if labels is not None: + labels = labels.labels + label_to_class_mapping = {v: k for k, v in enumerate(labels)} + hf_dataset = hf_dataset.map(partial(self._transform_label, label_to_class_mapping, target)) + + # Hugging Face models expect target to be named ``labels``. + if target != "labels": + hf_dataset.rename_column_(target, "labels") + + hf_dataset = hf_dataset.map(partial(self._tokenize_fn, input=input), batched=True) + hf_dataset.set_format("torch", columns=columns) + + return hf_dataset + + def predict_load_data(self, data: Any, dataset: AutoDataset): + return self.load_data(data, dataset, columns=["input_ids", "attention_mask"]) + + def __getstate__(self): # TODO: Find out why this is being pickled + state = self.__dict__.copy() + state.pop("tokenizer") + return state + + def __setstate__(self, state): + self.__dict__.update(state) + self.tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True) + + class TextSentencesDataSource(TextDataSource): def __init__(self, backbone: str, max_length: int = 128): super().__init__(backbone, max_length=max_length) @@ -266,6 +329,7 @@ def __init__( data_sources={ DefaultDataSources.CSV: TextCSVDataSource(self.backbone, max_length=max_length), DefaultDataSources.JSON: TextJSONDataSource(self.backbone, max_length=max_length), + "data_frame": TextDataFrameDataSource(self.backbone, max_length=max_length), "sentences": TextSentencesDataSource(self.backbone, max_length=max_length), }, default_data_source="sentences", @@ -313,3 +377,76 @@ class TextClassificationData(DataModule): @property def backbone(self) -> Optional[str]: return getattr(self.preprocess, "backbone", None) + + @classmethod + def from_data_frame( + cls, + input_field: str, + target_fields: Union[str, Sequence[str]], + train_data_frame: Optional[pd.DataFrame] = None, + val_data_frame: Optional[pd.DataFrame] = None, + test_data_frame: Optional[pd.DataFrame] = None, + predict_data_frame: Optional[pd.DataFrame] = None, + train_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None, + val_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None, + test_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + data_fetcher: Optional[BaseDataFetcher] = None, + preprocess: Optional[Preprocess] = None, + val_split: Optional[float] = None, + batch_size: int = 4, + num_workers: int = 0, + sampler: Optional[Type[Sampler]] = None, + **preprocess_kwargs: Any, + ) -> "DataModule": + """Creates a :class:`~flash.text.classification.data.TextClassificationData` object from the given pandas + ``DataFrame`` objects. + + Args: + input_field: The field (column) in the pandas ``DataFrame`` to use for the input. + target_fields: The field or fields (columns) in the pandas ``DataFrame`` to use for the target. + train_data_frame: The pandas ``DataFrame`` containing the training data. + val_data_frame: The pandas ``DataFrame`` containing the validation data. + test_data_frame: The pandas ``DataFrame`` containing the testing data. + predict_data_frame: The pandas ``DataFrame`` containing the data to use when predicting. + train_transform: The dictionary of transforms to use during training which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + val_transform: The dictionary of transforms to use during validation which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + test_transform: The dictionary of transforms to use during testing which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + predict_transform: The dictionary of transforms to use during predicting which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the + :class:`~flash.core.data.data_module.DataModule`. + preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the + :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` + will be constructed and used. + val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + sampler: The ``sampler`` to use for the ``train_dataloader``. + preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used + if ``preprocess = None``. + + Returns: + The constructed data module. + """ + return cls.from_data_source( + "data_frame", + (train_data_frame, input_field, target_fields), + (val_data_frame, input_field, target_fields), + (test_data_frame, input_field, target_fields), + (predict_data_frame, input_field, target_fields), + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_fetcher=data_fetcher, + preprocess=preprocess, + val_split=val_split, + batch_size=batch_size, + num_workers=num_workers, + sampler=sampler, + **preprocess_kwargs, + ) diff --git a/tests/text/classification/test_data.py b/tests/text/classification/test_data.py index 4c42909b35..238f419522 100644 --- a/tests/text/classification/test_data.py +++ b/tests/text/classification/test_data.py @@ -14,12 +14,14 @@ import os from pathlib import Path +import pandas as pd import pytest from flash.core.utilities.imports import _TEXT_AVAILABLE from flash.text import TextClassificationData from flash.text.classification.data import ( TextCSVDataSource, + TextDataFrameDataSource, TextDataSource, TextFileDataSource, TextJSONDataSource, @@ -51,6 +53,11 @@ """ +TEST_DATA_FRAME_DATA = pd.DataFrame( + {"sentence": ["this is a sentence one", "this is a sentence two", "this is a sentence three"], "lab": [0, 1, 0]}, +) + + def csv_data(tmpdir): path = Path(tmpdir) / "data.csv" path.write_text(TEST_CSV_DATA) @@ -123,6 +130,17 @@ def test_from_json_with_field(tmpdir): assert "input_ids" in batch +@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") +@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") +def test_from_data_frame(): + dm = TextClassificationData.from_data_frame( + "sentence", "lab", backbone=TEST_BACKBONE, train_data_frame=TEST_DATA_FRAME_DATA, batch_size=1 + ) + batch = next(iter(dm.train_dataloader())) + assert batch["labels"].item() in [0, 1] + assert "input_ids" in batch + + @pytest.mark.skipif(_TEXT_AVAILABLE, reason="text libraries are installed.") def test_text_module_not_found_error(): with pytest.raises(ModuleNotFoundError, match="[text]"): @@ -138,6 +156,7 @@ def test_text_module_not_found_error(): (TextFileDataSource, {"filetype": "csv"}), (TextCSVDataSource, {}), (TextJSONDataSource, {}), + (TextDataFrameDataSource, {}), (TextSentencesDataSource, {}), ], )