Skip to content

Commit

Permalink
[SPARK-8151] [MLLIB] pipeline components should correctly implement copy
Browse files Browse the repository at this point in the history
Otherwise, extra params get ignored in `PipelineModel.transform`. jkbradley

Author: Xiangrui Meng <[email protected]>

Closes #6622 from mengxr/SPARK-8087 and squashes the following commits:

0e4c8c4 [Xiangrui Meng] fix merge issues
26fc1f0 [Xiangrui Meng] address comments
e607a04 [Xiangrui Meng] merge master
b85b57e [Xiangrui Meng] fix examples/compile
d6f7891 [Xiangrui Meng] rename defaultCopyWithParams to defaultCopy
84ec278 [Xiangrui Meng] remove setter checks due to generics
2cf2ed0 [Xiangrui Meng] snapshot
291814f [Xiangrui Meng] OneVsRest.copy
1dfe3bd [Xiangrui Meng] PipelineModel.copy should copy stages

(cherry picked from commit 43c7ec6)
Signed-off-by: Xiangrui Meng <[email protected]>
  • Loading branch information
mengxr committed Jun 19, 2015
1 parent 164b9d3 commit 1f2dafb
Show file tree
Hide file tree
Showing 62 changed files with 350 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,11 @@ public MyJavaLogisticRegressionModel train(DataFrame dataset) {
// Create a model, and return it.
return new MyJavaLogisticRegressionModel(uid(), weights).setParent(this);
}

@Override
public MyJavaLogisticRegression copy(ParamMap extra) {
return defaultCopy(extra);
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ private class MyLogisticRegression(override val uid: String)
// Create a model, and return it.
new MyLogisticRegressionModel(uid, weights).setParent(this)
}

override def copy(extra: ParamMap): MyLogisticRegression = defaultCopy(extra)
}

/**
Expand Down
4 changes: 1 addition & 3 deletions mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,5 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage {
paramMaps.map(fit(dataset, _))
}

override def copy(extra: ParamMap): Estimator[M] = {
super.copy(extra).asInstanceOf[Estimator[M]]
}
override def copy(extra: ParamMap): Estimator[M]
}
5 changes: 1 addition & 4 deletions mllib/src/main/scala/org/apache/spark/ml/Model.scala
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,5 @@ abstract class Model[M <: Model[M]] extends Transformer {
/** Indicates whether this [[Model]] has a corresponding parent. */
def hasParent: Boolean = parent != null

override def copy(extra: ParamMap): M = {
// The default implementation of Params.copy doesn't work for models.
throw new NotImplementedError(s"${this.getClass} doesn't implement copy(extra: ParamMap)")
}
override def copy(extra: ParamMap): M
}
6 changes: 2 additions & 4 deletions mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,7 @@ abstract class PipelineStage extends Params with Logging {
outputSchema
}

override def copy(extra: ParamMap): PipelineStage = {
super.copy(extra).asInstanceOf[PipelineStage]
}
override def copy(extra: ParamMap): PipelineStage
}

/**
Expand Down Expand Up @@ -190,6 +188,6 @@ class PipelineModel private[ml] (
}

override def copy(extra: ParamMap): PipelineModel = {
new PipelineModel(uid, stages)
new PipelineModel(uid, stages.map(_.copy(extra)))
}
}
4 changes: 1 addition & 3 deletions mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,7 @@ abstract class Predictor[
copyValues(train(dataset).setParent(this))
}

override def copy(extra: ParamMap): Learner = {
super.copy(extra).asInstanceOf[Learner]
}
override def copy(extra: ParamMap): Learner

/**
* Train a model using the given dataset and parameters.
Expand Down
6 changes: 3 additions & 3 deletions mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,7 @@ abstract class Transformer extends PipelineStage {
*/
def transform(dataset: DataFrame): DataFrame

override def copy(extra: ParamMap): Transformer = {
super.copy(extra).asInstanceOf[Transformer]
}
override def copy(extra: ParamMap): Transformer
}

