diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala index 15b78b221394d..515b6706686f9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -32,8 +32,8 @@ import org.apache.spark.sql.types.{DataType, DoubleType, StructType} * with 5 categories, an input value of 2.0 would map to an output vector of * (0.0, 0.0, 1.0, 0.0, 0.0). If includeFirst is set to false, the first category is omitted, so the * output vector for the previous example would be (0.0, 1.0, 0.0, 0.0) and an input value - * of 0.0 would map to a vector of all zeros. Omitting the first category enables the vector - * columns to be independent. + * of 0.0 would map to a vector of all zeros. Including the first category makes the vector columns + * linearly dependent because they sum up to one. */ @AlphaComponent class OneHotEncoder extends UnaryTransformer[Double, Vector, OneHotEncoder] @@ -43,8 +43,8 @@ class OneHotEncoder extends UnaryTransformer[Double, Vector, OneHotEncoder] * Whether to include a component in the encoded vectors for the first category, defaults to true. * @group param */ - final val includeFirst: Param[Boolean] = - new Param[Boolean](this, "includeFirst", "include first category") + final val includeFirst: BooleanParam = + new BooleanParam(this, "includeFirst", "include first category") setDefault(includeFirst -> true) /** @@ -59,7 +59,7 @@ class OneHotEncoder extends UnaryTransformer[Double, Vector, OneHotEncoder] def setIncludeFirst(value: Boolean): this.type = set(includeFirst, value) /** @group setParam */ - def setLabelNames(value: Array[String]): this.type = set(labelNames, value) + def setLabelNames(attr: NominalAttribute): this.type = set(labelNames, attr.values.get) /** @group setParam */ override def setInputCol(value: String): this.type = set(inputCol, value) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala index 6b76843a44612..038914c59fdca 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.feature -import org.apache.spark.ml.attribute.{NominalAttribute, Attribute} +import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.util.MLlibTestSparkContext @@ -49,7 +49,7 @@ class OneHotEncoderSuite extends FunSuite with MLlibTestSparkContext { test("OneHotEncoder includeFirst = true") { val (transformed, attr) = stringIndexed() val encoder = new OneHotEncoder() - .setLabelNames(attr.values.get) + .setLabelNames(attr) .setInputCol("labelIndex") .setOutputCol("labelVec") val encoded = encoder.transform(transformed) @@ -68,7 +68,7 @@ class OneHotEncoderSuite extends FunSuite with MLlibTestSparkContext { val (transformed, attr) = stringIndexed() val encoder = new OneHotEncoder() .setIncludeFirst(false) - .setLabelNames(attr.values.get) + .setLabelNames(attr) .setInputCol("labelIndex") .setOutputCol("labelVec") val encoded = encoder.transform(transformed)