Skip to content

Commit

Permalink
Improve Java API compatibility.
Browse files Browse the repository at this point in the history
  • Loading branch information
Meihua Wu committed Nov 9, 2015
1 parent 3b42f96 commit 0072400
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package org.apache.spark.mllib.classification

import org.apache.spark.annotation.Since
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.StreamingDecay.TimeUnit
import org.apache.spark.mllib.regression.StreamingLinearAlgorithm

/**
Expand Down Expand Up @@ -105,4 +104,19 @@ class StreamingLogisticRegressionWithSGD private[mllib] (
this.model = Some(algorithm.createModel(initialWeights, 0.0))
this
}

override def setDecayFactor(decayFactor: Double): this.type = {
super.setDecayFactor(decayFactor)
this
}

override def setHalfLife(halfLife: Double): this.type = {
super.setHalfLife(halfLife)
this
}

override def setTimeUnit(timeUnit: String): this.type = {
super.setTimeUnit(timeUnit)
this
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.mllib.regression

import org.apache.spark.Logging
import org.apache.spark.annotation.{Since, Experimental}
import org.apache.spark.mllib.regression.StreamingDecay.{TimeUnit, BATCHES, POINTS}
import org.apache.spark.mllib.regression.StreamingDecay.{BATCHES, POINTS}

/**
* :: Experimental ::
Expand All @@ -31,7 +31,7 @@ import org.apache.spark.mllib.regression.StreamingDecay.{TimeUnit, BATCHES, POIN
@Experimental
private[mllib] trait StreamingDecay extends Logging{
private var decayFactor: Double = 0
private var timeUnit: TimeUnit = BATCHES
private var timeUnit: String = StreamingDecay.BATCHES

/**
* Set the decay factor for the forgetful algorithms.
Expand Down Expand Up @@ -73,7 +73,10 @@ private[mllib] trait StreamingDecay extends Logging{
* @param timeUnit the time unit
*/
@Since("1.6.0")
def setTimeUnit(timeUnit: TimeUnit): this.type = {
def setTimeUnit(timeUnit: String): this.type = {
if (timeUnit != StreamingDecay.BATCHES && timeUnit != StreamingDecay.POINTS) {
throw new IllegalArgumentException("Invalid time unit for decay: " + timeUnit)
}
this.timeUnit = timeUnit
this
}
Expand All @@ -96,15 +99,14 @@ private[mllib] trait StreamingDecay extends Logging{
@Experimental
@Since("1.6.0")
object StreamingDecay {
private[mllib] sealed trait TimeUnit
/**
* Each RDD in the DStream will be treated as 1 time unit.
*/
@Since("1.6.0")
case object BATCHES extends TimeUnit
final val BATCHES = "BATCHES"
/**
* Each data point will be treated as 1 time unit.
*/
@Since("1.6.0")
case object POINTS extends TimeUnit
final val POINTS = "POINTS"
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package org.apache.spark.mllib.regression

import org.apache.spark.annotation.Since
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.StreamingDecay.TimeUnit

/**
* Train or predict a linear regression model on streaming data. Training uses
Expand Down Expand Up @@ -111,4 +110,19 @@ class StreamingLinearRegressionWithSGD private[mllib] (
this.algorithm.optimizer.setConvergenceTol(tolerance)
this
}

override def setDecayFactor(decayFactor: Double): this.type = {
super.setDecayFactor(decayFactor)
this
}

override def setHalfLife(halfLife: Double): this.type = {
super.setHalfLife(halfLife)
this
}

override def setTimeUnit(timeUnit: String): this.type = {
super.setTimeUnit(timeUnit)
this
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.util.Arrays;
import java.util.List;

import org.apache.spark.mllib.regression.StreamingDecay;
import scala.Tuple2;

import org.junit.After;
Expand Down Expand Up @@ -72,7 +73,9 @@ public void javaAPI() {
attachTestInputStream(ssc, Arrays.asList(testBatch, testBatch), 2));
StreamingLogisticRegressionWithSGD slr = new StreamingLogisticRegressionWithSGD()
.setNumIterations(2)
.setInitialWeights(Vectors.dense(0.0));
.setInitialWeights(Vectors.dense(0.0))
.setDecayFactor(0.5)
.setTimeUnit("POINTS");
slr.trainOn(training);
JavaPairDStream<Integer, Double> prediction = slr.predictOnValues(test);
attachTestOutputStream(prediction.count());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ public void javaAPI() {
attachTestInputStream(ssc, Arrays.asList(testBatch, testBatch), 2));
StreamingLinearRegressionWithSGD slr = new StreamingLinearRegressionWithSGD()
.setNumIterations(2)
.setInitialWeights(Vectors.dense(0.0));
.setInitialWeights(Vectors.dense(0.0))
.setDecayFactor(0.5)
.setTimeUnit("POINTS");
slr.trainOn(training);
JavaPairDStream<Integer, Double> prediction = slr.predictOnValues(test);
attachTestOutputStream(prediction.count());
Expand Down

0 comments on commit 0072400

Please sign in to comment.