From 954d65e12ee089457b54ed728d62616b86514464 Mon Sep 17 00:00:00 2001 From: Xu Xiao Date: Wed, 7 Feb 2018 16:32:10 +0800 Subject: [PATCH] bug fix: DLModel prediction (#2194) * bug fix: DLModel prediction (#4) Make sure DLModel.train=False when predicting in pipeline API * 1. broadcast transformer in DLModel.transform ; 2. remove useless ut --- .../dl/src/main/scala/org/apache/spark/ml/DLEstimator.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/spark/dl/src/main/scala/org/apache/spark/ml/DLEstimator.scala b/spark/dl/src/main/scala/org/apache/spark/ml/DLEstimator.scala index a04334987512..6408d37fe8fc 100644 --- a/spark/dl/src/main/scala/org/apache/spark/ml/DLEstimator.scala +++ b/spark/dl/src/main/scala/org/apache/spark/ml/DLEstimator.scala @@ -387,11 +387,13 @@ class DLModel[@specialized(Float, Double) T: ClassTag]( val featureColIndex = dataFrame.schema.fieldIndex($(featuresCol)) val featureFunc = getConvertFunc(featureType) val sc = dataFrame.sqlContext.sparkContext - val modelBroadCast = ModelBroadcast[T]().broadcast(sc, model) + val modelBroadCast = ModelBroadcast[T]().broadcast(sc, model.evaluate()) val localBatchSize = $(batchSize) + val transformerBC = sc.broadcast(SampleToMiniBatch[T](localBatchSize)) val resultRDD = dataFrame.rdd.mapPartitions { rowIter => val localModel = modelBroadCast.value() + val transformer = transformerBC.value.cloneTransformer() rowIter.grouped(localBatchSize).flatMap { rowBatch => val samples = rowBatch.map { row => val features = featureFunc(row, featureColIndex) @@ -401,7 +403,7 @@ class DLModel[@specialized(Float, Double) T: ClassTag]( } Sample(Tensor(featureBuffer.toArray, featureSize)) }.toIterator - val predictions = SampleToMiniBatch(localBatchSize).apply(samples).flatMap { batch => + val predictions = transformer(samples).flatMap { batch => val batchResult = localModel.forward(batch.getInput()) batchResult.toTensor.split(1).map(outputToPrediction) }