-
Notifications
You must be signed in to change notification settings - Fork 211
Refactor InputTransform and DataModule #1233
Refactor InputTransform and DataModule #1233
Conversation
Codecov Report
@@ Coverage Diff @@
## master #1233 +/- ##
==========================================
- Coverage 91.11% 91.05% -0.06%
==========================================
Files 285 286 +1
Lines 12791 12764 -27
==========================================
- Hits 11654 11622 -32
- Misses 1137 1142 +5
Flags with carried forward coverage won't be shown. Click here to find out more.
Continue to review full report at Codecov.
|
…his PR and not torch 11.
Hello all, I have updated the PR for almost all tasks. I need some help for:
Thanks |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome work! Let's do the following:
- revert the changes around moving collation from seq2seq, qa, forecasting, etc. (these can be discussed separately and done in individual PRs if we decide)
- update the examples / docs to use the now recommended API (that is, not using
transform_kwargs
)
Regarding ServeInput
, it may cause issues having the datamodule own the transforms the way it's currently laid out. But I think we just need to refactor the serving stuff to apply the transforms in the same way as the datamodule.
|
||
|
||
@dataclass | ||
class SpeechRecognitionInputCollateTransform(InputTransform): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would prefer to keep this as it was and then consider moving it in a seperate PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was initially planning to make it a separate PR, but the tests for these tasks were failing. So I ended up implementing them all here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess we just need to resolve the collate function from the models correctly (or there are issues with it). The main reason to leave this for the future is that I'm not sure this is where this functionality should end up. To me it's weird that the transform can include collation since that's not really a transform. I also think we should avoid a situation where users have to provide the backbone in two places (since only the model should know the backbone really).
pad_to_multiple_of: Optional[int] = None | ||
pad_to_multiple_of_labels: Optional[int] = None | ||
|
||
def __post_init__(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be awesome to investigate AugLy Augmentation for Speech: https://github.com/facebookresearch/AugLy/tree/main/augly/audio
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome work, @karthikrangasai - lots of important refactoring done in this PR. Thank you! We are getting there, just left a few comments (minor nits, and a few questions for my knowledge). Please let me know if you have any questions.
Just taking a note for the future, we should also update the examples since the API has now changed (tests are also failing I guess for this reason). My suggestion would be, that we create a PR to fix examples - and once that is ready, then only merge this and the other PR for examples. Just to make sure that examples are never out of date. But open to discussion, of course :) cc: @ethanwharris @Borda
if on_device: | ||
return input_transform._identity, collate | ||
return collate, input_transform._identity |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please add a comment in this function on what it does?
Also, for the future, we should add a comment on what on_device
means and does, in the _InputTransformProcessorV2
class.
cc: @ethanwharris
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome, LGTM 😃
Awesome work @karthikrangasai ! |
What does this PR do?
Resolves discussion from #1166
At present, the
InputTransform
for everyDataModule
is being passed for every stage even thoughInputTransform
has specific methods to differentiate the stage at which the certain transform runs. This also creates 4 different instances of the class in each input (train, val, test, predict), with different methods to be run.This PR aims to change the aforementioned by making the
DataModule
as the owner of theInputTransform
because all the class does it to generatecollate_fn
for the dataloaders and the implementation for theon_after_batch_transfer
callback.Thus a single instance of the
InputTransform
class, present in the DataModule, can resolve the requiredCallable
s for every stage and the appropriate dataloadercollate_fn
andon_after_batch_transfer
functions are created in theDataLoader
's__init__
method.This also relieves the
Input
class from having to take care of theInputTransform
.TL;DR
Previous API
New API
Before submitting
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
Did you have fun?
Make sure you had fun coding 🙃