Skip to content

Commit

Permalink
Fixed bug in LogisticRegression (introduced in this PR). Fixed Java s…
Browse files Browse the repository at this point in the history
…uites
  • Loading branch information
jkbradley committed Feb 5, 2015
1 parent 0a16da9 commit 82f340b
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 91 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ class LogisticRegressionModel private[ml] (
if (map(probabilityCol) != "") {
if (map(rawPredictionCol) != "") {
val raw2prob: Vector => Vector = (rawPreds) => {
val prob1 = 1.0 / 1.0 + math.exp(-rawPreds(1))
val prob1 = 1.0 / (1.0 + math.exp(-rawPreds(1)))
Vectors.dense(1.0 - prob1, prob1)
}
tmpData = tmpData.select(Star(None),
Expand All @@ -171,7 +171,7 @@ class LogisticRegressionModel private[ml] (
predict.call(map(probabilityCol).attr) as map(predictionCol))
} else if (map(rawPredictionCol) != "") {
val predict: Vector => Double = (rawPreds) => {
val prob1 = 1.0 / 1.0 + math.exp(-rawPreds(1))
val prob1 = 1.0 / (1.0 + math.exp(-rawPreds(1)))
if (prob1 > t) 1.0 else 0.0
}
tmpData = tmpData.select(Star(None),
Expand Down Expand Up @@ -207,7 +207,7 @@ class LogisticRegressionModel private[ml] (

override protected def predictRaw(features: Vector): Vector = {
val m = margin(features)
Vectors.dense(-m, m)
Vectors.dense(0.0, m)
}

override protected def copy(): LogisticRegressionModel = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ public void pipeline() {
.setStages(new PipelineStage[] {scaler, lr});
PipelineModel model = pipeline.fit(dataset);
model.transform(dataset).registerTempTable("prediction");
DataFrame predictions = jsql.sql("SELECT label, score, prediction FROM prediction");
DataFrame predictions = jsql.sql("SELECT label, probability, prediction FROM prediction");
predictions.collectAsList();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.ml.LabeledPoint;
import org.apache.spark.ml.regression.LinearRegression;
import org.apache.spark.ml.regression.LinearRegressionModel;
import static org.apache.spark.mllib.classification.LogisticRegressionSuite
.generateLogisticInputAsList;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.api.java.JavaSQLContext;
import org.apache.spark.sql.api.java.JavaSchemaRDD;
import org.apache.spark.sql.api.java.Row;
Expand Down Expand Up @@ -93,35 +93,14 @@ public void linearRegressionWithSetters() {
.setMaxIter(10)
.setRegParam(1.0);
LinearRegressionModel model = lr.fit(dataset);
assert(model.fittingParamMap().get(lr.maxIter()).get() == 10);
assert(model.fittingParamMap().get(lr.regParam()).get() == 1.0);
assert(model.fittingParamMap().apply(lr.maxIter()) == 10);
assert(model.fittingParamMap().apply(lr.regParam()).equals(1.0));

// Call fit() with new params, and check as many params as we can.
LinearRegressionModel model2 =
lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1), lr.predictionCol().w("thePred"));
assert(model2.fittingParamMap().get(lr.maxIter()).get() == 5);
assert(model2.fittingParamMap().get(lr.regParam()).get() == 0.1);
assert(model2.fittingParamMap().apply(lr.maxIter()) == 5);
assert(model2.fittingParamMap().apply(lr.regParam()).equals(0.1));
assert(model2.getPredictionCol().equals("thePred"));
}

@Test
public void linearRegressionPredictorClassifierMethods() {
LinearRegression lr = new LinearRegression();

// fit() vs. train()
LinearRegressionModel model1 = lr.fit(dataset);
LinearRegressionModel model2 = lr.train(datasetRDD);
assert(model1.intercept() == model2.intercept());
assert(model1.weights().equals(model2.weights()));

// transform() vs. predict()
model1.transform(dataset).registerTempTable("transformed");
JavaSchemaRDD trans = jsql.sql("SELECT prediction FROM transformed");
JavaRDD<Double> preds = model1.predict(featuresRDD);
for (Tuple2<Row, Double> trans_pred: trans.zip(preds).collect()) {
double t = trans_pred._1().getDouble(0);
double p = trans_pred._2();
assert(t == p);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

package org.apache.spark.ml.classification;

import scala.Tuple2;

import java.io.Serializable;
import java.lang.Math;
import java.util.ArrayList;
Expand All @@ -34,9 +32,8 @@
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.SQLContext;
import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.ml.LabeledPoint;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.Row;


Expand All @@ -47,7 +44,6 @@ public class JavaLogisticRegressionSuite implements Serializable {
private transient DataFrame dataset;

private transient JavaRDD<LabeledPoint> datasetRDD;
private transient JavaRDD<Vector> featuresRDD;
private double eps = 1e-5;

@Before
Expand All @@ -60,9 +56,6 @@ public void setUp() {
points.add(new LabeledPoint(lp.label(), lp.features()));
}
datasetRDD = jsc.parallelize(points, 2);
featuresRDD = datasetRDD.map(new Function<LabeledPoint, Vector>() {
@Override public Vector call(LabeledPoint lp) { return lp.features(); }
});
dataset = jsql.applySchema(datasetRDD, LabeledPoint.class);
dataset.registerTempTable("dataset");
}
Expand All @@ -79,13 +72,13 @@ public void logisticRegressionDefaultParams() {
assert(lr.getLabelCol().equals("label"));
LogisticRegressionModel model = lr.fit(dataset);
model.transform(dataset).registerTempTable("prediction");
DataFrame predictions = jsql.sql("SELECT label, score, prediction FROM prediction");
DataFrame predictions = jsql.sql("SELECT label, probability, prediction FROM prediction");
predictions.collectAsList();
// Check defaults
assert(model.getThreshold() == 0.5);
assert(model.getFeaturesCol().equals("features"));
assert(model.getPredictionCol().equals("prediction"));
assert(model.getScoreCol().equals("score"));
assert(model.getProbabilityCol().equals("probability"));
}

@Test
Expand All @@ -95,17 +88,17 @@ public void logisticRegressionWithSetters() {
.setMaxIter(10)
.setRegParam(1.0)
.setThreshold(0.6)
.setScoreCol("probability");
.setProbabilityCol("myProbability");
LogisticRegressionModel model = lr.fit(dataset);
assert(model.fittingParamMap().get(lr.maxIter()).get() == 10);
assert(model.fittingParamMap().get(lr.regParam()).get() == 1.0);
assert(model.fittingParamMap().get(lr.threshold()).get() == 0.6);
assert(model.fittingParamMap().apply(lr.maxIter()) == 10);
assert(model.fittingParamMap().apply(lr.regParam()).equals(1.0));
assert(model.fittingParamMap().apply(lr.threshold()).equals(0.6));
assert(model.getThreshold() == 0.6);

// Modify model params, and check that the params worked.
model.setThreshold(1.0);
model.transform(dataset).registerTempTable("predAllZero");
SchemaRDD predAllZero = jsql.sql("SELECT prediction, probability FROM predAllZero");
SchemaRDD predAllZero = jsql.sql("SELECT prediction, myProbability FROM predAllZero");
for (Row r: predAllZero.collectAsList()) {
assert(r.getDouble(0) == 0.0);
}
Expand All @@ -117,7 +110,7 @@ public void logisticRegressionWithSetters() {
predictions.collectAsList();
*/

model.transform(dataset, model.threshold().w(0.0), model.scoreCol().w("myProb"))
model.transform(dataset, model.threshold().w(0.0), model.probabilityCol().w("myProb"))
.registerTempTable("predNotAllZero");
SchemaRDD predNotAllZero = jsql.sql("SELECT prediction, myProb FROM predNotAllZero");
boolean foundNonZero = false;
Expand All @@ -128,54 +121,37 @@ public void logisticRegressionWithSetters() {

// Call fit() with new params, and check as many params as we can.
LogisticRegressionModel model2 = lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1),
lr.threshold().w(0.4), lr.scoreCol().w("theProb"));
assert(model2.fittingParamMap().get(lr.maxIter()).get() == 5);
assert(model2.fittingParamMap().get(lr.regParam()).get() == 0.1);
assert(model2.fittingParamMap().get(lr.threshold()).get() == 0.4);
lr.threshold().w(0.4), lr.probabilityCol().w("theProb"));
assert(model2.fittingParamMap().apply(lr.maxIter()) == 5);
assert(model2.fittingParamMap().apply(lr.regParam()).equals(0.1));
assert(model2.fittingParamMap().apply(lr.threshold()).equals(0.4));
assert(model2.getThreshold() == 0.4);
assert(model2.getScoreCol().equals("theProb"));
assert(model2.getProbabilityCol().equals("theProb"));
}

@SuppressWarnings("unchecked")
@Test
public void logisticRegressionPredictorClassifierMethods() {
LogisticRegression lr = new LogisticRegression();

// fit() vs. train()
LogisticRegressionModel model1 = lr.fit(dataset);
LogisticRegressionModel model2 = lr.train(datasetRDD);
assert(model1.intercept() == model2.intercept());
assert(model1.weights().equals(model2.weights()));
assert(model1.numClasses() == model2.numClasses());
assert(model1.numClasses() == 2);

// transform() vs. predict()
model1.transform(dataset).registerTempTable("transformed");
SchemaRDD trans = jsql.sql("SELECT prediction FROM transformed");
JavaRDD<Double> preds = model1.predict(featuresRDD);
for (scala.Tuple2<Row, Double> trans_pred: trans.toJavaRDD().zip(preds).collect()) {
double t = trans_pred._1().getDouble(0);
double p = trans_pred._2();
assert(t == p);
LogisticRegressionModel model = lr.fit(dataset);
assert(model.numClasses() == 2);

model.transform(dataset).registerTempTable("transformed");
SchemaRDD trans1 = jsql.sql("SELECT rawPrediction, probability FROM transformed");
for (Row row: trans1.collect()) {
Vector raw = (Vector)row.get(0);
Vector prob = (Vector)row.get(1);
assert(raw.size() == 2);
assert(prob.size() == 2);
double probFromRaw1 = 1.0 / (1.0 + Math.exp(-raw.apply(1)));
assert(Math.abs(prob.apply(1) - probFromRaw1) < eps);
assert(Math.abs(prob.apply(0) - (1.0 - probFromRaw1)) < eps);
}

// Check various types of predictions.
JavaRDD<Vector> rawPredictions = model1.predictRaw(featuresRDD);
JavaRDD<Vector> probabilities = model1.predictProbabilities(featuresRDD);
JavaRDD<Double> predictions = model1.predict(featuresRDD);
double threshold = model1.getThreshold();
for (Tuple2<Vector, Vector> raw_prob: rawPredictions.zip(probabilities).collect()) {
Vector raw = raw_prob._1();
Vector prob = raw_prob._2();
for (int i = 0; i < raw.size(); ++i) {
double r = raw.apply(i);
double p = prob.apply(i);
double pFromR = 1.0 / (1.0 + Math.exp(-r));
assert(Math.abs(r - pFromR) < eps);
}
}
for (Tuple2<Vector, Double> prob_pred: probabilities.zip(predictions).collect()) {
Vector prob = prob_pred._1();
double pred = prob_pred._2();
SchemaRDD trans2 = jsql.sql("SELECT prediction, probability FROM transformed");
for (Row row: trans2.collect()) {
double pred = row.getDouble(0);
Vector prob = (Vector)row.get(1);
double probOfPred = prob.apply((int)pred);
for (int i = 0; i < prob.size(); ++i) {
assert(probOfPred >= prob.apply(i));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
.select('prediction, 'myProbability)
.collect()
.map { case Row(pred: Double, prob: Vector) => pred }
assert(predAllZero.forall(_ === 0.0))
assert(predAllZero.forall(_ === 0),
s"With threshold=1.0, expected predictions to be all 0, but only" +
s" ${predAllZero.count(_ === 0)} of ${dataset.count()} were 0.")
// Call transform with params, and check that the params worked.
val predNotAllZero =
model.transform(dataset, model.threshold -> 0.0, model.probabilityCol -> "myProb")
Expand Down Expand Up @@ -115,10 +117,11 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
// Compare rawPrediction with probability
results.select('rawPrediction, 'probability).collect().map {
case Row(raw: Vector, prob: Vector) =>
val raw2prob: (Double => Double) = (m) => 1.0 / (1.0 + math.exp(-m))
raw.toArray.map(raw2prob).zip(prob.toArray).foreach { case (r, p) =>
assert(r ~== p relTol eps)
}
assert(raw.size === 2)
assert(prob.size === 2)
val probFromRaw1 = 1.0 / (1.0 + math.exp(-raw(1)))
assert(prob(1) ~== probFromRaw1 relTol eps)
assert(prob(0) ~== 1.0 - probFromRaw1 relTol eps)
}

// Compare prediction with probability
Expand Down

0 comments on commit 82f340b

Please sign in to comment.