Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Fused Adam Triton Kernels #29

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 160 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
.pybuilder/
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version

# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock

# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock

# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml

# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# pytype static type analyzer
.pytype/

# Cython debug symbols
cython_debug/

# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
211 changes: 211 additions & 0 deletions galore_torch/fused/NOTE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
## Fused GaLore Adam (WIP)

### Various fused implementations of `Adam` update step per [Gradient Low-Rank Projection](https://arxiv.org/abs/2403.03507)

This is an initial attempt at optimizing the update step of the `GaLore Adam` optimizer.

#### Overview

The `GaLore` `Adam` optimizer introduces additional ops to the traditional `adam` update step.

Specifically:

1. `grad` is projected to low rank --> additional matmul
2. `adam` states are updated with `grad` elementwise (same as `Adam` except in low-rank)
3. normalized `grad` is projected to full rank --> additional matmul
4. `params` are updated with the normalized full rank grad

#### Implementation

Various fusions were attempted across 2 kernel implementations:

- `Fused`
- Steps 1 & 2 are fused: the `adam` state updates are loaded and updated (inplace) during the first `matmul`
- Steps 3 & 4 are fused: the param update is folded as an epilogue into the second `matmul`
- `Hybrid`
- Step 1 is performed using standard `torch matmul` (i.e., `cuBlas`)
- Step 2 is fused as an elementwise kernel
- Steps 3 & 4 per `Fused`

#### Performance

Below are benchmarks for various kernels:

- `torch` - reference `torch` implementation where each of the steps are implemented verbatim per above
- `hybrid` - see above
- `fused` - see above
- `compiled` - `torch` reference implementation compiled using `torch.compile` with `fullgraph=True` and `mode="max-autotune"`.

Configs for each benchmark are the `grad (param)` shape, `dtype` of `grad` and `adam` states, and `allow_tf32`, whether `torch` and `triton` matmuls are allowed to use `TF32` tensor cores (see `Discussion`).

`Grad shape`: `4096x4096`, `dtype`: `torch.float32`, `allow_tf32`: `False`

```
Median times (ms):
rank torch hybrid fused compiled
0 32.0 0.560128 0.347136 0.505856 0.534528
1 64.0 0.627712 0.404480 0.600960 0.615424
2 128.0 0.825232 0.583168 0.985072 0.833536
3 256.0 1.378304 1.126400 1.489920 1.375232
4 512.0 2.286080 2.101760 2.969600 2.302976
```

`Grad shape`: `4096x4096`, `dtype`: `torch.float32`, `allow_tf32`: `True`

```
Median times (ms):
rank torch hybrid fused compiled
0 32.0 0.540672 0.321536 0.316416 0.508928
1 64.0 0.612240 0.337728 0.345024 0.538624
2 128.0 0.640000 0.395264 0.393216 0.693248
3 256.0 0.777216 0.489472 0.548784 1.102848
4 512.0 1.216512 0.864256 0.960512 1.968128
```

`Grad shape`: `4096x11008`, `dtype`: `torch.float32`, `allow_tf32`: `False`

```
Median times (ms):
rank torch hybrid fused compiled
0 32.0 1.538672 0.915456 0.835584 1.364032
1 64.0 1.546240 0.940032 1.022976 1.486848
2 128.0 2.116608 1.498112 1.613312 2.098176
3 256.0 3.423744 2.719744 2.881536 3.227136
4 512.0 5.499904 5.036544 5.450752 5.508096
```

`Grad shape`: `4096x11008`, `dtype`: `torch.float32`, `allow_tf32`: `True`

```
Median times (ms):
rank torch hybrid fused compiled
0 32.0 1.413120 0.871424 0.817152 1.353184
1 64.0 1.489920 0.916480 0.854016 1.389568
2 128.0 1.679360 0.996352 1.005568 1.563648
3 256.0 2.152448 1.415168 1.470464 2.185216
4 512.0 3.210240 2.460672 2.580480 3.477504
```

##### Accuracy

Comparison to reference `torch` implementation:

```
Running with 4096 x 4096 grad (param) shape, GaLore orthogonal matrix [128, 4096], dtype torch.float32, and allow_tf32 True
Kernel: hybrid
Accuracy:
-> adam state - running grad mean:
Max err: 0.000000 Relative err: 0.000001
-> adam state - running grad var:
Max err: 0.000002 Relative err: 0.000002
-> params (after update):
Max err: 0.000000 Relative err: 0.000001
```

