Skip to content

Commit

Permalink
added test for the case when results are produced without training (w…
Browse files Browse the repository at this point in the history
…hen label is constant)
  • Loading branch information
iyounus committed Feb 2, 2016
1 parent c0744d8 commit 2480dc1
Showing 1 changed file with 50 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class LinearRegressionSuite
@transient var datasetWithSparseFeature: DataFrame = _
@transient var datasetWithWeight: DataFrame = _
@transient var datasetWithWeightConstantLabel: DataFrame = _
@transient var datasetWithWeightZeroLabel: DataFrame = _

/*
In `LinearRegressionSuite`, we will make sure that the model trained by SparkML
Expand Down Expand Up @@ -109,6 +110,13 @@ class LinearRegressionSuite
Instance(17.0, 3.0, Vectors.dense(2.0, 11.0)),
Instance(17.0, 4.0, Vectors.dense(3.0, 13.0))
), 2))
datasetWithWeightZeroLabel = sqlContext.createDataFrame(
sc.parallelize(Seq(
Instance(0.0, 1.0, Vectors.dense(0.0, 5.0).toSparse),
Instance(0.0, 2.0, Vectors.dense(1.0, 7.0)),
Instance(0.0, 3.0, Vectors.dense(2.0, 11.0)),
Instance(0.0, 4.0, Vectors.dense(3.0, 13.0))
), 2))
}

test("params") {
Expand Down Expand Up @@ -592,21 +600,31 @@ class LinearRegressionSuite
Seq("auto", "l-bfgs", "normal").foreach { solver =>
var idx = 0
for (fitIntercept <- Seq(false, true)) {
val model = new LinearRegression()
val model1 = new LinearRegression()
.setFitIntercept(fitIntercept)
.setWeightCol("weight")
.setSolver(solver)
.fit(datasetWithWeightConstantLabel)
val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1))
assert(actual ~== expected(idx) absTol 1e-4)
val actual1 = Vectors.dense(model1.intercept, model1.coefficients(0),
model1.coefficients(1))
assert(actual1 ~== expected(idx) absTol 1e-4)

val model2 = new LinearRegression()
.setFitIntercept(fitIntercept)
.setWeightCol("weight")
.setSolver(solver)
.fit(datasetWithWeightZeroLabel)
val actual2 = Vectors.dense(model2.intercept, model2.coefficients(0),
model2.coefficients(1))
assert(actual2 ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1e-4)
idx += 1
}
}
}

test("regularized linear regression through origin with constant label") {
// The problem is ill-defined if fitIntercept=false, regParam is non-zero and
// standardization=true. An exception is thrown in this case.
// The problem is ill-defined if fitIntercept=false, regParam is non-zero.
// An exception is thrown in this case.
Seq("auto", "l-bfgs", "normal").foreach { solver =>
for (standardization <- Seq(false, true)) {
val model = new LinearRegression().setFitIntercept(false)
Expand All @@ -618,6 +636,33 @@ class LinearRegressionSuite
}
}

test("linear regression with l-bfgs when training is not needed") {
// When label is constant, l-bfgs solver returns results without training.
// There are two possibilities: If the label is non-zero but constant,
// and fitIntercept is true, then the model return yMean as intercept without training.
// If label is all zeros, then all coefficients are zero regardless of fitIntercept, so
// no training is needed.
for (fitIntercept <- Seq(false, true)) {
for (standardization <- Seq(false, true)) {
val model1 = new LinearRegression()
.setFitIntercept(fitIntercept)
.setStandardization(standardization)
.setWeightCol("weight")
.setSolver("l-bfgs")
.fit(datasetWithWeightConstantLabel)
if (fitIntercept) {
assert(model1.summary.objectiveHistory(0) ~== 0.0 absTol 1e-4)
}
val model2 = new LinearRegression()
.setFitIntercept(fitIntercept)
.setWeightCol("weight")
.setSolver("l-bfgs")
.fit(datasetWithWeightZeroLabel)
assert(model2.summary.objectiveHistory(0) ~== 0.0 absTol 1e-4)
}
}
}

test("linear regression model training summary") {
Seq("auto", "l-bfgs", "normal").foreach { solver =>
val trainer = new LinearRegression().setSolver(solver)
Expand Down

0 comments on commit 2480dc1

Please sign in to comment.