Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SPARK-5888. [MLLIB]. Add OneHotEncoder as a Transformer #5500

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 107 additions & 0 deletions mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.feature

import org.apache.spark.SparkException
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.attribute.{Attribute, BinaryAttribute, NominalAttribute}
import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util.SchemaUtils
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}

/**
* A one-hot encoder that maps a column of label indices to a column of binary vectors, with
* at most a single one-value. By default, the binary vector has an element for each category, so
* with 5 categories, an input value of 2.0 would map to an output vector of
* (0.0, 0.0, 1.0, 0.0, 0.0). If includeFirst is set to false, the first category is omitted, so the
* output vector for the previous example would be (0.0, 1.0, 0.0, 0.0) and an input value
* of 0.0 would map to a vector of all zeros. Including the first category makes the vector columns
* linearly dependent because they sum up to one.
*/
@AlphaComponent
class OneHotEncoder extends UnaryTransformer[Double, Vector, OneHotEncoder]
with HasInputCol with HasOutputCol {

/**
* Whether to include a component in the encoded vectors for the first category, defaults to true.
* @group param
*/
final val includeFirst: BooleanParam =
new BooleanParam(this, "includeFirst", "include first category")
setDefault(includeFirst -> true)

private var categories: Array[String] = _

/** @group setParam */
def setIncludeFirst(value: Boolean): this.type = set(includeFirst, value)

/** @group setParam */
override def setInputCol(value: String): this.type = set(inputCol, value)

/** @group setParam */
override def setOutputCol(value: String): this.type = set(outputCol, value)

override def transformSchema(schema: StructType): StructType = {
SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType)
val inputFields = schema.fields
val outputColName = $(outputCol)
require(inputFields.forall(_.name != $(outputCol)),
s"Output column ${$(outputCol)} already exists.")

val inputColAttr = Attribute.fromStructField(schema($(inputCol)))
categories = inputColAttr match {
case nominal: NominalAttribute =>
nominal.values.getOrElse((0 until nominal.numValues.get).map(_.toString).toArray)
case binary: BinaryAttribute => binary.values.getOrElse(Array("0", "1"))
case _ =>
throw new SparkException(s"OneHotEncoder input column ${$(inputCol)} is not nominal")
}

val attrValues = (if ($(includeFirst)) categories else categories.drop(1)).toArray
val attr = NominalAttribute.defaultAttr.withName(outputColName).withValues(attrValues)
val outputFields = inputFields :+ attr.toStructField()
StructType(outputFields)
}

protected override def createTransformFunc(): (Double) => Vector = {
val first = $(includeFirst)
val vecLen = if (first) categories.length else categories.length - 1
val oneValue = Array(1.0)
val emptyValues = Array[Double]()
val emptyIndices = Array[Int]()
label: Double => {
val values = if (first || label != 0.0) oneValue else emptyValues
val indices = if (first) {
Array(label.toInt)
} else if (label != 0.0) {
Array(label.toInt - 1)
} else {
emptyIndices
}
Vectors.sparse(vecLen, indices, values)
}
}

/**
* Returns the data type of the output column.
*/
protected def outputDataType: DataType = new VectorUDT
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.feature

import org.scalatest.FunSuite

import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, SQLContext}


class OneHotEncoderSuite extends FunSuite with MLlibTestSparkContext {
private var sqlContext: SQLContext = _

override def beforeAll(): Unit = {
super.beforeAll()
sqlContext = new SQLContext(sc)
}

def stringIndexed(): DataFrame = {
val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2)
val df = sqlContext.createDataFrame(data).toDF("id", "label")
val indexer = new StringIndexer()
.setInputCol("label")
.setOutputCol("labelIndex")
.fit(df)
indexer.transform(df)
}

test("OneHotEncoder includeFirst = true") {
val transformed = stringIndexed()
val encoder = new OneHotEncoder()
.setInputCol("labelIndex")
.setOutputCol("labelVec")
val encoded = encoder.transform(transformed)

val output = encoded.select("id", "labelVec").map { r =>
val vec = r.get(1).asInstanceOf[Vector]
(r.getInt(0), vec(0), vec(1), vec(2))
}.collect().toSet
// a -> 0, b -> 2, c -> 1
val expected = Set((0, 1.0, 0.0, 0.0), (1, 0.0, 0.0, 1.0), (2, 0.0, 1.0, 0.0),
(3, 1.0, 0.0, 0.0), (4, 1.0, 0.0, 0.0), (5, 0.0, 1.0, 0.0))
assert(output === expected)
}

test("OneHotEncoder includeFirst = false") {
val transformed = stringIndexed()
val encoder = new OneHotEncoder()
.setIncludeFirst(false)
.setInputCol("labelIndex")
.setOutputCol("labelVec")
val encoded = encoder.transform(transformed)

val output = encoded.select("id", "labelVec").map { r =>
val vec = r.get(1).asInstanceOf[Vector]
(r.getInt(0), vec(0), vec(1))
}.collect().toSet
// a -> 0, b -> 2, c -> 1
val expected = Set((0, 0.0, 0.0), (1, 0.0, 1.0), (2, 1.0, 0.0),
(3, 0.0, 0.0), (4, 0.0, 0.0), (5, 1.0, 0.0))
assert(output === expected)
}

}