Skip to content

Commit

Permalink
[cleanup] vestiges of causal mask (#29806)
Browse files Browse the repository at this point in the history
nit
  • Loading branch information
ArthurZucker authored and amyeroberts committed Mar 22, 2024
1 parent dc8b789 commit e49ebae
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 10 deletions.
6 changes: 0 additions & 6 deletions src/transformers/models/cohere/modeling_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,12 +825,6 @@ def __init__(self, config: CohereConfig):
self.norm = CohereLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.gradient_checkpointing = False

# Register a causal mask to separate causal and padding mask creation. Merging happens in the attention class.
# NOTE: This is not friendly with TorchScript, ONNX, ExportedProgram serialization for very large `max_position_embeddings`.
causal_mask = torch.full(
(config.max_position_embeddings, config.max_position_embeddings), fill_value=True, dtype=torch.bool
)
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
# Initialize weights and apply final processing
self.post_init()

Expand Down
4 changes: 0 additions & 4 deletions src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,10 +719,6 @@ def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] =
"make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
)

if max_cache_len > self.model.causal_mask.shape[-1] or self.device != self.model.causal_mask.device:
causal_mask = torch.full((max_cache_len, max_cache_len), fill_value=1, device=self.device)
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)

for layer in self.model.layers:
weights = layer.self_attn.o_proj.weight
layer.self_attn.past_key_value = cache_cls(
Expand Down

0 comments on commit e49ebae

Please sign in to comment.