Skip to content

Commit

Permalink
[SPARK-13030][ML] Follow-up cleanups for OneHotEncoderEstimator
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

Follow-up cleanups for the OneHotEncoderEstimator PR.  See some discussion in the original PR: #19527 or read below for what this PR includes:
* configedCategorySize: I reverted this to return an Array.  I realized the original setup (which I had recommended in the original PR) caused the whole model to be serialized in the UDF.
* encoder: I reorganized the logic to show what I meant in the comment in the previous PR.  I think it's simpler but am open to suggestions.

I also made some small style cleanups based on IntelliJ warnings.

## How was this patch tested?

Existing unit tests

Author: Joseph K. Bradley <[email protected]>

Closes #20132 from jkbradley/viirya-SPARK-13030.
  • Loading branch information
jkbradley committed Jan 5, 2018
1 parent c0b7424 commit 930b90a
Showing 1 changed file with 49 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,24 +30,27 @@ import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.functions.{col, lit, udf}
import org.apache.spark.sql.types.{DoubleType, NumericType, StructField, StructType}
import org.apache.spark.sql.types.{DoubleType, StructField, StructType}

/** Private trait for params and common methods for OneHotEncoderEstimator and OneHotEncoderModel */
private[ml] trait OneHotEncoderBase extends Params with HasHandleInvalid
with HasInputCols with HasOutputCols {

/**
* Param for how to handle invalid data.
* Param for how to handle invalid data during transform().
* Options are 'keep' (invalid data presented as an extra categorical feature) or
* 'error' (throw an error).
* Note that this Param is only used during transform; during fitting, invalid data
* will result in an error.
* Default: "error"
* @group param
*/
@Since("2.3.0")
override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid",
"How to handle invalid data " +
"How to handle invalid data during transform(). " +
"Options are 'keep' (invalid data presented as an extra categorical feature) " +
"or error (throw an error).",
"or error (throw an error). Note that this Param is only used during transform; " +
"during fitting, invalid data will result in an error.",
ParamValidators.inArray(OneHotEncoderEstimator.supportedHandleInvalids))

setDefault(handleInvalid, OneHotEncoderEstimator.ERROR_INVALID)
Expand All @@ -66,10 +69,11 @@ private[ml] trait OneHotEncoderBase extends Params with HasHandleInvalid
def getDropLast: Boolean = $(dropLast)

