Skip to content

Commit

Permalink
Torchoptim, wrap pytorch optimizer to bigdl's optimmethod (intel-anal…
Browse files Browse the repository at this point in the history
…ytics#2869)

* TorchOptim

* add python api

* add scala test

* add python test

* add python test

* add python test

* fix on yarn

* clean up

* add test

* fix style check
  • Loading branch information
qiuxin2012 authored Sep 15, 2020
1 parent 1944f93 commit 5f3b659
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 2 deletions.
3 changes: 1 addition & 2 deletions net/TorchModel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class TorchModel private(private val modelHolder: TorchModel2Holder, init_weight
import TorchModel._

protected var loaded = false
@transient
protected lazy val load = {
PythonInterpreter.set("model_bytes", modelHolder.torchBytes)
val loadModelCode =
Expand All @@ -50,8 +51,6 @@ class TorchModel private(private val modelHolder: TorchModel2Holder, init_weight
|from zoo.pipeline.api.torch.utils import trainable_param
|from zoo.pipeline.api.torch import zoo_pickle_module
|
|import pickle
|from pyspark.serializers import CloudPickleSerializer
|by = bytes(b % 256 for b in model_bytes)
|${getName()} = torch.load(io.BytesIO(by), pickle_module=zoo_pickle_module)
|""".stripMargin
Expand Down
128 changes: 128 additions & 0 deletions net/TorchOptim.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
/*
* Copyright 2018 Analytics Zoo Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.intel.analytics.zoo.pipeline.api.net

import com.intel.analytics.bigdl.optim.OptimMethod
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric
import com.intel.analytics.bigdl.utils.{EngineType, Table}
import com.intel.analytics.zoo.common.PythonInterpreter
import com.intel.analytics.zoo.feature.PythonFeatureSet
import jep.NDArray
import org.apache.spark.TaskContext

import scala.reflect.ClassTag

class TorchOptim[@specialized(Float, Double) T: ClassTag](
torchOptim: Array[Byte])(implicit ev: TensorNumeric[T]) extends OptimMethod[T] {
import TorchOptim._

protected val postfix = Integer.toHexString(java.util.UUID.randomUUID().hashCode())
@transient
protected lazy val optimType: OptimType = {
val partId = TaskContext.getPartitionId()
name = s"optim_${postfix}_${partId}"
PythonInterpreter.set("optim_bytes", torchOptim)
val loadModelCode =
s"""
|import torch
|import io
|from torch.optim.optimizer import Optimizer
|from torch.optim.lr_scheduler import _LRScheduler
|from zoo.pipeline.api.torch import zoo_pickle_module
|
|optim_by = bytes(b % 256 for b in optim_bytes)
|$name = torch.load(io.BytesIO(optim_by), pickle_module=zoo_pickle_module)
|""".stripMargin
PythonInterpreter.exec(loadModelCode)
if (PythonInterpreter.getValue[Boolean](s"isinstance($name, Optimizer)")) {
Optim
} else if (PythonInterpreter.getValue[Boolean](s"isinstance($name, _LRScheduler)")) {
LrSchedule
} else {
throw new IllegalArgumentException(s"Unknown optimizer type")
}
}

var name = ""
var init = false

override def optimize(
feval: Tensor[T] => (T, Tensor[T]),
parameter: Tensor[T]): (Tensor[T], Array[T]) = {
optimType match {
case Optim =>
val (fx, dfdx) = feval(parameter)
val weightName = "weight"
if (!init) {
PythonInterpreter.set(weightName, new NDArray[Array[Float]](
parameter.toTensor[Float].storage().array()))
val initCode =
s"""
|$weightName = torch.tensor($weightName, requires_grad=True)
|$weightName = torch.autograd.Variable($weightName)
|${name}.__init__([${weightName}], **${name}.defaults)
|""".stripMargin
PythonInterpreter.exec(initCode)
}
val gradientName = "gradient"
PythonInterpreter.set("gradient", new NDArray[Array[Float]](
dfdx.toTensor[Float].storage().array()))
val stepCode =
s"""
|${weightName}.grad = torch.tensor(${gradientName})
|${name}.step()
|""".stripMargin
PythonInterpreter.exec(stepCode)
val updatedParameter = PythonFeatureSet.ndArrayToTensor(
PythonInterpreter.getValue(s"${weightName}.data.numpy()").asInstanceOf[NDArray[_]])
parameter.copy(updatedParameter.toTensor[T])
(parameter, Array(fx))
case LrSchedule =>
throw new IllegalArgumentException()
}

}

override def clearHistory(): Unit = {

}

override def getLearningRate(): Double = {
optimType match {
case Optim =>
PythonInterpreter.getValue[Double](s"${name}.defaults['lr']")
case _ =>
throw new IllegalArgumentException()
}
}

override def loadFromTable(config: Table): TorchOptim.this.type = {
this
}

}

object TorchOptim{
sealed trait OptimType

case object LrSchedule extends OptimType
case object Optim extends OptimType

def apply[T: ClassTag](optimBytes: Array[Byte])(implicit ev: TensorNumeric[T]): TorchOptim[T] = {
new TorchOptim[T](optimBytes)
}
}
4 changes: 4 additions & 0 deletions net/python/PythonZooNet.scala
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,10 @@ class PythonZooNet[T: ClassTag](implicit ev: TensorNumeric[T]) extends PythonZoo
TorchLoss(criterion)
}

def createTorchOptim(optim: Array[Byte]): TorchOptim[T] = {
TorchOptim(optim)
}

def torchNetSavePytorch(torchnet: TorchNet, path: String): Unit = {
torchnet.savePytorch(path)
}
Expand Down

0 comments on commit 5f3b659

Please sign in to comment.