From e6ae1f2719bbf0cc4c70acc1a9ac6fcd79bd4f03 Mon Sep 17 00:00:00 2001 From: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Date: Fri, 1 Jul 2022 18:33:54 +0100 Subject: [PATCH] [Flax] Add remat (gradient checkpointing) (#17843) * [Flax] Add remat (gradient checkpointing) * fix variable naming in test * flip: checkpoint using a method * fix naming * fix class naming * apply PVP's suggestions from code review * make fix-copies * fix big-bird, electra, roberta * cookie-cutter * fix flax big-bird * move test to common --- src/transformers/modeling_flax_utils.py | 3 + .../models/bert/modeling_flax_bert.py | 118 +++++++++++++++--- .../models/big_bird/modeling_flax_big_bird.py | 105 +++++++++++++--- .../models/electra/modeling_flax_electra.py | 88 ++++++++++--- .../models/roberta/modeling_flax_roberta.py | 102 ++++++++++++--- ...ax_{{cookiecutter.lowercase_modelname}}.py | 67 +++++++--- tests/test_modeling_flax_common.py | 27 ++++ 7 files changed, 414 insertions(+), 96 deletions(-) diff --git a/src/transformers/modeling_flax_utils.py b/src/transformers/modeling_flax_utils.py index 74124209cb6838..77eaa900de622f 100644 --- a/src/transformers/modeling_flax_utils.py +++ b/src/transformers/modeling_flax_utils.py @@ -235,6 +235,9 @@ def __init__( def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> Dict: raise NotImplementedError(f"init method has to be implemented for {self}") + def enable_gradient_checkpointing(self): + raise NotImplementedError(f"gradient checkpointing method has to be implemented for {self}") + @classmethod def _from_config(cls, config, **kwargs): """ diff --git a/src/transformers/models/bert/modeling_flax_bert.py b/src/transformers/models/bert/modeling_flax_bert.py index 902d6cca3d13f4..8daa866be10561 100644 --- a/src/transformers/models/bert/modeling_flax_bert.py +++ b/src/transformers/models/bert/modeling_flax_bert.py @@ -23,6 +23,7 @@ import jax.numpy as jnp from flax.core.frozen_dict import FrozenDict, freeze, unfreeze from flax.linen import combine_masks, make_causal_mask +from flax.linen import partitioning as nn_partitioning from flax.linen.attention import dot_product_attention_weights from flax.traverse_util import flatten_dict, unflatten_dict from jax import lax @@ -56,6 +57,8 @@ _CONFIG_FOR_DOC = "BertConfig" _TOKENIZER_FOR_DOC = "BertTokenizer" +remat = nn_partitioning.remat + @flax.struct.dataclass class FlaxBertForPreTrainingOutput(ModelOutput): @@ -544,11 +547,19 @@ def __call__( class FlaxBertLayerCollection(nn.Module): config: BertConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation + gradient_checkpointing: bool = False def setup(self): - self.layers = [ - FlaxBertLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers) - ] + if self.gradient_checkpointing: + FlaxBertCheckpointLayer = remat(FlaxBertLayer, static_argnums=(5, 6, 7)) + self.layers = [ + FlaxBertCheckpointLayer(self.config, name=str(i), dtype=self.dtype) + for i in range(self.config.num_hidden_layers) + ] + else: + self.layers = [ + FlaxBertLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers) + ] def __call__( self, @@ -582,12 +593,12 @@ def __call__( layer_outputs = layer( hidden_states, attention_mask, - layer_head_mask=head_mask[i] if head_mask is not None else None, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - init_cache=init_cache, - deterministic=deterministic, - output_attentions=output_attentions, + head_mask[i] if head_mask is not None else None, + encoder_hidden_states, + encoder_attention_mask, + init_cache, + deterministic, + output_attentions, ) hidden_states = layer_outputs[0] @@ -617,9 +628,14 @@ def __call__( class FlaxBertEncoder(nn.Module): config: BertConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation + gradient_checkpointing: bool = False def setup(self): - self.layer = FlaxBertLayerCollection(self.config, dtype=self.dtype) + self.layer = FlaxBertLayerCollection( + self.config, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) def __call__( self, @@ -756,11 +772,24 @@ def __init__( seed: int = 0, dtype: jnp.dtype = jnp.float32, _do_init: bool = True, + gradient_checkpointing: bool = False, **kwargs ): - module = self.module_class(config=config, dtype=dtype, **kwargs) + module = self.module_class( + config=config, + dtype=dtype, + gradient_checkpointing=gradient_checkpointing, + **kwargs, + ) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + def enable_gradient_checkpointing(self): + self._module = self.module_class( + config=self.config, + dtype=self.dtype, + gradient_checkpointing=True, + ) + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensors input_ids = jnp.zeros(input_shape, dtype="i4") @@ -925,10 +954,15 @@ class FlaxBertModule(nn.Module): config: BertConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation add_pooling_layer: bool = True + gradient_checkpointing: bool = False def setup(self): self.embeddings = FlaxBertEmbeddings(self.config, dtype=self.dtype) - self.encoder = FlaxBertEncoder(self.config, dtype=self.dtype) + self.encoder = FlaxBertEncoder( + self.config, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) self.pooler = FlaxBertPooler(self.config, dtype=self.dtype) def __call__( @@ -1003,9 +1037,14 @@ class FlaxBertModel(FlaxBertPreTrainedModel): class FlaxBertForPreTrainingModule(nn.Module): config: BertConfig dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False def setup(self): - self.bert = FlaxBertModule(config=self.config, dtype=self.dtype) + self.bert = FlaxBertModule( + config=self.config, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) self.cls = FlaxBertPreTrainingHeads(config=self.config, dtype=self.dtype) def __call__( @@ -1099,9 +1138,15 @@ class FlaxBertForPreTraining(FlaxBertPreTrainedModel): class FlaxBertForMaskedLMModule(nn.Module): config: BertConfig dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False def setup(self): - self.bert = FlaxBertModule(config=self.config, add_pooling_layer=False, dtype=self.dtype) + self.bert = FlaxBertModule( + config=self.config, + add_pooling_layer=False, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) self.cls = FlaxBertOnlyMLMHead(config=self.config, dtype=self.dtype) def __call__( @@ -1161,9 +1206,14 @@ class FlaxBertForMaskedLM(FlaxBertPreTrainedModel): class FlaxBertForNextSentencePredictionModule(nn.Module): config: BertConfig dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False def setup(self): - self.bert = FlaxBertModule(config=self.config, dtype=self.dtype) + self.bert = FlaxBertModule( + config=self.config, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) self.cls = FlaxBertOnlyNSPHead(dtype=self.dtype) def __call__( @@ -1248,9 +1298,14 @@ class FlaxBertForNextSentencePrediction(FlaxBertPreTrainedModel): class FlaxBertForSequenceClassificationModule(nn.Module): config: BertConfig dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False def setup(self): - self.bert = FlaxBertModule(config=self.config, dtype=self.dtype) + self.bert = FlaxBertModule( + config=self.config, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) classifier_dropout = ( self.config.classifier_dropout if self.config.classifier_dropout is not None @@ -1324,9 +1379,14 @@ class FlaxBertForSequenceClassification(FlaxBertPreTrainedModel): class FlaxBertForMultipleChoiceModule(nn.Module): config: BertConfig dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False def setup(self): - self.bert = FlaxBertModule(config=self.config, dtype=self.dtype) + self.bert = FlaxBertModule( + config=self.config, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) self.classifier = nn.Dense(1, dtype=self.dtype) @@ -1399,9 +1459,15 @@ class FlaxBertForMultipleChoice(FlaxBertPreTrainedModel): class FlaxBertForTokenClassificationModule(nn.Module): config: BertConfig dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False def setup(self): - self.bert = FlaxBertModule(config=self.config, dtype=self.dtype, add_pooling_layer=False) + self.bert = FlaxBertModule( + config=self.config, + dtype=self.dtype, + add_pooling_layer=False, + gradient_checkpointing=self.gradient_checkpointing, + ) classifier_dropout = ( self.config.classifier_dropout if self.config.classifier_dropout is not None @@ -1468,9 +1534,15 @@ class FlaxBertForTokenClassification(FlaxBertPreTrainedModel): class FlaxBertForQuestionAnsweringModule(nn.Module): config: BertConfig dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False def setup(self): - self.bert = FlaxBertModule(config=self.config, dtype=self.dtype, add_pooling_layer=False) + self.bert = FlaxBertModule( + config=self.config, + dtype=self.dtype, + add_pooling_layer=False, + gradient_checkpointing=self.gradient_checkpointing, + ) self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype) def __call__( @@ -1539,9 +1611,15 @@ class FlaxBertForQuestionAnswering(FlaxBertPreTrainedModel): class FlaxBertForCausalLMModule(nn.Module): config: BertConfig dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False def setup(self): - self.bert = FlaxBertModule(config=self.config, add_pooling_layer=False, dtype=self.dtype) + self.bert = FlaxBertModule( + config=self.config, + add_pooling_layer=False, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) self.cls = FlaxBertOnlyMLMHead(config=self.config, dtype=self.dtype) def __call__( diff --git a/src/transformers/models/big_bird/modeling_flax_big_bird.py b/src/transformers/models/big_bird/modeling_flax_big_bird.py index 18a0a98df8fe43..2e3192ff0eeb02 100644 --- a/src/transformers/models/big_bird/modeling_flax_big_bird.py +++ b/src/transformers/models/big_bird/modeling_flax_big_bird.py @@ -23,6 +23,7 @@ import jax.numpy as jnp from flax.core.frozen_dict import FrozenDict, freeze, unfreeze from flax.linen import combine_masks, make_causal_mask +from flax.linen import partitioning as nn_partitioning from flax.linen.attention import dot_product_attention_weights from flax.traverse_util import flatten_dict, unflatten_dict from jax import lax @@ -54,6 +55,8 @@ _CONFIG_FOR_DOC = "BigBirdConfig" _TOKENIZER_FOR_DOC = "BigBirdTokenizer" +remat = nn_partitioning.remat + @flax.struct.dataclass class FlaxBigBirdForPreTrainingOutput(ModelOutput): @@ -1368,12 +1371,20 @@ def __call__( class FlaxBigBirdLayerCollection(nn.Module): config: BigBirdConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation + gradient_checkpointing: bool = False def setup(self): - self.layers = [ - FlaxBigBirdLayer(self.config, layer_id=i, name=str(i), dtype=self.dtype) - for i in range(self.config.num_hidden_layers) - ] + if self.gradient_checkpointing: + FlaxBigBirdCheckpointLayer = remat(FlaxBigBirdLayer, static_argnums=(5, 6, 7)) + self.layers = [ + FlaxBigBirdCheckpointLayer(self.config, layer_id=i, name=str(i), dtype=self.dtype) + for i in range(self.config.num_hidden_layers) + ] + else: + self.layers = [ + FlaxBigBirdLayer(self.config, layer_id=i, name=str(i), dtype=self.dtype) + for i in range(self.config.num_hidden_layers) + ] # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayerCollection.__call__ with Bert->BigBird def __call__( @@ -1408,12 +1419,12 @@ def __call__( layer_outputs = layer( hidden_states, attention_mask, - layer_head_mask=head_mask[i] if head_mask is not None else None, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - init_cache=init_cache, - deterministic=deterministic, - output_attentions=output_attentions, + head_mask[i] if head_mask is not None else None, + encoder_hidden_states, + encoder_attention_mask, + init_cache, + deterministic, + output_attentions, ) hidden_states = layer_outputs[0] @@ -1444,9 +1455,14 @@ def __call__( class FlaxBigBirdEncoder(nn.Module): config: BigBirdConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation + gradient_checkpointing: bool = False def setup(self): - self.layer = FlaxBigBirdLayerCollection(self.config, dtype=self.dtype) + self.layer = FlaxBigBirdLayerCollection( + self.config, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) def __call__( self, @@ -1559,9 +1575,10 @@ def __init__( seed: int = 0, dtype: jnp.dtype = jnp.float32, _do_init: bool = True, + gradient_checkpointing: bool = False, **kwargs ): - module = self.module_class(config=config, dtype=dtype, **kwargs) + module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs) if config.attention_type == "block_sparse" and input_shape is None: input_shape = (1, 12 * config.block_size) elif input_shape is None: @@ -1569,6 +1586,14 @@ def __init__( super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.enable_gradient_checkpointing + def enable_gradient_checkpointing(self): + self._module = self.module_class( + config=self.config, + dtype=self.dtype, + gradient_checkpointing=True, + ) + # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.init_weights def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensors @@ -1735,10 +1760,13 @@ class FlaxBigBirdModule(nn.Module): config: BigBirdConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation add_pooling_layer: bool = True + gradient_checkpointing: bool = False def setup(self): self.embeddings = FlaxBigBirdEmbeddings(self.config, dtype=self.dtype) - self.encoder = FlaxBigBirdEncoder(self.config, dtype=self.dtype) + self.encoder = FlaxBigBirdEncoder( + self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing + ) self.pooler = nn.Dense( self.config.hidden_size, kernel_init=jax.nn.initializers.normal(self.config.initializer_range), @@ -1812,9 +1840,14 @@ class FlaxBigBirdModel(FlaxBigBirdPreTrainedModel): class FlaxBigBirdForPreTrainingModule(nn.Module): config: BigBirdConfig dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False def setup(self): - self.bert = FlaxBigBirdModule(config=self.config, dtype=self.dtype) + self.bert = FlaxBigBirdModule( + config=self.config, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) self.cls = FlaxBigBirdPreTrainingHeads(config=self.config, dtype=self.dtype) def __call__( @@ -1910,9 +1943,15 @@ class FlaxBigBirdForPreTraining(FlaxBigBirdPreTrainedModel): class FlaxBigBirdForMaskedLMModule(nn.Module): config: BigBirdConfig dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False def setup(self): - self.bert = FlaxBigBirdModule(config=self.config, add_pooling_layer=False, dtype=self.dtype) + self.bert = FlaxBigBirdModule( + config=self.config, + add_pooling_layer=False, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) self.cls = FlaxBigBirdOnlyMLMHead(config=self.config, dtype=self.dtype) def __call__( @@ -1999,9 +2038,12 @@ def __call__(self, features, deterministic=True): class FlaxBigBirdForSequenceClassificationModule(nn.Module): config: BigBirdConfig dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False def setup(self): - self.bert = FlaxBigBirdModule(config=self.config, dtype=self.dtype) + self.bert = FlaxBigBirdModule( + config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing + ) self.classifier = FlaxBigBirdClassificationHead(self.config, dtype=self.dtype) def __call__( @@ -2067,9 +2109,14 @@ class FlaxBigBirdForSequenceClassification(FlaxBigBirdPreTrainedModel): class FlaxBigBirdForMultipleChoiceModule(nn.Module): config: BigBirdConfig dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False def setup(self): - self.bert = FlaxBigBirdModule(config=self.config, dtype=self.dtype) + self.bert = FlaxBigBirdModule( + config=self.config, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) self.classifier = nn.Dense(1, dtype=self.dtype) @@ -2162,9 +2209,15 @@ def __init__( class FlaxBigBirdForTokenClassificationModule(nn.Module): config: BigBirdConfig dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False def setup(self): - self.bert = FlaxBigBirdModule(config=self.config, dtype=self.dtype, add_pooling_layer=False) + self.bert = FlaxBigBirdModule( + config=self.config, + dtype=self.dtype, + add_pooling_layer=False, + gradient_checkpointing=self.gradient_checkpointing, + ) classifier_dropout = ( self.config.classifier_dropout if self.config.classifier_dropout is not None @@ -2255,10 +2308,16 @@ class FlaxBigBirdForQuestionAnsweringModule(nn.Module): config: BigBirdConfig dtype: jnp.dtype = jnp.float32 add_pooling_layer: bool = False + gradient_checkpointing: bool = False def setup(self): self.config.num_labels = 2 - self.bert = FlaxBigBirdModule(self.config, dtype=self.dtype, add_pooling_layer=self.add_pooling_layer) + self.bert = FlaxBigBirdModule( + self.config, + dtype=self.dtype, + add_pooling_layer=self.add_pooling_layer, + gradient_checkpointing=self.gradient_checkpointing, + ) self.qa_classifier = FlaxBigBirdForQuestionAnsweringHead(self.config, dtype=self.dtype) def __call__( @@ -2414,9 +2473,15 @@ def prepare_question_mask(q_lengths, maxlen: int): class FlaxBigBirdForCausalLMModule(nn.Module): config: BigBirdConfig dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False def setup(self): - self.bert = FlaxBigBirdModule(config=self.config, add_pooling_layer=False, dtype=self.dtype) + self.bert = FlaxBigBirdModule( + config=self.config, + add_pooling_layer=False, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) self.cls = FlaxBigBirdOnlyMLMHead(config=self.config, dtype=self.dtype) def __call__( diff --git a/src/transformers/models/electra/modeling_flax_electra.py b/src/transformers/models/electra/modeling_flax_electra.py index 3e3a7103f07e30..5f02c01a650e12 100644 --- a/src/transformers/models/electra/modeling_flax_electra.py +++ b/src/transformers/models/electra/modeling_flax_electra.py @@ -23,6 +23,7 @@ import jax.numpy as jnp from flax.core.frozen_dict import FrozenDict, freeze, unfreeze from flax.linen import combine_masks, make_causal_mask +from flax.linen import partitioning as nn_partitioning from flax.linen.attention import dot_product_attention_weights from flax.traverse_util import flatten_dict, unflatten_dict from jax import lax @@ -54,6 +55,8 @@ _CONFIG_FOR_DOC = "ElectraConfig" _TOKENIZER_FOR_DOC = "ElectraTokenizer" +remat = nn_partitioning.remat + @flax.struct.dataclass class FlaxElectraForPreTrainingOutput(ModelOutput): @@ -521,11 +524,20 @@ def __call__( class FlaxElectraLayerCollection(nn.Module): config: ElectraConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation + gradient_checkpointing: bool = False def setup(self): - self.layers = [ - FlaxElectraLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers) - ] + if self.gradient_checkpointing: + FlaxElectraCheckpointLayer = remat(FlaxElectraLayer, static_argnums=(5, 6, 7)) + self.layers = [ + FlaxElectraCheckpointLayer(self.config, name=str(i), dtype=self.dtype) + for i in range(self.config.num_hidden_layers) + ] + else: + self.layers = [ + FlaxElectraLayer(self.config, name=str(i), dtype=self.dtype) + for i in range(self.config.num_hidden_layers) + ] def __call__( self, @@ -559,12 +571,12 @@ def __call__( layer_outputs = layer( hidden_states, attention_mask, - layer_head_mask=head_mask[i] if head_mask is not None else None, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - init_cache=init_cache, - deterministic=deterministic, - output_attentions=output_attentions, + head_mask[i] if head_mask is not None else None, + encoder_hidden_states, + encoder_attention_mask, + init_cache, + deterministic, + output_attentions, ) hidden_states = layer_outputs[0] @@ -595,9 +607,14 @@ def __call__( class FlaxElectraEncoder(nn.Module): config: ElectraConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation + gradient_checkpointing: bool = False def setup(self): - self.layer = FlaxElectraLayerCollection(self.config, dtype=self.dtype) + self.layer = FlaxElectraLayerCollection( + self.config, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) def __call__( self, @@ -675,11 +692,20 @@ def __init__( seed: int = 0, dtype: jnp.dtype = jnp.float32, _do_init: bool = True, + gradient_checkpointing: bool = False, **kwargs ): - module = self.module_class(config=config, dtype=dtype, **kwargs) + module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.enable_gradient_checkpointing + def enable_gradient_checkpointing(self): + self._module = self.module_class( + config=self.config, + dtype=self.dtype, + gradient_checkpointing=True, + ) + # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.init_weights def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensors @@ -845,12 +871,15 @@ def __call__( class FlaxElectraModule(nn.Module): config: ElectraConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation + gradient_checkpointing: bool = False def setup(self): self.embeddings = FlaxElectraEmbeddings(self.config, dtype=self.dtype) if self.config.embedding_size != self.config.hidden_size: self.embeddings_project = nn.Dense(self.config.hidden_size, dtype=self.dtype) - self.encoder = FlaxElectraEncoder(self.config, dtype=self.dtype) + self.encoder = FlaxElectraEncoder( + self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing + ) def __call__( self, @@ -925,9 +954,12 @@ def __call__(self, x, kernel): class FlaxElectraForMaskedLMModule(nn.Module): config: ElectraConfig dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False def setup(self): - self.electra = FlaxElectraModule(config=self.config, dtype=self.dtype) + self.electra = FlaxElectraModule( + config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing + ) self.generator_predictions = FlaxElectraGeneratorPredictions(config=self.config, dtype=self.dtype) if self.config.tie_word_embeddings: self.generator_lm_head = FlaxElectraTiedDense(self.config.vocab_size, dtype=self.dtype) @@ -989,9 +1021,12 @@ class FlaxElectraForMaskedLM(FlaxElectraPreTrainedModel): class FlaxElectraForPreTrainingModule(nn.Module): config: ElectraConfig dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False def setup(self): - self.electra = FlaxElectraModule(config=self.config, dtype=self.dtype) + self.electra = FlaxElectraModule( + config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing + ) self.discriminator_predictions = FlaxElectraDiscriminatorPredictions(config=self.config, dtype=self.dtype) def __call__( @@ -1074,9 +1109,12 @@ class FlaxElectraForPreTraining(FlaxElectraPreTrainedModel): class FlaxElectraForTokenClassificationModule(nn.Module): config: ElectraConfig dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False def setup(self): - self.electra = FlaxElectraModule(config=self.config, dtype=self.dtype) + self.electra = FlaxElectraModule( + config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing + ) classifier_dropout = ( self.config.classifier_dropout if self.config.classifier_dropout is not None @@ -1218,9 +1256,12 @@ def __call__(self, hidden_states, cls_index=None, deterministic: bool = True): class FlaxElectraForMultipleChoiceModule(nn.Module): config: ElectraConfig dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False def setup(self): - self.electra = FlaxElectraModule(config=self.config, dtype=self.dtype) + self.electra = FlaxElectraModule( + config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing + ) self.sequence_summary = FlaxElectraSequenceSummary(config=self.config, dtype=self.dtype) self.classifier = nn.Dense(1, dtype=self.dtype) @@ -1297,9 +1338,12 @@ class FlaxElectraForMultipleChoice(FlaxElectraPreTrainedModel): class FlaxElectraForQuestionAnsweringModule(nn.Module): config: ElectraConfig dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False def setup(self): - self.electra = FlaxElectraModule(config=self.config, dtype=self.dtype) + self.electra = FlaxElectraModule( + config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing + ) self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype) def __call__( @@ -1392,9 +1436,12 @@ def __call__(self, hidden_states, deterministic: bool = True): class FlaxElectraForSequenceClassificationModule(nn.Module): config: ElectraConfig dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False def setup(self): - self.electra = FlaxElectraModule(config=self.config, dtype=self.dtype) + self.electra = FlaxElectraModule( + config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing + ) self.classifier = FlaxElectraClassificationHead(config=self.config, dtype=self.dtype) def __call__( @@ -1457,9 +1504,12 @@ class FlaxElectraForSequenceClassification(FlaxElectraPreTrainedModel): class FlaxElectraForCausalLMModule(nn.Module): config: ElectraConfig dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False def setup(self): - self.electra = FlaxElectraModule(config=self.config, dtype=self.dtype) + self.electra = FlaxElectraModule( + config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing + ) self.generator_predictions = FlaxElectraGeneratorPredictions(config=self.config, dtype=self.dtype) if self.config.tie_word_embeddings: self.generator_lm_head = FlaxElectraTiedDense(self.config.vocab_size, dtype=self.dtype) diff --git a/src/transformers/models/roberta/modeling_flax_roberta.py b/src/transformers/models/roberta/modeling_flax_roberta.py index 84bf15da6d8614..ddd6359b36be83 100644 --- a/src/transformers/models/roberta/modeling_flax_roberta.py +++ b/src/transformers/models/roberta/modeling_flax_roberta.py @@ -21,6 +21,7 @@ import jax.numpy as jnp from flax.core.frozen_dict import FrozenDict, freeze, unfreeze from flax.linen import combine_masks, make_causal_mask +from flax.linen import partitioning as nn_partitioning from flax.linen.attention import dot_product_attention_weights from flax.traverse_util import flatten_dict, unflatten_dict from jax import lax @@ -47,6 +48,8 @@ _CONFIG_FOR_DOC = "RobertaConfig" _TOKENIZER_FOR_DOC = "RobertaTokenizer" +remat = nn_partitioning.remat + def create_position_ids_from_input_ids(input_ids, padding_idx): """ @@ -511,11 +514,20 @@ def __call__( class FlaxRobertaLayerCollection(nn.Module): config: RobertaConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation + gradient_checkpointing: bool = False def setup(self): - self.layers = [ - FlaxRobertaLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers) - ] + if self.gradient_checkpointing: + FlaxRobertaCheckpointLayer = remat(FlaxRobertaLayer, static_argnums=(5, 6, 7)) + self.layers = [ + FlaxRobertaCheckpointLayer(self.config, name=str(i), dtype=self.dtype) + for i in range(self.config.num_hidden_layers) + ] + else: + self.layers = [ + FlaxRobertaLayer(self.config, name=str(i), dtype=self.dtype) + for i in range(self.config.num_hidden_layers) + ] def __call__( self, @@ -549,12 +561,12 @@ def __call__( layer_outputs = layer( hidden_states, attention_mask, - layer_head_mask=head_mask[i] if head_mask is not None else None, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - init_cache=init_cache, - deterministic=deterministic, - output_attentions=output_attentions, + head_mask[i] if head_mask is not None else None, + encoder_hidden_states, + encoder_attention_mask, + init_cache, + deterministic, + output_attentions, ) hidden_states = layer_outputs[0] @@ -585,9 +597,14 @@ def __call__( class FlaxRobertaEncoder(nn.Module): config: RobertaConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation + gradient_checkpointing: bool = False def setup(self): - self.layer = FlaxRobertaLayerCollection(self.config, dtype=self.dtype) + self.layer = FlaxRobertaLayerCollection( + self.config, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) def __call__( self, @@ -719,11 +736,20 @@ def __init__( seed: int = 0, dtype: jnp.dtype = jnp.float32, _do_init: bool = True, + gradient_checkpointing: bool = False, **kwargs ): - module = self.module_class(config=config, dtype=dtype, **kwargs) + module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.enable_gradient_checkpointing + def enable_gradient_checkpointing(self): + self._module = self.module_class( + config=self.config, + dtype=self.dtype, + gradient_checkpointing=True, + ) + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensors input_ids = jnp.zeros(input_shape, dtype="i4") @@ -889,10 +915,15 @@ class FlaxRobertaModule(nn.Module): config: RobertaConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation add_pooling_layer: bool = True + gradient_checkpointing: bool = False def setup(self): self.embeddings = FlaxRobertaEmbeddings(self.config, dtype=self.dtype) - self.encoder = FlaxRobertaEncoder(self.config, dtype=self.dtype) + self.encoder = FlaxRobertaEncoder( + self.config, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) self.pooler = FlaxRobertaPooler(self.config, dtype=self.dtype) def __call__( @@ -967,9 +998,15 @@ class FlaxRobertaModel(FlaxRobertaPreTrainedModel): class FlaxRobertaForMaskedLMModule(nn.Module): config: RobertaConfig dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False def setup(self): - self.roberta = FlaxRobertaModule(config=self.config, add_pooling_layer=False, dtype=self.dtype) + self.roberta = FlaxRobertaModule( + config=self.config, + add_pooling_layer=False, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) self.lm_head = FlaxRobertaLMHead(config=self.config, dtype=self.dtype) def __call__( @@ -1034,9 +1071,15 @@ class FlaxRobertaForMaskedLM(FlaxRobertaPreTrainedModel): class FlaxRobertaForSequenceClassificationModule(nn.Module): config: RobertaConfig dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False def setup(self): - self.roberta = FlaxRobertaModule(config=self.config, dtype=self.dtype, add_pooling_layer=False) + self.roberta = FlaxRobertaModule( + config=self.config, + dtype=self.dtype, + add_pooling_layer=False, + gradient_checkpointing=self.gradient_checkpointing, + ) self.classifier = FlaxRobertaClassificationHead(config=self.config, dtype=self.dtype) def __call__( @@ -1101,9 +1144,14 @@ class FlaxRobertaForSequenceClassification(FlaxRobertaPreTrainedModel): class FlaxRobertaForMultipleChoiceModule(nn.Module): config: RobertaConfig dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False def setup(self): - self.roberta = FlaxRobertaModule(config=self.config, dtype=self.dtype) + self.roberta = FlaxRobertaModule( + config=self.config, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) self.classifier = nn.Dense(1, dtype=self.dtype) @@ -1181,9 +1229,15 @@ class FlaxRobertaForMultipleChoice(FlaxRobertaPreTrainedModel): class FlaxRobertaForTokenClassificationModule(nn.Module): config: RobertaConfig dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False def setup(self): - self.roberta = FlaxRobertaModule(config=self.config, dtype=self.dtype, add_pooling_layer=False) + self.roberta = FlaxRobertaModule( + config=self.config, + dtype=self.dtype, + add_pooling_layer=False, + gradient_checkpointing=self.gradient_checkpointing, + ) classifier_dropout = ( self.config.classifier_dropout if self.config.classifier_dropout is not None @@ -1255,9 +1309,15 @@ class FlaxRobertaForTokenClassification(FlaxRobertaPreTrainedModel): class FlaxRobertaForQuestionAnsweringModule(nn.Module): config: RobertaConfig dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False def setup(self): - self.roberta = FlaxRobertaModule(config=self.config, dtype=self.dtype, add_pooling_layer=False) + self.roberta = FlaxRobertaModule( + config=self.config, + dtype=self.dtype, + add_pooling_layer=False, + gradient_checkpointing=self.gradient_checkpointing, + ) self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype) def __call__( @@ -1326,9 +1386,15 @@ class FlaxRobertaForQuestionAnswering(FlaxRobertaPreTrainedModel): class FlaxRobertaForCausalLMModule(nn.Module): config: RobertaConfig dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False def setup(self): - self.roberta = FlaxRobertaModule(config=self.config, add_pooling_layer=False, dtype=self.dtype) + self.roberta = FlaxRobertaModule( + config=self.config, + add_pooling_layer=False, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) self.lm_head = FlaxRobertaLMHead(config=self.config, dtype=self.dtype) def __call__( diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_flax_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_flax_{{cookiecutter.lowercase_modelname}}.py index 451dc03f62ed13..676270c131fbca 100644 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_flax_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_flax_{{cookiecutter.lowercase_modelname}}.py @@ -25,6 +25,7 @@ import jax.numpy as jnp from flax.core.frozen_dict import FrozenDict, unfreeze, freeze from flax.linen import combine_masks, make_causal_mask +from flax.linen import partitioning as nn_partitioning from flax.traverse_util import flatten_dict, unflatten_dict from flax.linen.attention import dot_product_attention_weights from jax import lax @@ -126,6 +127,8 @@ """ +remat = nn_partitioning.remat + # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEmbeddings with Bert->{{cookiecutter.camelcase_modelname}} @@ -507,11 +510,19 @@ def __call__( class Flax{{cookiecutter.camelcase_modelname}}LayerCollection(nn.Module): config: {{cookiecutter.camelcase_modelname}}Config dtype: jnp.dtype = jnp.float32 # the dtype of the computation + gradient_checkpointing: bool = False def setup(self): - self.layers = [ - Flax{{cookiecutter.camelcase_modelname}}Layer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers) - ] + if self.gradient_checkpointing: + Flax{{cookiecutter.camelcase_modelname}}CheckpointLayer = remat(Flax{{cookiecutter.camelcase_modelname}}Layer, static_argnums=(5, 6, 7)) + self.layers = [ + Flax{{cookiecutter.camelcase_modelname}}CheckpointLayer(self.config, name=str(i), dtype=self.dtype) + for i in range(self.config.num_hidden_layers) + ] + else: + self.layers = [ + Flax{{cookiecutter.camelcase_modelname}}Layer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers) + ] def __call__( self, @@ -545,12 +556,12 @@ def __call__( layer_outputs = layer( hidden_states, attention_mask, - layer_head_mask=head_mask[i] if head_mask is not None else None, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - init_cache=init_cache, - deterministic=deterministic, - output_attentions=output_attentions, + head_mask[i] if head_mask is not None else None, + encoder_hidden_states, + encoder_attention_mask, + init_cache, + deterministic, + output_attentions, ) hidden_states = layer_outputs[0] @@ -581,9 +592,10 @@ def __call__( class Flax{{cookiecutter.camelcase_modelname}}Encoder(nn.Module): config: {{cookiecutter.camelcase_modelname}}Config dtype: jnp.dtype = jnp.float32 # the dtype of the computation + gradient_checkpointing: bool = False def setup(self): - self.layer = Flax{{cookiecutter.camelcase_modelname}}LayerCollection(self.config, dtype=self.dtype) + self.layer = Flax{{cookiecutter.camelcase_modelname}}LayerCollection(self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing) def __call__( self, @@ -725,11 +737,20 @@ def __init__( seed: int = 0, dtype: jnp.dtype = jnp.float32, _do_init: bool = True, + gradient_checkpointing: bool = False, **kwargs ): - module = self.module_class(config=config, dtype=dtype, **kwargs) + module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.enable_gradient_checkpointing + def enable_gradient_checkpointing(self): + self._module = self.module_class( + config=self.config, + dtype=self.dtype, + gradient_checkpointing=True, + ) + # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.init_weights with Bert->{{cookiecutter.camelcase_modelname}} def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensors @@ -897,10 +918,11 @@ class Flax{{cookiecutter.camelcase_modelname}}Module(nn.Module): config: {{cookiecutter.camelcase_modelname}}Config dtype: jnp.dtype = jnp.float32 # the dtype of the computation add_pooling_layer: bool = True + gradient_checkpointing: bool = False def setup(self): self.embeddings = Flax{{cookiecutter.camelcase_modelname}}Embeddings(self.config, dtype=self.dtype) - self.encoder = Flax{{cookiecutter.camelcase_modelname}}Encoder(self.config, dtype=self.dtype) + self.encoder = Flax{{cookiecutter.camelcase_modelname}}Encoder(self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing) self.pooler = Flax{{cookiecutter.camelcase_modelname}}Pooler(self.config, dtype=self.dtype) def __call__( @@ -969,9 +991,10 @@ class Flax{{cookiecutter.camelcase_modelname}}Model(Flax{{cookiecutter.camelcase class Flax{{cookiecutter.camelcase_modelname}}ForMaskedLMModule(nn.Module): config: {{cookiecutter.camelcase_modelname}}Config dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False def setup(self): - self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, add_pooling_layer=False, dtype=self.dtype) + self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, add_pooling_layer=False, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing) self.cls = Flax{{cookiecutter.camelcase_modelname}}OnlyMLMHead(config=self.config, dtype=self.dtype) def __call__( @@ -1030,9 +1053,10 @@ class Flax{{cookiecutter.camelcase_modelname}}ForMaskedLM(Flax{{cookiecutter.cam class Flax{{cookiecutter.camelcase_modelname}}ForCausalLMModule(nn.Module): config: {{cookiecutter.camelcase_modelname}}Config dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False def setup(self): - self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, add_pooling_layer=False, dtype=self.dtype) + self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, add_pooling_layer=False, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing) self.cls = Flax{{cookiecutter.camelcase_modelname}}OnlyMLMHead(config=self.config, dtype=self.dtype) def __call__( @@ -1092,9 +1116,10 @@ class Flax{{cookiecutter.camelcase_modelname}}ForCausalLM(Flax{{cookiecutter.cam class Flax{{cookiecutter.camelcase_modelname}}ForSequenceClassificationModule(nn.Module): config: {{cookiecutter.camelcase_modelname}}Config dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False def setup(self): - self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, dtype=self.dtype) + self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) self.classifier = nn.Dense( self.config.num_labels, @@ -1163,9 +1188,10 @@ class Flax{{cookiecutter.camelcase_modelname}}ForSequenceClassification(Flax{{co class Flax{{cookiecutter.camelcase_modelname}}ForMultipleChoiceModule(nn.Module): config: {{cookiecutter.camelcase_modelname}}Config dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False def setup(self): - self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, dtype=self.dtype) + self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) self.classifier = nn.Dense(1, dtype=self.dtype) @@ -1238,9 +1264,10 @@ class Flax{{cookiecutter.camelcase_modelname}}ForMultipleChoice(Flax{{cookiecutt class Flax{{cookiecutter.camelcase_modelname}}ForTokenClassificationModule(nn.Module): config: {{cookiecutter.camelcase_modelname}}Config dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False def setup(self): - self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, dtype=self.dtype, add_pooling_layer=False) + self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, dtype=self.dtype, add_pooling_layer=False, gradient_checkpointing=self.gradient_checkpointing) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype) @@ -1302,9 +1329,10 @@ class Flax{{cookiecutter.camelcase_modelname}}ForTokenClassification(Flax{{cooki class Flax{{cookiecutter.camelcase_modelname}}ForQuestionAnsweringModule(nn.Module): config: {{cookiecutter.camelcase_modelname}}Config dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False def setup(self): - self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, dtype=self.dtype, add_pooling_layer=False) + self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, dtype=self.dtype, add_pooling_layer=False, gradient_checkpointing=self.gradient_checkpointing) self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype) def __call__( @@ -1373,9 +1401,10 @@ class Flax{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(Flax{{cookiec class Flax{{cookiecutter.camelcase_modelname}}ForCausalLMModule(nn.Module): config: {{cookiecutter.camelcase_modelname}}Config dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False def setup(self): - self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, add_pooling_layer=False, dtype=self.dtype) + self.{{cookiecutter.lowercase_modelname}} = Flax{{cookiecutter.camelcase_modelname}}Module(config=self.config, add_pooling_layer=False, dtype=self.dtype, gradient_checkpointing=self.gradient_checkpointing) self.cls = Flax{{cookiecutter.camelcase_modelname}}OnlyMLMHead(config=self.config, dtype=self.dtype) def __call__( diff --git a/tests/test_modeling_flax_common.py b/tests/test_modeling_flax_common.py index ec3c1fcd0bc377..f90615efea3604 100644 --- a/tests/test_modeling_flax_common.py +++ b/tests/test_modeling_flax_common.py @@ -1099,6 +1099,33 @@ def test_checkpoint_sharding_local(self): for p1, p2 in zip(flatten_dict(model.params).values(), flatten_dict(new_model.params).values()): self.assertTrue(np.allclose(np.array(p1), np.array(p2))) + def test_gradient_checkpointing(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + # prepare inputs + prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) + model = model_class(config) + remat_model = model_class(config) + + try: + remat_model.enable_gradient_checkpointing() + except NotImplementedError: + continue + + outputs = model(**prepared_inputs_dict) + remat_outputs = remat_model(**prepared_inputs_dict) + + # ensure that the dicts of outputs contain the same keys + self.assertEqual(outputs.keys(), remat_outputs.keys()) + + outputs = outputs.to_tuple() + remat_outputs = remat_outputs.to_tuple() + + # ensure that the outputs remain precisely equal + for output, remat_output in zip(outputs, remat_outputs): + self.assertTrue((output == remat_output).all()) + @require_flax @is_staging_test