Skip to content

Commit

Permalink
Save KerasModel to pure keras or tf protobuf (intel-analytics#1606)
Browse files Browse the repository at this point in the history
* checkpoint

* add unit test

* revert neuralCF

* revert neuralCF

* some update

* some update

* some change
  • Loading branch information
qiuxin2012 authored Sep 30, 2019
1 parent 9841390 commit c764f29
Show file tree
Hide file tree
Showing 11 changed files with 69 additions and 5 deletions.
2 changes: 1 addition & 1 deletion layers/Dense.scala
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class Dense[T: ClassTag](
extends BigDLDense[T](outputDim, init, activation, wRegularizer, bRegularizer, bias,
inputShape) with Net {

override private[zoo] def toKeras2(dir: String): String = {
override private[zoo] def toKeras2(): String = {
val params = Net.inputShapeToString(inputShape) ++
Net.activationToString(activation) ++
Net.param(getName()) ++
Expand Down
2 changes: 1 addition & 1 deletion layers/Dropout.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class Dropout[T: ClassTag](
(implicit ev: TensorNumeric[T])
extends com.intel.analytics.bigdl.nn.keras.Dropout[T](p, inputShape) with Net {

override private[zoo] def toKeras2(dir: String): String = {
override private[zoo] def toKeras2(): String = {
val params = Net.inputShapeToString(inputShape) ++
Net.param(getName()) ++
Net.param(p, "rate")
Expand Down
9 changes: 9 additions & 0 deletions layers/Embedding.scala
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,15 @@ class Embedding[T: ClassTag](
model.add(layer)
model.asInstanceOf[AbstractModule[Tensor[T], Tensor[T], T]]
}

override private[zoo] def toKeras2(): String = {
val params = Net.inputShapeToString(inputShape) ++
Net.param(getName()) ++
Net.param(inputDim, "input_dim") ++
Net.param(outputDim, "output_dim") ++
Net.param(maskZero, "mask_zero")
Net.kerasDef(this, params)
}
}

object Embedding {
Expand Down
6 changes: 6 additions & 0 deletions layers/Flatten.scala
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ class Flatten[T: ClassTag](
Array(input.slice(1, input.length).product), batchMode = Some(true))
layer.asInstanceOf[AbstractModule[Tensor[T], Tensor[T], T]]
}

override private[zoo] def toKeras2(): String = {
val params = Net.inputShapeToString(inputShape) ++
Net.param(getName())
Net.kerasDef(this, params)
}
}

object Flatten {
Expand Down
6 changes: 6 additions & 0 deletions layers/Input.scala
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ class Input[T: ClassTag](val inputShape: Shape)(implicit ev: TensorNumeric[T])
override def allowRebuilt(): Boolean = true

override def skipDuplicateCheck(): Boolean = skipDuplicate

override private[zoo] def toKeras2(): String = {
val params = Net.inputShapeToString(inputShape, "shape") ++
Net.param(getName())
Net.kerasDef(this, params)
}
}

/**
Expand Down
2 changes: 1 addition & 1 deletion layers/LSTM.scala
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class LSTM[T: ClassTag](
bRegularizer = bRegularizer)
}

override private[zoo] def toKeras2(dir: String): String = {
override private[zoo] def toKeras2(): String = {
val params = Net.inputShapeToString(inputShape) ++
Net.activationToString(activation) ++
Net.activationToString(innerActivation, "recurrent_activation") ++
Expand Down
20 changes: 20 additions & 0 deletions layers/Merge.scala
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,26 @@ class Merge[T: ClassTag](
extends KerasLayer[Tensor[T], Tensor[T], T](Merge.calcBatchInputShape(inputShape, layers))
with Net {

override private[zoo] def toKeras2(): String = {
var params = Net.inputShapeToString(inputShape) ++
Net.param(getName())
val kerasLayerName = mode match {
case "sum" => "Add"
case "mul" => "Multiply"
case "max" => "Maximum"
case "ave" => "Average"
case "sub" => "Subtract"
case "min" => "Minimum"
case "concat" =>
params ++= Net.param(concatAxis, "axis")
"Concatenate"
case "dot" => "Dot"
case _ =>
throw new IllegalArgumentException(s"Merge ${mode} is not supported in Keras2")
}
Net.kerasDef(kerasLayerName, params)
}

private val mergeMode = mode.toLowerCase()
private var axis = concatAxis

Expand Down
2 changes: 1 addition & 1 deletion layers/Permute.scala
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class Permute[T: ClassTag](
extends BigDLPermute[T](
dims, inputShape) with Net {

override private[zoo] def toKeras2(dir: String): String = {
override private[zoo] def toKeras2(): String = {
val params = Net.inputShapeToString(inputShape) ++
Net.param(getName()) ++
Net.arrayToString(dims, "dims")
Expand Down
2 changes: 1 addition & 1 deletion layers/Reshape.scala
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class Reshape[T: ClassTag](
layer.asInstanceOf[AbstractModule[Tensor[T], Tensor[T], T]]
}

override private[zoo] def toKeras2(dir: String): String = {
override private[zoo] def toKeras2(): String = {
val params = Net.inputShapeToString(inputShape) ++
Net.param(getName()) ++
Net.arrayToString(targetShape, "target_shape")
Expand Down
7 changes: 7 additions & 0 deletions layers/Select.scala
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,13 @@ class Select[T: ClassTag](
val layer = com.intel.analytics.bigdl.nn.Select(positiveDim + 1, positiveIndex + 1)
layer.asInstanceOf[AbstractModule[Tensor[T], Tensor[T], T]]
}

override private[zoo] def toKeras2(): String = {
val params = Net.inputShapeToString(inputShape) ++
Net.param(getName()) ++
Net.param(dim, "dim")
Net.kerasDef(this, params)
}
}

object Select {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,10 @@ class Model[T: ClassTag] private (private val _inputs : Seq[ModuleNode[T]],

KerasLayerRef(this).setOutShape(Shape(_outputs.map{_.element.getOutputShape()}.toList))

private[zoo] def getInputs(): Seq[ModuleNode[T]] = _inputs

private[zoo] def getOutputs(): Seq[ModuleNode[T]] = _outputs

override def isKerasStyle(): Boolean = true

override def computeOutputShape(inputShape: Shape): Shape = {
Expand Down Expand Up @@ -664,6 +668,18 @@ class Model[T: ClassTag] private (private val _inputs : Seq[ModuleNode[T]],

override def toKeras(): Model[T] = this

override private[zoo] def getKerasWeights(): Array[Tensor[Float]] = {
val weights = new ArrayBuffer[Tensor[Float]]()
modules(0).asInstanceOf[StaticGraph[T]].modules.foreach(m => {
val params = m.asInstanceOf[Net].getKerasWeights()
if (params != null) {
params.foreach(weights += _)
}
})
weights.toArray
}


override def summary(
lineLength: Int = 120,
positions: Array[Double] = Array(.33, .55, .67, 1)): Unit = {
Expand Down

0 comments on commit c764f29

Please sign in to comment.