Skip to content
This repository has been archived by the owner on Nov 21, 2022. It is now read-only.

Support for pytorch_lightning.Trainer.predict #260

Closed
RR-28023 opened this issue Jun 16, 2022 · 1 comment · Fixed by #261
Closed

Support for pytorch_lightning.Trainer.predict #260

RR-28023 opened this issue Jun 16, 2022 · 1 comment · Fixed by #261
Labels
enhancement New feature or request help wanted Extra attention is needed

Comments

@RR-28023
Copy link
Contributor

RR-28023 commented Jun 16, 2022

🚀 Feature

Support for using Pytorch's Lightning pytorch_lightning.Trainer.predict method with ligthing-transformers models and datamodules.

Motivation

Trainer.predict is a very convenient method to run inference on a large dataset, since it leverages all the device management and parallelization functionalities of Trainer. However, I recently trained a lighting-transformers TextClassificationTransformer and, despite being able to train it using Trainer.fit, I can't run inference using `Trainer.predict.

Pitch

Only a couple of small changes are needed (mainly define the predict_dataloader in TransformerDataModule and then define the predict_step on each child of TaskTransformer). Can submit a PR showing what would need to change (since I did it anyway for my project).

Alternatives

Currently, inference can be run using a transformers.pipeline, like shown in predict.py. However that does not allow to leverage the Trainer functionalities, and it does not go along with the main purpose of this project which is to enable the use of Trainer in combination with the transformers.

Note that using Trainer.test would not work because it does not return predictions.

Thanks!

@RR-28023 RR-28023 added enhancement New feature or request help wanted Extra attention is needed labels Jun 16, 2022
@RR-28023 RR-28023 changed the title Support for pytorch_lightning.Trainer.inference Support for pytorch_lightning.Trainer.predict Jun 16, 2022
@SeanNaren
Copy link
Contributor

Thanks @RR-28023 we can get this in for sure! We're slowly working towards removing the rest of the hydra code (and then rely on pure lightning, we'll provide examples to show how to do this).

We'll be making a release before this change though, so will make sure this is in before then.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
enhancement New feature or request help wanted Extra attention is needed
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants