Skip to content
This repository has been archived by the owner on Nov 21, 2022. It is now read-only.

support for Trainer.predict method #261

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions lightning_transformers/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class TransformerDataConfig:
train_val_split: Optional[int] = None
train_file: Optional[str] = None
test_file: Optional[str] = None
predict_file: Optional[str] = None
validation_file: Optional[str] = None
padding: Union[str, bool] = "max_length"
truncation: str = "only_first"
Expand All @@ -42,6 +43,7 @@ class TransformerDataConfig:
train_subset_name: Optional[str] = None
validation_subset_name: Optional[str] = None
test_subset_name: Optional[str] = None
predict_subset_name: Optional[str] = None
streaming: bool = False


Expand Down
12 changes: 11 additions & 1 deletion lightning_transformers/core/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def load_dataset(self) -> Dataset:
)

# Use special subset names if provided, and rename them back to standard ones
for subset in ("train", "validation", "test"):
for subset in ("train", "validation", "test", "predict"):
config_attr = f"{subset}_subset_name"
if getattr(self.cfg, config_attr) is not None:
special_subset_name = getattr(self.cfg, config_attr)
Expand Down Expand Up @@ -141,6 +141,16 @@ def test_dataloader(self) -> Optional[DataLoader]:
collate_fn=self.collate_fn,
)

def predict_dataloader(self) -> Optional[DataLoader]:
if "predict" in self.ds:
cls = DataLoader if not self.cfg.streaming else IterableDataLoader
return cls(
self.ds["predict"],
batch_size=self.batch_size,
num_workers=self.cfg.num_workers,
collate_fn=self.collate_fn,
)

@property
def batch_size(self) -> int:
return self.cfg.batch_size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@ def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> torc
batch["labels"] = None
return self.common_step("test", batch)

def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> torch.Tensor:
batch["labels"] = None
outputs = self.model(**batch)
logits = outputs.logits
preds = torch.argmax(logits, dim=1)
return preds

def configure_metrics(self, _) -> None:
self.prec = Precision(num_classes=self.num_classes, average="macro")
self.recall = Recall(num_classes=self.num_classes, average="macro")
Expand Down
23 changes: 23 additions & 0 deletions tests/task/nlp/test_text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,29 @@ def test_smoke_train(hf_cache_path):
trainer.fit(model, dm)


def test_smoke_predict_with_trainer(hf_cache_path):
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path="prajjwal1/bert-tiny")
dm = TextClassificationDataModule(
cfg=TextClassificationDataConfig(
batch_size=1,
dataset_name="glue",
dataset_config_name="sst2",
max_length=512,
limit_test_samples=64,
limit_val_samples=64,
limit_train_samples=64,
cache_dir=hf_cache_path,
predict_subset_name="test", # Use the "test" split of the dataset as our prediction subset
),
tokenizer=tokenizer,
)
model = TextClassificationTransformer(pretrained_model_name_or_path="prajjwal1/bert-tiny")
trainer = pl.Trainer(fast_dev_run=True)
y = trainer.predict(model, dm)
assert len(y) == 1
assert int(y[0]) in (0, 1)


@pytest.mark.skipif(sys.platform == "win32", reason="Currently Windows is not supported")
def test_smoke_predict():
model = TextClassificationTransformer(
Expand Down