Skip to content

Commit

Permalink
fixed JavaLinearRegressionSuite.java Java sql api
Browse files Browse the repository at this point in the history
  • Loading branch information
jkbradley committed Feb 5, 2015
1 parent f542997 commit 9872424
Showing 1 changed file with 6 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

package org.apache.spark.ml.classification;

import scala.Tuple2;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
Expand All @@ -29,40 +27,33 @@

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.ml.regression.LinearRegression;
import org.apache.spark.ml.regression.LinearRegressionModel;
import static org.apache.spark.mllib.classification.LogisticRegressionSuite
.generateLogisticInputAsList;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.api.java.JavaSQLContext;
import org.apache.spark.sql.api.java.JavaSchemaRDD;
import org.apache.spark.sql.api.java.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.SchemaRDD;


public class JavaLinearRegressionSuite implements Serializable {

private transient JavaSparkContext jsc;
private transient JavaSQLContext jsql;
private transient JavaSchemaRDD dataset;
private transient SQLContext jsql;
private transient SchemaRDD dataset;
private transient JavaRDD<LabeledPoint> datasetRDD;
private transient JavaRDD<Vector> featuresRDD;
private double eps = 1e-5;

@Before
public void setUp() {
jsc = new JavaSparkContext("local", "JavaLinearRegressionSuite");
jsql = new JavaSQLContext(jsc);
jsql = new SQLContext(jsc);
List<LabeledPoint> points = new ArrayList<LabeledPoint>();
for (org.apache.spark.mllib.regression.LabeledPoint lp:
generateLogisticInputAsList(1.0, 1.0, 100, 42)) {
points.add(new LabeledPoint(lp.label(), lp.features()));
}
datasetRDD = jsc.parallelize(points, 2);
featuresRDD = datasetRDD.map(new Function<LabeledPoint, Vector>() {
@Override public Vector call(LabeledPoint lp) { return lp.features(); }
});
dataset = jsql.applySchema(datasetRDD, LabeledPoint.class);
dataset.registerTempTable("dataset");
}
Expand All @@ -79,7 +70,7 @@ public void linearRegressionDefaultParams() {
assert(lr.getLabelCol().equals("label"));
LinearRegressionModel model = lr.fit(dataset);
model.transform(dataset).registerTempTable("prediction");
JavaSchemaRDD predictions = jsql.sql("SELECT label, prediction FROM prediction");
SchemaRDD predictions = jsql.sql("SELECT label, prediction FROM prediction");
predictions.collect();
// Check defaults
assert(model.getFeaturesCol().equals("features"));
Expand Down

0 comments on commit 9872424

Please sign in to comment.