diff --git a/src/transformers/models/zamba/modeling_zamba.py b/src/transformers/models/zamba/modeling_zamba.py index b6c1e6e777cf6c..324df45027dd32 100644 --- a/src/transformers/models/zamba/modeling_zamba.py +++ b/src/transformers/models/zamba/modeling_zamba.py @@ -880,7 +880,7 @@ def slow_forward(self, input_states, cache_params: HybridMambaAttentionDynamicCa hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # (b d l) if not torch.all(attention_mask==1): hidden_states = hidden_states * attention_mask.unsqueeze(1) - + # 3. State Space Model sequence transformation # 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2] hidden_states = hidden_states.reshape(