Skip to content

Commit

Permalink
Fix .to() method for all attention biases (fairinternal/xformers#1278)
Browse files Browse the repository at this point in the history
__original_commit__ = fairinternal/xformers@0f7311b
  • Loading branch information
danthe3rd authored and xFormers Bot committed Dec 27, 2024
1 parent a0987e8 commit a8746f3
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 0 deletions.
2 changes: 2 additions & 0 deletions tests/test_mem_eff_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,8 @@ def test_forward(opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, packed, fmt, **kwargs)
fmt="BMHK" if packed else fmt,
**kwargs,
)
if attn_bias is not None:
assert type(attn_bias.to(query.device)) is type(attn_bias)

if packed:
c = torch.stack([query, key, value], 2)
Expand Down
115 changes: 115 additions & 0 deletions xformers/ops/fmha/attn_bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,9 @@ class LocalAttentionFromBottomRightMask(AttentionBias):
window_left: int
window_right: int

def to(self, device) -> "LocalAttentionFromBottomRightMask":
return self

def __post_init__(self) -> None:
if self.window_left < 0:
raise ValueError(
Expand Down Expand Up @@ -227,6 +230,9 @@ class LowerTriangularFromBottomRightMask(AttentionBias):
"""

def to(self, device: torch.device) -> "LowerTriangularFromBottomRightMask":
assert (
type(self) is LowerTriangularFromBottomRightMask
), "Please implement in subclass"
return self

def materialize(
Expand Down Expand Up @@ -273,6 +279,14 @@ class LowerTriangularFromBottomRightLocalAttentionMask(

_window_size: int

def to(
self, device: torch.device
) -> "LowerTriangularFromBottomRightLocalAttentionMask":
assert (
type(self) is LowerTriangularFromBottomRightLocalAttentionMask
), "Please implement in subclass"
return self

def __post_init__(self) -> None:
if self._window_size <= 0:
raise ValueError(
Expand Down Expand Up @@ -314,6 +328,7 @@ class _SeqLenInfo:
seqstart_py: List[int]

def to(self, device: torch.device) -> "_SeqLenInfo":
assert type(self) is _SeqLenInfo, "Please implement in subclass"
if self.seqstart.device == device:
return self
return _SeqLenInfo(
Expand Down Expand Up @@ -437,6 +452,7 @@ def __post_init__(self) -> None:
assert len(self.seqstart_py) == len(self.seqlen_py) + 1

def to(self, device: torch.device) -> "_PaddedSeqLenInfo":
assert type(self) is _PaddedSeqLenInfo, "Please implement in subclass"
if self.seqlen.device == device:
return self
return _PaddedSeqLenInfo(
Expand Down Expand Up @@ -552,6 +568,7 @@ class _GappySeqInfo(_SeqLenInfo):
# seqstart: torch.Tensor

def to(self, device: torch.device) -> "_GappySeqInfo":
assert type(self) is _GappySeqInfo, "Please implement in subclass"
if self.seqlen.device == device:
return self
return _GappySeqInfo(
Expand Down Expand Up @@ -654,6 +671,7 @@ class BlockDiagonalMask(AttentionBias):
_batch_sizes: Optional[Sequence[int]] = None

def to(self, device) -> "BlockDiagonalMask":
assert type(self) is BlockDiagonalMask, "Please implement in subclass"
return BlockDiagonalMask(
q_seqinfo=self.q_seqinfo.to(device),
k_seqinfo=self.k_seqinfo.to(device),
Expand Down Expand Up @@ -858,6 +876,14 @@ class BlockDiagonalCausalMask(BlockDiagonalMask):
is from the initial query in block i.
"""

def to(self, device) -> "BlockDiagonalCausalMask":
assert type(self) is BlockDiagonalCausalMask, "Please implement in subclass"
return BlockDiagonalCausalMask(
q_seqinfo=self.q_seqinfo.to(device),
k_seqinfo=self.k_seqinfo.to(device),
_batch_sizes=self._batch_sizes,
)

def _create_block_mask(
self,
shape: Tuple[int, ...],
Expand Down Expand Up @@ -885,6 +911,16 @@ class BlockDiagonalCausalFromBottomRightMask(BlockDiagonalMask):
final query in block i.
"""

def to(self, device) -> "BlockDiagonalCausalFromBottomRightMask":
assert (
type(self) is BlockDiagonalCausalFromBottomRightMask
), "Please implement in subclass"
return BlockDiagonalCausalFromBottomRightMask(
q_seqinfo=self.q_seqinfo.to(device),
k_seqinfo=self.k_seqinfo.to(device),
_batch_sizes=self._batch_sizes,
)

def __post_init__(self) -> None:
for i, ((q_start, q_end), (k_start, k_end)) in enumerate(
zip(
Expand Down Expand Up @@ -933,6 +969,7 @@ class BlockDiagonalPaddedKeysMask(AttentionBias):
k_seqinfo: _PaddedSeqLenInfo

def to(self, device) -> "BlockDiagonalPaddedKeysMask":
assert type(self) is BlockDiagonalPaddedKeysMask, "Please implement in subclass"
return BlockDiagonalPaddedKeysMask(
q_seqinfo=self.q_seqinfo.to(device),
k_seqinfo=self.k_seqinfo.to(device),
Expand Down Expand Up @@ -1044,6 +1081,15 @@ class BlockDiagonalCausalWithOffsetPaddedKeysMask(BlockDiagonalPaddedKeysMask):

causal_diagonal: Any = None # unused. Exists for BC only.

def to(self, device) -> "BlockDiagonalCausalWithOffsetPaddedKeysMask":
assert (
type(self) is BlockDiagonalCausalWithOffsetPaddedKeysMask
), "Please implement in subclass"
return BlockDiagonalCausalWithOffsetPaddedKeysMask(
q_seqinfo=self.q_seqinfo.to(device),
k_seqinfo=self.k_seqinfo.to(device),
)

def _create_block_mask(
self,
shape: Tuple[int, ...],
Expand Down Expand Up @@ -1103,6 +1149,16 @@ class BlockDiagonalCausalLocalAttentionPaddedKeysMask(BlockDiagonalPaddedKeysMas

_window_size: int

def to(self, device) -> "BlockDiagonalCausalLocalAttentionPaddedKeysMask":
assert (
type(self) is BlockDiagonalCausalLocalAttentionPaddedKeysMask
), "Please implement in subclass"
return BlockDiagonalCausalLocalAttentionPaddedKeysMask(
q_seqinfo=self.q_seqinfo.to(device),
k_seqinfo=self.k_seqinfo.to(device),
_window_size=self._window_size,
)

def _create_block_mask(
self,
shape: Tuple[int, ...],
Expand Down Expand Up @@ -1153,6 +1209,9 @@ class PagedBlockDiagonalPaddedKeysMask(AttentionBias):
] = BlockDiagonalPaddedKeysMask

def to(self, device: torch.device) -> "PagedBlockDiagonalPaddedKeysMask":
assert (
type(self) is PagedBlockDiagonalPaddedKeysMask
), "Please implement in subclass"
return PagedBlockDiagonalPaddedKeysMask(
q_seqinfo=self.q_seqinfo.to(device),
k_seqinfo=self.k_seqinfo.to(device),
Expand Down Expand Up @@ -1250,6 +1309,19 @@ class PagedBlockDiagonalCausalWithOffsetPaddedKeysMask(

_UNPAGED_TYPE = BlockDiagonalCausalWithOffsetPaddedKeysMask

def to(
self, device: torch.device
) -> "PagedBlockDiagonalCausalWithOffsetPaddedKeysMask":
assert (
type(self) is PagedBlockDiagonalCausalWithOffsetPaddedKeysMask
), "Please implement in subclass"
return PagedBlockDiagonalCausalWithOffsetPaddedKeysMask(
q_seqinfo=self.q_seqinfo.to(device),
k_seqinfo=self.k_seqinfo.to(device),
block_tables=self.block_tables.to(device),
page_size=self.page_size,
)


@dataclass
class BlockDiagonalGappyKeysMask(AttentionBias):
Expand All @@ -1264,6 +1336,7 @@ class BlockDiagonalGappyKeysMask(AttentionBias):
k_seqinfo: _GappySeqInfo

def to(self, device: torch.device) -> "BlockDiagonalGappyKeysMask":
assert type(self) is BlockDiagonalGappyKeysMask, "Please implement in subclass"
return BlockDiagonalGappyKeysMask(
q_seqinfo=self.q_seqinfo.to(device),
k_seqinfo=self.k_seqinfo.to(device),
Expand Down Expand Up @@ -1359,6 +1432,15 @@ class BlockDiagonalCausalWithOffsetGappyKeysMask(BlockDiagonalGappyKeysMask):
than Q is to the final query in block i.
"""

def to(self, device: torch.device) -> "BlockDiagonalCausalWithOffsetGappyKeysMask":
assert (
type(self) is BlockDiagonalCausalWithOffsetGappyKeysMask
), "Please implement in subclass"
return BlockDiagonalCausalWithOffsetGappyKeysMask(
q_seqinfo=self.q_seqinfo.to(device),
k_seqinfo=self.k_seqinfo.to(device),
)

def materialize(
self,
shape: Tuple[int, ...],
Expand Down Expand Up @@ -1407,6 +1489,17 @@ class PagedBlockDiagonalGappyKeysMask(AttentionBias):
Type[BlockDiagonalGappyKeysMask]
] = BlockDiagonalGappyKeysMask

def to(self, device: torch.device) -> "PagedBlockDiagonalGappyKeysMask":
assert (
type(self) is PagedBlockDiagonalGappyKeysMask
), "Please implement in subclass"
return PagedBlockDiagonalGappyKeysMask(
q_seqinfo=self.q_seqinfo.to(device),
k_seqinfo=self.k_seqinfo.to(device),
block_tables=self.block_tables.to(device),
page_size=self.page_size,
)

def materialize(
self,
shape: Tuple[int, ...],
Expand Down Expand Up @@ -1507,6 +1600,17 @@ class BlockDiagonalCausalLocalAttentionMask(BlockDiagonalCausalMask):

_window_size: int = 0 # forced due to inheritance and default arguments

def to(self, device) -> "BlockDiagonalCausalLocalAttentionMask":
assert (
type(self) is BlockDiagonalCausalLocalAttentionMask
), "Please implement in subclass"
return BlockDiagonalCausalLocalAttentionMask(
q_seqinfo=self.q_seqinfo.to(device),
k_seqinfo=self.k_seqinfo.to(device),
_batch_sizes=self._batch_sizes,
_window_size=self._window_size,
)

def __post_init__(self):
if self._window_size <= 0:
raise ValueError(
Expand Down Expand Up @@ -1561,6 +1665,17 @@ class BlockDiagonalCausalLocalAttentionFromBottomRightMask(

_window_size: int = 0 # forced due to inheritance and default arguments

def to(self, device) -> "BlockDiagonalCausalLocalAttentionFromBottomRightMask":
assert (
type(self) is BlockDiagonalCausalLocalAttentionFromBottomRightMask
), "Please implement in subclass"
return BlockDiagonalCausalLocalAttentionFromBottomRightMask(
q_seqinfo=self.q_seqinfo.to(device),
k_seqinfo=self.k_seqinfo.to(device),
_batch_sizes=self._batch_sizes,
_window_size=self._window_size,
)

def __post_init__(self):
super().__post_init__()
if self._window_size <= 0:
Expand Down

0 comments on commit a8746f3

Please sign in to comment.