Skip to content

Commit

Permalink
Updates based on code review. Major ones are:
Browse files Browse the repository at this point in the history
* Created weakly typed Predictor.train() method which is called by fit() so that developers do not have to call schema validation or copy parameters.
* Made Predictor.featuresDataType have a default value of VectorUDT.
  * NOTE: This could be dangerous since the FeaturesType type parameter cannot have a default value.
  • Loading branch information
jkbradley committed Feb 5, 2015
1 parent 343e7bd commit f549e34
Show file tree
Hide file tree
Showing 8 changed files with 65 additions and 72 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ object CrossValidatorExample {
.select('id, 'text, 'probability, 'prediction)
.collect()
.foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) =>
println("(" + id + ", " + text + ") --> prob=" + prob + ", prediction=" + prediction)
println(s"($id, $text) --> prob=$prob, prediction=$prediction")
}

sc.stop()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.SparkContext._
import org.apache.spark.ml.classification.{Classifier, ClassifierParams, ClassificationModel}
import org.apache.spark.ml.param.{Params, IntParam, ParamMap}
import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors, VectorUDT}
import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.sql.{DataType, SchemaRDD, Row, SQLContext}
import org.apache.spark.sql.{SchemaRDD, Row, SQLContext}

/**
* A simple example demonstrating how to write your own learning algorithm using Estimator,
Expand Down Expand Up @@ -85,7 +85,14 @@ object DeveloperApiExample {
*/
private trait MyLogisticRegressionParams extends ClassifierParams {

/** param for max number of iterations */
/**
* Param for max number of iterations
*
* NOTE: The usual way to add a parameter to a model or algorithm is to include:
* - val myParamName: ParamType
* - def getMyParamName
* - def setMyParamName
*/
val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations")
def getMaxIter: Int = get(maxIter)
}
Expand All @@ -101,40 +108,23 @@ private class MyLogisticRegression

setMaxIter(100) // Initialize

// The parameter setter is in this class since it should return type MyLogisticRegression.
def setMaxIter(value: Int): this.type = set(maxIter, value)

override def fit(dataset: SchemaRDD, paramMap: ParamMap): MyLogisticRegressionModel = {
// Check schema (types). This allows early failure before running the algorithm.
transformSchema(dataset.schema, paramMap, logging = true)

// This method is used by fit()
override protected def train(
dataset: SchemaRDD,
paramMap: ParamMap): MyLogisticRegressionModel = {
// Extract columns from data using helper method.
val oldDataset = extractLabeledPoints(dataset, paramMap)

// Combine given parameters with the embedded parameters, where the given paramMap overrides
// any embedded settings.
val map = this.paramMap ++ paramMap

// Do learning to estimate the weight vector.
val numFeatures = oldDataset.take(1)(0).features.size
val weights = Vectors.zeros(numFeatures) // Learning would happen here.

// Create a model to return.
val lrm = new MyLogisticRegressionModel(this, map, weights)

// Copy model params.
// An Estimator stores the parameters for the Model it produces, and this copies any relevant
// parameters to the model.
Params.inheritValues(map, this, lrm)

// Return the learned model.
lrm
// Create a model, and return it.
new MyLogisticRegressionModel(this, paramMap, weights)
}

/**
* Returns the SQL DataType corresponding to the FeaturesType type parameter.
* This is used by [[ClassifierParams.validateAndTransformSchema()]] to check the input data.
*/
override protected def featuresDataType: DataType = new VectorUDT
}

/**
Expand Down Expand Up @@ -186,10 +176,4 @@ private class MyLogisticRegressionModel(
Params.inheritValues(this.paramMap, this, m)
m
}

/**
* Returns the SQL DataType corresponding to the FeaturesType type parameter.
* This is used by [[ClassifierParams.validateAndTransformSchema()]] to check the input data.
*/
override protected def featuresDataType: DataType = new VectorUDT
}
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ object SimpleParamsExample {
.select('features, 'label, 'myProbability, 'prediction)
.collect()
.foreach { case Row(features: Vector, label: Double, prob: Vector, prediction: Double) =>
println("(" + features + ", " + label + ") -> prob=" + prob + ", prediction=" + prediction)
println("($features, $label) -> prob=$prob, prediction=$prediction")
}

sc.stop()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ object SimpleTextClassificationPipeline {
.select('id, 'text, 'probability, 'prediction)
.collect()
.foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) =>
println("(" + id + ", " + text + ") --> prob=" + prob + ", prediction=" + prediction)
println("($id, $text) --> prob=$prob, prediction=$prediction")
}

sc.stop()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.ml.classification
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.param._
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT, Vectors}
import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors}
import org.apache.spark.sql._
import org.apache.spark.sql.Dsl._
import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
Expand Down Expand Up @@ -52,13 +52,9 @@ class LogisticRegression
def setMaxIter(value: Int): this.type = set(maxIter, value)
def setThreshold(value: Double): this.type = set(threshold, value)

