-
Notifications
You must be signed in to change notification settings - Fork 28.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-4980] [MLlib] Add decay factors to streaming linear methods #8022
Changes from all commits
a20e2f4
d43c3a8
0534328
999beba
98a8a5b
7915a12
16227ab
8605004
686fd2c
3b42f96
0072400
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.mllib.regression | ||
|
||
import org.apache.spark.Logging | ||
import org.apache.spark.annotation.{Since, Experimental} | ||
import org.apache.spark.mllib.regression.StreamingDecay.{BATCHES, POINTS} | ||
|
||
/** | ||
* :: Experimental :: | ||
* Supply the discount value for the | ||
* forgetful update rule in [[StreamingLinearAlgorithm]]; | ||
* The degree of forgetfulness can be specified by the decay factor | ||
* or the half life. | ||
*/ | ||
@Experimental | ||
private[mllib] trait StreamingDecay extends Logging{ | ||
private var decayFactor: Double = 0 | ||
private var timeUnit: String = StreamingDecay.BATCHES | ||
|
||
/** | ||
* Set the decay factor for the forgetful algorithms. | ||
* The decay factor specifies the decay of | ||
* the contribution of data from time t-1 to time t. | ||
* Valid decayFactor ranges from 0 to 1, inclusive. | ||
* decayFactor = 0: previous data have no contribution to the model at the next time unit. | ||
* decayFactor = 1: previous data have equal contribution to the model as the data arriving | ||
* at the next time unit. | ||
* decayFactor is default to 0. | ||
* @param decayFactor the decay factor | ||
*/ | ||
@Since("1.6.0") | ||
def setDecayFactor(decayFactor: Double): this.type = { | ||
this.decayFactor = decayFactor | ||
this | ||
} | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Extra newline There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed. |
||
/** | ||
* Set the half life for the forgetful algorithm. | ||
* The half life provides an alternative way to specify decay factor. | ||
* The decay factor is calculated such that, for data acquired at time t, | ||
* its contribution by time t + halfLife will have dropped by 0.5. | ||
* Half life > 0 is considered as valid. | ||
* @param halfLife the half life | ||
*/ | ||
@Since("1.6.0") | ||
def setHalfLife(halfLife: Double): this.type = { | ||
this.decayFactor = math.exp(math.log(0.5) / halfLife) | ||
logInfo("Setting decay factor to: %g ".format (this.decayFactor)) | ||
this | ||
} | ||
|
||
/** | ||
* Set the time unit for the forgetful algorithm. | ||
* BATCHES: Each RDD in the DStream will be treated as 1 time unit. | ||
* POINTS: Each data point will be treated as 1 time unit. | ||
* timeUnit is default to BATCHES. | ||
* @param timeUnit the time unit | ||
*/ | ||
@Since("1.6.0") | ||
def setTimeUnit(timeUnit: String): this.type = { | ||
if (timeUnit != StreamingDecay.BATCHES && timeUnit != StreamingDecay.POINTS) { | ||
throw new IllegalArgumentException("Invalid time unit for decay: " + timeUnit) | ||
} | ||
this.timeUnit = timeUnit | ||
this | ||
} | ||
|
||
/** | ||
* Derive the discount factor. | ||
* @param numNewDataPoints number of data points for the RDD arriving at time t. | ||
* @return Discount factor | ||
*/ | ||
private[mllib] def getDiscount(numNewDataPoints: Long): Double = timeUnit match { | ||
case BATCHES => decayFactor | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since
|
||
case POINTS => math.pow(decayFactor, numNewDataPoints) | ||
} | ||
} | ||
|
||
/** | ||
* :: Experimental :: | ||
* Provides the String constants for allowed time unit in the forgetful algorithm. | ||
*/ | ||
@Experimental | ||
@Since("1.6.0") | ||
object StreamingDecay { | ||
/** | ||
* Each RDD in the DStream will be treated as 1 time unit. | ||
*/ | ||
@Since("1.6.0") | ||
final val BATCHES = "BATCHES" | ||
/** | ||
* Each data point will be treated as 1 time unit. | ||
*/ | ||
@Since("1.6.0") | ||
final val POINTS = "POINTS" | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,7 +22,7 @@ import scala.reflect.ClassTag | |
import org.apache.spark.Logging | ||
import org.apache.spark.annotation.{DeveloperApi, Since} | ||
import org.apache.spark.api.java.JavaSparkContext.fakeClassTag | ||
import org.apache.spark.mllib.linalg.Vector | ||
import org.apache.spark.mllib.linalg.{BLAS, Vector} | ||
import org.apache.spark.streaming.api.java.{JavaDStream, JavaPairDStream} | ||
import org.apache.spark.streaming.dstream.DStream | ||
|
||
|
@@ -39,6 +39,23 @@ import org.apache.spark.streaming.dstream.DStream | |
* Only weights will be updated, not an intercept. If the model needs | ||
* an intercept, it should be manually appended to the input data. | ||
* | ||
* StreamingLinearAlgorithm use the forgetful algorithm | ||
* to dynamically adjust for evolution of data source. For each batch of data, | ||
* we update the model estimates by: | ||
* | ||
* $$ \theta_{t+1} = \frac{theta_t n_t \alpha + \beta_t m_t}{n_t \alpha + m_t} $$ | ||
* $$ n_{t+1} = n_t \alpha + m_t $$ | ||
* | ||
* where $\theta_t$ is the model estimate before the data arriving at time t; | ||
* $n_t$ is the cumulative contribution of data arriving before time t; | ||
* $\beta_t$ is the estimate using data arriving at time t along; | ||
* $\m_t$ is the number of data point for data arriving at time t along; | ||
* $\alpha$ is the discount factor, $\alpha=0$ only the data from the | ||
* most recent RDD will be used, $\alpha=0$ all data since the beginning | ||
* of the DStream will be used with equal contributions. | ||
* | ||
* This updating rule is analogous to an exponentially-weighted moving average. | ||
* | ||
* For example usage, see `StreamingLinearRegressionWithSGD`. | ||
* | ||
* NOTE: In some use cases, the order in which trainOn and predictOn | ||
|
@@ -59,11 +76,15 @@ import org.apache.spark.streaming.dstream.DStream | |
@DeveloperApi | ||
abstract class StreamingLinearAlgorithm[ | ||
M <: GeneralizedLinearModel, | ||
A <: GeneralizedLinearAlgorithm[M]] extends Logging { | ||
A <: GeneralizedLinearAlgorithm[M]] | ||
extends StreamingDecay with Logging { | ||
|
||
/** The model to be updated and used for prediction. */ | ||
protected var model: Option[M] | ||
|
||
/** The weight estimated with data arriving before the current time unit. */ | ||
protected var previousDataWeight: Double = 0 | ||
|
||
/** The algorithm to use for updating. */ | ||
protected val algorithm: A | ||
|
||
|
@@ -91,7 +112,22 @@ abstract class StreamingLinearAlgorithm[ | |
} | ||
data.foreachRDD { (rdd, time) => | ||
if (!rdd.isEmpty) { | ||
model = Some(algorithm.run(rdd, model.get.weights)) | ||
val newModel = algorithm.run(rdd, model.get.weights) | ||
|
||
val numNewDataPoints = rdd.count() | ||
val discount = getDiscount(numNewDataPoints) | ||
|
||
val updatedDataWeight = previousDataWeight * discount + numNewDataPoints | ||
// updatedDataWeight >= 1 because rdd is not empty; | ||
// no need to check division by zero in below | ||
val lambda = numNewDataPoints / updatedDataWeight | ||
|
||
BLAS.scal(lambda, newModel.weights) | ||
BLAS.axpy(1-lambda, model.get.weights, newModel.weights) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we have some references about this merging scheme? I assume that this works for many cases, but there is no guarantee in theory. |
||
|
||
previousDataWeight = updatedDataWeight | ||
model = Some(newModel) | ||
|
||
logInfo(s"Model updated at time ${time.toString}") | ||
val display = model.get.weights.size match { | ||
case x if x > 100 => model.get.weights.toArray.take(100).mkString("[", ",", "...") | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,6 +30,11 @@ import org.apache.spark.mllib.linalg.Vector | |
* of features must be constant. An initial weight | ||
* vector must be provided. | ||
* | ||
* This class inherits the forgetful algorithm from [[StreamingLinearAlgorithm]] | ||
* to handle evolution of data source. Users can specify the degree of forgetfulness | ||
* by the decay factor or the half-life. Refer to [[StreamingLinearAlgorithm]] for | ||
* more details. | ||
* | ||
* Use a builder pattern to construct a streaming linear regression | ||
* analysis in an application, like: | ||
* | ||
|
@@ -105,4 +110,19 @@ class StreamingLinearRegressionWithSGD private[mllib] ( | |
this.algorithm.optimizer.setConvergenceTol(tolerance) | ||
this | ||
} | ||
|
||
override def setDecayFactor(decayFactor: Double): this.type = { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ditto on duplication There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can remove these setters as well since trait will return concrete subtype |
||
super.setDecayFactor(decayFactor) | ||
this | ||
} | ||
|
||
override def setHalfLife(halfLife: Double): this.type = { | ||
super.setHalfLife(halfLife) | ||
this | ||
} | ||
|
||
override def setTimeUnit(timeUnit: String): this.type = { | ||
super.setTimeUnit(timeUnit) | ||
this | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This boilerplate is duplicated in streaming linear regression. I am guessing you do this to get the concrete subclass (correct me if I'm wrong), but you actually don't need to do this since the
this.type
intrait StreamingDecay
takes care of this. A simple REPL example:There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I meant that you could remove these setters entirely