From 8329f43c3ecc20ec67dbba7111137557a29436ef Mon Sep 17 00:00:00 2001
From: Le-Zheng <30695225+Le-Zheng@users.noreply.github.com>
Date: Wed, 27 May 2020 02:11:56 +0100
Subject: [PATCH] DistriOptimizerV2 argument (#3003)

* call DistriOptimizerV2
---
 .../analytics/bigdl/optim/Optimizer.scala     | 70 ++++++++++++++-----
 .../bigdl/python/api/PythonBigDL.scala        | 28 +++++---
 .../intel/analytics/bigdl/utils/Engine.scala  | 33 +++++++++
 3 files changed, 102 insertions(+), 29 deletions(-)

diff --git a/spark/dl/src/main/scala/com/intel/analytics/bigdl/optim/Optimizer.scala b/spark/dl/src/main/scala/com/intel/analytics/bigdl/optim/Optimizer.scala
index fa9e7d91132..10701ec8718 100644
--- a/spark/dl/src/main/scala/com/intel/analytics/bigdl/optim/Optimizer.scala
+++ b/spark/dl/src/main/scala/com/intel/analytics/bigdl/optim/Optimizer.scala
@@ -27,6 +27,7 @@ import com.intel.analytics.bigdl.parameters.{ConstantClippingProcessor,
 import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric
 import com.intel.analytics.bigdl.utils._
 import com.intel.analytics.bigdl.visualization.{TrainSummary, ValidationSummary}
+import com.intel.analytics.bigdl.utils.Engine
 import org.apache.log4j.Logger
 import org.apache.spark.rdd.RDD
 
@@ -611,13 +612,24 @@ object Optimizer {
     val _featurePaddingParam = if (featurePaddingParam != null) Some(featurePaddingParam) else None
     val _labelPaddingParam = if (labelPaddingParam != null) Some(labelPaddingParam) else None
 
-    new DistriOptimizer[T](
-       _model = model,
-       _dataset = (DataSet.rdd(sampleRDD) ->
-         SampleToMiniBatch(batchSize, _featurePaddingParam, _labelPaddingParam))
-         .toDistributed(),
-       _criterion = criterion
-     ).asInstanceOf[Optimizer[T, MiniBatch[T]]]
+    Engine.getOptimizerVersion() match {
+      case OptimizerV1 =>
+        new DistriOptimizer[T](
+          _model = model,
+          _dataset = (DataSet.rdd(sampleRDD) ->
+            SampleToMiniBatch(batchSize, _featurePaddingParam, _labelPaddingParam))
+            .toDistributed(),
+          _criterion = criterion
+        ).asInstanceOf[Optimizer[T, MiniBatch[T]]]
+      case OptimizerV2 =>
+        new DistriOptimizerV2[T](
+          _model = model,
+          _dataset = (DataSet.rdd(sampleRDD) ->
+            SampleToMiniBatch(batchSize, _featurePaddingParam, _labelPaddingParam))
+            .toDistributed(),
+          _criterion = criterion
+        ).asInstanceOf[Optimizer[T, MiniBatch[T]]]
+    }
   }
 
 
@@ -640,13 +652,24 @@ object Optimizer {
           batchSize: Int,
           miniBatchImpl: MiniBatch[T]
         )(implicit ev: TensorNumeric[T]): Optimizer[T, MiniBatch[T]] = {
-    new DistriOptimizer[T](
-      _model = model,
-      _dataset = (DataSet.rdd(sampleRDD) ->
-        SampleToMiniBatch(miniBatchImpl, batchSize, None))
-        .toDistributed(),
-      _criterion = criterion
-    ).asInstanceOf[Optimizer[T, MiniBatch[T]]]
+    Engine.getOptimizerVersion() match {
+      case OptimizerV1 =>
+        new DistriOptimizer[T](
+          _model = model,
+          _dataset = (DataSet.rdd(sampleRDD) ->
+            SampleToMiniBatch(miniBatchImpl, batchSize, None))
+            .toDistributed(),
+          _criterion = criterion
+        ).asInstanceOf[Optimizer[T, MiniBatch[T]]]
+      case OptimizerV2 =>
+        new DistriOptimizerV2[T](
+          _model = model,
+          _dataset = (DataSet.rdd(sampleRDD) ->
+            SampleToMiniBatch(miniBatchImpl, batchSize, None))
+            .toDistributed(),
+          _criterion = criterion
+        ).asInstanceOf[Optimizer[T, MiniBatch[T]]]
+    }
   }
 
   /**
@@ -664,11 +687,20 @@ object Optimizer {
   )(implicit ev: TensorNumeric[T]): Optimizer[T, D] = {
     dataset match {
       case d: DistributedDataSet[_] =>
-        new DistriOptimizer[T](
-          _model = model,
-          _dataset = d.toDistributed().asInstanceOf[DistributedDataSet[MiniBatch[T]]],
-          _criterion = criterion
-        ).asInstanceOf[Optimizer[T, D]]
+        Engine.getOptimizerVersion() match {
+          case OptimizerV1 =>
+            new DistriOptimizer[T](
+              _model = model,
+              _dataset = d.toDistributed().asInstanceOf[DistributedDataSet[MiniBatch[T]]],
+              _criterion = criterion
+            ).asInstanceOf[Optimizer[T, D]]
+          case OptimizerV2 =>
+            new DistriOptimizerV2[T](
+              _model = model,
+              _dataset = d.toDistributed().asInstanceOf[DistributedDataSet[MiniBatch[T]]],
+              _criterion = criterion
+            ).asInstanceOf[Optimizer[T, D]]
+        }
       case d: LocalDataSet[_] =>
         new LocalOptimizer[T](
           model = model,
diff --git a/spark/dl/src/main/scala/com/intel/analytics/bigdl/python/api/PythonBigDL.scala b/spark/dl/src/main/scala/com/intel/analytics/bigdl/python/api/PythonBigDL.scala
index 36ac83aac7d..42fb5dd913b 100644
--- a/spark/dl/src/main/scala/com/intel/analytics/bigdl/python/api/PythonBigDL.scala
+++ b/spark/dl/src/main/scala/com/intel/analytics/bigdl/python/api/PythonBigDL.scala
@@ -2351,12 +2351,11 @@ class PythonBigDL[T: ClassTag](implicit ev: TensorNumeric[T]) extends Serializab
                             endTrigger: Trigger,
                             batchSize: Int): Optimizer[T, MiniBatch[T]] = {
     val sampleRDD = toJSample(trainingRdd)
-
-    val optimizer = new DistriOptimizer(
-      _model = model,
-      _dataset = batching(DataSet.rdd(sampleRDD), batchSize)
+    val optimizer = Optimizer(
+      model = model,
+      dataset = batching(DataSet.rdd(sampleRDD), batchSize)
         .asInstanceOf[DistributedDataSet[MiniBatch[T]]],
-      _criterion = criterion
+      criterion = criterion
     ).asInstanceOf[Optimizer[T, MiniBatch[T]]]
     enrichOptimizer(optimizer, endTrigger, optimMethod.asScala.toMap)
   }
@@ -2368,11 +2367,10 @@ class PythonBigDL[T: ClassTag](implicit ev: TensorNumeric[T]) extends Serializab
     endTrigger: Trigger,
     batchSize: Int): Optimizer[T, MiniBatch[T]] = {
     val dataSet = trainDataSet -> ImageFeatureToMiniBatch[T](batchSize)
-
-    val optimizer = new DistriOptimizer(
-      _model = model,
-      _dataset = dataSet.asInstanceOf[DistributedDataSet[MiniBatch[T]]],
-      _criterion = criterion
+    val optimizer = Optimizer(
+      model = model,
+      dataset = dataSet.asInstanceOf[DistributedDataSet[MiniBatch[T]]],
+      criterion = criterion
     ).asInstanceOf[Optimizer[T, MiniBatch[T]]]
     enrichOptimizer(optimizer, endTrigger, optimMethod.asScala.toMap)
   }
@@ -2516,6 +2514,16 @@ class PythonBigDL[T: ClassTag](implicit ev: TensorNumeric[T]) extends Serializab
     Array(Engine.nodeNumber(), Engine.coreNumber())
   }
 
+  def setOptimizerVersion(version: String): Unit = {
+    version.toLowerCase() match {
+      case "optimizerv1" => Engine.setOptimizerVersion(OptimizerV1)
+      case "optimizerv2" => Engine.setOptimizerVersion(OptimizerV2)
+    }
+  }
+
+  def getOptimizerVersion(): String = {
+    Engine.getOptimizerVersion().toString
+  }
 
   def setWeights(model: AbstractModule[Activity, Activity, T], weights: JList[JTensor]): Unit = {
     val weightTensor = weights.asScala.toArray.map(toTensor(_))
diff --git a/spark/dl/src/main/scala/com/intel/analytics/bigdl/utils/Engine.scala b/spark/dl/src/main/scala/com/intel/analytics/bigdl/utils/Engine.scala
index e0c880573b0..2b063160010 100644
--- a/spark/dl/src/main/scala/com/intel/analytics/bigdl/utils/Engine.scala
+++ b/spark/dl/src/main/scala/com/intel/analytics/bigdl/utils/Engine.scala
@@ -37,6 +37,14 @@ sealed trait EngineType
 case object MklBlas extends EngineType
 case object MklDnn extends EngineType
 
+/**
+ * define optimizer version trait
+ */
+sealed trait OptimizerVersion
+
+case object OptimizerV1 extends OptimizerVersion
+case object OptimizerV2 extends OptimizerVersion
+
 
 object Engine {
 
@@ -215,6 +223,18 @@ object Engine {
     }
   }
 
+  /**
+   * Notice: Please use property bigdl.optimizerVersion to set optimizerVersion.
+   * Default version is OptimizerV1
+   */
+  private var optimizerVersion: OptimizerVersion = {
+    System.getProperty("bigdl.optimizerVersion", "optimizerv1").toLowerCase(Locale.ROOT) match {
+      case "optimizerv1" => OptimizerV1
+      case "optimizerv2" => OptimizerV2
+      case optimizerVersion => throw new IllegalArgumentException(s"Unknown type $optimizerVersion")
+    }
+  }
+
   // Thread pool for default use
   @volatile private var _default: ThreadPool = null
 
@@ -314,6 +334,19 @@ object Engine {
     nodeNum = n
   }
 
+  /**
+   * This method should only be used for test purpose.
+   *
+   * @param optimizerVersion
+   */
+  private[bigdl] def setOptimizerVersion(optimizerVersion : OptimizerVersion): Unit = {
+    this.optimizerVersion = optimizerVersion
+  }
+
+  private[bigdl] def getOptimizerVersion(): OptimizerVersion = {
+    this.optimizerVersion
+  }
+
   /**
    * This method should only be used for test purpose.
    *