Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Refactor InputTransform and DataModule #1233

Merged

Conversation

karthikrangasai
Copy link
Contributor

@karthikrangasai karthikrangasai commented Mar 15, 2022

What does this PR do?

Resolves discussion from #1166

At present, the InputTransform for every DataModule is being passed for every stage even though InputTransform has specific methods to differentiate the stage at which the certain transform runs. This also creates 4 different instances of the class in each input (train, val, test, predict), with different methods to be run.

This PR aims to change the aforementioned by making the DataModule as the owner of the InputTransform because all the class does it to generate collate_fn for the dataloaders and the implementation for the on_after_batch_transfer callback.

Thus a single instance of the InputTransform class, present in the DataModule, can resolve the required Callables for every stage and the appropriate dataloader collate_fn and on_after_batch_transfer functions are created in the DataLoader's __init__ method.

This also relieves the Input class from having to take care of the InputTransform.

TL;DR

Previous API

dm = XYZTask_DataModule.from_xyz(
    train_file=train_file,
    val_file=val_file,
    test_file=test_file,
    predict_file=predict_file,
    train_transform=InputTransform,
    val_transform=InputTransform,
    test_transform=InputTransform,
    predict_transform=InputTransform,
    transform_kwargs=transform_kwargs,
)

# Implementation
class XYZTask_DataModule(DataModule):
    
    @classmethod
    def from_xyz(
        cls,
        train_file=train_file,
        val_file=val_file,
        test_file=test_file,
        predict_file=predict_file,
        train_transform=InputTransform,
        val_transform=InputTransform,
        test_transform=InputTransform,
        predict_transform=InputTransform,
        transform_kwargs=transform_kwargs,
        input_cls=Input,
    ):
        return cls(
            input_cls(RunningStage.TRAINING, train_file, transform=train_transform, **transform_kwargs),
            input_cls(RunningStage.VALIDATING, val_file, transform=val_transform, **transform_kwargs),
            input_cls(RunningStage.TESTING, test_file, transform=test_transform, **transform_kwargs),
            input_cls(RunningStage.PREDICTING, predict_file, transform=predict_transform, **transform_kwargs),
        )

New API

dm = XYZTask_DataModule.from_xyz(
    train_file=train_file,
    val_file=val_file,
    test_file=test_file,
    predict_file=predict_file,
    transform=InputTransform,
    transform_kwargs=transform_kwargs,
)

# Implementation
class XYZTask_DataModule(DataModule):
    
    @classmethod
    def from_xyz(
        cls,
        train_file=train_file,
        val_file=val_file,
        test_file=test_file,
        predict_file=predict_file,
        transform=InputTransform,
        transform_kwargs=transform_kwargs,
        input_cls=Input,
    ):
        return cls(
            input_cls(RunningStage.TRAINING, train_file),
            input_cls(RunningStage.TRAINING, val_file),
            input_cls(RunningStage.TRAINING, test_file),
            input_cls(RunningStage.TRAINING, predict_file),
            transform=transform,
            **transform_kwargs
        )

Before submitting

  • Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests? [not needed for typos/docs]
  • Did you verify new and existing tests pass locally with your changes?
  • If you made a notable change (that affects users), did you update the CHANGELOG?

PR review

  • Is this pull request ready for review? (if not, please submit in draft mode)

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃

@codecov
Copy link

codecov bot commented Mar 15, 2022

Codecov Report

Merging #1233 (e10cd5a) into master (9001449) will decrease coverage by 0.05%.
The diff coverage is 88.88%.

@@            Coverage Diff             @@
##           master    #1233      +/-   ##
==========================================
- Coverage   91.11%   91.05%   -0.06%     
==========================================
  Files         285      286       +1     
  Lines       12791    12764      -27     
