-
Notifications
You must be signed in to change notification settings - Fork 28.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[SPARK-5888] [MLLIB] Add OneHotEncoder as a Transformer
This patch adds a one hot encoder for categorical features. Planning to add documentation and another test after getting feedback on the approach. A couple choices made here: * There's an `includeFirst` option which, if false, creates numCategories - 1 columns and, if true, creates numCategories columns. The default is true, which is the behavior in scikit-learn. * The user is expected to pass a `Seq` of category names when instantiating a `OneHotEncoder`. These can be easily gotten from a `StringIndexer`. The names are used for the output column names, which take the form colName_categoryName. Author: Sandy Ryza <[email protected]> Closes #5500 from sryza/sandy-spark-5888 and squashes the following commits: f383250 [Sandy Ryza] Infer label names automatically 6e257b9 [Sandy Ryza] Review comments 7c539cf [Sandy Ryza] Vector transformers 1c182dd [Sandy Ryza] SPARK-5888. [MLLIB]. Add OneHotEncoder as a Transformer (cherry picked from commit 47728db) Signed-off-by: Xiangrui Meng <[email protected]>
- Loading branch information
Showing
2 changed files
with
187 additions
and
0 deletions.
There are no files selected for viewing
107 changes: 107 additions & 0 deletions
107
mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.ml.feature | ||
|
||
import org.apache.spark.SparkException | ||
import org.apache.spark.annotation.AlphaComponent | ||
import org.apache.spark.ml.UnaryTransformer | ||
import org.apache.spark.ml.attribute.{Attribute, BinaryAttribute, NominalAttribute} | ||
import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} | ||
import org.apache.spark.ml.param._ | ||
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} | ||
import org.apache.spark.ml.util.SchemaUtils | ||
import org.apache.spark.sql.types.{DataType, DoubleType, StructType} | ||
|
||
/** | ||
* A one-hot encoder that maps a column of label indices to a column of binary vectors, with | ||
* at most a single one-value. By default, the binary vector has an element for each category, so | ||
* with 5 categories, an input value of 2.0 would map to an output vector of | ||
* (0.0, 0.0, 1.0, 0.0, 0.0). If includeFirst is set to false, the first category is omitted, so the | ||
* output vector for the previous example would be (0.0, 1.0, 0.0, 0.0) and an input value | ||
* of 0.0 would map to a vector of all zeros. Including the first category makes the vector columns | ||
* linearly dependent because they sum up to one. | ||
*/ | ||
@AlphaComponent | ||
class OneHotEncoder extends UnaryTransformer[Double, Vector, OneHotEncoder] | ||
with HasInputCol with HasOutputCol { | ||
|
||
/** | ||
* Whether to include a component in the encoded vectors for the first category, defaults to true. | ||
* @group param | ||
*/ | ||
final val includeFirst: BooleanParam = | ||
new BooleanParam(this, "includeFirst", "include first category") | ||
setDefault(includeFirst -> true) | ||
|
||
private var categories: Array[String] = _ | ||
|
||
/** @group setParam */ | ||
def setIncludeFirst(value: Boolean): this.type = set(includeFirst, value) | ||
|
||
/** @group setParam */ | ||
override def setInputCol(value: String): this.type = set(inputCol, value) | ||
|
||
/** @group setParam */ | ||
override def setOutputCol(value: String): this.type = set(outputCol, value) | ||
|
||
override def transformSchema(schema: StructType): StructType = { | ||
SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType) | ||
val inputFields = schema.fields | ||
val outputColName = $(outputCol) | ||
require(inputFields.forall(_.name != $(outputCol)), | ||
s"Output column ${$(outputCol)} already exists.") | ||
|
||
val inputColAttr = Attribute.fromStructField(schema($(inputCol))) | ||
categories = inputColAttr match { | ||
case nominal: NominalAttribute => | ||
nominal.values.getOrElse((0 until nominal.numValues.get).map(_.toString).toArray) | ||
case binary: BinaryAttribute => binary.values.getOrElse(Array("0", "1")) | ||
case _ => | ||
throw new SparkException(s"OneHotEncoder input column ${$(inputCol)} is not nominal") | ||
} | ||
|
||
val attrValues = (if ($(includeFirst)) categories else categories.drop(1)).toArray | ||
val attr = NominalAttribute.defaultAttr.withName(outputColName).withValues(attrValues) | ||
val outputFields = inputFields :+ attr.toStructField() | ||
StructType(outputFields) | ||
} | ||
|
||
protected override def createTransformFunc(): (Double) => Vector = { | ||
val first = $(includeFirst) | ||
val vecLen = if (first) categories.length else categories.length - 1 | ||
val oneValue = Array(1.0) | ||
val emptyValues = Array[Double]() | ||
val emptyIndices = Array[Int]() | ||
label: Double => { | ||
val values = if (first || label != 0.0) oneValue else emptyValues | ||
val indices = if (first) { | ||
Array(label.toInt) | ||
} else if (label != 0.0) { | ||
Array(label.toInt - 1) | ||
} else { | ||
emptyIndices | ||
} | ||
Vectors.sparse(vecLen, indices, values) | ||
} | ||
} | ||
|
||
/** | ||
* Returns the data type of the output column. | ||
*/ | ||
protected def outputDataType: DataType = new VectorUDT | ||
} |
80 changes: 80 additions & 0 deletions
80
mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.ml.feature | ||
|
||
import org.scalatest.FunSuite | ||
|
||
import org.apache.spark.mllib.linalg.Vector | ||
import org.apache.spark.mllib.util.MLlibTestSparkContext | ||
import org.apache.spark.sql.{DataFrame, SQLContext} | ||
|
||
|
||
class OneHotEncoderSuite extends FunSuite with MLlibTestSparkContext { | ||
private var sqlContext: SQLContext = _ | ||
|
||
override def beforeAll(): Unit = { | ||
super.beforeAll() | ||
sqlContext = new SQLContext(sc) | ||
} | ||
|
||
def stringIndexed(): DataFrame = { | ||
val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")), 2) | ||
val df = sqlContext.createDataFrame(data).toDF("id", "label") | ||
val indexer = new StringIndexer() | ||
.setInputCol("label") | ||
.setOutputCol("labelIndex") | ||
.fit(df) | ||
indexer.transform(df) | ||
} | ||
|
||
test("OneHotEncoder includeFirst = true") { | ||
val transformed = stringIndexed() | ||
val encoder = new OneHotEncoder() | ||
.setInputCol("labelIndex") | ||
.setOutputCol("labelVec") | ||
val encoded = encoder.transform(transformed) | ||
|
||
val output = encoded.select("id", "labelVec").map { r => | ||
val vec = r.get(1).asInstanceOf[Vector] | ||
(r.getInt(0), vec(0), vec(1), vec(2)) | ||
}.collect().toSet | ||
// a -> 0, b -> 2, c -> 1 | ||
val expected = Set((0, 1.0, 0.0, 0.0), (1, 0.0, 0.0, 1.0), (2, 0.0, 1.0, 0.0), | ||
(3, 1.0, 0.0, 0.0), (4, 1.0, 0.0, 0.0), (5, 0.0, 1.0, 0.0)) | ||
assert(output === expected) | ||
} | ||
|
||
test("OneHotEncoder includeFirst = false") { | ||
val transformed = stringIndexed() | ||
val encoder = new OneHotEncoder() | ||
.setIncludeFirst(false) | ||
.setInputCol("labelIndex") | ||
.setOutputCol("labelVec") | ||
val encoded = encoder.transform(transformed) | ||
|
||
val output = encoded.select("id", "labelVec").map { r => | ||
val vec = r.get(1).asInstanceOf[Vector] | ||
(r.getInt(0), vec(0), vec(1)) | ||
}.collect().toSet | ||
// a -> 0, b -> 2, c -> 1 | ||
val expected = Set((0, 0.0, 0.0), (1, 0.0, 1.0), (2, 1.0, 0.0), | ||
(3, 0.0, 0.0), (4, 0.0, 0.0), (5, 1.0, 0.0)) | ||
assert(output === expected) | ||
} | ||
|
||
} |