Skip to content

Commit

Permalink
[SPARK-1327] GLM needs to check addIntercept for intercept and weights
Browse files Browse the repository at this point in the history
GLM needs to check addIntercept for intercept and weights. The current implementation always uses the first weight as intercept. Added a test for training without adding intercept.

JIRA: https://spark-project.atlassian.net/browse/SPARK-1327

Author: Xiangrui Meng <[email protected]>

Closes #236 from mengxr/glm and squashes the following commits:

bcac1ac [Xiangrui Meng] add two tests to ensure {Lasso, Ridge}.setIntercept will throw an exceptions
a104072 [Xiangrui Meng] remove protected to be compatible with 0.9
0e57aa4 [Xiangrui Meng] update Lasso and RidgeRegression to parse the weights correctly from GLM mark createModel protected mark predictPoint protected
d7f629f [Xiangrui Meng] fix a bug in GLM when intercept is not used
  • Loading branch information
mengxr authored and tdas committed Mar 27, 2014
1 parent 1fa48d9 commit d679843
Show file tree
Hide file tree
Showing 7 changed files with 86 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -136,25 +136,28 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]

// Prepend an extra variable consisting of all 1.0's for the intercept.
val data = if (addIntercept) {
input.map(labeledPoint => (labeledPoint.label, labeledPoint.features.+:(1.0)))
input.map(labeledPoint => (labeledPoint.label, 1.0 +: labeledPoint.features))
} else {
input.map(labeledPoint => (labeledPoint.label, labeledPoint.features))
}

val initialWeightsWithIntercept = if (addIntercept) {
initialWeights.+:(1.0)
0.0 +: initialWeights
} else {
initialWeights
}

val weights = optimizer.optimize(data, initialWeightsWithIntercept)
val intercept = weights(0)
val weightsScaled = weights.tail
val weightsWithIntercept = optimizer.optimize(data, initialWeightsWithIntercept)

val model = createModel(weightsScaled, intercept)
val (intercept, weights) = if (addIntercept) {
(weightsWithIntercept(0), weightsWithIntercept.tail)
} else {
(0.0, weightsWithIntercept)
}

logInfo("Final weights " + weights.mkString(","))
logInfo("Final intercept " + intercept)

