From 2480dc17b064165563fe30bf5e45d2d165c722d9 Mon Sep 17 00:00:00 2001 From: Imran Younus Date: Mon, 1 Feb 2016 16:36:57 -0800 Subject: [PATCH] added test for the case when results are produced without training (when label is constant) --- .../ml/regression/LinearRegressionSuite.scala | 55 +++++++++++++++++-- 1 file changed, 50 insertions(+), 5 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index c55adc10c31d7..81fc6603ccfe6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -38,6 +38,7 @@ class LinearRegressionSuite @transient var datasetWithSparseFeature: DataFrame = _ @transient var datasetWithWeight: DataFrame = _ @transient var datasetWithWeightConstantLabel: DataFrame = _ + @transient var datasetWithWeightZeroLabel: DataFrame = _ /* In `LinearRegressionSuite`, we will make sure that the model trained by SparkML @@ -109,6 +110,13 @@ class LinearRegressionSuite Instance(17.0, 3.0, Vectors.dense(2.0, 11.0)), Instance(17.0, 4.0, Vectors.dense(3.0, 13.0)) ), 2)) + datasetWithWeightZeroLabel = sqlContext.createDataFrame( + sc.parallelize(Seq( + Instance(0.0, 1.0, Vectors.dense(0.0, 5.0).toSparse), + Instance(0.0, 2.0, Vectors.dense(1.0, 7.0)), + Instance(0.0, 3.0, Vectors.dense(2.0, 11.0)), + Instance(0.0, 4.0, Vectors.dense(3.0, 13.0)) + ), 2)) } test("params") { @@ -592,21 +600,31 @@ class LinearRegressionSuite Seq("auto", "l-bfgs", "normal").foreach { solver => var idx = 0 for (fitIntercept <- Seq(false, true)) { - val model = new LinearRegression() + val model1 = new LinearRegression() .setFitIntercept(fitIntercept) .setWeightCol("weight") .setSolver(solver) .fit(datasetWithWeightConstantLabel) - val actual = Vectors.dense(model.intercept, model.coefficients(0), model.coefficients(1)) - assert(actual ~== expected(idx) absTol 1e-4) + val actual1 = Vectors.dense(model1.intercept, model1.coefficients(0), + model1.coefficients(1)) + assert(actual1 ~== expected(idx) absTol 1e-4) + + val model2 = new LinearRegression() + .setFitIntercept(fitIntercept) + .setWeightCol("weight") + .setSolver(solver) + .fit(datasetWithWeightZeroLabel) + val actual2 = Vectors.dense(model2.intercept, model2.coefficients(0), + model2.coefficients(1)) + assert(actual2 ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1e-4) idx += 1 } } } test("regularized linear regression through origin with constant label") { - // The problem is ill-defined if fitIntercept=false, regParam is non-zero and - // standardization=true. An exception is thrown in this case. + // The problem is ill-defined if fitIntercept=false, regParam is non-zero. + // An exception is thrown in this case. Seq("auto", "l-bfgs", "normal").foreach { solver => for (standardization <- Seq(false, true)) { val model = new LinearRegression().setFitIntercept(false) @@ -618,6 +636,33 @@ class LinearRegressionSuite } } + test("linear regression with l-bfgs when training is not needed") { + // When label is constant, l-bfgs solver returns results without training. + // There are two possibilities: If the label is non-zero but constant, + // and fitIntercept is true, then the model return yMean as intercept without training. + // If label is all zeros, then all coefficients are zero regardless of fitIntercept, so + // no training is needed. + for (fitIntercept <- Seq(false, true)) { + for (standardization <- Seq(false, true)) { + val model1 = new LinearRegression() + .setFitIntercept(fitIntercept) + .setStandardization(standardization) + .setWeightCol("weight") + .setSolver("l-bfgs") + .fit(datasetWithWeightConstantLabel) + if (fitIntercept) { + assert(model1.summary.objectiveHistory(0) ~== 0.0 absTol 1e-4) + } + val model2 = new LinearRegression() + .setFitIntercept(fitIntercept) + .setWeightCol("weight") + .setSolver("l-bfgs") + .fit(datasetWithWeightZeroLabel) + assert(model2.summary.objectiveHistory(0) ~== 0.0 absTol 1e-4) + } + } + } + test("linear regression model training summary") { Seq("auto", "l-bfgs", "normal").foreach { solver => val trainer = new LinearRegression().setSolver(solver)