diff --git a/lightning_transformers/core/config.py b/lightning_transformers/core/config.py index b47a5573..3438d6be 100644 --- a/lightning_transformers/core/config.py +++ b/lightning_transformers/core/config.py @@ -29,6 +29,7 @@ class TransformerDataConfig: train_val_split: Optional[int] = None train_file: Optional[str] = None test_file: Optional[str] = None + predict_file: Optional[str] = None validation_file: Optional[str] = None padding: Union[str, bool] = "max_length" truncation: str = "only_first" @@ -42,6 +43,7 @@ class TransformerDataConfig: train_subset_name: Optional[str] = None validation_subset_name: Optional[str] = None test_subset_name: Optional[str] = None + predict_subset_name: Optional[str] = None streaming: bool = False diff --git a/lightning_transformers/core/data.py b/lightning_transformers/core/data.py index 1c489d39..0766538f 100644 --- a/lightning_transformers/core/data.py +++ b/lightning_transformers/core/data.py @@ -75,7 +75,7 @@ def load_dataset(self) -> Dataset: ) # Use special subset names if provided, and rename them back to standard ones - for subset in ("train", "validation", "test"): + for subset in ("train", "validation", "test", "predict"): config_attr = f"{subset}_subset_name" if getattr(self.cfg, config_attr) is not None: special_subset_name = getattr(self.cfg, config_attr) @@ -141,6 +141,16 @@ def test_dataloader(self) -> Optional[DataLoader]: collate_fn=self.collate_fn, ) + def predict_dataloader(self) -> Optional[DataLoader]: + if "predict" in self.ds: + cls = DataLoader if not self.cfg.streaming else IterableDataLoader + return cls( + self.ds["predict"], + batch_size=self.batch_size, + num_workers=self.cfg.num_workers, + collate_fn=self.collate_fn, + ) + @property def batch_size(self) -> int: return self.cfg.batch_size diff --git a/lightning_transformers/task/nlp/text_classification/model.py b/lightning_transformers/task/nlp/text_classification/model.py index 0cca066d..e563b01c 100644 --- a/lightning_transformers/task/nlp/text_classification/model.py +++ b/lightning_transformers/task/nlp/text_classification/model.py @@ -60,6 +60,13 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> torc batch["labels"] = None return self.common_step("test", batch) + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> torch.Tensor: + batch["labels"] = None + outputs = self.model(**batch) + logits = outputs.logits + preds = torch.argmax(logits, dim=1) + return preds + def configure_metrics(self, _) -> None: self.prec = Precision(num_classes=self.num_classes, average="macro") self.recall = Recall(num_classes=self.num_classes, average="macro") diff --git a/tests/task/nlp/test_text_classification.py b/tests/task/nlp/test_text_classification.py index a0e64d0e..c5690b2a 100644 --- a/tests/task/nlp/test_text_classification.py +++ b/tests/task/nlp/test_text_classification.py @@ -33,6 +33,29 @@ def test_smoke_train(hf_cache_path): trainer.fit(model, dm) +def test_smoke_predict_with_trainer(hf_cache_path): + tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path="prajjwal1/bert-tiny") + dm = TextClassificationDataModule( + cfg=TextClassificationDataConfig( + batch_size=1, + dataset_name="glue", + dataset_config_name="sst2", + max_length=512, + limit_test_samples=64, + limit_val_samples=64, + limit_train_samples=64, + cache_dir=hf_cache_path, + predict_subset_name="test", # Use the "test" split of the dataset as our prediction subset + ), + tokenizer=tokenizer, + ) + model = TextClassificationTransformer(pretrained_model_name_or_path="prajjwal1/bert-tiny") + trainer = pl.Trainer(fast_dev_run=True) + y = trainer.predict(model, dm) + assert len(y) == 1 + assert int(y[0]) in (0, 1) + + @pytest.mark.skipif(sys.platform == "win32", reason="Currently Windows is not supported") def test_smoke_predict(): model = TextClassificationTransformer(