-
Notifications
You must be signed in to change notification settings - Fork 28.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Spark-12732][ML] bug fix in linear regression train #10702
Changes from 7 commits
3e23479
23ce5f3
5803bd1
e83b822
0b16353
c0744d8
2480dc1
fd7eb99
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -74,7 +74,8 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String | |
/** | ||
* Set the regularization parameter. | ||
* Default is 0.0. | ||
* @group setParam | ||
* | ||
* @group setParam | ||
*/ | ||
@Since("1.3.0") | ||
def setRegParam(value: Double): this.type = set(regParam, value) | ||
|
@@ -83,7 +84,8 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String | |
/** | ||
* Set if we should fit the intercept | ||
* Default is true. | ||
* @group setParam | ||
* | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto |
||
* @group setParam | ||
*/ | ||
@Since("1.5.0") | ||
def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value) | ||
|
@@ -96,7 +98,8 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String | |
* the models should be always converged to the same solution when no regularization | ||
* is applied. In R's GLMNET package, the default behavior is true as well. | ||
* Default is true. | ||
* @group setParam | ||
* | ||
* @group setParam | ||
*/ | ||
@Since("1.5.0") | ||
def setStandardization(value: Boolean): this.type = set(standardization, value) | ||
|
@@ -107,7 +110,8 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String | |
* For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty. | ||
* For 0 < alpha < 1, the penalty is a combination of L1 and L2. | ||
* Default is 0.0 which is an L2 penalty. | ||
* @group setParam | ||
* | ||
* @group setParam | ||
*/ | ||
@Since("1.4.0") | ||
def setElasticNetParam(value: Double): this.type = set(elasticNetParam, value) | ||
|
@@ -116,7 +120,8 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String | |
/** | ||
* Set the maximum number of iterations. | ||
* Default is 100. | ||
* @group setParam | ||
* | ||
* @group setParam | ||
*/ | ||
@Since("1.3.0") | ||
def setMaxIter(value: Int): this.type = set(maxIter, value) | ||
|
@@ -126,7 +131,8 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String | |
* Set the convergence tolerance of iterations. | ||
* Smaller value will lead to higher accuracy with the cost of more iterations. | ||
* Default is 1E-6. | ||
* @group setParam | ||
* | ||
* @group setParam | ||
*/ | ||
@Since("1.4.0") | ||
def setTol(value: Double): this.type = set(tol, value) | ||
|
@@ -136,7 +142,8 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String | |
* Whether to over-/under-sample training instances according to the given weights in weightCol. | ||
* If empty, all instances are treated equally (weight 1.0). | ||
* Default is empty, so all instances have weight one. | ||
* @group setParam | ||
* | ||
* @group setParam | ||
*/ | ||
@Since("1.6.0") | ||
def setWeightCol(value: String): this.type = set(weightCol, value) | ||
|
@@ -150,7 +157,8 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String | |
* solution to the linear regression problem. | ||
* The default value is "auto" which means that the solver algorithm is | ||
* selected automatically. | ||
* @group setParam | ||
* | ||
* @group setParam | ||
*/ | ||
@Since("1.6.0") | ||
def setSolver(value: String): this.type = set(solver, value) | ||
|
@@ -219,33 +227,49 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String | |
} | ||
|
||
val yMean = ySummarizer.mean(0) | ||
val yStd = math.sqrt(ySummarizer.variance(0)) | ||
|
||
// If the yStd is zero, then the intercept is yMean with zero coefficient; | ||
// as a result, training is not needed. | ||
if (yStd == 0.0) { | ||
logWarning(s"The standard deviation of the label is zero, so the coefficients will be " + | ||
s"zeros and the intercept will be the mean of the label; as a result, " + | ||
s"training is not needed.") | ||
if (handlePersistence) instances.unpersist() | ||
val coefficients = Vectors.sparse(numFeatures, Seq()) | ||
val intercept = yMean | ||
|
||
val model = new LinearRegressionModel(uid, coefficients, intercept) | ||
// Handle possible missing or invalid prediction columns | ||
val (summaryModel, predictionColName) = model.findSummaryModelAndPredictionCol() | ||
|
||
val trainingSummary = new LinearRegressionTrainingSummary( | ||
summaryModel.transform(dataset), | ||
predictionColName, | ||
$(labelCol), | ||
model, | ||
Array(0D), | ||
$(featuresCol), | ||
Array(0D)) | ||
return copyValues(model.setSummary(trainingSummary)) | ||
val rawYStd = math.sqrt(ySummarizer.variance(0)) | ||
if (rawYStd == 0.0) { | ||
if ($(fitIntercept) || yMean==0.0) { | ||
// If the rawYStd is zero and fitIntercept=true, then the intercept is yMean with | ||
// zero coefficient; as a result, training is not needed. | ||
// Also, if yMean==0 and rawYStd==0, all the coefficients are zero regardless of | ||
// the fitIntercept | ||
if (yMean == 0.0) { | ||
logWarning(s"Mean and standard deviation of the label are zero, so the coefficients " + | ||
s"and the intercept will all be zero; as a result, training is not needed.") | ||
} else { | ||
logWarning(s"The standard deviation of the label is zero, so the coefficients will be " + | ||
s"zeros and the intercept will be the mean of the label; as a result, " + | ||
s"training is not needed.") | ||
} | ||
if (handlePersistence) instances.unpersist() | ||
val coefficients = Vectors.sparse(numFeatures, Seq()) | ||
val intercept = yMean | ||
|
||
val model = new LinearRegressionModel(uid, coefficients, intercept) | ||
// Handle possible missing or invalid prediction columns | ||
val (summaryModel, predictionColName) = model.findSummaryModelAndPredictionCol() | ||
|
||
val trainingSummary = new LinearRegressionTrainingSummary( | ||
summaryModel.transform(dataset), | ||
predictionColName, | ||
$(labelCol), | ||
model, | ||
Array(0D), | ||
$(featuresCol), | ||
Array(0D)) | ||
return copyValues(model.setSummary(trainingSummary)) | ||
} else { | ||
require($(regParam) == 0.0, "The standard deviation of the label is zero. " + | ||
"Model cannot be regularized.") | ||
logWarning(s"The standard deviation of the label is zero. " + | ||
"Consider setting fitIntercept=true.") | ||
} | ||
} | ||
|
||
// if y is constant (rawYStd is zero), then y cannot be scaled. In this case | ||
// setting yStd=1.0 ensures that y is not scaled anymore in l-bfgs algorithm. | ||
val yStd = if (rawYStd > 0) rawYStd else math.abs(yMean) | ||
val featuresMean = featuresSummarizer.mean.toArray | ||
val featuresStd = featuresSummarizer.variance.toArray.map(math.sqrt) | ||
|
||
|
@@ -398,7 +422,8 @@ class LinearRegressionModel private[ml] ( | |
|
||
/** | ||
* Evaluates the model on a testset. | ||
* @param dataset Test dataset to evaluate model on. | ||
* | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I really don't know how this has happened. :) |
||
* @param dataset Test dataset to evaluate model on. | ||
*/ | ||
// TODO: decide on a good name before exposing to public API | ||
private[regression] def evaluate(dataset: DataFrame): LinearRegressionSummary = { | ||
|
@@ -496,7 +521,8 @@ object LinearRegressionModel extends MLReadable[LinearRegressionModel] { | |
* :: Experimental :: | ||
* Linear regression training results. Currently, the training summary ignores the | ||
* training coefficients except for the objective trace. | ||
* @param predictions predictions outputted by the model's `transform` method. | ||
* | ||
* @param predictions predictions outputted by the model's `transform` method. | ||
* @param objectiveHistory objective function (scaled loss + regularization) at each iteration. | ||
*/ | ||
@Since("1.5.0") | ||
|
@@ -520,7 +546,8 @@ class LinearRegressionTrainingSummary private[regression] ( | |
/** | ||
* :: Experimental :: | ||
* Linear regression results evaluated on a dataset. | ||
* @param predictions predictions outputted by the model's `transform` method. | ||
* | ||
* @param predictions predictions outputted by the model's `transform` method. | ||
*/ | ||
@Since("1.5.0") | ||
@Experimental | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -37,6 +37,8 @@ class LinearRegressionSuite | |
@transient var datasetWithDenseFeatureWithoutIntercept: DataFrame = _ | ||
@transient var datasetWithSparseFeature: DataFrame = _ | ||
@transient var datasetWithWeight: DataFrame = _ | ||
@transient var datasetWithWeightConstantLabel: DataFrame = _ | ||
@transient var datasetWithWeightZeroLabel: DataFrame = _ | ||
|
||
/* | ||
In `LinearRegressionSuite`, we will make sure that the model trained by SparkML | ||
|
@@ -92,6 +94,29 @@ class LinearRegressionSuite | |
Instance(23.0, 3.0, Vectors.dense(2.0, 11.0)), | ||
Instance(29.0, 4.0, Vectors.dense(3.0, 13.0)) | ||
), 2)) | ||
|
||
/* | ||
R code: | ||
|
||
A <- matrix(c(0, 1, 2, 3, 5, 7, 11, 13), 4, 2) | ||
b.const <- c(17, 17, 17, 17) | ||
w <- c(1, 2, 3, 4) | ||
df.const.label <- as.data.frame(cbind(A, b.const)) | ||
*/ | ||
datasetWithWeightConstantLabel = sqlContext.createDataFrame( | ||
sc.parallelize(Seq( | ||
Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), | ||
Instance(17.0, 2.0, Vectors.dense(1.0, 7.0)), | ||
Instance(17.0, 3.0, Vectors.dense(2.0, 11.0)), | ||
Instance(17.0, 4.0, Vectors.dense(3.0, 13.0)) | ||
), 2)) | ||
datasetWithWeightZeroLabel = sqlContext.createDataFrame( | ||
sc.parallelize(Seq( | ||
Instance(0.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), | ||
Instance(0.0, 2.0, Vectors.dense(1.0, 7.0)), | ||
Instance(0.0, 3.0, Vectors.dense(2.0, 11.0)), | ||
Instance(0.0, 4.0, Vectors.dense(3.0, 13.0)) | ||
), 2)) | ||
} | ||
|
||
test("params") { | ||
|
@@ -558,6 +583,86 @@ class LinearRegressionSuite | |
} | ||
} | ||
|
||
test("linear regression model with constant label") { | ||
/* | ||
R code: | ||
for (formula in c(b.const ~ . -1, b.const ~ .)) { | ||
model <- lm(formula, data=df.const.label, weights=w) | ||
print(as.vector(coef(model))) | ||
} | ||
[1] -9.221298 3.394343 | ||
[1] 17 0 0 | ||
*/ | ||
val expected = Seq( | ||
Vectors.dense(0.0, -9.221298, 3.394343), | ||
Vectors.dense(17.0, 0.0, 0.0)) | ||
|
||
Seq("auto", "l-bfgs", "normal").foreach { solver => | ||
var idx = 0 | ||
for (fitIntercept <- Seq(false, true)) { | ||
val model1 = new LinearRegression() | ||
.setFitIntercept(fitIntercept) | ||
.setWeightCol("weight") | ||
.setSolver(solver) | ||
.fit(datasetWithWeightConstantLabel) | ||
val actual1 = Vectors.dense(model1.intercept, model1.coefficients(0), | ||
model1.coefficients(1)) | ||
assert(actual1 ~== expected(idx) absTol 1e-4) | ||
|
||
val model2 = new LinearRegression() | ||
.setFitIntercept(fitIntercept) | ||
.setWeightCol("weight") | ||
.setSolver(solver) | ||
.fit(datasetWithWeightZeroLabel) | ||
val actual2 = Vectors.dense(model2.intercept, model2.coefficients(0), | ||
model2.coefficients(1)) | ||
assert(actual2 ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1e-4) | ||
idx += 1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When Will be nice to add one small test that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure how to check the size of lost history. Could you please point me to some example? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've added tests as suggested. |
||
} | ||
} | ||
} | ||
|
||
test("regularized linear regression through origin with constant label") { | ||
// The problem is ill-defined if fitIntercept=false, regParam is non-zero. | ||
// An exception is thrown in this case. | ||
Seq("auto", "l-bfgs", "normal").foreach { solver => | ||
for (standardization <- Seq(false, true)) { | ||
val model = new LinearRegression().setFitIntercept(false) | ||
.setRegParam(0.1).setStandardization(standardization).setSolver(solver) | ||
intercept[IllegalArgumentException] { | ||
model.fit(datasetWithWeightConstantLabel) | ||
} | ||
} | ||
} | ||
} | ||
|
||
test("linear regression with l-bfgs when training is not needed") { | ||
// When label is constant, l-bfgs solver returns results without training. | ||
// There are two possibilities: If the label is non-zero but constant, | ||
// and fitIntercept is true, then the model return yMean as intercept without training. | ||
// If label is all zeros, then all coefficients are zero regardless of fitIntercept, so | ||
// no training is needed. | ||
for (fitIntercept <- Seq(false, true)) { | ||
for (standardization <- Seq(false, true)) { | ||
val model1 = new LinearRegression() | ||
.setFitIntercept(fitIntercept) | ||
.setStandardization(standardization) | ||
.setWeightCol("weight") | ||
.setSolver("l-bfgs") | ||
.fit(datasetWithWeightConstantLabel) | ||
if (fitIntercept) { | ||
assert(model1.summary.objectiveHistory(0) ~== 0.0 absTol 1e-4) | ||
} | ||
val model2 = new LinearRegression() | ||
.setFitIntercept(fitIntercept) | ||
.setWeightCol("weight") | ||
.setSolver("l-bfgs") | ||
.fit(datasetWithWeightZeroLabel) | ||
assert(model2.summary.objectiveHistory(0) ~== 0.0 absTol 1e-4) | ||
} | ||
} | ||
} | ||
|
||
test("linear regression model training summary") { | ||
Seq("auto", "l-bfgs", "normal").foreach { solver => | ||
val trainer = new LinearRegression().setSolver(solver) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All the indentations you just added are off.