Skip to content

Commit

Permalink
Add AdEMAMix optimizer (#1360)
Browse files Browse the repository at this point in the history
* Add AdEMAMix optimizer

* Add PagedAdEMAMix32bit, AdEMAMix32bit

* Add PagedAdEMAMix32bit, AdEMAMix32bit

* AdEMAMix: add support for alpha/beta3 scheduling

* Update paged AdEMAMix
  • Loading branch information
matthewdouglas authored Sep 20, 2024
1 parent 8fc7892 commit d964546
Show file tree
Hide file tree
Showing 12 changed files with 858 additions and 85 deletions.
22 changes: 22 additions & 0 deletions bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ def prod(iterable):
lib.cadam32bit_grad_fp32,
lib.cadam32bit_grad_fp16,
),
"ademamix": (
lib.cademamix32bit_grad_fp32,
lib.cademamix32bit_grad_fp16,
lib.cademamix32bit_grad_bf16,
),
}

str2optimizer8bit = {
Expand Down Expand Up @@ -105,6 +110,11 @@ def prod(iterable):
lib.cadagrad_8bit_blockwise_grad_fp32,
lib.cadagrad_8bit_blockwise_grad_fp16,
),
"ademamix": (
lib.cademamix_8bit_blockwise_grad_fp32,
lib.cademamix_8bit_blockwise_grad_fp16,
lib.cademamix_8bit_blockwise_grad_bf16,
),
}


Expand Down Expand Up @@ -1550,6 +1560,8 @@ def optimizer_update_32bit(
lr: float,
state2: Optional[torch.Tensor] = None,
beta2: float = 0.0,
beta3: float = 0.0,
alpha: float = 0.0,
weight_decay: float = 0.0,
gnorm_scale: float = 1.0,
unorm_vec: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -1585,6 +1597,10 @@ def optimizer_update_32bit(
Optimizer state 2.
beta2 : float
Optimizer beta2.
beta3 : float
Optimizer beta3.
alpha : float
Optimizer alpha.
gnorm_scale : float
The factor to rescale the gradient to the max clip value.
unorm_vec : torch.Tensor
Expand Down Expand Up @@ -1623,6 +1639,8 @@ def optimizer_update_32bit(
ct.c_float(param_norm),
ct.c_float(beta1),
ct.c_float(beta2),
ct.c_float(beta3),
ct.c_float(alpha),
ct.c_float(eps),
ct.c_float(weight_decay),
ct.c_int32(step),
Expand Down Expand Up @@ -1775,6 +1793,8 @@ def optimizer_update_8bit_blockwise(
state2: Optional[torch.Tensor],
beta1: float,
beta2: float,
beta3: float,
alpha: float,
eps: float,
step: int,
lr: float,
Expand Down Expand Up @@ -1815,6 +1835,8 @@ def optimizer_update_8bit_blockwise(
get_ptr(state2),
ct.c_float(beta1),
ct.c_float(beta2),
ct.c_float(beta3),
ct.c_float(alpha),
ct.c_float(eps),
ct.c_int32(step),
ct.c_float(lr),
Expand Down
1 change: 1 addition & 0 deletions bitsandbytes/optim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
PagedAdamW8bit,
PagedAdamW32bit,
)
from .ademamix import AdEMAMix, AdEMAMix8bit, AdEMAMix32bit, PagedAdEMAMix, PagedAdEMAMix8bit, PagedAdEMAMix32bit
from .lamb import LAMB, LAMB8bit, LAMB32bit
from .lars import LARS, LARS8bit, LARS32bit, PytorchLARS
from .lion import Lion, Lion8bit, Lion32bit, PagedLion, PagedLion8bit, PagedLion32bit
Expand Down
Loading

0 comments on commit d964546

Please sign in to comment.