Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Dilated Sliding Window mask_mod #12

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions attn_gym/masks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from attn_gym.masks.sliding_window import generate_sliding_window
from attn_gym.masks.prefix_lm import generate_prefix_lm_mask
from attn_gym.masks.document_mask import generate_doc_mask_mod
from attn_gym.masks.dilated_sliding_window import generate_dilated_sliding_window
60 changes: 60 additions & 0 deletions attn_gym/masks/dilated_sliding_window.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import torch
from torch.nn.attention.flex_attention import _mask_mod_signature, and_masks
from attn_gym.masks import causal_mask


def generate_dilated_sliding_window(window_size: int, dilation: int) -> _mask_mod_signature:
"""Generates a dilated sliding window attention mask.
Args:
window_size: The size of the sliding window.
dilation: The dilation factor for the sliding window.

Note:
We assume that the window size represents the lookback size and we mask out all future tokens
similar to causal masking.
"""

def dilated_sliding_window(b, h, q_idx, kv_idx):
diff = q_idx - kv_idx
in_window = (diff >= 0) & (diff < window_size * dilation)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm from the paper its not clear to me that its always causal

what about torch.abs(diff) < window_size ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One other nit I think that its clearer if we keep the window_size and dilation separate

e.g. to recreate the paper (if we didnt have the and_causal mask)
we would set window_size = 8 and dilation = 2

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about torch.abs(diff) < window_size ?

I thought it would be good to make this implementation consistent with attn_gym/masks/sliding_window.py.

However, seems reasonable to follow the non-causal way the paper described. I will update the generate_dilated_sliding_window() function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One other nit I think that its clearer if we keep the window_size and dilation separate

e.g. to recreate the paper (if we didnt have the and_causal mask)
we would set window_size = 8 and dilation = 2

Maybe, I missed something. Can you please explain what does it mean by "if we keep the window_size and dilation separate"?
Did you mean setting window_size = 8 and dilation = 2?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ohh I just meant that the dilation factor doesnt have any impact on the absolute size of the window.
window_size * dilation -> window_size

So the "potential" size of the window is 16 elements (8 forward, 8 backward ) but a dilation factor knocks out half and we end up up with 4 on both sides. We dont extend the window so as to capture more elements

is_dilated = (diff % dilation) == 0
return in_window & is_dilated

dilated_sliding_window_mask = and_masks(dilated_sliding_window, causal_mask)
dilated_sliding_window_mask.__name__ = (
f"dilated_sliding_window_{window_size}_dilation_{dilation}"
)
return dilated_sliding_window_mask


def main(device: str = "cpu"):
"""Visualize the attention scores of dilated sliding window mask mod.

Args:
device (str): Device to use for computation.
"""
from attn_gym import visualize_attention_scores

B, H, SEQ_LEN, HEAD_DIM = 1, 1, 24, 8

def make_tensor():
return torch.ones(B, H, SEQ_LEN, HEAD_DIM, device=device)

query, key = make_tensor(), make_tensor()

dilated_sliding_window_mask = generate_dilated_sliding_window(window_size=5, dilation=2)
visualize_attention_scores(
query,
key,
mask_mod=dilated_sliding_window_mask,
device=device,
name="dilated_sliding_window_mask",
)


if __name__ == "__main__":
try:
from jsonargparse import CLI
except ImportError:
raise ImportError("Be sure to run: pip install -e .'[viz]'")
CLI(main)
Loading