diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..6ea018a --- /dev/null +++ b/.gitignore @@ -0,0 +1,160 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ diff --git a/galore_torch/fused/NOTE.md b/galore_torch/fused/NOTE.md new file mode 100644 index 0000000..c996b7b --- /dev/null +++ b/galore_torch/fused/NOTE.md @@ -0,0 +1,211 @@ +## Fused GaLore Adam (WIP) + +### Various fused implementations of `Adam` update step per [Gradient Low-Rank Projection](https://arxiv.org/abs/2403.03507) + +This is an initial attempt at optimizing the update step of the `GaLore Adam` optimizer. + +#### Overview + +The `GaLore` `Adam` optimizer introduces additional ops to the traditional `adam` update step. + +Specifically: + +1. `grad` is projected to low rank --> additional matmul +2. `adam` states are updated with `grad` elementwise (same as `Adam` except in low-rank) +3. normalized `grad` is projected to full rank --> additional matmul +4. `params` are updated with the normalized full rank grad + +#### Implementation + +Various fusions were attempted across 2 kernel implementations: + +- `Fused` + - Steps 1 & 2 are fused: the `adam` state updates are loaded and updated (inplace) during the first `matmul` + - Steps 3 & 4 are fused: the param update is folded as an epilogue into the second `matmul` +- `Hybrid` + - Step 1 is performed using standard `torch matmul` (i.e., `cuBlas`) + - Step 2 is fused as an elementwise kernel + - Steps 3 & 4 per `Fused` + +#### Performance + +Below are benchmarks for various kernels: + +- `torch` - reference `torch` implementation where each of the steps are implemented verbatim per above +- `hybrid` - see above +- `fused` - see above +- `compiled` - `torch` reference implementation compiled using `torch.compile` with `fullgraph=True` and `mode="max-autotune"`. + +Configs for each benchmark are the `grad (param)` shape, `dtype` of `grad` and `adam` states, and `allow_tf32`, whether `torch` and `triton` matmuls are allowed to use `TF32` tensor cores (see `Discussion`). + +`Grad shape`: `4096x4096`, `dtype`: `torch.float32`, `allow_tf32`: `False` + +``` +Median times (ms): + rank torch hybrid fused compiled +0 32.0 0.560128 0.347136 0.505856 0.534528 +1 64.0 0.627712 0.404480 0.600960 0.615424 +2 128.0 0.825232 0.583168 0.985072 0.833536 +3 256.0 1.378304 1.126400 1.489920 1.375232 +4 512.0 2.286080 2.101760 2.969600 2.302976 +``` + +`Grad shape`: `4096x4096`, `dtype`: `torch.float32`, `allow_tf32`: `True` + +``` +Median times (ms): + rank torch hybrid fused compiled +0 32.0 0.540672 0.321536 0.316416 0.508928 +1 64.0 0.612240 0.337728 0.345024 0.538624 +2 128.0 0.640000 0.395264 0.393216 0.693248 +3 256.0 0.777216 0.489472 0.548784 1.102848 +4 512.0 1.216512 0.864256 0.960512 1.968128 +``` + +`Grad shape`: `4096x11008`, `dtype`: `torch.float32`, `allow_tf32`: `False` + +``` +Median times (ms): + rank torch hybrid fused compiled +0 32.0 1.538672 0.915456 0.835584 1.364032 +1 64.0 1.546240 0.940032 1.022976 1.486848 +2 128.0 2.116608 1.498112 1.613312 2.098176 +3 256.0 3.423744 2.719744 2.881536 3.227136 +4 512.0 5.499904 5.036544 5.450752 5.508096 +``` + +`Grad shape`: `4096x11008`, `dtype`: `torch.float32`, `allow_tf32`: `True` + +``` +Median times (ms): + rank torch hybrid fused compiled +0 32.0 1.413120 0.871424 0.817152 1.353184 +1 64.0 1.489920 0.916480 0.854016 1.389568 +2 128.0 1.679360 0.996352 1.005568 1.563648 +3 256.0 2.152448 1.415168 1.470464 2.185216 +4 512.0 3.210240 2.460672 2.580480 3.477504 +``` + +##### Accuracy + +Comparison to reference `torch` implementation: + +``` +Running with 4096 x 4096 grad (param) shape, GaLore orthogonal matrix [128, 4096], dtype torch.float32, and allow_tf32 True +Kernel: hybrid +Accuracy: +-> adam state - running grad mean: + Max err: 0.000000 Relative err: 0.000001 +-> adam state - running grad var: + Max err: 0.000002 Relative err: 0.000002 +-> params (after update): + Max err: 0.000000 Relative err: 0.000001 +``` + +``` +Running with 4096 x 4096 grad (param) shape, GaLore orthogonal matrix [128, 4096], dtype torch.float32 and allow_tf32 False +Kernel: hybrid +Accuracy: +-> adam state - running grad mean: + Max err: 0.000000 Relative err: 0.000000 +-> adam state - running grad var: + Max err: 0.000002 Relative err: 0.000002 +-> params (after update): + Max err: 0.000000 Relative err: 0.000000 +``` + +``` +Running with 4096 x 4096 grad (param) shape, GaLore orthogonal matrix [128, 4096], dtype torch.float32 and allow_tf32 True +Kernel: fused +Accuracy: +-> adam state - running grad mean: + Max err: 0.000845 Relative err: 0.001152 +-> adam state - running grad var: + Max err: 0.000162 Relative err: 0.000161 +-> params (after update): + Max err: 0.000000 Relative err: 0.000001 +``` + +``` +Running with 4096 x 4096 grad (param) shape, GaLore orthogonal matrix [128, 4096], dtype torch.float32 and allow_tf32 False +Kernel: fused +Accuracy: +-> adam state - running grad mean: +Max err: 0.000003 Relative err: 0.000004 +-> adam state - running grad var: +Max err: 0.000002 Relative err: 0.000002 +-> params (after update): +Max err: 0.000000 Relative err: 0.000000 +``` + +#### Discussion + +##### Down Projection GEMM Shape + +The motivation for the `hybrid` approach is the unconventional matrix shapes of the down projection (Step 1): + +- The projection is always done such that the larger dimension of the `grad` matrix is maintained while other is projected to low rank per the `GaLore` algorithm + - E.g., if `M >= N`, the GEMM is of shape (`M x N`) x (`N x rank`) = (`M x rank`), (`rank x M`) x (`M x N`) = (`rank x N`) otherwise +- Since `{M, N} >> rank` by definition, this results in a large reduction dimension relative to one of the output dimensions (output matrix is either fat or skinny) +- This does not fit cleanly into the `split-k / parallel reduction` `GEMM` paradigm which is more tailored for shapes where both output dims are smaller than the reduction dimension. +- Consequently, I had trouble finding an optimal kernel config using `triton` `autotuner` for the down projection step, despite tuning across many compute and io-bound configs (see `fused.triton_utils.kernels.matmul.py`). +- Benchmarking `triton`-tuned `matmul` against default `torch.matmul` for these shapes showed worse performance, for `torch.float32` + +#### Effect of `TF32` tensor cores + +`allow_tf32`: this has significant impact on relative performance of `triton` vs `torch` matmuls: + +- Quick benchmarks of the downprojection `matmul` show that: + - with `allow_tf32=True` for both, triton exhibits `~1.30x` performance improvement over `torch`. + - with `allow_tf32=False`, performance of `triton` degrades significantly to `~.50x` of `torch`. + +See this [`torch note`](https://pytorch.org/docs/stable/notes/cuda.html#tf32-on-ampere) for more details on this feature. + +**Note**: This might be less of a concern given this incoming triton [PR](https://github.com/openai/triton/pull/3234), which implements a fast `TF32` trick that improves both performance and accuracy. + +#### Repro + +`tests/test_fused_kernels.py` is a `CLI` that has 2 modes, one for testing kernel accuracy, and the other for benchmarking across a number of configs. + +**Examples** + +_Accuracy_ + +- Test accuracy of `torch` vs `hybrid` for `M=4096`, `N=4096`, `rank=128`, and `tf32` switched on: + + ```python + python tests/test_fused_kernels.py --mode=test --kernel=hybrid --M=4096 --N=4096 --rank=128 --allow_tf32 + ``` + +_Benchmark_ + +- Benchmark across all kernels without `tf32`: + + ```python + python tests/test_fused_kernels.py --mode=benchmark + ``` + +_Additional options_ + +```python + python tests/test_fused_kernels.py --help +``` + +_Note:_ Passing in the additional flag `--verbose` will show `triton` autotuning logs -- I customized the `triton` autotuner spit out configs and other details. + +#### Test Env + +- GPU Device Props: + - Name: `NVIDIA RTX A6000` + - CC: `86` + - Total_memory: `48676MB` + - SM count: `84` +- Torch: `2.3.0.dev20240310+cu118` +- Triton: `3.0.0` + +#### Next Steps + +- [ ] Implement `FusedGaLoreOptimizer` +- [ ] `Cutlass` - given fixed GEMM shape, experiment with `Cutlass` GEMMs (`split-k`, `stream-k`, fast `tensorops`). Interestingly, profiling `torch.matmul` for down projection shows that `cuBlas` dispatches to a `Cutlass` kernel of shape `128x128x16`. +- [ ] Repeat with `AdamW8bit` +- [ ] More detailed analysis of `torch.compile` performance diff --git a/galore_torch/fused/__init__.py b/galore_torch/fused/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/galore_torch/fused/triton_utils/__init__.py b/galore_torch/fused/triton_utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/galore_torch/fused/triton_utils/custom_autotune.py b/galore_torch/fused/triton_utils/custom_autotune.py new file mode 100644 index 0000000..3b76605 --- /dev/null +++ b/galore_torch/fused/triton_utils/custom_autotune.py @@ -0,0 +1,392 @@ +from __future__ import annotations + +import builtins +import logging +import os +import time +from typing import Dict + +import numpy as np +from triton.runtime.cache import default_cache_dir +from triton.runtime.errors import OutOfResources +from triton.runtime.jit import KernelInterface +from triton.testing import do_bench + +logger = logging.getLogger(__file__) + + +class Autotuner(KernelInterface): + + def __init__( + self, + fn, + arg_names, + configs, + key, + reset_to_zero, + restore_value, + prune_configs_by: Dict = None, + warmup=25, + rep=100, + ): + """ + :param prune_configs_by: a dict of functions that are used to prune configs, fields: + 'perf_model': performance model used to predicate running time with different configs, returns running time + 'top_k': number of configs to bench + 'prune_num_stages_by'(optional): a function used to prune num_stages. It takes configs:List[Config] as its input, and returns pruned configs. + """ + if not configs: + self.configs = [Config({}, num_warps=4, num_stages=2, num_ctas=1)] + else: + self.configs = configs + self.key_idx = [arg_names.index(k) for k in key] + self.cache = {} + self.arg_names = arg_names + + # Reset to zero or restore values + self.reset_idx = [] + if reset_to_zero is not None: + self.reset_idx = [arg_names.index(k) for k in reset_to_zero] + self.restore_idx = [] + if restore_value is not None: + self.restore_idx = [arg_names.index(k) for k in restore_value] + + # Hook to reset or restore for required tensors + self.pre_hook = lambda args, reset_only=False: 0 + self.post_hook = lambda args: 0 + if len(self.reset_idx) > 0 or len(self.restore_idx) > 0: + + def _pre_hook(args, reset_only=False): + for i in self.reset_idx: + args[i].zero_() + if not reset_only: + self.restore_copies = [args[i].clone() for i in self.restore_idx] + + self.pre_hook = _pre_hook + if len(self.restore_idx) > 0: + + def _post_hook(args): + for i, j in enumerate(self.restore_idx): + args[j].copy_(self.restore_copies[i]) + self.restore_copies = [] + + self.post_hook = _post_hook + + self.perf_model = None + self.configs_top_k = 1.0 + self.early_config_prune = None + if prune_configs_by: + self.perf_model = prune_configs_by.get("perf_model", self.perf_model) + self.configs_top_k = prune_configs_by.get("top_k", self.configs_top_k) + self.early_config_prune = prune_configs_by.get( + "early_config_prune", self.early_config_prune + ) + + self.fn = fn + self.num_warmups = warmup + self.num_reps = rep + # self.autotune_log_path = os.path.join(default_cache_dir(), autotune_log_file) + self.kernel_name = self._find_kernel_name() + + def _find_kernel_name(self): + try: + kernel_name = self.fn.__name__ + except AttributeError: + try: # in case JITfn is wrapped in both autotune and heuristic + kernel_name = self.fn.fn.__name__ + except: # noqa + kernel_name = self.fn.__name__ + return kernel_name + + def _get_key_combination(self, args, as_str=True, sep=" "): + key_vals = [f"{self.arg_names[i]}={args[i]}" for i in self.key_idx] + return f"{sep}".join(key_vals) if as_str else key_vals + + def _bench(self, *args, config, **meta): + # check for conflicts, i.e. meta-parameters both provided + # as kwargs and by the autotuner + conflicts = meta.keys() & config.kwargs.keys() + if conflicts: + raise ValueError( + f"Conflicting meta-parameters: {', '.join(conflicts)}." + " Make sure that you don't re-define auto-tuned symbols." + ) + # augment meta-parameters with tunable ones + current = dict(meta, **config.kwargs) + full_nargs = {**self.nargs, **current} + + def kernel_call(): + if config.pre_hook: + config.pre_hook(full_nargs) + self.pre_hook(args) + self.fn.run( + *args, + num_warps=config.num_warps, + num_stages=config.num_stages, + num_ctas=config.num_ctas, + **current, + ) + self.post_hook(args) + + try: + return do_bench( + kernel_call, + warmup=self.num_warmups, + rep=self.num_reps, + quantiles=(0.5, 0.2, 0.8), + ) + except OutOfResources: + return [float("inf"), float("inf"), float("inf")] + + def run(self, *args, **kwargs): + self.nargs = dict(zip(self.arg_names, args)) + if len(self.configs) > 1: + all_args = {**self.nargs, **kwargs} + _args = [] + for name in self.arg_names: + if name in all_args: + _args.append(all_args[name]) + key = [_args[i] for i in self.key_idx] + for arg in _args: + if hasattr(arg, "dtype"): + key.append(str(arg.dtype)) + key = tuple(key) + if key not in self.cache: + logger.debug("Cache miss!\n") + logger.info( + f"\n==== Autotune ====\nRunning autotune for {self.kernel_name} for {len(self.configs)} total configs" + f" for key combination {self._get_key_combination(args)}..." + ) + # prune configs + pruned_configs = self.prune_configs(kwargs) + logger.info(f"\nNum configs after pruning {len(pruned_configs)}") + bench_start = time.time() + timings = {} + for config in pruned_configs: + timings[config] = self._bench(*args, config=config, **kwargs) + # timings = { + # config: self._bench(*args, config=config, **kwargs) + # for config in pruned_configs + # } + bench_end = time.time() + self.bench_time = bench_end - bench_start + self.cache[key] = builtins.min(timings, key=timings.get) + self.pre_hook(args, reset_only=True) + self.configs_timings = timings + + sorted_timings = dict( + sorted(timings.items(), key=lambda x: np.mean(x[1])) + ) + _key_suffix = self._get_key_combination(args, sep="-") + autotune_file = f"autotune_{self.kernel_name}_{_key_suffix}.log" + autotune_log_path = os.path.join(default_cache_dir(), autotune_file) + + logger.info(f"\nFinished autotune, writing log to {autotune_log_path}") + + with open(f"{autotune_log_path}", "w") as f: + f.write( + f" ==== Autotune Results ====\nKernel name: {self.kernel_name}\nArgs: {self.arg_names}\nKeys: {self._get_key_combination(args)}\n" + ) + f.write(f"\nPruned configs:\n") + for cfg in pruned_configs: + f.write(f"{cfg}\n") + f.write(f"Timings:\n") + for cfg, timing in sorted_timings.items(): + f.write(f"{cfg} {timing} \n") + f.write(f"Best config: {self.cache[key]}\n") + config = self.cache[key] + logger.debug(f"\nAutotune: Cache hit! Running best config...") + else: + config = self.configs[0] + self.best_config = config + logger.info(f"\nAutotune Best Config: {config}\n") + + full_nargs = {**self.nargs, **kwargs, **self.best_config.kwargs} + if config.pre_hook is not None: + config.pre_hook(full_nargs) + ret = self.fn.run( + *args, + num_warps=config.num_warps, + num_stages=config.num_stages, + num_ctas=config.num_ctas, + **kwargs, + **config.kwargs, + ) + self.nargs = None + return ret + + def prune_configs(self, kwargs): + pruned_configs = self.configs + if self.early_config_prune: + pruned_configs = self.early_config_prune(self.configs, self.nargs) + if self.perf_model: + top_k = self.configs_top_k + if isinstance(top_k, float) and top_k <= 1.0: + top_k = int(len(self.configs) * top_k) + if len(pruned_configs) > top_k: + est_timing = { + config: self.perf_model( + **self.nargs, + **kwargs, + **config.kwargs, + num_stages=config.num_stages, + num_warps=config.num_warps, + num_ctas=config.num_ctas, + ) + for config in pruned_configs + } + pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[ + :top_k + ] + return pruned_configs + + def warmup(self, *args, **kwargs): + self.nargs = dict(zip(self.arg_names, args)) + ret = [] + for config in self.prune_configs(kwargs): + ret.append( + self.fn.warmup( + *args, + num_warps=config.num_warps, + num_ctas=config.num_ctas, + num_stages=config.num_stages, + **kwargs, + **config.kwargs, + ) + ) + self.nargs = None + return ret + + +class Config: + """ + An object that represents a possible kernel configuration for the auto-tuner to try. + + :ivar meta: a dictionary of meta-parameters to pass to the kernel as keyword arguments. + :type meta: dict[Str, Any] + :ivar num_warps: the number of warps to use for the kernel when compiled for GPUs. For example, if + `num_warps=8`, then each kernel instance will be automatically parallelized to + cooperatively execute using `8 * 32 = 256` threads. + :type num_warps: int + :ivar num_stages: the number of stages that the compiler should use when software-pipelining loops. + Mostly useful for matrix multiplication workloads on SM80+ GPUs. + :type num_ctas: int + :ivar num_ctas: number of blocks in a block cluster. SM90+ only. + :ivar pre_hook: a function that will be called before the kernel is called. Parameters of this + function are args. + """ + + def __init__(self, kwargs, num_warps=4, num_stages=2, num_ctas=1, pre_hook=None): + self.kwargs = kwargs + self.num_warps = num_warps + self.num_ctas = num_ctas + self.num_stages = num_stages + self.pre_hook = pre_hook + + def __str__(self): + res = [] + for k, v in self.kwargs.items(): + res.append(f"{k}: {v}") + res.append(f"num_warps: {self.num_warps}") + res.append(f"num_ctas: {self.num_ctas}") + res.append(f"num_stages: {self.num_stages}") + return ", ".join(res) + + +def autotune( + configs, + key, + prune_configs_by=None, + reset_to_zero=None, + restore_value=None, + warmup=25, + rep=100, +): + """ + Decorator for auto-tuning a :code:`triton.jit`'d function. + + .. highlight:: python + .. code-block:: python + + @triton.autotune(configs=[ + triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4), + triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8), + ], + key=['x_size'] # the two above configs will be evaluated anytime + # the value of x_size changes + ) + @triton.jit + def kernel(x_ptr, x_size, **META): + BLOCK_SIZE = META['BLOCK_SIZE'] + :note: When all the configurations are evaluated, the kernel will run multiple times. + This means that whatever value the kernel updates will be updated multiple times. + To avoid this undesired behavior, you can use the `reset_to_zero` argument, which + resets the value of the provided tensor to `zero` before running any configuration. + :param configs: a list of :code:`triton.Config` objects + :type configs: list[triton.Config] + :param key: a list of argument names whose change in value will trigger the evaluation of all provided configs. + :type key: list[str] + :param prune_configs_by: a dict of functions that are used to prune configs, fields: + 'perf_model': performance model used to predicate running time with different configs, returns running time + 'top_k': number of configs to bench + 'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It takes configs:List[Config] as its input, and returns pruned configs. + :param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs. + :type reset_to_zero: list[str] + :param restore_value: a list of argument names whose value will be restored after evaluating any configs. + :type restore_value: list[str] + :param warmup: Warmup time (in ms) to pass to benchmarking, defaults to 25. + :type warmup: int + :param rep: Repetition time (in ms) to pass to benchmarking, defaults to 100. + :type rep: int + """ + + def decorator(fn): + return Autotuner( + fn, + fn.arg_names, + configs, + key, + reset_to_zero, + restore_value, + prune_configs_by, + warmup, + rep, + ) + + return decorator + + +class Heuristics(KernelInterface): + + def __init__(self, fn, arg_names, values) -> None: + self.fn = fn + self.values = values + self.arg_names = arg_names + + def run(self, *args, **kwargs): + for v, heur in self.values.items(): + kwargs[v] = heur({**dict(zip(self.arg_names, args)), **kwargs}) + return self.fn.run(*args, **kwargs) + + +def heuristics(values): + """ + Decorator for specifying how the values of certain meta-parameters may be computed. + This is useful for cases where auto-tuning is prohibitevely expensive, or just not applicable. + + .. highlight:: python + .. code-block:: python + + @triton.heuristics(values={'BLOCK_SIZE': lambda args: 2 ** int(math.ceil(math.log2(args[1])))}) + @triton.jit + def kernel(x_ptr, x_size, **META): + BLOCK_SIZE = META['BLOCK_SIZE'] # smallest power-of-two >= x_size + :param values: a dictionary of meta-parameter names and functions that compute the value of the meta-parameter. + each such function takes a list of positional arguments as input. + :type values: dict[str, Callable[[list[Any]], Any]] + """ + + def decorator(fn): + return Heuristics(fn, fn.arg_names, values) + + return decorator diff --git a/galore_torch/fused/triton_utils/kernels/__init__.py b/galore_torch/fused/triton_utils/kernels/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/galore_torch/fused/triton_utils/kernels/adam_downproj_fused.py b/galore_torch/fused/triton_utils/kernels/adam_downproj_fused.py new file mode 100644 index 0000000..e633c6c --- /dev/null +++ b/galore_torch/fused/triton_utils/kernels/adam_downproj_fused.py @@ -0,0 +1,355 @@ +import logging + +import torch +import triton +import triton.language as tl +from triton.ops.matmul import get_higher_dtype, init_to_zero +from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time + +from galore_torch.fused.triton_utils.custom_autotune import Config, autotune +from galore_torch.fused.triton_utils.kernels.adam_step import BETA1, BETA2, EPS +from galore_torch.fused.triton_utils.kernels.matmul import TRITON_ACC_TYPES +from galore_torch.fused.triton_utils.kernels.matmul import ( + get_autotuner as default_mm_autotuner, +) +from galore_torch.fused.triton_utils.kernels.matmul import get_mm_heuristics, to_tl_type + +logger = logging.getLogger(__name__) + +AUTOTUNER_TOP_K = 50 + + +def set_tuner_top_k(k): + global AUTOTUNER_TOP_K + AUTOTUNER_TOP_K = k + + +@triton.jit +def _fused_adam_mm_kernel( + # matmul args + A, + B, + C, + M, + N, + K, # + stride_am, + stride_ak, # + stride_bk, + stride_bn, # + stride_cm, + stride_cn, # + # adam epilogue, + exp_avg_ptr, # these will be updated inplace + exp_avg2_ptr, + store, + # grad_ptr, # low rank grad output -- not needed, C is the output + # meta params + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, # + SPLIT_K: tl.constexpr, + EVEN_K: tl.constexpr, + GROUP_M: tl.constexpr, + # Adam-specific params + BETA1: tl.constexpr = BETA1, + BETA2: tl.constexpr = BETA2, + EPS: tl.constexpr = EPS, + # matmul kernel settings + acc_dtype: tl.constexpr = tl.float32, # + allow_tf32: tl.constexpr = False, # higher precision for this phase + fp8_fast_accum: tl.constexpr = False, # + AB_DTYPE: tl.constexpr = None, # +): + # matrix multiplication + pid = tl.program_id(0) + pid_z = tl.program_id(1) + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + # do matrix multiplication + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K) + # pointers + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=acc_dtype) + for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): + if EVEN_K: + a = tl.load(A) + b = tl.load(B) + else: + k_remaining = K - k * (BLOCK_K * SPLIT_K) + _0 = tl.zeros((1, 1), dtype=C.dtype.element_ty) + a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0) + b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0) + if AB_DTYPE is not None: + a = a.to(AB_DTYPE) + b = b.to(AB_DTYPE) + if fp8_fast_accum: + acc = tl.dot(a, b, acc, out_dtype=acc_dtype, allow_tf32=allow_tf32) + else: + acc += tl.dot(a, b, out_dtype=acc_dtype, allow_tf32=allow_tf32) + A += BLOCK_K * SPLIT_K * stride_ak + B += BLOCK_K * SPLIT_K * stride_bk + # acc = acc.to(C.dtype.element_ty) + + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + epilogue_offsets = rm[:, None] * stride_cm + rn[None, :] * stride_cn + mask = (rm < M)[:, None] & (rn < N)[None, :] + + # Load adam state + exp_avg = tl.load(exp_avg_ptr + epilogue_offsets, mask=mask) + exp_avg2 = tl.load(exp_avg2_ptr + epilogue_offsets, mask=mask) + + # Perform update + exp_avg = BETA1 * exp_avg.to(acc.dtype) + (1.0 - BETA1) * acc + exp_avg2 = BETA2 * exp_avg2.to(acc.dtype) + (1.0 - BETA2) * (acc * acc) + denom = tl.sqrt(exp_avg2) + EPS + norm_grad = exp_avg / denom + # Convert to output type + norm_grad = norm_grad.to(C.dtype.element_ty) + + # acc = acc.to(C.dtype.element_ty) + C = C + epilogue_offsets + + # handles write-back with reduction-splitting + if SPLIT_K == 1: + tl.store(C, norm_grad, mask=mask) + else: + tl.atomic_add(C, norm_grad, mask=mask) + + if store: + tl.store( + exp_avg_ptr + epilogue_offsets, + exp_avg, + mask=mask, + ) + tl.store( + exp_avg2_ptr + epilogue_offsets, + exp_avg2, + mask=mask, + ) + + +def _get_configs_splitk_all(): + """ + Configs specific to split-k matmuls + Not used currently + """ + configs = [] + for num_stages in [2, 3, 4, 5]: + for block_m in [16, 32, 64, 128]: + for block_k in [16, 32, 64, 128, 256]: + for block_n in [16, 32, 64, 128]: + num_warps = 2 if block_n <= 64 else 4 + configs.append( + Config( + { + "BLOCK_M": block_m, + "BLOCK_N": block_n, + "BLOCK_K": block_k, + "SPLIT_K": 1, + }, + num_stages=num_stages, + num_warps=num_warps, + ) + ) + # split_k + for split_k in [2, 4, 8]: + configs.append( + Config( + { + "BLOCK_M": block_m, + "BLOCK_N": block_n, + "BLOCK_K": block_k, + "SPLIT_K": split_k, + }, + num_stages=num_stages, + num_warps=num_warps, + pre_hook=init_to_zero("C"), + ) + ) + return configs + + +def _get_configs_splitk_small(): + """Configs for split-k, smaller version than above + Not used currently + """ + configs = [] + for num_stages in [2, 3, 4]: + for block_m in [64, 128]: + for block_k in [16, 32, 64]: + for block_n in [64, 128]: + num_warps = 2 if block_n <= 64 else 4 + configs.append( + Config( + { + "BLOCK_M": block_m, + "BLOCK_N": block_n, + "BLOCK_K": block_k, + "SPLIT_K": 1, + }, + num_stages=num_stages, + num_warps=num_warps, + ) + ) + # split_k + for split_k in [2, 4, 8]: + configs.append( + Config( + { + "BLOCK_M": block_m, + "BLOCK_N": block_n, + "BLOCK_K": block_k, + "SPLIT_K": split_k, + }, + num_stages=num_stages, + num_warps=num_warps, + pre_hook=init_to_zero("C"), + ) + ) + return configs + + +def _splitk_autotuner( + configs=_get_configs_splitk_small(), + key=["M", "N", "K"], + early_config_prune=early_config_prune, + perf_model=estimate_matmul_time, + top_k=AUTOTUNER_TOP_K, +): + """Autotuner for splitk matmuls + Not used currently + """ + autotuner = autotune( + configs=configs, + key=key, + prune_configs_by={ + "early_config_prune": early_config_prune, + "perf_model": perf_model, + "top_k": top_k, + }, + ) + + return autotuner + + +def _get_kernel( + tuner_fn=default_mm_autotuner, heuristics_fn=get_mm_heuristics, topk=AUTOTUNER_TOP_K +): + tuner = tuner_fn() + tuner.topk = topk + heuristics = heuristics_fn() + return tuner(heuristics(_fused_adam_mm_kernel)) + + +DEFAULT_KERNEL = _get_kernel() + + +def fused_adam_mm_launcher( + a, + b, + *, + exp_avg, + exp_avg2, + store=True, + BETA1=BETA1, + BETA2=BETA2, + EPS=EPS, + allow_tf32=False, + fp8_fast_accum=False, + acc_dtype=None, + output_dtype=None, + kernel=None, +): + + device = a.device + # handle non-contiguous inputs if necessary + # a = grad + # b = galore_proj.ortho_matrix.t() + if a.stride(0) > 1 and a.stride(1) > 1: + a = a.contiguous() + if b.stride(0) > 1 and b.stride(1) > 1: + b = b.contiguous() + # checks constraints + assert a.shape[1] == b.shape[0], "incompatible dimensions" + M, K = a.shape + _, N = b.shape + + # common type between a and b + ab_dtype = get_higher_dtype(a.dtype, b.dtype) + + # allocates output + if output_dtype is None: + output_dtype = ab_dtype + + c = torch.empty((M, N), device=device, dtype=output_dtype) + + if acc_dtype is None: + acc_dtype = [ab_dtype][0] + else: + assert isinstance(acc_dtype, torch.dtype), "acc_dtype must be a torch.dtype" + assert ( + acc_dtype in TRITON_ACC_TYPES[a.dtype] + ), "acc_dtype not compatible with the type of a" + assert ( + acc_dtype in TRITON_ACC_TYPES[b.dtype] + ), "acc_dtype not compatible with the type of b" + + acc_dtype = to_tl_type(acc_dtype) + ab_dtype = to_tl_type(ab_dtype) + output_dtype = to_tl_type(output_dtype) + + # Tensor cores support input with mixed float8 types. + if a.dtype in [tl.float8e4nv, tl.float8e5] and b.dtype in [ + tl.float8e4nv, + tl.float8e5, + ]: + ab_dtype = None + + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), + META["SPLIT_K"], + ) + + if kernel is None: + kernel = DEFAULT_KERNEL + kernel[grid]( + a, + b, + c, + M, + N, + K, # + a.stride(0), + a.stride(1), # + b.stride(0), + b.stride(1), # + c.stride(0), + c.stride(1), # + exp_avg, + exp_avg2, + store=store, + BETA1=BETA1, # , # + BETA2=BETA2, # , # + EPS=EPS, # + acc_dtype=acc_dtype, # + allow_tf32=allow_tf32, # + fp8_fast_accum=fp8_fast_accum, # + GROUP_M=8, + AB_DTYPE=ab_dtype, + ) + return exp_avg, exp_avg2, c # c -> normalized low rank grad diff --git a/galore_torch/fused/triton_utils/kernels/adam_step.py b/galore_torch/fused/triton_utils/kernels/adam_step.py new file mode 100644 index 0000000..7b9a07a --- /dev/null +++ b/galore_torch/fused/triton_utils/kernels/adam_step.py @@ -0,0 +1,178 @@ +import torch +import triton +import triton.language as tl +from triton.language.extra.cuda import libdevice +from triton.runtime.autotuner import heuristics + +from galore_torch.fused.triton_utils.custom_autotune import Config, autotune + +BETA1, BETA2 = 0.9, 0.999 +EPS = 1e-8 + +AUTOTUNER_TOP_K = 10 + + +def get_configs_for_adam(num_warps=[2, 4, 8], block_sizes=[512, 1024, 2048]): + configs = [] + for w in num_warps: + for bs in block_sizes: + configs.append(Config({"BLOCK_SIZE": bs}, num_warps=w)) + return configs + + +def early_adam_prune(configs, named_args): + numels = named_args["numels"] + pruned_configs = [cfg for cfg in configs if numels % cfg.kwargs["BLOCK_SIZE"] == 0] + # print("Pruned configs:\n") + for cfg in pruned_configs: + print(f"{cfg}\n") + return pruned_configs + + +def get_adam_tuner( + configs=get_configs_for_adam(), + early_config_prune=None, # early_adam_prune, + top_k=AUTOTUNER_TOP_K, +): + return autotune( + configs=configs, + prune_configs_by={ + "early_config_prune": early_config_prune, + "top_k": top_k, + }, + key=["numels"], + ) + + +def get_adam_heuristics(): + return { + "USE_MASK": lambda args: args["numels"] % args["BLOCK_SIZE"] != 0, + } + + +@autotune(configs=get_configs_for_adam(), key=["numels"]) +@heuristics(get_adam_heuristics()) +@triton.jit +def _adam_update( + avg_ptr, + avg2_ptr, + grad_ptr, + # avg_out_ptr, + # avg2_out_ptr, + # grad_out_ptr, + numels, + store, + BLOCK_SIZE: tl.constexpr, + USE_MASK: tl.constexpr, + BETA1: tl.constexpr = BETA1, + BETA2: tl.constexpr = BETA2, + EPS: tl.constexpr = EPS, +): + pid_m = tl.program_id(0) + offset = pid_m * BLOCK_SIZE + offset = offset + tl.arange(0, BLOCK_SIZE) + # load_idx = offset + tl.arange(0, BLOCK_SIZE) + load_idx = tl.max_contiguous(tl.multiple_of(offset, BLOCK_SIZE), BLOCK_SIZE) + + mask = None + if USE_MASK: + mask = load_idx < numels + avg = tl.load(avg_ptr + load_idx, mask=mask) + avg2 = tl.load(avg2_ptr + load_idx, mask=mask) + grad = tl.load(grad_ptr + load_idx, mask=mask) + + avg = BETA1 * avg + (1.0 - BETA1) * grad + avg2 = BETA2 * avg2 + (1.0 - BETA2) * (grad * grad) + + denom = libdevice.sqrt(avg2) + EPS + # denom = tl.sqrt(avg2) + EPS + + norm_grad = avg / denom + + if store: + tl.store(avg_ptr + load_idx, avg, mask=mask) + tl.store(avg2_ptr + load_idx, avg2, mask=mask) + tl.store(grad_ptr + load_idx, norm_grad, mask=mask) + # tl.store(avg_out_ptr + load_idx, avg, mask=mask) + # tl.store(avg2_out_ptr + load_idx, avg2, mask=mask) + # tl.store(grad_out_ptr + load_idx, norm_grad, mask=mask) + + +adam_update = _adam_update + + +def triton_adam_launcher( + avg, + avg2, + grad, + store=True, + beta1=BETA1, + beta2=BETA2, + eps=EPS, +): + M, N = avg.shape + # avg_out = torch.empty_like(avg) + # avg2_out = torch.empty_like(avg2) + # grad_out = torch.empty_like(grad) + + grid = lambda META: (triton.cdiv(M * N, META["BLOCK_SIZE"]),) + adam_update[grid]( + avg, + avg2, + grad, + # avg_out, + # avg2_out, + # grad_out, + avg.numel(), + store=store, + BETA1=beta1, + BETA2=beta2, + EPS=eps, + # BLOCK_SIZE=1024, + # USE_MASK=USE_MASK, + ) + return avg, avg2, grad + + +def ref_adam_step(exp_avg, exp_avg2, grad, beta1=BETA1, beta2=BETA2, eps=EPS): + exp_avg = beta1 * exp_avg + (1 - beta1) * grad + exp_avg2 = beta2 * exp_avg2 + (1 - beta2) * torch.square(grad) + denom = exp_avg2.sqrt() + eps + norm_grad = exp_avg / denom + return exp_avg, exp_avg2, norm_grad + + +def make_data(M, N, rank, dtype): + # full_grad = torch.randn(M, N, device="cuda", dtype=dtype) + params = torch.randn(M, N, device="cuda", dtype=dtype) + + if M >= N: + exp_avg = torch.randn(M, rank, device="cuda", dtype=dtype) + else: + exp_avg = torch.randn(rank, N, device="cuda", dtype=dtype) + exp_avg2 = exp_avg**2 + down_grad = torch.randn_like(exp_avg) + + return exp_avg, exp_avg2, down_grad, params + + +if __name__ == "__main__": + from triton.testing import do_bench + + M = N = 4096 + rank = 128 + dtype = torch.float32 + exp_avg, exp_avg2, grad, params = make_data(M, N, rank, dtype=dtype) + exp_avg_copy, exp_avg2_copy, grad_copy = ( + exp_avg.clone(), + exp_avg2.clone(), + grad.clone(), + ) + ref_out = ref_adam_step(exp_avg, exp_avg2, grad) + + # Autotune run -- changes exp_avg, exp_avg2, grad in-place + _ = triton_adam_launcher(exp_avg, exp_avg2, grad) + triton_out = triton_adam_launcher(exp_avg_copy, exp_avg2_copy, grad_copy) + + for ref, tt in zip(ref_out, triton_out): + print(torch.max(torch.abs(ref - tt))) diff --git a/galore_torch/fused/triton_utils/kernels/matmul.py b/galore_torch/fused/triton_utils/kernels/matmul.py new file mode 100644 index 0000000..249598b --- /dev/null +++ b/galore_torch/fused/triton_utils/kernels/matmul.py @@ -0,0 +1,383 @@ +import torch +import triton +import triton.language as tl +from triton.ops.matmul import get_higher_dtype, init_to_zero +from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time + +from ..custom_autotune import Config, autotune, heuristics + +# Allowed types for acc_type given the types of a and b. +TRITON_ACC_TYPES = { + torch.float16: (torch.float32, torch.float16), + torch.bfloat16: (torch.float32, torch.bfloat16), + torch.float32: (torch.float32,), + torch.int8: (torch.int32,), +} + +AUTOTUNER_TOP_K = 50 + + +def set_tuner_top_k(k): + global AUTOTUNER_TOP_K + AUTOTUNER_TOP_K = k + + +def to_tl_type(ty): + return getattr(tl, str(ty).split(".")[-1]) + + +def get_configs_io_bound(): + configs = [] + for num_stages in [2, 3, 4, 5, 6]: + for block_m in [16, 32]: + for block_k in [32, 64]: + for block_n in [32, 64, 128, 256]: + num_warps = 2 if block_n <= 64 else 4 + configs.append( + Config( + { + "BLOCK_M": block_m, + "BLOCK_N": block_n, + "BLOCK_K": block_k, + "SPLIT_K": 1, + }, + num_stages=num_stages, + num_warps=num_warps, + ) + ) + # split_k + for split_k in [2, 4, 8, 16]: + configs.append( + Config( + { + "BLOCK_M": block_m, + "BLOCK_N": block_n, + "BLOCK_K": block_k, + "SPLIT_K": split_k, + }, + num_stages=num_stages, + num_warps=num_warps, + pre_hook=init_to_zero("C"), + ) + ) + return configs + + +def get_configs_compute_bound(): + configs = [ + # basic configs for compute-bound matmuls + Config( + {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=3, + num_warps=8, + ), + Config( + {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=3, + num_warps=8, + ), + Config( + {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=5, + num_warps=2, + ), + # good for int8 + Config( + {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, + num_stages=3, + num_warps=8, + ), + Config( + {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, + num_stages=3, + num_warps=8, + ), + Config( + {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, + num_stages=5, + num_warps=2, + ), + ] + return configs + + +def get_autotuner( + configs=get_configs_compute_bound() + get_configs_io_bound(), + key=["M", "N", "K"], + early_config_prune=early_config_prune, + perf_model=estimate_matmul_time, + top_k=AUTOTUNER_TOP_K, +): + autotuner = autotune( + configs=configs, + key=key, + prune_configs_by={ + "early_config_prune": early_config_prune, + "perf_model": perf_model, + "top_k": top_k, + }, + ) + + return autotuner + + +def get_mm_heuristics(): + return heuristics( + { + "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0, + } + ) + + +@triton.jit +def _matmul_kernel( + A, + B, + C, + M, + N, + K, # + stride_am, + stride_ak, # + stride_bk, + stride_bn, # + stride_cm, + stride_cn, # + # meta params + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, # + SPLIT_K: tl.constexpr, + EVEN_K: tl.constexpr, + GROUP_M: tl.constexpr, + # epilogue + epilogue_alpha=None, + epilogue_beta=None, + epilogue_source=None, # Corresponds to C in GEMM convention of D = AB + C + # matmul kernel settings + acc_dtype: tl.constexpr = tl.float32, # + allow_tf32: tl.constexpr = True, # + fp8_fast_accum: tl.constexpr = True, # + AB_DTYPE: tl.constexpr = None, # + EPILOGUE: tl.constexpr = False, +): + # matrix multiplication + pid = tl.program_id(0) + pid_z = tl.program_id(1) + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + # do matrix multiplication + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K) + # pointers + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=acc_dtype) + for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): + if EVEN_K: + a = tl.load(A) + b = tl.load(B) + else: + k_remaining = K - k * (BLOCK_K * SPLIT_K) + _0 = tl.zeros((1, 1), dtype=C.dtype.element_ty) + a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0) + b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0) + if AB_DTYPE is not None: + a = a.to(AB_DTYPE) + b = b.to(AB_DTYPE) + if fp8_fast_accum: + acc = tl.dot(a, b, acc, out_dtype=acc_dtype, allow_tf32=allow_tf32) + else: + acc += tl.dot(a, b, out_dtype=acc_dtype, allow_tf32=allow_tf32) + A += BLOCK_K * SPLIT_K * stride_ak + B += BLOCK_K * SPLIT_K * stride_bk + # acc = acc.to(C.dtype.element_ty) + + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + + if EPILOGUE: + if epilogue_alpha is not None: + acc = epilogue_alpha.to(acc_dtype) * acc + if epilogue_source is not None: + epilogue_src = tl.load( + epilogue_source + rm[:, None] * stride_cm + rn[None, :] * stride_cn + ) + if epilogue_beta is not None: + epilogue_src = epilogue_src.to(acc_dtype) * epilogue_beta.to(acc_dtype) + acc = acc + epilogue_src + + acc = acc.to(C.dtype.element_ty) + C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) + mask = (rm < M)[:, None] & (rn < N)[None, :] + # handles write-back with reduction-splitting + if SPLIT_K == 1: + tl.store(C, acc, mask=mask) + else: + tl.atomic_add(C, acc, mask=mask) + + +_autotuner = get_autotuner() +_heuristics = get_mm_heuristics() +matmul = _autotuner(_heuristics(_matmul_kernel)) + + +def triton_mm_launcher( + a, + b, + epilogue_alpha=None, + epilogue_beta=None, + epilogue_source=None, + allow_tf32=True, + fp8_fast_accum=True, + acc_dtype=None, + output_dtype=None, + kernel=matmul, +): + + device = a.device + # handle non-contiguous inputs if necessary + # a = grad + # b = galore_proj.ortho_matrix.t() + if a.stride(0) > 1 and a.stride(1) > 1: + a = a.contiguous() + if b.stride(0) > 1 and b.stride(1) > 1: + b = b.contiguous() + # checks constraints + assert a.shape[1] == b.shape[0], "incompatible dimensions" + M, K = a.shape + _, N = b.shape + + # common type between a and b + ab_dtype = get_higher_dtype(a.dtype, b.dtype) + + # allocates output + if output_dtype is None: + output_dtype = ab_dtype + + c = torch.empty((M, N), device=device, dtype=output_dtype) + + if acc_dtype is None: + acc_dtype = [ab_dtype][0] + else: + assert isinstance(acc_dtype, torch.dtype), "acc_dtype must be a torch.dtype" + assert ( + acc_dtype in TRITON_ACC_TYPES[a.dtype] + ), "acc_dtype not compatible with the type of a" + assert ( + acc_dtype in TRITON_ACC_TYPES[b.dtype] + ), "acc_dtype not compatible with the type of b" + + acc_dtype = to_tl_type(acc_dtype) + ab_dtype = to_tl_type(ab_dtype) + output_dtype = to_tl_type(output_dtype) + + # Tensor cores support input with mixed float8 types. + if a.dtype in [tl.float8e4nv, tl.float8e5] and b.dtype in [ + tl.float8e4nv, + tl.float8e5, + ]: + ab_dtype = None + # launch kernel + # print( + # f"{__file__} triton matmul args: (AB dtype {ab_dtype}) (C dtype {c.dtype}) (allow_tf32 {allow_tf32}) (fp8_fast_accum {fp8_fast_accum})" + # ) + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), + META["SPLIT_K"], + ) + + matmul[grid]( + a, + b, + c, + M, + N, + K, # + a.stride(0), + a.stride(1), # + b.stride(0), + b.stride(1), # + c.stride(0), + c.stride(1), # + epilogue_alpha=epilogue_alpha, # + epilogue_beta=epilogue_beta, # + epilogue_source=epilogue_source, # + acc_dtype=acc_dtype, # + allow_tf32=allow_tf32, # + fp8_fast_accum=fp8_fast_accum, # + GROUP_M=8, + AB_DTYPE=ab_dtype, + EPILOGUE=any([epilogue_alpha, epilogue_beta, epilogue_source]), + ) + return c diff --git a/galore_torch/fused/utils.py b/galore_torch/fused/utils.py new file mode 100644 index 0000000..58857ce --- /dev/null +++ b/galore_torch/fused/utils.py @@ -0,0 +1,111 @@ +import torch + + +def get_orthogonal_matrix(weights, rank, type): + module_params = weights + + if module_params.data.dtype != torch.float: + float_data = False + original_type = module_params.data.dtype + original_device = module_params.data.device + matrix = module_params.data.float() + else: + float_data = True + matrix = module_params.data + + U, s, Vh = torch.linalg.svd(matrix, full_matrices=False) + + # make the smaller matrix always to be orthogonal matrix + if type == "right": + A = U[:, :rank] @ torch.diag(s[:rank]) + B = Vh[:rank, :] + + if not float_data: + B = B.to(original_device).type(original_type) + return B + elif type == "left": + A = U[:, :rank] + B = torch.diag(s[:rank]) @ Vh[:rank, :] + if not float_data: + A = A.to(original_device).type(original_type) + return A + elif type == "full": + A = U[:, :rank] + B = Vh[:rank, :] + if not float_data: + A = A.to(original_device).type(original_type) + B = B.to(original_device).type(original_type) + return [A, B] + else: + raise ValueError("type should be left, right or full") + + +class TestGaLoreProjector: + def __init__( + self, + rank=128, + scale=1.0, + proj_type="std", + ): + self.rank = rank + self.scale = scale + + if proj_type != "std": + raise ("Only std projection is supported") + + self.proj_type = proj_type + + self.ortho_matrix = None + + def update_orthogonal_matrix(self, full_rank_grad): + + if full_rank_grad.shape[0] >= full_rank_grad.shape[1]: + self.ortho_matrix = get_orthogonal_matrix( + full_rank_grad, self.rank, type="right" + ) + else: + self.ortho_matrix = get_orthogonal_matrix( + full_rank_grad, self.rank, type="left" + ) + + def project(self, full_rank_grad): + if full_rank_grad.shape[0] >= full_rank_grad.shape[1]: + low_rank_grad = torch.matmul(full_rank_grad, self.ortho_matrix.t()) + else: + low_rank_grad = torch.matmul(self.ortho_matrix.t(), full_rank_grad) + + return low_rank_grad + + def project_back(self, low_rank_grad): + + if low_rank_grad.shape[0] >= low_rank_grad.shape[1]: + full_rank_grad = torch.matmul(low_rank_grad, self.ortho_matrix) + else: + full_rank_grad = torch.matmul(self.ortho_matrix, low_rank_grad) + + return full_rank_grad * self.scale + + +def make_copy(*args): + return [t.detach().clone() for t in args] + + +# def adam_step( +# exp_avg, +# exp_avg2, +# grad, +# galore_proj, +# params, +# step_size=1e-4, +# beta1=BETA1, +# beta2=BETA2, +# eps=EPS, +# ): +# grad = galore_proj.project(grad) +# exp_avg = beta1 * exp_avg + (1 - beta1) * grad +# exp_avg2 = beta2 * exp_avg2 + (1 - beta2) * torch.square(grad) +# denom = exp_avg2.sqrt() + eps +# norm_grad = exp_avg / denom +# norm_grad = galore_proj.project_back(norm_grad) +# # params = params - step_size * norm_grad +# return exp_avg, exp_avg2, denom, norm_grad diff --git a/requirements.txt b/requirements.txt index c60a4fc..7d9fcda 100755 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ -torch -transformers -bitsandbytes \ No newline at end of file +# torch +# transformers +# bitsandbytes \ No newline at end of file diff --git a/tests/test_fused_kernels.py b/tests/test_fused_kernels.py new file mode 100644 index 0000000..67e28ad --- /dev/null +++ b/tests/test_fused_kernels.py @@ -0,0 +1,168 @@ +import argparse +import logging + +import torch +from test_utils import get_kernel, make_copy, make_data +from triton.testing import do_bench + + +def run_test(kernel, exp_avg, exp_avg2, grad, proj_matrix, params, allow_tf32): + # Copy to use for first run -- needed because of autotuning and inplace ops + ( + exp_avg_autotune_copy, + exp_avg2_autotune_copy, + grad_autotune_copy, + proj_matrix_autotune_copy, + params_autotune_copy, + ) = make_copy(exp_avg, exp_avg2, grad, proj_matrix, params) + + # Copy to use for second run to check accuracy + ( + exp_avg_test_copy, + exp_avg2_test_copy, + grad_test_copy, + proj_matrix_test_copy, + params_test_copy, + ) = make_copy(exp_avg, exp_avg2, grad, proj_matrix, params) + + print( + f"Running with {grad.shape[0]} x {grad.shape[1]} grad (param) shape, GaLore orthogonal matrix {list(proj_matrix.shape)}, dtype {grad.dtype} and allow_tf32 {allow_tf32}\n" + f"Kernel: {kernel}", + flush=True, + ) + + ref_op = get_kernel("ref") + test_op = get_kernel(kernel) + + # Reference run + ref_out = ref_op( + grad, + proj_matrix, + exp_avg, + exp_avg2, + params, + ) + + # Autotune + _ = test_op( + grad_autotune_copy, + proj_matrix_autotune_copy, + exp_avg_autotune_copy, + exp_avg2_autotune_copy, + params_autotune_copy, + store=False, + allow_tf32=allow_tf32, + ) + + # Accuracy run + test_out = test_op( + grad_test_copy, + proj_matrix_test_copy, + exp_avg_test_copy, + exp_avg2_test_copy, + params_test_copy, + store=True, + allow_tf32=allow_tf32, + ) + print("Accuracy:") + + output_names = [ + "adam state - running grad mean", + "adam state - running grad var", + "params (after update)", + ] + for name, ref, tt in zip(output_names, ref_out, test_out): + print( + f"-> {name}:\n Max err: {(ref- tt).abs().max():.6f} Relative err: {(ref- tt).abs().max() / ref.abs().mean():.6f}" + ) + + # Turn off autotune logging during benchmarking + from galore_torch.fused.triton_utils.custom_autotune import logger + + logger.setLevel(logging.WARNING) + + ref_perf = do_bench(lambda: ref_op(grad, proj_matrix, exp_avg, exp_avg2, params)) + test_perf = do_bench( + lambda: test_op( + grad_test_copy, + proj_matrix_test_copy, + exp_avg_test_copy, + exp_avg2_test_copy, + params_test_copy, + store=True, + allow_tf32=allow_tf32, + ) + ) + print( + f"Performance, torch vs test: {ref_perf:.4f}ms vs {test_perf:.4f}ms, {ref_perf / test_perf:1.2f}x" + ) + + +def run(args): + dtype = getattr(torch, args.dtype) + allow_tf32 = args.allow_tf32 + fp8_fast_accum = False + torch.backends.cuda.matmul.allow_tf32 = allow_tf32 + kernel = args.kernel + M, N = args.M, args.N + rank = args.rank + + exp_avg, exp_avg2, grad, proj_matrix, params = make_data(M, N, rank, dtype) + if args.mode.lower() == "test": + run_test(kernel, exp_avg, exp_avg2, grad, proj_matrix, params, allow_tf32) + elif args.mode.lower() == "benchmark": + from test_utils import get_benchmark + + benchmark = get_benchmark(M, N, dtype, allow_tf32=allow_tf32) + save_path = f'benchmark_{M}x{N}_{rank}_{args.dtype}_{"tf32" if allow_tf32 else "no-tf32"}' + print( + f"Running benchmark for {M}x{N}, dtype {args.dtype}, allow_tf32 {allow_tf32}", + flush=True, + ) + benchmark.run(show_plots=False, print_data=True, save_path=save_path) + print(f"Finished benchmark, results saved to {save_path}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument( + "--kernel", + choices=["hybrid", "fused", "compiled"], + default="hybrid", + type=str, + help="Kernel to test", + ) + parser.add_argument( + "--mode", + choices=["test", "benchmark"], + default="test", + type=str, + help="If test, runs kernel vs torch, comparing accuracy first then performance using triton `do_bench` for given config {M, N, rank, dtype, allow_tf32}." + "If benchmark, runs all kernels across range of ranks for given config {M, N, dtype, allow_tf32}", + ) + parser.add_argument( + "--allow_tf32", action="store_true", help="Allow tf32 for matmuls" + ) + parser.add_argument("--M", type=int, default=4096, help="Grad (param) shape M") + parser.add_argument("--N", type=int, default=4096, help="Grad (param) shape N") + parser.add_argument( + "--rank", type=int, default=128, help="Rank of GaLore projection" + ) + parser.add_argument( + "--dtype", + type=str, + choices=["float32", "float16", "bfloat16"], + default="float32", + help="Data type of grad (param) tensors", + ) + parser.add_argument( + "--verbose", + action="store_true", + help="If true, prints autotuning output", + ) + args = parser.parse_args() + if args.verbose: + logging.basicConfig(level=logging.INFO) + run(args) diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..cabfaac --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,259 @@ +import torch +import triton +from triton.testing import do_bench + +from galore_torch.fused.triton_utils.kernels.adam_downproj_fused import ( + fused_adam_mm_launcher, +) +from galore_torch.fused.triton_utils.kernels.adam_step import triton_adam_launcher +from galore_torch.fused.triton_utils.kernels.matmul import triton_mm_launcher +from galore_torch.fused.utils import TestGaLoreProjector as GaLoreProjector + +torch.manual_seed(0) + +BETA1 = 0.9 +BETA2 = 0.999 +EPS = 1e-8 +STEP_SIZE = 1e-4 + + +def make_data(M, N, rank, dtype): + grad = torch.randn(M, N, device="cuda", dtype=dtype) + params = torch.randn(M, N, device="cuda", dtype=dtype) + + galore_proj = GaLoreProjector(rank=rank) + galore_proj.update_orthogonal_matrix(grad) + + if M >= N: + exp_avg = torch.randn(M, rank, device="cuda", dtype=dtype) + else: + exp_avg = torch.randn(rank, N, device="cuda", dtype=dtype) + exp_avg2 = exp_avg**2 + + return exp_avg, exp_avg2, grad, galore_proj.ortho_matrix, params + + +def make_copy(*args): + return [t.detach().clone() for t in args] + + +def _ref_op( + grad, + proj_matrix, + exp_avg, + exp_avg2, + params, + beta1=BETA1, + beta2=BETA2, + eps=EPS, + step_size=STEP_SIZE, + **kwargs, +): + + # Step 1: Down proj grad + M, N = grad.shape + if M >= N: + a, b = grad, proj_matrix.t() + else: + a, b = proj_matrix.t(), grad + low_rank_grad = a @ b + + # Step 2: update adam state + exp_avg.mul_(beta1).add_(low_rank_grad, alpha=(1.0 - beta1)) + exp_avg2.mul_(beta2).addcmul_(low_rank_grad, low_rank_grad, value=1.0 - beta2) + denom = exp_avg2.sqrt().add_(eps) + low_rank_norm_grad = exp_avg / denom + + # Step 3: project normalized low rank grad to full rank + if M >= N: + a, b = low_rank_norm_grad, proj_matrix + else: + a, b = proj_matrix, low_rank_norm_grad + full_grad_norm = a @ b + + # Finally, update params with updated grad + params.add_(full_grad_norm, alpha=-step_size) + + return exp_avg, exp_avg2, params + + +def _tt_hybrid( + grad, + proj_matrix, + exp_avg, + exp_avg2, + params, + store=True, + step_size=STEP_SIZE, + fp8_fast_accum=False, + allow_tf32=False, +): + M, N = grad.shape + if M >= N: + a, b = grad, proj_matrix.t() + else: + a, b = proj_matrix.t(), grad + low_rank_grad = a @ b + + exp_avg, exp_avg2, norm_grad = triton_adam_launcher( + exp_avg, exp_avg2, low_rank_grad, store=store + ) + + if M >= N: + a, b = low_rank_grad, proj_matrix + else: + a, b = proj_matrix, low_rank_grad + params = triton_mm_launcher( + a, + b, + epilogue_alpha=-step_size, + epilogue_source=params, + allow_tf32=allow_tf32, + fp8_fast_accum=fp8_fast_accum, + ) + return exp_avg, exp_avg2, params + + +def _tt_fused( + grad, + proj_matrix, + exp_avg, + exp_avg2, + params, + store=True, + step_size=STEP_SIZE, + fp8_fast_accum=False, + allow_tf32=False, +): + M, N = grad.shape + + if M >= N: + a, b = grad, proj_matrix.t() + else: + a, b = proj_matrix.t(), grad + exp_avg, exp_avg2, low_rank_grad = fused_adam_mm_launcher( + a, + b, + exp_avg=exp_avg, + exp_avg2=exp_avg2, + store=store, + fp8_fast_accum=fp8_fast_accum, + allow_tf32=allow_tf32, + ) + + if M >= N: + a, b = low_rank_grad, proj_matrix + else: + a, b = proj_matrix, low_rank_grad + params = triton_mm_launcher( + a, + b, + epilogue_alpha=-step_size, + epilogue_source=params, + allow_tf32=allow_tf32, + fp8_fast_accum=fp8_fast_accum, + ) + return exp_avg, exp_avg2, params + + # logging.basicConfig(level=logging.INFO) + + +def get_kernel(kernel): + if kernel == "ref": + op = _ref_op + elif kernel == "ref": + op = torch.compile(_ref_op, fullgraph=True, mode="max-autotune") + elif kernel == "hybrid": + op = _tt_hybrid + elif kernel == "fused": + op = _tt_fused + else: + raise ValueError(f"Unknown kernel {kernel}") + + return lambda *args, **kwargs: op(*args, **kwargs) + + +def get_benchmark( + M, N, dtype, allow_tf32, fp8_fast_accum=False, quantiles=[0.5, 0.2, 0.8] +): + config = triton.testing.Benchmark( + x_names=["rank"], # Argument names to use as an x-axis for the plot + x_vals=[ + 32, + 64, + 128, + 256, + 512, + ], # Different possible values for `x_name` + line_arg="kernel", # Argument name whose value corresponds to a different line in the plot + # Possible values for `line_arg` + line_vals=["torch", "hybrid", "fused", "compiled"], + # Label name for the lines + line_names=["torch", "hybrid", "fused", "compiled"], + # Line styles + styles=[("black", "-"), ("blue", "-"), ("red", "-"), ("green", "-")], + ylabel="ms", # Label name for the y-axis + plot_name=f"Adam Kernel Comparison Grad shape: {M}x{N}, dtype: {dtype}, allow_tf32: {allow_tf32}\nMedian times (ms)", # Name for the plot, used also as a file name for saving the plot. + args={}, + ) + + def benchmark(rank, kernel): + torch.backends.cuda.matmul.allow_tf32 = allow_tf32 + + exp_avg, exp_avg2, grad, proj_matrix, params = make_data(M, N, rank, dtype) + + if kernel == "torch": + ms, min_ms, max_ms = do_bench( + lambda: _ref_op( + grad, + proj_matrix, + exp_avg, + exp_avg2, + params, + ), + quantiles=quantiles, + ) + if kernel == "hybrid": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: _tt_hybrid( + grad, + proj_matrix, + exp_avg, + exp_avg2, + params, + store=True, + allow_tf32=allow_tf32, + fp8_fast_accum=fp8_fast_accum, + ), + quantiles=quantiles, + ) + if kernel == "fused": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: _tt_fused( + grad, + proj_matrix, + exp_avg, + exp_avg2, + params, + store=True, + allow_tf32=allow_tf32, + fp8_fast_accum=fp8_fast_accum, + ), + quantiles=quantiles, + ) + if kernel == "compiled": + compiled_op = torch.compile(_ref_op, fullgraph=True, mode="max-autotune") + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: compiled_op( + grad, + proj_matrix, + exp_avg, + exp_avg2, + params, + ), + quantiles=quantiles, + ) + + return ms, max_ms, min_ms + + return triton.testing.perf_report(config)(benchmark)