-
Notifications
You must be signed in to change notification settings - Fork 27.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Flax] Add remat (gradient checkpointing) #17843
Conversation
layer_head_mask=head_mask[i] if head_mask is not None else None, | ||
encoder_hidden_states=encoder_hidden_states, | ||
encoder_attention_mask=encoder_attention_mask, | ||
init_cache=init_cache, | ||
deterministic=deterministic, | ||
output_attentions=output_attentions, | ||
head_mask[i] if head_mask is not None else None, | ||
encoder_hidden_states, | ||
encoder_attention_mask, | ||
init_cache, | ||
deterministic, | ||
output_attentions, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: remat
does not support kwargs, hence the need to change to args
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok!
The documentation is not available anymore as the PR was closed or merged. |
Is there an inconvenient in adding it to all layers? In my case I used it only on transformers blocks (attention + feed forward). |
By wrapping
We then use this remat'd layer to construct the Transformer block (layers collection): transformers/src/transformers/models/bert/modeling_flax_bert.py Lines 559 to 562 in ea8150a
Meaning that each component of the Bert layer is checkpointed, and that all Bert layers in the Transformer block (layers collection) are checkpointed. Would you like to see |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice! Also cc @younesbelkada
We could also look into implementing this for OPT and BLOOM in Flax :-) Great job @sanchit-gandhi
Only feedback from my side would be to remove the option to overwrite the policy (also since we don't test it)
|
||
def setup(self): | ||
if self.gradient_checkpointing: | ||
FlaxBertCheckpointLayer = remat(FlaxBertLayer, static_argnums=(5, 6, 7), policy=self.remat_policy) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FlaxBertCheckpointLayer = remat(FlaxBertLayer, static_argnums=(5, 6, 7), policy=self.remat_policy) | |
FlaxBertLayer = remat(FlaxBertLayer, static_argnums=(5, 6, 7), policy=self.remat_policy) |
(nit) I'd just leave the naming as is. IMO it's easier to read the code and compare to PyTorch this way, but also happy to leave as is
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remat
prevented re-use of the class name FlaxBertLayer
:
google/flax#2251
Can re-name in a follow-up PR if we find a workaround!
layer_head_mask=head_mask[i] if head_mask is not None else None, | ||
encoder_hidden_states=encoder_hidden_states, | ||
encoder_attention_mask=encoder_attention_mask, | ||
init_cache=init_cache, | ||
deterministic=deterministic, | ||
output_attentions=output_attentions, | ||
head_mask[i] if head_mask is not None else None, | ||
encoder_hidden_states, | ||
encoder_attention_mask, | ||
init_cache, | ||
deterministic, | ||
output_attentions, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok!
@@ -617,9 +628,16 @@ def __call__( | |||
class FlaxBertEncoder(nn.Module): | |||
config: BertConfig | |||
dtype: jnp.dtype = jnp.float32 # the dtype of the computation | |||
gradient_checkpointing: bool = False | |||
remat_policy: Callable[..., bool] = (None,) # the gradient checkpointing policy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are there multiple policies? Would one every use another one then the default one? Wondering if allowing this parameter to be customizable might be a bit scary for the user and make the whole functionality less understandable. Think I'd prefer to just use the default here and not allow the user to configure it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The full list of remat
policies can be found here. They dictate whether output value(s) are saved as a residual or whether they must be recomputed in the (co)tangent computation.
The advice for selecting an appropriate remat
policy is empirically driven: try them all and see what works best! On paper, dot_with_no_batch_dims
should work best for Transformer architectures, and indeed was the preference for T5x. However, for the Seq2Seq project, I found the default policy to be optimal!
I'm in agreement that including the remat_policy
as an attribute is probably too heavy and clutters the code. It's straightforward to add one's own policy choice by overriding the policy
arg to the remat
method, and users who wish to do so can easily access this.
No actually I thought it was on all layers but the way you did is great! |
Cool! Once the tests are green, happy to merge it here :-) |
c791574
to
9972f38
Compare
* [Flax] Add remat (gradient checkpointing) * fix variable naming in test * flip: checkpoint using a method * fix naming * fix class naming * apply PVP's suggestions from code review * make fix-copies * fix big-bird, electra, roberta * cookie-cutter * fix flax big-bird * move test to common
What does this PR do?
Adds gradient checkpointing in Flax (c.f. #17399). The API currently takes the form of a method:
Note: checkpointing has currently only been implemented for FlaxBert. Implementing for all Flax models is a TODO.
TODO:
init
from_pretrained
test_modeling_flax_bert
gradient_checkpointing=True
) or a method (model.gradient_checkpointing_enable()
)?test_modeling_flax_common
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
cc @borisdayma