Skip to content

Commit

Permalink
[transformer] set use_reentrant=False for gradient ckpt (#2491)
Browse files Browse the repository at this point in the history
  • Loading branch information
xingchensong authored Apr 18, 2024
1 parent a30dbb7 commit 69a084f
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 4 deletions.
4 changes: 2 additions & 2 deletions wenet/paraformer/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def forward_layers_checkpointed(self, xs: torch.Tensor,
xs, _, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
for layer in self.encoders:
xs, _, _, _ = ckpt.checkpoint(layer.__call__, xs, chunk_masks,
pos_emb, mask_pad)
pos_emb, mask_pad, use_reentrant=False)
return xs


Expand Down Expand Up @@ -481,7 +481,7 @@ def forward_layers_checkpointed(self, x: torch.Tensor,
x, _, _, _ = layer(x, tgt_mask, memory, memory_mask)
else:
x, _, _, _ = ckpt.checkpoint(layer.__call__, x, tgt_mask,
memory, memory_mask)
memory, memory_mask, use_reentrant=False)
for layer in self.decoders3:
x = layer(x)
return x
3 changes: 2 additions & 1 deletion wenet/transformer/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,8 @@ def forward_layers_checkpointed(self, x: torch.Tensor,
memory_mask: torch.Tensor) -> torch.Tensor:
for layer in self.decoders:
x, tgt_mask, memory, memory_mask = ckpt.checkpoint(
layer.__call__, x, tgt_mask, memory, memory_mask)
layer.__call__, x, tgt_mask, memory, memory_mask,
use_reentrant=False)
return x

def forward_one_step(
Expand Down
2 changes: 1 addition & 1 deletion wenet/transformer/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def forward_layers_checkpointed(self, xs: torch.Tensor,
for layer in self.encoders:
xs, chunk_masks, _, _ = ckpt.checkpoint(layer.__call__, xs,
chunk_masks, pos_emb,
mask_pad)
mask_pad, use_reentrant=False)
return xs

def forward_chunk(
Expand Down

0 comments on commit 69a084f

Please sign in to comment.