Skip to content

Commit

Permalink
Wrap Sequential and Model (#11)
Browse files Browse the repository at this point in the history
* extend container

* input

* python wrapper

* fix topology

* update

* fix

* trial

* fix java version

* remove

* update

* update pom
  • Loading branch information
hkvision authored Apr 4, 2018
1 parent c4eafd9 commit 18471fc
Showing 1 changed file with 30 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,18 @@

package com.intel.analytics.zoo.pipeline.api.keras.python

import java.util.{List => JList}
import scala.collection.JavaConverters._

import com.intel.analytics.bigdl.optim.Regularizer
import com.intel.analytics.bigdl.python.api.PythonBigDLKeras
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric
import com.intel.analytics.zoo.pipeline.api.keras.layers.{Dense => ZDense}
import com.intel.analytics.bigdl.nn.Graph.ModuleNode
import com.intel.analytics.bigdl.nn.abstractnn.Activity
import com.intel.analytics.bigdl.nn.keras.KerasLayer
import com.intel.analytics.zoo.pipeline.api.keras.layers._

import scala.reflect.ClassTag
import java.util.{List => JList}


object PythonZooKeras {

Expand All @@ -34,15 +38,36 @@ object PythonZooKeras {

class PythonZooKeras[T: ClassTag](implicit ev: TensorNumeric[T]) extends PythonBigDLKeras[T] {

def createZooKerasModel(
input: JList[ModuleNode[T]],
output: JList[ModuleNode[T]]): Model[T] = {
Model[T](input.asScala.toArray, output.asScala.toArray)
}

def createZooKerasSequential(): Sequential[T] = {
Sequential[T]()
}

def createZooKerasInput(
name : String = null,
inputShape: JList[Int] = null): ModuleNode[T] = {
Input(name = name, inputShape = toScalaShape(inputShape))
}

def createZooKerasInputLayer(
inputShape: JList[Int] = null): KerasLayer[Activity, Activity, T] = {
InputLayer(inputShape = toScalaShape(inputShape))
}

def createZooKerasDense(
outputDim: Int,
init: String = "glorot_uniform",
activation: String = null,
wRegularizer: Regularizer[T] = null,
bRegularizer: Regularizer[T] = null,
bias: Boolean = true,
inputShape: JList[Int] = null): ZDense[T] = {
ZDense(outputDim, init, activation, wRegularizer,
inputShape: JList[Int] = null): Dense[T] = {
Dense(outputDim, init, activation, wRegularizer,
bRegularizer, bias, toScalaShape(inputShape))
}

Expand Down

0 comments on commit 18471fc

Please sign in to comment.