Skip to content

Files

Latest commit

5138565 · Dec 20, 2024

History

History
139 lines (101 loc) · 5.73 KB

README.md

File metadata and controls

139 lines (101 loc) · 5.73 KB

Graph Convolutions Enrich the Self-Attention in Transformers!

GitHub Repo stars Twitter Follow arXiv Hits

Jeongwhan Choi1*, Hyowon Wi2*, Jayoung Kim2, Yehjin Shin2, Kookjin Lee3, Nathaniel Trask4, Noseong Park2,
1Yonsei University, 2KAIST, 3Arizona State University, 4University of Pennsylvania

📢 News!

Introduction

  • Graph Filter-based Self-Attention (GFSA) is a novel approach to enhance the self-attention mechanism in Transformers.
  • By redesigning self-attention from a graph signal processing (GSP) perspective, GFSA addresses the oversmoothing problem and improves performance for various domains.

Key Features:

  • Easily integrates with existing Transformer models
  • Improves performance with minimal computational overhead
  • GFSA shows significant improvements across various tasks on multiple domains


Tasks and Directories

The detailed guidance is included in the README.md of each subdirectory:

  1. 🖼️ Image Classification 👉 ./Image

  2. 📚 Natural Language Understanding 👉 ./NLP

  3. 🧠 Causal Language Modeling 👉 ./NLP

  4. 🌐 Graph Regression 👉 ./Graph

  5. 🎙️ Speech Recognition 👉 ./Speech

  6. 💻 Code Classification 👉 ./Code


Implementation Example with the Pseudocode

GFSA's core implementation is shown in the following pseudocode:

def GFSA(att, K):
    """
    Graph Filter-based Self-Attention
    
    Args:
        att: original self-attention matrix
        K: order of high-order term
        
    Notes:
        w_0, w_1 can be set in two ways:
        1) As learnable parameters
        2) Fixed as hyperparameters (w_0=0, w_1=1)
    
    Returns:
        gf_att: GFSA attention matrix
    """
    # Initialize weights
    w_0 = torch.zeros(h)  # identity term weight
    w_1 = torch.ones(h)   # first-order term weight  
    w_K = torch.zeros(h)  # high-order term weight
    I = torch.eyes(n)[None, None, ...]
    
    # Compute high-order term using Taylor approximation
    att_K = att + (K-1) * (torch.mm(att,att) - att)
    
    # Combine terms with weights
    gf_att = w_0[None, :, None, None] * I + \
             w_1[None, :, None, None] * att + \
             w_K[None, :, None, None] * att_K
             
    return gf_att

Key Implementation Features

  • Weight Initialization: w_0, w_1 can be either learnable parameters or fixed hyperparameters
  • High-order Term: Uses Taylor approximation to reduce computational cost
  • Minimal Parameters: Adds only a small number of parameters compared to base models

Integration Example

from models.attention import GFSA

# Replace original self-attention with GFSA
attention_output = GFSA(
    att=attention_scores,  # original attention matrix
    K=3                    # order of high-order term
)

Citation

If you use this code for your research, please cite our paper:

@inproceedings{choi2024gfsa,
   title={Graph Convolutions Enrich the Self-Attention in Transformers!},
   author={Jeongwhan Choi and Hyowon Wi and Jayoung Kim and Yehjin Shin and Kookjin Lee and Nathaniel Trask and Noseong Park},
   booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems},
   year={2024},
   url={https://openreview.net/forum?id=ffNrpcBpi6}
}

Star History

Star History Chart