Skip to content

Commit

Permalink
revert back api (intel-analytics#2943)
Browse files Browse the repository at this point in the history
  • Loading branch information
wzhongyuan authored Oct 24, 2019
1 parent dff06ce commit 1c7f742
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,7 @@ abstract class AbstractModule[A <: Activity: ClassTag, B <: Activity: ClassTag,
* if -1, default is 4 * partitionNumber of dataset
*/

final def predictClass(dataset: RDD[Sample[T]], batchSize: Int = -1): RDD[Sample[T]] = {
final def predictClass(dataset: RDD[Sample[T]], batchSize: Int = -1): RDD[Int] = {
Predictor(this).predictClass(dataset, batchSize)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,34 +180,15 @@ object Predictor {

def predictClass[T: ClassTag](dataSet: RDD[Sample[T]], batchSize: Int = -1, model: Module[T],
batchPerPartition: Int, featurePaddingParam: Option[PaddingParam[T]])(
implicit ev: TensorNumeric[T]): RDD[Sample[T]] = {
val shareBuffer = false
val modelBroad = ModelBroadcast[T]().broadcast(dataSet.sparkContext,
ConversionUtils.convert(model.evaluate()))
val partitionNum = dataSet.partitions.length
val totalBatch = if (batchSize > 0) {
require(batchSize % partitionNum == 0, s"Predictor.predict: total batch size $batchSize " +
s"should be divided by partitionNum ${partitionNum}")
batchSize
} else {
batchPerPartition * partitionNum
}
val rdd = ConversionUtils.coalesce(dataSet)
val realPartitionLength = rdd.partitions.length
val otherBroad = rdd.sparkContext.broadcast(SampleToMiniBatch(
batchSize = totalBatch,
partitionNum = Some(realPartitionLength),
featurePaddingParam = featurePaddingParam))
val localBatchPerPartition = totalBatch / realPartitionLength
rdd.mapPartitions { partition =>
val localModel = modelBroad.value()
val localTransformer = otherBroad.value.cloneTransformer()
partition.grouped(localBatchPerPartition).flatMap(samples => {
val batchOut = predictSamples(localModel, samples, localTransformer, shareBuffer)
samples.toIterator.zip(batchOut).foreach(tuple => {
Sample(tuple._1.feature(), tuple._2.toTensor)
})
samples
implicit ev: TensorNumeric[T]): RDD[Int] = {
val result = Predictor.predict(dataSet, batchSize, true, model,
batchPerPartition, featurePaddingParam)
result.mapPartitions { partition =>
partition.map(output => {
val _output = output.toTensor[T]
require(_output.dim() == 1, s"Predictor.predictClass:" +
s"Only support one sample has one label, but got ${_output.dim()} label")
ev.toType[Int](_output.max(1)._2.valueAt(1))
})
}
}
Expand All @@ -233,7 +214,7 @@ class Predictor[T: ClassTag] private[optim](
batchPerPartition: Int = 4)
(implicit ev: TensorNumeric[T]) extends Serializable {

def predictClass(dataSet: RDD[Sample[T]], batchSize: Int = -1): RDD[Sample[T]] = {
def predictClass(dataSet: RDD[Sample[T]], batchSize: Int = -1): RDD[Int] = {
Predictor.predictClass(dataSet, batchSize, model, batchPerPartition, featurePaddingParam)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2050,10 +2050,10 @@ class PythonBigDL[T: ClassTag](implicit ev: TensorNumeric[T]) extends Serializab
}

def modelPredictClass(model: AbstractModule[Activity, Activity, T],
dataRdd: JavaRDD[Sample]): JavaRDD[Sample] = {
val sampleRDD = toJSample(dataRdd)
val pySampleRDD = model.predictClass(sampleRDD).map(toPySample(_))
new JavaRDD[Sample](pySampleRDD)
dataRdd: JavaRDD[Sample]): JavaRDD[Int] = {
val sampleRdd = toJSample(dataRdd)
val tensorRDD = model.predictClass(sampleRdd)
new JavaRDD[Int](tensorRDD)
}

def modelForward(model: AbstractModule[Activity, Activity, T],
Expand Down

0 comments on commit 1c7f742

Please sign in to comment.