forked from keras-team/keras-cv
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adds all remaining Keras optimizers (Adamax, Adafactor, Nadam, and Ft…
…rl) (keras-team#80) * Add golden correctness tests for Adam and SGD * Fix dtype issues * Sync with main (keras-team#56) * Minor touch ups * Fix a pretty major bug * Format code * Big rethink of Variable API * Make build-by-run the default build(), leveraging new zero_history KerasTensor mode * Minor fixes * Format code * Switch back to build-by-eager-run for simplicity * Add raise upon build failure * Work around JAX bug. * Add a few more tests. * Add saving tests * Adds test suite for SGD and golden correctness tests for all optimizers (keras-team#40) * Add golden correctness tests for Adam and SGD * Fix dtype issues * Add binary accuracy (keras-team#41) * chore: adding binary accuracy * chore: fix docstring * Add tests for add_loss and activity regularization. * Reformat code * Add ActivityRegularization layer * Fix JAX CI. * Add Lambda Callback (keras-team#42) * Add LambdaCallback * Add Lambda Callback * Add Lambda Callback * Rename lambda_callback_test.py * Add einsum (keras-team#43) * Add einsum * address comments * Fix format line length (keras-team#45) * Add Embedding layer * Shorten lines * Add .vscode to .gitignore (keras-team#46) * rm vscode settings * add .vscode to gitignore * Set demo program backend (keras-team#48) * Add tests for training arg resolution in Layer. * Implement mixed precision. * Replace backend.execute with backend.numpy.XXX (keras-team#50) * Add cosine similarity loss and update l2_normalize from regularizers (keras-team#34) * Begin cosine loss * Add testing for cosine similarity * Fix formatting * Docstring standardization * Formatting * Create numerical_utils * Fix issue with call context lingering. * Add the EarlyStopping callback (keras-team#44) * add earlystopping callback * addressing comments * address comments * addressing comments * remove unused imports * re-enable imports checks (keras-team#51) * Add nn.one_hot (keras-team#52) * Add GaussianDropout layer. * Add GaussianNoise layer * Add Categorical Accuracy Metric (keras-team#47) * chore: adding categorical accuracy metric * chore: reformat docstrings * chore: reformat * chore: ndims with len * refactor the docstring * Fix typos * Implement masking. --------- Co-authored-by: Francois Chollet <[email protected]> Co-authored-by: Aritra Roy Gosthipaty <[email protected]> Co-authored-by: Ramesh Sampath <[email protected]> Co-authored-by: Chen Qian <[email protected]> Co-authored-by: Haifeng Jin <[email protected]> Co-authored-by: Gabriel Rasskin <[email protected]> * Adds rmsprop optimizer and tests * Add AdamW optimizer and tests, minor formatting changes * Implemented formatting fixes * Adds clip norm and clip value tests to Adam * Adds Adagrad and Adadelta optimizers * Applies fixes to formatting and deletes unnecessary kwargs * Adds Adamax and Adafactor and associated tests * Adds Nadam and Ftrl optimizers and associated tests --------- Co-authored-by: Francois Chollet <[email protected]> Co-authored-by: Aritra Roy Gosthipaty <[email protected]> Co-authored-by: Ramesh Sampath <[email protected]> Co-authored-by: Chen Qian <[email protected]> Co-authored-by: Haifeng Jin <[email protected]> Co-authored-by: Gabriel Rasskin <[email protected]>
- Loading branch information
1 parent
cd8e5fb
commit 3cf0a83
Showing
8 changed files
with
1,055 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,190 @@ | ||
from keras_core import backend | ||
from keras_core import operations as ops | ||
from keras_core.api_export import keras_core_export | ||
from keras_core.optimizers import optimizer | ||
|
||
|
||
@keras_core_export(["keras_core.optimizers.Adafactor"]) | ||
class Adafactor(optimizer.Optimizer): | ||
"""Optimizer that implements the Adafactor algorithm. | ||
Adafactor is commonly used in NLP tasks, and has the advantage | ||
of taking less memory because it only saves partial information of previous | ||
gradients. | ||
The default argument setup is based on the original paper (see reference). | ||
When gradients are of dimension > 2, Adafactor optimizer will delete the | ||
last 2 dimensions separately in its accumulator variables. | ||
Args: | ||
learning_rate: Initial value for the learning rate: | ||
a floating point value, Defaults to 0.001. | ||
beta_2_decay: float, defaults to -0.8. The decay rate of `beta_2`. | ||
epsilon_1: float, defaults to 1e-30. A small offset to keep demoninator | ||
away from 0. | ||
epsilon_2: float, defaults to 1e-3. A small offset to avoid learning | ||
rate becoming too small by time. | ||
clip_threshold: float, defaults to 1.0. Clipping threshold. This is a | ||
part of Adafactor algorithm, independent from `clipnorm`, | ||
`clipvalue`, and `global_clipnorm`. | ||
relative_step: bool, defaults to True. If `learning_rate` is a | ||
constant and `relative_step=True`, learning rate will be adjusted | ||
based on current iterations. This is a default learning rate decay | ||
in Adafactor. | ||
{{base_optimizer_keyword_args}} | ||
Reference: | ||
- [Shazeer, Noam et al., 2018](https://arxiv.org/abs/1804.04235). | ||
""" | ||
|
||
def __init__( | ||
self, | ||
learning_rate=0.001, | ||
beta_2_decay=-0.8, | ||
epsilon_1=1e-30, | ||
epsilon_2=1e-3, | ||
clip_threshold=1.0, | ||
relative_step=True, | ||
weight_decay=None, | ||
clipnorm=None, | ||
clipvalue=None, | ||
global_clipnorm=None, | ||
use_ema=False, | ||
ema_momentum=0.99, | ||
ema_overwrite_frequency=None, | ||
name="adafactor", | ||
): | ||
super().__init__( | ||
learning_rate=learning_rate, | ||
name=name, | ||
weight_decay=weight_decay, | ||
clipnorm=clipnorm, | ||
clipvalue=clipvalue, | ||
global_clipnorm=global_clipnorm, | ||
use_ema=use_ema, | ||
ema_momentum=ema_momentum, | ||
ema_overwrite_frequency=ema_overwrite_frequency, | ||
) | ||
self.beta_2_decay = beta_2_decay | ||
self.epsilon_1 = epsilon_1 | ||
self.epsilon_2 = epsilon_2 | ||
self.clip_threshold = clip_threshold | ||
self.relative_step = relative_step | ||
|
||
def build(self, var_list): | ||
"""Initialize optimizer variables. | ||
Adam optimizer has 3 types of variables: momentums, velocities and | ||
velocity_hat (only set when amsgrad is applied), | ||
Args: | ||
var_list: list of model variables to build Adam variables on. | ||
""" | ||
if self.built: | ||
return | ||
super().build(var_list) | ||
self._r = [] | ||
self._c = [] | ||
self._v = [] | ||
for var in var_list: | ||
if len(var.shape) < 2: | ||
# Don't factor if variable is of dimension < 2, but we still | ||
# need to create dummy variables as placeholder. | ||
self._r.append(backend.Variable(0, name=var.name)) | ||
self._c.append(backend.Variable(0, name=var.name)) | ||
else: | ||
# Always factor the last 2 dimenstions. | ||
r_shape = var.shape[:-1] | ||
c_shape = var.shape[:-2] + var.shape[-1] | ||
self._r.append( | ||
self.add_variable( | ||
shape=r_shape, | ||
dtype=var.dtype, | ||
name=var.name, | ||
) | ||
) | ||
self._c.append( | ||
self.add_variable( | ||
shape=c_shape, | ||
dtype=var.dtype, | ||
name=var.name, | ||
) | ||
) | ||
self._v.append( | ||
self.add_variable_from_reference( | ||
reference_variable=var, name="v" | ||
) | ||
) | ||
|
||
def _rms(self, x): | ||
return ops.sqrt(ops.mean(ops.square(x))) | ||
|
||
def update_step(self, gradient, variable, learning_rate): | ||
"""Update step given gradient and the associated model variable.""" | ||
|
||
lr = ops.cast(learning_rate, variable.dtype) | ||
gradient = ops.cast(gradient, variable.dtype) | ||
epsilon_2 = ops.cast(self.epsilon_2, variable.dtype) | ||
one = ops.cast(1.0, variable.dtype) | ||
local_step = ops.cast(self.iterations + 1, variable.dtype) | ||
if self.relative_step: # TODO: add learning_rate_schedule logic | ||
# If `relative_step=True` and learning rate is a constant, we | ||
# apply the relative step algorithm. | ||
lr = ops.minimum(lr, 1 / ops.sqrt(local_step)) | ||
|
||
r = self._r[self._get_variable_index(variable)] | ||
c = self._c[self._get_variable_index(variable)] | ||
v = self._v[self._get_variable_index(variable)] | ||
|
||
rho_t = ops.minimum(lr, 1 / ops.sqrt(local_step)) | ||
alpha_t = ops.maximum(epsilon_2, self._rms(variable)) * rho_t | ||
regulated_grad_square = ops.square(gradient) + self.epsilon_1 | ||
beta_2_t = 1 - ops.power(local_step, self.beta_2_decay) | ||
|
||
if len(variable.shape) >= 2: | ||
# `r` deletes the last dimension of gradient, so it is of shape | ||
# `gradient.shape[:-1]`. | ||
r.assign( | ||
beta_2_t * r | ||
+ (1 - beta_2_t) * ops.mean(regulated_grad_square, axis=-1) | ||
) | ||
# `c` deletes the second last dimension of gradient, so it is of | ||
# shape `gradient.shape[:-2] + gradient.shape[-1]`. | ||
c.assign( | ||
beta_2_t * c | ||
+ (1 - beta_2_t) * ops.mean(regulated_grad_square, axis=-2) | ||
) | ||
v.assign( | ||
ops.expand_dims( | ||
r / ops.mean(r, axis=-1, keepdims=True), axis=-1 | ||
) | ||
* ops.expand_dims(c, -2) | ||
) | ||
else: | ||
v.assign(beta_2_t * v + (1 - beta_2_t) * regulated_grad_square) | ||
|
||
# `convert_to_tensor` unifies the handling of sparse and dense grads. | ||
u_t = gradient / ops.sqrt(v) | ||
u_t_hat = u_t / ops.maximum(one, (self._rms(u_t) / self.clip_threshold)) | ||
variable.assign(variable - alpha_t * u_t_hat) | ||
|
||
def get_config(self): | ||
config = super().get_config() | ||
|
||
config.update( | ||
{ | ||
"beta_2_decay": self.beta_2_decay, | ||
"epsilon_1": self.epsilon_1, | ||
"epsilon_2": self.epsilon_2, | ||
"clip_threshold": self.clip_threshold, | ||
"relative_step": self.relative_step, | ||
} | ||
) | ||
return config | ||
|
||
|
||
Adafactor.__doc__ = Adafactor.__doc__.replace( | ||
"{{base_optimizer_keyword_args}}", optimizer.base_optimizer_keyword_args | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
# flake8: noqa | ||
|
||
|
||
import numpy as np | ||
|
||
from keras_core import backend | ||
from keras_core import testing | ||
from keras_core.optimizers.adafactor import Adafactor | ||
|
||
|
||
class AdafactorTest(testing.TestCase): | ||
def test_config(self): | ||
optimizer = Adafactor( | ||
learning_rate=0.5, | ||
beta_2_decay=-0.65, | ||
epsilon_1=1e-15, | ||
epsilon_2=1e-4, | ||
clip_threshold=0.9, | ||
relative_step=False, | ||
) | ||
self.run_class_serialization_test(optimizer) | ||
|
||
def test_single_step(self): | ||
optimizer = Adafactor(learning_rate=0.5) | ||
grads = np.array([1.0, 6.0, 7.0, 2.0]) | ||
vars = backend.Variable([1.0, 2.0, 3.0, 4.0]) | ||
optimizer.apply_gradients(zip([grads], [vars])) | ||
self.assertAllClose( | ||
vars, [-0.3693, 0.6307, 1.6307, 2.6307], rtol=1e-4, atol=1e-4 | ||
) | ||
|
||
def test_weight_decay(self): | ||
grads, var1, var2, var3 = ( | ||
np.zeros(()), | ||
backend.Variable(2.0), | ||
backend.Variable(2.0, name="exclude"), | ||
backend.Variable(2.0), | ||
) | ||
optimizer_1 = Adafactor(learning_rate=1.0, weight_decay=0.004) | ||
optimizer_1.apply_gradients(zip([grads], [var1])) | ||
|
||
optimizer_2 = Adafactor(learning_rate=1.0, weight_decay=0.004) | ||
optimizer_2.exclude_from_weight_decay(var_names=["exclude"]) | ||
optimizer_2.apply_gradients(zip([grads, grads], [var1, var2])) | ||
|
||
optimizer_3 = Adafactor(learning_rate=1.0, weight_decay=0.004) | ||
optimizer_3.exclude_from_weight_decay(var_list=[var3]) | ||
optimizer_3.apply_gradients(zip([grads, grads], [var1, var3])) | ||
|
||
self.assertAlmostEqual(var1.numpy(), 1.9760959, decimal=6) | ||
self.assertAlmostEqual(var2.numpy(), 2.0, decimal=6) | ||
self.assertAlmostEqual(var3.numpy(), 2.0, decimal=6) | ||
|
||
def test_correctness_with_golden(self): | ||
optimizer = Adafactor( | ||
learning_rate=0.5, | ||
beta_2_decay=-0.65, | ||
epsilon_1=1e-15, | ||
epsilon_2=1e-4, | ||
clip_threshold=0.9, | ||
relative_step=False, | ||
) | ||
|
||
x = backend.Variable(np.ones([10])) | ||
grads = np.arange(0.1, 1.1, 0.1) | ||
first_grads = np.full((10,), 0.01) | ||
|
||
# fmt: off | ||
golden = np.array( | ||
[[0.55, 0.55, 0.55, 0.55, 0.55, 0.55, 0.55, 0.55, 0.55, 0.55], | ||
[0.3031, 0.3026, 0.3025, 0.3024, 0.3024, 0.3024, 0.3024, 0.3024, 0.3024, 0.3024], | ||
[0.1671, 0.1665, 0.1663, 0.1663, 0.1663, 0.1663, 0.1663, 0.1663, 0.1663, 0.1663], | ||
[0.0923, 0.0916, 0.0915, 0.0914, 0.0914, 0.0914, 0.0914, 0.0914, 0.0914, 0.0914], | ||
[0.0554, 0.0548, 0.0546, 0.0546, 0.0546, 0.0546, 0.0546, 0.0545, 0.0545, 0.0545]] | ||
) | ||
# fmt: on | ||
|
||
optimizer.apply_gradients(zip([first_grads], [x])) | ||
for i in range(5): | ||
self.assertAllClose(x, golden[i], rtol=5e-4, atol=5e-4) | ||
optimizer.apply_gradients(zip([grads], [x])) | ||
|
||
def test_clip_norm(self): | ||
optimizer = Adafactor(clipnorm=1) | ||
grad = [np.array([100.0, 100.0])] | ||
clipped_grad = optimizer._clip_gradients(grad) | ||
self.assertAllClose(clipped_grad[0], [2**0.5 / 2, 2**0.5 / 2]) | ||
|
||
def test_clip_value(self): | ||
optimizer = Adafactor(clipvalue=1) | ||
grad = [np.array([100.0, 100.0])] | ||
clipped_grad = optimizer._clip_gradients(grad) | ||
self.assertAllClose(clipped_grad[0], [1.0, 1.0]) |
Oops, something went wrong.