Skip to content

Commit

Permalink
fix naming
Browse files Browse the repository at this point in the history
  • Loading branch information
sanchit-gandhi committed Jun 27, 2022
1 parent 90443a8 commit 80188b6
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/transformers/models/bert/modeling_flax_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,15 +774,15 @@ def __init__(
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
_do_init: bool = True,
_gradient_checkpointing: bool = False,
_remat_policy: Callable[..., bool] = (None,),
gradient_checkpointing: bool = False,
remat_policy: Callable[..., bool] = (None,),
**kwargs
):
module = self.module_class(
config=config,
dtype=dtype,
gradient_checkpointing=_gradient_checkpointing,
remat_policy=_remat_policy,
gradient_checkpointing=gradient_checkpointing,
remat_policy=remat_policy,
**kwargs,
)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
Expand Down

0 comments on commit 80188b6

Please sign in to comment.