Skip to content

Commit

Permalink
[SPARK-19133][SPARKR][ML][BACKPORT-2.1] fix glm for Gamma, clarify gl…
Browse files Browse the repository at this point in the history
…m family supported

## What changes were proposed in this pull request?

backporting to 2.1, 2.0 and 1.6

## How was this patch tested?

unit tests

Author: Felix Cheung <[email protected]>

Closes #16532 from felixcheung/rgammabackport.
  • Loading branch information
felixcheung authored and Felix Cheung committed Jan 11, 2017
1 parent 230607d commit 1022049
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 1 deletion.
7 changes: 6 additions & 1 deletion R/pkg/R/mllib.R
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,8 @@ predict_internal <- function(object, newData) {
#' This can be a character string naming a family function, a family function or
#' the result of a call to a family function. Refer R family at
#' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}.
#' Currently these families are supported: \code{binomial}, \code{gaussian},
#' \code{Gamma}, and \code{poisson}.
#' @param tol positive convergence tolerance of iterations.
#' @param maxIter integer giving the maximal number of IRLS iterations.
#' @param weightCol the weight column name. If this is not set or \code{NULL}, we treat all instance
Expand Down Expand Up @@ -236,8 +238,9 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"),
weightCol <- ""
}

# For known families, Gamma is upper-cased
jobj <- callJStatic("org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper",
"fit", formula, data@sdf, family$family, family$link,
"fit", formula, data@sdf, tolower(family$family), family$link,
tol, as.integer(maxIter), as.character(weightCol), regParam)
new("GeneralizedLinearRegressionModel", jobj = jobj)
})
Expand All @@ -252,6 +255,8 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"),
#' This can be a character string naming a family function, a family function or
#' the result of a call to a family function. Refer R family at
#' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}.
#' Currently these families are supported: \code{binomial}, \code{gaussian},
#' \code{Gamma}, and \code{poisson}.
#' @param weightCol the weight column name. If this is not set or \code{NULL}, we treat all instance
#' weights as 1.0.
#' @param epsilon positive convergence tolerance of iterations.
Expand Down
8 changes: 8 additions & 0 deletions R/pkg/inst/tests/testthat/test_mllib.R
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,14 @@ test_that("spark.glm and predict", {
data = iris, family = poisson(link = identity)), iris))
expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals)

# Gamma family
x <- runif(100, -1, 1)
y <- rgamma(100, rate = 10 / exp(0.5 + 1.2 * x), shape = 10)
df <- as.DataFrame(as.data.frame(list(x = x, y = y)))
model <- glm(y ~ x, family = Gamma, df)
out <- capture.output(print(summary(model)))
expect_true(any(grepl("Dispersion parameter for gamma family", out)))

# Test stats::predict is working
x <- rnorm(15)
y <- x + rnorm(15)
Expand Down

0 comments on commit 1022049

Please sign in to comment.