Skip to content

Commit

Permalink
Batched inference
Browse files Browse the repository at this point in the history
  • Loading branch information
pglorio committed Aug 14, 2024
1 parent 435a119 commit e6c6278
Showing 1 changed file with 21 additions and 9 deletions.
30 changes: 21 additions & 9 deletions src/transformers/models/zamba/modeling_zamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,7 +729,7 @@ def __init__(self, config: ZambaConfig, layer_idx):
" https://github.com/Dao-AILab/causal-conv1d. If you want to use the naive implementation, set `use_mamba_kernels=False` in the model config"
)

def cuda_kernels_forward(self, hidden_states: torch.Tensor, cache_params: HybridMambaAttentionDynamicCache = None):
def cuda_kernels_forward(self, hidden_states: torch.Tensor, cache_params: HybridMambaAttentionDynamicCache = None, attention_mask = None):
batch_size, seq_len, _ = hidden_states.shape
use_precomputed_states = cache_params is not None and cache_params.has_previous_state and seq_len == 1

Expand All @@ -753,10 +753,14 @@ def cuda_kernels_forward(self, hidden_states: torch.Tensor, cache_params: Hybrid
)
hidden_states = hidden_states.unsqueeze(-1)
else:
if not torch.all(attention_mask==1):
hidden_states = hidden_states * attention_mask.unsqueeze(1)
if cache_params is not None:
conv_states = nn.functional.pad(hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0))
cache_params.conv_states[self.layer_idx].copy_(conv_states)
hidden_states = causal_conv1d_fn(hidden_states, conv_weights, self.conv1d.bias, activation=self.activation)
if not torch.all(attention_mask==1):
hidden_states = hidden_states * attention_mask.unsqueeze(1)

# 3. State Space Model sequence transformation
# 3.a. input varying initialization of time_step, B and C
Expand Down Expand Up @@ -822,7 +826,7 @@ def cuda_kernels_forward(self, hidden_states: torch.Tensor, cache_params: Hybrid
contextualized_states = self.out_proj(scan_outputs.transpose(1, 2))
return contextualized_states

def slow_forward(self, input_states, cache_params: HybridMambaAttentionDynamicCache = None):
def slow_forward(self, input_states, cache_params: HybridMambaAttentionDynamicCache = None, attention_mask = None):
batch_size, seq_len, _ = input_states.shape
dtype = input_states.dtype
# 1. Gated linear projection
Expand Down Expand Up @@ -858,17 +862,25 @@ def slow_forward(self, input_states, cache_params: HybridMambaAttentionDynamicCa
hidden_states += self.conv1d.bias
hidden_states = self.act(hidden_states).to(dtype).unsqueeze(-1) # (b d 1) : decoding
else:
if not torch.all(attention_mask==1):
hidden_states = hidden_states * attention_mask.unsqueeze(1)
conv_state = nn.functional.pad(hidden_states, (self.conv_kernel_size - hidden_states.shape[-1], 0))
cache_params.conv_states[self.layer_idx] = conv_state
hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # (b d l)
if not torch.all(attention_mask==1):
hidden_states = hidden_states * attention_mask.unsqueeze(1)
else:
ssm_state = torch.zeros(
(batch_size, self.n_mamba_heads, self.intermediate_size // self.n_mamba_heads, self.ssm_state_size),
device=hidden_states.device,
dtype=dtype,
) # (b h d l)
if not torch.all(attention_mask==1):
hidden_states = hidden_states * attention_mask.unsqueeze(1)
hidden_states = self.act(self.conv1d(hidden_states)[..., :seq_len]) # (b d l)

if not torch.all(attention_mask==1):
hidden_states = hidden_states * attention_mask.unsqueeze(1)

# 3. State Space Model sequence transformation
# 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2]
hidden_states = hidden_states.reshape(
Expand All @@ -879,7 +891,6 @@ def slow_forward(self, input_states, cache_params: HybridMambaAttentionDynamicCa
time_step, B, C = torch.split(
ssm_parameters, [self.time_step_rank, self.ssm_state_size, self.ssm_state_size], dim=-1
) # (h b l d)

discrete_time_step = (self.dt_proj_weight[:, None] @ time_step.transpose(-1, -2)) + self.dt_proj_bias[
:, None, :, None
] # (h b d l)
Expand Down Expand Up @@ -912,16 +923,16 @@ def slow_forward(self, input_states, cache_params: HybridMambaAttentionDynamicCa
) # (b l d)
return contextualized_states

def forward(self, hidden_states, cache_params: HybridMambaAttentionDynamicCache = None):
def forward(self, hidden_states, cache_params: HybridMambaAttentionDynamicCache = None, attention_mask = None):
if self.use_fast_kernels:
if not is_fast_path_available or "cuda" not in self.x_proj_weight.device.type:
raise ValueError(
"Fast Mamba kernels are not available. Make sure to they are installed and that "
"the mamba module is on a CUDA device. lease run 'pip install causal-conv1d>=1.2.0' "
"and 'pip install mamba-ssm', or set use_fast_kernels=False in the model's config."
)
return self.cuda_kernels_forward(hidden_states, cache_params)
return self.slow_forward(hidden_states, cache_params)
return self.cuda_kernels_forward(hidden_states, cache_params, attention_mask=attention_mask)
return self.slow_forward(hidden_states, cache_params, attention_mask=attention_mask)


class ZambaMLP(nn.Module):
Expand Down Expand Up @@ -1056,6 +1067,7 @@ def forward(
hidden_states = self.mamba(
hidden_states=hidden_states,
cache_params=past_key_value,
attention_mask=attention_mask,
)

self_attn_weights = None
Expand Down Expand Up @@ -1342,7 +1354,7 @@ def forward(
next(mamba_layers).__call__,
hidden_states,
transformer_hidden_states,
causal_mask,
attention_mask,
position_ids,
past_key_values,
output_attentions,
Expand All @@ -1353,7 +1365,7 @@ def forward(
layer_outputs = next(mamba_layers)(
hidden_states,
transformer_hidden_states=transformer_hidden_states,
attention_mask=causal_mask,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
Expand Down

0 comments on commit e6c6278

Please sign in to comment.