From 089968825ea42084ce3f63b27cfe14ff5a2304ee Mon Sep 17 00:00:00 2001 From: Xin Qiu Date: Tue, 14 Jul 2020 14:27:10 +0800 Subject: [PATCH] add weightdecay to Adam (#2415) * add weightdecay * revert * some change * some update * some update --- python/PythonZooKeras.scala | 3 ++- .../analytics/bigdl/dllib/keras/optimizers/Adam.scala | 8 +++++++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/python/PythonZooKeras.scala b/python/PythonZooKeras.scala index 985589dfbe2..1fd6e54ee2f 100644 --- a/python/PythonZooKeras.scala +++ b/python/PythonZooKeras.scala @@ -1032,9 +1032,10 @@ class PythonZooKeras[T: ClassTag](implicit ev: TensorNumeric[T]) extends PythonZ beta_2: Double = 0.999, epsilon: Double = 1e-8, decay: Double = 0.0, + weightDecay: Double = 0.0, schedule: SGD.LearningRateSchedule = SGD.Default() ): Adam[T] = { - new Adam[T](lr, beta_1, beta_2, epsilon, decay, schedule) + new Adam[T](lr, beta_1, beta_2, epsilon, decay, weightDecay, schedule) } def createZooKerasHardShrink( diff --git a/scala/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/keras/optimizers/Adam.scala b/scala/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/keras/optimizers/Adam.scala index 069b52cb6ea..469fba9ac05 100644 --- a/scala/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/keras/optimizers/Adam.scala +++ b/scala/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/keras/optimizers/Adam.scala @@ -41,9 +41,10 @@ class Adam[@specialized(Float, Double) T: ClassTag]( var beta_2: Double = 0.999, var epsilon: Double = 1e-8, var decay: Double = 0.0, + var wDecay: Double = 0.0, val schedule: LearningRateSchedule = Default() )(implicit ev: TensorNumeric[T]) extends SGD[T](learningRate = lr, - learningRateDecay = decay, learningRateSchedule = schedule) { + learningRateDecay = decay, weightDecay = wDecay, learningRateSchedule = schedule) { @transient private var buffer: Tensor[T] = null @@ -65,6 +66,7 @@ class Adam[@specialized(Float, Double) T: ClassTag]( val beta1 = this.beta_1 val beta2 = this.beta_2 val eps = this.epsilon + val wd = this.wDecay val (fx, dfdx) = feval(parameter) val state = SGDRef.getstate(this) @@ -80,6 +82,10 @@ class Adam[@specialized(Float, Double) T: ClassTag]( val clr = - this.schedule.currentRate + if(wd > 0) { + dfdx.add(parameter * (ev.fromType(wd))) + } + /** * m_t = beta_1 * m_t-1 + (1 - beta_1) * g_t * v_t = beta_2 * v_t-1 + (1 - beta_2) * g_t * g_t