/**
Expand Down Expand Up @@ -120,4 +118,6 @@ abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]]
dataset.withColumn($(outputCol),
callUDF(this.createTransformFunc, outputDataType, dataset($(inputCol))))
}

override def copy(extra: ParamMap): T = defaultCopy(extra)
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.ml.classification

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.{PredictionModel, PredictorParams, Predictor}
import org.apache.spark.ml.param.shared.HasRawPredictionCol
import org.apache.spark.ml.util.SchemaUtils
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ final class DecisionTreeClassifier(override val uid: String)
super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity,
subsamplingRate = 1.0)
}

override def copy(extra: ParamMap): DecisionTreeClassifier = defaultCopy(extra)
}

@Experimental
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ final class GBTClassifier(override val uid: String)
val oldModel = oldGBT.run(oldDataset)
GBTClassificationModel.fromOld(oldModel, this, categoricalFeatures)
}

override def copy(extra: ParamMap): GBTClassifier = defaultCopy(extra)
}

@Experimental
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,8 @@ class LogisticRegression(override val uid: String)

new LogisticRegressionModel(uid, weights.compressed, intercept)
}

override def copy(extra: ParamMap): LogisticRegression = defaultCopy(extra)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import scala.language.existentials
import org.apache.spark.annotation.Experimental
import org.apache.spark.ml._
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.param.Param
import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.sql.{DataFrame, Row}
Expand Down Expand Up @@ -133,6 +133,12 @@ final class OneVsRestModel private[ml] (
aggregatedDataset.withColumn($(predictionCol), labelUdf.as($(predictionCol), labelMetadata))
.drop(accColName)
}

override def copy(extra: ParamMap): OneVsRestModel = {
val copied = new OneVsRestModel(
uid, labelMetadata, models.map(_.copy(extra).asInstanceOf[ClassificationModel[_, _]]))
copyValues(copied, extra)
}
}

/**
Expand Down Expand Up @@ -209,4 +215,12 @@ final class OneVsRest(override val uid: String)
val model = new OneVsRestModel(uid, labelAttribute.toMetadata(), models).setParent(this)
copyValues(model)
}

override def copy(extra: ParamMap): OneVsRest = {
val copied = defaultCopy(extra).asInstanceOf[OneVsRest]
if (isDefined(classifier)) {
copied.setClassifier($(classifier).copy(extra))
}
copied
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ final class RandomForestClassifier(override val uid: String)
oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed.toInt)
RandomForestClassificationModel.fromOld(oldModel, this, categoricalFeatures)
}

override def copy(extra: ParamMap): RandomForestClassifier = defaultCopy(extra)
}

@Experimental
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,4 +79,6 @@ class BinaryClassificationEvaluator(override val uid: String)
metrics.unpersist()
metric
}

override def copy(extra: ParamMap): BinaryClassificationEvaluator = defaultCopy(extra)
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,5 @@ abstract class Evaluator extends Params {
*/
def evaluate(dataset: DataFrame): Double

override def copy(extra: ParamMap): Evaluator = {
super.copy(extra).asInstanceOf[Evaluator]
}
override def copy(extra: ParamMap): Evaluator
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.ml.evaluation

import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.param.{Param, ParamValidators}
import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol}
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
import org.apache.spark.mllib.evaluation.RegressionMetrics
Expand Down Expand Up @@ -80,4 +80,6 @@ final class RegressionEvaluator(override val uid: String)
}
metric
}

override def copy(extra: ParamMap): RegressionEvaluator = defaultCopy(extra)
}
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,6 @@ final class Binarizer(override val uid: String)
val outputFields = inputFields :+ attr.toStructField()
StructType(outputFields)
}

override def copy(extra: ParamMap): Binarizer = defaultCopy(extra)
}
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ final class Bucketizer(override val uid: String)
SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType)
SchemaUtils.appendColumn(schema, prepOutputField(schema))
}

override def copy(extra: ParamMap): Bucketizer = defaultCopy(extra)
}

private[feature] object Bucketizer {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.ml.feature

import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.param.Param
import org.apache.spark.ml.param.{ParamMap, Param}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.ml.feature
import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.param.{IntParam, ParamValidators}
import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
import org.apache.spark.mllib.feature
Expand Down Expand Up @@ -74,4 +74,6 @@ class HashingTF(override val uid: String) extends Transformer with HasInputCol w
val attrGroup = new AttributeGroup($(outputCol), $(numFeatures))
SchemaUtils.appendColumn(schema, attrGroup.toStructField())
}

