Skip to content
/ NATTEN Public
forked from SHI-Labs/NATTEN

Neighborhood Attention Extension. Bringing attention to a neighborhood near you!

License

Notifications You must be signed in to change notification settings

a-gn/NATTEN

 
 

Repository files navigation

NATTENLogo NATTENLogo

Neighborhood Attention Extension

Bringing attention to a neighborhood near you!

NATTEN is an extension to PyTorch, which provides the first fast sliding window attention with efficient CUDA kernels. It provides Neighborhood Attention (local attention) and Dilated Neighborhood Attention (sparse global attention, a.k.a. dilated local attention) as PyTorch modules for both 1D and 2D data.

About NATTEN

Sliding window self attention mechanisms have been relatively overlooked, in part due to implementation difficulties. For example, in a paper proposing one of the earliest examples of such methods, SASA, it was noted that although such methods are theoretically efficient, they're relatively slow in practice, compared to convolutions, which have been implemented in most well-known deep learning libraries.

That is why we started developing NATTEN, an extension to existing libraries with efficient implementations of sliding window attention mechanisms, which will enable research in this direction including building powerful hierarchical vision transformers.

For more information, we highly recommend reading our preprints NAT and DiNAT, and check out their repository.

How fast is NATTEN?

The latest version of NATTEN runs pretty fast on Ampere with the latest torch and CUDA versions.

TimePlot TimePlot MemPlot MemPlot

Requirements

NATTEN supports PyTorch version 1.8 and later, and Python versions 3.7, 3.8, 3.9, 3.10(only torch >= 1.11), and 3.11 (only torch >= 1.13).

NOTE: The current version of NATTEN comes with Linux-only wheels, and supports Pascal and above (SM >= 60, i.e. Tesla P100). Make sure your GPU is supported by referring to this webpage. Future versions will extend support to older GPUs.

You may try and build from source on Windows, but do so at your own risk. We also welcome contributions in all forms.

Getting started

Linux

Just refer to our website, shi-labs.com/natten, select your PyTorch version and the CUDA version it was compiled with, copy-paste the command and install in seconds!

For example, if you're on torch==2.0.0+cu118, you should install NATTEN using the following wheel:

pip3 install natten -f https://shi-labs.com/natten/wheels/cu118/torch2.0.0/index.html

More generally:

pip3 install natten -f https://shi-labs.com/natten/wheels/{cu_version}/torch{torch_version}/index.html

NOTE: If you do not specify a wheel URL, pip will collect NATTEN and try to compile on locally, which depending on your system might take up to 30 minutes. We strongly recommend using our website if you're a Linux user.

Mac

Unfortunately we are not yet able to build Mac wheels, but you can compile on install, so just run:

pip3 install natten

Windows

NATTEN should support Windows devices with CUDA, but does not yet have Windows wheels. You can try and build NATTEN from source (see below).

Build from source

Once you've set up your Python environment and installed PyTorch with CUDA, simply clone and build:

pip install ninja # Recommended, not required
git clone https://github.com/SHI-Labs/NATTEN
cd NATTEN
make

Optional: run unit tests

You can optionally run unit tests to verify building from source finished successfully:

make test

Catalog

  • Neighborhood Attention 1D (CUDA)
  • Neighborhood Attention 2D (CUDA)
  • Neighborhood Attention 3D (CUDA)
  • Neighborhood Attention 1D (CPU)
  • Neighborhood Attention 2D (CPU)
  • Neighborhood Attention 3D (CPU)
  • Dilation support
  • Float16 support and utilization
  • BFloat16 support
  • Kepler and Maxwell (30<=SM<60) support
  • Windows builds

Usage

Simply import NeighborhoodAttention1D, NeighborhoodAttention2D, or NeighborhoodAttention3D from natten:

from natten import NeighborhoodAttention1D
from natten import NeighborhoodAttention2D
from natten import NeighborhoodAttention3D

na1d = NeighborhoodAttention1D(dim=128, kernel_size=7, dilation=2, num_heads=4)
na2d = NeighborhoodAttention2D(dim=128, kernel_size=7, dilation=2, num_heads=4)
na3d = NeighborhoodAttention3D(dim=128, kernel_size=7, dilation=2, num_heads=4)

NA3D also supports different kernel size and dilation values for depth:

na3d = NeighborhoodAttention3D(
	dim=128,
	kernel_size=7,
	kernel_size_d=5,
	dilation=2,
	dilation_d=3,
	num_heads=4)

Modules expect inputs of shape [batch_size, *, dim]:

  • NA1D: [batch_size, sequence_length, dim]
  • NA2D: [batch_size, height, width, dim]
  • NA3D: [batch_size, depth, height, width, dim]

FLOPs

We recommend counting flops through fvcore.

pip install fvcore

Once you have fvcore installed, you can directly use our dedicated FLOP counter:

from natten.flops import get_flops

flops = get_flops(model, input)

Alternatively, if you are using fvcore's FlopCountAnalysis directly, be sure to add our op handles:

from fvcore.nn import FlopCountAnalysis
from natten.flops import add_natten_handle

# ...

flop_ctr = FlopCountAnalysis(model, input)
flop_ctr = add_natten_handle(flop_ctr)

# ...

License

NATTEN is released under the MIT License.

Citation

@inproceedings{hassani2023neighborhood,
	title        = {Neighborhood Attention Transformer},
	author       = {Ali Hassani and Steven Walton and Jiachen Li and Shen Li and Humphrey Shi},
	year         = 2023,
        booktitle    = {IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}
}
@article{hassani2022dilated,
	title        = {Dilated Neighborhood Attention Transformer},
	author       = {Ali Hassani and Humphrey Shi},
	year         = 2022,
	url          = {https://arxiv.org/abs/2209.15001},
	eprint       = {2209.15001},
	archiveprefix = {arXiv},
	primaryclass = {cs.CV}
}

About

Neighborhood Attention Extension. Bringing attention to a neighborhood near you!

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Cuda 56.6%
  • C++ 27.0%
  • Python 14.8%
  • Shell 1.4%
  • Makefile 0.2%