diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala index 47bff5ebdde47..a13e4143e5c4f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionWithSGD.scala @@ -31,6 +31,11 @@ import org.apache.spark.mllib.regression.StreamingLinearAlgorithm * of features must be constant. An initial weight * vector must be provided. * + * This class inherits the forgetful algorithm from [[StreamingLinearAlgorithm]] + * to handle evolution of data source. Users can specify the degree of forgetfulness + * by the decay factor or the half-life. Refer to [[StreamingLinearAlgorithm]] for + * more details. + * * Use a builder pattern to construct a streaming logistic regression * analysis in an application, like: * @@ -99,4 +104,19 @@ class StreamingLogisticRegressionWithSGD private[mllib] ( this.model = Some(algorithm.createModel(initialWeights, 0.0)) this } + + override def setDecayFactor(decayFactor: Double): this.type = { + super.setDecayFactor(decayFactor) + this + } + + override def setHalfLife(halfLife: Double): this.type = { + super.setHalfLife(halfLife) + this + } + + override def setTimeUnit(timeUnit: String): this.type = { + super.setTimeUnit(timeUnit) + this + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingDecay.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingDecay.scala new file mode 100644 index 0000000000000..643b70c3552e2 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingDecay.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.regression + +import org.apache.spark.Logging +import org.apache.spark.annotation.{Since, Experimental} +import org.apache.spark.mllib.regression.StreamingDecay.{BATCHES, POINTS} + +/** + * :: Experimental :: + * Supply the discount value for the + * forgetful update rule in [[StreamingLinearAlgorithm]]; + * The degree of forgetfulness can be specified by the decay factor + * or the half life. + */ +@Experimental +private[mllib] trait StreamingDecay extends Logging{ + private var decayFactor: Double = 0 + private var timeUnit: String = StreamingDecay.BATCHES + + /** + * Set the decay factor for the forgetful algorithms. + * The decay factor specifies the decay of + * the contribution of data from time t-1 to time t. + * Valid decayFactor ranges from 0 to 1, inclusive. + * decayFactor = 0: previous data have no contribution to the model at the next time unit. + * decayFactor = 1: previous data have equal contribution to the model as the data arriving + * at the next time unit. + * decayFactor is default to 0. + * @param decayFactor the decay factor + */ + @Since("1.6.0") + def setDecayFactor(decayFactor: Double): this.type = { + this.decayFactor = decayFactor + this + } + + /** + * Set the half life for the forgetful algorithm. + * The half life provides an alternative way to specify decay factor. + * The decay factor is calculated such that, for data acquired at time t, + * its contribution by time t + halfLife will have dropped by 0.5. + * Half life > 0 is considered as valid. + * @param halfLife the half life + */ + @Since("1.6.0") + def setHalfLife(halfLife: Double): this.type = { + this.decayFactor = math.exp(math.log(0.5) / halfLife) + logInfo("Setting decay factor to: %g ".format (this.decayFactor)) + this + } + + /** + * Set the time unit for the forgetful algorithm. + * BATCHES: Each RDD in the DStream will be treated as 1 time unit. + * POINTS: Each data point will be treated as 1 time unit. + * timeUnit is default to BATCHES. + * @param timeUnit the time unit + */ + @Since("1.6.0") + def setTimeUnit(timeUnit: String): this.type = { + if (timeUnit != StreamingDecay.BATCHES && timeUnit != StreamingDecay.POINTS) { + throw new IllegalArgumentException("Invalid time unit for decay: " + timeUnit) + } + this.timeUnit = timeUnit + this + } + + /** + * Derive the discount factor. + * @param numNewDataPoints number of data points for the RDD arriving at time t. + * @return Discount factor + */ + private[mllib] def getDiscount(numNewDataPoints: Long): Double = timeUnit match { + case BATCHES => decayFactor + case POINTS => math.pow(decayFactor, numNewDataPoints) + } +} + +/** + * :: Experimental :: + * Provides the String constants for allowed time unit in the forgetful algorithm. + */ +@Experimental +@Since("1.6.0") +object StreamingDecay { + /** + * Each RDD in the DStream will be treated as 1 time unit. + */ + @Since("1.6.0") + final val BATCHES = "BATCHES" + /** + * Each data point will be treated as 1 time unit. + */ + @Since("1.6.0") + final val POINTS = "POINTS" +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala index 73948b2d9851a..57a1a72b9e5a8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearAlgorithm.scala @@ -22,7 +22,7 @@ import scala.reflect.ClassTag import org.apache.spark.Logging import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.api.java.JavaSparkContext.fakeClassTag -import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.linalg.{BLAS, Vector} import org.apache.spark.streaming.api.java.{JavaDStream, JavaPairDStream} import org.apache.spark.streaming.dstream.DStream @@ -39,6 +39,23 @@ import org.apache.spark.streaming.dstream.DStream * Only weights will be updated, not an intercept. If the model needs * an intercept, it should be manually appended to the input data. * + * StreamingLinearAlgorithm use the forgetful algorithm + * to dynamically adjust for evolution of data source. For each batch of data, + * we update the model estimates by: + * + * $$ \theta_{t+1} = \frac{theta_t n_t \alpha + \beta_t m_t}{n_t \alpha + m_t} $$ + * $$ n_{t+1} = n_t \alpha + m_t $$ + * + * where $\theta_t$ is the model estimate before the data arriving at time t; + * $n_t$ is the cumulative contribution of data arriving before time t; + * $\beta_t$ is the estimate using data arriving at time t along; + * $\m_t$ is the number of data point for data arriving at time t along; + * $\alpha$ is the discount factor, $\alpha=0$ only the data from the + * most recent RDD will be used, $\alpha=0$ all data since the beginning + * of the DStream will be used with equal contributions. + * + * This updating rule is analogous to an exponentially-weighted moving average. + * * For example usage, see `StreamingLinearRegressionWithSGD`. * * NOTE: In some use cases, the order in which trainOn and predictOn @@ -59,11 +76,15 @@ import org.apache.spark.streaming.dstream.DStream @DeveloperApi abstract class StreamingLinearAlgorithm[ M <: GeneralizedLinearModel, - A <: GeneralizedLinearAlgorithm[M]] extends Logging { + A <: GeneralizedLinearAlgorithm[M]] + extends StreamingDecay with Logging { /** The model to be updated and used for prediction. */ protected var model: Option[M] + /** The weight estimated with data arriving before the current time unit. */ + protected var previousDataWeight: Double = 0 + /** The algorithm to use for updating. */ protected val algorithm: A @@ -91,7 +112,22 @@ abstract class StreamingLinearAlgorithm[ } data.foreachRDD { (rdd, time) => if (!rdd.isEmpty) { - model = Some(algorithm.run(rdd, model.get.weights)) + val newModel = algorithm.run(rdd, model.get.weights) + + val numNewDataPoints = rdd.count() + val discount = getDiscount(numNewDataPoints) + + val updatedDataWeight = previousDataWeight * discount + numNewDataPoints + // updatedDataWeight >= 1 because rdd is not empty; + // no need to check division by zero in below + val lambda = numNewDataPoints / updatedDataWeight + + BLAS.scal(lambda, newModel.weights) + BLAS.axpy(1-lambda, model.get.weights, newModel.weights) + + previousDataWeight = updatedDataWeight + model = Some(newModel) + logInfo(s"Model updated at time ${time.toString}") val display = model.get.weights.size match { case x if x > 100 => model.get.weights.toArray.take(100).mkString("[", ",", "...") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala index fe2a46b9eecc7..67d66ab70b90c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala @@ -30,6 +30,11 @@ import org.apache.spark.mllib.linalg.Vector * of features must be constant. An initial weight * vector must be provided. * + * This class inherits the forgetful algorithm from [[StreamingLinearAlgorithm]] + * to handle evolution of data source. Users can specify the degree of forgetfulness + * by the decay factor or the half-life. Refer to [[StreamingLinearAlgorithm]] for + * more details. + * * Use a builder pattern to construct a streaming linear regression * analysis in an application, like: * @@ -105,4 +110,19 @@ class StreamingLinearRegressionWithSGD private[mllib] ( this.algorithm.optimizer.setConvergenceTol(tolerance) this } + + override def setDecayFactor(decayFactor: Double): this.type = { + super.setDecayFactor(decayFactor) + this + } + + override def setHalfLife(halfLife: Double): this.type = { + super.setHalfLife(halfLife) + this + } + + override def setTimeUnit(timeUnit: String): this.type = { + super.setTimeUnit(timeUnit) + this + } } diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java index c9e5ee22f3273..351aa22b78804 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java @@ -21,6 +21,7 @@ import java.util.Arrays; import java.util.List; +import org.apache.spark.mllib.regression.StreamingDecay; import scala.Tuple2; import org.junit.After; @@ -72,7 +73,9 @@ public void javaAPI() { attachTestInputStream(ssc, Arrays.asList(testBatch, testBatch), 2)); StreamingLogisticRegressionWithSGD slr = new StreamingLogisticRegressionWithSGD() .setNumIterations(2) - .setInitialWeights(Vectors.dense(0.0)); + .setInitialWeights(Vectors.dense(0.0)) + .setDecayFactor(0.5) + .setTimeUnit("POINTS"); slr.trainOn(training); JavaPairDStream prediction = slr.predictOnValues(test); attachTestOutputStream(prediction.count()); diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java index dbf6488d41085..03b24c7c2f9b3 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java @@ -71,7 +71,9 @@ public void javaAPI() { attachTestInputStream(ssc, Arrays.asList(testBatch, testBatch), 2)); StreamingLinearRegressionWithSGD slr = new StreamingLinearRegressionWithSGD() .setNumIterations(2) - .setInitialWeights(Vectors.dense(0.0)); + .setInitialWeights(Vectors.dense(0.0)) + .setDecayFactor(0.5) + .setTimeUnit("POINTS"); slr.trainOn(training); JavaPairDStream prediction = slr.predictOnValues(test); attachTestOutputStream(prediction.count()); diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala index d7b291d5a6330..00a24d376286c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala @@ -184,4 +184,72 @@ class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase ) val output: Seq[Seq[(Double, Double)]] = runStreams(ssc, numBatches, numBatches) } + + test("parameter accuracy with full memory (decayFactor = 1)") { + + val nPoints = 600 + + // create model + val model = new StreamingLogisticRegressionWithSGD() + .setDecayFactor(1) + .setInitialWeights(Vectors.dense(0.0)) + .setStepSize(1) + .setNumIterations(100) + + // generate sequence of simulated data + val numBatches = 20 + // the first few RDD's are generated under the model A + val inputA = (0 until (numBatches - 1)).map { i => + LogisticRegressionSuite.generateLogisticInput(0.0, 0.1, nPoints, 33 * (i + 1)) + } + // the last RDD is generated under the model B + val inputB = + LogisticRegressionSuite.generateLogisticInput(0.0, 0.5, nPoints, 33 * (numBatches + 1)) + val input = inputA :+ inputB + + // apply model training to input stream + ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => { + model.trainOn(inputDStream) + inputDStream.count() + }) + runStreams(ssc, numBatches, numBatches) + + // with full memory, the final parameter estimates should be close to model A + assert(model.latestModel().weights(0) ~== 0.1 relTol 0.1) + + } + + test("parameter accuracy with no memory (decayFactor = 0)") { + + val nPoints = 600 + + // create model + val model = new StreamingLogisticRegressionWithSGD() + .setDecayFactor(0) + .setInitialWeights(Vectors.dense(0.0)) + .setStepSize(1) + .setNumIterations(100) + + // generate sequence of simulated data + val numBatches = 20 + // the first few RDD's are generated under the model A + val inputA = (0 until (numBatches - 1)).map { i => + LogisticRegressionSuite.generateLogisticInput(0.0, 0.1, nPoints, 33 * (i + 1)) + } + // the last RDD is generated under the model B + val inputB = + LogisticRegressionSuite.generateLogisticInput(0.0, 0.5, nPoints, 33 * (numBatches + 1)) + val input = inputA :+ inputB + + // apply model training to input stream + ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => { + model.trainOn(inputDStream) + inputDStream.count() + }) + runStreams(ssc, numBatches, numBatches) + + // with no memory, the final parameter estimates should be close to model B + assert(model.latestModel().weights(0) ~== 0.5 relTol 0.1) + + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala index 34c07ed170816..0bc8a4e6f2938 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionSuite.scala @@ -21,9 +21,11 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.regression.StreamingDecay.{BATCHES, POINTS} import org.apache.spark.mllib.util.LinearDataGenerator -import org.apache.spark.streaming.{StreamingContext, TestSuiteBase} +import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.streaming.{StreamingContext, TestSuiteBase} class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase { @@ -194,4 +196,208 @@ class StreamingLinearRegressionSuite extends SparkFunSuite with TestSuiteBase { ) val output: Seq[Seq[(Double, Double)]] = runStreams(ssc, numBatches, numBatches) } + + test("parameter accuracy with full memory (decayFactor = 1)") { + // create model + val model = new StreamingLinearRegressionWithSGD() + .setDecayFactor(1) + .setInitialWeights(Vectors.dense(0.0, 0.0)) + .setStepSize(0.5) + .setNumIterations(50) + .setConvergenceTol(0.0001) + + // generate sequence of simulated data + val numBatches = 10 + // the first few RDD's are generated under the model A + val inputA = (0 until (numBatches-1)).map { i => + LinearDataGenerator.generateLinearInput(0.0, Array(10.0, 10.0), 200, 42 * (i + 1)) + } + // the last RDD is generated under the model B + val inputB = + LinearDataGenerator.generateLinearInput(0.0, Array(5.0, 3.0), 200, 42 * (numBatches + 1)) + val input = inputA :+ inputB + + // apply model training to input stream + ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => { + model.trainOn(inputDStream) + inputDStream.count() + }) + runStreams(ssc, numBatches, numBatches) + + // with full memory, the final parameter estimates should be close to model A + assert(model.latestModel().intercept ~== 0.0 absTol 1.0) + assert(model.latestModel().weights(0) ~== 10.0 absTol 1.0) + assert(model.latestModel().weights(1) ~== 10.0 absTol 1.0) + } + + test("parameter accuracy with no memory (decayFactor = 0)") { + // create model + val model = new StreamingLinearRegressionWithSGD() + .setDecayFactor(0) + .setInitialWeights(Vectors.dense(0.0, 0.0)) + .setStepSize(0.5) + .setNumIterations(50) + .setConvergenceTol(0.0001) + + // generate sequence of simulated data + val numBatches = 10 + // the first few RDD's are generated under the model A + val inputA = (0 until (numBatches - 1)).map { i => + LinearDataGenerator.generateLinearInput(0.0, Array(10.0, 10.0), 200, 42 * (i + 1)) + } + // the last RDD is generated under the model B + val inputB = + LinearDataGenerator.generateLinearInput(0.0, Array(5.0, 3.0), 200, 42 * (numBatches + 1)) + val input = inputA :+ inputB + + // apply model training to input stream + ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => { + model.trainOn(inputDStream) + inputDStream.count() + }) + runStreams(ssc, numBatches, numBatches) + + // with no memory, the final parameter estimates should be close to model B + assert(model.latestModel().intercept ~== 0.0 absTol 1.0) + assert(model.latestModel().weights(0) ~== 5.0 absTol 1.0) + assert(model.latestModel().weights(1) ~== 3.0 absTol 1.0) + } + + test("parameter accuracy with long half life and POINTS as TimeUnit") { + // create model + val model = new StreamingLinearRegressionWithSGD() + .setHalfLife(5000) + .setTimeUnit(POINTS) + .setInitialWeights(Vectors.dense(0.0, 0.0)) + .setStepSize(0.5) + .setNumIterations(50) + .setConvergenceTol(0.0001) + + // generate sequence of simulated data + val numBatches = 10 + // the first few RDD's are generated under the model A + val inputA = (0 until (numBatches-1)).map { i => + LinearDataGenerator.generateLinearInput(0.0, Array(10.0, 10.0), 200, 42 * (i + 1)) + } + // the last RDD is generated under the model B + val inputB = + LinearDataGenerator.generateLinearInput(0.0, Array(5.0, 3.0), 200, 42 * (numBatches + 1)) + val input = inputA :+ inputB + + // apply model training to input stream + ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => { + model.trainOn(inputDStream) + inputDStream.count() + }) + runStreams(ssc, numBatches, numBatches) + + // with long half life, the final parameter estimates should be close to model A + assert(model.latestModel().intercept ~== 0.0 absTol 1.0) + assert(model.latestModel().weights(0) ~== 10.0 absTol 1.0) + assert(model.latestModel().weights(1) ~== 10.0 absTol 1.0) + } + + test("parameter accuracy with long half life and BATCHES as TimeUnit") { + // create model + val model = new StreamingLinearRegressionWithSGD() + .setHalfLife(20) + .setTimeUnit(BATCHES) + .setInitialWeights(Vectors.dense(0.0, 0.0)) + .setStepSize(0.5) + .setNumIterations(50) + .setConvergenceTol(0.0001) + + // generate sequence of simulated data + val numBatches = 10 + // the first few RDD's are generated under the model A + val inputA = (0 until (numBatches-1)).map { i => + LinearDataGenerator.generateLinearInput(0.0, Array(10.0, 10.0), 200, 42 * (i + 1)) + } + // the last RDD is generated under the model B + val inputB = + LinearDataGenerator.generateLinearInput(0.0, Array(5.0, 3.0), 200, 42 * (numBatches + 1)) + val input = inputA :+ inputB + + // apply model training to input stream + ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => { + model.trainOn(inputDStream) + inputDStream.count() + }) + runStreams(ssc, numBatches, numBatches) + + // with long half life, the final parameter estimates should be close to model A + assert(model.latestModel().intercept ~== 0.0 absTol 1.0) + assert(model.latestModel().weights(0) ~== 10.0 absTol 1.0) + assert(model.latestModel().weights(1) ~== 10.0 absTol 1.0) + } + + test("parameter accuracy with short half life and POINTS as TimeUnit") { + // create model + val model = new StreamingLinearRegressionWithSGD() + .setHalfLife(50) + .setTimeUnit(POINTS) + .setInitialWeights(Vectors.dense(0.0, 0.0)) + .setStepSize(0.5) + .setNumIterations(50) + .setConvergenceTol(0.0001) + + // generate sequence of simulated data + val numBatches = 10 + // the first half of the RDD's are generated under the model A + val inputA = (0 until (numBatches / 2 - 1)).map { i => + LinearDataGenerator.generateLinearInput(0.0, Array(10.0, 10.0), 200, 42 * (i + 1)) + } + // the second half of the RDD's are generated under the model B + val inputB = (0 until (numBatches / 2 - 1)).map { i => + LinearDataGenerator.generateLinearInput(0.0, Array(5.0, 3.0), 200, 42 * (i + 1)) + } + val input = inputA ++ inputB + + // apply model training to input stream + ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => { + model.trainOn(inputDStream) + inputDStream.count() + }) + runStreams(ssc, numBatches, numBatches) + + // with short half life, the final parameter estimates should be close to model B + assert(model.latestModel().intercept ~== 0.0 absTol 1.0) + assert(model.latestModel().weights(0) ~== 5.0 absTol 1.0) + assert(model.latestModel().weights(1) ~== 3.0 absTol 1.0) + } + + test("parameter accuracy with short half life and BATCHES as TimeUnit") { + // create model + val model = new StreamingLinearRegressionWithSGD() + .setHalfLife(1) + .setTimeUnit(BATCHES) + .setInitialWeights(Vectors.dense(0.0, 0.0)) + .setStepSize(0.5) + .setNumIterations(50) + .setConvergenceTol(0.0001) + + // generate sequence of simulated data + val numBatches = 10 + // the first half of the RDD's are generated under the model A + val inputA = (0 until (numBatches / 2 - 1)).map { i => + LinearDataGenerator.generateLinearInput(0.0, Array(10.0, 10.0), 200, 42 * (i + 1)) + } + // the second half of the RDD's are generated under the model B + val inputB = (0 until (numBatches / 2 - 1)).map { i => + LinearDataGenerator.generateLinearInput(0.0, Array(5.0, 3.0), 200, 42 * (i + 1)) + } + val input = inputA ++ inputB + + // apply model training to input stream + ssc = setupStreams(input, (inputDStream: DStream[LabeledPoint]) => { + model.trainOn(inputDStream) + inputDStream.count() + }) + runStreams(ssc, numBatches, numBatches) + + // with short half life, the final parameter estimates should be close to model B + assert(model.latestModel().intercept ~== 0.0 absTol 1.0) + assert(model.latestModel().weights(0) ~== 5.0 absTol 1.0) + assert(model.latestModel().weights(1) ~== 3.0 absTol 1.0) + } }