override def copy(extra: ParamMap): HashingTF = defaultCopy(extra)
}
13 changes: 10 additions & 3 deletions mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,6 @@ private[feature] trait IDFBase extends Params with HasInputCol with HasOutputCol
/** @group getParam */
def getMinDocFreq: Int = $(minDocFreq)

/** @group setParam */
def setMinDocFreq(value: Int): this.type = set(minDocFreq, value)

/**
* Validate and transform the input schema.
*/
Expand All @@ -72,6 +69,9 @@ final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBa
/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)

/** @group setParam */
def setMinDocFreq(value: Int): this.type = set(minDocFreq, value)

override def fit(dataset: DataFrame): IDFModel = {
transformSchema(dataset.schema, logging = true)
val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v }
Expand All @@ -82,6 +82,8 @@ final class IDF(override val uid: String) extends Estimator[IDFModel] with IDFBa
override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema)
}

override def copy(extra: ParamMap): IDF = defaultCopy(extra)
}

/**
Expand Down Expand Up @@ -109,4 +111,9 @@ class IDFModel private[ml] (
override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema)
}

override def copy(extra: ParamMap): IDFModel = {
val copied = new IDFModel(uid, idfModel)
copyValues(copied, extra)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -165,4 +165,6 @@ class OneHotEncoder(override val uid: String) extends Transformer

dataset.select(col("*"), encode(col(inputColName).cast(DoubleType)).as(outputColName, metadata))
}

override def copy(extra: ParamMap): OneHotEncoder = defaultCopy(extra)
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import scala.collection.mutable

import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.param.{IntParam, ParamValidators}
import org.apache.spark.ml.param.{ParamMap, IntParam, ParamValidators}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.mllib.linalg._
import org.apache.spark.sql.types.DataType
Expand Down Expand Up @@ -61,6 +61,8 @@ class PolynomialExpansion(override val uid: String)
}

override protected def outputDataType: DataType = new VectorUDT()

override def copy(extra: ParamMap): PolynomialExpansion = defaultCopy(extra)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM
val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false)
StructType(outputFields)
}

override def copy(extra: ParamMap): StandardScaler = defaultCopy(extra)
}

/**
Expand Down Expand Up @@ -125,4 +127,9 @@ class StandardScalerModel private[ml] (
val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false)
StructType(outputFields)
}

override def copy(extra: ParamMap): StandardScalerModel = {
val copied = new StandardScalerModel(uid, scaler)
copyValues(copied, extra)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod
override def transformSchema(schema: StructType): StructType = {
validateAndTransformSchema(schema)
}

override def copy(extra: ParamMap): StringIndexer = defaultCopy(extra)
}

/**
Expand Down Expand Up @@ -144,4 +146,9 @@ class StringIndexerModel private[ml] (
schema
}
}

override def copy(extra: ParamMap): StringIndexerModel = {
val copied = new StringIndexerModel(uid, labels)
copyValues(copied, extra)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ class Tokenizer(override val uid: String) extends UnaryTransformer[String, Seq[S
}

override protected def outputDataType: DataType = new ArrayType(StringType, false)

override def copy(extra: ParamMap): Tokenizer = defaultCopy(extra)
}

/**
Expand Down Expand Up @@ -112,4 +114,6 @@ class RegexTokenizer(override val uid: String)
}

override protected def outputDataType: DataType = new ArrayType(StringType, false)

override def copy(extra: ParamMap): RegexTokenizer = defaultCopy(extra)
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.spark.SparkException
import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute, UnresolvedAttribute}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors}
Expand Down Expand Up @@ -117,6 +118,8 @@ class VectorAssembler(override val uid: String)
}
StructType(schema.fields :+ new StructField(outputColName, new VectorUDT, false))
}

override def copy(extra: ParamMap): VectorAssembler = defaultCopy(extra)
}

private object VectorAssembler {
Expand Down
Loading

0 comments on commit 1f2dafb

Please sign in to comment.