diff --git a/spark/dl/src/main/scala/org/apache/spark/ml/DLEstimator.scala b/spark/dl/src/main/scala/org/apache/spark/ml/DLEstimator.scala index a0433498751..6408d37fe8f 100644 --- a/spark/dl/src/main/scala/org/apache/spark/ml/DLEstimator.scala +++ b/spark/dl/src/main/scala/org/apache/spark/ml/DLEstimator.scala @@ -387,11 +387,13 @@ class DLModel[@specialized(Float, Double) T: ClassTag]( val featureColIndex = dataFrame.schema.fieldIndex($(featuresCol)) val featureFunc = getConvertFunc(featureType) val sc = dataFrame.sqlContext.sparkContext - val modelBroadCast = ModelBroadcast[T]().broadcast(sc, model) + val modelBroadCast = ModelBroadcast[T]().broadcast(sc, model.evaluate()) val localBatchSize = $(batchSize) + val transformerBC = sc.broadcast(SampleToMiniBatch[T](localBatchSize)) val resultRDD = dataFrame.rdd.mapPartitions { rowIter => val localModel = modelBroadCast.value() + val transformer = transformerBC.value.cloneTransformer() rowIter.grouped(localBatchSize).flatMap { rowBatch => val samples = rowBatch.map { row => val features = featureFunc(row, featureColIndex) @@ -401,7 +403,7 @@ class DLModel[@specialized(Float, Double) T: ClassTag]( } Sample(Tensor(featureBuffer.toArray, featureSize)) }.toIterator - val predictions = SampleToMiniBatch(localBatchSize).apply(samples).flatMap { batch => + val predictions = transformer(samples).flatMap { batch => val batchResult = localModel.forward(batch.getInput()) batchResult.toTensor.split(1).map(outputToPrediction) }