Skip to content

Commit

Permalink
[SPARK-1212, Part II] Support sparse data in MLlib
Browse files Browse the repository at this point in the history
In PR apache/spark#117, we added dense/sparse vector data model and updated KMeans to support sparse input. This PR is to replace all other `Array[Double]` usage by `Vector` in generalized linear models (GLMs) and Naive Bayes. Major changes:

1. `LabeledPoint` becomes `LabeledPoint(Double, Vector)`.
2. Methods that accept `RDD[Array[Double]]` now accept `RDD[Vector]`. We cannot support both in an elegant way because of type erasure.
3. Mark 'createModel' and 'predictPoint' protected because they are not for end users.
4. Add libSVMFile to MLContext.
5. NaiveBayes can accept arbitrary labels (introducing a breaking change to Python's `NaiveBayesModel`).
6. Gradient computation no longer creates temp vectors.
7. Column normalization and centering are removed from Lasso and Ridge because the operation will densify the data. Simple feature transformation can be done before training.

TODO:
1. ~~Use axpy when possible.~~
2. ~~Optimize Naive Bayes.~~

Author: Xiangrui Meng <[email protected]>

Closes #245 from mengxr/vector and squashes the following commits:

eb6e793 [Xiangrui Meng] move libSVMFile to MLUtils and rename to loadLibSVMData
c26c4fc [Xiangrui Meng] update DecisionTree to use RDD[Vector]
11999c7 [Xiangrui Meng] Merge branch 'master' into vector
f7da54b [Xiangrui Meng] add minSplits to libSVMFile
da25e24 [Xiangrui Meng] revert the change to default addIntercept because it might change the behavior of existing code without warning
493f26f [Xiangrui Meng] Merge branch 'master' into vector
7c1bc01 [Xiangrui Meng] add a TODO to NB
b9b7ef7 [Xiangrui Meng] change default value of addIntercept to false
b01df54 [Xiangrui Meng] allow to change or clear threshold in LR and SVM
4addc50 [Xiangrui Meng] merge master
4ca5b1b [Xiangrui Meng] remove normalization from Lasso and update tests
f04fe8a [Xiangrui Meng] remove normalization from RidgeRegression and update tests
d088552 [Xiangrui Meng] use static constructor for MLContext
6f59eed [Xiangrui Meng] update libSVMFile to determine number of features automatically
3432e84 [Xiangrui Meng] update NaiveBayes to support sparse data
0f8759b [Xiangrui Meng] minor updates to NB
b11659c [Xiangrui Meng] style update
78c4671 [Xiangrui Meng] add libSVMFile to MLContext
f0fe616 [Xiangrui Meng] add a test for sparse linear regression
44733e1 [Xiangrui Meng] use in-place gradient computation
e981396 [Xiangrui Meng] use axpy in Updater
db808a1 [Xiangrui Meng] update JavaLR example
befa592 [Xiangrui Meng] passed scala/java tests
75c83a4 [Xiangrui Meng] passed test compile
1859701 [Xiangrui Meng] passed compile
834ada2 [Xiangrui Meng] optimized MLUtils.computeStats update some ml algorithms to use Vector (cont.)
135ab72 [Xiangrui Meng] merge glm
0e57aa4 [Xiangrui Meng] update Lasso and RidgeRegression to parse the weights correctly from GLM mark createModel protected mark predictPoint protected
d7f629f [Xiangrui Meng] fix a bug in GLM when intercept is not used
3f346ba [Xiangrui Meng] update some ml algorithms to use Vector
  • Loading branch information
mengxr authored and mateiz committed Apr 2, 2014
1 parent 9182c23 commit da0f38b
Showing 1 changed file with 4 additions and 10 deletions.
14 changes: 4 additions & 10 deletions src/main/java/org/apache/spark/mllib/examples/JavaLR.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,17 @@

package org.apache.spark.mllib.examples;

import java.util.regex.Pattern;

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;

import org.apache.spark.mllib.classification.LogisticRegressionWithSGD;
import org.apache.spark.mllib.classification.LogisticRegressionModel;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;

import java.util.Arrays;
import java.util.regex.Pattern;

/**
* Logistic regression based classification using ML Lib.
*/
Expand All @@ -47,14 +46,10 @@ public LabeledPoint call(String line) {
for (int i = 0; i < tok.length; ++i) {
x[i] = Double.parseDouble(tok[i]);
}
return new LabeledPoint(y, x);
return new LabeledPoint(y, Vectors.dense(x));
}
}

public static void printWeights(double[] a) {
System.out.println(Arrays.toString(a));
}

public static void main(String[] args) {
if (args.length != 4) {
System.err.println("Usage: JavaLR <master> <input_dir> <step_size> <niters>");
Expand All @@ -80,8 +75,7 @@ public static void main(String[] args) {
LogisticRegressionModel model = LogisticRegressionWithSGD.train(points.rdd(),
iterations, stepSize);

System.out.print("Final w: ");
printWeights(model.weights());
System.out.print("Final w: " + model.weights());

System.exit(0);
}
Expand Down

0 comments on commit da0f38b

Please sign in to comment.