Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is in_dims unnecessary in the EfficientAdditiveAttnetion module #17

Open
plutolove233 opened this issue Aug 6, 2024 · 0 comments
Open

Comments

@plutolove233
Copy link

Hello authors, your work is absolutely amazing! Thank you so much for making the results and the code publicly available.
However, I have found a small issue with the EfficientAdditiveAttention module when I was migrating it to my research content:

class EfficientAdditiveAttnetion(nn.Module):
    """
    Efficient Additive Attention module for SwiftFormer.
    Input: tensor in shape [B, N, D]
    Output: tensor in shape [B, N, D]
    """

    def __init__(self, in_dims=20, token_dim=768, num_heads=2):
        super().__init__()

        self.to_query = nn.Linear(in_dims, token_dim * num_heads)
        self.to_key = nn.Linear(in_dims, token_dim * num_heads)

        ...

If the dimension of the input tensor is BxNxD, denoted by batch, in_dims, and token_dim, respectively, then should the input dimension of the linear transformation layer be token_dim? Otherwise, the operation cannot be performed.

I did a little test as follows:

import torch
from torch import nn
import numpy as np
import einops
class EfficientAdditiveAttnetion(nn.Module):

    def __init__(self, in_dims=20, token_dim=768, num_heads=2):
        super().__init__()

        self.to_query = nn.Linear(in_dims, token_dim * num_heads)
        self.to_key = nn.Linear(in_dims, token_dim * num_heads)

        self.w_g = nn.Parameter(torch.randn(token_dim * num_heads, 1))
        self.scale_factor = token_dim ** -0.5
        self.Proj = nn.Linear(token_dim * num_heads, token_dim * num_heads)
        self.final = nn.Linear(token_dim * num_heads, token_dim)

    def forward(self, x):
        query = self.to_query(x)
        key = self.to_key(x)

        query = torch.nn.functional.normalize(query, dim=-1) #BxNxD
        key = torch.nn.functional.normalize(key, dim=-1) #BxNxD

        query_weight = query @ self.w_g # BxNx1 (BxNxD @ Dx1)
        A = query_weight * self.scale_factor # BxNx1

        A = torch.nn.functional.normalize(A, dim=1) # BxNx1

        G = torch.sum(A * query, dim=1) # BxD

        G = einops.repeat(
            G, "b d -> b repeat d", repeat=key.shape[1]
        ) # BxNxD

        out = self.Proj(G * key) + query #BxNxD

        out = self.final(out) # BxNxD

        return out

if __name__ == '__main__':
    model = EfficientAdditiveAttnetion(in_dims=2, token_dim=3)
    x = np.random.randn(4, 2, 3) # B=4, N=2, D=3
    print(x.shape)
    x = torch.tensor(x, dtype=torch.float32)
    res = model(x)
    print(res.shape)

After executing the above code, the interpreter reports the following error:
image
Am I misunderstanding something? I hope you can help me out.🙏

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant