Skip to content
This repository has been archived by the owner on Nov 21, 2022. It is now read-only.

support for Trainer.predict method #261

Merged

Conversation

RR-28023
Copy link
Contributor

This PR would resolve #260. Is just some additions to the TransformerDataConfig and the TransformerDataModule objects in order to support Trainer.predict.

The model passed to the Trainer must have a a predict_step or a forward method implemented (it defaults to forward if there is no predict_step) in order for Trainer.predictto work. I've implemented it in TextClassificationTransformer, but eventually would need to be implemented for all nlp tasks.

@RR-28023 RR-28023 requested review from SeanNaren and Borda as code owners June 16, 2022 16:48
@codecov
Copy link

codecov bot commented Jun 16, 2022

Codecov Report

Merging #261 (946e23a) into master (c5bca75) will increase coverage by 0%.
The diff coverage is 100%.

@@          Coverage Diff          @@
##           master   #261   +/-   ##
=====================================
  Coverage      75%    75%           
=====================================
  Files          74     74           
  Lines        1641   1653   +12     
=====================================
+ Hits         1228   1243   +15     
+ Misses        413    410    -3     

@SeanNaren SeanNaren enabled auto-merge (squash) June 21, 2022 11:49
@SeanNaren SeanNaren disabled auto-merge June 21, 2022 11:49
@SeanNaren SeanNaren merged commit 840a67a into Lightning-Universe:master Jun 21, 2022
@SeanNaren
Copy link
Contributor

Thanks @RR-28023!!!

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support for pytorch_lightning.Trainer.predict
2 participants