From aec18460d5d6904f9de31f202239500ef4e4b2db Mon Sep 17 00:00:00 2001 From: carefree0910 Date: Wed, 24 Mar 2021 10:18:29 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=A5=E2=9A=A1=EF=B8=8FSimplified=20&=20?= =?UTF-8?q?Optimized=20`Transformer`?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1) Removed `final_attention` stuffs 2) Introduced `mean` reduction 3) Introduced `final_norm_type` 4) Optimized `DecayedAttention`, where earlier token can evenly attend to every later token --- cflearn/modules/blocks.py | 2 +- .../modules/extractors/transformer/configs.py | 4 +- .../modules/extractors/transformer/core.py | 39 +++++++------------ 3 files changed, 16 insertions(+), 29 deletions(-) diff --git a/cflearn/modules/blocks.py b/cflearn/modules/blocks.py index 947d9175c..e39906c5f 100644 --- a/cflearn/modules/blocks.py +++ b/cflearn/modules/blocks.py @@ -1251,7 +1251,7 @@ def forward( raw_weights = torch.bmm(q, k.transpose(-2, -1)) if mask is not None: raw_weights.masked_fill_(mask, float("-inf")) - # B * N_head, Sq, Sk -> # B * N_head, Sq, Sk + # B * N_head, Sq, Sk -> B * N_head, Sq, Sk weights = self._get_weights(raw_weights) if 0.0 < self.dropout < 1.0: weights = F.dropout(weights, self.dropout, self.training) diff --git a/cflearn/modules/extractors/transformer/configs.py b/cflearn/modules/extractors/transformer/configs.py index a282c0384..62d0555bc 100644 --- a/cflearn/modules/extractors/transformer/configs.py +++ b/cflearn/modules/extractors/transformer/configs.py @@ -12,14 +12,14 @@ def get_default(self) -> Dict[str, Any]: "num_layers": 3, "latent_dim": 32, "dropout": 0.1, - "norm_type": "batch_norm", + "norm_type": "layer_norm", "attention_type": "decayed", "encoder_type": "basic", "input_linear_config": {"bias": False}, "layer_config": {"latent_dim": 128}, "encoder_config": {}, "use_head_token": True, - "use_final_attention": False, + "final_norm_type": "layer_norm", } diff --git a/cflearn/modules/extractors/transformer/core.py b/cflearn/modules/extractors/transformer/core.py index 79fd2aa97..1f894ebf6 100644 --- a/cflearn/modules/extractors/transformer/core.py +++ b/cflearn/modules/extractors/transformer/core.py @@ -4,7 +4,6 @@ import numpy as np import torch.nn as nn -import torch.nn.functional as F from typing import * from torch import Tensor @@ -129,7 +128,6 @@ def __init__( mask = np.zeros([seq_len, seq_len], dtype=np.float32) for i in range(1, seq_len): np.fill_diagonal(mask[i:], i ** 2) - np.fill_diagonal(mask[..., i:], i ** 2) mask_ = torch.from_numpy(mask) decayed_mask = torch.empty(num_heads, seq_len, seq_len) for i in range(num_heads): @@ -282,7 +280,7 @@ def __init__( layer_config: Dict[str, Any], encoder_config: Dict[str, Any], use_head_token: bool, - use_final_attention: bool, + final_norm_type: Optional[str], ): super().__init__(in_flat_dim, dimensions) seq_len = dimensions.num_history @@ -306,11 +304,10 @@ def __init__( layer = TransformerLayer(latent_dim, num_heads, **layer_config) encoder_base = TransformerEncoder.get(encoder_type) self.encoder = encoder_base(layer, num_layers, dimensions, **encoder_config) - self.encoder_norm = nn.LayerNorm(latent_dim) - if not use_final_attention: - self.final_attn_linear = None + if final_norm_type is None: + self.final_norm = None else: - self.final_attn_linear = nn.Linear(latent_dim, 1) + self.final_norm = _get_norm(final_norm_type, latent_dim) @property def flatten_ts(self) -> bool: @@ -318,24 +315,12 @@ def flatten_ts(self) -> bool: @property def out_dim(self) -> int: - if self.head_token is None: - return self.latent_dim - if self.final_attn_linear is None: - return self.latent_dim - return 2 * self.latent_dim + return self.latent_dim def _aggregate(self, net: Tensor) -> Tensor: - last_token = net[..., -1, :] - if self.final_attn_linear is None: - return last_token - if self.head_token is None: - no_head_token = net - else: - no_head_token = net[..., :-1, :] - a_hat = self.final_attn_linear(no_head_token) - a_prob = F.softmax(a_hat, dim=1) - a = torch.sum(a_prob * no_head_token, dim=1) - return torch.cat([a, last_token], 1) + if self.head_token is not None: + return net[:, -1] + return net.mean(1) def forward(self, net: Tensor) -> Tensor: # input -> latent @@ -347,9 +332,11 @@ def forward(self, net: Tensor) -> Tensor: # encode latent vector with transformer net = self.position_encoding(net) net = self.encoder(net, None) - net = self.encoder_norm(net) - # aggregate - return self._aggregate(net) + # aggregate & norm + net = self._aggregate(net) + if self.final_norm is not None: + net = self.final_norm(net) + return net __all__ = ["Transformer"]