Skip to content

Commit

Permalink
Adds all remaining Keras optimizers (Adamax, Adafactor, Nadam, and Ft…
Browse files Browse the repository at this point in the history
…rl) (keras-team#80)

* Add golden correctness tests for Adam and SGD

* Fix dtype issues

* Sync with main (keras-team#56)

* Minor touch ups

* Fix a pretty major bug

* Format code

* Big rethink of Variable API

* Make build-by-run the default build(), leveraging new zero_history KerasTensor mode

* Minor fixes

* Format code

* Switch back to build-by-eager-run for simplicity

* Add raise upon build failure

* Work around JAX bug.

* Add a few more tests.

* Add saving tests

* Adds test suite for SGD and golden correctness tests for all optimizers (keras-team#40)

* Add golden correctness tests for Adam and SGD

* Fix dtype issues

* Add binary accuracy (keras-team#41)

* chore: adding binary accuracy

* chore: fix docstring

* Add tests for add_loss and activity regularization.

* Reformat code

* Add ActivityRegularization layer

* Fix JAX CI.

* Add Lambda Callback (keras-team#42)

* Add LambdaCallback

* Add Lambda Callback

* Add Lambda Callback

* Rename lambda_callback_test.py

* Add einsum (keras-team#43)

* Add einsum

* address comments

* Fix format line length (keras-team#45)

* Add Embedding layer

* Shorten lines

* Add .vscode to .gitignore (keras-team#46)

* rm vscode settings

* add .vscode to gitignore

* Set demo program backend (keras-team#48)

* Add tests for training arg resolution in Layer.

* Implement mixed precision.

* Replace backend.execute with backend.numpy.XXX (keras-team#50)

* Add cosine similarity loss and update l2_normalize from regularizers (keras-team#34)

* Begin cosine loss

* Add testing for cosine similarity

* Fix formatting

* Docstring standardization

* Formatting

* Create numerical_utils

* Fix issue with call context lingering.

* Add the EarlyStopping callback (keras-team#44)

* add earlystopping callback

* addressing comments

* address comments

* addressing comments

* remove unused imports

* re-enable imports checks (keras-team#51)

* Add nn.one_hot (keras-team#52)

* Add GaussianDropout layer.

* Add GaussianNoise layer

* Add Categorical Accuracy Metric (keras-team#47)

* chore: adding categorical accuracy metric

* chore: reformat docstrings

* chore: reformat

* chore: ndims with len

* refactor the docstring

* Fix typos

* Implement masking.

---------

Co-authored-by: Francois Chollet <[email protected]>
Co-authored-by: Aritra Roy Gosthipaty <[email protected]>
Co-authored-by: Ramesh Sampath <[email protected]>
Co-authored-by: Chen Qian <[email protected]>
Co-authored-by: Haifeng Jin <[email protected]>
Co-authored-by: Gabriel Rasskin <[email protected]>

* Adds rmsprop optimizer and tests

* Add AdamW optimizer and tests, minor formatting changes

* Implemented formatting fixes

* Adds clip norm and clip value tests to Adam

* Adds Adagrad and Adadelta optimizers

* Applies fixes to formatting and deletes unnecessary kwargs

* Adds Adamax and Adafactor and associated tests

* Adds Nadam and Ftrl optimizers and associated tests

---------

Co-authored-by: Francois Chollet <[email protected]>
Co-authored-by: Aritra Roy Gosthipaty <[email protected]>
Co-authored-by: Ramesh Sampath <[email protected]>
Co-authored-by: Chen Qian <[email protected]>
Co-authored-by: Haifeng Jin <[email protected]>
Co-authored-by: Gabriel Rasskin <[email protected]>
  • Loading branch information
7 people authored May 3, 2023
1 parent cd8e5fb commit 3cf0a83
Show file tree
Hide file tree
Showing 8 changed files with 1,055 additions and 0 deletions.
190 changes: 190 additions & 0 deletions keras_core/optimizers/adafactor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
from keras_core import backend
from keras_core import operations as ops
from keras_core.api_export import keras_core_export
from keras_core.optimizers import optimizer


@keras_core_export(["keras_core.optimizers.Adafactor"])
class Adafactor(optimizer.Optimizer):
"""Optimizer that implements the Adafactor algorithm.
Adafactor is commonly used in NLP tasks, and has the advantage
of taking less memory because it only saves partial information of previous
gradients.
The default argument setup is based on the original paper (see reference).
When gradients are of dimension > 2, Adafactor optimizer will delete the
last 2 dimensions separately in its accumulator variables.
Args:
learning_rate: Initial value for the learning rate:
a floating point value, Defaults to 0.001.
beta_2_decay: float, defaults to -0.8. The decay rate of `beta_2`.
epsilon_1: float, defaults to 1e-30. A small offset to keep demoninator
away from 0.
epsilon_2: float, defaults to 1e-3. A small offset to avoid learning
rate becoming too small by time.
clip_threshold: float, defaults to 1.0. Clipping threshold. This is a
part of Adafactor algorithm, independent from `clipnorm`,
`clipvalue`, and `global_clipnorm`.
relative_step: bool, defaults to True. If `learning_rate` is a
constant and `relative_step=True`, learning rate will be adjusted
based on current iterations. This is a default learning rate decay
in Adafactor.
{{base_optimizer_keyword_args}}
Reference:
- [Shazeer, Noam et al., 2018](https://arxiv.org/abs/1804.04235).
"""

def __init__(
self,
learning_rate=0.001,
beta_2_decay=-0.8,
epsilon_1=1e-30,
epsilon_2=1e-3,
clip_threshold=1.0,
relative_step=True,
weight_decay=None,
clipnorm=None,
clipvalue=None,
global_clipnorm=None,
use_ema=False,
ema_momentum=0.99,
ema_overwrite_frequency=None,
name="adafactor",
):
super().__init__(
learning_rate=learning_rate,
name=name,
weight_decay=weight_decay,
clipnorm=clipnorm,
clipvalue=clipvalue,
global_clipnorm=global_clipnorm,
use_ema=use_ema,
ema_momentum=ema_momentum,
ema_overwrite_frequency=ema_overwrite_frequency,
)
self.beta_2_decay = beta_2_decay
self.epsilon_1 = epsilon_1
self.epsilon_2 = epsilon_2
self.clip_threshold = clip_threshold
self.relative_step = relative_step

def build(self, var_list):
"""Initialize optimizer variables.
Adam optimizer has 3 types of variables: momentums, velocities and
velocity_hat (only set when amsgrad is applied),
Args:
var_list: list of model variables to build Adam variables on.
"""
if self.built:
return
super().build(var_list)
self._r = []
self._c = []
self._v = []
for var in var_list:
if len(var.shape) < 2:
# Don't factor if variable is of dimension < 2, but we still
# need to create dummy variables as placeholder.
self._r.append(backend.Variable(0, name=var.name))
self._c.append(backend.Variable(0, name=var.name))
else:
# Always factor the last 2 dimenstions.
r_shape = var.shape[:-1]
c_shape = var.shape[:-2] + var.shape[-1]
self._r.append(
self.add_variable(
shape=r_shape,
dtype=var.dtype,
name=var.name,
)
)
self._c.append(
self.add_variable(
shape=c_shape,
dtype=var.dtype,
name=var.name,
)
)
self._v.append(
self.add_variable_from_reference(
reference_variable=var, name="v"
)
)

def _rms(self, x):
return ops.sqrt(ops.mean(ops.square(x)))

def update_step(self, gradient, variable, learning_rate):
"""Update step given gradient and the associated model variable."""

lr = ops.cast(learning_rate, variable.dtype)
gradient = ops.cast(gradient, variable.dtype)
epsilon_2 = ops.cast(self.epsilon_2, variable.dtype)
one = ops.cast(1.0, variable.dtype)
local_step = ops.cast(self.iterations + 1, variable.dtype)
if self.relative_step: # TODO: add learning_rate_schedule logic
# If `relative_step=True` and learning rate is a constant, we
# apply the relative step algorithm.
lr = ops.minimum(lr, 1 / ops.sqrt(local_step))

r = self._r[self._get_variable_index(variable)]
c = self._c[self._get_variable_index(variable)]
v = self._v[self._get_variable_index(variable)]

rho_t = ops.minimum(lr, 1 / ops.sqrt(local_step))
alpha_t = ops.maximum(epsilon_2, self._rms(variable)) * rho_t
regulated_grad_square = ops.square(gradient) + self.epsilon_1
beta_2_t = 1 - ops.power(local_step, self.beta_2_decay)

if len(variable.shape) >= 2:
# `r` deletes the last dimension of gradient, so it is of shape
# `gradient.shape[:-1]`.
r.assign(
beta_2_t * r
+ (1 - beta_2_t) * ops.mean(regulated_grad_square, axis=-1)
)
# `c` deletes the second last dimension of gradient, so it is of
# shape `gradient.shape[:-2] + gradient.shape[-1]`.
c.assign(
beta_2_t * c
+ (1 - beta_2_t) * ops.mean(regulated_grad_square, axis=-2)
)
v.assign(
ops.expand_dims(
r / ops.mean(r, axis=-1, keepdims=True), axis=-1
)
* ops.expand_dims(c, -2)
)
else:
v.assign(beta_2_t * v + (1 - beta_2_t) * regulated_grad_square)

# `convert_to_tensor` unifies the handling of sparse and dense grads.
u_t = gradient / ops.sqrt(v)
u_t_hat = u_t / ops.maximum(one, (self._rms(u_t) / self.clip_threshold))
variable.assign(variable - alpha_t * u_t_hat)

def get_config(self):
config = super().get_config()

config.update(
{
"beta_2_decay": self.beta_2_decay,
"epsilon_1": self.epsilon_1,
"epsilon_2": self.epsilon_2,
"clip_threshold": self.clip_threshold,
"relative_step": self.relative_step,
}
)
return config


Adafactor.__doc__ = Adafactor.__doc__.replace(
"{{base_optimizer_keyword_args}}", optimizer.base_optimizer_keyword_args
)
93 changes: 93 additions & 0 deletions keras_core/optimizers/adafactor_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# flake8: noqa


import numpy as np

from keras_core import backend
from keras_core import testing
from keras_core.optimizers.adafactor import Adafactor


class AdafactorTest(testing.TestCase):
def test_config(self):
optimizer = Adafactor(
learning_rate=0.5,
beta_2_decay=-0.65,
epsilon_1=1e-15,
epsilon_2=1e-4,
clip_threshold=0.9,
relative_step=False,
)
self.run_class_serialization_test(optimizer)

def test_single_step(self):
optimizer = Adafactor(learning_rate=0.5)
grads = np.array([1.0, 6.0, 7.0, 2.0])
vars = backend.Variable([1.0, 2.0, 3.0, 4.0])
optimizer.apply_gradients(zip([grads], [vars]))
self.assertAllClose(
vars, [-0.3693, 0.6307, 1.6307, 2.6307], rtol=1e-4, atol=1e-4
)

def test_weight_decay(self):
grads, var1, var2, var3 = (
np.zeros(()),
backend.Variable(2.0),
backend.Variable(2.0, name="exclude"),
backend.Variable(2.0),
)
optimizer_1 = Adafactor(learning_rate=1.0, weight_decay=0.004)
optimizer_1.apply_gradients(zip([grads], [var1]))

optimizer_2 = Adafactor(learning_rate=1.0, weight_decay=0.004)
optimizer_2.exclude_from_weight_decay(var_names=["exclude"])
optimizer_2.apply_gradients(zip([grads, grads], [var1, var2]))

optimizer_3 = Adafactor(learning_rate=1.0, weight_decay=0.004)
optimizer_3.exclude_from_weight_decay(var_list=[var3])
optimizer_3.apply_gradients(zip([grads, grads], [var1, var3]))

self.assertAlmostEqual(var1.numpy(), 1.9760959, decimal=6)
self.assertAlmostEqual(var2.numpy(), 2.0, decimal=6)
self.assertAlmostEqual(var3.numpy(), 2.0, decimal=6)

def test_correctness_with_golden(self):
optimizer = Adafactor(
learning_rate=0.5,
beta_2_decay=-0.65,
epsilon_1=1e-15,
epsilon_2=1e-4,
clip_threshold=0.9,
relative_step=False,
)

x = backend.Variable(np.ones([10]))
grads = np.arange(0.1, 1.1, 0.1)
first_grads = np.full((10,), 0.01)

# fmt: off
golden = np.array(
[[0.55, 0.55, 0.55, 0.55, 0.55, 0.55, 0.55, 0.55, 0.55, 0.55],
[0.3031, 0.3026, 0.3025, 0.3024, 0.3024, 0.3024, 0.3024, 0.3024, 0.3024, 0.3024],
[0.1671, 0.1665, 0.1663, 0.1663, 0.1663, 0.1663, 0.1663, 0.1663, 0.1663, 0.1663],
[0.0923, 0.0916, 0.0915, 0.0914, 0.0914, 0.0914, 0.0914, 0.0914, 0.0914, 0.0914],
[0.0554, 0.0548, 0.0546, 0.0546, 0.0546, 0.0546, 0.0546, 0.0545, 0.0545, 0.0545]]
)
# fmt: on

optimizer.apply_gradients(zip([first_grads], [x]))
for i in range(5):
self.assertAllClose(x, golden[i], rtol=5e-4, atol=5e-4)
optimizer.apply_gradients(zip([grads], [x]))

def test_clip_norm(self):
optimizer = Adafactor(clipnorm=1)
grad = [np.array([100.0, 100.0])]
clipped_grad = optimizer._clip_gradients(grad)
self.assertAllClose(clipped_grad[0], [2**0.5 / 2, 2**0.5 / 2])

def test_clip_value(self):
optimizer = Adafactor(clipvalue=1)
grad = [np.array([100.0, 100.0])]
clipped_grad = optimizer._clip_gradients(grad)
self.assertAllClose(clipped_grad[0], [1.0, 1.0])
Loading

0 comments on commit 3cf0a83

Please sign in to comment.