diff --git a/scala/orca/src/main/scala/com/intel/analytics/bigdl/orca/net/TorchOptim.scala b/scala/orca/src/main/scala/com/intel/analytics/bigdl/orca/net/TorchOptim.scala index 006aec9eebe..1d938daafc5 100644 --- a/scala/orca/src/main/scala/com/intel/analytics/bigdl/orca/net/TorchOptim.scala +++ b/scala/orca/src/main/scala/com/intel/analytics/bigdl/orca/net/TorchOptim.scala @@ -25,17 +25,20 @@ import jep.NDArray import org.apache.spark.TaskContext import scala.reflect.ClassTag +import com.intel.analytics.zoo.pipeline.api.keras.models.InternalOptimizerUtil class TorchOptim[@specialized(Float, Double) T: ClassTag]( torchOptim: Array[Byte])(implicit ev: TensorNumeric[T]) extends OptimMethod[T] { import TorchOptim._ + @transient protected val postfix = Integer.toHexString(java.util.UUID.randomUUID().hashCode()) @transient protected lazy val optimType: OptimType = { val partId = TaskContext.getPartitionId() name = s"optim_${postfix}_${partId}" PythonInterpreter.set("optim_bytes", torchOptim) + val currentEpoch = getEpoch(this) val loadModelCode = s""" |import torch @@ -48,52 +51,79 @@ class TorchOptim[@specialized(Float, Double) T: ClassTag]( |$name = torch.load(io.BytesIO(optim_by), pickle_module=zoo_pickle_module) |""".stripMargin PythonInterpreter.exec(loadModelCode) + weightName = name + "_weight" + gradientName = name + "gradient" + lrStepCode = s""" + |${name}.step() + |""".stripMargin if (PythonInterpreter.getValue[Boolean](s"isinstance($name, Optimizer)")) { + initCode = s""" + |$weightName = torch.tensor($weightName, requires_grad=True) + |$weightName = torch.autograd.Variable($weightName) + |${name}.__init__([${weightName}], **${name}.defaults) + |""".stripMargin + stepCode = s""" + |${weightName}.grad = torch.tensor(${gradientName}) + |${name}.step() + |""".stripMargin Optim } else if (PythonInterpreter.getValue[Boolean](s"isinstance($name, _LRScheduler)")) { + initCode = s""" + |$weightName = torch.tensor($weightName, requires_grad=True) + |$weightName = torch.autograd.Variable($weightName) + |${name}.optimizer.__init__([${weightName}], **${name}.optimizer.defaults) + |""".stripMargin + stepCode = s""" + |${weightName}.grad = torch.tensor(${gradientName}) + |${name}.optimizer.step() + |""".stripMargin LrSchedule } else { - throw new IllegalArgumentException(s"Unknown optimizer type") + val unknowType = PythonInterpreter.getValue[String](s"str(type($name))") + throw new IllegalArgumentException(s"Unknown optimizer type: " + unknowType) } } - var name = "" - var init = false + @transient + protected var name = "" + @transient + protected var weightName = "" + @transient + protected var gradientName = "" + @transient + protected var initCode = "" + @transient + protected var lrStepCode = "" + @transient + protected var stepCode = "" + @transient + protected var init = false + @transient + protected var lastEpoch = -1 override def optimize( feval: Tensor[T] => (T, Tensor[T]), parameter: Tensor[T]): (Tensor[T], Array[T]) = { - optimType match { - case Optim => - val (fx, dfdx) = feval(parameter) - val weightName = "weight" - if (!init) { - PythonInterpreter.set(weightName, new NDArray[Array[Float]]( - parameter.toTensor[Float].storage().array())) - val initCode = - s""" - |$weightName = torch.tensor($weightName, requires_grad=True) - |$weightName = torch.autograd.Variable($weightName) - |${name}.__init__([${weightName}], **${name}.defaults) - |""".stripMargin - PythonInterpreter.exec(initCode) - } - val gradientName = "gradient" - PythonInterpreter.set("gradient", new NDArray[Array[Float]]( - dfdx.toTensor[Float].storage().array())) - val stepCode = - s""" - |${weightName}.grad = torch.tensor(${gradientName}) - |${name}.step() - |""".stripMargin - PythonInterpreter.exec(stepCode) - val updatedParameter = PythonFeatureSet.ndArrayToTensor( - PythonInterpreter.getValue(s"${weightName}.data.numpy()").asInstanceOf[NDArray[_]]) - parameter.copy(updatedParameter.toTensor[T]) - (parameter, Array(fx)) - case LrSchedule => - throw new IllegalArgumentException() + optimType + val epoch = getEpoch(this) + val (fx, dfdx) = feval(parameter) + if (!init) { + lastEpoch = epoch + PythonInterpreter.set(weightName, new NDArray[Array[Float]]( + parameter.toTensor[Float].storage().array())) + PythonInterpreter.exec(initCode) + init = true + } + if (optimType == LrSchedule && lastEpoch < epoch) { + PythonInterpreter.exec(lrStepCode) } + PythonInterpreter.set(gradientName, new NDArray[Array[Float]]( + dfdx.toTensor[Float].storage().array())) + PythonInterpreter.exec(stepCode) + val updatedParameter = PythonFeatureSet.ndArrayToTensor( + PythonInterpreter.getValue(s"${weightName}.data.numpy()").asInstanceOf[NDArray[_]]) + parameter.copy(updatedParameter.toTensor[T]) + (parameter, Array(fx)) } @@ -105,6 +135,9 @@ class TorchOptim[@specialized(Float, Double) T: ClassTag]( optimType match { case Optim => PythonInterpreter.getValue[Double](s"${name}.defaults['lr']") + case lrSchedule => + // TODO: multi LR support. + PythonInterpreter.getValue[Double](s"${name}.get_last_lr()[0]") case _ => throw new IllegalArgumentException() } @@ -114,6 +147,21 @@ class TorchOptim[@specialized(Float, Double) T: ClassTag]( this } + override def updateHyperParameter(): Unit = { + if (optimType == LrSchedule) { + val epoch = getEpoch(this) + PythonInterpreter.exec(s"${name}.step(${epoch})") + } + } + + override def getHyperParameter(): String = { + if (optimType == LrSchedule) { + s"Current learning rate is ${getLearningRate()}. " + } else { + "" + } + } + } object TorchOptim{ @@ -125,4 +173,9 @@ object TorchOptim{ def apply[T: ClassTag](optimBytes: Array[Byte])(implicit ev: TensorNumeric[T]): TorchOptim[T] = { new TorchOptim[T](optimBytes) } + + protected[net] def getEpoch[T: ClassTag](optim: TorchOptim[T]): Int = { + // BigDL's epoch starts from 1, while torch starts from 0. + InternalOptimizerUtil.getStateFromOptiMethod(optim)[Int]("epoch") - 1 + } }