Skip to content

Commit

Permalink
add weightdecay to Adam (intel-analytics#2415)
Browse files Browse the repository at this point in the history
* add weightdecay

* revert

* some change

* some update

* some update
  • Loading branch information
qiuxin2012 authored Jul 14, 2020
1 parent 683f1c8 commit 0899688
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
3 changes: 2 additions & 1 deletion python/PythonZooKeras.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1032,9 +1032,10 @@ class PythonZooKeras[T: ClassTag](implicit ev: TensorNumeric[T]) extends PythonZ
beta_2: Double = 0.999,
epsilon: Double = 1e-8,
decay: Double = 0.0,
weightDecay: Double = 0.0,
schedule: SGD.LearningRateSchedule = SGD.Default()
): Adam[T] = {
new Adam[T](lr, beta_1, beta_2, epsilon, decay, schedule)
new Adam[T](lr, beta_1, beta_2, epsilon, decay, weightDecay, schedule)
}

def createZooKerasHardShrink(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,10 @@ class Adam[@specialized(Float, Double) T: ClassTag](
var beta_2: Double = 0.999,
var epsilon: Double = 1e-8,
var decay: Double = 0.0,
var wDecay: Double = 0.0,
val schedule: LearningRateSchedule = Default()
)(implicit ev: TensorNumeric[T]) extends SGD[T](learningRate = lr,
learningRateDecay = decay, learningRateSchedule = schedule) {
learningRateDecay = decay, weightDecay = wDecay, learningRateSchedule = schedule) {

@transient
private var buffer: Tensor[T] = null
Expand All @@ -65,6 +66,7 @@ class Adam[@specialized(Float, Double) T: ClassTag](
val beta1 = this.beta_1
val beta2 = this.beta_2
val eps = this.epsilon
val wd = this.wDecay

val (fx, dfdx) = feval(parameter)
val state = SGDRef.getstate(this)
Expand All @@ -80,6 +82,10 @@ class Adam[@specialized(Float, Double) T: ClassTag](

val clr = - this.schedule.currentRate

if(wd > 0) {
dfdx.add(parameter * (ev.fromType(wd)))
}

/**
* m_t = beta_1 * m_t-1 + (1 - beta_1) * g_t
* v_t = beta_2 * v_t-1 + (1 - beta_2) * g_t * g_t
Expand Down

0 comments on commit 0899688

Please sign in to comment.