Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
pglorio committed Aug 24, 2024
1 parent e51113d commit cf6ee16
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/transformers/models/zamba/modeling_zamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,12 +868,12 @@ def slow_forward(self, input_states, cache_params: HybridMambaAttentionDynamicCa
hidden_states = self.act(hidden_states).to(dtype).unsqueeze(-1) # (b d 1) : decoding
else:
if attention_mask is not None and not torch.all(attention_mask == 1):
hidden_states = hidden_states * attention_mask[:, -hidden_states.shape[-1]:].unsqueeze(1)
hidden_states = hidden_states * attention_mask[:, -hidden_states.shape[-1] :].unsqueeze(1)
conv_state = nn.functional.pad(hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0))
cache_params.conv_states[self.layer_idx] = conv_state
hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # (b d l)
if attention_mask is not None and not torch.all(attention_mask == 1):
hidden_states = hidden_states * attention_mask[:, -hidden_states.shape[-1]:].unsqueeze(1)
hidden_states = hidden_states * attention_mask[:, -hidden_states.shape[-1] :].unsqueeze(1)
else:
ssm_state = torch.zeros(
(batch_size, self.n_mamba_heads, self.intermediate_size // self.n_mamba_heads, self.ssm_state_size),
Expand Down

0 comments on commit cf6ee16

Please sign in to comment.