Skip to content

Commit

Permalink
support epochdecay iterationdecay epochdecaybyscore in TorchOptim (in…
Browse files Browse the repository at this point in the history
…tel-analytics#2951)

* support epochdecay iterationdecay epochdecaybyscore

* fix ut

* fix ut

* add ut

* fix unit test
  • Loading branch information
qiuxin2012 authored Oct 14, 2020
1 parent 830c24d commit 885a4d7
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@ import org.apache.spark.TaskContext

import scala.reflect.ClassTag
import com.intel.analytics.zoo.pipeline.api.keras.models.InternalOptimizerUtil
import com.intel.analytics.zoo.pipeline.api.net.TorchOptim.DecayType

class TorchOptim[@specialized(Float, Double) T: ClassTag](
torchOptim: Array[Byte])(implicit ev: TensorNumeric[T]) extends OptimMethod[T] {
torchOptim: Array[Byte],
decayType: DecayType)(implicit ev: TensorNumeric[T]) extends OptimMethod[T] {
import TorchOptim._

@transient
Expand All @@ -43,8 +45,9 @@ class TorchOptim[@specialized(Float, Double) T: ClassTag](
s"""
|import torch
|import io
|from torch.optim.optimizer import Optimizer
|from torch.optim.optimizer import *
|from torch.optim.lr_scheduler import _LRScheduler
|from torch.optim.lr_scheduler import *
|from zoo.pipeline.api.torch import zoo_pickle_module
|
|optim_by = bytes(b % 256 for b in optim_bytes)
Expand All @@ -53,9 +56,6 @@ class TorchOptim[@specialized(Float, Double) T: ClassTag](
PythonInterpreter.exec(loadModelCode)
weightName = name + "_weight"
gradientName = name + "gradient"
lrStepCode = s"""
|${name}.step()
|""".stripMargin
if (PythonInterpreter.getValue[Boolean](s"isinstance($name, Optimizer)")) {
initCode = s"""
|$weightName = torch.tensor($weightName, requires_grad=True)
Expand All @@ -77,7 +77,20 @@ class TorchOptim[@specialized(Float, Double) T: ClassTag](
|${weightName}.grad = torch.tensor(${gradientName})
|${name}.optimizer.step()
|""".stripMargin
LrSchedule
LrScheduler
} else if (PythonInterpreter.getValue[Boolean](s"isinstance($name, ReduceLROnPlateau)")) {
// ReduceLROnPlateau is not subclass of LRScheduler
require(decayType == EpochDecayByScore, "Plateau should use decayType EpochDecayByScore")
initCode = s"""
|$weightName = torch.tensor($weightName, requires_grad=True)
|$weightName = torch.autograd.Variable($weightName)
|${name}.optimizer.__init__([${weightName}], **${name}.optimizer.defaults)
|""".stripMargin
stepCode = s"""
|${weightName}.grad = torch.tensor(${gradientName})
|${name}.optimizer.step()
|""".stripMargin
Plateau
} else {
val unknowType = PythonInterpreter.getValue[String](s"str(type($name))")
throw new IllegalArgumentException(s"Unknown optimizer type: " + unknowType)
Expand All @@ -93,8 +106,6 @@ class TorchOptim[@specialized(Float, Double) T: ClassTag](
@transient
protected var initCode = ""
@transient
protected var lrStepCode = ""
@transient
protected var stepCode = ""
@transient
protected var init = false
Expand All @@ -113,9 +124,8 @@ class TorchOptim[@specialized(Float, Double) T: ClassTag](
parameter.toTensor[Float].storage().array()))
PythonInterpreter.exec(initCode)
init = true
}
if (optimType == LrSchedule && lastEpoch < epoch) {
PythonInterpreter.exec(lrStepCode)
} else {
updateHyperParameter()
}
PythonInterpreter.set(gradientName, new NDArray[Array[Float]](
dfdx.toTensor[Float].storage().array()))
Expand All @@ -135,9 +145,15 @@ class TorchOptim[@specialized(Float, Double) T: ClassTag](
optimType match {
case Optim =>
PythonInterpreter.getValue[Double](s"${name}.defaults['lr']")
case lrSchedule =>
case LrScheduler =>
// TODO: multi LR support.
PythonInterpreter.getValue[Double](s"${name}.get_last_lr()[0]")
case Plateau =>
if (PythonInterpreter.getValue[Boolean](s"hasattr(${name}, '_last_lr')")) {
PythonInterpreter.getValue[Double](s"${name}._last_lr[0]")
} else {
PythonInterpreter.getValue[Double](s"${name}.optimizer.defaults['lr']")
}
case _ =>
throw new IllegalArgumentException()
}
Expand All @@ -148,14 +164,29 @@ class TorchOptim[@specialized(Float, Double) T: ClassTag](
}

override def updateHyperParameter(): Unit = {
if (optimType == LrSchedule) {
if (optimType == LrScheduler || optimType == Plateau) {
val epoch = getEpoch(this)
PythonInterpreter.exec(s"${name}.step(${epoch})")
decayType match {
case TorchOptim.EpochDecay =>
if (lastEpoch < epoch) {
PythonInterpreter.exec(s"${name}.step()")
lastEpoch += 1
}
case TorchOptim.IterationDecay =>
PythonInterpreter.exec(s"${name}.step()")
case TorchOptim.EpochDecayByScore =>
if (lastEpoch < epoch) {
val valScore = getScore(this)
PythonInterpreter.set("val_score", java.lang.Float.valueOf(valScore))
PythonInterpreter.exec(s"${name}.step(val_score)")
lastEpoch += 1
}
}
}
}

override def getHyperParameter(): String = {
if (optimType == LrSchedule) {
if (optimType == LrScheduler) {
s"Current learning rate is ${getLearningRate()}. "
} else {
""
Expand All @@ -166,16 +197,53 @@ class TorchOptim[@specialized(Float, Double) T: ClassTag](

object TorchOptim{
sealed trait OptimType

case object LrSchedule extends OptimType
case object LrScheduler extends OptimType
case object Optim extends OptimType
case object Plateau extends OptimType

sealed trait DecayType
case object EpochDecay extends DecayType
case object IterationDecay extends DecayType
case object EpochDecayByScore extends DecayType
// TODO: Support this later.
// case object IterationDecayByEpoch extends DecayType

def getDecayType(decayType: String): DecayType = {
decayType.toLowerCase() match {
case "epochdecay" =>
EpochDecay
case "iterationdecay" =>
IterationDecay
case "epochdecaybyscore" =>
EpochDecayByScore
// case "iterationdecaybyepoch" =>
// IterationDecayByEpoch
case _ =>
throw new IllegalArgumentException(s"unknow decay type: ${decayType}, expected:" +
s"EpochDecay, IterationDecay, EpochDecayByScore")
}

}

def apply[T: ClassTag](optimBytes: Array[Byte])(implicit ev: TensorNumeric[T]): TorchOptim[T] = {
new TorchOptim[T](optimBytes)
def apply[T: ClassTag](
optimBytes: Array[Byte],
decayType: String)(implicit ev: TensorNumeric[T]): TorchOptim[T] = {
apply[T](optimBytes, getDecayType(decayType))
}

def apply[T: ClassTag](
optimBytes: Array[Byte],
decayType: DecayType)(implicit ev: TensorNumeric[T]): TorchOptim[T] = {
new TorchOptim[T](optimBytes, decayType)
}

protected[net] def getEpoch[T: ClassTag](optim: TorchOptim[T]): Int = {
// BigDL's epoch starts from 1, while torch starts from 0.
InternalOptimizerUtil.getStateFromOptiMethod(optim)[Int]("epoch") - 1
}

protected[net] def getScore[T: ClassTag](optim: TorchOptim[T]): Float = {
// BigDL's epoch starts from 1, while torch starts from 0.
InternalOptimizerUtil.getStateFromOptiMethod(optim)[Float]("score")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,8 @@ class PythonZooNet[T: ClassTag](implicit ev: TensorNumeric[T]) extends PythonZoo
TorchLoss(criterion)
}

def createTorchOptim(optim: Array[Byte]): TorchOptim[T] = {
TorchOptim(optim)
def createTorchOptim(optim: Array[Byte], decayType: String): TorchOptim[T] = {
TorchOptim(optim, decayType)
}

def torchNetSavePytorch(torchnet: TorchNet, path: String): Unit = {
Expand Down

0 comments on commit 885a4d7

Please sign in to comment.