Skip to content

Commit

Permalink
[SPARK-13761][ML] Deprecate validateParams
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

Deprecate validateParams() method here: https://github.com/apache/spark/blob/035d3acdf3c1be5b309a861d5c5beb803b946b5e/mllib/src/main/scala/org/apache/spark/ml/param/params.scala#L553
Move all functionality in overridden methods to transformSchema().
Check docs to make sure they indicate complex Param interaction checks should be done in transformSchema.

## How was this patch tested?

unit tests

Author: Yuhao Yang <[email protected]>

Closes #11620 from hhbyyh/depreValid.
  • Loading branch information
hhbyyh authored and jkbradley committed Mar 17, 2016
1 parent d4d8493 commit 92b7057
Show file tree
Hide file tree
Showing 33 changed files with 36 additions and 89 deletions.
14 changes: 0 additions & 14 deletions mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,6 @@ class Pipeline @Since("1.4.0") (
@Since("1.2.0")
def getStages: Array[PipelineStage] = $(stages).clone()

@Since("1.4.0")
override def validateParams(): Unit = {
super.validateParams()
$(stages).foreach(_.validateParams())
}

/**
* Fits the pipeline to the input dataset with additional parameters. If a stage is an
* [[Estimator]], its [[Estimator#fit]] method will be called on the input dataset to fit a model.
Expand Down Expand Up @@ -175,7 +169,6 @@ class Pipeline @Since("1.4.0") (

@Since("1.2.0")
override def transformSchema(schema: StructType): StructType = {
validateParams()
val theStages = $(stages)
require(theStages.toSet.size == theStages.length,
"Cannot have duplicate components in a pipeline.")
Expand Down Expand Up @@ -297,12 +290,6 @@ class PipelineModel private[ml] (
this(uid, stages.asScala.toArray)
}

@Since("1.4.0")
override def validateParams(): Unit = {
super.validateParams()
stages.foreach(_.validateParams())
}

@Since("1.2.0")
override def transform(dataset: DataFrame): DataFrame = {
transformSchema(dataset.schema, logging = true)
Expand All @@ -311,7 +298,6 @@ class PipelineModel private[ml] (

@Since("1.2.0")
override def transformSchema(schema: StructType): StructType = {
validateParams()
stages.foldLeft(schema)((cur, transformer) => transformer.transformSchema(cur))
}

Expand Down
1 change: 0 additions & 1 deletion mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ private[ml] trait PredictorParams extends Params
schema: StructType,
fitting: Boolean,
featuresDataType: DataType): StructType = {
validateParams()
// TODO: Support casting Array[Double] and Array[Float] to Vector when FeaturesType = Vector
SchemaUtils.checkColumnType(schema, $(featuresCol), featuresDataType)
if (fitting) {
Expand Down
1 change: 0 additions & 1 deletion mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@ abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]]
protected def validateInputType(inputType: DataType): Unit = {}

override def transformSchema(schema: StructType): StructType = {
validateParams()
val inputType = schema($(inputCol)).dataType
validateInputType(inputType)
if (schema.fieldNames.contains($(outputCol))) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe
* @return output schema
*/
protected def validateAndTransformSchema(schema: StructType): StructType = {
validateParams()
SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType)
}
Expand Down
9 changes: 2 additions & 7 deletions mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
Original file line number Diff line number Diff line change
Expand Up @@ -263,13 +263,6 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
* @return output schema
*/
protected def validateAndTransformSchema(schema: StructType): StructType = {
validateParams()
SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
SchemaUtils.appendColumn(schema, $(topicDistributionCol), new VectorUDT)
}

@Since("1.6.0")
override def validateParams(): Unit = {
if (isSet(docConcentration)) {
if (getDocConcentration.length != 1) {
require(getDocConcentration.length == getK, s"LDA docConcentration was of length" +
Expand Down Expand Up @@ -297,6 +290,8 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
s" must be >= 1. Found value: $getTopicConcentration")
}
}
SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
SchemaUtils.appendColumn(schema, $(topicDistributionCol), new VectorUDT)
}

private[clustering] def getOldOptimizer: OldLDAOptimizer = getOptimizer match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ final class Bucketizer(override val uid: String)
}

