From 7da5a2d598098dca2fee47eff0a47edc2de81419 Mon Sep 17 00:00:00 2001 From: minmingzhu Date: Thu, 24 Mar 2022 22:24:55 +0800 Subject: [PATCH 1/2] update Signed-off-by: minmingzhu --- .../ml/classification/spark320/NaiveBayes.scala | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/mllib-dal/src/main/scala/org/apache/spark/ml/classification/spark320/NaiveBayes.scala b/mllib-dal/src/main/scala/org/apache/spark/ml/classification/spark320/NaiveBayes.scala index 819ef1905..2cf94ec64 100644 --- a/mllib-dal/src/main/scala/org/apache/spark/ml/classification/spark320/NaiveBayes.scala +++ b/mllib-dal/src/main/scala/org/apache/spark/ml/classification/spark320/NaiveBayes.scala @@ -133,8 +133,12 @@ class NaiveBayes @Since("1.5.0") ( val sc = spark.sparkContext - val executor_num = Utils.sparkExecutorNum(sc) - val executor_cores = Utils.sparkExecutorCores() + // select label and features columns and cache data. + val naiveBayesData = dataset.select($(labelCol), $(featuresCol)).cache() + naiveBayesData.count() + + val executorNum = Utils.sparkExecutorNum(sc) + val executorCores = Utils.sparkExecutorCores() logInfo(s"NaiveBayesDAL fit using $executor_num Executors") @@ -146,17 +150,17 @@ class NaiveBayes @Since("1.5.0") ( // numClasses should be explicitly included in the parquet metadata // This can be done by applying StringIndexer to the label column val numClasses = confClasses match { - case -1 => getNumClasses(dataset) + case -1 => getNumClasses(naiveBayesData) case _ => confClasses } instr.logNumClasses(numClasses) - val labeledPointsDS = dataset + val labeledPointsDS = naiveBayesData .select(col(getLabelCol), DatasetUtils.columnToVector(dataset, getFeaturesCol)) val dalModel = new NaiveBayesDALImpl(uid, numClasses, - executor_num, executor_cores).train(labeledPointsDS, ${labelCol}, ${featuresCol}) + executorNum, executorCores).train(labeledPointsDS, ${labelCol}, ${featuresCol}) val model = copyValues(new NaiveBayesModel( dalModel.uid, dalModel.pi, dalModel.theta, dalModel.sigma)) From e891d7d9a7ec21dddb88c1ca73025aeec8f7ae64 Mon Sep 17 00:00:00 2001 From: minmingzhu <45281494+minmingzhu@users.noreply.github.com> Date: Tue, 29 Mar 2022 10:00:03 +0800 Subject: [PATCH 2/2] Update NaiveBayes.scala --- .../apache/spark/ml/classification/spark320/NaiveBayes.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib-dal/src/main/scala/org/apache/spark/ml/classification/spark320/NaiveBayes.scala b/mllib-dal/src/main/scala/org/apache/spark/ml/classification/spark320/NaiveBayes.scala index 2cf94ec64..e34a19d52 100644 --- a/mllib-dal/src/main/scala/org/apache/spark/ml/classification/spark320/NaiveBayes.scala +++ b/mllib-dal/src/main/scala/org/apache/spark/ml/classification/spark320/NaiveBayes.scala @@ -140,7 +140,7 @@ class NaiveBayes @Since("1.5.0") ( val executorNum = Utils.sparkExecutorNum(sc) val executorCores = Utils.sparkExecutorCores() - logInfo(s"NaiveBayesDAL fit using $executor_num Executors") + logInfo(s"NaiveBayesDAL fit using $executorNum Executors") // DAL only support [0..numClasses) as labels, should map original labels using StringIndexer // Todo: optimize getting num of classes @@ -336,4 +336,4 @@ class NaiveBayes @Since("1.5.0") ( val sigma = new DenseMatrix(numLabels, numFeatures, sigmaArray, true) new NaiveBayesModel(uid, pi.compressed, theta.compressed, sigma.compressed) } -} \ No newline at end of file +}