Skip to content

Commit

Permalink
Fix circle/ci tests
Browse files Browse the repository at this point in the history
  • Loading branch information
pglorio committed Sep 24, 2024
1 parent 1e4ffe6 commit 2c53db2
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 20 deletions.
10 changes: 7 additions & 3 deletions src/transformers/models/zamba/configuration_zamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ class ZambaConfig(PretrainedConfig):
Number of hidden layers in the model.
num_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer decoder.
attention_head_dim (`int`, *optional*):
Dimension of the attention head in the Transformer decoder.
num_key_value_heads (`int`, *optional*):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=None`, the model will use Multi Head Attention (MHA), if
Expand Down Expand Up @@ -127,7 +129,7 @@ def __init__(
intermediate_size=14848,
num_hidden_layers=76,
num_attention_heads=16,
attention_head_dim=None,
attention_head_dim=None,
num_key_value_heads=None,
n_mamba_heads=2,
hidden_act="gelu",
Expand Down Expand Up @@ -193,8 +195,10 @@ def __init__(
self.mamba_proj_bias = mamba_proj_bias

self.layers_block_type = self._layers_block_type(num_hidden_layers, attn_layer_period, attn_layer_offset)

assert (self.mamba_expand * self.hidden_size) % self.n_mamba_heads == 0, '`intermediate_size` should be divisible by `n_mamba_heads`.'

assert (
self.mamba_expand * self.hidden_size
) % self.n_mamba_heads == 0, "`intermediate_size` should be divisible by `n_mamba_heads`."

super().__init__(
pad_token_id=pad_token_id,
Expand Down
26 changes: 9 additions & 17 deletions src/transformers/models/zamba/modeling_zamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,17 +610,13 @@ def __init__(self, config: ZambaConfig, layer_idx):
* 2
/ self.time_step_rank**0.5
)
self.dt_proj_bias = nn.Parameter(
torch.zeros(self.n_mamba_heads, self.mamba_head_dim)
)
self.dt_proj_bias = nn.Parameter(torch.zeros(self.n_mamba_heads, self.mamba_head_dim))

# S4D real initialization. These are not discretized!
# The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
A = torch.arange(1, self.ssm_state_size + 1, dtype=torch.float32)[None, :]
A = A.expand(self.intermediate_size, -1).contiguous()
self.A_log = nn.Parameter(
torch.log(A).reshape(self.n_mamba_heads, self.mamba_head_dim, -1)
)
self.A_log = nn.Parameter(torch.log(A).reshape(self.n_mamba_heads, self.mamba_head_dim, -1))
self.D = nn.Parameter(torch.ones(self.n_mamba_heads, self.mamba_head_dim))
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=self.use_bias)

Expand Down Expand Up @@ -669,9 +665,7 @@ def cuda_kernels_forward(
# 3. SSM sequence transformation
# 3.a. input varying initialization of time_step, B and C

hidden_states = hidden_states.reshape(
-1, self.n_mamba_heads, self.mamba_head_dim, seq_len
).transpose(0, 1)
hidden_states = hidden_states.reshape(-1, self.n_mamba_heads, self.mamba_head_dim, seq_len).transpose(0, 1)
ssm_parameters = (self.x_proj_weight[:, None, :, :] @ hidden_states).transpose(-1, -2)

time_step, B, C = torch.split(
Expand Down Expand Up @@ -787,9 +781,7 @@ def slow_forward(self, input_states, cache_params: HybridMambaAttentionDynamicCa

# 3. State Space Model sequence transformation
# 3.a. Selection: [batch, seq_len, self.time_step_rank + self.ssm_state_size * 2]
hidden_states = hidden_states.reshape(
-1, self.n_mamba_heads, self.mamba_head_dim, seq_len
).transpose(0, 1)
hidden_states = hidden_states.reshape(-1, self.n_mamba_heads, self.mamba_head_dim, seq_len).transpose(0, 1)
ssm_parameters = (self.x_proj_weight[:, None, :, :] @ hidden_states).transpose(-1, -2)

time_step, B, C = torch.split(
Expand All @@ -809,9 +801,7 @@ def slow_forward(self, input_states, cache_params: HybridMambaAttentionDynamicCa
# 3.c perform the recurrence y ← SSM(A, B, C)(x)
scan_outputs = []
for i in range(seq_len):
ssm_state = discrete_A[:, :, :, i, :].transpose(0, 1) * ssm_state + deltaB_u[:, :, :, i, :].transpose(
0, 1
)
ssm_state = discrete_A[:, :, :, i, :].transpose(0, 1) * ssm_state + deltaB_u[:, :, :, i, :].transpose(0, 1)
scan_output = torch.matmul(ssm_state.transpose(0, 1).to(dtype), C[:, :, i, :].unsqueeze(-1))
scan_outputs.append(scan_output[:, :, :, 0])
scan_output = torch.stack(scan_outputs, dim=-1)
Expand Down Expand Up @@ -997,7 +987,9 @@ def forward(


class HybridLayer(nn.Module):
def __init__(self, shared_transformer: ZambaAttentionDecoderLayer, linear: nn.Linear, mamba: ZambaMambaDecoderLayer):
def __init__(
self, shared_transformer: ZambaAttentionDecoderLayer, linear: nn.Linear, mamba: ZambaMambaDecoderLayer
):
super().__init__()
self.shared_transformer = shared_transformer
self.linear = linear
Expand Down Expand Up @@ -1133,7 +1125,7 @@ def _init_weights(self, module):
with torch.no_grad():
module.dt_proj_bias.copy_(inv_dt)
module.dt_proj_bias._no_reinit = True

@classmethod
@classmethod
def _check_and_enable_flash_attn_2(
Expand Down

0 comments on commit 2c53db2

Please sign in to comment.