diff --git a/src/transformers/models/bert/modeling_flax_bert.py b/src/transformers/models/bert/modeling_flax_bert.py index 173f158a56ecde..b0c5ccdac6e99c 100644 --- a/src/transformers/models/bert/modeling_flax_bert.py +++ b/src/transformers/models/bert/modeling_flax_bert.py @@ -552,12 +552,12 @@ class FlaxBertLayerCollection(nn.Module): def setup(self): if self.gradient_checkpointing: - FlaxBertBlockLayer = remat(FlaxBertLayer, static_argnums=(5, 6, 7), policy=self.remat_policy) + FlaxBertCheckpointLayer = remat(FlaxBertLayer, static_argnums=(5, 6, 7), policy=self.remat_policy) else: - FlaxBertBlockLayer = FlaxBertLayer + FlaxBertCheckpointLayer = FlaxBertLayer self.layers = [ - FlaxBertBlockLayer(self.config, name=str(i), dtype=self.dtype) + FlaxBertCheckpointLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers) ]