Skip to content

Commit

Permalink
seeing really big improvements with per token learned value residual …
Browse files Browse the repository at this point in the history
…mixing values
  • Loading branch information
lucidrains committed Nov 22, 2024
1 parent 9a5f78d commit fa76daf
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 7 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '1.42.12',
version = '1.42.14',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand Down
6 changes: 5 additions & 1 deletion tests/test_x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,10 @@ def test_reinject_input():

model(x) # (1, 1024, 20000)

def test_value_residual():
@pytest.mark.parametrize('learned_value_residual_mix', (False, True))
def test_value_residual(
learned_value_residual_mix: bool
):

model = TransformerWrapper(
num_tokens = 20000,
Expand All @@ -341,6 +344,7 @@ def test_value_residual():
depth = 6,
heads = 8,
add_value_residual = True,
learned_value_residual_mix = learned_value_residual_mix
)
)

Expand Down
29 changes: 24 additions & 5 deletions x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1072,6 +1072,7 @@ def __init__(
logit_softclamp_value = 50.,
neutreno_value_residual = False, # Nguyen et al. https://arxiv.org/abs/2312.00751
neutreno_alpha = 0.4,
learned_value_residual_mix = False,
onnxable = False,
attend_sdp_kwargs: dict = dict(
enable_flash = True,
Expand Down Expand Up @@ -1231,6 +1232,14 @@ def __init__(
self.mem_k = nn.Parameter(torch.randn(kv_heads, num_mem_kv, dim_head))
self.mem_v = nn.Parameter(torch.randn(kv_heads, num_mem_kv, dim_head))

# maybe learned value residual mixer per token

self.to_value_residual_mix = nn.Sequential(
nn.Linear(dim, 1),
nn.Sigmoid(),
Rearrange('b n 1 -> b 1 n 1')
) if learned_value_residual_mix else always(0.5)

# attention on attention

self.attn_on_attn = on_attn
Expand Down Expand Up @@ -1303,7 +1312,8 @@ def forward(
diff_values = repeat(diff_values, 'b h n d -> b (r h) n d', r = h // kv_h)
else:
# https://arxiv.org/abs/2410.17897v1
v = 0.5 * (v + value_residual)
value_residual_mix = self.to_value_residual_mix(q_input)
v = v * value_residual_mix + value_residual * (1. - value_residual_mix)

# take care of caching

Expand Down Expand Up @@ -1541,8 +1551,9 @@ def __init__(
use_layerscale = False,
layerscale_init_value = 0.,
unet_skips = False,
reinject_input = False, # seen first in DEQ paper https://arxiv.org/abs/1909.01377, but later used in a number of papers trying to achieve depthwise generalization https://arxiv.org/abs/2410.03020v1
add_value_residual = False, # resformer from Zhou et al - https://arxiv.org/abs/2410.17897v1
reinject_input = False, # seen first in DEQ paper https://arxiv.org/abs/1909.01377, but later used in a number of papers trying to achieve depthwise generalization https://arxiv.org/abs/2410.03020v1
add_value_residual = False, # resformer from Zhou et al - https://arxiv.org/abs/2410.17897v1
learned_value_residual_mix = False, # seeing big improvements when the value residual mix value is learned per token - credit goes to @faresobeid for taking the first step with learned scalar mix, then @Blinkdl for taking it a step further with data dependent. here we will use per token learned
rel_pos_kwargs: dict = dict(),
**kwargs
):
Expand Down Expand Up @@ -1786,6 +1797,10 @@ def __init__(

self.add_value_residual = add_value_residual

is_first_self_attn = True
is_first_cross_attn = True
learned_value_residual_mix &= add_value_residual

# iterate and construct layers

for ind, (layer_type, layer_shift_tokens) in enumerate(zip(self.layer_types, shift_tokens)):
Expand All @@ -1801,9 +1816,13 @@ def __init__(
# attention, cross attention, feedforward

if layer_type == 'a':
layer = Attention(dim, heads = heads, causal = causal, **attn_kwargs)
self_attn_learned_value_residual = learned_value_residual_mix and not is_first_self_attn
layer = Attention(dim, heads = heads, causal = causal, learned_value_residual_mix = self_attn_learned_value_residual, **attn_kwargs)
is_first_self_attn = False
elif layer_type == 'c':
layer = Attention(dim, heads = heads, **{**attn_kwargs, **cross_attn_kwargs})
cross_attn_learned_value_residual = learned_value_residual_mix and not is_first_cross_attn
layer = Attention(dim, heads = heads, learned_value_residual_mix = learned_value_residual_mix and not is_first_cross_attn, **{**attn_kwargs, **cross_attn_kwargs})
is_first_cross_attn = False
elif layer_type == 'f':
layer = FeedForward(dim, **ff_kwargs)
layer = layer if not macaron else Scale(0.5, layer)
Expand Down

0 comments on commit fa76daf

Please sign in to comment.