Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

padding attention mask on torch side #12577

Merged
merged 1 commit into from
Dec 19, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions python/llm/src/ipex_llm/transformers/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,3 +184,80 @@ def layer_norm_forward(self, hidden_states: torch.Tensor):
hidden_states, self.normalized_shape,
self.weight, self.bias, self.eps
)


def prepare_mask(mask, bsz, n_heads, seq_length, kv_length, is_causal, dtype, device):
max_kvs = 128
padding_kv_length = (kv_length + max_kvs - 1) // max_kvs * max_kvs
if mask is None:
if is_causal:
mask = torch.full([1, 1, seq_length, padding_kv_length], torch.finfo(dtype).min,
dtype=dtype, device=device)
mask.triu_(1)
mask = mask.expand([bsz, n_heads, seq_length, padding_kv_length])
elif seq_length != kv_length and seq_length <= 32:
mask = None
else:
mask = torch.zeros([1, 1, 1, padding_kv_length], torch.finfo(dtype).min,
dtype=dtype, device=device)
mask = mask.expand([bsz, n_heads, seq_length, padding_kv_length])
else:
if seq_length != kv_length and seq_length <= 32:
mask = mask[..., :seq_length, :kv_length]
mask = mask.expand([bsz, n_heads, seq_length, kv_length])
elif mask.size(3) != padding_kv_length:
new_mask = torch.empty([bsz, 1, seq_length, padding_kv_length],
dtype=dtype, device=device)
new_mask[:, :, :, :kv_length] = mask[:, 0:1, :seq_length, :kv_length]
new_mask[:, :, :, kv_length:] = torch.finfo(dtype).min
new_mask = new_mask.expand([bsz, n_heads, seq_length, padding_kv_length])
mask.set_(new_mask) # modify `mask` inplaced
else:
mask = mask.expand([bsz, n_heads, seq_length, padding_kv_length])
return mask


def scaled_dot_product_attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
mask: torch.Tensor = None,
is_causal: bool = False, scale: float = None) -> torch.Tensor:
bsz, n_heads, seq_length, head_dim = query.shape
_, n_kv_heads, kv_length, _ = key.shape

dtype, device = query.dtype, query.device

if (
device.type == "xpu"
and dtype in [torch.float, torch.half]
and head_dim in [64, 80, 96, 128]
):
# prepare scale
scale = 1 / math.sqrt(head_dim) if scale is None else scale

# prepare mask
mask = prepare_mask(mask, bsz, n_heads, seq_length, kv_length, is_causal, dtype, device)

# compute
import xe_addons
if is_causal:
if key.dtype == torch.uint8:
attn_output = xe_addons.sdp_fp8_causal(query, key, value, mask, scale)
else:
attn_output = xe_addons.sdp_causal(query, key, value, mask, scale)
elif seq_length != kv_length and seq_length <= 32:
# todo: add scale support
if key.dtype == torch.uint8:
attn_output = xe_addons.sdp_fp8(query, key, value, mask)
else:
attn_output = xe_addons.sdp(query, key, value, mask)
else:
if key.dtype == torch.uint8:
attn_output = xe_addons.sdp_fp8(query, key, value, mask, scale)
else:
attn_output = xe_addons.sdp_non_causal(query, key, value, mask, scale)

return attn_output
else:
mask = mask[..., :seq_length, :kv_length] if mask is not None else None
return torch.nn.functional.scaled_dot_product_attention(
query, key, value, mask, is_causal=is_causal, scale=scale
)
Loading