From 1f86e795b87ba93640062f29e87a032924d94b2a Mon Sep 17 00:00:00 2001 From: "wm624@hotmail.com" Date: Wed, 22 Feb 2017 11:50:24 -0800 Subject: [PATCH] [SPARK-19616][SPARKR] weightCol and aggregationDepth should be improved for some SparkR APIs ## What changes were proposed in this pull request? This is a follow-up PR of #16800 When doing SPARK-19456, we found that "" should be consider a NULL column name and should not be set. aggregationDepth should be exposed as an expert parameter. ## How was this patch tested? Existing tests. Author: wm624@hotmail.com Closes #16945 from wangmiao1981/svc. --- R/pkg/R/generics.R | 2 +- R/pkg/R/mllib_classification.R | 13 ++++++---- R/pkg/R/mllib_regression.R | 24 ++++++++++++------- .../testthat/test_mllib_classification.R | 10 +++++++- .../ml/r/AFTSurvivalRegressionWrapper.scala | 6 ++++- .../GeneralizedLinearRegressionWrapper.scala | 4 +++- .../ml/r/IsotonicRegressionWrapper.scala | 3 ++- .../ml/r/LogisticRegressionWrapper.scala | 7 ++++-- 8 files changed, 50 insertions(+), 19 deletions(-) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 11940d356039e..647cbbdd825e3 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1406,7 +1406,7 @@ setGeneric("spark.randomForest", #' @rdname spark.survreg #' @export -setGeneric("spark.survreg", function(data, formula) { standardGeneric("spark.survreg") }) +setGeneric("spark.survreg", function(data, formula, ...) { standardGeneric("spark.survreg") }) #' @rdname spark.svmLinear #' @export diff --git a/R/pkg/R/mllib_classification.R b/R/pkg/R/mllib_classification.R index fa0d795faa10f..05bb95266173a 100644 --- a/R/pkg/R/mllib_classification.R +++ b/R/pkg/R/mllib_classification.R @@ -207,6 +207,9 @@ function(object, path, overwrite = FALSE) { #' excepting that at most one value may be 0. The class with largest value p/t is predicted, where p #' is the original probability of that class and t is the class's threshold. #' @param weightCol The weight column name. +#' @param aggregationDepth The depth for treeAggregate (greater than or equal to 2). If the dimensions of features +#' or the number of partitions are large, this param could be adjusted to a larger size. +#' This is an expert parameter. Default value should be good for most cases. #' @param ... additional arguments passed to the method. #' @return \code{spark.logit} returns a fitted logistic regression model. #' @rdname spark.logit @@ -245,11 +248,13 @@ function(object, path, overwrite = FALSE) { setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula"), function(data, formula, regParam = 0.0, elasticNetParam = 0.0, maxIter = 100, tol = 1E-6, family = "auto", standardization = TRUE, - thresholds = 0.5, weightCol = NULL) { + thresholds = 0.5, weightCol = NULL, aggregationDepth = 2) { formula <- paste(deparse(formula), collapse = "") - if (is.null(weightCol)) { - weightCol <- "" + if (!is.null(weightCol) && weightCol == "") { + weightCol <- NULL + } else if (!is.null(weightCol)) { + weightCol <- as.character(weightCol) } jobj <- callJStatic("org.apache.spark.ml.r.LogisticRegressionWrapper", "fit", @@ -257,7 +262,7 @@ setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula") as.numeric(elasticNetParam), as.integer(maxIter), as.numeric(tol), as.character(family), as.logical(standardization), as.array(thresholds), - as.character(weightCol)) + weightCol, as.integer(aggregationDepth)) new("LogisticRegressionModel", jobj = jobj) }) diff --git a/R/pkg/R/mllib_regression.R b/R/pkg/R/mllib_regression.R index 96ee220bc4113..ac0578c4ab259 100644 --- a/R/pkg/R/mllib_regression.R +++ b/R/pkg/R/mllib_regression.R @@ -102,14 +102,16 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), } formula <- paste(deparse(formula), collapse = "") - if (is.null(weightCol)) { - weightCol <- "" + if (!is.null(weightCol) && weightCol == "") { + weightCol <- NULL + } else if (!is.null(weightCol)) { + weightCol <- as.character(weightCol) } # For known families, Gamma is upper-cased jobj <- callJStatic("org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper", "fit", formula, data@sdf, tolower(family$family), family$link, - tol, as.integer(maxIter), as.character(weightCol), regParam) + tol, as.integer(maxIter), weightCol, regParam) new("GeneralizedLinearRegressionModel", jobj = jobj) }) @@ -305,13 +307,15 @@ setMethod("spark.isoreg", signature(data = "SparkDataFrame", formula = "formula" function(data, formula, isotonic = TRUE, featureIndex = 0, weightCol = NULL) { formula <- paste(deparse(formula), collapse = "") - if (is.null(weightCol)) { - weightCol <- "" + if (!is.null(weightCol) && weightCol == "") { + weightCol <- NULL + } else if (!is.null(weightCol)) { + weightCol <- as.character(weightCol) } jobj <- callJStatic("org.apache.spark.ml.r.IsotonicRegressionWrapper", "fit", data@sdf, formula, as.logical(isotonic), as.integer(featureIndex), - as.character(weightCol)) + weightCol) new("IsotonicRegressionModel", jobj = jobj) }) @@ -372,6 +376,10 @@ setMethod("write.ml", signature(object = "IsotonicRegressionModel", path = "char #' @param formula a symbolic description of the model to be fitted. Currently only a few formula #' operators are supported, including '~', ':', '+', and '-'. #' Note that operator '.' is not supported currently. +#' @param aggregationDepth The depth for treeAggregate (greater than or equal to 2). If the dimensions of features +#' or the number of partitions are large, this param could be adjusted to a larger size. +#' This is an expert parameter. Default value should be good for most cases. +#' @param ... additional arguments passed to the method. #' @return \code{spark.survreg} returns a fitted AFT survival regression model. #' @rdname spark.survreg #' @seealso survival: \url{https://cran.r-project.org/package=survival} @@ -396,10 +404,10 @@ setMethod("write.ml", signature(object = "IsotonicRegressionModel", path = "char #' } #' @note spark.survreg since 2.0.0 setMethod("spark.survreg", signature(data = "SparkDataFrame", formula = "formula"), - function(data, formula) { + function(data, formula, aggregationDepth = 2) { formula <- paste(deparse(formula), collapse = "") jobj <- callJStatic("org.apache.spark.ml.r.AFTSurvivalRegressionWrapper", - "fit", formula, data@sdf) + "fit", formula, data@sdf, as.integer(aggregationDepth)) new("AFTSurvivalRegressionModel", jobj = jobj) }) diff --git a/R/pkg/inst/tests/testthat/test_mllib_classification.R b/R/pkg/inst/tests/testthat/test_mllib_classification.R index 620f528f2e6c8..459254d271a58 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_classification.R +++ b/R/pkg/inst/tests/testthat/test_mllib_classification.R @@ -211,7 +211,15 @@ test_that("spark.logit", { df <- createDataFrame(data) model <- spark.logit(df, label ~ feature) prediction <- collect(select(predict(model, df), "prediction")) - expect_equal(prediction$prediction, c("0.0", "0.0", "1.0", "1.0", "0.0")) + expect_equal(sort(prediction$prediction), c("0.0", "0.0", "0.0", "1.0", "1.0")) + + # Test prediction with weightCol + weight <- c(2.0, 2.0, 2.0, 1.0, 1.0) + data2 <- as.data.frame(cbind(label, feature, weight)) + df2 <- createDataFrame(data2) + model2 <- spark.logit(df2, label ~ feature, weightCol = "weight") + prediction2 <- collect(select(predict(model2, df2), "prediction")) + expect_equal(sort(prediction2$prediction), c("0.0", "0.0", "0.0", "0.0", "0.0")) }) test_that("spark.mlp", { diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala index bd965acf56944..0bf543d88894e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala @@ -82,7 +82,10 @@ private[r] object AFTSurvivalRegressionWrapper extends MLReadable[AFTSurvivalReg } - def fit(formula: String, data: DataFrame): AFTSurvivalRegressionWrapper = { + def fit( + formula: String, + data: DataFrame, + aggregationDepth: Int): AFTSurvivalRegressionWrapper = { val (rewritedFormula, censorCol) = formulaRewrite(formula) @@ -100,6 +103,7 @@ private[r] object AFTSurvivalRegressionWrapper extends MLReadable[AFTSurvivalReg .setCensorCol(censorCol) .setFitIntercept(rFormula.hasIntercept) .setFeaturesCol(rFormula.getFeaturesCol) + .setAggregationDepth(aggregationDepth) val pipeline = new Pipeline() .setStages(Array(rFormulaModel, aft)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala index 78f401f29b004..cbd6cd1c7933c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala @@ -87,9 +87,11 @@ private[r] object GeneralizedLinearRegressionWrapper .setFitIntercept(rFormula.hasIntercept) .setTol(tol) .setMaxIter(maxIter) - .setWeightCol(weightCol) .setRegParam(regParam) .setFeaturesCol(rFormula.getFeaturesCol) + + if (weightCol != null) glr.setWeightCol(weightCol) + val pipeline = new Pipeline() .setStages(Array(rFormulaModel, glr)) .fit(data) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/IsotonicRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/IsotonicRegressionWrapper.scala index 48632316f3950..d31ebb46afb97 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/IsotonicRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/IsotonicRegressionWrapper.scala @@ -74,9 +74,10 @@ private[r] object IsotonicRegressionWrapper val isotonicRegression = new IsotonicRegression() .setIsotonic(isotonic) .setFeatureIndex(featureIndex) - .setWeightCol(weightCol) .setFeaturesCol(rFormula.getFeaturesCol) + if (weightCol != null) isotonicRegression.setWeightCol(weightCol) + val pipeline = new Pipeline() .setStages(Array(rFormulaModel, isotonicRegression)) .fit(data) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala index 645bc7247f30f..c96f99cb83434 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/LogisticRegressionWrapper.scala @@ -96,7 +96,8 @@ private[r] object LogisticRegressionWrapper family: String, standardization: Boolean, thresholds: Array[Double], - weightCol: String + weightCol: String, + aggregationDepth: Int ): LogisticRegressionWrapper = { val rFormula = new RFormula() @@ -119,10 +120,10 @@ private[r] object LogisticRegressionWrapper .setFitIntercept(fitIntercept) .setFamily(family) .setStandardization(standardization) - .setWeightCol(weightCol) .setFeaturesCol(rFormula.getFeaturesCol) .setLabelCol(rFormula.getLabelCol) .setPredictionCol(PREDICTED_LABEL_INDEX_COL) + .setAggregationDepth(aggregationDepth) if (thresholds.length > 1) { lr.setThresholds(thresholds) @@ -130,6 +131,8 @@ private[r] object LogisticRegressionWrapper lr.setThreshold(thresholds(0)) } + if (weightCol != null) lr.setWeightCol(weightCol) + val idxToStr = new IndexToString() .setInputCol(PREDICTED_LABEL_INDEX_COL) .setOutputCol(PREDICTED_LABEL_COL)