forked from facebookresearch/xformers
-
Notifications
You must be signed in to change notification settings - Fork 0
/
favor.py
173 lines (144 loc) · 6.05 KB
/
favor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import logging
import math
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
import torch.nn as nn
from torch.cuda.amp import autocast
from xformers.components.attention import Attention, AttentionConfig, register_attention
from xformers.components.attention.feature_maps import (
FeatureMap,
FeatureMapType,
SMHyperbolic,
SMOrf,
SMReg,
)
logger = logging.getLogger("xformers")
@dataclass
class FavorAttentionConfig(AttentionConfig):
causal: Optional[bool]
dim_features: Optional[int] = None # The dimensions of the random features
dim_head: Optional[
int
] = None # The embedding dimension of the inputs. Only useful to get a dim_features estimate
iter_before_redraw: Optional[
int
] = None # The number of iterations before the random features are re-drawn from scratch
feature_map: Optional[FeatureMapType] = None
@register_attention("favor", FavorAttentionConfig)
class FavorAttention(Attention):
def __init__(
self,
causal: bool = False,
dropout: float = 0.0,
dim_features: Optional[int] = None,
dim_head: Optional[int] = None,
iter_before_redraw: Optional[int] = None,
feature_map_type: FeatureMapType = FeatureMapType.SMReg,
normalize_inputs: bool = False,
*_,
**__,
):
r"""
Kernelized attention, as proposed in Performers_
("Rethinking attention with performers." K. Choromanski et al. (2020).).
FAVOR stands for "Fast Attention Via positive Orthogonal Random features"
Args:
dropout (float): the probability of an output to be randomly dropped at training time
dim_features (int): the dimension of the random features space
iter_before_redraw (int): the number of steps (forward calls) before a redraw of the features
feature_map_type (FeatureMapType): the type of feature map being used,
for instance orthogonal random features.
.. _Performers: https://arxiv.org/pdf/2009.14794v1.pdf
"""
super().__init__()
self.causal = causal
self.iter_before_redraw = (
(2 * iter_before_redraw)
if iter_before_redraw is not None
else iter_before_redraw
) # This will be used for both key and query
self.normalize_inputs = normalize_inputs
self.feature_map_type = feature_map_type
self.attn_drop = nn.Dropout(dropout, inplace=True)
# Setup dimension-dependent variables
# Reasonable dimension default
if dim_features is None:
assert dim_head is not None, "dim_features or dim_head needs to be passed"
self.dim_features = math.ceil(dim_head * (1 + math.log2(dim_head)))
self.dim_features = 2 * (
self.dim_features // 2
) # needs to be even for some variants
logger.info(
f"FAVOR: Automatically setting the random mapping dimension to {self.dim_features} from {dim_head}"
)
else:
self.dim_features = dim_features
feature_map_constructor = {
FeatureMapType.SMHyp: SMHyperbolic,
FeatureMapType.SMReg: SMReg,
FeatureMapType.SMOrf: SMOrf,
}[self.feature_map_type]
feature_settings = {
"dim_features": self.dim_features,
"iter_before_redraw": self.iter_before_redraw,
"normalize_inputs": self.normalize_inputs,
}
self.feature_map: FeatureMap = feature_map_constructor(**feature_settings) # type: ignore
# Properties specific to this attention mechanism
self.supports_attention_mask = False
self.supports_key_padding_mask = False
@staticmethod
def _maybe_promote(x: torch.Tensor) -> torch.Tensor:
# Only promote fp16 buffers, bfloat16 would be fine for instance
return x.float() if x.dtype == torch.float16 else x
@staticmethod
def _causal_attention(
k_prime: torch.Tensor, q_prime: torch.Tensor, v: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
# Algorithm 1 in the paper
ref_v = torch.ones_like(v.unsqueeze(2)) # BATCH x SEQ x 1 x EMB
Gps = k_prime.unsqueeze(3) * v.unsqueeze(2)
Grenorm = k_prime.unsqueeze(3) * ref_v
# Consolidate against the feature dimension
att_raw = torch.einsum("bcfe,bcf->bce", Gps, q_prime)
att_norm = torch.einsum("bcfe,bcf->bce", Grenorm, q_prime)
# Cumulative sum over the sequence
att_raw = att_raw.cumsum(2)
att_norm = att_norm.cumsum(2)
return att_raw, att_norm
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
*_,
**__,
):
# Project key and queries onto the feature map space
k_prime = self.feature_map(k)
q_prime = self.feature_map(q)
with autocast(enabled=False):
# The softmax kernel approximation for Favor will easily overflow
# Force the computations here to stay in fp32 for numerical stability
# Note that the dimensions are vastly reduced when compared to scaled_dot_product
k_prime = self._maybe_promote(k_prime)
q_prime = self._maybe_promote(q_prime)
v = self._maybe_promote(v)
if not self.causal:
att_normalization = q_prime @ (
k_prime.transpose(-2, -1) @ torch.ones_like(v)
)
att_raw = q_prime @ (k_prime.transpose(-2, -1) @ v)
else:
# Actually compute attention
att_raw, att_normalization = self._causal_attention(k_prime, q_prime, v)
# Normalize
att = att_raw / att_normalization
if self.attn_drop is not None:
att = self.attn_drop(att)
return att