diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala index a669da183e2c8..5ab6c2dde667a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -41,8 +41,12 @@ import org.apache.spark.sql.types.{DoubleType, NumericType, StructType} * The output vectors are sparse. * * @see `StringIndexer` for converting categorical values into category indices + * @deprecated `OneHotEncoderEstimator` will be renamed `OneHotEncoder` and this `OneHotEncoder` + * will be removed in 3.0.0. */ @Since("1.4.0") +@deprecated("`OneHotEncoderEstimator` will be renamed `OneHotEncoder` and this `OneHotEncoder`" + + " will be removed in 3.0.0.", "2.3.0") class OneHotEncoder @Since("1.4.0") (@Since("1.4.0") override val uid: String) extends Transformer with HasInputCol with HasOutputCol with DefaultParamsWritable { @@ -78,56 +82,16 @@ class OneHotEncoder @Since("1.4.0") (@Since("1.4.0") override val uid: String) e override def transformSchema(schema: StructType): StructType = { val inputColName = $(inputCol) val outputColName = $(outputCol) + val inputFields = schema.fields require(schema(inputColName).dataType.isInstanceOf[NumericType], s"Input column must be of type NumericType but got ${schema(inputColName).dataType}") - val inputFields = schema.fields require(!inputFields.exists(_.name == outputColName), s"Output column $outputColName already exists.") - val inputAttr = Attribute.fromStructField(schema(inputColName)) - val outputAttrNames: Option[Array[String]] = inputAttr match { - case nominal: NominalAttribute => - if (nominal.values.isDefined) { - nominal.values - } else if (nominal.numValues.isDefined) { - nominal.numValues.map(n => Array.tabulate(n)(_.toString)) - } else { - None - } - case binary: BinaryAttribute => - if (binary.values.isDefined) { - binary.values - } else { - Some(Array.tabulate(2)(_.toString)) - } - case _: NumericAttribute => - throw new RuntimeException( - s"The input column $inputColName cannot be numeric.") - case _ => - None // optimistic about unknown attributes - } - - val filteredOutputAttrNames = outputAttrNames.map { names => - if ($(dropLast)) { - require(names.length > 1, - s"The input column $inputColName should have at least two distinct values.") - names.dropRight(1) - } else { - names - } - } - - val outputAttrGroup = if (filteredOutputAttrNames.isDefined) { - val attrs: Array[Attribute] = filteredOutputAttrNames.get.map { name => - BinaryAttribute.defaultAttr.withName(name) - } - new AttributeGroup($(outputCol), attrs) - } else { - new AttributeGroup($(outputCol)) - } - - val outputFields = inputFields :+ outputAttrGroup.toStructField() + val outputField = OneHotEncoderCommon.transformOutputColumnSchema( + schema(inputColName), outputColName, $(dropLast)) + val outputFields = inputFields :+ outputField StructType(outputFields) } @@ -136,30 +100,17 @@ class OneHotEncoder @Since("1.4.0") (@Since("1.4.0") override val uid: String) e // schema transformation val inputColName = $(inputCol) val outputColName = $(outputCol) - val shouldDropLast = $(dropLast) - var outputAttrGroup = AttributeGroup.fromStructField( + + val outputAttrGroupFromSchema = AttributeGroup.fromStructField( transformSchema(dataset.schema)(outputColName)) - if (outputAttrGroup.size < 0) { - // If the number of attributes is unknown, we check the values from the input column. - val numAttrs = dataset.select(col(inputColName).cast(DoubleType)).rdd.map(_.getDouble(0)) - .treeAggregate(0.0)( - (m, x) => { - assert(x <= Int.MaxValue, - s"OneHotEncoder only supports up to ${Int.MaxValue} indices, but got $x") - assert(x >= 0.0 && x == x.toInt, - s"Values from column $inputColName must be indices, but got $x.") - math.max(m, x) - }, - (m0, m1) => { - math.max(m0, m1) - } - ).toInt + 1 - val outputAttrNames = Array.tabulate(numAttrs)(_.toString) - val filtered = if (shouldDropLast) outputAttrNames.dropRight(1) else outputAttrNames - val outputAttrs: Array[Attribute] = - filtered.map(name => BinaryAttribute.defaultAttr.withName(name)) - outputAttrGroup = new AttributeGroup(outputColName, outputAttrs) + + val outputAttrGroup = if (outputAttrGroupFromSchema.size < 0) { + OneHotEncoderCommon.getOutputAttrGroupFromData( + dataset, Seq(inputColName), Seq(outputColName), $(dropLast))(0) + } else { + outputAttrGroupFromSchema } + val metadata = outputAttrGroup.toMetadata() // data transformation diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala new file mode 100644 index 0000000000000..074622d41e28d --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala @@ -0,0 +1,522 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.hadoop.fs.Path + +import org.apache.spark.SparkException +import org.apache.spark.annotation.Since +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.attribute._ +import org.apache.spark.ml.linalg.Vectors +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCols, HasOutputCols} +import org.apache.spark.ml.util._ +import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.sql.expressions.UserDefinedFunction +import org.apache.spark.sql.functions.{col, lit, udf} +import org.apache.spark.sql.types.{DoubleType, NumericType, StructField, StructType} + +/** Private trait for params and common methods for OneHotEncoderEstimator and OneHotEncoderModel */ +private[ml] trait OneHotEncoderBase extends Params with HasHandleInvalid + with HasInputCols with HasOutputCols { + + /** + * Param for how to handle invalid data. + * Options are 'keep' (invalid data presented as an extra categorical feature) or + * 'error' (throw an error). + * Default: "error" + * @group param + */ + @Since("2.3.0") + override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", + "How to handle invalid data " + + "Options are 'keep' (invalid data presented as an extra categorical feature) " + + "or error (throw an error).", + ParamValidators.inArray(OneHotEncoderEstimator.supportedHandleInvalids)) + + setDefault(handleInvalid, OneHotEncoderEstimator.ERROR_INVALID) + + /** + * Whether to drop the last category in the encoded vector (default: true) + * @group param + */ + @Since("2.3.0") + final val dropLast: BooleanParam = + new BooleanParam(this, "dropLast", "whether to drop the last category") + setDefault(dropLast -> true) + + /** @group getParam */ + @Since("2.3.0") + def getDropLast: Boolean = $(dropLast) + + protected def validateAndTransformSchema( + schema: StructType, dropLast: Boolean, keepInvalid: Boolean): StructType = { + val inputColNames = $(inputCols) + val outputColNames = $(outputCols) + val existingFields = schema.fields + + require(inputColNames.length == outputColNames.length, + s"The number of input columns ${inputColNames.length} must be the same as the number of " + + s"output columns ${outputColNames.length}.") + + // Input columns must be NumericType. + inputColNames.foreach(SchemaUtils.checkNumericType(schema, _)) + + // Prepares output columns with proper attributes by examining input columns. + val inputFields = $(inputCols).map(schema(_)) + + val outputFields = inputFields.zip(outputColNames).map { case (inputField, outputColName) => + OneHotEncoderCommon.transformOutputColumnSchema( + inputField, outputColName, dropLast, keepInvalid) + } + outputFields.foldLeft(schema) { case (newSchema, outputField) => + SchemaUtils.appendColumn(newSchema, outputField) + } + } +} + +/** + * A one-hot encoder that maps a column of category indices to a column of binary vectors, with + * at most a single one-value per row that indicates the input category index. + * For example with 5 categories, an input value of 2.0 would map to an output vector of + * `[0.0, 0.0, 1.0, 0.0]`. + * The last category is not included by default (configurable via `dropLast`), + * because it makes the vector entries sum up to one, and hence linearly dependent. + * So an input value of 4.0 maps to `[0.0, 0.0, 0.0, 0.0]`. + * + * @note This is different from scikit-learn's OneHotEncoder, which keeps all categories. + * The output vectors are sparse. + * + * When `handleInvalid` is configured to 'keep', an extra "category" indicating invalid values is + * added as last category. So when `dropLast` is true, invalid values are encoded as all-zeros + * vector. + * + * @note When encoding multi-column by using `inputCols` and `outputCols` params, input/output cols + * come in pairs, specified by the order in the arrays, and each pair is treated independently. + * + * @see `StringIndexer` for converting categorical values into category indices + */ +@Since("2.3.0") +class OneHotEncoderEstimator @Since("2.3.0") (@Since("2.3.0") override val uid: String) + extends Estimator[OneHotEncoderModel] with OneHotEncoderBase with DefaultParamsWritable { + + @Since("2.3.0") + def this() = this(Identifiable.randomUID("oneHotEncoder")) + + /** @group setParam */ + @Since("2.3.0") + def setInputCols(values: Array[String]): this.type = set(inputCols, values) + + /** @group setParam */ + @Since("2.3.0") + def setOutputCols(values: Array[String]): this.type = set(outputCols, values) + + /** @group setParam */ + @Since("2.3.0") + def setDropLast(value: Boolean): this.type = set(dropLast, value) + + /** @group setParam */ + @Since("2.3.0") + def setHandleInvalid(value: String): this.type = set(handleInvalid, value) + + @Since("2.3.0") + override def transformSchema(schema: StructType): StructType = { + val keepInvalid = $(handleInvalid) == OneHotEncoderEstimator.KEEP_INVALID + validateAndTransformSchema(schema, dropLast = $(dropLast), + keepInvalid = keepInvalid) + } + + @Since("2.3.0") + override def fit(dataset: Dataset[_]): OneHotEncoderModel = { + transformSchema(dataset.schema) + + // Compute the plain number of categories without `handleInvalid` and + // `dropLast` taken into account. + val transformedSchema = validateAndTransformSchema(dataset.schema, dropLast = false, + keepInvalid = false) + val categorySizes = new Array[Int]($(outputCols).length) + + val columnToScanIndices = $(outputCols).zipWithIndex.flatMap { case (outputColName, idx) => + val numOfAttrs = AttributeGroup.fromStructField( + transformedSchema(outputColName)).size + if (numOfAttrs < 0) { + Some(idx) + } else { + categorySizes(idx) = numOfAttrs + None + } + } + + // Some input columns don't have attributes or their attributes don't have necessary info. + // We need to scan the data to get the number of values for each column. + if (columnToScanIndices.length > 0) { + val inputColNames = columnToScanIndices.map($(inputCols)(_)) + val outputColNames = columnToScanIndices.map($(outputCols)(_)) + + // When fitting data, we want the plain number of categories without `handleInvalid` and + // `dropLast` taken into account. + val attrGroups = OneHotEncoderCommon.getOutputAttrGroupFromData( + dataset, inputColNames, outputColNames, dropLast = false) + attrGroups.zip(columnToScanIndices).foreach { case (attrGroup, idx) => + categorySizes(idx) = attrGroup.size + } + } + + val model = new OneHotEncoderModel(uid, categorySizes).setParent(this) + copyValues(model) + } + + @Since("2.3.0") + override def copy(extra: ParamMap): OneHotEncoderEstimator = defaultCopy(extra) +} + +@Since("2.3.0") +object OneHotEncoderEstimator extends DefaultParamsReadable[OneHotEncoderEstimator] { + + private[feature] val KEEP_INVALID: String = "keep" + private[feature] val ERROR_INVALID: String = "error" + private[feature] val supportedHandleInvalids: Array[String] = Array(KEEP_INVALID, ERROR_INVALID) + + @Since("2.3.0") + override def load(path: String): OneHotEncoderEstimator = super.load(path) +} + +@Since("2.3.0") +class OneHotEncoderModel private[ml] ( + @Since("2.3.0") override val uid: String, + @Since("2.3.0") val categorySizes: Array[Int]) + extends Model[OneHotEncoderModel] with OneHotEncoderBase with MLWritable { + + import OneHotEncoderModel._ + + // Returns the category size for a given index with `dropLast` and `handleInvalid` + // taken into account. + private def configedCategorySize(orgCategorySize: Int, idx: Int): Int = { + val dropLast = getDropLast + val keepInvalid = getHandleInvalid == OneHotEncoderEstimator.KEEP_INVALID + + if (!dropLast && keepInvalid) { + // When `handleInvalid` is "keep", an extra category is added as last category + // for invalid data. + orgCategorySize + 1 + } else if (dropLast && !keepInvalid) { + // When `dropLast` is true, the last category is removed. + orgCategorySize - 1 + } else { + // When `dropLast` is true and `handleInvalid` is "keep", the extra category for invalid + // data is removed. Thus, it is the same as the plain number of categories. + orgCategorySize + } + } + + private def encoder: UserDefinedFunction = { + val oneValue = Array(1.0) + val emptyValues = Array.empty[Double] + val emptyIndices = Array.empty[Int] + val dropLast = getDropLast + val handleInvalid = getHandleInvalid + val keepInvalid = handleInvalid == OneHotEncoderEstimator.KEEP_INVALID + + // The udf performed on input data. The first parameter is the input value. The second + // parameter is the index of input. + udf { (label: Double, idx: Int) => + val plainNumCategories = categorySizes(idx) + val size = configedCategorySize(plainNumCategories, idx) + + if (label < 0) { + throw new SparkException(s"Negative value: $label. Input can't be negative.") + } else if (label == size && dropLast && !keepInvalid) { + // When `dropLast` is true and `handleInvalid` is not "keep", + // the last category is removed. + Vectors.sparse(size, emptyIndices, emptyValues) + } else if (label >= plainNumCategories && keepInvalid) { + // When `handleInvalid` is "keep", encodes invalid data to last category (and removed + // if `dropLast` is true) + if (dropLast) { + Vectors.sparse(size, emptyIndices, emptyValues) + } else { + Vectors.sparse(size, Array(size - 1), oneValue) + } + } else if (label < plainNumCategories) { + Vectors.sparse(size, Array(label.toInt), oneValue) + } else { + assert(handleInvalid == OneHotEncoderEstimator.ERROR_INVALID) + throw new SparkException(s"Unseen value: $label. To handle unseen values, " + + s"set Param handleInvalid to ${OneHotEncoderEstimator.KEEP_INVALID}.") + } + } + } + + /** @group setParam */ + @Since("2.3.0") + def setInputCols(values: Array[String]): this.type = set(inputCols, values) + + /** @group setParam */ + @Since("2.3.0") + def setOutputCols(values: Array[String]): this.type = set(outputCols, values) + + /** @group setParam */ + @Since("2.3.0") + def setDropLast(value: Boolean): this.type = set(dropLast, value) + + /** @group setParam */ + @Since("2.3.0") + def setHandleInvalid(value: String): this.type = set(handleInvalid, value) + + @Since("2.3.0") + override def transformSchema(schema: StructType): StructType = { + val inputColNames = $(inputCols) + val outputColNames = $(outputCols) + + require(inputColNames.length == categorySizes.length, + s"The number of input columns ${inputColNames.length} must be the same as the number of " + + s"features ${categorySizes.length} during fitting.") + + val keepInvalid = $(handleInvalid) == OneHotEncoderEstimator.KEEP_INVALID + val transformedSchema = validateAndTransformSchema(schema, dropLast = $(dropLast), + keepInvalid = keepInvalid) + verifyNumOfValues(transformedSchema) + } + + /** + * If the metadata of input columns also specifies the number of categories, we need to + * compare with expected category number with `handleInvalid` and `dropLast` taken into + * account. Mismatched numbers will cause exception. + */ + private def verifyNumOfValues(schema: StructType): StructType = { + $(outputCols).zipWithIndex.foreach { case (outputColName, idx) => + val inputColName = $(inputCols)(idx) + val attrGroup = AttributeGroup.fromStructField(schema(outputColName)) + + // If the input metadata specifies number of category for output column, + // comparing with expected category number with `handleInvalid` and + // `dropLast` taken into account. + if (attrGroup.attributes.nonEmpty) { + val numCategories = configedCategorySize(categorySizes(idx), idx) + require(attrGroup.size == numCategories, "OneHotEncoderModel expected " + + s"$numCategories categorical values for input column ${inputColName}, " + + s"but the input column had metadata specifying ${attrGroup.size} values.") + } + } + schema + } + + @Since("2.3.0") + override def transform(dataset: Dataset[_]): DataFrame = { + val transformedSchema = transformSchema(dataset.schema, logging = true) + val keepInvalid = $(handleInvalid) == OneHotEncoderEstimator.KEEP_INVALID + + val encodedColumns = (0 until $(inputCols).length).map { idx => + val inputColName = $(inputCols)(idx) + val outputColName = $(outputCols)(idx) + + val outputAttrGroupFromSchema = + AttributeGroup.fromStructField(transformedSchema(outputColName)) + + val metadata = if (outputAttrGroupFromSchema.size < 0) { + OneHotEncoderCommon.createAttrGroupForAttrNames(outputColName, + categorySizes(idx), $(dropLast), keepInvalid).toMetadata() + } else { + outputAttrGroupFromSchema.toMetadata() + } + + encoder(col(inputColName).cast(DoubleType), lit(idx)) + .as(outputColName, metadata) + } + dataset.withColumns($(outputCols), encodedColumns) + } + + @Since("2.3.0") + override def copy(extra: ParamMap): OneHotEncoderModel = { + val copied = new OneHotEncoderModel(uid, categorySizes) + copyValues(copied, extra).setParent(parent) + } + + @Since("2.3.0") + override def write: MLWriter = new OneHotEncoderModelWriter(this) +} + +@Since("2.3.0") +object OneHotEncoderModel extends MLReadable[OneHotEncoderModel] { + + private[OneHotEncoderModel] + class OneHotEncoderModelWriter(instance: OneHotEncoderModel) extends MLWriter { + + private case class Data(categorySizes: Array[Int]) + + override protected def saveImpl(path: String): Unit = { + DefaultParamsWriter.saveMetadata(instance, path, sc) + val data = Data(instance.categorySizes) + val dataPath = new Path(path, "data").toString + sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class OneHotEncoderModelReader extends MLReader[OneHotEncoderModel] { + + private val className = classOf[OneHotEncoderModel].getName + + override def load(path: String): OneHotEncoderModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val data = sparkSession.read.parquet(dataPath) + .select("categorySizes") + .head() + val categorySizes = data.getAs[Seq[Int]](0).toArray + val model = new OneHotEncoderModel(metadata.uid, categorySizes) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } + + @Since("2.3.0") + override def read: MLReader[OneHotEncoderModel] = new OneHotEncoderModelReader + + @Since("2.3.0") + override def load(path: String): OneHotEncoderModel = super.load(path) +} + +/** + * Provides some helper methods used by both `OneHotEncoder` and `OneHotEncoderEstimator`. + */ +private[feature] object OneHotEncoderCommon { + + private def genOutputAttrNames(inputCol: StructField): Option[Array[String]] = { + val inputAttr = Attribute.fromStructField(inputCol) + inputAttr match { + case nominal: NominalAttribute => + if (nominal.values.isDefined) { + nominal.values + } else if (nominal.numValues.isDefined) { + nominal.numValues.map(n => Array.tabulate(n)(_.toString)) + } else { + None + } + case binary: BinaryAttribute => + if (binary.values.isDefined) { + binary.values + } else { + Some(Array.tabulate(2)(_.toString)) + } + case _: NumericAttribute => + throw new RuntimeException( + s"The input column ${inputCol.name} cannot be continuous-value.") + case _ => + None // optimistic about unknown attributes + } + } + + /** Creates an `AttributeGroup` filled by the `BinaryAttribute` named as required. */ + private def genOutputAttrGroup( + outputAttrNames: Option[Array[String]], + outputColName: String): AttributeGroup = { + outputAttrNames.map { attrNames => + val attrs: Array[Attribute] = attrNames.map { name => + BinaryAttribute.defaultAttr.withName(name) + } + new AttributeGroup(outputColName, attrs) + }.getOrElse{ + new AttributeGroup(outputColName) + } + } + + /** + * Prepares the `StructField` with proper metadata for `OneHotEncoder`'s output column. + */ + def transformOutputColumnSchema( + inputCol: StructField, + outputColName: String, + dropLast: Boolean, + keepInvalid: Boolean = false): StructField = { + val outputAttrNames = genOutputAttrNames(inputCol) + val filteredOutputAttrNames = outputAttrNames.map { names => + if (dropLast && !keepInvalid) { + require(names.length > 1, + s"The input column ${inputCol.name} should have at least two distinct values.") + names.dropRight(1) + } else if (!dropLast && keepInvalid) { + names ++ Seq("invalidValues") + } else { + names + } + } + + genOutputAttrGroup(filteredOutputAttrNames, outputColName).toStructField() + } + + /** + * This method is called when we want to generate `AttributeGroup` from actual data for + * one-hot encoder. + */ + def getOutputAttrGroupFromData( + dataset: Dataset[_], + inputColNames: Seq[String], + outputColNames: Seq[String], + dropLast: Boolean): Seq[AttributeGroup] = { + // The RDD approach has advantage of early-stop if any values are invalid. It seems that + // DataFrame ops don't have equivalent functions. + val columns = inputColNames.map { inputColName => + col(inputColName).cast(DoubleType) + } + val numOfColumns = columns.length + + val numAttrsArray = dataset.select(columns: _*).rdd.map { row => + (0 until numOfColumns).map(idx => row.getDouble(idx)).toArray + }.treeAggregate(new Array[Double](numOfColumns))( + (maxValues, curValues) => { + (0 until numOfColumns).foreach { idx => + val x = curValues(idx) + assert(x <= Int.MaxValue, + s"OneHotEncoder only supports up to ${Int.MaxValue} indices, but got $x.") + assert(x >= 0.0 && x == x.toInt, + s"Values from column ${inputColNames(idx)} must be indices, but got $x.") + maxValues(idx) = math.max(maxValues(idx), x) + } + maxValues + }, + (m0, m1) => { + (0 until numOfColumns).foreach { idx => + m0(idx) = math.max(m0(idx), m1(idx)) + } + m0 + } + ).map(_.toInt + 1) + + outputColNames.zip(numAttrsArray).map { case (outputColName, numAttrs) => + createAttrGroupForAttrNames(outputColName, numAttrs, dropLast, keepInvalid = false) + } + } + + /** Creates an `AttributeGroup` with the required number of `BinaryAttribute`. */ + def createAttrGroupForAttrNames( + outputColName: String, + numAttrs: Int, + dropLast: Boolean, + keepInvalid: Boolean): AttributeGroup = { + val outputAttrNames = Array.tabulate(numAttrs)(_.toString) + val filtered = if (dropLast && !keepInvalid) { + outputAttrNames.dropRight(1) + } else if (!dropLast && keepInvalid) { + outputAttrNames ++ Seq("invalidValues") + } else { + outputAttrNames + } + genOutputAttrGroup(Some(filtered), outputColName) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala new file mode 100644 index 0000000000000..1d3f845586426 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderEstimatorSuite.scala @@ -0,0 +1,421 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.ml.attribute.{AttributeGroup, BinaryAttribute, NominalAttribute} +import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.types._ + +class OneHotEncoderEstimatorSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + + import testImplicits._ + + test("params") { + ParamsSuite.checkParams(new OneHotEncoderEstimator) + } + + test("OneHotEncoderEstimator dropLast = false") { + val data = Seq( + Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))), + Row(1.0, Vectors.sparse(3, Seq((1, 1.0)))), + Row(2.0, Vectors.sparse(3, Seq((2, 1.0)))), + Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))), + Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))), + Row(2.0, Vectors.sparse(3, Seq((2, 1.0))))) + + val schema = StructType(Array( + StructField("input", DoubleType), + StructField("expected", new VectorUDT))) + + val df = spark.createDataFrame(sc.parallelize(data), schema) + + val encoder = new OneHotEncoderEstimator() + .setInputCols(Array("input")) + .setOutputCols(Array("output")) + assert(encoder.getDropLast === true) + encoder.setDropLast(false) + assert(encoder.getDropLast === false) + + val model = encoder.fit(df) + val encoded = model.transform(df) + encoded.select("output", "expected").rdd.map { r => + (r.getAs[Vector](0), r.getAs[Vector](1)) + }.collect().foreach { case (vec1, vec2) => + assert(vec1 === vec2) + } + } + + test("OneHotEncoderEstimator dropLast = true") { + val data = Seq( + Row(0.0, Vectors.sparse(2, Seq((0, 1.0)))), + Row(1.0, Vectors.sparse(2, Seq((1, 1.0)))), + Row(2.0, Vectors.sparse(2, Seq())), + Row(0.0, Vectors.sparse(2, Seq((0, 1.0)))), + Row(0.0, Vectors.sparse(2, Seq((0, 1.0)))), + Row(2.0, Vectors.sparse(2, Seq()))) + + val schema = StructType(Array( + StructField("input", DoubleType), + StructField("expected", new VectorUDT))) + + val df = spark.createDataFrame(sc.parallelize(data), schema) + + val encoder = new OneHotEncoderEstimator() + .setInputCols(Array("input")) + .setOutputCols(Array("output")) + + val model = encoder.fit(df) + val encoded = model.transform(df) + encoded.select("output", "expected").rdd.map { r => + (r.getAs[Vector](0), r.getAs[Vector](1)) + }.collect().foreach { case (vec1, vec2) => + assert(vec1 === vec2) + } + } + + test("input column with ML attribute") { + val attr = NominalAttribute.defaultAttr.withValues("small", "medium", "large") + val df = Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply).toDF("size") + .select(col("size").as("size", attr.toMetadata())) + val encoder = new OneHotEncoderEstimator() + .setInputCols(Array("size")) + .setOutputCols(Array("encoded")) + val model = encoder.fit(df) + val output = model.transform(df) + val group = AttributeGroup.fromStructField(output.schema("encoded")) + assert(group.size === 2) + assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("small").withIndex(0)) + assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("medium").withIndex(1)) + } + + test("input column without ML attribute") { + val df = Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply).toDF("index") + val encoder = new OneHotEncoderEstimator() + .setInputCols(Array("index")) + .setOutputCols(Array("encoded")) + val model = encoder.fit(df) + val output = model.transform(df) + val group = AttributeGroup.fromStructField(output.schema("encoded")) + assert(group.size === 2) + assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("0").withIndex(0)) + assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("1").withIndex(1)) + } + + test("read/write") { + val encoder = new OneHotEncoderEstimator() + .setInputCols(Array("index")) + .setOutputCols(Array("encoded")) + testDefaultReadWrite(encoder) + } + + test("OneHotEncoderModel read/write") { + val instance = new OneHotEncoderModel("myOneHotEncoderModel", Array(1, 2, 3)) + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.categorySizes === instance.categorySizes) + } + + test("OneHotEncoderEstimator with varying types") { + val data = Seq( + Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))), + Row(1.0, Vectors.sparse(3, Seq((1, 1.0)))), + Row(2.0, Vectors.sparse(3, Seq((2, 1.0)))), + Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))), + Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))), + Row(2.0, Vectors.sparse(3, Seq((2, 1.0))))) + + val schema = StructType(Array( + StructField("input", DoubleType), + StructField("expected", new VectorUDT))) + + val df = spark.createDataFrame(sc.parallelize(data), schema) + + val dfWithTypes = df + .withColumn("shortInput", df("input").cast(ShortType)) + .withColumn("longInput", df("input").cast(LongType)) + .withColumn("intInput", df("input").cast(IntegerType)) + .withColumn("floatInput", df("input").cast(FloatType)) + .withColumn("decimalInput", df("input").cast(DecimalType(10, 0))) + + val cols = Array("input", "shortInput", "longInput", "intInput", + "floatInput", "decimalInput") + for (col <- cols) { + val encoder = new OneHotEncoderEstimator() + .setInputCols(Array(col)) + .setOutputCols(Array("output")) + .setDropLast(false) + + val model = encoder.fit(dfWithTypes) + val encoded = model.transform(dfWithTypes) + + encoded.select("output", "expected").rdd.map { r => + (r.getAs[Vector](0), r.getAs[Vector](1)) + }.collect().foreach { case (vec1, vec2) => + assert(vec1 === vec2) + } + } + } + + test("OneHotEncoderEstimator: encoding multiple columns and dropLast = false") { + val data = Seq( + Row(0.0, Vectors.sparse(3, Seq((0, 1.0))), 2.0, Vectors.sparse(4, Seq((2, 1.0)))), + Row(1.0, Vectors.sparse(3, Seq((1, 1.0))), 3.0, Vectors.sparse(4, Seq((3, 1.0)))), + Row(2.0, Vectors.sparse(3, Seq((2, 1.0))), 0.0, Vectors.sparse(4, Seq((0, 1.0)))), + Row(0.0, Vectors.sparse(3, Seq((0, 1.0))), 1.0, Vectors.sparse(4, Seq((1, 1.0)))), + Row(0.0, Vectors.sparse(3, Seq((0, 1.0))), 0.0, Vectors.sparse(4, Seq((0, 1.0)))), + Row(2.0, Vectors.sparse(3, Seq((2, 1.0))), 2.0, Vectors.sparse(4, Seq((2, 1.0))))) + + val schema = StructType(Array( + StructField("input1", DoubleType), + StructField("expected1", new VectorUDT), + StructField("input2", DoubleType), + StructField("expected2", new VectorUDT))) + + val df = spark.createDataFrame(sc.parallelize(data), schema) + + val encoder = new OneHotEncoderEstimator() + .setInputCols(Array("input1", "input2")) + .setOutputCols(Array("output1", "output2")) + assert(encoder.getDropLast === true) + encoder.setDropLast(false) + assert(encoder.getDropLast === false) + + val model = encoder.fit(df) + val encoded = model.transform(df) + encoded.select("output1", "expected1", "output2", "expected2").rdd.map { r => + (r.getAs[Vector](0), r.getAs[Vector](1), r.getAs[Vector](2), r.getAs[Vector](3)) + }.collect().foreach { case (vec1, vec2, vec3, vec4) => + assert(vec1 === vec2) + assert(vec3 === vec4) + } + } + + test("OneHotEncoderEstimator: encoding multiple columns and dropLast = true") { + val data = Seq( + Row(0.0, Vectors.sparse(2, Seq((0, 1.0))), 2.0, Vectors.sparse(3, Seq((2, 1.0)))), + Row(1.0, Vectors.sparse(2, Seq((1, 1.0))), 3.0, Vectors.sparse(3, Seq())), + Row(2.0, Vectors.sparse(2, Seq()), 0.0, Vectors.sparse(3, Seq((0, 1.0)))), + Row(0.0, Vectors.sparse(2, Seq((0, 1.0))), 1.0, Vectors.sparse(3, Seq((1, 1.0)))), + Row(0.0, Vectors.sparse(2, Seq((0, 1.0))), 0.0, Vectors.sparse(3, Seq((0, 1.0)))), + Row(2.0, Vectors.sparse(2, Seq()), 2.0, Vectors.sparse(3, Seq((2, 1.0))))) + + val schema = StructType(Array( + StructField("input1", DoubleType), + StructField("expected1", new VectorUDT), + StructField("input2", DoubleType), + StructField("expected2", new VectorUDT))) + + val df = spark.createDataFrame(sc.parallelize(data), schema) + + val encoder = new OneHotEncoderEstimator() + .setInputCols(Array("input1", "input2")) + .setOutputCols(Array("output1", "output2")) + + val model = encoder.fit(df) + val encoded = model.transform(df) + encoded.select("output1", "expected1", "output2", "expected2").rdd.map { r => + (r.getAs[Vector](0), r.getAs[Vector](1), r.getAs[Vector](2), r.getAs[Vector](3)) + }.collect().foreach { case (vec1, vec2, vec3, vec4) => + assert(vec1 === vec2) + assert(vec3 === vec4) + } + } + + test("Throw error on invalid values") { + val trainingData = Seq((0, 0), (1, 1), (2, 2)) + val trainingDF = trainingData.toDF("id", "a") + val testData = Seq((0, 0), (1, 2), (1, 3)) + val testDF = testData.toDF("id", "a") + + val encoder = new OneHotEncoderEstimator() + .setInputCols(Array("a")) + .setOutputCols(Array("encoded")) + + val model = encoder.fit(trainingDF) + val err = intercept[SparkException] { + model.transform(testDF).show + } + err.getMessage.contains("Unseen value: 3.0. To handle unseen values") + } + + test("Can't transform on negative input") { + val trainingDF = Seq((0, 0), (1, 1), (2, 2)).toDF("a", "b") + val testDF = Seq((0, 0), (-1, 2), (1, 3)).toDF("a", "b") + + val encoder = new OneHotEncoderEstimator() + .setInputCols(Array("a")) + .setOutputCols(Array("encoded")) + + val model = encoder.fit(trainingDF) + val err = intercept[SparkException] { + model.transform(testDF).collect() + } + err.getMessage.contains("Negative value: -1.0. Input can't be negative") + } + + test("Keep on invalid values: dropLast = false") { + val trainingDF = Seq(Tuple1(0), Tuple1(1), Tuple1(2)).toDF("input") + + val testData = Seq( + Row(0.0, Vectors.sparse(4, Seq((0, 1.0)))), + Row(1.0, Vectors.sparse(4, Seq((1, 1.0)))), + Row(3.0, Vectors.sparse(4, Seq((3, 1.0))))) + + val schema = StructType(Array( + StructField("input", DoubleType), + StructField("expected", new VectorUDT))) + + val testDF = spark.createDataFrame(sc.parallelize(testData), schema) + + val encoder = new OneHotEncoderEstimator() + .setInputCols(Array("input")) + .setOutputCols(Array("output")) + .setHandleInvalid("keep") + .setDropLast(false) + + val model = encoder.fit(trainingDF) + val encoded = model.transform(testDF) + encoded.select("output", "expected").rdd.map { r => + (r.getAs[Vector](0), r.getAs[Vector](1)) + }.collect().foreach { case (vec1, vec2) => + assert(vec1 === vec2) + } + } + + test("Keep on invalid values: dropLast = true") { + val trainingDF = Seq(Tuple1(0), Tuple1(1), Tuple1(2)).toDF("input") + + val testData = Seq( + Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))), + Row(1.0, Vectors.sparse(3, Seq((1, 1.0)))), + Row(3.0, Vectors.sparse(3, Seq()))) + + val schema = StructType(Array( + StructField("input", DoubleType), + StructField("expected", new VectorUDT))) + + val testDF = spark.createDataFrame(sc.parallelize(testData), schema) + + val encoder = new OneHotEncoderEstimator() + .setInputCols(Array("input")) + .setOutputCols(Array("output")) + .setHandleInvalid("keep") + .setDropLast(true) + + val model = encoder.fit(trainingDF) + val encoded = model.transform(testDF) + encoded.select("output", "expected").rdd.map { r => + (r.getAs[Vector](0), r.getAs[Vector](1)) + }.collect().foreach { case (vec1, vec2) => + assert(vec1 === vec2) + } + } + + test("OneHotEncoderModel changes dropLast") { + val data = Seq( + Row(0.0, Vectors.sparse(3, Seq((0, 1.0))), Vectors.sparse(2, Seq((0, 1.0)))), + Row(1.0, Vectors.sparse(3, Seq((1, 1.0))), Vectors.sparse(2, Seq((1, 1.0)))), + Row(2.0, Vectors.sparse(3, Seq((2, 1.0))), Vectors.sparse(2, Seq())), + Row(0.0, Vectors.sparse(3, Seq((0, 1.0))), Vectors.sparse(2, Seq((0, 1.0)))), + Row(0.0, Vectors.sparse(3, Seq((0, 1.0))), Vectors.sparse(2, Seq((0, 1.0)))), + Row(2.0, Vectors.sparse(3, Seq((2, 1.0))), Vectors.sparse(2, Seq()))) + + val schema = StructType(Array( + StructField("input", DoubleType), + StructField("expected1", new VectorUDT), + StructField("expected2", new VectorUDT))) + + val df = spark.createDataFrame(sc.parallelize(data), schema) + + val encoder = new OneHotEncoderEstimator() + .setInputCols(Array("input")) + .setOutputCols(Array("output")) + + val model = encoder.fit(df) + + model.setDropLast(false) + val encoded1 = model.transform(df) + encoded1.select("output", "expected1").rdd.map { r => + (r.getAs[Vector](0), r.getAs[Vector](1)) + }.collect().foreach { case (vec1, vec2) => + assert(vec1 === vec2) + } + + model.setDropLast(true) + val encoded2 = model.transform(df) + encoded2.select("output", "expected2").rdd.map { r => + (r.getAs[Vector](0), r.getAs[Vector](1)) + }.collect().foreach { case (vec1, vec2) => + assert(vec1 === vec2) + } + } + + test("OneHotEncoderModel changes handleInvalid") { + val trainingDF = Seq(Tuple1(0), Tuple1(1), Tuple1(2)).toDF("input") + + val testData = Seq( + Row(0.0, Vectors.sparse(4, Seq((0, 1.0)))), + Row(1.0, Vectors.sparse(4, Seq((1, 1.0)))), + Row(3.0, Vectors.sparse(4, Seq((3, 1.0))))) + + val schema = StructType(Array( + StructField("input", DoubleType), + StructField("expected", new VectorUDT))) + + val testDF = spark.createDataFrame(sc.parallelize(testData), schema) + + val encoder = new OneHotEncoderEstimator() + .setInputCols(Array("input")) + .setOutputCols(Array("output")) + + val model = encoder.fit(trainingDF) + model.setHandleInvalid("error") + + val err = intercept[SparkException] { + model.transform(testDF).collect() + } + err.getMessage.contains("Unseen value: 3.0. To handle unseen values") + + model.setHandleInvalid("keep") + model.transform(testDF).collect() + } + + test("Transforming on mismatched attributes") { + val attr = NominalAttribute.defaultAttr.withValues("small", "medium", "large") + val df = Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply).toDF("size") + .select(col("size").as("size", attr.toMetadata())) + val encoder = new OneHotEncoderEstimator() + .setInputCols(Array("size")) + .setOutputCols(Array("encoded")) + val model = encoder.fit(df) + + val testAttr = NominalAttribute.defaultAttr.withValues("tiny", "small", "medium", "large") + val testDF = Seq(0.0, 1.0, 2.0, 3.0).map(Tuple1.apply).toDF("size") + .select(col("size").as("size", testAttr.toMetadata())) + val err = intercept[Exception] { + model.transform(testDF).collect() + } + err.getMessage.contains("OneHotEncoderModel expected 2 categorical values") + } +}