From ae6c961fd83a0998c49050009a8d26cbe7cfda09 Mon Sep 17 00:00:00 2001 From: Karthik Rangasai Date: Wed, 14 Jul 2021 14:45:13 +0530 Subject: [PATCH 1/5] Added field parameter to the from_json method with other required changes. --- flash/core/data/data_module.py | 9 +++++---- flash/text/classification/data.py | 27 +++++++++++++++++++++------ flash/text/seq2seq/core/data.py | 27 +++++++++++++++++++++------ 3 files changed, 47 insertions(+), 16 deletions(-) diff --git a/flash/core/data/data_module.py b/flash/core/data/data_module.py index 97e8e7a49c..2b34879f29 100644 --- a/flash/core/data/data_module.py +++ b/flash/core/data/data_module.py @@ -839,6 +839,7 @@ def from_json( batch_size: int = 4, num_workers: Optional[int] = None, sampler: Optional[Sampler] = None, + field: str = None, **preprocess_kwargs: Any, ) -> 'DataModule': """Creates a :class:`~flash.core.data.data_module.DataModule` object from the given JSON files using the @@ -889,10 +890,10 @@ def from_json( """ return cls.from_data_source( DefaultDataSources.JSON, - (train_file, input_fields, target_fields), - (val_file, input_fields, target_fields), - (test_file, input_fields, target_fields), - (predict_file, input_fields, target_fields), + (train_file, input_fields, target_fields, field), + (val_file, input_fields, target_fields, field), + (test_file, input_fields, target_fields, field), + (predict_file, input_fields, target_fields, field), train_transform=train_transform, val_transform=val_transform, test_transform=test_transform, diff --git a/flash/text/classification/data.py b/flash/text/classification/data.py index 5049c0e975..a0296cb1b3 100644 --- a/flash/text/classification/data.py +++ b/flash/text/classification/data.py @@ -110,7 +110,10 @@ def load_data( dataset: Optional[Any] = None, columns: Union[List[str], Tuple[str]] = ("input_ids", "attention_mask", "labels"), ) -> Union[Sequence[Mapping[str, Any]]]: - file, input, target = data + if self.filetype == 'json': + file, input, target, field = data + else: + file, input, target = data data_files = {} @@ -120,13 +123,25 @@ def load_data( # FLASH_TESTING is set in the CI to run faster. if flash._IS_TESTING and not torch.cuda.is_available(): try: - dataset_dict = DatasetDict({ - stage: load_dataset(self.filetype, data_files=data_files, split=[f'{stage}[:20]'])[0] - }) + if self.filetype == 'json' and field is not None: + dataset_dict = DatasetDict({ + stage: load_dataset(self.filetype, data_files=data_files, split=[f'{stage}[:20]'], + field=field)[0] + }) + else: + dataset_dict = DatasetDict({ + stage: load_dataset(self.filetype, data_files=data_files, split=[f'{stage}[:20]'])[0] + }) except Exception: - dataset_dict = load_dataset(self.filetype, data_files=data_files) + if self.filetype == 'json' and field is not None: + dataset_dict = load_dataset(self.filetype, data_files=data_files, field=field) + else: + dataset_dict = load_dataset(self.filetype, data_files=data_files) else: - dataset_dict = load_dataset(self.filetype, data_files=data_files) + if self.filetype == 'json' and field is not None: + dataset_dict = load_dataset(self.filetype, data_files=data_files, field=field) + else: + dataset_dict = load_dataset(self.filetype, data_files=data_files) if not self.predicting: if isinstance(target, List): diff --git a/flash/text/seq2seq/core/data.py b/flash/text/seq2seq/core/data.py index 1b29d7e2c2..3dd0ef1cea 100644 --- a/flash/text/seq2seq/core/data.py +++ b/flash/text/seq2seq/core/data.py @@ -98,7 +98,10 @@ def __init__( def load_data(self, data: Any, columns: List[str] = None) -> 'datasets.Dataset': if columns is None: columns = ["input_ids", "attention_mask", "labels"] - file, input, target = data + if self.filetype == 'json': + file, input, target, field = data + else: + file, input, target = data data_files = {} stage = self._running_stage.value data_files[stage] = str(file) @@ -106,13 +109,25 @@ def load_data(self, data: Any, columns: List[str] = None) -> 'datasets.Dataset': # FLASH_TESTING is set in the CI to run faster. if flash._IS_TESTING: try: - dataset_dict = DatasetDict({ - stage: load_dataset(self.filetype, data_files=data_files, split=[f'{stage}[:20]'])[0] - }) + if self.filetype == 'json' and field is not None: + dataset_dict = DatasetDict({ + stage: load_dataset(self.filetype, data_files=data_files, split=[f'{stage}[:20]'], + field=field)[0] + }) + else: + dataset_dict = DatasetDict({ + stage: load_dataset(self.filetype, data_files=data_files, split=[f'{stage}[:20]'])[0] + }) except Exception: - dataset_dict = load_dataset(self.filetype, data_files=data_files) + if self.filetype == 'json' and field is not None: + dataset_dict = load_dataset(self.filetype, data_files=data_files, field=field) + else: + dataset_dict = load_dataset(self.filetype, data_files=data_files) else: - dataset_dict = load_dataset(self.filetype, data_files=data_files) + if self.filetype == 'json' and field is not None: + dataset_dict = load_dataset(self.filetype, data_files=data_files, field=field) + else: + dataset_dict = load_dataset(self.filetype, data_files=data_files) dataset_dict = dataset_dict.map(partial(self._tokenize_fn, input=input, target=target), batched=True) dataset_dict.set_format(columns=columns) From 65725e28e70f96d3955129e4737943656d82cf01 Mon Sep 17 00:00:00 2001 From: Karthik Rangasai Date: Wed, 14 Jul 2021 15:16:05 +0530 Subject: [PATCH 2/5] Updating field parameter type and CHANGELOG --- CHANGELOG.md | 1 + flash/core/data/data_module.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index aded4ca732..8d3fc024d3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for nesting of `Task` objects ([#575](https://github.com/PyTorchLightning/lightning-flash/pull/575)) +- Added support for `field` parameter for loadng JSON based datasets in text tasks. ([#585](https://github.com/PyTorchLightning/lightning-flash/pull/585)) ### Changed - Changed how pretrained flag works for loading weights for ImageClassifier task ([#560](https://github.com/PyTorchLightning/lightning-flash/pull/560)) diff --git a/flash/core/data/data_module.py b/flash/core/data/data_module.py index 2b34879f29..bbca47fcea 100644 --- a/flash/core/data/data_module.py +++ b/flash/core/data/data_module.py @@ -839,7 +839,7 @@ def from_json( batch_size: int = 4, num_workers: Optional[int] = None, sampler: Optional[Sampler] = None, - field: str = None, + field: Optional[str] = None, **preprocess_kwargs: Any, ) -> 'DataModule': """Creates a :class:`~flash.core.data.data_module.DataModule` object from the given JSON files using the From b8a29dfdd407641c0db297b28fb678da9e1c4bfe Mon Sep 17 00:00:00 2001 From: Karthik Rangasai Date: Wed, 14 Jul 2021 16:38:29 +0530 Subject: [PATCH 3/5] Added docs for the new parameter --- flash/core/data/data_module.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/flash/core/data/data_module.py b/flash/core/data/data_module.py index bbca47fcea..6e6abca2e2 100644 --- a/flash/core/data/data_module.py +++ b/flash/core/data/data_module.py @@ -871,6 +871,7 @@ def from_json( batch_size: The ``batch_size`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. num_workers: The ``num_workers`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. sampler: The ``sampler`` argument to pass to the :class:`~flash.core.data.data_module.DataModule`. + field: To specify the field that holds the data in the JSON file. preprocess_kwargs: Additional keyword arguments to use when constructing the preprocess. Will only be used if ``preprocess = None``. @@ -887,6 +888,28 @@ def from_json( "to_tensor_transform": torch.as_tensor, }, ) + + # In the case where the data is of the form: + # { + # "version": 0.0.x, + # "data": { + # { + # "input_field" : "input_data", + # "target_field" : "target_output" + # }, + # ... + # } + # } + + data_module = DataModule.from_json( + "input", + "target", + train_file="train_data.json", + train_transform={ + "to_tensor_transform": torch.as_tensor, + }, + feild="data" + ) """ return cls.from_data_source( DefaultDataSources.JSON, From d548fb3e426d5bb7ab4c06e555cde52a4a3cf0a7 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Wed, 14 Jul 2021 20:11:44 +0100 Subject: [PATCH 4/5] Add some tests --- tests/text/classification/test_data.py | 24 +++++++++++++ .../seq2seq/question_answering/test_data.py | 24 +++++++++++++ tests/text/seq2seq/summarization/test_data.py | 36 +++++++++++++++---- tests/text/seq2seq/translation/test_data.py | 24 +++++++++++++ 4 files changed, 102 insertions(+), 6 deletions(-) diff --git a/tests/text/classification/test_data.py b/tests/text/classification/test_data.py index d5a3b680f9..b92c3757cc 100644 --- a/tests/text/classification/test_data.py +++ b/tests/text/classification/test_data.py @@ -44,6 +44,12 @@ {"sentence": "this is a sentence three","lab":0} """ +TEST_JSON_DATA_FIELD = """{"data": [ +{"sentence": "this is a sentence one","lab":0}, +{"sentence": "this is a sentence two","lab":1}, +{"sentence": "this is a sentence three","lab":0}]} +""" + def csv_data(tmpdir): path = Path(tmpdir) / "data.csv" @@ -57,6 +63,12 @@ def json_data(tmpdir): return path +def json_data_with_field(tmpdir): + path = Path(tmpdir) / "data.json" + path.write_text(TEST_JSON_DATA_FIELD) + return path + + @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") @pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") def test_from_csv(tmpdir): @@ -99,6 +111,18 @@ def test_from_json(tmpdir): assert "input_ids" in batch +@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") +@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") +def test_from_json_with_field(tmpdir): + json_path = json_data_with_field(tmpdir) + dm = TextClassificationData.from_json( + "sentence", "lab", backbone=TEST_BACKBONE, train_file=json_path, batch_size=1, field="data" + ) + batch = next(iter(dm.train_dataloader())) + assert batch["labels"].item() in [0, 1] + assert "input_ids" in batch + + @pytest.mark.skipif(_TEXT_AVAILABLE, reason="text libraries are installed.") def test_text_module_not_found_error(): with pytest.raises(ModuleNotFoundError, match="[text]"): diff --git a/tests/text/seq2seq/question_answering/test_data.py b/tests/text/seq2seq/question_answering/test_data.py index 2db170464e..83f7824e57 100644 --- a/tests/text/seq2seq/question_answering/test_data.py +++ b/tests/text/seq2seq/question_answering/test_data.py @@ -33,6 +33,12 @@ {"input": "this is a question three","target":"this is an answer three"} """ +TEST_JSON_DATA_FIELD = """{"data": [ +{"input": "this is a question one","target":"this is an answer one"}, +{"input": "this is a question two","target":"this is an answer two"}, +{"input": "this is a question three","target":"this is an answer three"}]} +""" + def csv_data(tmpdir): path = Path(tmpdir) / "data.csv" @@ -46,6 +52,12 @@ def json_data(tmpdir): return path +def json_data_with_field(tmpdir): + path = Path(tmpdir) / "data.json" + path.write_text(TEST_JSON_DATA_FIELD) + return path + + @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") @pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") def test_from_csv(tmpdir): @@ -106,3 +118,15 @@ def test_from_json(tmpdir): batch = next(iter(dm.train_dataloader())) assert "labels" in batch assert "input_ids" in batch + + +@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") +@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") +def test_from_json_with_field(tmpdir): + json_path = json_data_with_field(tmpdir) + dm = QuestionAnsweringData.from_json( + "input", "target", backbone=TEST_BACKBONE, train_file=json_path, batch_size=1, field="data" + ) + batch = next(iter(dm.train_dataloader())) + assert "labels" in batch + assert "input_ids" in batch diff --git a/tests/text/seq2seq/summarization/test_data.py b/tests/text/seq2seq/summarization/test_data.py index 2ab09f3636..a1120854ea 100644 --- a/tests/text/seq2seq/summarization/test_data.py +++ b/tests/text/seq2seq/summarization/test_data.py @@ -22,15 +22,21 @@ TEST_BACKBONE = "sshleifer/tiny-mbart" # super small model for testing TEST_CSV_DATA = """input,target -this is a sentence one,this is a translated sentence one -this is a sentence two,this is a translated sentence two -this is a sentence three,this is a translated sentence three +this is a sentence one,this is a summarized sentence one +this is a sentence two,this is a summarized sentence two +this is a sentence three,this is a summarized sentence three """ TEST_JSON_DATA = """ -{"input": "this is a sentence one","target":"this is a translated sentence one"} -{"input": "this is a sentence two","target":"this is a translated sentence two"} -{"input": "this is a sentence three","target":"this is a translated sentence three"} +{"input": "this is a sentence one","target":"this is a summarized sentence one"} +{"input": "this is a sentence two","target":"this is a summarized sentence two"} +{"input": "this is a sentence three","target":"this is a summarized sentence three"} +""" + +TEST_JSON_DATA_FIELD = """{"data": [ +{"input": "this is a sentence one","target":"this is a summarized sentence one"}, +{"input": "this is a sentence two","target":"this is a summarized sentence two"}, +{"input": "this is a sentence three","target":"this is a summarized sentence three"}]} """ @@ -46,6 +52,12 @@ def json_data(tmpdir): return path +def json_data_with_field(tmpdir): + path = Path(tmpdir) / "data.json" + path.write_text(TEST_JSON_DATA_FIELD) + return path + + @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") @pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") def test_from_csv(tmpdir): @@ -106,3 +118,15 @@ def test_from_json(tmpdir): batch = next(iter(dm.train_dataloader())) assert "labels" in batch assert "input_ids" in batch + + +@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") +@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") +def test_from_json_with_field(tmpdir): + json_path = json_data_with_field(tmpdir) + dm = SummarizationData.from_json( + "input", "target", backbone=TEST_BACKBONE, train_file=json_path, batch_size=1, field="data" + ) + batch = next(iter(dm.train_dataloader())) + assert "labels" in batch + assert "input_ids" in batch diff --git a/tests/text/seq2seq/translation/test_data.py b/tests/text/seq2seq/translation/test_data.py index 244cb27d4a..27162491a0 100644 --- a/tests/text/seq2seq/translation/test_data.py +++ b/tests/text/seq2seq/translation/test_data.py @@ -33,6 +33,12 @@ {"input": "this is a sentence three","target":"this is a translated sentence three"} """ +TEST_JSON_DATA_FIELD = """{"data": [ +{"input": "this is a sentence one","target":"this is a translated sentence one"}, +{"input": "this is a sentence two","target":"this is a translated sentence two"}, +{"input": "this is a sentence three","target":"this is a translated sentence three"}]} +""" + def csv_data(tmpdir): path = Path(tmpdir) / "data.csv" @@ -46,6 +52,12 @@ def json_data(tmpdir): return path +def json_data_with_field(tmpdir): + path = Path(tmpdir) / "data.json" + path.write_text(TEST_JSON_DATA_FIELD) + return path + + @pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") @pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") def test_from_csv(tmpdir): @@ -86,3 +98,15 @@ def test_from_json(tmpdir): batch = next(iter(dm.train_dataloader())) assert "labels" in batch assert "input_ids" in batch + + +@pytest.mark.skipif(os.name == "nt", reason="Huggingface timing out on Windows") +@pytest.mark.skipif(not _TEXT_TESTING, reason="text libraries aren't installed.") +def test_from_json_with_field(tmpdir): + json_path = json_data_with_field(tmpdir) + dm = TranslationData.from_json( + "input", "target", backbone=TEST_BACKBONE, train_file=json_path, batch_size=1, field="data" + ) + batch = next(iter(dm.train_dataloader())) + assert "labels" in batch + assert "input_ids" in batch From dc5779218f1929275bff06092d8da95538396b87 Mon Sep 17 00:00:00 2001 From: Ethan Harris Date: Wed, 14 Jul 2021 20:32:44 +0100 Subject: [PATCH 5/5] Update flash/core/data/data_module.py --- flash/core/data/data_module.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flash/core/data/data_module.py b/flash/core/data/data_module.py index 87eef202b1..5831c84a68 100644 --- a/flash/core/data/data_module.py +++ b/flash/core/data/data_module.py @@ -942,13 +942,13 @@ def from_json( # In the case where the data is of the form: # { # "version": 0.0.x, - # "data": { + # "data": [ # { # "input_field" : "input_data", # "target_field" : "target_output" # }, # ... - # } + # ] # } data_module = DataModule.from_json(