override def fit(dataset: SchemaRDD, paramMap: ParamMap): LogisticRegressionModel = {
// Check schema
transformSchema(dataset.schema, paramMap, logging = true)

override protected def train(dataset: SchemaRDD, paramMap: ParamMap): LogisticRegressionModel = {
// Extract columns from data. If dataset is persisted, do not persist oldDataset.
val oldDataset = extractLabeledPoints(dataset, paramMap)
val map = this.paramMap ++ paramMap
val handlePersistence = dataset.getStorageLevel == StorageLevel.NONE
if (handlePersistence) {
oldDataset.persist(StorageLevel.MEMORY_AND_DISK)
Expand All @@ -67,21 +63,16 @@ class LogisticRegression
// Train model
val lr = new LogisticRegressionWithLBFGS
lr.optimizer
.setRegParam(map(regParam))
.setNumIterations(map(maxIter))
.setRegParam(paramMap(regParam))
.setNumIterations(paramMap(maxIter))
val oldModel = lr.run(oldDataset)
val lrm = new LogisticRegressionModel(this, map, oldModel.weights, oldModel.intercept)
val lrm = new LogisticRegressionModel(this, paramMap, oldModel.weights, oldModel.intercept)

if (handlePersistence) {
oldDataset.unpersist()
}

// copy model params
Params.inheritValues(map, this, lrm)
lrm
}

override protected def featuresDataType: DataType = new VectorUDT
}


Expand Down Expand Up @@ -215,6 +206,4 @@ class LogisticRegressionModel private[ml] (
Params.inheritValues(this.paramMap, this, m)
m
}

override protected def featuresDataType: DataType = new VectorUDT
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.ml.impl.estimator
import org.apache.spark.annotation.{AlphaComponent, DeveloperApi}
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param._
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.linalg.{VectorUDT, Vector}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
Expand Down Expand Up @@ -84,16 +84,43 @@ abstract class Predictor[
def setFeaturesCol(value: String): Learner = set(featuresCol, value).asInstanceOf[Learner]
def setPredictionCol(value: String): Learner = set(predictionCol, value).asInstanceOf[Learner]

override def fit(dataset: SchemaRDD, paramMap: ParamMap): M = {
// This handles a few items such as schema validation.
// Developers only need to implement train().
transformSchema(dataset.schema, paramMap, logging = true)
val map = this.paramMap ++ paramMap
val model = train(dataset, map)
Params.inheritValues(map, this, model) // copy params to model
model
}

/**
* :: DeveloperApi ::
*
* Train a model using the given dataset and parameters.
* Developers can implement this instead of [[fit()]] to avoid dealing with schema validation
* and copying parameters into the model.
*
* @param dataset Training dataset
* @param paramMap Parameter map. Unlike [[fit()]]'s paramMap, this paramMap has already
* been combined with the embedded ParamMap.
* @return Fitted model
*/
@DeveloperApi
protected def train(dataset: SchemaRDD, paramMap: ParamMap): M

/**
* :: DeveloperApi ::
*
* Returns the SQL DataType corresponding to the FeaturesType type parameter.
*
* This is used by [[validateAndTransformSchema()]].
* This workaround is needed since SQL has different APIs for Scala and Java.
*
* The default value is VectorUDT, but it may be overridden if FeaturesType is not Vector.
*/
@DeveloperApi
protected def featuresDataType: DataType
protected def featuresDataType: DataType = new VectorUDT

private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
validateAndTransformSchema(schema, paramMap, fitting = true, featuresDataType)
Expand Down Expand Up @@ -138,9 +165,11 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType,
*
* This is used by [[validateAndTransformSchema()]].
* This workaround is needed since SQL has different APIs for Scala and Java.
*
* The default value is VectorUDT, but it may be overridden if FeaturesType is not Vector.
*/
@DeveloperApi
protected def featuresDataType: DataType
protected def featuresDataType: DataType = new VectorUDT

private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
validateAndTransformSchema(schema, paramMap, fitting = false, featuresDataType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
package org.apache.spark.ml.param

/* NOTE TO DEVELOPERS:
* If you add these parameter traits into your algorithm, you need to add a setter method as well.
* If you mix these parameter traits into your algorithm, please add a setter method as well
* so that users may use a builder pattern:
* val myLearner = new MyLearner().setParam1(x).setParam2(y)...
*/

private[ml] trait HasRegParam extends Params {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.ml.regression

import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.param.{Params, ParamMap, HasMaxIter, HasRegParam}
import org.apache.spark.mllib.linalg.{VectorUDT, BLAS, Vector}
import org.apache.spark.mllib.linalg.{BLAS, Vector}
import org.apache.spark.mllib.regression.LinearRegressionWithSGD
import org.apache.spark.sql._
import org.apache.spark.storage.StorageLevel
Expand All @@ -45,13 +45,9 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress
def setRegParam(value: Double): this.type = set(regParam, value)
def setMaxIter(value: Int): this.type = set(maxIter, value)

override def fit(dataset: SchemaRDD, paramMap: ParamMap): LinearRegressionModel = {
// Check schema
transformSchema(dataset.schema, paramMap, logging = true)

override protected def train(dataset: SchemaRDD, paramMap: ParamMap): LinearRegressionModel = {
// Extract columns from data. If dataset is persisted, do not persist oldDataset.
val oldDataset = extractLabeledPoints(dataset, paramMap)
val map = this.paramMap ++ paramMap
val handlePersistence = dataset.getStorageLevel == StorageLevel.NONE
if (handlePersistence) {
oldDataset.persist(StorageLevel.MEMORY_AND_DISK)
Expand All @@ -60,21 +56,16 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress
// Train model
val lr = new LinearRegressionWithSGD()
lr.optimizer
.setRegParam(map(regParam))
.setNumIterations(map(maxIter))
.setRegParam(paramMap(regParam))
.setNumIterations(paramMap(maxIter))
val model = lr.run(oldDataset)
val lrm = new LinearRegressionModel(this, map, model.weights, model.intercept)
val lrm = new LinearRegressionModel(this, paramMap, model.weights, model.intercept)

if (handlePersistence) {
oldDataset.unpersist()
}

// copy model params
Params.inheritValues(map, this, lrm)
lrm
}

override protected def featuresDataType: DataType = new VectorUDT
}

/**
Expand All @@ -100,6 +91,4 @@ class LinearRegressionModel private[ml] (
Params.inheritValues(this.paramMap, this, m)
m
}

override protected def featuresDataType: DataType = new VectorUDT
}

0 comments on commit f549e34

Please sign in to comment.