diff --git a/src/transformers/models/zamba/configuration_zamba.py b/src/transformers/models/zamba/configuration_zamba.py index dff48b51b48802..23bad4323f990b 100644 --- a/src/transformers/models/zamba/configuration_zamba.py +++ b/src/transformers/models/zamba/configuration_zamba.py @@ -50,6 +50,8 @@ class ZambaConfig(PretrainedConfig): Number of hidden layers in the model. num_attention_heads (`int`, *optional*, defaults to 16): Number of attention heads for each attention layer in the Transformer decoder. + attention_head_dim (`int`, *optional*): + Dimension of the attention head in the Transformer decoder. num_key_value_heads (`int`, *optional*): This is the number of key_value heads that should be used to implement Grouped Query Attention. If `num_key_value_heads=None`, the model will use Multi Head Attention (MHA), if @@ -127,7 +129,7 @@ def __init__( intermediate_size=14848, num_hidden_layers=76, num_attention_heads=16, - attention_head_dim=None, + attention_head_dim=None, num_key_value_heads=None, n_mamba_heads=2, hidden_act="gelu", @@ -193,8 +195,10 @@ def __init__( self.mamba_proj_bias = mamba_proj_bias self.layers_block_type = self._layers_block_type(num_hidden_layers, attn_layer_period, attn_layer_offset) - - assert (self.mamba_expand * self.hidden_size) % self.n_mamba_heads == 0, '`intermediate_size` should be divisible by `n_mamba_heads`.' + + assert ( + self.mamba_expand * self.hidden_size + ) % self.n_mamba_heads == 0, "`intermediate_size` should be divisible by `n_mamba_heads`." super().__init__( pad_token_id=pad_token_id, diff --git a/src/transformers/models/zamba/modeling_zamba.py b/src/transformers/models/zamba/modeling_zamba.py index 44febe4f1cfa2e..7fb3682400e38d 100644 --- a/src/transformers/models/zamba/modeling_zamba.py +++ b/src/transformers/models/zamba/modeling_zamba.py @@ -610,17 +610,13 @@ def __init__(self, config: ZambaConfig, layer_idx): * 2 / self.time_step_rank**0.5 ) - self.dt_proj_bias = nn.Parameter( - torch.zeros(self.n_mamba_heads, self.mamba_head_dim) - ) + self.dt_proj_bias = nn.Parameter(torch.zeros(self.n_mamba_heads, self.mamba_head_dim)) # S4D real initialization. These are not discretized! # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded A = torch.arange(1, self.ssm_state_size + 1, dtype=torch.float32)[None, :] A = A.expand(self.intermediate_size, -1).contiguous() - self.A_log = nn.Parameter( - torch.log(A).reshape(self.n_mamba_heads, self.mamba_head_dim, -1) - ) + self.A_log = nn.Parameter(torch.log(A).reshape(self.n_mamba_heads, self.mamba_head_dim, -1)) self.D = nn.Parameter(torch.ones(self.n_mamba_heads, self.mamba_head_dim)) self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=self.use_bias) @@ -669,9 +665,7 @@ def cuda_kernels_forward( # 3. SSM sequence transformation # 3.a. input varying initialization of time_step, B and C - hidden_states = hidden_states.reshape( - -1, self.n_mamba_heads, self.mamba_head_dim, seq_len - ).transpose(0, 1) + hidden_states = hidden_states.reshape(-1, self.n_mamba_heads, self.mamba_head_dim, seq_len).transpose(0, 1) ssm_parameters = (self.x_proj_weight[:, None, :, :] @ hidden_states).transpose(-1, -2) time_step, B, C = torch.split( @@ -787,9 +781,7 @@ def slow_forward(self, input_states, cache_params: HybridMambaAttentionDynamicCa # 3. State Space Model sequence transformation # 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2] - hidden_states = hidden_states.reshape( - -1, self.n_mamba_heads, self.mamba_head_dim, seq_len - ).transpose(0, 1) + hidden_states = hidden_states.reshape(-1, self.n_mamba_heads, self.mamba_head_dim, seq_len).transpose(0, 1) ssm_parameters = (self.x_proj_weight[:, None, :, :] @ hidden_states).transpose(-1, -2) time_step, B, C = torch.split( @@ -809,9 +801,7 @@ def slow_forward(self, input_states, cache_params: HybridMambaAttentionDynamicCa # 3.c perform the recurrence y ← SSM(A, B, C)(x) scan_outputs = [] for i in range(seq_len): - ssm_state = discrete_A[:, :, :, i, :].transpose(0, 1) * ssm_state + deltaB_u[:, :, :, i, :].transpose( - 0, 1 - ) + ssm_state = discrete_A[:, :, :, i, :].transpose(0, 1) * ssm_state + deltaB_u[:, :, :, i, :].transpose(0, 1) scan_output = torch.matmul(ssm_state.transpose(0, 1).to(dtype), C[:, :, i, :].unsqueeze(-1)) scan_outputs.append(scan_output[:, :, :, 0]) scan_output = torch.stack(scan_outputs, dim=-1) @@ -997,7 +987,9 @@ def forward( class HybridLayer(nn.Module): - def __init__(self, shared_transformer: ZambaAttentionDecoderLayer, linear: nn.Linear, mamba: ZambaMambaDecoderLayer): + def __init__( + self, shared_transformer: ZambaAttentionDecoderLayer, linear: nn.Linear, mamba: ZambaMambaDecoderLayer + ): super().__init__() self.shared_transformer = shared_transformer self.linear = linear @@ -1133,7 +1125,7 @@ def _init_weights(self, module): with torch.no_grad(): module.dt_proj_bias.copy_(inv_dt) module.dt_proj_bias._no_reinit = True - + @classmethod @classmethod def _check_and_enable_flash_attn_2(