From fd39e711d199b033f955c239d61954c0fb13616d Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Tue, 20 Aug 2019 12:15:59 -0700 Subject: [PATCH] support multi input models for nnframes (#1553) * support multi input for nnframes * update ut * add doc and unit test * doc update * scala style --- .../dllib/nnframes/NNEstimatorSpec.scala | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/scala/dllib/src/test/scala/com/intel/analytics/bigdl/dllib/nnframes/NNEstimatorSpec.scala b/scala/dllib/src/test/scala/com/intel/analytics/bigdl/dllib/nnframes/NNEstimatorSpec.scala index ac4621b204e..d905ddaca0b 100644 --- a/scala/dllib/src/test/scala/com/intel/analytics/bigdl/dllib/nnframes/NNEstimatorSpec.scala +++ b/scala/dllib/src/test/scala/com/intel/analytics/bigdl/dllib/nnframes/NNEstimatorSpec.scala @@ -31,6 +31,9 @@ import com.intel.analytics.bigdl.visualization.{TrainSummary, ValidationSummary} import com.intel.analytics.zoo.common.NNContext import com.intel.analytics.zoo.feature.common.{TensorToSample, _} import com.intel.analytics.zoo.feature.image._ +import com.intel.analytics.zoo.pipeline.api.keras.layers.Merge.merge +import com.intel.analytics.zoo.pipeline.api.keras.layers.{Input, Dense} +import com.intel.analytics.zoo.pipeline.api.keras.models.Model import org.apache.spark.SparkContext import org.apache.spark.ml.{Pipeline, PipelineModel} import org.apache.spark.ml.feature.MinMaxScaler @@ -596,6 +599,24 @@ class NNEstimatorSpec extends FlatSpec with Matchers with BeforeAndAfter { Path(tmpFile).deleteRecursively() } } + + "An NNEstimator" should "support multi-input model" in { + val input1 = Input(Shape(4)) + val input2 = Input(Shape(2)) + val latent = merge(inputs = List(input1, input2), mode = "concat") + val output = Dense(2, activation = "log_softmax").inputs(latent) + val model = Model(Array(input1, input2), output) + + val criterion = ClassNLLCriterion[Float]() + val estimator = NNEstimator(model, criterion, Array(Array(4), Array(2)), Array(1)) + .setBatchSize(nRecords) + .setMaxEpoch(5) + + val data = sc.parallelize(smallData) + val df = sqlContext.createDataFrame(data).toDF("features", "label") + val nnmodel = estimator.fit(df) + nnmodel.transform(df).collect() + } } private case class MinibatchData[T](featureData : Array[T], labelData : Array[T])