Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How best to implement a differential transformer? #567

Open
Wilsontomass opened this issue Oct 22, 2024 · 2 comments
Open

How best to implement a differential transformer? #567

Wilsontomass opened this issue Oct 22, 2024 · 2 comments

Comments

@Wilsontomass
Copy link

Wilsontomass commented Oct 22, 2024

I'm not sure issues is the greatest place to post this but I just wanted to see if anyone else had been trying this idea:

There was a paper that came out recently that proposed a new head architecture, and I wanted to see if I could replicate the results (according to the paper they are very promising). It didn't seem too hard given what I knew from messing around with this repo. The authors provided 3 versions of the code here and to keep things simple I tried to use this implementation here. I added rotary positional encoding separately and tested that, it worked well, and then I added the differential mechanism, my code looks like this:

class CausalSelfAttention(nn.Module):
    def __init__(self, config, depth):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        self.q_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
        self.k_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
        self.v_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        # regularization
        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)
        self.n_head = config.n_head // 2  # div by 2 because each head is larger, so we only have half as many
        self.n_embd = config.n_embd
        self.dropout = config.dropout
        
        self.head_dim = self.n_embd // self.n_head // 2 # div by 2 because double key and query
        self.rotary_emb = RotaryEmbedding(dim=self.head_dim, max_position_embeddings=config.block_size)  # Added line
        
        self.lambda_init = lambda_init_fn(depth)
        self.lambda_q1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
        self.lambda_k1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
        self.lambda_q2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
        self.lambda_k2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
        self.subln = RMSNorm(2 * self.head_dim, eps=1e-5, elementwise_affine=False)
        # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
        self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
        if not self.flash:
            print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
            # causal mask to ensure that attention is only applied to the left in the input sequence
            self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                                        .view(1, 1, config.block_size, config.block_size))
    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)
        k = k.view(B, T, self.n_head*2, self.head_dim)  # (B, T, nh, hs)
        q = q.view(B, T, self.n_head*2, self.head_dim)  # (B, T, nh, hs)
        v = v.view(B, T, self.n_head, 2, self.head_dim)  # (B, T, nh, hs)
        # Apply rotary embeddings to q and k
        cos, sin = self.rotary_emb(q, seq_len=T)
        q = apply_rotary_pos_emb(q, cos, sin)
        k = apply_rotary_pos_emb(k, cos, sin)
        q = q.reshape(B, T, self.n_head, 2, self.head_dim)
        k = k.reshape(B, T, self.n_head, 2, self.head_dim)
        q1, q2 = q[:, :, :, 0], q[:, :, :, 1]
        k1, k2 = k[:, :, :, 0], k[:, :, :, 1]
        v1, v2 = v[:, :, :, 0], v[:, :, :, 1]
        attn11 = F.scaled_dot_product_attention(q1, k1, v1, attn_mask=None, is_causal=True)
        attn12 = F.scaled_dot_product_attention(q1, k1, v2, attn_mask=None, is_causal=True)
        attn1 = torch.cat([attn11, attn12], dim=-1)
        attn21 = F.scaled_dot_product_attention(q2, k2, v1, attn_mask=None, is_causal=True)
        attn22 = F.scaled_dot_product_attention(q2, k2, v2, attn_mask=None, is_causal=True)
        attn2 = torch.cat([attn21, attn22], dim=-1)
        
        lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()).type_as(q)
        lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()).type_as(q)
        lambda_full = lambda_1 - lambda_2 + self.lambda_init
        attn = attn1 - lambda_full * attn2
        attn = self.subln(attn)
        attn = attn * (1 - self.lambda_init)
        attn = attn.reshape(B, T, C)
        # output projection
        y = self.resid_dropout(self.c_proj(attn))
        return y

When i try and train this model it understandably trains at a lower iterations/sec, but if we look at the loss per iteration it seems to be getting stuck. (in each iteration i have kept the total batch size as compared to the gpt2-124M-RoPE run)
image
Any ideas on what I've gotten wrong? I'm no ML expert

@karpathy on the off chance that you see this, have you read about the diff transformer paper and if so, what do you think about it?

@notlober
Copy link

notlober commented Nov 6, 2024

ive done similar experiment too, first of all I recommend looking at non flash implementation from diff attn: "multihead_diffattn.py" and not "multihead_flashdiff_1.py"

secondly youre dividing n_heads and head dim twice, its a issue, it does not appear in original code.
first here: self.n_head = config.n_head // 2
and here: self.head_dim = self.n_embd // self.n_head // 2

lastly, even through standard gpt2 with RoPE went well, I recommend starting with non RoPE version, since its easier to begin with.

and probably F.scaled_dot_product_attention is a bit different that flash attention internally.

@notlober
Copy link

notlober commented Nov 6, 2024

my implementation:

class CausalSelfAttention(nn.Module):

    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        # regularization
        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.dropout = config.dropout
        self.head_dim = self.n_embd // self.n_head // 2
        # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
        self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                                    .view(1, 1, config.block_size, config.block_size))
        self.lambda_init = lambda_init_fn(config.n_layer)
        self.lambda_q1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
        self.lambda_k1 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
        self.lambda_q2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))
        self.lambda_k2 = nn.Parameter(torch.zeros(self.head_dim, dtype=torch.float32).normal_(mean=0,std=0.1))

        self.subln = nn.LayerNorm(2 * self.head_dim, eps=1e-5, elementwise_affine=True)
    
    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k, v  = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, 2*self.n_head, self.head_dim).transpose(1, 2) # (B, nh, T, hs)
        q = q.view(B, T, 2*self.n_head, self.head_dim).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, 2*self.head_dim).transpose(1, 2) # (B, nh, T, hs)

        # manual implementation of attention
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)

        lambda_1 = torch.exp(torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1))
        lambda_2 = torch.exp(torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1))
        lambda_full = lambda_1 - lambda_2 + self.lambda_init
        lambda_full = lambda_full.view(1, self.n_head, 1, 1) 

        att = att.view(B, self.n_head, 2, T, T)
        attn1 = att[:, :, 0, :, :]  # (B, nh, T, T)
        attn2 = att[:, :, 1, :, :]  # (B, nh, T, T)

        attn_weights = attn1 - lambda_full * attn2  # (B, nh, T, T)
        
        att = self.attn_dropout(attn_weights)
        att = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        att = self.subln(att)
        att = att * (1 - self.lambda_init)
        y = att.permute(0, 2, 1, 3).contiguous().view(B, T, C)  # (B, T, n_embd)
        # output projection
        y = self.resid_dropout(self.c_proj(y))
        return y

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants