This repository has been archived by the owner on Nov 21, 2022. It is now read-only.
Support for pytorch_lightning.Trainer.predict
#260
Labels
pytorch_lightning.Trainer.predict
#260
🚀 Feature
Support for using Pytorch's Lightning
pytorch_lightning.Trainer.predict
method with ligthing-transformers models and datamodules.Motivation
Trainer.predict
is a very convenient method to run inference on a large dataset, since it leverages all the device management and parallelization functionalities of Trainer. However, I recently trained a lighting-transformersTextClassificationTransformer
and, despite being able to train it usingTrainer.fit
, I can't run inference using `Trainer.predict.Pitch
Only a couple of small changes are needed (mainly define the
predict_dataloader
inTransformerDataModule
and then define thepredict_step
on each child ofTaskTransformer
). Can submit a PR showing what would need to change (since I did it anyway for my project).Alternatives
Currently, inference can be run using a
transformers.pipeline
, like shown in predict.py. However that does not allow to leverage the Trainer functionalities, and it does not go along with the main purpose of this project which is to enable the use ofTrainer
in combination with the transformers.Note that using
Trainer.test
would not work because it does not return predictions.Thanks!
The text was updated successfully, but these errors were encountered: