Skip to content

Commit

Permalink
[SPARK-19616][SPARKR] weightCol and aggregationDepth should be improv…
Browse files Browse the repository at this point in the history
…ed for some SparkR APIs

## What changes were proposed in this pull request?

This is a follow-up PR of apache#16800

When doing SPARK-19456, we found that "" should be consider a NULL column name and should not be set. aggregationDepth should be exposed as an expert parameter.

## How was this patch tested?
Existing tests.

Author: [email protected] <[email protected]>

Closes apache#16945 from wangmiao1981/svc.
  • Loading branch information
wangmiao1981 authored and Yun Ni committed Feb 27, 2017
1 parent 93c6477 commit 7c32b69
Show file tree
Hide file tree
Showing 8 changed files with 50 additions and 19 deletions.
2 changes: 1 addition & 1 deletion R/pkg/R/generics.R
Original file line number Diff line number Diff line change
Expand Up @@ -1406,7 +1406,7 @@ setGeneric("spark.randomForest",

#' @rdname spark.survreg
#' @export
setGeneric("spark.survreg", function(data, formula) { standardGeneric("spark.survreg") })
setGeneric("spark.survreg", function(data, formula, ...) { standardGeneric("spark.survreg") })

#' @rdname spark.svmLinear
#' @export
Expand Down
13 changes: 9 additions & 4 deletions R/pkg/R/mllib_classification.R
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,9 @@ function(object, path, overwrite = FALSE) {
#' excepting that at most one value may be 0. The class with largest value p/t is predicted, where p
#' is the original probability of that class and t is the class's threshold.
#' @param weightCol The weight column name.
#' @param aggregationDepth The depth for treeAggregate (greater than or equal to 2). If the dimensions of features
#' or the number of partitions are large, this param could be adjusted to a larger size.
#' This is an expert parameter. Default value should be good for most cases.
#' @param ... additional arguments passed to the method.
#' @return \code{spark.logit} returns a fitted logistic regression model.
#' @rdname spark.logit
Expand Down Expand Up @@ -245,19 +248,21 @@ function(object, path, overwrite = FALSE) {
setMethod("spark.logit", signature(data = "SparkDataFrame", formula = "formula"),
function(data, formula, regParam = 0.0, elasticNetParam = 0.0, maxIter = 100,
tol = 1E-6, family = "auto", standardization = TRUE,
thresholds = 0.5, weightCol = NULL) {
thresholds = 0.5, weightCol = NULL, aggregationDepth = 2) {
formula <- paste(deparse(formula), collapse = "")

if (is.null(weightCol)) {
weightCol <- ""
if (!is.null(weightCol) && weightCol == "") {
weightCol <- NULL
} else if (!is.null(weightCol)) {
weightCol <- as.character(weightCol)
}

jobj <- callJStatic("org.apache.spark.ml.r.LogisticRegressionWrapper", "fit",
data@sdf, formula, as.numeric(regParam),
as.numeric(elasticNetParam), as.integer(maxIter),
as.numeric(tol), as.character(family),
as.logical(standardization), as.array(thresholds),
as.character(weightCol))
weightCol, as.integer(aggregationDepth))
new("LogisticRegressionModel", jobj = jobj)
})

Expand Down
24 changes: 16 additions & 8 deletions R/pkg/R/mllib_regression.R
Original file line number Diff line number Diff line change
Expand Up @@ -102,14 +102,16 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"),
}

formula <- paste(deparse(formula), collapse = "")
if (is.null(weightCol)) {
weightCol <- ""
if (!is.null(weightCol) && weightCol == "") {
weightCol <- NULL
} else if (!is.null(weightCol)) {
weightCol <- as.character(weightCol)
}

# For known families, Gamma is upper-cased
jobj <- callJStatic("org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper",
"fit", formula, data@sdf, tolower(family$family), family$link,
tol, as.integer(maxIter), as.character(weightCol), regParam)
tol, as.integer(maxIter), weightCol, regParam)
new("GeneralizedLinearRegressionModel", jobj = jobj)
})

Expand Down Expand Up @@ -305,13 +307,15 @@ setMethod("spark.isoreg", signature(data = "SparkDataFrame", formula = "formula"
function(data, formula, isotonic = TRUE, featureIndex = 0, weightCol = NULL) {
formula <- paste(deparse(formula), collapse = "")

if (is.null(weightCol)) {
weightCol <- ""
if (!is.null(weightCol) && weightCol == "") {
weightCol <- NULL
} else if (!is.null(weightCol)) {
weightCol <- as.character(weightCol)
}

jobj <- callJStatic("org.apache.spark.ml.r.IsotonicRegressionWrapper", "fit",
data@sdf, formula, as.logical(isotonic), as.integer(featureIndex),
as.character(weightCol))
weightCol)
new("IsotonicRegressionModel", jobj = jobj)
})

Expand Down Expand Up @@ -372,6 +376,10 @@ setMethod("write.ml", signature(object = "IsotonicRegressionModel", path = "char
#' @param formula a symbolic description of the model to be fitted. Currently only a few formula
#' operators are supported, including '~', ':', '+', and '-'.
#' Note that operator '.' is not supported currently.
#' @param aggregationDepth The depth for treeAggregate (greater than or equal to 2). If the dimensions of features
#' or the number of partitions are large, this param could be adjusted to a larger size.
#' This is an expert parameter. Default value should be good for most cases.
#' @param ... additional arguments passed to the method.
#' @return \code{spark.survreg} returns a fitted AFT survival regression model.
#' @rdname spark.survreg
#' @seealso survival: \url{https://cran.r-project.org/package=survival}
Expand All @@ -396,10 +404,10 @@ setMethod("write.ml", signature(object = "IsotonicRegressionModel", path = "char
#' }
#' @note spark.survreg since 2.0.0
setMethod("spark.survreg", signature(data = "SparkDataFrame", formula = "formula"),
function(data, formula) {
function(data, formula, aggregationDepth = 2) {
formula <- paste(deparse(formula), collapse = "")
jobj <- callJStatic("org.apache.spark.ml.r.AFTSurvivalRegressionWrapper",
"fit", formula, data@sdf)
"fit", formula, data@sdf, as.integer(aggregationDepth))
new("AFTSurvivalRegressionModel", jobj = jobj)
})

Expand Down
10 changes: 9 additions & 1 deletion R/pkg/inst/tests/testthat/test_mllib_classification.R
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,15 @@ test_that("spark.logit", {
df <- createDataFrame(data)
model <- spark.logit(df, label ~ feature)
prediction <- collect(select(predict(model, df), "prediction"))
expect_equal(prediction$prediction, c("0.0", "0.0", "1.0", "1.0", "0.0"))
expect_equal(sort(prediction$prediction), c("0.0", "0.0", "0.0", "1.0", "1.0"))

# Test prediction with weightCol
weight <- c(2.0, 2.0, 2.0, 1.0, 1.0)
data2 <- as.data.frame(cbind(label, feature, weight))
df2 <- createDataFrame(data2)
model2 <- spark.logit(df2, label ~ feature, weightCol = "weight")
prediction2 <- collect(select(predict(model2, df2), "prediction"))
expect_equal(sort(prediction2$prediction), c("0.0", "0.0", "0.0", "0.0", "0.0"))
})

test_that("spark.mlp", {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,10 @@ private[r] object AFTSurvivalRegressionWrapper extends MLReadable[AFTSurvivalReg
}


def fit(formula: String, data: DataFrame): AFTSurvivalRegressionWrapper = {
def fit(
formula: String,
data: DataFrame,
aggregationDepth: Int): AFTSurvivalRegressionWrapper = {

val (rewritedFormula, censorCol) = formulaRewrite(formula)

Expand All @@ -100,6 +103,7 @@ private[r] object AFTSurvivalRegressionWrapper extends MLReadable[AFTSurvivalReg
.setCensorCol(censorCol)
.setFitIntercept(rFormula.hasIntercept)
.setFeaturesCol(rFormula.getFeaturesCol)
.setAggregationDepth(aggregationDepth)

val pipeline = new Pipeline()
.setStages(Array(rFormulaModel, aft))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,11 @@ private[r] object GeneralizedLinearRegressionWrapper
.setFitIntercept(rFormula.hasIntercept)
.setTol(tol)
.setMaxIter(maxIter)
.setWeightCol(weightCol)
.setRegParam(regParam)
.setFeaturesCol(rFormula.getFeaturesCol)

if (weightCol != null) glr.setWeightCol(weightCol)

val pipeline = new Pipeline()
.setStages(Array(rFormulaModel, glr))
.fit(data)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,10 @@ private[r] object IsotonicRegressionWrapper
val isotonicRegression = new IsotonicRegression()
.setIsotonic(isotonic)
.setFeatureIndex(featureIndex)
.setWeightCol(weightCol)
.setFeaturesCol(rFormula.getFeaturesCol)

if (weightCol != null) isotonicRegression.setWeightCol(weightCol)

val pipeline = new Pipeline()
.setStages(Array(rFormulaModel, isotonicRegression))
.fit(data)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ private[r] object LogisticRegressionWrapper
family: String,
standardization: Boolean,
thresholds: Array[Double],
weightCol: String
weightCol: String,
aggregationDepth: Int
): LogisticRegressionWrapper = {

val rFormula = new RFormula()
Expand All @@ -119,17 +120,19 @@ private[r] object LogisticRegressionWrapper
.setFitIntercept(fitIntercept)
.setFamily(family)
.setStandardization(standardization)
.setWeightCol(weightCol)
.setFeaturesCol(rFormula.getFeaturesCol)
.setLabelCol(rFormula.getLabelCol)
.setPredictionCol(PREDICTED_LABEL_INDEX_COL)
.setAggregationDepth(aggregationDepth)

if (thresholds.length > 1) {
lr.setThresholds(thresholds)
} else {
lr.setThreshold(thresholds(0))
}

if (weightCol != null) lr.setWeightCol(weightCol)

val idxToStr = new IndexToString()
.setInputCol(PREDICTED_LABEL_INDEX_COL)
.setOutputCol(PREDICTED_LABEL_COL)
Expand Down

0 comments on commit 7c32b69

Please sign in to comment.