==========================================
- Hits        11654    11622      -32     
- Misses       1137     1142       +5     
Flag Coverage Δ
unittests 91.05% <88.88%> (-0.06%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
flash/audio/speech_recognition/data.py 100.00% <ø> (ø)
flash/image/detection/backbones.py 93.54% <ø> (+19.12%) ⬆️
flash/image/segmentation/data.py 100.00% <ø> (ø)
flash/image/style_transfer/data.py 100.00% <ø> (ø)
flash/pointcloud/detection/data.py 92.00% <0.00%> (ø)
flash/text/question_answering/data.py 100.00% <ø> (ø)
flash/pointcloud/segmentation/data.py 81.81% <33.33%> (ø)
flash/core/integrations/icevision/adapter.py 87.91% <68.18%> (-5.34%) ⬇️
flash/core/integrations/icevision/wrappers.py 75.00% <75.00%> (ø)
flash/core/integrations/icevision/backbones.py 92.30% <83.33%> (-2.14%) ⬇️
... and 33 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 9001449...e10cd5a. Read the comment docs.

@karthikrangasai
Copy link
Contributor Author

Hello all,

I have updated the PR for almost all tasks.

I need some help for:

  1. ServeInput and serve related parts.
  2. Pointcloud tasks retrieve their collate_fn when model is instantiated. The updated _collate_fn method in InputTransform class accepts an extra argument (stage: RunningStage) and I wasn't sure of how to set that up in a cleaner manner.
  3. Same issue as above with the image.ObjectDetectionTask which fails for the test_init_train testing function. The extra argument is causing an error. Need some help in updating/changing this test.

Thanks

Copy link
Collaborator

@ethanwharris ethanwharris left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome work! Let's do the following:

  • revert the changes around moving collation from seq2seq, qa, forecasting, etc. (these can be discussed separately and done in individual PRs if we decide)
  • update the examples / docs to use the now recommended API (that is, not using transform_kwargs)

Regarding ServeInput, it may cause issues having the datamodule own the transforms the way it's currently laid out. But I think we just need to refactor the serving stuff to apply the transforms in the same way as the datamodule.



@dataclass
class SpeechRecognitionInputCollateTransform(InputTransform):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer to keep this as it was and then consider moving it in a seperate PR

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was initially planning to make it a separate PR, but the tests for these tasks were failing. So I ended up implementing them all here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we just need to resolve the collate function from the models correctly (or there are issues with it). The main reason to leave this for the future is that I'm not sure this is where this functionality should end up. To me it's weird that the transform can include collation since that's not really a transform. I also think we should avoid a situation where users have to provide the backbone in two places (since only the model should know the backbone really).

flash/core/data/data_module.py Show resolved Hide resolved
flash/core/data/data_module.py Show resolved Hide resolved
flash/core/data/io/input_transform.py Show resolved Hide resolved
flash/tabular/forecasting/input_transform.py Outdated Show resolved Hide resolved
tests/text/question_answering/test_model.py Outdated Show resolved Hide resolved
pad_to_multiple_of: Optional[int] = None
pad_to_multiple_of_labels: Optional[int] = None

def __post_init__(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be awesome to investigate AugLy Augmentation for Speech: https://github.com/facebookresearch/AugLy/tree/main/augly/audio

flash/core/data/data_module.py Outdated Show resolved Hide resolved
flash/core/data/data_module.py Outdated Show resolved Hide resolved
flash/text/question_answering/input_transform.py Outdated Show resolved Hide resolved
flash/text/question_answering/model.py Outdated Show resolved Hide resolved
@karthikrangasai karthikrangasai requested review from ethanwharris and tchaton and removed request for ananyahjha93 March 23, 2022 09:02
Copy link
Contributor

@krshrimali krshrimali left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome work, @karthikrangasai - lots of important refactoring done in this PR. Thank you! We are getting there, just left a few comments (minor nits, and a few questions for my knowledge). Please let me know if you have any questions.

Just taking a note for the future, we should also update the examples since the API has now changed (tests are also failing I guess for this reason). My suggestion would be, that we create a PR to fix examples - and once that is ready, then only merge this and the other PR for examples. Just to make sure that examples are never out of date. But open to discussion, of course :) cc: @ethanwharris @Borda

flash/core/data/io/input_transform.py Outdated Show resolved Hide resolved
flash/core/data/io/input_transform.py Show resolved Hide resolved
flash/core/data/io/input_transform.py Outdated Show resolved Hide resolved
Comment on lines +1187 to +1189
if on_device:
return input_transform._identity, collate
return collate, input_transform._identity
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please add a comment in this function on what it does?

Also, for the future, we should add a comment on what on_device means and does, in the _InputTransformProcessorV2 class.

cc: @ethanwharris

flash/core/data/io/input_transform.py Outdated Show resolved Hide resolved
flash/core/integrations/icevision/adapter.py Show resolved Hide resolved
flash_examples/flash_components/custom_data_loading.py Outdated Show resolved Hide resolved
flash_examples/flash_components/custom_data_loading.py Outdated Show resolved Hide resolved
flash_examples/flash_components/custom_data_loading.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@ethanwharris ethanwharris left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, LGTM 😃

@ethanwharris ethanwharris merged commit 6da53fe into Lightning-Universe:master Mar 25, 2022
@tchaton
Copy link
Contributor

tchaton commented Mar 26, 2022

Awesome work @karthikrangasai !

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants