Skip to content

Commit

Permalink
Add AdEMAMix optimizer (#33682)
Browse files Browse the repository at this point in the history
* Add AdEMAMix optimizer

* Fix test

* Update tests/trainer/test_trainer.py

Co-authored-by: Marc Sun <[email protected]>

---------

Co-authored-by: Marc Sun <[email protected]>
  • Loading branch information
matthewdouglas and SunMarc authored Sep 25, 2024
1 parent 61e98cb commit 196d35c
Show file tree
Hide file tree
Showing 3 changed files with 201 additions and 1 deletion.
31 changes: 31 additions & 0 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1237,6 +1237,10 @@ def get_optimizer_cls_and_kwargs(
OptimizerNames.ADAMW_8BIT,
OptimizerNames.PAGED_ADAMW,
OptimizerNames.PAGED_ADAMW_8BIT,
OptimizerNames.ADEMAMIX,
OptimizerNames.ADEMAMIX_8BIT,
OptimizerNames.PAGED_ADEMAMIX,
OptimizerNames.PAGED_ADEMAMIX_8BIT,
OptimizerNames.LION,
OptimizerNames.LION_8BIT,
OptimizerNames.PAGED_LION,
Expand Down Expand Up @@ -1266,6 +1270,33 @@ def get_optimizer_cls_and_kwargs(
# Above we pass all `adam_kwargs` to the optimizer, here
# we only pass `optim_args` which can be passed by the user.
additional_optim_kwargs = optim_args
elif "ademamix" in args.optim:
if is_bitsandbytes_available() and version.parse(
importlib.metadata.version("bitsandbytes")
) < version.parse("0.44.0"):
raise ValueError(
"The AdEMAMix optimizer is not supported by your current version of `bitsandbytes`. "
"Please install `bitsandbytes` >= 0.44.0."
)

from bitsandbytes.optim import AdEMAMix

optimizer_cls = AdEMAMix
additional_optim_kwargs = {
"betas": (
float(optim_args.get("beta1", args.adam_beta1)),
float(optim_args.get("beta2", args.adam_beta2)),
float(optim_args.get("beta3", 0.9999)),
),
"alpha": float(optim_args.get("alpha", 5.0)),
"eps": float(optim_args.get("eps", args.adam_epsilon)),
}

if "t_alpha" in optim_args:
additional_optim_kwargs["t_alpha"] = int(optim_args["t_alpha"])

if "t_beta3" in optim_args:
additional_optim_kwargs["t_beta3"] = int(optim_args["t_beta3"])

bnb_kwargs = {"optim_bits": optim_bits}
if "rmsprop" not in args.optim:
Expand Down
6 changes: 5 additions & 1 deletion src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,14 +155,18 @@ class OptimizerNames(ExplicitEnum):
ADAFACTOR = "adafactor"
ADAMW_ANYPRECISION = "adamw_anyprecision"
ADAMW_TORCH_4BIT = "adamw_torch_4bit"
ADEMAMIX = "ademamix"
SGD = "sgd"
ADAGRAD = "adagrad"
ADAMW_BNB = "adamw_bnb_8bit"
ADAMW_8BIT = "adamw_8bit" # just an alias for adamw_bnb_8bit
ADEMAMIX_8BIT = "ademamix_8bit"
LION_8BIT = "lion_8bit"
LION = "lion_32bit"
PAGED_ADAMW = "paged_adamw_32bit"
PAGED_ADAMW_8BIT = "paged_adamw_8bit"
PAGED_ADEMAMIX = "paged_ademamix_32bit"
PAGED_ADEMAMIX_8BIT = "paged_ademamix_8bit"
PAGED_LION = "paged_lion_32bit"
PAGED_LION_8BIT = "paged_lion_8bit"
RMSPROP = "rmsprop"
Expand Down Expand Up @@ -618,7 +622,7 @@ class TrainingArguments:
"adafactor". See `OptimizerNames` in [training_args.py](https://github.com/huggingface/transformers/blob/main/src/transformers/training_args.py)
for a full list of optimizers.
optim_args (`str`, *optional*):
Optional arguments that are supplied to AnyPrecisionAdamW.
Optional arguments that are supplied to optimizers such as AnyPrecisionAdamW, AdEMAMix, and GaLore.
group_by_length (`bool`, *optional*, defaults to `False`):
Whether or not to group together samples of roughly the same length in the training dataset (to minimize
padding applied and be more efficient). Only useful if applying dynamic padding.
Expand Down
165 changes: 165 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import dataclasses
import gc
import importlib
import json
import math
import os
Expand All @@ -32,6 +33,7 @@

import numpy as np
from huggingface_hub import HfFolder, ModelCard, create_branch, delete_repo, list_repo_commits, list_repo_files
from packaging import version
from parameterized import parameterized
from requests.exceptions import HTTPError

Expand Down Expand Up @@ -1091,6 +1093,40 @@ def test_rmsprop_bnb(self):
# Check that it trains without errors
trainer.train()

@require_bitsandbytes
def test_ademamix_bnb(self):
config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4)
tiny_gpt2 = GPT2LMHeadModel(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)

with tempfile.TemporaryDirectory() as tmpdir:
# Trainer without inf/nan filter
args = TrainingArguments(
tmpdir, learning_rate=1e-9, logging_steps=5, logging_nan_inf_filter=False, optim="ademamix"
)
trainer = Trainer(tiny_gpt2, args, train_dataset=train_dataset)

# Check that it trains without errors
trainer.train()

@require_bitsandbytes
def test_ademamix_bnb_8bit(self):
config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4)
tiny_gpt2 = GPT2LMHeadModel(config)
x = torch.randint(0, 100, (128,))
train_dataset = RepeatDataset(x)

with tempfile.TemporaryDirectory() as tmpdir:
# Trainer without inf/nan filter
args = TrainingArguments(
tmpdir, learning_rate=1e-9, logging_steps=5, logging_nan_inf_filter=False, optim="ademamix_8bit"
)
trainer = Trainer(tiny_gpt2, args, train_dataset=train_dataset)

# Check that it trains without errors
trainer.train()

@require_bitsandbytes
def test_rmsprop_bnb_8bit(self):
config = GPT2Config(vocab_size=100, n_positions=128, n_embd=32, n_layer=3, n_head=4)
Expand Down Expand Up @@ -4187,6 +4223,13 @@ def hp_name(trial):
"lr": TrainingArguments.learning_rate,
}

default_ademamix_kwargs = {
"betas": (TrainingArguments.adam_beta1, TrainingArguments.adam_beta2, 0.9999),
"alpha": 5.0,
"eps": TrainingArguments.adam_epsilon,
"lr": TrainingArguments.learning_rate,
}

default_anyprecision_kwargs = {
"use_kahan_summation": False,
"momentum_dtype": torch.float32,
Expand Down Expand Up @@ -4291,6 +4334,36 @@ def hp_name(trial):
)
)

if version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse("0.44.0"):
optim_test_params.append(
(
TrainingArguments(optim=OptimizerNames.ADEMAMIX, output_dir="None"),
bnb.optim.AdEMAMix,
default_ademamix_kwargs,
)
)
optim_test_params.append(
(
TrainingArguments(optim=OptimizerNames.ADEMAMIX_8BIT, output_dir="None"),
bnb.optim.AdEMAMix,
default_ademamix_kwargs,
)
)
optim_test_params.append(
(
TrainingArguments(optim=OptimizerNames.PAGED_ADEMAMIX_8BIT, output_dir="None"),
bnb.optim.AdEMAMix,
default_ademamix_kwargs,
)
)
optim_test_params.append(
(
TrainingArguments(optim=OptimizerNames.PAGED_ADEMAMIX, output_dir="None"),
bnb.optim.AdEMAMix,
default_ademamix_kwargs,
)
)

if is_torchdistx_available():
import torchdistx

Expand Down Expand Up @@ -4420,6 +4493,62 @@ def test_bnb_paged_adam8bit(self):
default_adam_kwargs,
)

def test_bnb_ademamix(self):
mock = Mock()
modules = {
"bitsandbytes": mock,
"bitsandbytes.optim": mock.optim,
"bitsandbytes.optim.AdEMAMix": mock.optim.AdEMAMix,
}
with patch.dict("sys.modules", modules):
self.check_optim_and_kwargs(
TrainingArguments(optim=OptimizerNames.ADEMAMIX, output_dir="None"),
mock.optim.AdEMAMix,
default_ademamix_kwargs,
)

def test_bnb_ademamix8bit(self):
mock = Mock()
modules = {
"bitsandbytes": mock,
"bitsandbytes.optim": mock.optim,
"bitsandbytes.optim.AdEMAMix": mock.optim.AdEMAMix,
}
with patch.dict("sys.modules", modules):
self.check_optim_and_kwargs(
TrainingArguments(optim=OptimizerNames.ADEMAMIX_8BIT, output_dir="None"),
mock.optim.AdEMAMix,
default_ademamix_kwargs,
)

def test_bnb_paged_ademamix(self):
mock = Mock()
modules = {
"bitsandbytes": mock,
"bitsandbytes.optim": mock.optim,
"bitsandbytes.optim.AdEMAMix": mock.optim.AdEMAMix,
}
with patch.dict("sys.modules", modules):
self.check_optim_and_kwargs(
TrainingArguments(optim=OptimizerNames.PAGED_ADEMAMIX, output_dir="None"),
mock.optim.AdEMAMix,
default_ademamix_kwargs,
)

def test_bnb_paged_ademamix8bit(self):
mock = Mock()
modules = {
"bitsandbytes": mock,
"bitsandbytes.optim": mock.optim,
"bitsandbytes.optim.AdEMAMix": mock.optim.AdEMAMix,
}
with patch.dict("sys.modules", modules):
self.check_optim_and_kwargs(
TrainingArguments(optim=OptimizerNames.PAGED_ADEMAMIX_8BIT, output_dir="None"),
mock.optim.AdEMAMix,
default_ademamix_kwargs,
)

def test_bnb_lion(self):
mock = Mock()
modules = {
Expand Down Expand Up @@ -4503,6 +4632,42 @@ def test_bnb_paged_adam8bit_no_bnb(self):
with self.assertRaises(ValueError):
Trainer.get_optimizer_cls_and_kwargs(args)

def test_bnb_ademamix_no_bnb(self):
args = TrainingArguments(optim=OptimizerNames.ADEMAMIX, output_dir="None")

# Pretend that bnb does not exist, even if installed. By setting bnb to None, importing
# bnb will fail even if `bitsandbytes` is installed.
with patch.dict("sys.modules", {"bitsandbytes.optim": None}):
with self.assertRaises(ValueError):
Trainer.get_optimizer_cls_and_kwargs(args)

def test_bnb_ademamix8bit_no_bnb(self):
args = TrainingArguments(optim=OptimizerNames.ADEMAMIX_8BIT, output_dir="None")

# Pretend that bnb does not exist, even if installed. By setting bnb to None, importing
# bnb will fail even if `bitsandbytes` is installed.
with patch.dict("sys.modules", {"bitsandbytes.optim": None}):
with self.assertRaises(ValueError):
Trainer.get_optimizer_cls_and_kwargs(args)

def test_bnb_paged_ademamix_no_bnb(self):
args = TrainingArguments(optim=OptimizerNames.PAGED_ADEMAMIX, output_dir="None")

# Pretend that bnb does not exist, even if installed. By setting bnb to None, importing
# bnb will fail even if `bitsandbytes` is installed.
with patch.dict("sys.modules", {"bitsandbytes.optim": None}):
with self.assertRaises(ValueError):
Trainer.get_optimizer_cls_and_kwargs(args)

def test_bnb_paged_ademamix8bit_no_bnb(self):
args = TrainingArguments(optim=OptimizerNames.PAGED_ADEMAMIX_8BIT, output_dir="None")

# Pretend that bnb does not exist, even if installed. By setting bnb to None, importing
# bnb will fail even if `bitsandbytes` is installed.
with patch.dict("sys.modules", {"bitsandbytes.optim": None}):
with self.assertRaises(ValueError):
Trainer.get_optimizer_cls_and_kwargs(args)

def test_bnb_paged_lion_no_bnb(self):
args = TrainingArguments(optim=OptimizerNames.PAGED_LION, output_dir="None")

Expand Down

0 comments on commit 196d35c

Please sign in to comment.