Skip to content


Address comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Dec 26, 2017
1 parent 32318fa commit 144f07d
Show file tree
Hide file tree
Showing 3 changed files with 348 additions and 186 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class OneHotEncoder @Since("1.4.0") (@Since("1.4.0") override val uid: String) e
s"Output column $outputColName already exists.")

val outputField = OneHotEncoderCommon.transformOutputColumnSchema(
schema(inputColName), $(dropLast), outputColName)
schema(inputColName), outputColName, $(dropLast))
val outputFields = inputFields :+ outputField
Expand All @@ -106,7 +106,7 @@ class OneHotEncoder @Since("1.4.0") (@Since("1.4.0") override val uid: String) e

val outputAttrGroup = if (outputAttrGroupFromSchema.size < 0) {
dataset, $(dropLast), Seq(inputColName), Seq(outputColName))(0)
dataset, Seq(inputColName), Seq(outputColName), $(dropLast))(0)
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ private[ml] trait OneHotEncoderBase extends Params with HasHandleInvalid
def getDropLast: Boolean = $(dropLast)

protected def validateAndTransformSchema(schema: StructType): StructType = {
protected def validateAndTransformSchema(
schema: StructType, dropLast: Boolean, keepInvalid: Boolean): StructType = {
val inputColNames = $(inputCols)
val outputColNames = $(outputCols)
val existingFields = schema.fields
Expand All @@ -74,22 +75,19 @@ private[ml] trait OneHotEncoderBase extends Params with HasHandleInvalid
s"The number of input columns ${inputColNames.length} must be the same as the number of " +
s"output columns ${outputColNames.length}.") { case (inputColName, outputColName) =>
s"Input column must be of type NumericType but got ${schema(inputColName).dataType}")
require(!existingFields.exists( == outputColName),
s"Output column $outputColName already exists.")
// Input columns must be NumericType.
inputColNames.foreach(SchemaUtils.checkNumericType(schema, _))

// Prepares output columns with proper attributes by examining input columns.
val inputFields = $(inputCols).map(schema(_))
val keepInvalid = $(handleInvalid) == OneHotEncoderEstimator.KEEP_INVALID

val outputFields = { case (inputField, outputColName) =>
inputField, $(dropLast), outputColName, keepInvalid)
inputField, outputColName, dropLast, keepInvalid)
outputFields.foldLeft(schema) { case (newSchema, outputField) =>
SchemaUtils.appendColumn(newSchema, outputField)
StructType(schema.fields ++ outputFields)

Expand All @@ -109,6 +107,9 @@ private[ml] trait OneHotEncoderBase extends Params with HasHandleInvalid
* added as last category. So when `dropLast` is true, invalid values are encoded as all-zeros
* vector.
* @note When encoding multi-column by using `inputCols` and `outputCols` params, input/output cols
* come in pairs, specified by the order in the arrays, and each pair is treated independently.
* @see `StringIndexer` for converting categorical values into category indices
Expand Down Expand Up @@ -136,7 +137,9 @@ class OneHotEncoderEstimator @Since("2.3.0") (@Since("2.3.0") override val uid:

override def transformSchema(schema: StructType): StructType = {
// When fitting data, we want the the plain number of categories without `handleInvalid` and
// `dropLast` taken into account.
validateAndTransformSchema(schema, dropLast = false, keepInvalid = false)

Expand All @@ -160,9 +163,11 @@ class OneHotEncoderEstimator @Since("2.3.0") (@Since("2.3.0") override val uid:
if (columnToScanIndices.length > 0) {
val inputColNames =$(inputCols)(_))
val outputColNames =$(outputCols)(_))
val keepInvalid = $(handleInvalid) == OneHotEncoderEstimator.KEEP_INVALID

// When fitting data, we want the plain number of categories without `handleInvalid` and
// `dropLast` taken into account.
val attrGroups = OneHotEncoderCommon.getOutputAttrGroupFromData(
dataset, $(dropLast), inputColNames, outputColNames, keepInvalid)
dataset, inputColNames, outputColNames, dropLast = false) { case (attrGroup, idx) =>
categorySizes(idx) = attrGroup.size
Expand Down Expand Up @@ -195,6 +200,26 @@ class OneHotEncoderModel private[ml] (

import OneHotEncoderModel._

// The actual number of categories varies due to different setting of `dropLast` and
// `handleInvalid`.
private def configedCategorySizes: Array[Int] = {
val dropLast = getDropLast
val keepInvalid = getHandleInvalid == OneHotEncoderEstimator.KEEP_INVALID

if (!dropLast && keepInvalid) {
// When `handleInvalid` is "keep", an extra category is added as last category
// for invalid data. + 1)
} else if (dropLast && !keepInvalid) {
// When `dropLast` is true, the last category is removed. - 1)
} else {
// When `dropLast` is true and `handleInvalid` is "keep", the extra category for invalid
// data is removed. Thus, it is the same as the plain number of categories.

private def encoder: UserDefinedFunction = {
val oneValue = Array(1.0)
val emptyValues = Array.empty[Double]
Expand All @@ -205,21 +230,29 @@ class OneHotEncoderModel private[ml] (

udf { (label: Double, size: Int) =>
val numCategory = if (!dropLast && keepInvalid) {
// When `handleInvalid` is 'keep' and `dropLast` is false, the last category is
// When `dropLast` is false and `handleInvalid` is "keep", the last category is
// for invalid data.
size - 1
} else {

if (label < numCategory) {
if (label < 0) {
throw new SparkException(s"Negative value: $label. Input can't be negative.")
} else if (label < numCategory) {
Vectors.sparse(size, Array(label.toInt), oneValue)
} else if (label == numCategory && dropLast && !keepInvalid) {
// When `dropLast` is true and `handleInvalid` is not "keep",
// the last category is removed.
Vectors.sparse(size, emptyIndices, emptyValues)
} else if (dropLast && keepInvalid) {
// When `dropLast` is true and `handleInvalid` is "keep",
// invalid data is encoded to the removed last category.
Vectors.sparse(size, emptyIndices, emptyValues)
} else if (keepInvalid) {
Vectors.sparse(size, Array(size - 1), oneValue)
// When `dropLast` is false and `handleInvalid` is "keep",
// invalid data is encoded to the last category.
Vectors.sparse(size, Array(numCategory), oneValue)
} else {
assert(handleInvalid == OneHotEncoderEstimator.ERROR_INVALID)
throw new SparkException(s"Unseen value: $label. To handle unseen values, " +
Expand Down Expand Up @@ -253,26 +286,29 @@ class OneHotEncoderModel private[ml] (
s"The number of input columns ${inputColNames.length} must be the same as the number of " +
s"features ${categorySizes.length} during fitting.")

val transformedSchema = validateAndTransformSchema(schema)
val keepInvalid = $(handleInvalid) == OneHotEncoderEstimator.KEEP_INVALID
val transformedSchema = validateAndTransformSchema(schema, dropLast = $(dropLast),
keepInvalid = keepInvalid)

* If the metadata of input columns also specifies the number of categories, we need to
* compare with expected category number obtained during fitting. Mismatched numbers will
* cause exception.
* compare with expected category number with `handleInvalid` and `dropLast` taken into
* account. Mismatched numbers will cause exception.
private def verifyNumOfValues(schema: StructType): StructType = {
$(outputCols).zipWithIndex.foreach { case (outputColName, idx) =>
val inputColName = $(inputCols)(idx)
val attrGroup = AttributeGroup.fromStructField(schema(outputColName))

// If the input metadata specifies number of category,
// compare with expected category number.
// If the input metadata specifies number of category for output column,
// comparing with expected category number with `handleInvalid` and
// `dropLast` taken into account.
if (attrGroup.attributes.nonEmpty) {
require(attrGroup.size == categorySizes(idx), "OneHotEncoderModel expected " +
s"${categorySizes(idx)} categorical values for input column ${inputColName}, but " +
s"the input column had metadata specifying ${attrGroup.size} values.")
require(attrGroup.size == configedCategorySizes(idx), "OneHotEncoderModel expected " +
s"${configedCategorySizes(idx)} categorical values for input column ${inputColName}, " +
s"but the input column had metadata specifying ${attrGroup.size} values.")
Expand All @@ -281,6 +317,7 @@ class OneHotEncoderModel private[ml] (
override def transform(dataset: Dataset[_]): DataFrame = {
val transformedSchema = transformSchema(dataset.schema, logging = true)
val keepInvalid = $(handleInvalid) == OneHotEncoderEstimator.KEEP_INVALID

val encodedColumns = (0 until $(inputCols).length).map { idx =>
val inputColName = $(inputCols)(idx)
Expand All @@ -290,13 +327,13 @@ class OneHotEncoderModel private[ml] (

val metadata = if (outputAttrGroupFromSchema.size < 0) {
OneHotEncoderCommon.createAttrGroupForAttrNames(outputColName, false,
categorySizes(idx), $(dropLast), keepInvalid).toMetadata()
} else {

encoder(col(inputColName).cast(DoubleType), lit(categorySizes(idx)))
encoder(col(inputColName).cast(DoubleType), lit(configedCategorySizes(idx)))
.as(outputColName, metadata)
dataset.withColumns($(outputCols), encodedColumns)
Expand Down Expand Up @@ -376,7 +413,7 @@ private[feature] object OneHotEncoderCommon {
case _: NumericAttribute =>
throw new RuntimeException(
s"The input column ${} cannot be numeric.")
s"The input column ${} cannot be continuous-value.")
case _ =>
None // optimistic about unknown attributes
Expand All @@ -401,8 +438,8 @@ private[feature] object OneHotEncoderCommon {
def transformOutputColumnSchema(
inputCol: StructField,
dropLast: Boolean,
outputColName: String,
dropLast: Boolean,
keepInvalid: Boolean = false): StructField = {
val outputAttrNames = genOutputAttrNames(inputCol)
val filteredOutputAttrNames = { names =>
Expand All @@ -426,10 +463,9 @@ private[feature] object OneHotEncoderCommon {
def getOutputAttrGroupFromData(
dataset: Dataset[_],
dropLast: Boolean,
inputColNames: Seq[String],
outputColNames: Seq[String],
handleInvalid: Boolean = false): Seq[AttributeGroup] = {
dropLast: Boolean): Seq[AttributeGroup] = {
// The RDD approach has advantage of early-stop if any values are invalid. It seems that
// DataFrame ops don't have equivalent functions.
val columns = { inputColName =>
Expand All @@ -441,31 +477,35 @@ private[feature] object OneHotEncoderCommon {
(0 until numOfColumns).map(idx => row.getDouble(idx)).toArray
}.treeAggregate(new Array[Double](numOfColumns))(
(maxValues, curValues) => {
(0 until numOfColumns).map { idx =>
(0 until numOfColumns).foreach { idx =>
val x = curValues(idx)
assert(x <= Int.MaxValue,
s"OneHotEncoder only supports up to ${Int.MaxValue} indices, but got $x.")
assert(x >= 0.0 && x == x.toInt,
s"Values from column ${inputColNames(idx)} must be indices, but got $x.")
math.max(maxValues(idx), x)
maxValues(idx) = math.max(maxValues(idx), x)
(m0, m1) => {
(0 until numOfColumns).map(idx => math.max(m0(idx), m1(idx))).toArray
(0 until numOfColumns).foreach { idx =>
m0(idx) = math.max(m0(idx), m1(idx))
).map(_.toInt + 1) { case (outputColName, numAttrs) =>
createAttrGroupForAttrNames(outputColName, dropLast, numAttrs, handleInvalid)
createAttrGroupForAttrNames(outputColName, numAttrs, dropLast, keepInvalid = false)

/** Creates an `AttributeGroup` with the required number of `BinaryAttribute`. */
def createAttrGroupForAttrNames(
outputColName: String,
dropLast: Boolean,
numAttrs: Int,
keepInvalid: Boolean = false): AttributeGroup = {
dropLast: Boolean,
keepInvalid: Boolean): AttributeGroup = {
val outputAttrNames = Array.tabulate(numAttrs)(_.toString)
val filtered = if (dropLast && !keepInvalid) {
Expand Down

0 comments on commit 144f07d

Please sign in to comment.