Skip to content

Commit

Permalink
pytorch Lrscheduler support (intel-analytics#2886)
Browse files Browse the repository at this point in the history
* add lrscheduler

* add ut

* update TorchOptim

* hyper parameter

* some fix
  • Loading branch information
qiuxin2012 authored Sep 22, 2020
1 parent 609b360 commit 830c24d
Showing 1 changed file with 86 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,20 @@ import jep.NDArray
import org.apache.spark.TaskContext

import scala.reflect.ClassTag
import com.intel.analytics.zoo.pipeline.api.keras.models.InternalOptimizerUtil

class TorchOptim[@specialized(Float, Double) T: ClassTag](
torchOptim: Array[Byte])(implicit ev: TensorNumeric[T]) extends OptimMethod[T] {
import TorchOptim._

@transient
protected val postfix = Integer.toHexString(java.util.UUID.randomUUID().hashCode())
@transient
protected lazy val optimType: OptimType = {
val partId = TaskContext.getPartitionId()
name = s"optim_${postfix}_${partId}"
PythonInterpreter.set("optim_bytes", torchOptim)
val currentEpoch = getEpoch(this)
val loadModelCode =
s"""
|import torch
Expand All @@ -48,52 +51,79 @@ class TorchOptim[@specialized(Float, Double) T: ClassTag](
|$name = torch.load(io.BytesIO(optim_by), pickle_module=zoo_pickle_module)
|""".stripMargin
PythonInterpreter.exec(loadModelCode)
weightName = name + "_weight"
gradientName = name + "gradient"
lrStepCode = s"""
|${name}.step()
|""".stripMargin
if (PythonInterpreter.getValue[Boolean](s"isinstance($name, Optimizer)")) {
initCode = s"""
|$weightName = torch.tensor($weightName, requires_grad=True)
|$weightName = torch.autograd.Variable($weightName)
|${name}.__init__([${weightName}], **${name}.defaults)
|""".stripMargin
stepCode = s"""
|${weightName}.grad = torch.tensor(${gradientName})
|${name}.step()
|""".stripMargin
Optim
} else if (PythonInterpreter.getValue[Boolean](s"isinstance($name, _LRScheduler)")) {
initCode = s"""
|$weightName = torch.tensor($weightName, requires_grad=True)
|$weightName = torch.autograd.Variable($weightName)
|${name}.optimizer.__init__([${weightName}], **${name}.optimizer.defaults)
|""".stripMargin
stepCode = s"""
|${weightName}.grad = torch.tensor(${gradientName})
|${name}.optimizer.step()
|""".stripMargin
LrSchedule
} else {
throw new IllegalArgumentException(s"Unknown optimizer type")
val unknowType = PythonInterpreter.getValue[String](s"str(type($name))")
throw new IllegalArgumentException(s"Unknown optimizer type: " + unknowType)
}
}

var name = ""
var init = false
@transient
protected var name = ""
@transient
protected var weightName = ""
@transient
protected var gradientName = ""
@transient
protected var initCode = ""
@transient
protected var lrStepCode = ""
@transient
protected var stepCode = ""
@transient
protected var init = false
@transient
protected var lastEpoch = -1

override def optimize(
feval: Tensor[T] => (T, Tensor[T]),
parameter: Tensor[T]): (Tensor[T], Array[T]) = {
optimType match {
case Optim =>
val (fx, dfdx) = feval(parameter)
val weightName = "weight"
if (!init) {
PythonInterpreter.set(weightName, new NDArray[Array[Float]](
parameter.toTensor[Float].storage().array()))
val initCode =
s"""
|$weightName = torch.tensor($weightName, requires_grad=True)
|$weightName = torch.autograd.Variable($weightName)
|${name}.__init__([${weightName}], **${name}.defaults)
|""".stripMargin
PythonInterpreter.exec(initCode)
}
val gradientName = "gradient"
PythonInterpreter.set("gradient", new NDArray[Array[Float]](
dfdx.toTensor[Float].storage().array()))
val stepCode =
s"""
|${weightName}.grad = torch.tensor(${gradientName})
|${name}.step()
|""".stripMargin
PythonInterpreter.exec(stepCode)
val updatedParameter = PythonFeatureSet.ndArrayToTensor(
PythonInterpreter.getValue(s"${weightName}.data.numpy()").asInstanceOf[NDArray[_]])
parameter.copy(updatedParameter.toTensor[T])
(parameter, Array(fx))
case LrSchedule =>
throw new IllegalArgumentException()
optimType
val epoch = getEpoch(this)
val (fx, dfdx) = feval(parameter)
if (!init) {
lastEpoch = epoch
PythonInterpreter.set(weightName, new NDArray[Array[Float]](
parameter.toTensor[Float].storage().array()))
PythonInterpreter.exec(initCode)
init = true
}
if (optimType == LrSchedule && lastEpoch < epoch) {
PythonInterpreter.exec(lrStepCode)
}
PythonInterpreter.set(gradientName, new NDArray[Array[Float]](
dfdx.toTensor[Float].storage().array()))
PythonInterpreter.exec(stepCode)
val updatedParameter = PythonFeatureSet.ndArrayToTensor(
PythonInterpreter.getValue(s"${weightName}.data.numpy()").asInstanceOf[NDArray[_]])
parameter.copy(updatedParameter.toTensor[T])
(parameter, Array(fx))

}

Expand All @@ -105,6 +135,9 @@ class TorchOptim[@specialized(Float, Double) T: ClassTag](
optimType match {
case Optim =>
PythonInterpreter.getValue[Double](s"${name}.defaults['lr']")
case lrSchedule =>
// TODO: multi LR support.
PythonInterpreter.getValue[Double](s"${name}.get_last_lr()[0]")
case _ =>
throw new IllegalArgumentException()
}
Expand All @@ -114,6 +147,21 @@ class TorchOptim[@specialized(Float, Double) T: ClassTag](
this
}

override def updateHyperParameter(): Unit = {
if (optimType == LrSchedule) {
val epoch = getEpoch(this)
PythonInterpreter.exec(s"${name}.step(${epoch})")
}
}

override def getHyperParameter(): String = {
if (optimType == LrSchedule) {
s"Current learning rate is ${getLearningRate()}. "
} else {
""
}
}

}

object TorchOptim{
Expand All @@ -125,4 +173,9 @@ object TorchOptim{
def apply[T: ClassTag](optimBytes: Array[Byte])(implicit ev: TensorNumeric[T]): TorchOptim[T] = {
new TorchOptim[T](optimBytes)
}

protected[net] def getEpoch[T: ClassTag](optim: TorchOptim[T]): Int = {
// BigDL's epoch starts from 1, while torch starts from 0.
InternalOptimizerUtil.getStateFromOptiMethod(optim)[Int]("epoch") - 1
}
}

0 comments on commit 830c24d

Please sign in to comment.