Skip to content

Commit

Permalink
🔥⚡️Simplified & Optimized Transformer
Browse files Browse the repository at this point in the history
1) Removed `final_attention` stuffs
2) Introduced `mean` reduction
3) Introduced `final_norm_type`
4) Optimized `DecayedAttention`, where earlier token can evenly attend to every later token
  • Loading branch information
carefree0910 committed Mar 24, 2021
1 parent e8160ec commit aec1846
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 29 deletions.
2 changes: 1 addition & 1 deletion cflearn/modules/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1251,7 +1251,7 @@ def forward(
raw_weights = torch.bmm(q, k.transpose(-2, -1))
if mask is not None:
raw_weights.masked_fill_(mask, float("-inf"))
# B * N_head, Sq, Sk -> # B * N_head, Sq, Sk
# B * N_head, Sq, Sk -> B * N_head, Sq, Sk
weights = self._get_weights(raw_weights)
if 0.0 < self.dropout < 1.0:
weights = F.dropout(weights, self.dropout, self.training)
Expand Down
4 changes: 2 additions & 2 deletions cflearn/modules/extractors/transformer/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@ def get_default(self) -> Dict[str, Any]:
"num_layers": 3,
"latent_dim": 32,
"dropout": 0.1,
"norm_type": "batch_norm",
"norm_type": "layer_norm",
"attention_type": "decayed",
"encoder_type": "basic",
"input_linear_config": {"bias": False},
"layer_config": {"latent_dim": 128},
"encoder_config": {},
"use_head_token": True,
"use_final_attention": False,
"final_norm_type": "layer_norm",
}


Expand Down
39 changes: 13 additions & 26 deletions cflearn/modules/extractors/transformer/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import numpy as np
import torch.nn as nn
import torch.nn.functional as F

from typing import *
from torch import Tensor
Expand Down Expand Up @@ -129,7 +128,6 @@ def __init__(
mask = np.zeros([seq_len, seq_len], dtype=np.float32)
for i in range(1, seq_len):
np.fill_diagonal(mask[i:], i ** 2)
np.fill_diagonal(mask[..., i:], i ** 2)
mask_ = torch.from_numpy(mask)
decayed_mask = torch.empty(num_heads, seq_len, seq_len)
for i in range(num_heads):
Expand Down Expand Up @@ -282,7 +280,7 @@ def __init__(
layer_config: Dict[str, Any],
encoder_config: Dict[str, Any],
use_head_token: bool,
use_final_attention: bool,
final_norm_type: Optional[str],
):
super().__init__(in_flat_dim, dimensions)
seq_len = dimensions.num_history
Expand All @@ -306,36 +304,23 @@ def __init__(
layer = TransformerLayer(latent_dim, num_heads, **layer_config)
encoder_base = TransformerEncoder.get(encoder_type)
self.encoder = encoder_base(layer, num_layers, dimensions, **encoder_config)
self.encoder_norm = nn.LayerNorm(latent_dim)
if not use_final_attention:
self.final_attn_linear = None
if final_norm_type is None:
self.final_norm = None
else:
self.final_attn_linear = nn.Linear(latent_dim, 1)
self.final_norm = _get_norm(final_norm_type, latent_dim)

@property
def flatten_ts(self) -> bool:
return False

@property
def out_dim(self) -> int:
if self.head_token is None:
return self.latent_dim
if self.final_attn_linear is None:
return self.latent_dim
return 2 * self.latent_dim
return self.latent_dim

def _aggregate(self, net: Tensor) -> Tensor:
last_token = net[..., -1, :]
if self.final_attn_linear is None:
return last_token
if self.head_token is None:
no_head_token = net
else:
no_head_token = net[..., :-1, :]
a_hat = self.final_attn_linear(no_head_token)
a_prob = F.softmax(a_hat, dim=1)
a = torch.sum(a_prob * no_head_token, dim=1)
return torch.cat([a, last_token], 1)
if self.head_token is not None:
return net[:, -1]
return net.mean(1)

def forward(self, net: Tensor) -> Tensor:
# input -> latent
Expand All @@ -347,9 +332,11 @@ def forward(self, net: Tensor) -> Tensor:
# encode latent vector with transformer
net = self.position_encoding(net)
net = self.encoder(net, None)
net = self.encoder_norm(net)
# aggregate
return self._aggregate(net)
# aggregate & norm
net = self._aggregate(net)
if self.final_norm is not None:
net = self.final_norm(net)
return net


__all__ = ["Transformer"]

0 comments on commit aec1846

Please sign in to comment.