diff --git a/CHANGELOG.md b/CHANGELOG.md index 91f7061a4b..3845a894fc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for `from_data_frame` to `TextClassificationData` ([#785](https://github.com/PyTorchLightning/lightning-flash/pull/785)) +- Added support for `from_lists` to `TextClassificationData` ([#805](https://github.com/PyTorchLightning/lightning-flash/pull/805)) + ### Changed - Changed the default `num_workers` on linux to `0` (matching the default for other OS) ([#759](https://github.com/PyTorchLightning/lightning-flash/pull/759)) diff --git a/flash/core/data/data_source.py b/flash/core/data/data_source.py index fb4260ed89..de20be9791 100644 --- a/flash/core/data/data_source.py +++ b/flash/core/data/data_source.py @@ -160,6 +160,9 @@ class DefaultDataSources(LightningEnum): JSON = "json" DATASETS = "datasets" FIFTYONE = "fiftyone" + DATAFRAME = "data_frame" + LISTS = "lists" + SENTENCES = "sentences" LABELSTUDIO = "labelstudio" # TODO: Create a FlashEnum class??? diff --git a/flash/text/classification/data.py b/flash/text/classification/data.py index 085b30988c..ab02e746a7 100644 --- a/flash/text/classification/data.py +++ b/flash/text/classification/data.py @@ -228,7 +228,7 @@ def _multilabel_target(targets, element): def load_data( self, - data: Union[Tuple[pd.DataFrame, Union[str, List[str]], Union[str, List[str]]], Tuple[List[str], List[str]]], + data: Tuple[pd.DataFrame, Union[str, List[str]], Union[str, List[str]]], dataset: Optional[Any] = None, columns: Union[List[str], Tuple[str]] = ("input_ids", "attention_mask", "labels"), ) -> Union[Sequence[Mapping[str, Any]]]: @@ -279,6 +279,55 @@ def __setstate__(self, state): self.tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True) +class TextListDataSource(TextDataSource): + def load_data( + self, + data: Tuple[List[str], Union[List[Any], List[List[Any]]]], + dataset: Optional[Any] = None, + columns: Union[List[str], Tuple[str]] = ("input_ids", "attention_mask", "labels"), + ) -> Union[Sequence[Mapping[str, Any]]]: + input, target = data + hf_dataset = Dataset.from_dict({"input": input, "labels": target}) + + if not self.predicting: + if isinstance(target[0], List): + # multi-target + dataset.multi_label = True + dataset.num_classes = len(target[0]) + self.set_state(LabelsState(target)) + else: + dataset.multi_label = False + if self.training: + labels = list(sorted(list(set(hf_dataset["labels"])))) + dataset.num_classes = len(labels) + self.set_state(LabelsState(labels)) + + labels = self.get_state(LabelsState) + + # convert labels to ids + if labels is not None: + labels = labels.labels + label_to_class_mapping = {v: k for k, v in enumerate(labels)} + hf_dataset = hf_dataset.map(partial(self._transform_label, label_to_class_mapping, "labels")) + + hf_dataset = hf_dataset.map(partial(self._tokenize_fn, input="input"), batched=True) + hf_dataset.set_format("torch", columns=columns) + + return hf_dataset + + def predict_load_data(self, data: Any, dataset: AutoDataset): + return self.load_data(data, dataset, columns=["input_ids", "attention_mask"]) + + def __getstate__(self): # TODO: Find out why this is being pickled + state = self.__dict__.copy() + state.pop("tokenizer") + return state + + def __setstate__(self, state): + self.__dict__.update(state) + self.tokenizer = AutoTokenizer.from_pretrained(self.backbone, use_fast=True) + + class TextSentencesDataSource(TextDataSource): def __init__(self, backbone: str, max_length: int = 128): super().__init__(backbone, max_length=max_length) @@ -330,13 +379,14 @@ def __init__( data_sources={ DefaultDataSources.CSV: TextCSVDataSource(self.backbone, max_length=max_length), DefaultDataSources.JSON: TextJSONDataSource(self.backbone, max_length=max_length), - "data_frame": TextDataFrameDataSource(self.backbone, max_length=max_length), - "sentences": TextSentencesDataSource(self.backbone, max_length=max_length), + DefaultDataSources.DATAFRAME: TextDataFrameDataSource(self.backbone, max_length=max_length), + DefaultDataSources.LISTS: TextListDataSource(self.backbone, max_length=max_length), + DefaultDataSources.SENTENCES: TextSentencesDataSource(self.backbone, max_length=max_length), DefaultDataSources.LABELSTUDIO: LabelStudioTextClassificationDataSource( backbone=self.backbone, max_length=max_length ), }, - default_data_source="sentences", + default_data_source=DefaultDataSources.SENTENCES, deserializer=TextDeserializer(backbone, max_length), ) @@ -437,7 +487,7 @@ def from_data_frame( The constructed data module. """ return cls.from_data_source( - "data_frame", + DefaultDataSources.DATAFRAME, (train_data_frame, input_field, target_fields), (val_data_frame, input_field, target_fields), (test_data_frame, input_field, target_fields), @@ -454,3 +504,81 @@ def from_data_frame( sampler=sampler, **preprocess_kwargs, ) + + @classmethod + def from_lists( + cls, + train_data: Optional[List[str]] = None, + train_targets: Optional[Union[List[Any], List[List[Any]]]] = None, + val_data: Optional[List[str]] = None, + val_targets: Optional[Union[List[Any], List[List[Any]]]] = None, + test_data: Optional[List[str]] = None, + test_targets: Optional[Union[List[Any], List[List[Any]]]] = None, + predict_data: Optional[List[str]] = None, + train_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None, + val_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None, + test_transform: Optional[Union[Callable, List, Dict[str, Callable]]] = None, + predict_transform: Optional[Dict[str, Callable]] = None, + data_fetcher: Optional[BaseDataFetcher] = None, + preprocess: Optional[Preprocess] = None, + val_split: Optional[float] = None, + batch_size: int = 4, + num_workers: int = 0, + sampler: Optional[Type[Sampler]] = None, + **preprocess_kwargs: Any, + ) -> "DataModule": + """Creates a :class:`~flash.text.classification.data.TextClassificationData` object from the given Python + lists. + + Args: + train_data: A list of sentences to use as the train inputs. + train_targets: A list of targets to use as the train targets. For multi-label classification, the targets + should be provided as a list of lists, where each inner list contains the targets for a sample. + val_data: A list of sentences to use as the validation inputs. + val_targets: A list of targets to use as the validation targets. For multi-label classification, the targets + should be provided as a list of lists, where each inner list contains the targets for a sample. + test_data: A list of sentences to use as the test inputs. + test_targets: A list of targets to use as the test targets. For multi-label classification, the targets + should be provided as a list of lists, where each inner list contains the targets for a sample. + predict_data: A list of sentences to use when predicting. + train_transform: The dictionary of transforms to use during training which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + val_transform: The dictionary of transforms to use during validation which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + test_transform: The dictionary of transforms to use during testing which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + predict_transform: The dictionary of transforms to use during predicting which maps + :class:`~flash.core.data.process.Preprocess` hook names to callable transforms. + data_fetcher: The :class:`~flash.core.data.callback.BaseDataFetcher` to pass to the + :class:`~flash.core.data.data_module.DataModule`. + preprocess: The :class:`~flash.core.data.data.Preprocess` to pass to the + :class:`~flash.core.data.data_module.DataModule`. If ``None``, ``cls.preprocess_cls`` + will be constructed and used. + val_split: The ``val_split`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + sampler: The ``sampler`` to use for the ``train_dataloader``. + preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used + if ``preprocess = None``. + + Returns: + The constructed data module. + """ + return cls.from_data_source( + DefaultDataSources.LISTS, + (train_data, train_targets), + (val_data, val_targets), + (test_data, test_targets), + predict_data, + train_transform=train_transform, + val_transform=val_transform, + test_transform=test_transform, + predict_transform=predict_transform, + data_fetcher=data_fetcher, + preprocess=preprocess, + val_split=val_split, + batch_size=batch_size, + num_workers=num_workers, + sampler=sampler, + **preprocess_kwargs, + ) diff --git a/tests/text/classification/test_data.py b/tests/text/classification/test_data.py index 238f419522..a2f57be5f8 100644 --- a/tests/text/classification/test_data.py +++ b/tests/text/classification/test_data.py @@ -25,6 +25,7 @@ TextDataSource, TextFileDataSource, TextJSONDataSource, + TextListDataSource, TextSentencesDataSource, ) from tests.helpers.utils import _TEXT_TESTING @@ -54,10 +55,19 @@ TEST_DATA_FRAME_DATA = pd.DataFrame( - {"sentence": ["this is a sentence one", "this is a sentence two", "this is a sentence three"], "lab": [0, 1, 0]}, + { + "sentence": ["this is a sentence one", "this is a sentence two", "this is a sentence three"], + "lab1": [0, 1, 0], + "lab2": [1, 0, 1], + }, ) +TEST_LIST_DATA = ["this is a sentence one", "this is a sentence two", "this is a sentence three"] +TEST_LIST_TARGETS = [0, 1, 0] +TEST_LIST_TARGETS_MULTILABEL = [[0, 1], [1, 0], [0, 1]] + + def csv_data(tmpdir): path = Path(tmpdir) / "data.csv" path.write_text(TEST_CSV_DATA) @@ -134,13 +144,46 @@ def test_from_json_with_field(tmpdir): @pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") def test_from_data_frame(): dm = TextClassificationData.from_data_frame( - "sentence", "lab", backbone=TEST_BACKBONE, train_data_frame=TEST_DATA_FRAME_DATA, batch_size=1 + "sentence", "lab1", backbone=TEST_BACKBONE, train_data_frame=TEST_DATA_FRAME_DATA, batch_size=1 ) batch = next(iter(dm.train_dataloader())) assert batch["labels"].item() in [0, 1] assert "input_ids" in batch +@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") +@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") +def test_from_data_frame_multilabel(): + dm = TextClassificationData.from_data_frame( + "sentence", ["lab1", "lab2"], backbone=TEST_BACKBONE, train_data_frame=TEST_DATA_FRAME_DATA, batch_size=1 + ) + batch = next(iter(dm.train_dataloader())) + assert all([label in [0, 1] for label in batch["labels"][0]]) + assert "input_ids" in batch + + +@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") +@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") +def test_from_lists(): + dm = TextClassificationData.from_lists( + backbone=TEST_BACKBONE, train_data=TEST_LIST_DATA, train_targets=TEST_LIST_TARGETS, batch_size=1 + ) + batch = next(iter(dm.train_dataloader())) + assert batch["labels"].item() in [0, 1] + assert "input_ids" in batch + + +@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") +@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") +def test_from_lists_multilabel(): + dm = TextClassificationData.from_lists( + backbone=TEST_BACKBONE, train_data=TEST_LIST_DATA, train_targets=TEST_LIST_TARGETS_MULTILABEL, batch_size=1 + ) + batch = next(iter(dm.train_dataloader())) + assert all([label in [0, 1] for label in batch["labels"][0]]) + assert "input_ids" in batch + + @pytest.mark.skipif(_TEXT_AVAILABLE, reason="text libraries are installed.") def test_text_module_not_found_error(): with pytest.raises(ModuleNotFoundError, match="[text]"): @@ -157,6 +200,7 @@ def test_text_module_not_found_error(): (TextCSVDataSource, {}), (TextJSONDataSource, {}), (TextDataFrameDataSource, {}), + (TextListDataSource, {}), (TextSentencesDataSource, {}), ], )