Skip to content

Commit

Permalink
fix class naming
Browse files Browse the repository at this point in the history
  • Loading branch information
sanchit-gandhi committed Jun 27, 2022
1 parent 80188b6 commit ea8150a
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/transformers/models/bert/modeling_flax_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,12 +552,12 @@ class FlaxBertLayerCollection(nn.Module):

def setup(self):
if self.gradient_checkpointing:
FlaxBertBlockLayer = remat(FlaxBertLayer, static_argnums=(5, 6, 7), policy=self.remat_policy)
FlaxBertCheckpointLayer = remat(FlaxBertLayer, static_argnums=(5, 6, 7), policy=self.remat_policy)
else:
FlaxBertBlockLayer = FlaxBertLayer
FlaxBertCheckpointLayer = FlaxBertLayer

self.layers = [
FlaxBertBlockLayer(self.config, name=str(i), dtype=self.dtype)
FlaxBertCheckpointLayer(self.config, name=str(i), dtype=self.dtype)
for i in range(self.config.num_hidden_layers)
]

Expand Down

0 comments on commit ea8150a

Please sign in to comment.