Skip to content

Commit

Permalink
[SPARK-10573] [ML] IndexToString output schema should be StringType
Browse files Browse the repository at this point in the history
Fixes bug where IndexToString output schema was DoubleType. Correct me if I'm wrong, but it doesn't seem like the output needs to have any "ML Attribute" metadata.

Author: Nick Pritchard <[email protected]>

Closes #8751 from pnpritchard/SPARK-10573.

(cherry picked from commit 8a634e9)
Signed-off-by: Xiangrui Meng <[email protected]>
  • Loading branch information
pnpritchard authored and mengxr committed Sep 14, 2015
1 parent 5f58704 commit 5b7067c
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.ml.Transformer
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DoubleType, NumericType, StringType, StructType}
import org.apache.spark.sql.types._
import org.apache.spark.util.collection.OpenHashMap

/**
Expand Down Expand Up @@ -220,8 +220,7 @@ class IndexToString private[ml] (
val outputColName = $(outputCol)
require(inputFields.forall(_.name != outputColName),
s"Output column $outputColName already exists.")
val attr = NominalAttribute.defaultAttr.withName($(outputCol))
val outputFields = inputFields :+ attr.toStructField()
val outputFields = inputFields :+ StructField($(outputCol), StringType)
StructType(outputFields)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

package org.apache.spark.ml.feature

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types.{StringType, StructType, StructField, DoubleType}
import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.MLTestingUtils
Expand Down Expand Up @@ -134,4 +135,11 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(a === b)
}
}

test("IndexToString.transformSchema (SPARK-10573)") {
val idxToStr = new IndexToString().setInputCol("input").setOutputCol("output")
val inSchema = StructType(Seq(StructField("input", DoubleType)))
val outSchema = idxToStr.transformSchema(inSchema)
assert(outSchema("output").dataType === StringType)
}
}

0 comments on commit 5b7067c

Please sign in to comment.