Skip to content

Commit

Permalink
Wrap some useful torch layers (intel#3)
Browse files Browse the repository at this point in the history
* more torch layers

* update

* python wrapper

* update alias

* style
  • Loading branch information
hkvision authored Apr 13, 2018
1 parent 18471fc commit e65a059
Showing 1 changed file with 69 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ import com.intel.analytics.bigdl.nn.Graph.ModuleNode
import com.intel.analytics.bigdl.nn.abstractnn.Activity
import com.intel.analytics.bigdl.nn.keras.KerasLayer
import com.intel.analytics.zoo.pipeline.api.keras.layers._
import com.intel.analytics.zoo.pipeline.api.keras.layers.extra._
import com.intel.analytics.zoo.pipeline.api.keras.layers.utils.KerasUtils

import scala.reflect.ClassTag

Expand Down Expand Up @@ -71,4 +73,71 @@ class PythonZooKeras[T: ClassTag](implicit ev: TensorNumeric[T]) extends PythonB
bRegularizer, bias, toScalaShape(inputShape))
}


// ================================= Torch layers in Keras Style =================================

def createZooKerasSelect(
dim: Int,
index: Int,
inputShape: JList[Int] = null): Select[T] = {
Select(dim, index, toScalaShape(inputShape))
}

def createZooKerasNarrow(
dim: Int,
offset: Int,
length: Int = 1,
inputShape: JList[Int] = null): Narrow[T] = {
Narrow(dim, offset, length, toScalaShape(inputShape))
}

def createZooKerasSqueeze(
dims: JList[Int],
inputShape: JList[Int] = null): Squeeze[T] = {
Squeeze(toScalaArray(dims), toScalaShape(inputShape))
}

def createZooKerasAddConstant(
constant: Double,
inputShape: JList[Int] = null): AddConstant[T] = {
AddConstant(constant, toScalaShape(inputShape))
}

def createZooKerasMulConstant(
constant: Double,
inputShape: JList[Int] = null): MulConstant[T] = {
MulConstant(constant, toScalaShape(inputShape))
}

def createZooKerasLRN2D(
alpha: Double = 1e-4,
k: Double = 1.0,
beta: Double = 0.75,
n: Int = 5,
dimOrdering: String = "th",
inputShape: JList[Int] = null): LRN2D[T] = {
LRN2D(alpha, k, beta, n, dimOrdering, toScalaShape(inputShape))
}

def createZooKerasShareConvolution2D(
nbFilter: Int,
nbRow: Int,
nbCol: Int,
init: String = "glorot_uniform",
activation: String = null,
subsample: JList[Int],
padH: Int = 0,
padW: Int = 0,
propagateBack: Boolean = true,
dimOrdering: String = "th",
wRegularizer: Regularizer[T] = null,
bRegularizer: Regularizer[T] = null,
bias: Boolean = true,
inputShape: JList[Int] = null): ShareConvolution2D[T] = {
new ShareConvolution2D(nbFilter, nbRow, nbCol, KerasUtils.getInitMethod(init),
KerasUtils.getKerasActivation(activation), toScalaArray(subsample),
padH, padW, propagateBack, KerasUtils.toBigDLFormat(dimOrdering),
wRegularizer, bRegularizer, bias, toScalaShape(inputShape))
}

}

0 comments on commit e65a059

Please sign in to comment.