Skip to content
This repository has been archived by the owner on Nov 21, 2022. It is now read-only.

Commit

Permalink
support for Trainer.predict method (#261)
Browse files Browse the repository at this point in the history
  • Loading branch information
RR-28023 authored Jun 21, 2022
1 parent c5bca75 commit 840a67a
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 1 deletion.
2 changes: 2 additions & 0 deletions lightning_transformers/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class TransformerDataConfig:
train_val_split: Optional[int] = None
train_file: Optional[str] = None
test_file: Optional[str] = None
predict_file: Optional[str] = None
validation_file: Optional[str] = None
padding: Union[str, bool] = "max_length"
truncation: str = "only_first"
Expand All @@ -42,6 +43,7 @@ class TransformerDataConfig:
train_subset_name: Optional[str] = None
validation_subset_name: Optional[str] = None
test_subset_name: Optional[str] = None
predict_subset_name: Optional[str] = None
streaming: bool = False


Expand Down
12 changes: 11 additions & 1 deletion lightning_transformers/core/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def load_dataset(self) -> Dataset:
)

# Use special subset names if provided, and rename them back to standard ones
for subset in ("train", "validation", "test"):
for subset in ("train", "validation", "test", "predict"):
config_attr = f"{subset}_subset_name"
if getattr(self.cfg, config_attr) is not None:
special_subset_name = getattr(self.cfg, config_attr)
Expand Down Expand Up @@ -141,6 +141,16 @@ def test_dataloader(self) -> Optional[DataLoader]:
collate_fn=self.collate_fn,
)

def predict_dataloader(self) -> Optional[DataLoader]:
if "predict" in self.ds:
cls = DataLoader if not self.cfg.streaming else IterableDataLoader
return cls(
self.ds["predict"],
batch_size=self.batch_size,
num_workers=self.cfg.num_workers,
collate_fn=self.collate_fn,
)

@property
def batch_size(self) -> int:
return self.cfg.batch_size
Expand Down
7 changes: 7 additions & 0 deletions lightning_transformers/task/nlp/text_classification/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> torc
batch["labels"] = None
return self.common_step("test", batch)

def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> torch.Tensor:
batch["labels"] = None
outputs = self.model(**batch)
logits = outputs.logits
preds = torch.argmax(logits, dim=1)
return preds

def configure_metrics(self, _) -> None:
self.prec = Precision(num_classes=self.num_classes, average="macro")
self.recall = Recall(num_classes=self.num_classes, average="macro")
Expand Down
23 changes: 23 additions & 0 deletions tests/task/nlp/test_text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,29 @@ def test_smoke_train(hf_cache_path):
trainer.fit(model, dm)


def test_smoke_predict_with_trainer(hf_cache_path):
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path="prajjwal1/bert-tiny")
dm = TextClassificationDataModule(
cfg=TextClassificationDataConfig(
batch_size=1,
dataset_name="glue",
dataset_config_name="sst2",
max_length=512,
limit_test_samples=64,
limit_val_samples=64,
limit_train_samples=64,
cache_dir=hf_cache_path,
predict_subset_name="test", # Use the "test" split of the dataset as our prediction subset
),
tokenizer=tokenizer,
)
model = TextClassificationTransformer(pretrained_model_name_or_path="prajjwal1/bert-tiny")
trainer = pl.Trainer(fast_dev_run=True)
y = trainer.predict(model, dm)
assert len(y) == 1
assert int(y[0]) in (0, 1)


@pytest.mark.skipif(sys.platform == "win32", reason="Currently Windows is not supported")
def test_smoke_predict():
model = TextClassificationTransformer(
Expand Down

0 comments on commit 840a67a

Please sign in to comment.