diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index db748013ff35e1..4d7d9c4b780750 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -825,12 +825,6 @@ def __init__(self, config: CohereConfig): self.norm = CohereLayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.gradient_checkpointing = False - # Register a causal mask to separate causal and padding mask creation. Merging happens in the attention class. - # NOTE: This is not friendly with TorchScript, ONNX, ExportedProgram serialization for very large `max_position_embeddings`. - causal_mask = torch.full( - (config.max_position_embeddings, config.max_position_embeddings), fill_value=True, dtype=torch.bool - ) - self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False) # Initialize weights and apply final processing self.post_init() diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 8360b4080781ff..c60c67d46e1bfb 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -719,10 +719,6 @@ def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] = "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" ) - if max_cache_len > self.model.causal_mask.shape[-1] or self.device != self.model.causal_mask.device: - causal_mask = torch.full((max_cache_len, max_cache_len), fill_value=1, device=self.device) - self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False) - for layer in self.model.layers: weights = layer.self_attn.o_proj.weight layer.self_attn.past_key_value = cache_cls(