protected def validateAndTransformSchema(
schema: StructType, dropLast: Boolean, keepInvalid: Boolean): StructType = {
schema: StructType,
dropLast: Boolean,
keepInvalid: Boolean): StructType = {
val inputColNames = $(inputCols)
val outputColNames = $(outputCols)
val existingFields = schema.fields

require(inputColNames.length == outputColNames.length,
s"The number of input columns ${inputColNames.length} must be the same as the number of " +
Expand Down Expand Up @@ -197,6 +201,10 @@ object OneHotEncoderEstimator extends DefaultParamsReadable[OneHotEncoderEstimat
override def load(path: String): OneHotEncoderEstimator = super.load(path)
}

/**
* @param categorySizes Original number of categories for each feature being encoded.
* The array contains one value for each input column, in order.
*/
@Since("2.3.0")
class OneHotEncoderModel private[ml] (
@Since("2.3.0") override val uid: String,
Expand All @@ -205,60 +213,58 @@ class OneHotEncoderModel private[ml] (

import OneHotEncoderModel._

// Returns the category size for a given index with `dropLast` and `handleInvalid`
// Returns the category size for each index with `dropLast` and `handleInvalid`
// taken into account.
private def configedCategorySize(orgCategorySize: Int, idx: Int): Int = {
private def getConfigedCategorySizes: Array[Int] = {
val dropLast = getDropLast
val keepInvalid = getHandleInvalid == OneHotEncoderEstimator.KEEP_INVALID

if (!dropLast && keepInvalid) {
// When `handleInvalid` is "keep", an extra category is added as last category
// for invalid data.
orgCategorySize + 1
categorySizes.map(_ + 1)
} else if (dropLast && !keepInvalid) {
// When `dropLast` is true, the last category is removed.
orgCategorySize - 1
categorySizes.map(_ - 1)
} else {
// When `dropLast` is true and `handleInvalid` is "keep", the extra category for invalid
// data is removed. Thus, it is the same as the plain number of categories.
orgCategorySize
categorySizes
}
}

private def encoder: UserDefinedFunction = {
val oneValue = Array(1.0)
val emptyValues = Array.empty[Double]
val emptyIndices = Array.empty[Int]
val dropLast = getDropLast
val handleInvalid = getHandleInvalid
val keepInvalid = handleInvalid == OneHotEncoderEstimator.KEEP_INVALID
val keepInvalid = getHandleInvalid == OneHotEncoderEstimator.KEEP_INVALID
val configedSizes = getConfigedCategorySizes
val localCategorySizes = categorySizes

// The udf performed on input data. The first parameter is the input value. The second
// parameter is the index of input.
udf { (label: Double, idx: Int) =>
val plainNumCategories = categorySizes(idx)
val size = configedCategorySize(plainNumCategories, idx)

if (label < 0) {
throw new SparkException(s"Negative value: $label. Input can't be negative.")
} else if (label == size && dropLast && !keepInvalid) {
// When `dropLast` is true and `handleInvalid` is not "keep",
// the last category is removed.
Vectors.sparse(size, emptyIndices, emptyValues)
} else if (label >= plainNumCategories && keepInvalid) {
// When `handleInvalid` is "keep", encodes invalid data to last category (and removed
// if `dropLast` is true)
if (dropLast) {
Vectors.sparse(size, emptyIndices, emptyValues)
// parameter is the index in inputCols of the column being encoded.
udf { (label: Double, colIdx: Int) =>
val origCategorySize = localCategorySizes(colIdx)
// idx: index in vector of the single 1-valued element
val idx = if (label >= 0 && label < origCategorySize) {
label
} else {
if (keepInvalid) {
origCategorySize
} else {
Vectors.sparse(size, Array(size - 1), oneValue)
if (label < 0) {
throw new SparkException(s"Negative value: $label. Input can't be negative. " +
s"To handle invalid values, set Param handleInvalid to " +
s"${OneHotEncoderEstimator.KEEP_INVALID}")
} else {
throw new SparkException(s"Unseen value: $label. To handle unseen values, " +
s"set Param handleInvalid to ${OneHotEncoderEstimator.KEEP_INVALID}.")
}
}
} else if (label < plainNumCategories) {
Vectors.sparse(size, Array(label.toInt), oneValue)
}

val size = configedSizes(colIdx)
if (idx < size) {
Vectors.sparse(size, Array(idx.toInt), Array(1.0))
} else {
assert(handleInvalid == OneHotEncoderEstimator.ERROR_INVALID)
throw new SparkException(s"Unseen value: $label. To handle unseen values, " +
s"set Param handleInvalid to ${OneHotEncoderEstimator.KEEP_INVALID}.")
Vectors.sparse(size, Array.empty[Int], Array.empty[Double])
}
}
}
Expand All @@ -282,7 +288,6 @@ class OneHotEncoderModel private[ml] (
@Since("2.3.0")
override def transformSchema(schema: StructType): StructType = {
val inputColNames = $(inputCols)
val outputColNames = $(outputCols)

require(inputColNames.length == categorySizes.length,
s"The number of input columns ${inputColNames.length} must be the same as the number of " +
Expand All @@ -300,6 +305,7 @@ class OneHotEncoderModel private[ml] (
* account. Mismatched numbers will cause exception.
*/
private def verifyNumOfValues(schema: StructType): StructType = {
val configedSizes = getConfigedCategorySizes
$(outputCols).zipWithIndex.foreach { case (outputColName, idx) =>
val inputColName = $(inputCols)(idx)
val attrGroup = AttributeGroup.fromStructField(schema(outputColName))
Expand All @@ -308,9 +314,9 @@ class OneHotEncoderModel private[ml] (
// comparing with expected category number with `handleInvalid` and
// `dropLast` taken into account.
if (attrGroup.attributes.nonEmpty) {
val numCategories = configedCategorySize(categorySizes(idx), idx)
val numCategories = configedSizes(idx)
require(attrGroup.size == numCategories, "OneHotEncoderModel expected " +
s"$numCategories categorical values for input column ${inputColName}, " +
s"$numCategories categorical values for input column $inputColName, " +
s"but the input column had metadata specifying ${attrGroup.size} values.")
}
}
Expand All @@ -322,7 +328,7 @@ class OneHotEncoderModel private[ml] (
val transformedSchema = transformSchema(dataset.schema, logging = true)
val keepInvalid = $(handleInvalid) == OneHotEncoderEstimator.KEEP_INVALID

val encodedColumns = (0 until $(inputCols).length).map { idx =>
val encodedColumns = $(inputCols).indices.map { idx =>
val inputColName = $(inputCols)(idx)
val outputColName = $(outputCols)(idx)

Expand Down

0 comments on commit 930b90a

Please sign in to comment.