logInfo("Final model weights " + model.weights.mkString(","))
logInfo("Final model intercept " + model.intercept)
model
createModel(weights, intercept)
}
}
20 changes: 14 additions & 6 deletions mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@ class LassoModel(
extends GeneralizedLinearModel(weights, intercept)
with RegressionModel with Serializable {

override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix,
intercept: Double) = {
override def predictPoint(
dataMatrix: DoubleMatrix,
weightMatrix: DoubleMatrix,
intercept: Double): Double = {
dataMatrix.dot(weightMatrix) + intercept
}
}
Expand Down Expand Up @@ -66,7 +68,7 @@ class LassoWithSGD private (
.setMiniBatchFraction(miniBatchFraction)

// We don't want to penalize the intercept, so set this to false.
setIntercept(false)
super.setIntercept(false)

var yMean = 0.0
var xColMean: DoubleMatrix = _
Expand All @@ -77,10 +79,16 @@ class LassoWithSGD private (
*/
def this() = this(1.0, 100, 1.0, 1.0)

def createModel(weights: Array[Double], intercept: Double) = {
val weightsMat = new DoubleMatrix(weights.length + 1, 1, (Array(intercept) ++ weights):_*)
override def setIntercept(addIntercept: Boolean): this.type = {
// TODO: Support adding intercept.
if (addIntercept) throw new UnsupportedOperationException("Adding intercept is not supported.")
this
}

override def createModel(weights: Array[Double], intercept: Double) = {
val weightsMat = new DoubleMatrix(weights.length, 1, weights: _*)
val weightsScaled = weightsMat.div(xColSd)
val interceptScaled = yMean - (weightsMat.transpose().mmul(xColMean.div(xColSd)).get(0))
val interceptScaled = yMean - weightsMat.transpose().mmul(xColMean.div(xColSd)).get(0)

new LassoModel(weightsScaled.data, interceptScaled)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,14 @@ import org.jblas.DoubleMatrix
* @param intercept Intercept computed for this model.
*/
class LinearRegressionModel(
override val weights: Array[Double],
override val intercept: Double)
extends GeneralizedLinearModel(weights, intercept)
with RegressionModel with Serializable {

override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix,
intercept: Double) = {
override val weights: Array[Double],
override val intercept: Double)
extends GeneralizedLinearModel(weights, intercept) with RegressionModel with Serializable {

override def predictPoint(
dataMatrix: DoubleMatrix,
weightMatrix: DoubleMatrix,
intercept: Double): Double = {
dataMatrix.dot(weightMatrix) + intercept
}
}
Expand All @@ -55,8 +56,7 @@ class LinearRegressionWithSGD private (
var stepSize: Double,
var numIterations: Int,
var miniBatchFraction: Double)
extends GeneralizedLinearAlgorithm[LinearRegressionModel]
with Serializable {
extends GeneralizedLinearAlgorithm[LinearRegressionModel] with Serializable {

val gradient = new LeastSquaresGradient()
val updater = new SimpleUpdater()
Expand All @@ -69,7 +69,7 @@ class LinearRegressionWithSGD private (
*/
def this() = this(1.0, 100, 1.0)

def createModel(weights: Array[Double], intercept: Double) = {
override def createModel(weights: Array[Double], intercept: Double) = {
new LinearRegressionModel(weights, intercept)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@ class RidgeRegressionModel(
extends GeneralizedLinearModel(weights, intercept)
with RegressionModel with Serializable {

override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix,
intercept: Double) = {
override def predictPoint(
dataMatrix: DoubleMatrix,
weightMatrix: DoubleMatrix,
intercept: Double): Double = {
dataMatrix.dot(weightMatrix) + intercept
}
}
Expand Down Expand Up @@ -67,7 +69,7 @@ class RidgeRegressionWithSGD private (
.setMiniBatchFraction(miniBatchFraction)

// We don't want to penalize the intercept in RidgeRegression, so set this to false.
setIntercept(false)
super.setIntercept(false)

var yMean = 0.0
var xColMean: DoubleMatrix = _
Expand All @@ -78,8 +80,14 @@ class RidgeRegressionWithSGD private (
*/
def this() = this(1.0, 100, 1.0, 1.0)

def createModel(weights: Array[Double], intercept: Double) = {
val weightsMat = new DoubleMatrix(weights.length + 1, 1, (Array(intercept) ++ weights):_*)
override def setIntercept(addIntercept: Boolean): this.type = {
// TODO: Support adding intercept.
if (addIntercept) throw new UnsupportedOperationException("Adding intercept is not supported.")
this
}

override def createModel(weights: Array[Double], intercept: Double) = {
val weightsMat = new DoubleMatrix(weights.length, 1, weights: _*)
val weightsScaled = weightsMat.div(xColSd)
val interceptScaled = yMean - weightsMat.transpose().mmul(xColMean.div(xColSd)).get(0)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,8 @@

package org.apache.spark.mllib.regression


import org.scalatest.BeforeAndAfterAll
import org.scalatest.FunSuite

import org.apache.spark.SparkContext
import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext}

class LassoSuite extends FunSuite with LocalSparkContext {
Expand Down Expand Up @@ -104,4 +101,10 @@ class LassoSuite extends FunSuite with LocalSparkContext {
// Test prediction on Array.
validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
}

test("do not support intercept") {
intercept[UnsupportedOperationException] {
new LassoWithSGD().setIntercept(true)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.spark.mllib.regression

import org.scalatest.BeforeAndAfterAll
import org.scalatest.FunSuite

import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext}
Expand Down Expand Up @@ -57,4 +56,29 @@ class LinearRegressionSuite extends FunSuite with LocalSparkContext {
// Test prediction on Array.
validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
}

// Test if we can correctly learn Y = 10*X1 + 10*X2
test("linear regression without intercept") {
val testRDD = sc.parallelize(LinearDataGenerator.generateLinearInput(
0.0, Array(10.0, 10.0), 100, 42), 2).cache()
val linReg = new LinearRegressionWithSGD().setIntercept(false)
linReg.optimizer.setNumIterations(1000).setStepSize(1.0)

val model = linReg.run(testRDD)

assert(model.intercept === 0.0)
assert(model.weights.length === 2)
assert(model.weights(0) >= 9.0 && model.weights(0) <= 11.0)
assert(model.weights(1) >= 9.0 && model.weights(1) <= 11.0)

val validationData = LinearDataGenerator.generateLinearInput(
0.0, Array(10.0, 10.0), 100, 17)
val validationRDD = sc.parallelize(validationData, 2).cache()

// Test prediction on RDD.
validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData)

// Test prediction on Array.
validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,11 @@

package org.apache.spark.mllib.regression


import org.jblas.DoubleMatrix
import org.scalatest.BeforeAndAfterAll
import org.scalatest.FunSuite

import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext}


class RidgeRegressionSuite extends FunSuite with LocalSparkContext {

def predictionError(predictions: Seq[Double], input: Seq[LabeledPoint]) = {
Expand Down Expand Up @@ -74,4 +71,10 @@ class RidgeRegressionSuite extends FunSuite with LocalSparkContext {
assert(ridgeErr < linearErr,
"ridgeError (" + ridgeErr + ") was not less than linearError(" + linearErr + ")")
}

test("do not support intercept") {
intercept[UnsupportedOperationException] {
new RidgeRegressionWithSGD().setIntercept(true)
}
}
}

0 comments on commit d679843

Please sign in to comment.