diff --git a/scala/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/keras/python/PythonZooKeras.scala b/scala/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/keras/python/PythonZooKeras.scala index 45dbc694291..39a7932c377 100644 --- a/scala/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/keras/python/PythonZooKeras.scala +++ b/scala/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/keras/python/PythonZooKeras.scala @@ -26,6 +26,8 @@ import com.intel.analytics.bigdl.nn.Graph.ModuleNode import com.intel.analytics.bigdl.nn.abstractnn.Activity import com.intel.analytics.bigdl.nn.keras.KerasLayer import com.intel.analytics.zoo.pipeline.api.keras.layers._ +import com.intel.analytics.zoo.pipeline.api.keras.layers.extra._ +import com.intel.analytics.zoo.pipeline.api.keras.layers.utils.KerasUtils import scala.reflect.ClassTag @@ -71,4 +73,71 @@ class PythonZooKeras[T: ClassTag](implicit ev: TensorNumeric[T]) extends PythonB bRegularizer, bias, toScalaShape(inputShape)) } + + // ================================= Torch layers in Keras Style ================================= + + def createZooKerasSelect( + dim: Int, + index: Int, + inputShape: JList[Int] = null): Select[T] = { + Select(dim, index, toScalaShape(inputShape)) + } + + def createZooKerasNarrow( + dim: Int, + offset: Int, + length: Int = 1, + inputShape: JList[Int] = null): Narrow[T] = { + Narrow(dim, offset, length, toScalaShape(inputShape)) + } + + def createZooKerasSqueeze( + dims: JList[Int], + inputShape: JList[Int] = null): Squeeze[T] = { + Squeeze(toScalaArray(dims), toScalaShape(inputShape)) + } + + def createZooKerasAddConstant( + constant: Double, + inputShape: JList[Int] = null): AddConstant[T] = { + AddConstant(constant, toScalaShape(inputShape)) + } + + def createZooKerasMulConstant( + constant: Double, + inputShape: JList[Int] = null): MulConstant[T] = { + MulConstant(constant, toScalaShape(inputShape)) + } + + def createZooKerasLRN2D( + alpha: Double = 1e-4, + k: Double = 1.0, + beta: Double = 0.75, + n: Int = 5, + dimOrdering: String = "th", + inputShape: JList[Int] = null): LRN2D[T] = { + LRN2D(alpha, k, beta, n, dimOrdering, toScalaShape(inputShape)) + } + + def createZooKerasShareConvolution2D( + nbFilter: Int, + nbRow: Int, + nbCol: Int, + init: String = "glorot_uniform", + activation: String = null, + subsample: JList[Int], + padH: Int = 0, + padW: Int = 0, + propagateBack: Boolean = true, + dimOrdering: String = "th", + wRegularizer: Regularizer[T] = null, + bRegularizer: Regularizer[T] = null, + bias: Boolean = true, + inputShape: JList[Int] = null): ShareConvolution2D[T] = { + new ShareConvolution2D(nbFilter, nbRow, nbCol, KerasUtils.getInitMethod(init), + KerasUtils.getKerasActivation(activation), toScalaArray(subsample), + padH, padW, propagateBack, KerasUtils.toBigDLFormat(dimOrdering), + wRegularizer, bRegularizer, bias, toScalaShape(inputShape)) + } + }