override def transformSchema(schema: StructType): StructType = {
validateParams()
SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType)
SchemaUtils.appendColumn(schema, prepOutputField(schema))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ final class ChiSqSelector(override val uid: String)
}

override def transformSchema(schema: StructType): StructType = {
validateParams()
SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT)
Expand Down Expand Up @@ -136,7 +135,6 @@ final class ChiSqSelectorModel private[ml] (
}

override def transformSchema(schema: StructType): StructType = {
validateParams()
SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
val newField = prepOutputField(schema)
val outputFields = schema.fields :+ newField
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ private[feature] trait CountVectorizerParams extends Params with HasInputCol wit

/** Validates and transforms the input schema. */
protected def validateAndTransformSchema(schema: StructType): StructType = {
validateParams()
val typeCandidates = List(new ArrayType(StringType, true), new ArrayType(StringType, false))
SchemaUtils.checkColumnTypes(schema, $(inputCol), typeCandidates)
SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ class HashingTF(override val uid: String)
}

override def transformSchema(schema: StructType): StructType = {
validateParams()
val inputType = schema($(inputCol)).dataType
require(inputType.isInstanceOf[ArrayType],
s"The input column must be ArrayType, but got $inputType.")
Expand Down
1 change: 0 additions & 1 deletion mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ private[feature] trait IDFBase extends Params with HasInputCol with HasOutputCol
* Validate and transform the input schema.
*/
protected def validateAndTransformSchema(schema: StructType): StructType = {
validateParams()
SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT)
SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,15 @@ class Interaction @Since("1.6.0") (override val uid: String) extends Transformer
// optimistic schema; does not contain any ML attributes
@Since("1.6.0")
override def transformSchema(schema: StructType): StructType = {
validateParams()
require(get(inputCols).isDefined, "Input cols must be defined first.")
require(get(outputCol).isDefined, "Output col must be defined first.")
require($(inputCols).length > 0, "Input cols must have non-zero length.")
require($(inputCols).distinct.length == $(inputCols).length, "Input cols must be distinct.")
StructType(schema.fields :+ StructField($(outputCol), new VectorUDT, false))
}

@Since("1.6.0")
override def transform(dataset: DataFrame): DataFrame = {
validateParams()
val inputFeatures = $(inputCols).map(c => dataset.schema(c))
val featureEncoders = getFeatureEncoders(inputFeatures)
val featureAttrs = getFeatureAttrs(inputFeatures)
Expand Down Expand Up @@ -217,13 +219,6 @@ class Interaction @Since("1.6.0") (override val uid: String) extends Transformer
@Since("1.6.0")
override def copy(extra: ParamMap): Interaction = defaultCopy(extra)

@Since("1.6.0")
override def validateParams(): Unit = {
require(get(inputCols).isDefined, "Input cols must be defined first.")
require(get(outputCol).isDefined, "Output col must be defined first.")
require($(inputCols).length > 0, "Input cols must have non-zero length.")
require($(inputCols).distinct.length == $(inputCols).length, "Input cols must be distinct.")
}
}

@Since("1.6.0")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ private[feature] trait MaxAbsScalerParams extends Params with HasInputCol with H

/** Validates and transforms the input schema. */
protected def validateAndTransformSchema(schema: StructType): StructType = {
validateParams()
val inputType = schema($(inputCol)).dataType
require(inputType.isInstanceOf[VectorUDT],
s"Input column ${$(inputCol)} must be a vector column")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ private[feature] trait MinMaxScalerParams extends Params with HasInputCol with H

/** Validates and transforms the input schema. */
protected def validateAndTransformSchema(schema: StructType): StructType = {
validateParams()
require($(min) < $(max), s"The specified min(${$(min)}) is larger or equal to max(${$(max)})")
val inputType = schema($(inputCol)).dataType
require(inputType.isInstanceOf[VectorUDT],
s"Input column ${$(inputCol)} must be a vector column")
Expand All @@ -69,9 +69,6 @@ private[feature] trait MinMaxScalerParams extends Params with HasInputCol with H
StructType(outputFields)
}

override def validateParams(): Unit = {
require($(min) < $(max), s"The specified min(${$(min)}) is larger or equal to max(${$(max)})")
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ class OneHotEncoder(override val uid: String) extends Transformer
def setOutputCol(value: String): this.type = set(outputCol, value)

override def transformSchema(schema: StructType): StructType = {
validateParams()
val inputColName = $(inputCol)
val outputColName = $(outputCol)

Expand Down
2 changes: 0 additions & 2 deletions mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams
}

override def transformSchema(schema: StructType): StructType = {
validateParams()
val inputType = schema($(inputCol)).dataType
require(inputType.isInstanceOf[VectorUDT],
s"Input column ${$(inputCol)} must be a vector column")
Expand Down Expand Up @@ -133,7 +132,6 @@ class PCAModel private[ml] (
}

override def transformSchema(schema: StructType): StructType = {
validateParams()
val inputType = schema($(inputCol)).dataType
require(inputType.isInstanceOf[VectorUDT],
s"Input column ${$(inputCol)} must be a vector column")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ final class QuantileDiscretizer(override val uid: String)
def setSeed(value: Long): this.type = set(seed, value)

override def transformSchema(schema: StructType): StructType = {
validateParams()
SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType)
val inputFields = schema.fields
require(inputFields.forall(_.name != $(outputCol)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,6 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R

// optimistic schema; does not contain any ML attributes
override def transformSchema(schema: StructType): StructType = {
validateParams()
if (hasLabelCol(schema)) {
StructType(schema.fields :+ StructField($(featuresCol), new VectorUDT, true))
} else {
Expand Down Expand Up @@ -200,7 +199,6 @@ class RFormulaModel private[feature](
}

override def transformSchema(schema: StructType): StructType = {
validateParams()
checkCanTransform(schema)
val withFeatures = pipelineModel.transformSchema(schema)
if (hasLabelCol(withFeatures)) {
Expand Down Expand Up @@ -263,7 +261,6 @@ private class ColumnPruner(columnsToPrune: Set[String]) extends Transformer {
}

override def transformSchema(schema: StructType): StructType = {
validateParams()
StructType(schema.fields.filter(col => !columnsToPrune.contains(col.name)))
}

Expand Down Expand Up @@ -312,7 +309,6 @@ private class VectorAttributeRewriter(
}

override def transformSchema(schema: StructType): StructType = {
validateParams()
StructType(
schema.fields.filter(_.name != vectorCol) ++
schema.fields.filter(_.name == vectorCol))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transfor

@Since("1.6.0")
override def transformSchema(schema: StructType): StructType = {
validateParams()
val sc = SparkContext.getOrCreate()
val sqlContext = SQLContext.getOrCreate(sc)
val dummyRDD = sc.parallelize(Seq(Row.empty))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM
}

override def transformSchema(schema: StructType): StructType = {
validateParams()
val inputType = schema($(inputCol)).dataType
require(inputType.isInstanceOf[VectorUDT],
s"Input column ${$(inputCol)} must be a vector column")
Expand Down Expand Up @@ -144,7 +143,6 @@ class StandardScalerModel private[ml] (
}

override def transformSchema(schema: StructType): StructType = {
validateParams()
val inputType = schema($(inputCol)).dataType
require(inputType.isInstanceOf[VectorUDT],
s"Input column ${$(inputCol)} must be a vector column")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,6 @@ class StopWordsRemover(override val uid: String)
}

override def transformSchema(schema: StructType): StructType = {
validateParams()
val inputType = schema($(inputCol)).dataType
require(inputType.sameType(ArrayType(StringType)),
s"Input type must be ArrayType(StringType) but got $inputType.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha

/** Validates and transforms the input schema. */
protected def validateAndTransformSchema(schema: StructType): StructType = {
validateParams()
val inputColName = $(inputCol)
val inputDataType = schema(inputColName).dataType
require(inputDataType == StringType || inputDataType.isInstanceOf[NumericType],
Expand Down Expand Up @@ -275,7 +274,6 @@ class IndexToString private[ml] (override val uid: String)
final def getLabels: Array[String] = $(labels)

override def transformSchema(schema: StructType): StructType = {
validateParams()
val inputColName = $(inputCol)
val inputDataType = schema(inputColName).dataType
require(inputDataType.isInstanceOf[NumericType],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ class VectorAssembler(override val uid: String)
}

override def transformSchema(schema: StructType): StructType = {
validateParams()
val inputColNames = $(inputCols)
val outputColName = $(outputCol)
val inputDataTypes = inputColNames.map(name => schema(name).dataType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ class VectorIndexer(override val uid: String) extends Estimator[VectorIndexerMod
}

override def transformSchema(schema: StructType): StructType = {
validateParams()
// We do not transfer feature metadata since we do not know what types of features we will
// produce in transform().
val dataType = new VectorUDT
Expand Down Expand Up @@ -355,7 +354,6 @@ class VectorIndexerModel private[ml] (
}

override def transformSchema(schema: StructType): StructType = {
validateParams()
val dataType = new VectorUDT
require(isDefined(inputCol),
s"VectorIndexerModel requires input column parameter: $inputCol")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,6 @@ final class VectorSlicer(override val uid: String)
/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)

override def validateParams(): Unit = {
require($(indices).length > 0 || $(names).length > 0,
s"VectorSlicer requires that at least one feature be selected.")
}

override def transform(dataset: DataFrame): DataFrame = {
// Validity checks
transformSchema(dataset.schema)
Expand Down Expand Up @@ -139,7 +134,8 @@ final class VectorSlicer(override val uid: String)
}

override def transformSchema(schema: StructType): StructType = {
validateParams()
require($(indices).length > 0 || $(names).length > 0,
s"VectorSlicer requires that at least one feature be selected.")
SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT)

if (schema.fieldNames.contains($(outputCol))) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ private[feature] trait Word2VecBase extends Params
* Validate and transform the input schema.
*/
protected def validateAndTransformSchema(schema: StructType): StructType = {
validateParams()
SchemaUtils.checkColumnType(schema, $(inputCol), new ArrayType(StringType, true))
SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT)
}
Expand Down
7 changes: 4 additions & 3 deletions mllib/src/main/scala/org/apache/spark/ml/param/params.scala
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,8 @@ class Param[T](val parent: String, val name: String, val doc: String, val isVali
/**
* Assert that the given value is valid for this parameter.
*
* Note: Parameter checks involving interactions between multiple parameters should be
* implemented in [[Params.validateParams()]]. Checks for input/output columns should be
* implemented in [[org.apache.spark.ml.PipelineStage.transformSchema()]].
* Note: Parameter checks involving interactions between multiple parameters and input/output
* columns should be implemented in [[org.apache.spark.ml.PipelineStage.transformSchema()]].
*
* DEVELOPERS: This method is only called by [[ParamPair]], which means that all parameters
* should be specified via [[ParamPair]].
Expand Down Expand Up @@ -555,7 +554,9 @@ trait Params extends Identifiable with Serializable {
* Parameter value checks which do not depend on other parameters are handled by
* [[Param.validate()]]. This method does not handle input/output column parameters;
* those are checked during schema validation.
* @deprecated Will be removed in 2.1.0. All the checks should be merged into transformSchema
*/
@deprecated("Will be removed in 2.1.0. Checks should be merged into transformSchema.", "2.0.0")
def validateParams(): Unit = {
// Do nothing by default. Override to handle Param interactions.
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,6 @@ private[recommendation] trait ALSParams extends ALSModelParams with HasMaxIter w
* @return output schema
*/
protected def validateAndTransformSchema(schema: StructType): StructType = {
validateParams()
SchemaUtils.checkColumnType(schema, $(userCol), IntegerType)
SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType)
val ratingType = schema($(ratingCol)).dataType
Expand Down Expand Up @@ -220,7 +219,6 @@ class ALSModel private[ml] (

@Since("1.3.0")
override def transformSchema(schema: StructType): StructType = {
validateParams()
SchemaUtils.checkColumnType(schema, $(userCol), IntegerType)
SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType)
SchemaUtils.appendColumn(schema, $(predictionCol), FloatType)
Expand Down
Loading

0 comments on commit 92b7057

Please sign in to comment.