Skip to content

Commit

Permalink
Save KerasModel to pure keras or tf protobuf (intel-analytics#1606)
Browse files Browse the repository at this point in the history
* checkpoint

* add unit test

* revert neuralCF

* revert neuralCF

* some update

* some update

* some change
  • Loading branch information
qiuxin2012 authored Sep 30, 2019
1 parent 18a4241 commit 6235781
Show file tree
Hide file tree
Showing 12 changed files with 130 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ import com.intel.analytics.bigdl.Module
import com.intel.analytics.bigdl.nn.Graph._
import com.intel.analytics.bigdl.nn.abstractnn.{AbstractModule, Activity, Initializable}
import com.intel.analytics.bigdl.nn.keras.{KerasIdentityWrapper, KerasLayer}
import com.intel.analytics.bigdl.nn.{Container, Graph, InitializationMethod}
import com.intel.analytics.bigdl.nn.{Sequential => TSequential}
import com.intel.analytics.bigdl.nn.{Container, Graph, InitializationMethod, StaticGraph, Identity => BIdentity, Sequential => TSequential}
import com.intel.analytics.bigdl.python.api.PythonBigDL
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.tensor.TensorNumericMath.TensorNumeric
Expand All @@ -35,7 +34,7 @@ import com.intel.analytics.bigdl.utils.serializer.ModuleLoader
import com.intel.analytics.bigdl.utils.tf.{Session, TensorflowLoader}
import com.intel.analytics.zoo.common.Utils
import com.intel.analytics.zoo.pipeline.api.autograd.Variable
import com.intel.analytics.zoo.pipeline.api.keras.layers.{KerasLayerWrapper, WordEmbedding}
import com.intel.analytics.zoo.pipeline.api.keras.layers.{KerasLayerWrapper, Merge, WordEmbedding}
import com.intel.analytics.zoo.pipeline.api.keras.layers.utils.KerasUtils
import com.intel.analytics.zoo.pipeline.api.keras.models.{KerasNet, Model, Sequential}
import com.intel.analytics.zoo.pipeline.api.net.{GraphNet, NetUtils}
Expand Down Expand Up @@ -66,18 +65,22 @@ trait Net {
this.asInstanceOf[AbstractModule[Activity, Activity, T]].inputs(vars.map(_.node): _*))
}

