diff --git a/mllib-dal/src/main/scala/org/apache/spark/ml/classification/spark320/NaiveBayes.scala b/mllib-dal/src/main/scala/org/apache/spark/ml/classification/spark320/NaiveBayes.scala index 819ef1905..e34a19d52 100644 --- a/mllib-dal/src/main/scala/org/apache/spark/ml/classification/spark320/NaiveBayes.scala +++ b/mllib-dal/src/main/scala/org/apache/spark/ml/classification/spark320/NaiveBayes.scala @@ -133,10 +133,14 @@ class NaiveBayes @Since("1.5.0") ( val sc = spark.sparkContext - val executor_num = Utils.sparkExecutorNum(sc) - val executor_cores = Utils.sparkExecutorCores() + // select label and features columns and cache data. + val naiveBayesData = dataset.select($(labelCol), $(featuresCol)).cache() + naiveBayesData.count() - logInfo(s"NaiveBayesDAL fit using $executor_num Executors") + val executorNum = Utils.sparkExecutorNum(sc) + val executorCores = Utils.sparkExecutorCores() + + logInfo(s"NaiveBayesDAL fit using $executorNum Executors") // DAL only support [0..numClasses) as labels, should map original labels using StringIndexer // Todo: optimize getting num of classes @@ -146,17 +150,17 @@ class NaiveBayes @Since("1.5.0") ( // numClasses should be explicitly included in the parquet metadata // This can be done by applying StringIndexer to the label column val numClasses = confClasses match { - case -1 => getNumClasses(dataset) + case -1 => getNumClasses(naiveBayesData) case _ => confClasses } instr.logNumClasses(numClasses) - val labeledPointsDS = dataset + val labeledPointsDS = naiveBayesData .select(col(getLabelCol), DatasetUtils.columnToVector(dataset, getFeaturesCol)) val dalModel = new NaiveBayesDALImpl(uid, numClasses, - executor_num, executor_cores).train(labeledPointsDS, ${labelCol}, ${featuresCol}) + executorNum, executorCores).train(labeledPointsDS, ${labelCol}, ${featuresCol}) val model = copyValues(new NaiveBayesModel( dalModel.uid, dalModel.pi, dalModel.theta, dalModel.sigma)) @@ -332,4 +336,4 @@ class NaiveBayes @Since("1.5.0") ( val sigma = new DenseMatrix(numLabels, numFeatures, sigmaArray, true) new NaiveBayesModel(uid, pi.compressed, theta.compressed, sigma.compressed) } -} \ No newline at end of file +}