Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

About the use of rotary position coding. #8

Open
tianyabanbu opened this issue Jul 13, 2023 · 2 comments
Open

About the use of rotary position coding. #8

tianyabanbu opened this issue Jul 13, 2023 · 2 comments

Comments

@tianyabanbu
Copy link

I have a doubt about the rotary positional encoding part of the code.

your code :


def rotate_as_if_first(x, rotary_emb):
    # x: [bs, num_attention_heads, seq_len, head_size]
    # apply rotary as if all elements were first in the sequence
    cos, sin = rotary_emb(x, x.shape[-2])
    return rotate_one(x, cos, sin, torch.zeros(x.shape[0], x.shape[-2], dtype=torch.long, device=cos.device))

Should it be like this :


def rotate_as_if_first(x, rotary_emb, position_ids):
    # x: [bs, num_attention_heads, seq_len, head_size]
    # apply rotary as if all elements were first in the sequence
    cos, sin = rotary_emb(x, x.shape[-2])
    return rotate_one(x, cos, sin, position_ids)

When the function rotate_as_if_first calls the function rotate_one, the parameter position_ids needs to be passed in instead of generating a position parameter by torch.zeros(x.shape[0], x.shape[-2], dtype=torch.long, device=cos.device) .

@CStanKonrad
Copy link
Owner

Hi, Thanks for the question! We always treat the memory keys as if they have position 0. Position ids inside the local context are converted to be in range 0, 2047 here

rel_pos_ids = position_ids - torch.min(position_ids, dim=-2, keepdim=True)[0]

More context:
Memory layers use positional encodings for local context in the standard way. Whereas for the memory keys, they encode them as if they were at the beginning of the local context.

In other words, let
$$t_0, t_1, t_2, t_3, \ldots t_{2047}, t_{2048}, \ldots, t_{4095}, \ldots$$
be some input.
LongLLaMA will process it in context windows. First, it will process
$$t_0, t_1, t_2, t_3, \ldots t_{2047}$$
and move the (key, value) pairs from memory layers to the memory cache. The local context part ($t_0, \ldots, t_{2047}$) uses $2048$ rotary positional encodings.
Then LongLLaMA will process
$$t_{2048}, \ldots, t_{4095}$$
Here again the local context part ($t_{2048}, \ldots, t_{4095}$)
uses the same $2048$ rotary positional encodings as the previous local context ($t_0 \ldots t_{2047}$).
Memory layers see previous embeddings (keys and values corresponding to $t_0, \ldots, t_{2047}$), but as if they were located at the same position as $t_{2048}$ (what is position 0 after the conversion).

@tianyabanbu
Copy link
Author

I see, thank you very much for your answer.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants