diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala index e4e710a726308..643ce63f62eeb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala @@ -42,7 +42,7 @@ abstract class GeneralizedLinearModel(val weights: Vector, val intercept: Double * @param weightMatrix Column vector containing the weights of the model * @param intercept Intercept of the model. */ - def predictPoint(dataMatrix: Vector, weightMatrix: Vector, intercept: Double): Double + protected def predictPoint(dataMatrix: Vector, weightMatrix: Vector, intercept: Double): Double /** * Predict values for the given data set using the model trained. @@ -116,6 +116,7 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] run(input, initialWeights) } + /** Prepends one to the input vector. */ private def prependOne(vector: Vector): Vector = { val vectorWithIntercept = vector match { case dv: BDV[Double] => BDV.vertcat(BDV.ones(1), dv) @@ -154,8 +155,6 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] val intercept = if (addIntercept) brzWeightsWithIntercept(0) else 0.0 val brzWeights = if (addIntercept) brzWeightsWithIntercept(1 to -1) else brzWeightsWithIntercept - val model = createModel(Vectors.fromBreeze(brzWeights), intercept) - - model + createModel(Vectors.fromBreeze(brzWeights), intercept) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala index fb2bc9b92a51c..e397a573079e5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala @@ -36,8 +36,10 @@ class LassoModel( extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable { - override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix, - intercept: Double) = { + override protected def predictPoint( + dataMatrix: DoubleMatrix, + weightMatrix: DoubleMatrix, + intercept: Double): Double = { dataMatrix.dot(weightMatrix) + intercept } } @@ -66,7 +68,7 @@ class LassoWithSGD private ( .setMiniBatchFraction(miniBatchFraction) // We don't want to penalize the intercept, so set this to false. - setIntercept(false) + super.setIntercept(false) var yMean = 0.0 var xColMean: DoubleMatrix = _ @@ -77,10 +79,16 @@ class LassoWithSGD private ( */ def this() = this(1.0, 100, 1.0, 1.0) - def createModel(weights: Array[Double], intercept: Double) = { - val weightsMat = new DoubleMatrix(weights.length + 1, 1, (Array(intercept) ++ weights):_*) + override def setIntercept(addIntercept: Boolean): this.type = { + // TODO: Support adding intercept. + if (addIntercept) throw new UnsupportedOperationException("Adding intercept is not supported.") + this + } + + override protected def createModel(weights: Array[Double], intercept: Double) = { + val weightsMat = new DoubleMatrix(weights.length, 1, weights: _*) val weightsScaled = weightsMat.div(xColSd) - val interceptScaled = yMean - (weightsMat.transpose().mmul(xColMean.div(xColSd)).get(0)) + val interceptScaled = yMean - weightsMat.transpose().mmul(xColMean.div(xColSd)).get(0) new LassoModel(weightsScaled.data, interceptScaled) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala index 8ee40addb25d9..b4aafbe8bcaff 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala @@ -31,13 +31,14 @@ import org.jblas.DoubleMatrix * @param intercept Intercept computed for this model. */ class LinearRegressionModel( - override val weights: Array[Double], - override val intercept: Double) - extends GeneralizedLinearModel(weights, intercept) - with RegressionModel with Serializable { - - override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix, - intercept: Double) = { + override val weights: Array[Double], + override val intercept: Double) + extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable { + + override protected def predictPoint( + dataMatrix: DoubleMatrix, + weightMatrix: DoubleMatrix, + intercept: Double): Double = { dataMatrix.dot(weightMatrix) + intercept } } @@ -55,8 +56,7 @@ class LinearRegressionWithSGD private ( var stepSize: Double, var numIterations: Int, var miniBatchFraction: Double) - extends GeneralizedLinearAlgorithm[LinearRegressionModel] - with Serializable { + extends GeneralizedLinearAlgorithm[LinearRegressionModel] with Serializable { val gradient = new LeastSquaresGradient() val updater = new SimpleUpdater() @@ -69,7 +69,7 @@ class LinearRegressionWithSGD private ( */ def this() = this(1.0, 100, 1.0) - def createModel(weights: Array[Double], intercept: Double) = { + override protected def createModel(weights: Array[Double], intercept: Double) = { new LinearRegressionModel(weights, intercept) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala index b8ce4602b53ef..d5371f5c33414 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala @@ -35,8 +35,10 @@ class RidgeRegressionModel( extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable { - override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix, - intercept: Double) = { + override protected def predictPoint( + dataMatrix: DoubleMatrix, + weightMatrix: DoubleMatrix, + intercept: Double): Double = { dataMatrix.dot(weightMatrix) + intercept } } @@ -66,7 +68,7 @@ class RidgeRegressionWithSGD private ( .setMiniBatchFraction(miniBatchFraction) // We don't want to penalize the intercept in RidgeRegression, so set this to false. - setIntercept(false) + super.setIntercept(false) var yMean = 0.0 var xColMean: DoubleMatrix = _ @@ -77,8 +79,14 @@ class RidgeRegressionWithSGD private ( */ def this() = this(1.0, 100, 1.0, 1.0) - def createModel(weights: Array[Double], intercept: Double) = { - val weightsMat = new DoubleMatrix(weights.length + 1, 1, (Array(intercept) ++ weights):_*) + override def setIntercept(addIntercept: Boolean): this.type = { + // TODO: Support adding intercept. + if (addIntercept) throw new UnsupportedOperationException("Adding intercept is not supported.") + this + } + + override protected def createModel(weights: Array[Double], intercept: Double) = { + val weightsMat = new DoubleMatrix(weights.length, 1, weights: _*) val weightsScaled = weightsMat.div(xColSd) val interceptScaled = yMean - weightsMat.transpose().mmul(xColMean.div(xColSd)).get(0) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala index 281f9df36ddb3..5d251bcbf35db 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.mllib.regression -import org.scalatest.BeforeAndAfterAll import org.scalatest.FunSuite import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext} @@ -57,4 +56,29 @@ class LinearRegressionSuite extends FunSuite with LocalSparkContext { // Test prediction on Array. validatePrediction(validationData.map(row => model.predict(row.features)), validationData) } + + // Test if we can correctly learn Y = 10*X1 + 10*X2 + test("linear regression without intercept") { + val testRDD = sc.parallelize(LinearDataGenerator.generateLinearInput( + 0.0, Array(10.0, 10.0), 100, 42), 2).cache() + val linReg = new LinearRegressionWithSGD().setIntercept(false) + linReg.optimizer.setNumIterations(1000).setStepSize(1.0) + + val model = linReg.run(testRDD) + + assert(model.intercept === 0.0) + assert(model.weights.length === 2) + assert(model.weights(0) >= 9.0 && model.weights(0) <= 11.0) + assert(model.weights(1) >= 9.0 && model.weights(1) <= 11.0) + + val validationData = LinearDataGenerator.generateLinearInput( + 0.0, Array(10.0, 10.0), 100, 17) + val validationRDD = sc.parallelize(validationData, 2).cache() + + // Test prediction on RDD. + validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) + + // Test prediction on Array. + validatePrediction(validationData.map(row => model.predict(row.features)), validationData) + } }