-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/auto batch size find #426
Conversation
@@ -93,6 +100,10 @@ def train(cfg: DictConfig, data=None) -> Tuple[dict, dict]: | |||
log.info("Logging hyperparameters!") | |||
utils.log_hyperparameters(object_dict) | |||
|
|||
if use_batch_tuner: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems related to the new block up starting on line 62 - could it go up there for readability? Could they also be a single condition (it seems they are logically bound?).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. I have it split up because the data arguments need to be changed if the batch size tuner is used, but then the actual batch tuning requires an initialized datamodule. I can rearrange things to make them a little closer, but that might be purely aesthetic as they can't be fully grouped.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would weight my review low not having much context.
Is a new kwarg is added by this change? It would be good to see some unit test coverage for that but only if a suit exists already.
No new kwarg is added - just changing the default value of an existing kwarg. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Super useful! Definitely annoying to try to find the largest batch size that fits memory
What does this PR do?
The auto batch size finder increases the batch size by powers of 2 (starting with batch-size=1) by setting the datamodule's
batch_size
attribute. This PR allows just the dataframe datamodule (I think the only one used by the plugin) to respond to its batch_size attribute being changed. In order to do this, we store the initial batch size (set to 1) and check for changes to the batch size to update the train dataloaders.Before submitting
pytest
command?pre-commit run -a
command?Did you have fun?
Make sure you had fun coding 🙃