private[zoo] def toKeras2(dir: String): String = {
private[zoo] def toKeras2(): String = {
throw new UnimplementedException()
}

/**
* Get keras-like weights.
* @tparam T
* @return
* Need to override this when this default weights doesn't match the weights in Keras.
* @return keras-like weights.
*/
private[zoo] def getKerasWeights(): Array[Tensor[Float]] = {
if (this.asInstanceOf[AbstractModule[_, _, _]].parameters()._1.length != 0) {
throw new UnimplementedException()
val weights = this.asInstanceOf[AbstractModule[_, _, _]].parameters()._1
val kWeights = Array.tabulate(weights.length)(_ => Tensor[Float]())
(0 until weights.length).foreach(i =>
weights(i).cast[Float](kWeights(i).resizeAs(weights(i))))
kWeights
} else {
Array()
}
Expand Down Expand Up @@ -236,9 +239,9 @@ object Net {

private[zoo] def inputShapeToString(
inputShape: Shape,
paramName: String = "inputShape"): Map[String, String] = {
paramName: String = "input_shape"): Map[String, String] = {
if (inputShape != null) {
Map("input_shape" -> s"(${inputShape.toSingle().mkString(", ")},)")
Map(paramName -> s"(${inputShape.toSingle().mkString(", ")},)")
} else {
Map()
}
Expand Down Expand Up @@ -295,12 +298,19 @@ object Net {
params.map(v => s"${v._1}=${v._2}").mkString(", ") + ")"
}

private[zoo] def kerasDef(
moduleType: String,
params: Map[String, String]): String = {
s"${moduleType}(" +
params.map(v => s"${v._1}=${v._2}").mkString(", ") + ")"
}

protected object NetSaver {
private val logger = Logger.getLogger(getClass)

protected val header =
"""
|from tensorflow.keras.models import Sequential
|from tensorflow.keras.models import Sequential, Model
|from tensorflow.keras.layers import *
|from pyspark.serializers import PickleSerializer
|
Expand All @@ -320,22 +330,23 @@ object Net {
""".stripMargin + "\n"

def save[T: ClassTag](
m: Module[T],
module: Module[T],
path: String,
python: String,
saveCommand: String)
(implicit ev: TensorNumeric[T]): Unit = {
val tmpDir = Utils.createTmpDir("ZooKeras")
logger.info(s"Write model's temp file to ${tmpDir}")
val modelFile = tmpDir.toString + s"/${m.getName()}.py"
val modelFile = tmpDir.toString + s"/${module.getName()}.py"
val bw = new BufferedWriter(new FileWriter(modelFile))
bw.write(header)
if (m.isInstanceOf[Sequential[T]]) {
export(m.asInstanceOf[Sequential[T]], tmpDir.toString, bw)
} else {
throw new IllegalArgumentException(s"${m.getClass.getName} is not supported.")
module match {
case s: Sequential[T] => export(s, bw)
case m: Model[T] => export(m, bw)
case _ =>
throw new IllegalArgumentException(s"${module.getClass.getName} is not supported.")
}
bw.write(saveWeights(m, tmpDir.toString))
bw.write(saveWeights(module, tmpDir.toString))
bw.write(saveCommand)
bw.flush()
bw.close()
Expand Down Expand Up @@ -368,22 +379,52 @@ object Net {
}
throw new RuntimeException(s"Export Keras2 model failed:\n" + errorMsg.toString())
}
}

def export[T: ClassTag](
model: Model[T],
writer: BufferedWriter): Unit = {
val inputs = model.getInputs()
val outputs = model.getOutputs()
val nodes = model.labor.asInstanceOf[StaticGraph[T]].getSortedForwardExecutions()
nodes.foreach(export(_, writer))
val inputsName = inputs.map(_.element.getName).mkString(", ")
val outputsName = outputs.map(_.element.getName).mkString(", ")
writer.write(s"${model.getName()} = Model(inputs=[${inputsName}]," +
s" outputs=[${outputsName}])\n")
}

def export[T: ClassTag](
node: ModuleNode[T],
writer: BufferedWriter): Unit = {
val element = node.element
if (!element.isInstanceOf[Net]) {
throw new IllegalArgumentException(s"Unsupported layer ${element.getName()}")
} else {
val pre = if (node.prevNodes.length == 1) {
s"(${node.prevNodes(0).element.getName})"
} else if (node.prevNodes.length > 1) {
s"([${node.prevNodes.map(_.element.getName).mkString(", ")}])"
} else {
""
}
writer.write(s"${element.getName()} = ${element.asInstanceOf[Net].toKeras2()}${pre}\n")
writer.flush()
}
}

def export[T: ClassTag](
sequential: Sequential[T],
path: String,
writer: BufferedWriter): Unit = {
writer.write(s"${sequential.getName()} = " +
s"Sequential(name='${(sequential.getName())}')\n")
val modules = sequential.modules(0).asInstanceOf[TSequential[T]].modules
modules.foreach{ module =>
if (module.isInstanceOf[Sequential[T]]) {
export(module.asInstanceOf[Sequential[T]], path, writer)
export(module.asInstanceOf[Sequential[T]], writer)
writer.write(s"${sequential.getName()}.add(${module.getName})\n")
} else if (module.isInstanceOf[Net]) {
writer.write(s"${module.getName()} = ${module.asInstanceOf[Net].toKeras2(path)}\n")
writer.write(s"${module.getName()} = ${module.asInstanceOf[Net].toKeras2()}\n")
writer.write(s"${sequential.getName()}.add(${module.getName})\n")
} else {
throw new IllegalArgumentException(s"unkown type ${this.getClass.getName}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class Dense[T: ClassTag](
extends BigDLDense[T](outputDim, init, activation, wRegularizer, bRegularizer, bias,
inputShape) with Net {

override private[zoo] def toKeras2(dir: String): String = {
override private[zoo] def toKeras2(): String = {
val params = Net.inputShapeToString(inputShape) ++
Net.activationToString(activation) ++
Net.param(getName()) ++
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class Dropout[T: ClassTag](
(implicit ev: TensorNumeric[T])
extends com.intel.analytics.bigdl.nn.keras.Dropout[T](p, inputShape) with Net {

override private[zoo] def toKeras2(dir: String): String = {
override private[zoo] def toKeras2(): String = {
val params = Net.inputShapeToString(inputShape) ++
Net.param(getName()) ++
Net.param(p, "rate")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,15 @@ class Embedding[T: ClassTag](
model.add(layer)
model.asInstanceOf[AbstractModule[Tensor[T], Tensor[T], T]]
}

override private[zoo] def toKeras2(): String = {
val params = Net.inputShapeToString(inputShape) ++
Net.param(getName()) ++
Net.param(inputDim, "input_dim") ++
Net.param(outputDim, "output_dim") ++
Net.param(maskZero, "mask_zero")
Net.kerasDef(this, params)
}
}

object Embedding {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ class Flatten[T: ClassTag](
Array(input.slice(1, input.length).product), batchMode = Some(true))
layer.asInstanceOf[AbstractModule[Tensor[T], Tensor[T], T]]
}

override private[zoo] def toKeras2(): String = {
val params = Net.inputShapeToString(inputShape) ++
Net.param(getName())
Net.kerasDef(this, params)
}
}

object Flatten {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ class Input[T: ClassTag](val inputShape: Shape)(implicit ev: TensorNumeric[T])
override def allowRebuilt(): Boolean = true

override def skipDuplicateCheck(): Boolean = skipDuplicate

override private[zoo] def toKeras2(): String = {
val params = Net.inputShapeToString(inputShape, "shape") ++
Net.param(getName())
Net.kerasDef(this, params)
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class LSTM[T: ClassTag](
bRegularizer = bRegularizer)
}

override private[zoo] def toKeras2(dir: String): String = {
override private[zoo] def toKeras2(): String = {
val params = Net.inputShapeToString(inputShape) ++
Net.activationToString(activation) ++
Net.activationToString(innerActivation, "recurrent_activation") ++
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,26 @@ class Merge[T: ClassTag](
extends KerasLayer[Tensor[T], Tensor[T], T](Merge.calcBatchInputShape(inputShape, layers))
with Net {

override private[zoo] def toKeras2(): String = {
var params = Net.inputShapeToString(inputShape) ++
Net.param(getName())
val kerasLayerName = mode match {
case "sum" => "Add"
case "mul" => "Multiply"
case "max" => "Maximum"
case "ave" => "Average"
case "sub" => "Subtract"
case "min" => "Minimum"
case "concat" =>
params ++= Net.param(concatAxis, "axis")
"Concatenate"
case "dot" => "Dot"
case _ =>
throw new IllegalArgumentException(s"Merge ${mode} is not supported in Keras2")
}
Net.kerasDef(kerasLayerName, params)
}

private val mergeMode = mode.toLowerCase()
private var axis = concatAxis

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class Permute[T: ClassTag](
extends BigDLPermute[T](
dims, inputShape) with Net {

override private[zoo] def toKeras2(dir: String): String = {
override private[zoo] def toKeras2(): String = {
val params = Net.inputShapeToString(inputShape) ++
Net.param(getName()) ++
Net.arrayToString(dims, "dims")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class Reshape[T: ClassTag](
layer.asInstanceOf[AbstractModule[Tensor[T], Tensor[T], T]]
}

override private[zoo] def toKeras2(dir: String): String = {
override private[zoo] def toKeras2(): String = {
val params = Net.inputShapeToString(inputShape) ++
Net.param(getName()) ++
Net.arrayToString(targetShape, "target_shape")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,13 @@ class Select[T: ClassTag](
val layer = com.intel.analytics.bigdl.nn.Select(positiveDim + 1, positiveIndex + 1)
layer.asInstanceOf[AbstractModule[Tensor[T], Tensor[T], T]]
}

override private[zoo] def toKeras2(): String = {
val params = Net.inputShapeToString(inputShape) ++
Net.param(getName()) ++
Net.param(dim, "dim")
Net.kerasDef(this, params)
}
}

object Select {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,10 @@ class Model[T: ClassTag] private (private val _inputs : Seq[ModuleNode[T]],

KerasLayerRef(this).setOutShape(Shape(_outputs.map{_.element.getOutputShape()}.toList))

private[zoo] def getInputs(): Seq[ModuleNode[T]] = _inputs

private[zoo] def getOutputs(): Seq[ModuleNode[T]] = _outputs

override def isKerasStyle(): Boolean = true

override def computeOutputShape(inputShape: Shape): Shape = {
Expand Down Expand Up @@ -664,6 +668,18 @@ class Model[T: ClassTag] private (private val _inputs : Seq[ModuleNode[T]],

override def toKeras(): Model[T] = this

override private[zoo] def getKerasWeights(): Array[Tensor[Float]] = {
val weights = new ArrayBuffer[Tensor[Float]]()
modules(0).asInstanceOf[StaticGraph[T]].modules.foreach(m => {
val params = m.asInstanceOf[Net].getKerasWeights()
if (params != null) {
params.foreach(weights += _)
}
})
weights.toArray
}


override def summary(
lineLength: Int = 120,
positions: Array[Double] = Array(.33, .55, .67, 1)): Unit = {
Expand Down

0 comments on commit 6235781

Please sign in to comment.