From 1022049c78e55914c54dff6d5206ad56dba7eef4 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Tue, 10 Jan 2017 21:22:16 -0800 Subject: [PATCH] [SPARK-19133][SPARKR][ML][BACKPORT-2.1] fix glm for Gamma, clarify glm family supported ## What changes were proposed in this pull request? backporting to 2.1, 2.0 and 1.6 ## How was this patch tested? unit tests Author: Felix Cheung Closes #16532 from felixcheung/rgammabackport. --- R/pkg/R/mllib.R | 7 ++++++- R/pkg/inst/tests/testthat/test_mllib.R | 8 ++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index d736bbb5e9113..1a254ad49b086 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -184,6 +184,8 @@ predict_internal <- function(object, newData) { #' This can be a character string naming a family function, a family function or #' the result of a call to a family function. Refer R family at #' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}. +#' Currently these families are supported: \code{binomial}, \code{gaussian}, +#' \code{Gamma}, and \code{poisson}. #' @param tol positive convergence tolerance of iterations. #' @param maxIter integer giving the maximal number of IRLS iterations. #' @param weightCol the weight column name. If this is not set or \code{NULL}, we treat all instance @@ -236,8 +238,9 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), weightCol <- "" } + # For known families, Gamma is upper-cased jobj <- callJStatic("org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper", - "fit", formula, data@sdf, family$family, family$link, + "fit", formula, data@sdf, tolower(family$family), family$link, tol, as.integer(maxIter), as.character(weightCol), regParam) new("GeneralizedLinearRegressionModel", jobj = jobj) }) @@ -252,6 +255,8 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), #' This can be a character string naming a family function, a family function or #' the result of a call to a family function. Refer R family at #' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}. +#' Currently these families are supported: \code{binomial}, \code{gaussian}, +#' \code{Gamma}, and \code{poisson}. #' @param weightCol the weight column name. If this is not set or \code{NULL}, we treat all instance #' weights as 1.0. #' @param epsilon positive convergence tolerance of iterations. diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index 40c0446740277..1f2fae9c813f3 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -74,6 +74,14 @@ test_that("spark.glm and predict", { data = iris, family = poisson(link = identity)), iris)) expect_true(all(abs(rVals - vals) < 1e-6), rVals - vals) + # Gamma family + x <- runif(100, -1, 1) + y <- rgamma(100, rate = 10 / exp(0.5 + 1.2 * x), shape = 10) + df <- as.DataFrame(as.data.frame(list(x = x, y = y))) + model <- glm(y ~ x, family = Gamma, df) + out <- capture.output(print(summary(model))) + expect_true(any(grepl("Dispersion parameter for gamma family", out))) + # Test stats::predict is working x <- rnorm(15) y <- x + rnorm(15)