```
Running with 4096 x 4096 grad (param) shape, GaLore orthogonal matrix [128, 4096], dtype torch.float32 and allow_tf32 False
Kernel: hybrid
Accuracy:
-> adam state - running grad mean:
Max err: 0.000000 Relative err: 0.000000
-> adam state - running grad var:
Max err: 0.000002 Relative err: 0.000002
-> params (after update):
Max err: 0.000000 Relative err: 0.000000
```

```
Running with 4096 x 4096 grad (param) shape, GaLore orthogonal matrix [128, 4096], dtype torch.float32 and allow_tf32 True
Kernel: fused
Accuracy:
-> adam state - running grad mean:
Max err: 0.000845 Relative err: 0.001152
-> adam state - running grad var:
Max err: 0.000162 Relative err: 0.000161
-> params (after update):
Max err: 0.000000 Relative err: 0.000001
```

```
Running with 4096 x 4096 grad (param) shape, GaLore orthogonal matrix [128, 4096], dtype torch.float32 and allow_tf32 False
Kernel: fused
Accuracy:
-> adam state - running grad mean:
Max err: 0.000003 Relative err: 0.000004
-> adam state - running grad var:
Max err: 0.000002 Relative err: 0.000002
-> params (after update):
Max err: 0.000000 Relative err: 0.000000
```

#### Discussion

##### Down Projection GEMM Shape

The motivation for the `hybrid` approach is the unconventional matrix shapes of the down projection (Step 1):

- The projection is always done such that the larger dimension of the `grad` matrix is maintained while other is projected to low rank per the `GaLore` algorithm
- E.g., if `M >= N`, the GEMM is of shape (`M x N`) x (`N x rank`) = (`M x rank`), (`rank x M`) x (`M x N`) = (`rank x N`) otherwise
- Since `{M, N} >> rank` by definition, this results in a large reduction dimension relative to one of the output dimensions (output matrix is either fat or skinny)
- This does not fit cleanly into the `split-k / parallel reduction` `GEMM` paradigm which is more tailored for shapes where both output dims are smaller than the reduction dimension.
- Consequently, I had trouble finding an optimal kernel config using `triton` `autotuner` for the down projection step, despite tuning across many compute and io-bound configs (see `fused.triton_utils.kernels.matmul.py`).
- Benchmarking `triton`-tuned `matmul` against default `torch.matmul` for these shapes showed worse performance, for `torch.float32`

#### Effect of `TF32` tensor cores

`allow_tf32`: this has significant impact on relative performance of `triton` vs `torch` matmuls:

- Quick benchmarks of the downprojection `matmul` show that:
- with `allow_tf32=True` for both, triton exhibits `~1.30x` performance improvement over `torch`.
- with `allow_tf32=False`, performance of `triton` degrades significantly to `~.50x` of `torch`.

See this [`torch note`](https://pytorch.org/docs/stable/notes/cuda.html#tf32-on-ampere) for more details on this feature.

**Note**: This might be less of a concern given this incoming triton [PR](https://github.com/openai/triton/pull/3234), which implements a fast `TF32` trick that improves both performance and accuracy.

#### Repro

`tests/test_fused_kernels.py` is a `CLI` that has 2 modes, one for testing kernel accuracy, and the other for benchmarking across a number of configs.

**Examples**

_Accuracy_

- Test accuracy of `torch` vs `hybrid` for `M=4096`, `N=4096`, `rank=128`, and `tf32` switched on:

```python
python tests/test_fused_kernels.py --mode=test --kernel=hybrid --M=4096 --N=4096 --rank=128 --allow_tf32
```

_Benchmark_

- Benchmark across all kernels without `tf32`:

```python
python tests/test_fused_kernels.py --mode=benchmark
```

_Additional options_

```python
python tests/test_fused_kernels.py --help
```

_Note:_ Passing in the additional flag `--verbose` will show `triton` autotuning logs -- I customized the `triton` autotuner spit out configs and other details.

#### Test Env

- GPU Device Props:
- Name: `NVIDIA RTX A6000`
- CC: `86`
- Total_memory: `48676MB`
- SM count: `84`
- Torch: `2.3.0.dev20240310+cu118`
- Triton: `3.0.0`

#### Next Steps

- [ ] Implement `FusedGaLoreOptimizer`
- [ ] `Cutlass` - given fixed GEMM shape, experiment with `Cutlass` GEMMs (`split-k`, `stream-k`, fast `tensorops`). Interestingly, profiling `torch.matmul` for down projection shows that `cuBlas` dispatches to a `Cutlass` kernel of shape `128x128x16`.
- [ ] Repeat with `AdamW8bit`
- [ ] More detailed analysis of `torch.compile` performance
Empty file added galore_torch/fused/__init__.py
Empty file.
Empty file.
Loading