diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java
index 0fbee6e433608..5041e0b6d34b0 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java
@@ -116,10 +116,12 @@ public static void main(String[] args) {
// Make predictions on test documents. cvModel uses the best model found (lrModel).
cvModel.transform(test).registerTempTable("prediction");
- DataFrame predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction");
+ DataFrame predictions = jsql.sql("SELECT id, text, probability, prediction FROM prediction");
for (Row r: predictions.collect()) {
- System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> score=" + r.get(2)
+ System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2)
+ ", prediction=" + r.get(3));
}
+
+ jsc.stop();
}
}
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
new file mode 100644
index 0000000000000..42d4d7d0bef26
--- /dev/null
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
@@ -0,0 +1,217 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.ml;
+
+import java.util.List;
+
+import com.google.common.collect.Lists;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.ml.classification.Classifier;
+import org.apache.spark.ml.classification.ClassificationModel;
+import org.apache.spark.ml.param.IntParam;
+import org.apache.spark.ml.param.ParamMap;
+import org.apache.spark.ml.param.Params;
+import org.apache.spark.ml.param.Params$;
+import org.apache.spark.mllib.linalg.BLAS;
+import org.apache.spark.mllib.linalg.Vector;
+import org.apache.spark.mllib.linalg.Vectors;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SQLContext;
+
+
+/**
+ * A simple example demonstrating how to write your own learning algorithm using Estimator,
+ * Transformer, and other abstractions.
+ * This mimics {@link org.apache.spark.ml.classification.LogisticRegression}.
+ *
+ * Run with
+ *
+ * bin/run-example ml.JavaDeveloperApiExample
+ *
+ */
+public class JavaDeveloperApiExample {
+
+ public static void main(String[] args) throws Exception {
+ SparkConf conf = new SparkConf().setAppName("JavaDeveloperApiExample");
+ JavaSparkContext jsc = new JavaSparkContext(conf);
+ SQLContext jsql = new SQLContext(jsc);
+
+ // Prepare training data.
+ List localTraining = Lists.newArrayList(
+ new LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)),
+ new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)),
+ new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)),
+ new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5)));
+ DataFrame training = jsql.applySchema(jsc.parallelize(localTraining), LabeledPoint.class);
+
+ // Create a LogisticRegression instance. This instance is an Estimator.
+ MyJavaLogisticRegression lr = new MyJavaLogisticRegression();
+ // Print out the parameters, documentation, and any default values.
+ System.out.println("MyJavaLogisticRegression parameters:\n" + lr.explainParams() + "\n");
+
+ // We may set parameters using setter methods.
+ lr.setMaxIter(10);
+
+ // Learn a LogisticRegression model. This uses the parameters stored in lr.
+ MyJavaLogisticRegressionModel model = lr.fit(training);
+
+ // Prepare test data.
+ List localTest = Lists.newArrayList(
+ new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)),
+ new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)),
+ new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5)));
+ DataFrame test = jsql.applySchema(jsc.parallelize(localTest), LabeledPoint.class);
+
+ // Make predictions on test documents. cvModel uses the best model found (lrModel).
+ DataFrame results = model.transform(test);
+ double sumPredictions = 0;
+ for (Row r : results.select("features", "label", "prediction").collect()) {
+ sumPredictions += r.getDouble(2);
+ }
+ if (sumPredictions != 0.0) {
+ throw new Exception("MyJavaLogisticRegression predicted something other than 0," +
+ " even though all weights are 0!");
+ }
+
+ jsc.stop();
+ }
+}
+
+/**
+ * Example of defining a type of {@link Classifier}.
+ *
+ * NOTE: This is private since it is an example. In practice, you may not want it to be private.
+ */
+class MyJavaLogisticRegression
+ extends Classifier
+ implements Params {
+
+ /**
+ * Param for max number of iterations
+ *
+ * NOTE: The usual way to add a parameter to a model or algorithm is to include:
+ * - val myParamName: ParamType
+ * - def getMyParamName
+ * - def setMyParamName
+ */
+ IntParam maxIter = new IntParam(this, "maxIter", "max number of iterations");
+
+ int getMaxIter() { return (int)get(maxIter); }
+
+ public MyJavaLogisticRegression() {
+ setMaxIter(100);
+ }
+
+ // The parameter setter is in this class since it should return type MyJavaLogisticRegression.
+ MyJavaLogisticRegression setMaxIter(int value) {
+ return (MyJavaLogisticRegression)set(maxIter, value);
+ }
+
+ // This method is used by fit().
+ // In Java, we have to make it public since Java does not understand Scala's protected modifier.
+ public MyJavaLogisticRegressionModel train(DataFrame dataset, ParamMap paramMap) {
+ // Extract columns from data using helper method.
+ JavaRDD oldDataset = extractLabeledPoints(dataset, paramMap).toJavaRDD();
+
+ // Do learning to estimate the weight vector.
+ int numFeatures = oldDataset.take(1).get(0).features().size();
+ Vector weights = Vectors.zeros(numFeatures); // Learning would happen here.
+
+ // Create a model, and return it.
+ return new MyJavaLogisticRegressionModel(this, paramMap, weights);
+ }
+}
+
+/**
+ * Example of defining a type of {@link ClassificationModel}.
+ *
+ * NOTE: This is private since it is an example. In practice, you may not want it to be private.
+ */
+class MyJavaLogisticRegressionModel
+ extends ClassificationModel implements Params {
+
+ private MyJavaLogisticRegression parent_;
+ public MyJavaLogisticRegression parent() { return parent_; }
+
+ private ParamMap fittingParamMap_;
+ public ParamMap fittingParamMap() { return fittingParamMap_; }
+
+ private Vector weights_;
+ public Vector weights() { return weights_; }
+
+ public MyJavaLogisticRegressionModel(
+ MyJavaLogisticRegression parent_,
+ ParamMap fittingParamMap_,
+ Vector weights_) {
+ this.parent_ = parent_;
+ this.fittingParamMap_ = fittingParamMap_;
+ this.weights_ = weights_;
+ }
+
+ // This uses the default implementation of transform(), which reads column "features" and outputs
+ // columns "prediction" and "rawPrediction."
+
+ // This uses the default implementation of predict(), which chooses the label corresponding to
+ // the maximum value returned by [[predictRaw()]].
+
+ /**
+ * Raw prediction for each possible label.
+ * The meaning of a "raw" prediction may vary between algorithms, but it intuitively gives
+ * a measure of confidence in each possible label (where larger = more confident).
+ * This internal method is used to implement [[transform()]] and output [[rawPredictionCol]].
+ *
+ * @return vector where element i is the raw prediction for label i.
+ * This raw prediction may be any real number, where a larger value indicates greater
+ * confidence for that label.
+ *
+ * In Java, we have to make this method public since Java does not understand Scala's protected
+ * modifier.
+ */
+ public Vector predictRaw(Vector features) {
+ double margin = BLAS.dot(features, weights_);
+ // There are 2 classes (binary classification), so we return a length-2 vector,
+ // where index i corresponds to class i (i = 0, 1).
+ return Vectors.dense(-margin, margin);
+ }
+
+ /**
+ * Number of classes the label can take. 2 indicates binary classification.
+ */
+ public int numClasses() { return 2; }
+
+ /**
+ * Create a copy of the model.
+ * The copy is shallow, except for the embedded paramMap, which gets a deep copy.
+ *
+ * This is used for the defaul implementation of [[transform()]].
+ *
+ * In Java, we have to make this method public since Java does not understand Scala's protected
+ * modifier.
+ */
+ public MyJavaLogisticRegressionModel copy() {
+ MyJavaLogisticRegressionModel m =
+ new MyJavaLogisticRegressionModel(parent_, fittingParamMap_, weights_);
+ Params$.MODULE$.inheritValues(this.paramMap(), this, m);
+ return m;
+ }
+}
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java
index eaaa344be49c8..cc69e6315fdda 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java
@@ -81,7 +81,7 @@ public static void main(String[] args) {
// One can also combine ParamMaps.
ParamMap paramMap2 = new ParamMap();
- paramMap2.put(lr.scoreCol().w("probability")); // Change output column name
+ paramMap2.put(lr.probabilityCol().w("myProbability")); // Change output column name
ParamMap paramMapCombined = paramMap.$plus$plus(paramMap2);
// Now learn a new model using the paramMapCombined parameters.
@@ -98,14 +98,16 @@ public static void main(String[] args) {
// Make predictions on test documents using the Transformer.transform() method.
// LogisticRegression.transform will only use the 'features' column.
- // Note that model2.transform() outputs a 'probability' column instead of the usual 'score'
- // column since we renamed the lr.scoreCol parameter previously.
+ // Note that model2.transform() outputs a 'myProbability' column instead of the usual
+ // 'probability' column since we renamed the lr.probabilityCol parameter previously.
model2.transform(test).registerTempTable("results");
DataFrame results =
- jsql.sql("SELECT features, label, probability, prediction FROM results");
+ jsql.sql("SELECT features, label, myProbability, prediction FROM results");
for (Row r: results.collect()) {
System.out.println("(" + r.get(0) + ", " + r.get(1) + ") -> prob=" + r.get(2)
+ ", prediction=" + r.get(3));
}
+
+ jsc.stop();
}
}
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java
index 82d665a3e1386..d929f1ad2014a 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java
@@ -85,8 +85,10 @@ public static void main(String[] args) {
model.transform(test).registerTempTable("prediction");
DataFrame predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction");
for (Row r: predictions.collect()) {
- System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> score=" + r.get(2)
+ System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2)
+ ", prediction=" + r.get(3));
}
+
+ jsc.stop();
}
}
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala
index b6c30a007d88f..a2893f78e0fec 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala
@@ -23,6 +23,7 @@ import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.ml.feature.{HashingTF, Tokenizer}
import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator}
+import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.sql.{Row, SQLContext}
/**
@@ -100,10 +101,10 @@ object CrossValidatorExample {
// Make predictions on test documents. cvModel uses the best model found (lrModel).
cvModel.transform(test)
- .select("id", "text", "score", "prediction")
+ .select("id", "text", "probability", "prediction")
.collect()
- .foreach { case Row(id: Long, text: String, score: Double, prediction: Double) =>
- println("(" + id + ", " + text + ") --> score=" + score + ", prediction=" + prediction)
+ .foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) =>
+ println(s"($id, $text) --> prob=$prob, prediction=$prediction")
}
sc.stop()
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala
new file mode 100644
index 0000000000000..aed44238939c7
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala
@@ -0,0 +1,184 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.ml
+
+import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.ml.classification.{Classifier, ClassifierParams, ClassificationModel}
+import org.apache.spark.ml.param.{Params, IntParam, ParamMap}
+import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors}
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+
+
+/**
+ * A simple example demonstrating how to write your own learning algorithm using Estimator,
+ * Transformer, and other abstractions.
+ * This mimics [[org.apache.spark.ml.classification.LogisticRegression]].
+ * Run with
+ * {{{
+ * bin/run-example ml.DeveloperApiExample
+ * }}}
+ */
+object DeveloperApiExample {
+
+ def main(args: Array[String]) {
+ val conf = new SparkConf().setAppName("DeveloperApiExample")
+ val sc = new SparkContext(conf)
+ val sqlContext = new SQLContext(sc)
+ import sqlContext.implicits._
+
+ // Prepare training data.
+ val training = sc.parallelize(Seq(
+ LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)),
+ LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)),
+ LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)),
+ LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5))))
+
+ // Create a LogisticRegression instance. This instance is an Estimator.
+ val lr = new MyLogisticRegression()
+ // Print out the parameters, documentation, and any default values.
+ println("MyLogisticRegression parameters:\n" + lr.explainParams() + "\n")
+
+ // We may set parameters using setter methods.
+ lr.setMaxIter(10)
+
+ // Learn a LogisticRegression model. This uses the parameters stored in lr.
+ val model = lr.fit(training)
+
+ // Prepare test data.
+ val test = sc.parallelize(Seq(
+ LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)),
+ LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)),
+ LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5))))
+
+ // Make predictions on test data.
+ val sumPredictions: Double = model.transform(test)
+ .select("features", "label", "prediction")
+ .collect()
+ .map { case Row(features: Vector, label: Double, prediction: Double) =>
+ prediction
+ }.sum
+ assert(sumPredictions == 0.0,
+ "MyLogisticRegression predicted something other than 0, even though all weights are 0!")
+
+ sc.stop()
+ }
+}
+
+/**
+ * Example of defining a parameter trait for a user-defined type of [[Classifier]].
+ *
+ * NOTE: This is private since it is an example. In practice, you may not want it to be private.
+ */
+private trait MyLogisticRegressionParams extends ClassifierParams {
+
+ /**
+ * Param for max number of iterations
+ *
+ * NOTE: The usual way to add a parameter to a model or algorithm is to include:
+ * - val myParamName: ParamType
+ * - def getMyParamName
+ * - def setMyParamName
+ * Here, we have a trait to be mixed in with the Estimator and Model (MyLogisticRegression
+ * and MyLogisticRegressionModel). We place the setter (setMaxIter) method in the Estimator
+ * class since the maxIter parameter is only used during training (not in the Model).
+ */
+ val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations")
+ def getMaxIter: Int = get(maxIter)
+}
+
+/**
+ * Example of defining a type of [[Classifier]].
+ *
+ * NOTE: This is private since it is an example. In practice, you may not want it to be private.
+ */
+private class MyLogisticRegression
+ extends Classifier[Vector, MyLogisticRegression, MyLogisticRegressionModel]
+ with MyLogisticRegressionParams {
+
+ setMaxIter(100) // Initialize
+
+ // The parameter setter is in this class since it should return type MyLogisticRegression.
+ def setMaxIter(value: Int): this.type = set(maxIter, value)
+
+ // This method is used by fit()
+ override protected def train(
+ dataset: DataFrame,
+ paramMap: ParamMap): MyLogisticRegressionModel = {
+ // Extract columns from data using helper method.
+ val oldDataset = extractLabeledPoints(dataset, paramMap)
+
+ // Do learning to estimate the weight vector.
+ val numFeatures = oldDataset.take(1)(0).features.size
+ val weights = Vectors.zeros(numFeatures) // Learning would happen here.
+
+ // Create a model, and return it.
+ new MyLogisticRegressionModel(this, paramMap, weights)
+ }
+}
+
+/**
+ * Example of defining a type of [[ClassificationModel]].
+ *
+ * NOTE: This is private since it is an example. In practice, you may not want it to be private.
+ */
+private class MyLogisticRegressionModel(
+ override val parent: MyLogisticRegression,
+ override val fittingParamMap: ParamMap,
+ val weights: Vector)
+ extends ClassificationModel[Vector, MyLogisticRegressionModel]
+ with MyLogisticRegressionParams {
+
+ // This uses the default implementation of transform(), which reads column "features" and outputs
+ // columns "prediction" and "rawPrediction."
+
+ // This uses the default implementation of predict(), which chooses the label corresponding to
+ // the maximum value returned by [[predictRaw()]].
+
+ /**
+ * Raw prediction for each possible label.
+ * The meaning of a "raw" prediction may vary between algorithms, but it intuitively gives
+ * a measure of confidence in each possible label (where larger = more confident).
+ * This internal method is used to implement [[transform()]] and output [[rawPredictionCol]].
+ *
+ * @return vector where element i is the raw prediction for label i.
+ * This raw prediction may be any real number, where a larger value indicates greater
+ * confidence for that label.
+ */
+ override protected def predictRaw(features: Vector): Vector = {
+ val margin = BLAS.dot(features, weights)
+ // There are 2 classes (binary classification), so we return a length-2 vector,
+ // where index i corresponds to class i (i = 0, 1).
+ Vectors.dense(-margin, margin)
+ }
+
+ /** Number of classes the label can take. 2 indicates binary classification. */
+ override val numClasses: Int = 2
+
+ /**
+ * Create a copy of the model.
+ * The copy is shallow, except for the embedded paramMap, which gets a deep copy.
+ *
+ * This is used for the defaul implementation of [[transform()]].
+ */
+ override protected def copy(): MyLogisticRegressionModel = {
+ val m = new MyLogisticRegressionModel(parent, fittingParamMap, weights)
+ Params.inheritValues(this.paramMap, this, m)
+ m
+ }
+}
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala
index 4d1530cd1349f..80c9f5ff5781e 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala
@@ -72,7 +72,7 @@ object SimpleParamsExample {
paramMap.put(lr.regParam -> 0.1, lr.threshold -> 0.55) // Specify multiple Params.
// One can also combine ParamMaps.
- val paramMap2 = ParamMap(lr.scoreCol -> "probability") // Change output column name
+ val paramMap2 = ParamMap(lr.probabilityCol -> "myProbability") // Change output column name
val paramMapCombined = paramMap ++ paramMap2
// Now learn a new model using the paramMapCombined parameters.
@@ -80,21 +80,21 @@ object SimpleParamsExample {
val model2 = lr.fit(training, paramMapCombined)
println("Model 2 was fit using parameters: " + model2.fittingParamMap)
- // Prepare test documents.
+ // Prepare test data.
val test = sc.parallelize(Seq(
LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)),
LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)),
LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5))))
- // Make predictions on test documents using the Transformer.transform() method.
+ // Make predictions on test data using the Transformer.transform() method.
// LogisticRegression.transform will only use the 'features' column.
- // Note that model2.transform() outputs a 'probability' column instead of the usual 'score'
- // column since we renamed the lr.scoreCol parameter previously.
+ // Note that model2.transform() outputs a 'myProbability' column instead of the usual
+ // 'probability' column since we renamed the lr.probabilityCol parameter previously.
model2.transform(test)
- .select("features", "label", "probability", "prediction")
+ .select("features", "label", "myProbability", "prediction")
.collect()
- .foreach { case Row(features: Vector, label: Double, prob: Double, prediction: Double) =>
- println("(" + features + ", " + label + ") -> prob=" + prob + ", prediction=" + prediction)
+ .foreach { case Row(features: Vector, label: Double, prob: Vector, prediction: Double) =>
+ println("($features, $label) -> prob=$prob, prediction=$prediction")
}
sc.stop()
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala
index dbbe01dd5ce8e..968cb292120d8 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala
@@ -23,6 +23,7 @@ import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.feature.{HashingTF, Tokenizer}
+import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.sql.{Row, SQLContext}
@BeanInfo
@@ -79,10 +80,10 @@ object SimpleTextClassificationPipeline {
// Make predictions on test documents.
model.transform(test)
- .select("id", "text", "score", "prediction")
+ .select("id", "text", "probability", "prediction")
.collect()
- .foreach { case Row(id: Long, text: String, score: Double, prediction: Double) =>
- println("(" + id + ", " + text + ") --> score=" + score + ", prediction=" + prediction)
+ .foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) =>
+ println("($id, $text) --> prob=$prob, prediction=$prediction")
}
sc.stop()
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
index bc3defe968afd..eff7ef925dfbd 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
@@ -34,7 +34,8 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params {
* Fits a single model to the input data with optional parameters.
*
* @param dataset input dataset
- * @param paramPairs optional list of param pairs (overwrite embedded params)
+ * @param paramPairs Optional list of param pairs.
+ * These values override any specified in this Estimator's embedded ParamMap.
* @return fitted model
*/
@varargs
@@ -47,7 +48,8 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params {
* Fits a single model to the input data with provided parameter map.
*
* @param dataset input dataset
- * @param paramMap parameter map
+ * @param paramMap Parameter map.
+ * These values override any specified in this Estimator's embedded ParamMap.
* @return fitted model
*/
def fit(dataset: DataFrame, paramMap: ParamMap): M
@@ -58,7 +60,8 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params {
* Subclasses could overwrite this to optimize multi-model training.
*
* @param dataset input dataset
- * @param paramMaps an array of parameter maps
+ * @param paramMaps An array of parameter maps.
+ * These values override any specified in this Estimator's embedded ParamMap.
* @return fitted models, matching the input parameter maps
*/
def fit(dataset: DataFrame, paramMaps: Array[ParamMap]): Seq[M] = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
new file mode 100644
index 0000000000000..1bf8eb4640d11
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
@@ -0,0 +1,206 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification
+
+import org.apache.spark.annotation.{DeveloperApi, AlphaComponent}
+import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor, PredictorParams}
+import org.apache.spark.ml.param.{Params, ParamMap, HasRawPredictionCol}
+import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
+import org.apache.spark.sql.Dsl._
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
+
+
+/**
+ * :: DeveloperApi ::
+ * Params for classification.
+ *
+ * NOTE: This is currently private[spark] but will be made public later once it is stabilized.
+ */
+@DeveloperApi
+private[spark] trait ClassifierParams extends PredictorParams
+ with HasRawPredictionCol {
+
+ override protected def validateAndTransformSchema(
+ schema: StructType,
+ paramMap: ParamMap,
+ fitting: Boolean,
+ featuresDataType: DataType): StructType = {
+ val parentSchema = super.validateAndTransformSchema(schema, paramMap, fitting, featuresDataType)
+ val map = this.paramMap ++ paramMap
+ addOutputColumn(parentSchema, map(rawPredictionCol), new VectorUDT)
+ }
+}
+
+/**
+ * :: AlphaComponent ::
+ * Single-label binary or multiclass classification.
+ * Classes are indexed {0, 1, ..., numClasses - 1}.
+ *
+ * @tparam FeaturesType Type of input features. E.g., [[Vector]]
+ * @tparam E Concrete Estimator type
+ * @tparam M Concrete Model type
+ *
+ * NOTE: This is currently private[spark] but will be made public later once it is stabilized.
+ */
+@AlphaComponent
+private[spark] abstract class Classifier[
+ FeaturesType,
+ E <: Classifier[FeaturesType, E, M],
+ M <: ClassificationModel[FeaturesType, M]]
+ extends Predictor[FeaturesType, E, M]
+ with ClassifierParams {
+
+ def setRawPredictionCol(value: String): E =
+ set(rawPredictionCol, value).asInstanceOf[E]
+
+ // TODO: defaultEvaluator (follow-up PR)
+}
+
+/**
+ * :: AlphaComponent ::
+ * Model produced by a [[Classifier]].
+ * Classes are indexed {0, 1, ..., numClasses - 1}.
+ *
+ * @tparam FeaturesType Type of input features. E.g., [[Vector]]
+ * @tparam M Concrete Model type
+ *
+ * NOTE: This is currently private[spark] but will be made public later once it is stabilized.
+ */
+@AlphaComponent
+private[spark]
+abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[FeaturesType, M]]
+ extends PredictionModel[FeaturesType, M] with ClassifierParams {
+
+ def setRawPredictionCol(value: String): M = set(rawPredictionCol, value).asInstanceOf[M]
+
+ /** Number of classes (values which the label can take). */
+ def numClasses: Int
+
+ /**
+ * Transforms dataset by reading from [[featuresCol]], and appending new columns as specified by
+ * parameters:
+ * - predicted labels as [[predictionCol]] of type [[Double]]
+ * - raw predictions (confidences) as [[rawPredictionCol]] of type [[Vector]].
+ *
+ * @param dataset input dataset
+ * @param paramMap additional parameters, overwrite embedded params
+ * @return transformed dataset
+ */
+ override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
+ // This default implementation should be overridden as needed.
+
+ // Check schema
+ transformSchema(dataset.schema, paramMap, logging = true)
+ val map = this.paramMap ++ paramMap
+
+ // Prepare model
+ val tmpModel = if (paramMap.size != 0) {
+ val tmpModel = this.copy()
+ Params.inheritValues(paramMap, parent, tmpModel)
+ tmpModel
+ } else {
+ this
+ }
+
+ val (numColsOutput, outputData) =
+ ClassificationModel.transformColumnsImpl[FeaturesType](dataset, tmpModel, map)
+ if (numColsOutput == 0) {
+ logWarning(s"$uid: ClassificationModel.transform() was called as NOOP" +
+ " since no output columns were set.")
+ }
+ outputData
+ }
+
+ /**
+ * :: DeveloperApi ::
+ *
+ * Predict label for the given features.
+ * This internal method is used to implement [[transform()]] and output [[predictionCol]].
+ *
+ * This default implementation for classification predicts the index of the maximum value
+ * from [[predictRaw()]].
+ */
+ @DeveloperApi
+ override protected def predict(features: FeaturesType): Double = {
+ predictRaw(features).toArray.zipWithIndex.maxBy(_._1)._2
+ }
+
+ /**
+ * :: DeveloperApi ::
+ *
+ * Raw prediction for each possible label.
+ * The meaning of a "raw" prediction may vary between algorithms, but it intuitively gives
+ * a measure of confidence in each possible label (where larger = more confident).
+ * This internal method is used to implement [[transform()]] and output [[rawPredictionCol]].
+ *
+ * @return vector where element i is the raw prediction for label i.
+ * This raw prediction may be any real number, where a larger value indicates greater
+ * confidence for that label.
+ */
+ @DeveloperApi
+ protected def predictRaw(features: FeaturesType): Vector
+
+}
+
+private[ml] object ClassificationModel {
+
+ /**
+ * Added prediction column(s). This is separated from [[ClassificationModel.transform()]]
+ * since it is used by [[org.apache.spark.ml.classification.ProbabilisticClassificationModel]].
+ * @param dataset Input dataset
+ * @param map Parameter map. This will NOT be merged with the embedded paramMap; the merge
+ * should already be done.
+ * @return (number of columns added, transformed dataset)
+ */
+ def transformColumnsImpl[FeaturesType](
+ dataset: DataFrame,
+ model: ClassificationModel[FeaturesType, _],
+ map: ParamMap): (Int, DataFrame) = {
+
+ // Output selected columns only.
+ // This is a bit complicated since it tries to avoid repeated computation.
+ var tmpData = dataset
+ var numColsOutput = 0
+ if (map(model.rawPredictionCol) != "") {
+ // output raw prediction
+ val features2raw: FeaturesType => Vector = model.predictRaw
+ tmpData = tmpData.select($"*",
+ callUDF(features2raw, new VectorUDT,
+ col(map(model.featuresCol))).as(map(model.rawPredictionCol)))
+ numColsOutput += 1
+ if (map(model.predictionCol) != "") {
+ val raw2pred: Vector => Double = (rawPred) => {
+ rawPred.toArray.zipWithIndex.maxBy(_._1)._2
+ }
+ tmpData = tmpData.select($"*", callUDF(raw2pred, DoubleType,
+ col(map(model.rawPredictionCol))).as(map(model.predictionCol)))
+ numColsOutput += 1
+ }
+ } else if (map(model.predictionCol) != "") {
+ // output prediction
+ val features2pred: FeaturesType => Double = model.predict
+ tmpData = tmpData.select($"*",
+ callUDF(features2pred, DoubleType,
+ col(map(model.featuresCol))).as(map(model.predictionCol)))
+ numColsOutput += 1
+ }
+ (numColsOutput, tmpData)
+ }
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index b46a5cd8bdf29..c146fe244c66e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -18,61 +18,32 @@
package org.apache.spark.ml.classification
import org.apache.spark.annotation.AlphaComponent
-import org.apache.spark.ml._
import org.apache.spark.ml.param._
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
-import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT}
-import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.sql._
+import org.apache.spark.mllib.linalg.{VectorUDT, BLAS, Vector, Vectors}
+import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.Dsl._
-import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
+import org.apache.spark.sql.types.DoubleType
import org.apache.spark.storage.StorageLevel
+
/**
- * :: AlphaComponent ::
* Params for logistic regression.
*/
-@AlphaComponent
-private[classification] trait LogisticRegressionParams extends Params
- with HasRegParam with HasMaxIter with HasLabelCol with HasThreshold with HasFeaturesCol
- with HasScoreCol with HasPredictionCol {
+private[classification] trait LogisticRegressionParams extends ProbabilisticClassifierParams
+ with HasRegParam with HasMaxIter with HasThreshold
- /**
- * Validates and transforms the input schema with the provided param map.
- * @param schema input schema
- * @param paramMap additional parameters
- * @param fitting whether this is in fitting
- * @return output schema
- */
- protected def validateAndTransformSchema(
- schema: StructType,
- paramMap: ParamMap,
- fitting: Boolean): StructType = {
- val map = this.paramMap ++ paramMap
- val featuresType = schema(map(featuresCol)).dataType
- // TODO: Support casting Array[Double] and Array[Float] to Vector.
- require(featuresType.isInstanceOf[VectorUDT],
- s"Features column ${map(featuresCol)} must be a vector column but got $featuresType.")
- if (fitting) {
- val labelType = schema(map(labelCol)).dataType
- require(labelType == DoubleType,
- s"Cannot convert label column ${map(labelCol)} of type $labelType to a double column.")
- }
- val fieldNames = schema.fieldNames
- require(!fieldNames.contains(map(scoreCol)), s"Score column ${map(scoreCol)} already exists.")
- require(!fieldNames.contains(map(predictionCol)),
- s"Prediction column ${map(predictionCol)} already exists.")
- val outputFields = schema.fields ++ Seq(
- StructField(map(scoreCol), DoubleType, false),
- StructField(map(predictionCol), DoubleType, false))
- StructType(outputFields)
- }
-}
/**
+ * :: AlphaComponent ::
+ *
* Logistic regression.
+ * Currently, this class only supports binary classification.
*/
-class LogisticRegression extends Estimator[LogisticRegressionModel] with LogisticRegressionParams {
+@AlphaComponent
+class LogisticRegression
+ extends ProbabilisticClassifier[Vector, LogisticRegression, LogisticRegressionModel]
+ with LogisticRegressionParams {
setRegParam(0.1)
setMaxIter(100)
@@ -80,68 +51,151 @@ class LogisticRegression extends Estimator[LogisticRegressionModel] with Logisti
def setRegParam(value: Double): this.type = set(regParam, value)
def setMaxIter(value: Int): this.type = set(maxIter, value)
- def setLabelCol(value: String): this.type = set(labelCol, value)
def setThreshold(value: Double): this.type = set(threshold, value)
- def setFeaturesCol(value: String): this.type = set(featuresCol, value)
- def setScoreCol(value: String): this.type = set(scoreCol, value)
- def setPredictionCol(value: String): this.type = set(predictionCol, value)
- override def fit(dataset: DataFrame, paramMap: ParamMap): LogisticRegressionModel = {
- transformSchema(dataset.schema, paramMap, logging = true)
- val map = this.paramMap ++ paramMap
- val instances = dataset.select(map(labelCol), map(featuresCol))
- .map { case Row(label: Double, features: Vector) =>
- LabeledPoint(label, features)
- }.persist(StorageLevel.MEMORY_AND_DISK)
+ override protected def train(dataset: DataFrame, paramMap: ParamMap): LogisticRegressionModel = {
+ // Extract columns from data. If dataset is persisted, do not persist oldDataset.
+ val oldDataset = extractLabeledPoints(dataset, paramMap)
+ val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
+ if (handlePersistence) {
+ oldDataset.persist(StorageLevel.MEMORY_AND_DISK)
+ }
+
+ // Train model
val lr = new LogisticRegressionWithLBFGS
lr.optimizer
- .setRegParam(map(regParam))
- .setNumIterations(map(maxIter))
- val lrm = new LogisticRegressionModel(this, map, lr.run(instances).weights)
- instances.unpersist()
- // copy model params
- Params.inheritValues(map, this, lrm)
- lrm
- }
+ .setRegParam(paramMap(regParam))
+ .setNumIterations(paramMap(maxIter))
+ val oldModel = lr.run(oldDataset)
+ val lrm = new LogisticRegressionModel(this, paramMap, oldModel.weights, oldModel.intercept)
- private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
- validateAndTransformSchema(schema, paramMap, fitting = true)
+ if (handlePersistence) {
+ oldDataset.unpersist()
+ }
+ lrm
}
}
+
/**
* :: AlphaComponent ::
+ *
* Model produced by [[LogisticRegression]].
*/
@AlphaComponent
class LogisticRegressionModel private[ml] (
override val parent: LogisticRegression,
override val fittingParamMap: ParamMap,
- weights: Vector)
- extends Model[LogisticRegressionModel] with LogisticRegressionParams {
+ val weights: Vector,
+ val intercept: Double)
+ extends ProbabilisticClassificationModel[Vector, LogisticRegressionModel]
+ with LogisticRegressionParams {
+
+ setThreshold(0.5)
def setThreshold(value: Double): this.type = set(threshold, value)
- def setFeaturesCol(value: String): this.type = set(featuresCol, value)
- def setScoreCol(value: String): this.type = set(scoreCol, value)
- def setPredictionCol(value: String): this.type = set(predictionCol, value)
- private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
- validateAndTransformSchema(schema, paramMap, fitting = false)
+ private val margin: Vector => Double = (features) => {
+ BLAS.dot(features, weights) + intercept
+ }
+
+ private val score: Vector => Double = (features) => {
+ val m = margin(features)
+ 1.0 / (1.0 + math.exp(-m))
}
override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
+ // This is overridden (a) to be more efficient (avoiding re-computing values when creating
+ // multiple output columns) and (b) to handle threshold, which the abstractions do not use.
+ // TODO: We should abstract away the steps defined by UDFs below so that the abstractions
+ // can call whichever UDFs are needed to create the output columns.
+
+ // Check schema
transformSchema(dataset.schema, paramMap, logging = true)
+
val map = this.paramMap ++ paramMap
- val scoreFunction = udf { v: Vector =>
- val margin = BLAS.dot(v, weights)
- 1.0 / (1.0 + math.exp(-margin))
+
+ // Output selected columns only.
+ // This is a bit complicated since it tries to avoid repeated computation.
+ // rawPrediction (-margin, margin)
+ // probability (1.0-score, score)
+ // prediction (max margin)
+ var tmpData = dataset
+ var numColsOutput = 0
+ if (map(rawPredictionCol) != "") {
+ val features2raw: Vector => Vector = (features) => predictRaw(features)
+ tmpData = tmpData.select($"*",
+ callUDF(features2raw, new VectorUDT, col(map(featuresCol))).as(map(rawPredictionCol)))
+ numColsOutput += 1
+ }
+ if (map(probabilityCol) != "") {
+ if (map(rawPredictionCol) != "") {
+ val raw2prob: Vector => Vector = { (rawPreds: Vector) =>
+ val prob1 = 1.0 / (1.0 + math.exp(-rawPreds(1)))
+ Vectors.dense(1.0 - prob1, prob1)
+ }
+ tmpData = tmpData.select($"*",
+ callUDF(raw2prob, new VectorUDT, col(map(rawPredictionCol))).as(map(probabilityCol)))
+ } else {
+ val features2prob: Vector => Vector = (features: Vector) => predictProbabilities(features)
+ tmpData = tmpData.select($"*",
+ callUDF(features2prob, new VectorUDT, col(map(featuresCol))).as(map(probabilityCol)))
+ }
+ numColsOutput += 1
}
- val t = map(threshold)
- val predictFunction = udf { score: Double =>
- if (score > t) 1.0 else 0.0
+ if (map(predictionCol) != "") {
+ val t = map(threshold)
+ if (map(probabilityCol) != "") {
+ val predict: Vector => Double = { probs: Vector =>
+ if (probs(1) > t) 1.0 else 0.0
+ }
+ tmpData = tmpData.select($"*",
+ callUDF(predict, DoubleType, col(map(probabilityCol))).as(map(predictionCol)))
+ } else if (map(rawPredictionCol) != "") {
+ val predict: Vector => Double = { rawPreds: Vector =>
+ val prob1 = 1.0 / (1.0 + math.exp(-rawPreds(1)))
+ if (prob1 > t) 1.0 else 0.0
+ }
+ tmpData = tmpData.select($"*",
+ callUDF(predict, DoubleType, col(map(rawPredictionCol))).as(map(predictionCol)))
+ } else {
+ val predict: Vector => Double = (features: Vector) => this.predict(features)
+ tmpData = tmpData.select($"*",
+ callUDF(predict, DoubleType, col(map(featuresCol))).as(map(predictionCol)))
+ }
+ numColsOutput += 1
}
- dataset
- .select($"*", scoreFunction(col(map(featuresCol))).as(map(scoreCol)))
- .select($"*", predictFunction(col(map(scoreCol))).as(map(predictionCol)))
+ if (numColsOutput == 0) {
+ this.logWarning(s"$uid: LogisticRegressionModel.transform() was called as NOOP" +
+ " since no output columns were set.")
+ }
+ tmpData
+ }
+
+ override val numClasses: Int = 2
+
+ /**
+ * Predict label for the given feature vector.
+ * The behavior of this can be adjusted using [[threshold]].
+ */
+ override protected def predict(features: Vector): Double = {
+ println(s"LR.predict with threshold: ${paramMap(threshold)}")
+ if (score(features) > paramMap(threshold)) 1 else 0
+ }
+
+ override protected def predictProbabilities(features: Vector): Vector = {
+ val s = score(features)
+ Vectors.dense(1.0 - s, s)
+ }
+
+ override protected def predictRaw(features: Vector): Vector = {
+ val m = margin(features)
+ Vectors.dense(0.0, m)
+ }
+
+ override protected def copy(): LogisticRegressionModel = {
+ val m = new LogisticRegressionModel(parent, fittingParamMap, weights, intercept)
+ Params.inheritValues(this.paramMap, this, m)
+ m
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
new file mode 100644
index 0000000000000..1202528ca654e
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
@@ -0,0 +1,147 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification
+
+import org.apache.spark.annotation.{AlphaComponent, DeveloperApi}
+import org.apache.spark.ml.param.{HasProbabilityCol, ParamMap, Params}
+import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.Dsl._
+import org.apache.spark.sql.types.{DataType, StructType}
+
+
+/**
+ * Params for probabilistic classification.
+ */
+private[classification] trait ProbabilisticClassifierParams
+ extends ClassifierParams with HasProbabilityCol {
+
+ override protected def validateAndTransformSchema(
+ schema: StructType,
+ paramMap: ParamMap,
+ fitting: Boolean,
+ featuresDataType: DataType): StructType = {
+ val parentSchema = super.validateAndTransformSchema(schema, paramMap, fitting, featuresDataType)
+ val map = this.paramMap ++ paramMap
+ addOutputColumn(parentSchema, map(probabilityCol), new VectorUDT)
+ }
+}
+
+
+/**
+ * :: AlphaComponent ::
+ *
+ * Single-label binary or multiclass classifier which can output class conditional probabilities.
+ *
+ * @tparam FeaturesType Type of input features. E.g., [[Vector]]
+ * @tparam E Concrete Estimator type
+ * @tparam M Concrete Model type
+ *
+ * NOTE: This is currently private[spark] but will be made public later once it is stabilized.
+ */
+@AlphaComponent
+private[spark] abstract class ProbabilisticClassifier[
+ FeaturesType,
+ E <: ProbabilisticClassifier[FeaturesType, E, M],
+ M <: ProbabilisticClassificationModel[FeaturesType, M]]
+ extends Classifier[FeaturesType, E, M] with ProbabilisticClassifierParams {
+
+ def setProbabilityCol(value: String): E = set(probabilityCol, value).asInstanceOf[E]
+}
+
+
+/**
+ * :: AlphaComponent ::
+ *
+ * Model produced by a [[ProbabilisticClassifier]].
+ * Classes are indexed {0, 1, ..., numClasses - 1}.
+ *
+ * @tparam FeaturesType Type of input features. E.g., [[Vector]]
+ * @tparam M Concrete Model type
+ *
+ * NOTE: This is currently private[spark] but will be made public later once it is stabilized.
+ */
+@AlphaComponent
+private[spark] abstract class ProbabilisticClassificationModel[
+ FeaturesType,
+ M <: ProbabilisticClassificationModel[FeaturesType, M]]
+ extends ClassificationModel[FeaturesType, M] with ProbabilisticClassifierParams {
+
+ def setProbabilityCol(value: String): M = set(probabilityCol, value).asInstanceOf[M]
+
+ /**
+ * Transforms dataset by reading from [[featuresCol]], and appending new columns as specified by
+ * parameters:
+ * - predicted labels as [[predictionCol]] of type [[Double]]
+ * - raw predictions (confidences) as [[rawPredictionCol]] of type [[Vector]]
+ * - probability of each class as [[probabilityCol]] of type [[Vector]].
+ *
+ * @param dataset input dataset
+ * @param paramMap additional parameters, overwrite embedded params
+ * @return transformed dataset
+ */
+ override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
+ // This default implementation should be overridden as needed.
+
+ // Check schema
+ transformSchema(dataset.schema, paramMap, logging = true)
+ val map = this.paramMap ++ paramMap
+
+ // Prepare model
+ val tmpModel = if (paramMap.size != 0) {
+ val tmpModel = this.copy()
+ Params.inheritValues(paramMap, parent, tmpModel)
+ tmpModel
+ } else {
+ this
+ }
+
+ val (numColsOutput, outputData) =
+ ClassificationModel.transformColumnsImpl[FeaturesType](dataset, tmpModel, map)
+
+ // Output selected columns only.
+ if (map(probabilityCol) != "") {
+ // output probabilities
+ val features2probs: FeaturesType => Vector = (features) => {
+ tmpModel.predictProbabilities(features)
+ }
+ outputData.select($"*",
+ callUDF(features2probs, new VectorUDT, col(map(featuresCol))).as(map(probabilityCol)))
+ } else {
+ if (numColsOutput == 0) {
+ this.logWarning(s"$uid: ProbabilisticClassificationModel.transform() was called as NOOP" +
+ " since no output columns were set.")
+ }
+ outputData
+ }
+ }
+
+ /**
+ * :: DeveloperApi ::
+ *
+ * Predict the probability of each class given the features.
+ * These predictions are also called class conditional probabilities.
+ *
+ * WARNING: Not all models output well-calibrated probability estimates! These probabilities
+ * should be treated as confidences, not precise probabilities.
+ *
+ * This internal method is used to implement [[transform()]] and output [[probabilityCol]].
+ */
+ @DeveloperApi
+ protected def predictProbabilities(features: FeaturesType): Vector
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
index 1979ab9eb6516..f21a30627e540 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
@@ -18,19 +18,22 @@
package org.apache.spark.ml.evaluation
import org.apache.spark.annotation.AlphaComponent
-import org.apache.spark.ml._
+import org.apache.spark.ml.Evaluator
import org.apache.spark.ml.param._
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
+import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.types.DoubleType
+
/**
* :: AlphaComponent ::
+ *
* Evaluator for binary classification, which expects two input columns: score and label.
*/
@AlphaComponent
class BinaryClassificationEvaluator extends Evaluator with Params
- with HasScoreCol with HasLabelCol {
+ with HasRawPredictionCol with HasLabelCol {
/** param for metric name in evaluation */
val metricName: Param[String] = new Param(this, "metricName",
@@ -38,23 +41,20 @@ class BinaryClassificationEvaluator extends Evaluator with Params
def getMetricName: String = get(metricName)
def setMetricName(value: String): this.type = set(metricName, value)
- def setScoreCol(value: String): this.type = set(scoreCol, value)
+ def setScoreCol(value: String): this.type = set(rawPredictionCol, value)
def setLabelCol(value: String): this.type = set(labelCol, value)
override def evaluate(dataset: DataFrame, paramMap: ParamMap): Double = {
val map = this.paramMap ++ paramMap
val schema = dataset.schema
- val scoreType = schema(map(scoreCol)).dataType
- require(scoreType == DoubleType,
- s"Score column ${map(scoreCol)} must be double type but found $scoreType")
- val labelType = schema(map(labelCol)).dataType
- require(labelType == DoubleType,
- s"Label column ${map(labelCol)} must be double type but found $labelType")
+ checkInputColumn(schema, map(rawPredictionCol), new VectorUDT)
+ checkInputColumn(schema, map(labelCol), DoubleType)
- val scoreAndLabels = dataset.select(map(scoreCol), map(labelCol))
- .map { case Row(score: Double, label: Double) =>
- (score, label)
+ // TODO: When dataset metadata has been implemented, check rawPredictionCol vector length = 2.
+ val scoreAndLabels = dataset.select(map(rawPredictionCol), map(labelCol))
+ .map { case Row(rawPrediction: Vector, label: Double) =>
+ (rawPrediction(1), label)
}
val metrics = new BinaryClassificationMetrics(scoreAndLabels)
val metric = map(metricName) match {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
index e622a5cf9e6f3..0b1f90daa7d8e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
@@ -29,11 +29,11 @@ import org.apache.spark.sql.types.{DataType, StringType, ArrayType}
@AlphaComponent
class Tokenizer extends UnaryTransformer[String, Seq[String], Tokenizer] {
- protected override def createTransformFunc(paramMap: ParamMap): String => Seq[String] = {
+ override protected def createTransformFunc(paramMap: ParamMap): String => Seq[String] = {
_.toLowerCase.split("\\s")
}
- protected override def validateInputType(inputType: DataType): Unit = {
+ override protected def validateInputType(inputType: DataType): Unit = {
require(inputType == StringType, s"Input type must be string type but got $inputType.")
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala
new file mode 100644
index 0000000000000..89b53f3890ea3
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala
@@ -0,0 +1,234 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.impl.estimator
+
+import org.apache.spark.annotation.{AlphaComponent, DeveloperApi}
+import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.ml.param._
+import org.apache.spark.mllib.linalg.{VectorUDT, Vector}
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.Dsl._
+import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
+
+
+/**
+ * :: DeveloperApi ::
+ *
+ * Trait for parameters for prediction (regression and classification).
+ *
+ * NOTE: This is currently private[spark] but will be made public later once it is stabilized.
+ */
+@DeveloperApi
+private[spark] trait PredictorParams extends Params
+ with HasLabelCol with HasFeaturesCol with HasPredictionCol {
+
+ /**
+ * Validates and transforms the input schema with the provided param map.
+ * @param schema input schema
+ * @param paramMap additional parameters
+ * @param fitting whether this is in fitting
+ * @param featuresDataType SQL DataType for FeaturesType.
+ * E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features.
+ * @return output schema
+ */
+ protected def validateAndTransformSchema(
+ schema: StructType,
+ paramMap: ParamMap,
+ fitting: Boolean,
+ featuresDataType: DataType): StructType = {
+ val map = this.paramMap ++ paramMap
+ // TODO: Support casting Array[Double] and Array[Float] to Vector when FeaturesType = Vector
+ checkInputColumn(schema, map(featuresCol), featuresDataType)
+ if (fitting) {
+ // TODO: Allow other numeric types
+ checkInputColumn(schema, map(labelCol), DoubleType)
+ }
+ addOutputColumn(schema, map(predictionCol), DoubleType)
+ }
+}
+
+/**
+ * :: AlphaComponent ::
+ *
+ * Abstraction for prediction problems (regression and classification).
+ *
+ * @tparam FeaturesType Type of features.
+ * E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features.
+ * @tparam Learner Specialization of this class. If you subclass this type, use this type
+ * parameter to specify the concrete type.
+ * @tparam M Specialization of [[PredictionModel]]. If you subclass this type, use this type
+ * parameter to specify the concrete type for the corresponding model.
+ *
+ * NOTE: This is currently private[spark] but will be made public later once it is stabilized.
+ */
+@AlphaComponent
+private[spark] abstract class Predictor[
+ FeaturesType,
+ Learner <: Predictor[FeaturesType, Learner, M],
+ M <: PredictionModel[FeaturesType, M]]
+ extends Estimator[M] with PredictorParams {
+
+ def setLabelCol(value: String): Learner = set(labelCol, value).asInstanceOf[Learner]
+ def setFeaturesCol(value: String): Learner = set(featuresCol, value).asInstanceOf[Learner]
+ def setPredictionCol(value: String): Learner = set(predictionCol, value).asInstanceOf[Learner]
+
+ override def fit(dataset: DataFrame, paramMap: ParamMap): M = {
+ // This handles a few items such as schema validation.
+ // Developers only need to implement train().
+ transformSchema(dataset.schema, paramMap, logging = true)
+ val map = this.paramMap ++ paramMap
+ val model = train(dataset, map)
+ Params.inheritValues(map, this, model) // copy params to model
+ model
+ }
+
+ /**
+ * :: DeveloperApi ::
+ *
+ * Train a model using the given dataset and parameters.
+ * Developers can implement this instead of [[fit()]] to avoid dealing with schema validation
+ * and copying parameters into the model.
+ *
+ * @param dataset Training dataset
+ * @param paramMap Parameter map. Unlike [[fit()]]'s paramMap, this paramMap has already
+ * been combined with the embedded ParamMap.
+ * @return Fitted model
+ */
+ @DeveloperApi
+ protected def train(dataset: DataFrame, paramMap: ParamMap): M
+
+ /**
+ * :: DeveloperApi ::
+ *
+ * Returns the SQL DataType corresponding to the FeaturesType type parameter.
+ *
+ * This is used by [[validateAndTransformSchema()]].
+ * This workaround is needed since SQL has different APIs for Scala and Java.
+ *
+ * The default value is VectorUDT, but it may be overridden if FeaturesType is not Vector.
+ */
+ @DeveloperApi
+ protected def featuresDataType: DataType = new VectorUDT
+
+ private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ validateAndTransformSchema(schema, paramMap, fitting = true, featuresDataType)
+ }
+
+ /**
+ * Extract [[labelCol]] and [[featuresCol]] from the given dataset,
+ * and put it in an RDD with strong types.
+ */
+ protected def extractLabeledPoints(dataset: DataFrame, paramMap: ParamMap): RDD[LabeledPoint] = {
+ val map = this.paramMap ++ paramMap
+ dataset.select(map(labelCol), map(featuresCol))
+ .map { case Row(label: Double, features: Vector) =>
+ LabeledPoint(label, features)
+ }
+ }
+}
+
+/**
+ * :: AlphaComponent ::
+ *
+ * Abstraction for a model for prediction tasks (regression and classification).
+ *
+ * @tparam FeaturesType Type of features.
+ * E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features.
+ * @tparam M Specialization of [[PredictionModel]]. If you subclass this type, use this type
+ * parameter to specify the concrete type for the corresponding model.
+ *
+ * NOTE: This is currently private[spark] but will be made public later once it is stabilized.
+ */
+@AlphaComponent
+private[spark] abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType, M]]
+ extends Model[M] with PredictorParams {
+
+ def setFeaturesCol(value: String): M = set(featuresCol, value).asInstanceOf[M]
+
+ def setPredictionCol(value: String): M = set(predictionCol, value).asInstanceOf[M]
+
+ /**
+ * :: DeveloperApi ::
+ *
+ * Returns the SQL DataType corresponding to the FeaturesType type parameter.
+ *
+ * This is used by [[validateAndTransformSchema()]].
+ * This workaround is needed since SQL has different APIs for Scala and Java.
+ *
+ * The default value is VectorUDT, but it may be overridden if FeaturesType is not Vector.
+ */
+ @DeveloperApi
+ protected def featuresDataType: DataType = new VectorUDT
+
+ private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ validateAndTransformSchema(schema, paramMap, fitting = false, featuresDataType)
+ }
+
+ /**
+ * Transforms dataset by reading from [[featuresCol]], calling [[predict()]], and storing
+ * the predictions as a new column [[predictionCol]].
+ *
+ * @param dataset input dataset
+ * @param paramMap additional parameters, overwrite embedded params
+ * @return transformed dataset with [[predictionCol]] of type [[Double]]
+ */
+ override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
+ // This default implementation should be overridden as needed.
+
+ // Check schema
+ transformSchema(dataset.schema, paramMap, logging = true)
+ val map = this.paramMap ++ paramMap
+
+ // Prepare model
+ val tmpModel = if (paramMap.size != 0) {
+ val tmpModel = this.copy()
+ Params.inheritValues(paramMap, parent, tmpModel)
+ tmpModel
+ } else {
+ this
+ }
+
+ if (map(predictionCol) != "") {
+ val pred: FeaturesType => Double = (features) => {
+ tmpModel.predict(features)
+ }
+ dataset.select($"*", callUDF(pred, DoubleType, col(map(featuresCol))).as(map(predictionCol)))
+ } else {
+ this.logWarning(s"$uid: Predictor.transform() was called as NOOP" +
+ " since no output columns were set.")
+ dataset
+ }
+ }
+
+ /**
+ * :: DeveloperApi ::
+ *
+ * Predict label for the given features.
+ * This internal method is used to implement [[transform()]] and output [[predictionCol]].
+ */
+ @DeveloperApi
+ protected def predict(features: FeaturesType): Double
+
+ /**
+ * Create a copy of the model.
+ * The copy is shallow, except for the embedded paramMap, which gets a deep copy.
+ */
+ protected def copy(): M
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
index 5fb4379e23c2f..17ece897a6c55 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
@@ -22,8 +22,10 @@ import scala.collection.mutable
import java.lang.reflect.Modifier
-import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.annotation.{AlphaComponent, DeveloperApi}
import org.apache.spark.ml.Identifiable
+import org.apache.spark.sql.types.{DataType, StructField, StructType}
+
/**
* :: AlphaComponent ::
@@ -65,37 +67,47 @@ class Param[T] (
// specialize primitive-typed params because Java doesn't recognize scala.Double, scala.Int, ...
/** Specialized version of [[Param[Double]]] for Java. */
-class DoubleParam(parent: Params, name: String, doc: String, defaultValue: Option[Double] = None)
+class DoubleParam(parent: Params, name: String, doc: String, defaultValue: Option[Double])
extends Param[Double](parent, name, doc, defaultValue) {
+ def this(parent: Params, name: String, doc: String) = this(parent, name, doc, None)
+
override def w(value: Double): ParamPair[Double] = super.w(value)
}
/** Specialized version of [[Param[Int]]] for Java. */
-class IntParam(parent: Params, name: String, doc: String, defaultValue: Option[Int] = None)
+class IntParam(parent: Params, name: String, doc: String, defaultValue: Option[Int])
extends Param[Int](parent, name, doc, defaultValue) {
+ def this(parent: Params, name: String, doc: String) = this(parent, name, doc, None)
+
override def w(value: Int): ParamPair[Int] = super.w(value)
}
/** Specialized version of [[Param[Float]]] for Java. */
-class FloatParam(parent: Params, name: String, doc: String, defaultValue: Option[Float] = None)
+class FloatParam(parent: Params, name: String, doc: String, defaultValue: Option[Float])
extends Param[Float](parent, name, doc, defaultValue) {
+ def this(parent: Params, name: String, doc: String) = this(parent, name, doc, None)
+
override def w(value: Float): ParamPair[Float] = super.w(value)
}
/** Specialized version of [[Param[Long]]] for Java. */
-class LongParam(parent: Params, name: String, doc: String, defaultValue: Option[Long] = None)
+class LongParam(parent: Params, name: String, doc: String, defaultValue: Option[Long])
extends Param[Long](parent, name, doc, defaultValue) {
+ def this(parent: Params, name: String, doc: String) = this(parent, name, doc, None)
+
override def w(value: Long): ParamPair[Long] = super.w(value)
}
/** Specialized version of [[Param[Boolean]]] for Java. */
-class BooleanParam(parent: Params, name: String, doc: String, defaultValue: Option[Boolean] = None)
+class BooleanParam(parent: Params, name: String, doc: String, defaultValue: Option[Boolean])
extends Param[Boolean](parent, name, doc, defaultValue) {
+ def this(parent: Params, name: String, doc: String) = this(parent, name, doc, None)
+
override def w(value: Boolean): ParamPair[Boolean] = super.w(value)
}
@@ -158,7 +170,7 @@ trait Params extends Identifiable with Serializable {
/**
* Sets a parameter in the embedded param map.
*/
- private[ml] def set[T](param: Param[T], value: T): this.type = {
+ protected def set[T](param: Param[T], value: T): this.type = {
require(param.parent.eq(this))
paramMap.put(param.asInstanceOf[Param[Any]], value)
this
@@ -174,7 +186,7 @@ trait Params extends Identifiable with Serializable {
/**
* Gets the value of a parameter in the embedded param map.
*/
- private[ml] def get[T](param: Param[T]): T = {
+ protected def get[T](param: Param[T]): T = {
require(param.parent.eq(this))
paramMap(param)
}
@@ -183,9 +195,40 @@ trait Params extends Identifiable with Serializable {
* Internal param map.
*/
protected val paramMap: ParamMap = ParamMap.empty
+
+ /**
+ * Check whether the given schema contains an input column.
+ * @param colName Parameter name for the input column.
+ * @param dataType SQL DataType of the input column.
+ */
+ protected def checkInputColumn(schema: StructType, colName: String, dataType: DataType): Unit = {
+ val actualDataType = schema(colName).dataType
+ require(actualDataType.equals(dataType),
+ s"Input column $colName must be of type $dataType" +
+ s" but was actually $actualDataType. Column param description: ${getParam(colName)}")
+ }
+
+ protected def addOutputColumn(
+ schema: StructType,
+ colName: String,
+ dataType: DataType): StructType = {
+ if (colName.length == 0) return schema
+ val fieldNames = schema.fieldNames
+ require(!fieldNames.contains(colName), s"Prediction column $colName already exists.")
+ val outputFields = schema.fields ++ Seq(StructField(colName, dataType, nullable = false))
+ StructType(outputFields)
+ }
}
-private[ml] object Params {
+/**
+ * :: DeveloperApi ::
+ *
+ * Helper functionality for developers.
+ *
+ * NOTE: This is currently private[spark] but will be made public later once it is stabilized.
+ */
+@DeveloperApi
+private[spark] object Params {
/**
* Copies parameter values from the parent estimator to the child model it produced.
@@ -279,7 +322,7 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten
def copy: ParamMap = new ParamMap(map.clone())
override def toString: String = {
- map.map { case (param, value) =>
+ map.toSeq.sortBy(_._1.name).map { case (param, value) =>
s"\t${param.parent.uid}-${param.name}: $value"
}.mkString("{\n", ",\n", "\n}")
}
@@ -310,6 +353,11 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten
ParamPair(param, value)
}
}
+
+ /**
+ * Number of param pairs in this set.
+ */
+ def size: Int = map.size
}
object ParamMap {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala
index ef141d3eb2b06..32fc74462ef4a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala
@@ -17,6 +17,12 @@
package org.apache.spark.ml.param
+/* NOTE TO DEVELOPERS:
+ * If you mix these parameter traits into your algorithm, please add a setter method as well
+ * so that users may use a builder pattern:
+ * val myLearner = new MyLearner().setParam1(x).setParam2(y)...
+ */
+
private[ml] trait HasRegParam extends Params {
/** param for regularization parameter */
val regParam: DoubleParam = new DoubleParam(this, "regParam", "regularization parameter")
@@ -42,12 +48,6 @@ private[ml] trait HasLabelCol extends Params {
def getLabelCol: String = get(labelCol)
}
-private[ml] trait HasScoreCol extends Params {
- /** param for score column name */
- val scoreCol: Param[String] = new Param(this, "scoreCol", "score column name", Some("score"))
- def getScoreCol: String = get(scoreCol)
-}
-
private[ml] trait HasPredictionCol extends Params {
/** param for prediction column name */
val predictionCol: Param[String] =
@@ -55,6 +55,22 @@ private[ml] trait HasPredictionCol extends Params {
def getPredictionCol: String = get(predictionCol)
}
+private[ml] trait HasRawPredictionCol extends Params {
+ /** param for raw prediction column name */
+ val rawPredictionCol: Param[String] =
+ new Param(this, "rawPredictionCol", "raw prediction (a.k.a. confidence) column name",
+ Some("rawPrediction"))
+ def getRawPredictionCol: String = get(rawPredictionCol)
+}
+
+private[ml] trait HasProbabilityCol extends Params {
+ /** param for predicted class conditional probabilities column name */
+ val probabilityCol: Param[String] =
+ new Param(this, "probabilityCol", "column name for predicted class conditional probabilities",
+ Some("probability"))
+ def getProbabilityCol: String = get(probabilityCol)
+}
+
private[ml] trait HasThreshold extends Params {
/** param for threshold in (binary) prediction */
val threshold: DoubleParam = new DoubleParam(this, "threshold", "threshold in prediction")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
new file mode 100644
index 0000000000000..d5a7bdafcb623
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.regression
+
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml.param.{Params, ParamMap, HasMaxIter, HasRegParam}
+import org.apache.spark.mllib.linalg.{BLAS, Vector}
+import org.apache.spark.mllib.regression.LinearRegressionWithSGD
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.storage.StorageLevel
+
+
+/**
+ * Params for linear regression.
+ */
+private[regression] trait LinearRegressionParams extends RegressorParams
+ with HasRegParam with HasMaxIter
+
+
+/**
+ * :: AlphaComponent ::
+ *
+ * Linear regression.
+ */
+@AlphaComponent
+class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegressionModel]
+ with LinearRegressionParams {
+
+ setRegParam(0.1)
+ setMaxIter(100)
+
+ def setRegParam(value: Double): this.type = set(regParam, value)
+ def setMaxIter(value: Int): this.type = set(maxIter, value)
+
+ override protected def train(dataset: DataFrame, paramMap: ParamMap): LinearRegressionModel = {
+ // Extract columns from data. If dataset is persisted, do not persist oldDataset.
+ val oldDataset = extractLabeledPoints(dataset, paramMap)
+ val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
+ if (handlePersistence) {
+ oldDataset.persist(StorageLevel.MEMORY_AND_DISK)
+ }
+
+ // Train model
+ val lr = new LinearRegressionWithSGD()
+ lr.optimizer
+ .setRegParam(paramMap(regParam))
+ .setNumIterations(paramMap(maxIter))
+ val model = lr.run(oldDataset)
+ val lrm = new LinearRegressionModel(this, paramMap, model.weights, model.intercept)
+
+ if (handlePersistence) {
+ oldDataset.unpersist()
+ }
+ lrm
+ }
+}
+
+/**
+ * :: AlphaComponent ::
+ *
+ * Model produced by [[LinearRegression]].
+ */
+@AlphaComponent
+class LinearRegressionModel private[ml] (
+ override val parent: LinearRegression,
+ override val fittingParamMap: ParamMap,
+ val weights: Vector,
+ val intercept: Double)
+ extends RegressionModel[Vector, LinearRegressionModel]
+ with LinearRegressionParams {
+
+ override protected def predict(features: Vector): Double = {
+ BLAS.dot(features, weights) + intercept
+ }
+
+ override protected def copy(): LinearRegressionModel = {
+ val m = new LinearRegressionModel(parent, fittingParamMap, weights, intercept)
+ Params.inheritValues(this.paramMap, this, m)
+ m
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala
new file mode 100644
index 0000000000000..d679085eeafe1
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.regression
+
+import org.apache.spark.annotation.{DeveloperApi, AlphaComponent}
+import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor, PredictorParams}
+
+/**
+ * :: DeveloperApi ::
+ * Params for regression.
+ * Currently empty, but may add functionality later.
+ *
+ * NOTE: This is currently private[spark] but will be made public later once it is stabilized.
+ */
+@DeveloperApi
+private[spark] trait RegressorParams extends PredictorParams
+
+/**
+ * :: AlphaComponent ::
+ *
+ * Single-label regression
+ *
+ * @tparam FeaturesType Type of input features. E.g., [[org.apache.spark.mllib.linalg.Vector]]
+ * @tparam Learner Concrete Estimator type
+ * @tparam M Concrete Model type
+ *
+ * NOTE: This is currently private[spark] but will be made public later once it is stabilized.
+ */
+@AlphaComponent
+private[spark] abstract class Regressor[
+ FeaturesType,
+ Learner <: Regressor[FeaturesType, Learner, M],
+ M <: RegressionModel[FeaturesType, M]]
+ extends Predictor[FeaturesType, Learner, M]
+ with RegressorParams {
+
+ // TODO: defaultEvaluator (follow-up PR)
+}
+
+/**
+ * :: AlphaComponent ::
+ *
+ * Model produced by a [[Regressor]].
+ *
+ * @tparam FeaturesType Type of input features. E.g., [[org.apache.spark.mllib.linalg.Vector]]
+ * @tparam M Concrete Model type.
+ *
+ * NOTE: This is currently private[spark] but will be made public later once it is stabilized.
+ */
+@AlphaComponent
+private[spark] abstract class RegressionModel[FeaturesType, M <: RegressionModel[FeaturesType, M]]
+ extends PredictionModel[FeaturesType, M] with RegressorParams {
+
+ /**
+ * :: DeveloperApi ::
+ *
+ * Predict real-valued label for the given features.
+ * This internal method is used to implement [[transform()]] and output [[predictionCol]].
+ */
+ @DeveloperApi
+ protected def predict(features: FeaturesType): Double
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
index 77785bdbd03d9..480bbfb5fe94a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -26,6 +26,7 @@ import scala.collection.JavaConverters._
import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV}
import org.apache.spark.SparkException
+import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.mllib.util.NumericParser
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
@@ -110,9 +111,14 @@ sealed trait Vector extends Serializable {
}
/**
+ * :: DeveloperApi ::
+ *
* User-defined type for [[Vector]] which allows easy interaction with SQL
* via [[org.apache.spark.sql.DataFrame]].
+ *
+ * NOTE: This is currently private[spark] but will be made public later once it is stabilized.
*/
+@DeveloperApi
private[spark] class VectorUDT extends UserDefinedType[Vector] {
override def sqlType: StructType = {
@@ -169,6 +175,13 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] {
override def pyUDT: String = "pyspark.mllib.linalg.VectorUDT"
override def userClass: Class[Vector] = classOf[Vector]
+
+ override def equals(o: Any): Boolean = {
+ o match {
+ case v: VectorUDT => true
+ case _ => false
+ }
+ }
}
/**
diff --git a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
index 56a9dbdd58b64..50995ffef9ad5 100644
--- a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
@@ -65,7 +65,7 @@ public void pipeline() {
.setStages(new PipelineStage[] {scaler, lr});
PipelineModel model = pipeline.fit(dataset);
model.transform(dataset).registerTempTable("prediction");
- DataFrame predictions = jsql.sql("SELECT label, score, prediction FROM prediction");
+ DataFrame predictions = jsql.sql("SELECT label, probability, prediction FROM prediction");
predictions.collectAsList();
}
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
index f4ba23c44563e..26284023b0f69 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
@@ -18,17 +18,22 @@
package org.apache.spark.ml.classification;
import java.io.Serializable;
+import java.lang.Math;
import java.util.List;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
+import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
+import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList;
+import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.SQLContext;
-import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList;
+import org.apache.spark.sql.Row;
+
public class JavaLogisticRegressionSuite implements Serializable {
@@ -36,12 +41,17 @@ public class JavaLogisticRegressionSuite implements Serializable {
private transient SQLContext jsql;
private transient DataFrame dataset;
+ private transient JavaRDD datasetRDD;
+ private double eps = 1e-5;
+
@Before
public void setUp() {
jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite");
jsql = new SQLContext(jsc);
List points = generateLogisticInputAsList(1.0, 1.0, 100, 42);
- dataset = jsql.applySchema(jsc.parallelize(points, 2), LabeledPoint.class);
+ datasetRDD = jsc.parallelize(points, 2);
+ dataset = jsql.applySchema(datasetRDD, LabeledPoint.class);
+ dataset.registerTempTable("dataset");
}
@After
@@ -51,29 +61,88 @@ public void tearDown() {
}
@Test
- public void logisticRegression() {
+ public void logisticRegressionDefaultParams() {
LogisticRegression lr = new LogisticRegression();
+ assert(lr.getLabelCol().equals("label"));
LogisticRegressionModel model = lr.fit(dataset);
model.transform(dataset).registerTempTable("prediction");
- DataFrame predictions = jsql.sql("SELECT label, score, prediction FROM prediction");
+ DataFrame predictions = jsql.sql("SELECT label, probability, prediction FROM prediction");
predictions.collectAsList();
+ // Check defaults
+ assert(model.getThreshold() == 0.5);
+ assert(model.getFeaturesCol().equals("features"));
+ assert(model.getPredictionCol().equals("prediction"));
+ assert(model.getProbabilityCol().equals("probability"));
}
@Test
public void logisticRegressionWithSetters() {
+ // Set params, train, and check as many params as we can.
LogisticRegression lr = new LogisticRegression()
.setMaxIter(10)
- .setRegParam(1.0);
+ .setRegParam(1.0)
+ .setThreshold(0.6)
+ .setProbabilityCol("myProbability");
LogisticRegressionModel model = lr.fit(dataset);
- model.transform(dataset, model.threshold().w(0.8)) // overwrite threshold
- .registerTempTable("prediction");
- DataFrame predictions = jsql.sql("SELECT label, score, prediction FROM prediction");
- predictions.collectAsList();
+ assert(model.fittingParamMap().apply(lr.maxIter()) == 10);
+ assert(model.fittingParamMap().apply(lr.regParam()).equals(1.0));
+ assert(model.fittingParamMap().apply(lr.threshold()).equals(0.6));
+ assert(model.getThreshold() == 0.6);
+
+ // Modify model params, and check that the params worked.
+ model.setThreshold(1.0);
+ model.transform(dataset).registerTempTable("predAllZero");
+ DataFrame predAllZero = jsql.sql("SELECT prediction, myProbability FROM predAllZero");
+ for (Row r: predAllZero.collectAsList()) {
+ assert(r.getDouble(0) == 0.0);
+ }
+ // Call transform with params, and check that the params worked.
+ model.transform(dataset, model.threshold().w(0.0), model.probabilityCol().w("myProb"))
+ .registerTempTable("predNotAllZero");
+ DataFrame predNotAllZero = jsql.sql("SELECT prediction, myProb FROM predNotAllZero");
+ boolean foundNonZero = false;
+ for (Row r: predNotAllZero.collectAsList()) {
+ if (r.getDouble(0) != 0.0) foundNonZero = true;
+ }
+ assert(foundNonZero);
+
+ // Call fit() with new params, and check as many params as we can.
+ LogisticRegressionModel model2 = lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1),
+ lr.threshold().w(0.4), lr.probabilityCol().w("theProb"));
+ assert(model2.fittingParamMap().apply(lr.maxIter()) == 5);
+ assert(model2.fittingParamMap().apply(lr.regParam()).equals(0.1));
+ assert(model2.fittingParamMap().apply(lr.threshold()).equals(0.4));
+ assert(model2.getThreshold() == 0.4);
+ assert(model2.getProbabilityCol().equals("theProb"));
}
+ @SuppressWarnings("unchecked")
@Test
- public void logisticRegressionFitWithVarargs() {
+ public void logisticRegressionPredictorClassifierMethods() {
LogisticRegression lr = new LogisticRegression();
- lr.fit(dataset, lr.maxIter().w(10), lr.regParam().w(1.0));
+ LogisticRegressionModel model = lr.fit(dataset);
+ assert(model.numClasses() == 2);
+
+ model.transform(dataset).registerTempTable("transformed");
+ DataFrame trans1 = jsql.sql("SELECT rawPrediction, probability FROM transformed");
+ for (Row row: trans1.collect()) {
+ Vector raw = (Vector)row.get(0);
+ Vector prob = (Vector)row.get(1);
+ assert(raw.size() == 2);
+ assert(prob.size() == 2);
+ double probFromRaw1 = 1.0 / (1.0 + Math.exp(-raw.apply(1)));
+ assert(Math.abs(prob.apply(1) - probFromRaw1) < eps);
+ assert(Math.abs(prob.apply(0) - (1.0 - probFromRaw1)) < eps);
+ }
+
+ DataFrame trans2 = jsql.sql("SELECT prediction, probability FROM transformed");
+ for (Row row: trans2.collect()) {
+ double pred = row.getDouble(0);
+ Vector prob = (Vector)row.get(1);
+ double probOfPred = prob.apply((int)pred);
+ for (int i = 0; i < prob.size(); ++i) {
+ assert(probOfPred >= prob.apply(i));
+ }
+ }
}
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
new file mode 100644
index 0000000000000..5bd616e74d86c
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.regression;
+
+import java.io.Serializable;
+import java.util.List;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import static org.apache.spark.mllib.classification.LogisticRegressionSuite
+ .generateLogisticInputAsList;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.SQLContext;
+
+
+public class JavaLinearRegressionSuite implements Serializable {
+
+ private transient JavaSparkContext jsc;
+ private transient SQLContext jsql;
+ private transient DataFrame dataset;
+ private transient JavaRDD datasetRDD;
+
+ @Before
+ public void setUp() {
+ jsc = new JavaSparkContext("local", "JavaLinearRegressionSuite");
+ jsql = new SQLContext(jsc);
+ List points = generateLogisticInputAsList(1.0, 1.0, 100, 42);
+ datasetRDD = jsc.parallelize(points, 2);
+ dataset = jsql.applySchema(datasetRDD, LabeledPoint.class);
+ dataset.registerTempTable("dataset");
+ }
+
+ @After
+ public void tearDown() {
+ jsc.stop();
+ jsc = null;
+ }
+
+ @Test
+ public void linearRegressionDefaultParams() {
+ LinearRegression lr = new LinearRegression();
+ assert(lr.getLabelCol().equals("label"));
+ LinearRegressionModel model = lr.fit(dataset);
+ model.transform(dataset).registerTempTable("prediction");
+ DataFrame predictions = jsql.sql("SELECT label, prediction FROM prediction");
+ predictions.collect();
+ // Check defaults
+ assert(model.getFeaturesCol().equals("features"));
+ assert(model.getPredictionCol().equals("prediction"));
+ }
+
+ @Test
+ public void linearRegressionWithSetters() {
+ // Set params, train, and check as many params as we can.
+ LinearRegression lr = new LinearRegression()
+ .setMaxIter(10)
+ .setRegParam(1.0);
+ LinearRegressionModel model = lr.fit(dataset);
+ assert(model.fittingParamMap().apply(lr.maxIter()) == 10);
+ assert(model.fittingParamMap().apply(lr.regParam()).equals(1.0));
+
+ // Call fit() with new params, and check as many params as we can.
+ LinearRegressionModel model2 =
+ lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1), lr.predictionCol().w("thePred"));
+ assert(model2.fittingParamMap().apply(lr.maxIter()) == 5);
+ assert(model2.fittingParamMap().apply(lr.regParam()).equals(0.1));
+ assert(model2.getPredictionCol().equals("thePred"));
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index 33e40dc7410cc..b3d1bfcfbee0f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -20,44 +20,108 @@ package org.apache.spark.ml.classification
import org.scalatest.FunSuite
import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
+import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.{SQLContext, DataFrame}
+import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+
class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
@transient var sqlContext: SQLContext = _
@transient var dataset: DataFrame = _
+ private val eps: Double = 1e-5
override def beforeAll(): Unit = {
super.beforeAll()
sqlContext = new SQLContext(sc)
dataset = sqlContext.createDataFrame(
- sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2))
+ sc.parallelize(generateLogisticInput(1.0, 1.0, nPoints = 100, seed = 42), 2))
}
- test("logistic regression") {
+ test("logistic regression: default params") {
val lr = new LogisticRegression
+ assert(lr.getLabelCol == "label")
+ assert(lr.getFeaturesCol == "features")
+ assert(lr.getPredictionCol == "prediction")
+ assert(lr.getRawPredictionCol == "rawPrediction")
+ assert(lr.getProbabilityCol == "probability")
val model = lr.fit(dataset)
model.transform(dataset)
- .select("label", "prediction")
+ .select("label", "probability", "prediction", "rawPrediction")
.collect()
+ assert(model.getThreshold === 0.5)
+ assert(model.getFeaturesCol == "features")
+ assert(model.getPredictionCol == "prediction")
+ assert(model.getRawPredictionCol == "rawPrediction")
+ assert(model.getProbabilityCol == "probability")
}
test("logistic regression with setters") {
+ // Set params, train, and check as many params as we can.
val lr = new LogisticRegression()
.setMaxIter(10)
.setRegParam(1.0)
+ .setThreshold(0.6)
+ .setProbabilityCol("myProbability")
val model = lr.fit(dataset)
- model.transform(dataset, model.threshold -> 0.8) // overwrite threshold
- .select("label", "score", "prediction")
+ assert(model.fittingParamMap.get(lr.maxIter) === Some(10))
+ assert(model.fittingParamMap.get(lr.regParam) === Some(1.0))
+ assert(model.fittingParamMap.get(lr.threshold) === Some(0.6))
+ assert(model.getThreshold === 0.6)
+
+ // Modify model params, and check that the params worked.
+ model.setThreshold(1.0)
+ val predAllZero = model.transform(dataset)
+ .select("prediction", "myProbability")
.collect()
+ .map { case Row(pred: Double, prob: Vector) => pred }
+ assert(predAllZero.forall(_ === 0),
+ s"With threshold=1.0, expected predictions to be all 0, but only" +
+ s" ${predAllZero.count(_ === 0)} of ${dataset.count()} were 0.")
+ // Call transform with params, and check that the params worked.
+ val predNotAllZero =
+ model.transform(dataset, model.threshold -> 0.0, model.probabilityCol -> "myProb")
+ .select("prediction", "myProb")
+ .collect()
+ .map { case Row(pred: Double, prob: Vector) => pred }
+ assert(predNotAllZero.exists(_ !== 0.0))
+
+ // Call fit() with new params, and check as many params as we can.
+ val model2 = lr.fit(dataset, lr.maxIter -> 5, lr.regParam -> 0.1, lr.threshold -> 0.4,
+ lr.probabilityCol -> "theProb")
+ assert(model2.fittingParamMap.get(lr.maxIter).get === 5)
+ assert(model2.fittingParamMap.get(lr.regParam).get === 0.1)
+ assert(model2.fittingParamMap.get(lr.threshold).get === 0.4)
+ assert(model2.getThreshold === 0.4)
+ assert(model2.getProbabilityCol == "theProb")
}
- test("logistic regression fit and transform with varargs") {
+ test("logistic regression: Predictor, Classifier methods") {
+ val sqlContext = this.sqlContext
val lr = new LogisticRegression
- val model = lr.fit(dataset, lr.maxIter -> 10, lr.regParam -> 1.0)
- model.transform(dataset, model.threshold -> 0.8, model.scoreCol -> "probability")
- .select("label", "probability", "prediction")
- .collect()
+
+ val model = lr.fit(dataset)
+ assert(model.numClasses === 2)
+
+ val threshold = model.getThreshold
+ val results = model.transform(dataset)
+
+ // Compare rawPrediction with probability
+ results.select("rawPrediction", "probability").collect().map {
+ case Row(raw: Vector, prob: Vector) =>
+ assert(raw.size === 2)
+ assert(prob.size === 2)
+ val probFromRaw1 = 1.0 / (1.0 + math.exp(-raw(1)))
+ assert(prob(1) ~== probFromRaw1 relTol eps)
+ assert(prob(0) ~== 1.0 - probFromRaw1 relTol eps)
+ }
+
+ // Compare prediction with probability
+ results.select("prediction", "probability").collect().map {
+ case Row(pred: Double, prob: Vector) =>
+ val predFromProb = prob.toArray.zipWithIndex.maxBy(_._1)._2
+ assert(pred == predFromProb)
+ }
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
new file mode 100644
index 0000000000000..bbb44c3e2dfc2
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.regression
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.{DataFrame, SQLContext}
+
+class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext {
+
+ @transient var sqlContext: SQLContext = _
+ @transient var dataset: DataFrame = _
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ sqlContext = new SQLContext(sc)
+ dataset = sqlContext.createDataFrame(
+ sc.parallelize(generateLogisticInput(1.0, 1.0, nPoints = 100, seed = 42), 2))
+ }
+
+ test("linear regression: default params") {
+ val lr = new LinearRegression
+ assert(lr.getLabelCol == "label")
+ val model = lr.fit(dataset)
+ model.transform(dataset)
+ .select("label", "prediction")
+ .collect()
+ // Check defaults
+ assert(model.getFeaturesCol == "features")
+ assert(model.getPredictionCol == "prediction")
+ }
+
+ test("linear regression with setters") {
+ // Set params, train, and check as many as we can.
+ val lr = new LinearRegression()
+ .setMaxIter(10)
+ .setRegParam(1.0)
+ val model = lr.fit(dataset)
+ assert(model.fittingParamMap.get(lr.maxIter).get === 10)
+ assert(model.fittingParamMap.get(lr.regParam).get === 1.0)
+
+ // Call fit() with new params, and check as many as we can.
+ val model2 = lr.fit(dataset, lr.maxIter -> 5, lr.regParam -> 0.1, lr.predictionCol -> "thePred")
+ assert(model2.fittingParamMap.get(lr.maxIter).get === 5)
+ assert(model2.fittingParamMap.get(lr.regParam).get === 0.1)
+ assert(model2.getPredictionCol == "thePred")
+ }
+}
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index b17532c1d814c..4065a562a1a18 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -36,6 +36,7 @@ object MimaExcludes {
case v if v.startsWith("1.3") =>
Seq(
MimaBuild.excludeSparkPackage("deploy"),
+ MimaBuild.excludeSparkPackage("ml"),
// These are needed if checking against the sbt build, since they are part of
// the maven-generated artifacts in the 1.2 build.
MimaBuild.excludeSparkPackage("unused"),
@@ -142,6 +143,11 @@ object MimaExcludes {
"org.apache.spark.graphx.Graph.getCheckpointFiles"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.graphx.Graph.isCheckpointed")
+ ) ++ Seq(
+ // SPARK-4789 Standardize ML Prediction APIs
+ ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.mllib.linalg.VectorUDT"),
+ ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.mllib.linalg.VectorUDT.serialize"),
+ ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.mllib.linalg.VectorUDT.sqlType")
)
case v if v.startsWith("1.2") =>