Skip to content

brandnewchoppa/gau-tensorflow

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

GAU - TensorFlow

Gated Attention Unit (TensorFlow implementation) from the paper Transformer Quality in Linear Time.

They presented a simpler and more efficient architecture than the Vanilla Transformer. It also suffers from quadratic complexity over the context length, but the authors claim it can replace the Multi-Headed Attention with reducing the number of heads to just a single one.

The key idea is to formulate attention and GLU (Gated Linear Unit) as a unified layer to share their computation as much as possible. This not only results in higher param/compute efficiency, but also naturally enables a powerful attentive gating mechanism. (Section 2.)

Gated Attention Unit

Roadmap

  • GAU module, Transformer Model
  • AutoregressiveWrapper (top_p, top_k)
  • Rotary Embeddings
  • ScaleNorm + FixNorm experiment from the paper
  • Extend inference with tokenizer and call(str) method to directly call the text
  • Implement custom 'pre_train_step' and 'legacy_train_step' for compatibility with model.fit (LegacyModel)

Warning

This repository is under development, so expect changes regulary but please feel free to explore and provide any feedback or suggestions you may have. 🚧

Install

Install through pip

!pip install git+https://github.com/brandnewchoppa/gau-tensorflow.git

Clone to colab with git

!git clone https://github.com/brandnewchoppa/gau-tensorflow.git
!mv /content/gau-tensorflow/gau_tensorflow .
!rm -rf gau-tensorflow

Usage

import tensorflow as tf
from gau_tensorflow import GAUTransformer

model = GAUTransformer(
    emb_dim = 128,            # embedding dimension
    n_tokens = 50257,         # number of tokens used in the vocabulary
    depth = 4,                # number of blocks stacked in the model
    causal = True             # autoregressive functionality
    use_rope = False,         # rotary position embeddings
    laplace_attn_fn = False   # laplacian attention function
)

x = tf.random.uniform([1, 512], 0, 50257, tf.int64)
logits = model(x, training = False)

Gated Attention Unit

import tensorflow as tf
from gau_tensorflow import GAU

model = GAU(
    qk_duim = 64,             # query/key dimension
    expansion_factor = 2,     # feed forward multiplier
    causal = True,            # autoregressive functionality
    norm_type = 'layer_norm', # normalisation type (layer_norm, scale_norm, rms_norm)
    shift_tokens = False,     # extra for autoregressive functionality
    use_rope = True,          # rotary position embeddings
    laplace_attn_fn = True    # laplacian attention function
)

x = tf.random.normal([1, 512])
z = model(x, training = False)

Interpolate Sequence Positions

for i in range(model.depth):
    model.get_layer('blocks').get_layer(f'gau{i}').rotary_pos_embs.interpolate_factor = 2.0

Citations

@article{Hua2022TransformerQI,
    title   = {Transformer Quality in Linear Time},
    author  = {Weizhe Hua and Zihang Dai and Hanxiao Liu and Quoc V. Le},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2202.10447}
}
@article{Toan2019TransformerwT,
    title   = {Transformers without Tears: Improving the Normalization of Self-Attention},
    author  = {Toan Q. Nguyen and Julian Salazar},
    journal = {ArXiv},
    year    = {2019},
    volume  = {abs/1910.05895}
}
@article{Xuezhe2023MEGA,
    title   = {Transformers without Tears: Improving the Normalization of Self-Attention},
    author  = {Xuezhe Ma and Chunting Zhou and Xiang Kong and Junxian He and Liangke Gui and Graham Neubig and Jonathan May and Luke Zettlemoyer},
    journal = {ArXiv},
    year    = {2023},
    volume  = {abs/2209.10655}
}