-
Notifications
You must be signed in to change notification settings - Fork 28.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-19391][SparkR][ML] Tweedie GLM API for SparkR #16729
Changes from all commits
67364ab
654551b
852dd6e
5aa4ae7
3682692
3555afb
56f6da0
083849c
fb66ce0
0d722fd
d11fc4b
4c24158
295711d
c315fb1
9be9c51
201939b
6737122
0b5ed43
b10777e
7d5bd60
a9ac439
f540922
6cbc62f
ef65adc
c11e57c
5ce4c84
aeeb3f7
4cffc40
0b496a6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -53,12 +53,23 @@ setClass("IsotonicRegressionModel", representation(jobj = "jobj")) | |
#' the result of a call to a family function. Refer R family at | ||
#' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}. | ||
#' Currently these families are supported: \code{binomial}, \code{gaussian}, | ||
#' \code{Gamma}, and \code{poisson}. | ||
#' \code{Gamma}, \code{poisson} and \code{tweedie}. | ||
#' | ||
#' Note that there are two ways to specify the tweedie family. | ||
#' \itemize{ | ||
#' \item Set \code{family = "tweedie"} and specify the var.power and link.power; | ||
#' \item When package \code{statmod} is loaded, the tweedie family is specified using the | ||
#' family definition therein, i.e., \code{tweedie(var.power, link.power)}. | ||
#' } | ||
#' @param tol positive convergence tolerance of iterations. | ||
#' @param maxIter integer giving the maximal number of IRLS iterations. | ||
#' @param weightCol the weight column name. If this is not set or \code{NULL}, we treat all instance | ||
#' weights as 1.0. | ||
#' @param regParam regularization parameter for L2 regularization. | ||
#' @param var.power the power in the variance function of the Tweedie distribution which provides | ||
#' the relationship between the variance and mean of the distribution. Only | ||
#' applicable to the Tweedie family. | ||
#' @param link.power the index in the power link function. Only applicable to the Tweedie family. | ||
#' @param ... additional arguments passed to the method. | ||
#' @aliases spark.glm,SparkDataFrame,formula-method | ||
#' @return \code{spark.glm} returns a fitted generalized linear model. | ||
|
@@ -84,14 +95,30 @@ setClass("IsotonicRegressionModel", representation(jobj = "jobj")) | |
#' # can also read back the saved model and print | ||
#' savedModel <- read.ml(path) | ||
#' summary(savedModel) | ||
#' | ||
#' # fit tweedie model | ||
#' model <- spark.glm(df, Freq ~ Sex + Age, family = "tweedie", | ||
#' var.power = 1.2, link.power = 0) | ||
#' summary(model) | ||
#' | ||
#' # use the tweedie family from statmod | ||
#' library(statmod) | ||
#' model <- spark.glm(df, Freq ~ Sex + Age, family = tweedie(1.2, 0)) | ||
#' summary(model) | ||
#' } | ||
#' @note spark.glm since 2.0.0 | ||
#' @seealso \link{glm}, \link{read.ml} | ||
setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), | ||
function(data, formula, family = gaussian, tol = 1e-6, maxIter = 25, weightCol = NULL, | ||
regParam = 0.0) { | ||
regParam = 0.0, var.power = 0.0, link.power = 1.0 - var.power) { | ||
|
||
if (is.character(family)) { | ||
family <- get(family, mode = "function", envir = parent.frame()) | ||
# Handle when family = "tweedie" | ||
if (tolower(family) == "tweedie") { | ||
family <- list(family = "tweedie", link = NULL) | ||
} else { | ||
family <- get(family, mode = "function", envir = parent.frame()) | ||
} | ||
} | ||
if (is.function(family)) { | ||
family <- family() | ||
|
@@ -100,6 +127,12 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), | |
print(family) | ||
stop("'family' not recognized") | ||
} | ||
# Handle when family = statmod::tweedie() | ||
if (tolower(family$family) == "tweedie" && !is.null(family$variance)) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i assume it handles the "fake" family created on L111 correctly? it doesn't have There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This part only handles the case when |
||
var.power <- log(family$variance(exp(1))) | ||
link.power <- log(family$linkfun(exp(1))) | ||
family <- list(family = "tweedie", link = NULL) | ||
} | ||
|
||
formula <- paste(deparse(formula), collapse = "") | ||
if (!is.null(weightCol) && weightCol == "") { | ||
|
@@ -111,7 +144,8 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), | |
# For known families, Gamma is upper-cased | ||
jobj <- callJStatic("org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper", | ||
"fit", formula, data@sdf, tolower(family$family), family$link, | ||
tol, as.integer(maxIter), weightCol, regParam) | ||
tol, as.integer(maxIter), weightCol, regParam, | ||
as.double(var.power), as.double(link.power)) | ||
new("GeneralizedLinearRegressionModel", jobj = jobj) | ||
}) | ||
|
||
|
@@ -126,11 +160,13 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), | |
#' the result of a call to a family function. Refer R family at | ||
#' \url{https://stat.ethz.ch/R-manual/R-devel/library/stats/html/family.html}. | ||
#' Currently these families are supported: \code{binomial}, \code{gaussian}, | ||
#' \code{Gamma}, and \code{poisson}. | ||
#' \code{poisson}, \code{Gamma}, and \code{tweedie}. | ||
#' @param weightCol the weight column name. If this is not set or \code{NULL}, we treat all instance | ||
#' weights as 1.0. | ||
#' @param epsilon positive convergence tolerance of iterations. | ||
#' @param maxit integer giving the maximal number of IRLS iterations. | ||
#' @param var.power the index of the power variance function in the Tweedie family. | ||
#' @param link.power the index of the power link function in the Tweedie family. | ||
#' @return \code{glm} returns a fitted generalized linear model. | ||
#' @rdname glm | ||
#' @export | ||
|
@@ -145,8 +181,10 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), | |
#' @note glm since 1.5.0 | ||
#' @seealso \link{spark.glm} | ||
setMethod("glm", signature(formula = "formula", family = "ANY", data = "SparkDataFrame"), | ||
function(formula, family = gaussian, data, epsilon = 1e-6, maxit = 25, weightCol = NULL) { | ||
spark.glm(data, formula, family, tol = epsilon, maxIter = maxit, weightCol = weightCol) | ||
function(formula, family = gaussian, data, epsilon = 1e-6, maxit = 25, weightCol = NULL, | ||
var.power = 0.0, link.power = 1.0 - var.power) { | ||
spark.glm(data, formula, family, tol = epsilon, maxIter = maxit, weightCol = weightCol, | ||
var.power = var.power, link.power = link.power) | ||
}) | ||
|
||
# Returns the summary of a model produced by glm() or spark.glm(), similarly to R's summary(). | ||
|
@@ -172,9 +210,10 @@ setMethod("summary", signature(object = "GeneralizedLinearRegressionModel"), | |
deviance <- callJMethod(jobj, "rDeviance") | ||
df.null <- callJMethod(jobj, "rResidualDegreeOfFreedomNull") | ||
df.residual <- callJMethod(jobj, "rResidualDegreeOfFreedom") | ||
aic <- callJMethod(jobj, "rAic") | ||
iter <- callJMethod(jobj, "rNumIterations") | ||
family <- callJMethod(jobj, "rFamily") | ||
aic <- callJMethod(jobj, "rAic") | ||
if (family == "tweedie" && aic == 0) aic <- NA | ||
deviance.resid <- if (is.loaded) { | ||
NULL | ||
} else { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please update L56 for documentation. Also we should update the programming guide and vignettes too