-
Notifications
You must be signed in to change notification settings - Fork 28.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-19391][SparkR][ML] Tweedie GLM API for SparkR #16729
Conversation
Test build #72111 has finished for PR 16729 at commit
|
Test build #72112 has finished for PR 16729 at commit
|
Test build #72114 has finished for PR 16729 at commit
|
@@ -77,6 +77,18 @@ test_that("spark.glm and predict", { | |||
out <- capture.output(print(summary(model))) | |||
expect_true(any(grepl("Dispersion parameter for gamma family", out))) | |||
|
|||
# tweedie family | |||
require(statmod) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can't require this as of now - we would need to update Jenkins otherwise it would fail like it is right now, because on Jenkins we don't have the statmod package
spark.glm and predict (@test_mllib_regression.R#81) - there is no package called 'statmod'
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in fact, library
is more correct here as it will fail if the package isn't installed, instead of the warning we see and failing later.
R/pkg/R/mllib_regression.R
Outdated
@@ -84,6 +84,12 @@ setClass("IsotonicRegressionModel", representation(jobj = "jobj")) | |||
#' # can also read back the saved model and print | |||
#' savedModel <- read.ml(path) | |||
#' summary(savedModel) | |||
#' | |||
#' # fit tweedie model | |||
#' require(statmod) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
generally people use library
instead of require
I did look into this... I think it's great if Is there a way to expose this in the API without having a hard dependency on tweedie family defined in |
@@ -84,6 +84,12 @@ setClass("IsotonicRegressionModel", representation(jobj = "jobj")) | |||
#' # can also read back the saved model and print | |||
#' savedModel <- read.ml(path) | |||
#' summary(savedModel) | |||
#' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please update L56 for documentation. Also we should update the programming guide and vignettes too
R/pkg/R/mllib_regression.R
Outdated
@@ -109,7 +125,8 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), | |||
# For known families, Gamma is upper-cased | |||
jobj <- callJStatic("org.apache.spark.ml.r.GeneralizedLinearRegressionWrapper", | |||
"fit", formula, data@sdf, tolower(family$family), family$link, | |||
tol, as.integer(maxIter), as.character(weightCol), regParam) | |||
tol, as.integer(maxIter), as.character(weightCol), regParam, | |||
as.double(variancePower), as.double(linkPower)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we probably don't need to as.double
here since it is either set to fixed values (L116) or from a calculation (L112). Instead, we should check var.power
and link.power
are within the correct range - not sure if the tweedie function does that.
prediction <- predict(model, training) | ||
expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double") | ||
vals <- collect(select(prediction, "prediction")) | ||
rVals <- suppressWarnings(predict( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need suppressWarnings
here?
.setFitIntercept(rFormula.hasIntercept) | ||
.setTol(tol) | ||
.setMaxIter(maxIter) | ||
.setWeightCol(weightCol) | ||
.setRegParam(regParam) | ||
.setFeaturesCol(rFormula.getFeaturesCol) | ||
// set variancePower and linkPower if family is tweedie; otherwise, set link function | ||
if (family.toLowerCase == "tweedie") { | ||
glr = glr.setVariancePower(variancePower).setLinkPower(linkPower) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need to assign glr =
here? generally the setter method will update the instance
model <- spark.glm(training, Sepal_Width ~ Sepal_Length + Species, | ||
family = tweedie(var.power = 1.2, link.power = 1.0)) | ||
prediction <- predict(model, training) | ||
expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you might want to use dtypes
instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would you remind me what dtypes
is and why we need to use it here? Thanks.
@@ -143,7 +150,12 @@ private[r] object GeneralizedLinearRegressionWrapper | |||
val rDeviance: Double = summary.deviance | |||
val rResidualDegreeOfFreedomNull: Long = summary.residualDegreeOfFreedomNull | |||
val rResidualDegreeOfFreedom: Long = summary.residualDegreeOfFreedom | |||
val rAic: Double = summary.aic | |||
val rAic: Double = if (family.toLowerCase == "tweedie" && | |||
!Array(0.0, 1.0, 2.0).contains(variancePower)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we are comparing double values here, do you know how reliable is this? should it have epsilon in the comparison?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the suggestion. Changed it to comparison instead.
Test build #72131 has finished for PR 16729 at commit
|
@felixcheung Thanks so much for your quick and detailed review. I have made a new commit that removed dependency on |
Test build #72132 has finished for PR 16729 at commit
|
Test build #72136 has finished for PR 16729 at commit
|
Test build #72137 has finished for PR 16729 at commit
|
Test build #73964 has finished for PR 16729 at commit
|
@felixcheung Sorry for taking so long for this update. I think your first suggestion makes most sense, i.e., we do not expose the internal I have made this to work. The following shows it now works both when statmod is not loaded (using Let me know if there is any other issues. Thanks.
|
Thanks for working on this - to clarify, this only works with |
@felixcheung Yes, the SparkR
|
@felixcheung Could you take a look at this new fix when you get a chance? Thanks. |
yea - I'm sorry if it was confusing - I was referring to In the past there were concerns of exposing methods privately, so I'm not sure if we want to encourage accessing the tweedie function that way? Perhaps then # 3 would be the only option (and it would be like Python) |
@felixcheung If we go with # 3, do we still want to compatibility with statmod::tweedie? It's confusing to have two different ways of specifying the same model. |
@actuaryzhang that's true, it's not ideal. This is somewhat an unusual case for R for several reasons.
But since in this method we have this odd design where we take a R function (R
In the case Also the other concern is we do have another signature |
@felixcheung OK, new implementation of # 3. Now works in two ways:
They work for both
|
Test build #74242 has finished for PR 16729 at commit
|
Test build #74243 has finished for PR 16729 at commit
|
One other change I could make is to change |
I like the example in this implementation! thanks |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looking good, thanks for working on this!
you can follow up with programming guide change separately if you want.
R/pkg/vignettes/sparkr-vignettes.Rmd
Outdated
summary(tweedieGLM1) | ||
``` | ||
We can try other distributions in the tweedie family, for example, a compound Poisson distribution with a log link: | ||
```{r} | ||
tweedieGLM2 <- spark.glm(carsDF, mpg ~ wt + hp, family = SparkR:::tweedie(1.2, 0.0)) | ||
tweedieGLM2 <- spark.glm(carsDF, mpg ~ wt + hp, family = "tweedie", | ||
variancePower = 1.2, linkPower = 0.0) | ||
summary(tweedieGLM2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's add an example with statmod
too? either here or in roxygen2 API doc (later might be a better place?)
R/pkg/R/mllib_regression.R
Outdated
#' | ||
#' Note that there are two ways to specify the tweedie family. | ||
#' a) Set \code{family = "tweedie"} and specify the variancePower and linkPower | ||
#' b) When package \code{statmod} is loaded, the tweedie family is specified using the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
roxygen2 will collapse these two lines - suggest separating with ;
or use \item
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
basically roxygen2 trims all the "insignificant whitespace"
@@ -100,6 +120,12 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), | |||
print(family) | |||
stop("'family' not recognized") | |||
} | |||
# Handle when family = statmod::tweedie() | |||
if (tolower(family$family) == "tweedie" && !is.null(family$variance)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i assume it handles the "fake" family created on L111 correctly? it doesn't have variance
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This part only handles the case when statmod::tweedie
is specified: it retrieves the var.power
and link.power
and construct a list with family name and link name to be used.
The check for non-null variance
is to skip handling the "fake" family. All we need when specifying family = "tweedie"
is just a list with family name and link name.
R/pkg/R/mllib_regression.R
Outdated
family <- get(family, mode = "function", envir = parent.frame()) | ||
# Handle when family = "tweedie" | ||
if (tolower(family) == "tweedie") { | ||
family <- list(family = "tweedie", link = "linkNotUsed") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I think you can set link = NULL
R/pkg/R/mllib_regression.R
Outdated
if (tolower(family$family) == "tweedie" && !is.null(family$variance)) { | ||
variancePower <- log(family$variance(exp(1))) | ||
linkPower <- log(family$linkfun(exp(1))) | ||
family <- list(family = "tweedie", link = "linkNotUsed") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto, link = NULL
R/pkg/R/mllib_regression.R
Outdated
#' | ||
#' # fit tweedie model | ||
#' model <- spark.glm(df, Freq ~ Sex + Age, family = "tweedie", | ||
#' variancePower = 1.2, linkPower = 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you add an example with statmod?
@felixcheung Thanks for the feedback. Made a new commit that
Let me know if there is anything else needed. |
Test build #74417 has finished for PR 16729 at commit
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looking good! just this earlier comment #16729 (comment)
Sorry that I forgot to address that comment. Fixed now. |
Test build #74423 has finished for PR 16729 at commit
|
@felixcheung Could you merge this please? Thanks! |
merged to master |
What changes were proposed in this pull request?
Port Tweedie GLM #16344 to SparkR
@felixcheung @yanboliang
How was this patch tested?
new test in SparkR