From c9634dee9600440e247a7dfe758e42f6d4babb30 Mon Sep 17 00:00:00 2001 From: Jerry Wu Date: Thu, 24 Oct 2019 16:22:08 +0800 Subject: [PATCH] revert back api (#2943) --- .../dllib/nn/abstractnn/AbstractModule.scala | 2 +- .../bigdl/dllib/optim/Predictor.scala | 39 +++++-------------- .../dllib/utils/python/api/PythonBigDL.scala | 8 ++-- 3 files changed, 15 insertions(+), 34 deletions(-) diff --git a/scala/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/nn/abstractnn/AbstractModule.scala b/scala/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/nn/abstractnn/AbstractModule.scala index cec55896e55..f218abb36cf 100644 --- a/scala/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/nn/abstractnn/AbstractModule.scala +++ b/scala/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/nn/abstractnn/AbstractModule.scala @@ -647,7 +647,7 @@ abstract class AbstractModule[A <: Activity: ClassTag, B <: Activity: ClassTag, * if -1, default is 4 * partitionNumber of dataset */ - final def predictClass(dataset: RDD[Sample[T]], batchSize: Int = -1): RDD[Sample[T]] = { + final def predictClass(dataset: RDD[Sample[T]], batchSize: Int = -1): RDD[Int] = { Predictor(this).predictClass(dataset, batchSize) } diff --git a/scala/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/optim/Predictor.scala b/scala/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/optim/Predictor.scala index 5b18d86c6d5..86886b7fc01 100644 --- a/scala/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/optim/Predictor.scala +++ b/scala/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/optim/Predictor.scala @@ -180,34 +180,15 @@ object Predictor { def predictClass[T: ClassTag](dataSet: RDD[Sample[T]], batchSize: Int = -1, model: Module[T], batchPerPartition: Int, featurePaddingParam: Option[PaddingParam[T]])( - implicit ev: TensorNumeric[T]): RDD[Sample[T]] = { - val shareBuffer = false - val modelBroad = ModelBroadcast[T]().broadcast(dataSet.sparkContext, - ConversionUtils.convert(model.evaluate())) - val partitionNum = dataSet.partitions.length - val totalBatch = if (batchSize > 0) { - require(batchSize % partitionNum == 0, s"Predictor.predict: total batch size $batchSize " + - s"should be divided by partitionNum ${partitionNum}") - batchSize - } else { - batchPerPartition * partitionNum - } - val rdd = ConversionUtils.coalesce(dataSet) - val realPartitionLength = rdd.partitions.length - val otherBroad = rdd.sparkContext.broadcast(SampleToMiniBatch( - batchSize = totalBatch, - partitionNum = Some(realPartitionLength), - featurePaddingParam = featurePaddingParam)) - val localBatchPerPartition = totalBatch / realPartitionLength - rdd.mapPartitions { partition => - val localModel = modelBroad.value() - val localTransformer = otherBroad.value.cloneTransformer() - partition.grouped(localBatchPerPartition).flatMap(samples => { - val batchOut = predictSamples(localModel, samples, localTransformer, shareBuffer) - samples.toIterator.zip(batchOut).foreach(tuple => { - Sample(tuple._1.feature(), tuple._2.toTensor) - }) - samples + implicit ev: TensorNumeric[T]): RDD[Int] = { + val result = Predictor.predict(dataSet, batchSize, true, model, + batchPerPartition, featurePaddingParam) + result.mapPartitions { partition => + partition.map(output => { + val _output = output.toTensor[T] + require(_output.dim() == 1, s"Predictor.predictClass:" + + s"Only support one sample has one label, but got ${_output.dim()} label") + ev.toType[Int](_output.max(1)._2.valueAt(1)) }) } } @@ -233,7 +214,7 @@ class Predictor[T: ClassTag] private[optim]( batchPerPartition: Int = 4) (implicit ev: TensorNumeric[T]) extends Serializable { - def predictClass(dataSet: RDD[Sample[T]], batchSize: Int = -1): RDD[Sample[T]] = { + def predictClass(dataSet: RDD[Sample[T]], batchSize: Int = -1): RDD[Int] = { Predictor.predictClass(dataSet, batchSize, model, batchPerPartition, featurePaddingParam) } diff --git a/scala/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/utils/python/api/PythonBigDL.scala b/scala/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/utils/python/api/PythonBigDL.scala index feec0418c26..bb849fa12c1 100644 --- a/scala/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/utils/python/api/PythonBigDL.scala +++ b/scala/dllib/src/main/scala/com/intel/analytics/bigdl/dllib/utils/python/api/PythonBigDL.scala @@ -2050,10 +2050,10 @@ class PythonBigDL[T: ClassTag](implicit ev: TensorNumeric[T]) extends Serializab } def modelPredictClass(model: AbstractModule[Activity, Activity, T], - dataRdd: JavaRDD[Sample]): JavaRDD[Sample] = { - val sampleRDD = toJSample(dataRdd) - val pySampleRDD = model.predictClass(sampleRDD).map(toPySample(_)) - new JavaRDD[Sample](pySampleRDD) + dataRdd: JavaRDD[Sample]): JavaRDD[Int] = { + val sampleRdd = toJSample(dataRdd) + val tensorRDD = model.predictClass(sampleRdd) + new JavaRDD[Int](tensorRDD) } def modelForward(model: AbstractModule[Activity, Activity, T],