Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-13030][ML] Create OneHotEncoderEstimator for OneHotEncoder as Estimator #19527

Closed
wants to merge 14 commits into from

Conversation

viirya
Copy link
Member

@viirya viirya commented Oct 18, 2017

What changes were proposed in this pull request?

This patch adds a new class OneHotEncoderEstimator which extends Estimator. The fit method returns OneHotEncoderModel.

Common methods between existing OneHotEncoder and new OneHotEncoderEstimator, such as transforming schema, are extracted and put into OneHotEncoderCommon to reduce code duplication.

Multi-column support

OneHotEncoderEstimator adds simpler multi-column support because it is new API and can be free from backward compatibility.

handleInvalid Param support

OneHotEncoderEstimator supports handleInvalid Param. It supports error and keep.

How was this patch tested?

Added new test suite OneHotEncoderEstimatorSuite.

@SparkQA
Copy link

SparkQA commented Oct 18, 2017

Test build #82879 has finished for PR 19527 at commit 8fd4677.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds the following public classes (experimental):
  • class OneHotEncoderEstimator @Since(\"2.3.0\") (@Since(\"2.3.0\") override val uid: String)
  • class OneHotEncoderModelWriter(instance: OneHotEncoderModel) extends MLWriter

@viirya
Copy link
Member Author

viirya commented Oct 18, 2017

cc @MLnick @WeichenXu123 @jkbradley This adds a new class OneHotEncoderEstimator which extends Estimator. Please review this when you can. Thanks.


encoder(col(inputColName).cast(DoubleType)).as(outputColName, metadata)
}
val allCols = Seq(col("*")) ++ encodedColumns
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just looking at this in the context of the previous multi-col work done (Bucketizer and Imputer) is there any real performance difference between this and the use of dataset.withColumns(...)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before running real performance test, I guess the difference might not be so significant.

Besides performance, this multi-column support on OneHotEncoder can also benefit user convenience when using its API. For example scikit-learn's OneHotEncoder is also with multi-column support.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we can get the number of values for multi-column in https://github.com/apache/spark/pull/19527/files#r145457081, the performance of fitting can be improved.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Performance test results are attached now. Existing OneHotEncoder only performs transform. So if we consider the time of fitting and transforming in this estimator/model with existing one hot encoder, this multi-column approach is faster.

* This method is called when we want to generate `AttributeGroup` from actual data for
* one-hot encoder.
*/
def getOutputAttrGroupFromData(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

calling this multiple times for multiple columns seems inefficient. It should be possible to use dataframe ops to compute this efficiently?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be improved under multi-column usage. Let me think about it. Thanks.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The rdd approach has advantage of early-stop if any values are invalid. It seems that dataframe ops don't have equivalent functions.

We can only check if max values are more than Int.MaxValue after aggregation. Seems that we also need to compute min aggregation function for columns, so we can check if any values are less than zero.

So currently I think I will modify this to multi-column version but still use rdd approach. Sounds good to you?

throw new SparkException(s"Unseen value: $label. To handle unseen values, " +
s"set Param handleInvalid to ${OneHotEncoderEstimator.SKIP_INVALID}.")
} else {
Vectors.sparse(size, emptyIndices, emptyValues)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This behaviour seems more like "keep" in StringIndexer than "skip". Skip filters out rows with invalid values and returns a dataframe with number of rows <= the number of rows in the input. "keep" maps all invalid values to a single, known value. Should we rename this to KEEP_INVALID?

https://spark.apache.org/docs/2.2.0/ml-features.html#stringindexer

Copy link
Member Author

@viirya viirya Oct 18, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I checked scikit-learn's OneHotEncoder: http://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.OneHotEncoder.html, its handle_unknown has error or ignore choices. Sounds like ignore is more close to skip in semantics?

I'm willing to rename it to keep if others also think it is better.

Copy link
Contributor

@MLnick MLnick Oct 20, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although keep does seem a little strange semantically, it's probably best to match the param option names from StringIndexer since that will be the typical pipeline StringIndexer -> OneHotEncoder.

And skip definitely would seem to imply that rows are filtered out or dropped, which is not the case here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. Then I will rename it to keep. Thanks.

@viirya
Copy link
Member Author

viirya commented Oct 19, 2017

Benchmark against multi-column one hot encoder.

Multi-Col, Multiple Run: The first commit. Run multiple treeAggregate on columns.
Multi-Col, Single Run: Run one treeAggregate on all columns, see suggestion at #19527 (comment).

Definitely, running one treeAggregate on all columns at once is faster.

Fitting:

numColums Multi-Col, Multiple Run Multi-Col, Single Run
1 0.11003638430000003 0.12968824099999998
100 3.6879334635000007 0.36438897839999995
1000 90.3695017947 2.4687475008

Transforming:

numColums Multi-Col, Multiple Run Multi-Col, Single Run
1 0.14080461019999999 0.1434849307
100 0.3636357813 0.41459606969999996
1000 3.1933874685 2.8026313985

Benchmark codes:

import org.apache.spark.ml.feature._
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
import spark.implicits._
import scala.util.Random

val seed = 123l
val random = new Random(seed)
val n = 10000
val m = 1000
val rows = sc.parallelize(1 to n).map(i=> Row(Array.fill(m)(random.nextInt(1000)): _*))
val struct = new StructType(Array.range(0,m,1).map(i => StructField(s"c$i",IntegerType,true)))
val df = spark.createDataFrame(rows, struct)
df.persist()
df.count()

val inputCols = Array.range(0,m,1).map(i => s"c$i")
val outputCols = Array.range(0,m,1).map(i => s"c${i}_encoded")

val encoder = new OneHotEncoderEstimator().setInputCols(inputCols).setOutputCols(outputCols)
var durationFitting = 0.0
var durationTransforming = 0.0
for (i <- 0 until 10) {
  val startFitting = System.nanoTime()
  val model = encoder.fit(df)
  val endFitting = System.nanoTime()
  durationFitting += (endFitting - startFitting) / 1e9

  val startTransforming = System.nanoTime()
  model.transform(df).count
  val endTransforming = System.nanoTime()
  durationTransforming += (endTransforming - startTransforming) / 1e9
}
println(s"fitting: ${durationFitting / 10}")
println(s"transforming: ${durationTransforming / 10}")

import org.apache.spark.ml.attribute._
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol, HasInputCols, HasOutputCol, HasOutputCols}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HasInputCol, HasOutputCol not needed, also I think Transformer above

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, thanks.

import org.apache.spark.sql.types.{DoubleType, NumericType, StructField, StructType}

/** Private trait for params for OneHotEncoderEstimator and OneHotEncoderModel */
private[ml] trait OneHotEncoderParams extends Params with HasHandleInvalid
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can the old OneHotEncoder also inherit this trait?

Copy link
Member Author

@viirya viirya Oct 19, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, should we? If we don't plan to add HasHandleInvalid, HasInputCols and HasOutputCols into existing OneHotEncoder, I think we can keep it as it is now.

With this one hot encoder estimator added, we may want to deprecate the existing OneHotEncoder.

Copy link
Contributor

@MLnick MLnick Oct 20, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the best approach is to leave the existing OneHotEncoder as it is and deprecate it for 3.0.0. Then in 3.0.0 we can remove the old one and rename OneHotEncoderEstimator -> OneHotEncoder to fit with the convention of other Estimator/Model pairs.

@viirya
Copy link
Member Author

viirya commented Oct 19, 2017

Benchmark against existing one hot encoder.

Because existing encoder only needs to run transform, there is no fitting time.

Transforming:

numColums Existing one hot encoder
1 0.2516055188
100 20.291758921100005
1000 26242.039411932*
  • Because ten iterations take too long to finish, I just ran one iteration for 1000 columns. But it shows the scale already.

Benchmark codes:

import org.apache.spark.ml.feature._
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
import spark.implicits._
import scala.util.Random

val seed = 123l
val random = new Random(seed)
val n = 10000
val m = 1000
val rows = sc.parallelize(1 to n).map(i=> Row(Array.fill(m)(random.nextInt(1000)): _*))
val struct = new StructType(Array.range(0,m,1).map(i => StructField(s"c$i",IntegerType,true)))
val df = spark.createDataFrame(rows, struct)
df.persist()
df.count()

val inputCols = Array.range(0,m,1).map(i => s"c$i")
val outputCols = Array.range(0,m,1).map(i => s"c${i}_encoded")

val encoders = Array.range(0,m,1).map(i => new OneHotEncoder().setInputCol(s"c$i").setOutputCol(s"c${i}_encoded"))
var duration = 0.0
for (i <- 0 until 10) {
  var encoded = df
  val start = System.nanoTime()
  encoders.foreach { encoder =>
    encoded = encoder.transform(encoded)
  }
  encoded.count
  val end = System.nanoTime()
  duration += (end - start) / 1e9
}
println(s"duration: ${duration / 10}")

@SparkQA
Copy link

SparkQA commented Oct 20, 2017

Test build #82917 has finished for PR 19527 at commit b42d175.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

Copy link
Contributor

@WeichenXu123 WeichenXu123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice work, thanks!

encoder(col(inputColName).cast(DoubleType)).as(outputColName, metadata)
}
val allCols = Seq(col("*")) ++ encodedColumns
dataset.select(allCols: _*)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does we need to handle the case: outputColNames conflict with existing column names ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah my other #19527 (comment) was that we created withColumns in Dataset that handles this?

private[spark] def withColumns(colNames: Seq[String], cols: Seq[Column]): DataFrame = {

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@WeichenXu123 Yes. transformSchema has handled the case that there are existing columns which conflict with any of outputColNames.

@MLnick Sorry I misunderstood your comment #19527 (comment). withColumns approach calls select so there is no performance difference. withColumns did one more thing to detect duplication in the column names. We should check the duplication too.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MLnick Use withColumns now.


import OneHotEncoderModel._

private def encoders: Array[UserDefinedFunction] = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about defining one UDF (label, size) => Vector instead of defining an array of UDFs ?
Or do you find some apparent perf difference between them ?
If not, I prefer defining only one UDF, which makes the code clearer.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. Use one UDF now.

@SparkQA
Copy link

SparkQA commented Oct 20, 2017

Test build #82930 has finished for PR 19527 at commit 66d46ac.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Oct 20, 2017

Test build #82932 has finished for PR 19527 at commit a9e9262.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Oct 20, 2017

Test build #82933 has finished for PR 19527 at commit fe80e98.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Oct 20, 2017

Test build #82934 has finished for PR 19527 at commit e024120.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@viirya
Copy link
Member Author

viirya commented Oct 23, 2017

@BryanCutler @MLnick @WeichenXu123 Thanks for reviewing. Your comments should be all addressed now. Please take a look again when you have more time.

Copy link
Member

@BryanCutler BryanCutler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just one minor comment on grammar, LGTM!

@Since("2.3.0")
override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid",
"How to handle invalid data " +
"Options are 'keep' (invalid data are ignored) or error (throw an error).",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should read "invalid data is ignored", same above too

Copy link
Contributor

@MrBago MrBago Oct 23, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to be more precise than "is ignored". "ignoring data" sounds a lot like "skipping data", which is not what "keep" does :).

How about something like 'keep' (invalid data produces a vector of zeros) or ...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will change the wording. Thanks.

@SparkQA
Copy link

SparkQA commented Oct 24, 2017

Test build #83000 has finished for PR 19527 at commit adc4107.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

Copy link
Contributor

@huaxingao huaxingao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems that the input parameter outputColName is not used in method OneHotEncoderCommon.genOutputAttrNames.
@viirya

@viirya
Copy link
Member Author

viirya commented Oct 26, 2017

@huaxingao Good catch! Thanks.

@SparkQA
Copy link

SparkQA commented Oct 26, 2017

Test build #83070 has finished for PR 19527 at commit ae2ac82.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@@ -41,8 +41,12 @@ import org.apache.spark.sql.types.{DoubleType, NumericType, StructType}
* The output vectors are sparse.
*
* @see `StringIndexer` for converting categorical values into category indices
* @deprecated `OneHotEncoderEstimator` will be renamed `OneHotEncoder` and this `OneHotEncoder`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note for the future: For 3.0, it'd be nice to do what you're describing here but also leave OneHotEncoderEstimator as a deprecated alias. That way, user code won't break but will have deprecation warnings when upgrading to 3.0.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. Sounds good.

@jkbradley
Copy link
Member

jkbradley commented Dec 1, 2017

Question about this PR description comment:

Note that keep can't be used at the same time with dropLast as true. Because they will conflict in encoded vector by producing a vector of zeros.

Why is this necessary? With n categories found in fitting, shouldn't the behavior be the following?

  • keep=true, dropLast=true ==> vector size n (all-0 only if there was an invalid value)
  • keep=true, dropLast=false ==> vector size n+1 (never all-0)
  • keep=false, dropLast=true ==> vector size n-1 (all-0 only for the last category)
  • keep=false, dropLast=false ==> vector size n (never all-0)

@viirya
Copy link
Member Author

viirya commented Dec 2, 2017

The behavior I thought is:

  • keep=true, dropLast=true ==> error option
  • keep=true, dropLast=false ==> vector size n (all-0 only for invalid value)

For the cases of dropLast = false, it behaves similarly as sklearn.preprocessing.OneHotEncoder.

If we make it behave as:

  • keep=true, dropLast=true ==> vector size n (all-0 only if there was an invalid value)

For example with 5 categories, we don't know [0.0, 0.0, 0.0, 0.0, 0.0] means last category or invalid value.

@jkbradley
Copy link
Member

For example with 5 categories, we don't know [0.0, 0.0, 0.0, 0.0, 0.0] means last category or invalid value.

For the semantics I described ("OPTION 1"), I think it is clear:

  • Last category would lead to [0.0, 0.0, 0.0, 0.0, 1.0]
  • Invalid value would lead to [0.0, 0.0, 0.0, 0.0, 0.0]

For OPTION 1, I figured:

  • keep=true adds 1 extra "category" indicating an invalid value
  • dropLast=true removes the last category
  • Invalid values ("keep") are handled before removing the last category ("dropLast").

I realize now that it's the ordering of operations which is unclear here. If we handled "dropLast" before "keep" then we would have OPTION 2:

  • keep=true, dropLast=true ==> vector size n (all-0 for the last category; has a 1 at the end for invalid values)
    • These semantics seem weird to me, so I'd prefer we handle "keep" before "dropLast".
  • keep=true, dropLast=false ==> (same regardless of the order of operations)
  • keep=false, dropLast=true ==> (same regardless of the order of operations)
  • keep=false, dropLast=false ==> (same regardless of the order of operations)

OPTION 1 is more flexible than disallowing keep=true, dropLast=true. With OPTION 1, we can match sklearn's behavior with keep=true, dropLast=true.

What do you think?

@viirya
Copy link
Member Author

viirya commented Dec 9, 2017

Ok, I understood. In other words, the extra category is added as the last category and dropLast option works as before. It makes sense to me.

@SparkQA
Copy link

SparkQA commented Dec 10, 2017

Test build #84690 has finished for PR 19527 at commit 32318fa.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@viirya
Copy link
Member Author

viirya commented Dec 13, 2017

ping @jkbradley Can you review this again? Thanks.

Copy link
Member

@jkbradley jkbradley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the delay, and thanks for the PR! I finally made a detailed pass. There are a couple of subtle issues I commented on (and verified locally).

One general comment about unit tests: When you're writing a test which takes a toy dataset, transforms it, and checks against expected output, I recommend writing the input and expected output in the original DataFrame in a nice, human-readable layout. It makes the tests easier to read and maintain. I'm OK if you don't want to reformat this PR though.

size
}

if (label < numCategory) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just noticed: we should check label >= 0 too

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It'd be nice to test negative values in unit tests too.

private def verifyNumOfValues(schema: StructType): StructType = {
$(outputCols).zipWithIndex.foreach { case (outputColName, idx) =>
val inputColName = $(inputCols)(idx)
val attrGroup = AttributeGroup.fromStructField(schema(outputColName))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should use inputColName, not outputColName, here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we use outputColName to get the size of attribute group of the output column. For example if the first input column specifies 5 categorical values, the size of the attribute group is 5.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This number should be consistent with the corresponding number in trained model.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, you're right; my bad!


val numAttrsArray = dataset.select(columns: _*).rdd.map { row =>
(0 until numOfColumns).map(idx => row.getDouble(idx)).toArray
}.treeAggregate(new Array[Double](numOfColumns))(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The use of treeAggregate could be more efficient: You're allowed to modify the first arg of combOp and seqOp, which would decrease the number of object allocations a lot.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

dropLast: Boolean,
inputColNames: Seq[String],
outputColNames: Seq[String],
handleInvalid: Boolean = false): Seq[AttributeGroup] = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's call this keepInvalid since that's what it seems to mean.

s"output columns ${outputColNames.length}.")

inputColNames.zip(outputColNames).map { case (inputColName, outputColName) =>
require(schema(inputColName).dataType.isInstanceOf[NumericType],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's slightly more concise to use ml.util.SchemaUtils.checkNumericType

inputColNames.zip(outputColNames).map { case (inputColName, outputColName) =>
require(schema(inputColName).dataType.isInstanceOf[NumericType],
s"Input column must be of type NumericType but got ${schema(inputColName).dataType}")
require(!existingFields.exists(_.name == outputColName),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likewise, this could be checked by using ml.util.SchemaUtils.appendColumn at the end of this method.

* When `handleInvalid` is configured to 'keep', an extra "category" indicating invalid values is
* added as last category. So when `dropLast` is true, invalid values are encoded as all-zeros
* vector.
*
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's be nice to add a note explaining the multi-column API: input/output cols come in pairs, specified by the order in the arrays, and each pair is treated independently.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added.


private[feature] val KEEP_INVALID: String = "keep"
private[feature] val ERROR_INVALID: String = "error"
private[feature] val supportedHandleInvalids: Array[String] = Array(KEEP_INVALID, ERROR_INVALID)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm OK not supporting skip for now, but I don't see a reason not to have it in the future.

val attrGroups = OneHotEncoderCommon.getOutputAttrGroupFromData(
dataset, $(dropLast), inputColNames, outputColNames, keepInvalid)
attrGroups.zip(columnToScanIndices).foreach { case (attrGroup, idx) =>
categorySizes(idx) = attrGroup.size
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This category size is being specified with keepInvalid and dropLast taken into account, but it shouldn't be. It should be the plain number of categories since OneHotEncoderModel can have its keepInvalid and dropLast settings changed. Unit testing this somewhat complex behavior would be nice (and verifyNumOfValues should fail in some of these cases).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: I tried this by fitting with dropLast = false and then transforming with dropLast = true, and it caused a failure.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we don't need to consider keepInvalid and dropLast when fitting?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, I think it's correct. We only need to consider dropLast and keepInvalid when transforming.

}
case _: NumericAttribute =>
throw new RuntimeException(
s"The input column ${inputCol.name} cannot be numeric.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be confusing since elsewhere we do say it can be Numeric (since "numeric" is overloaded). How about "cannot be continuous-valued"?

@MLnick
Copy link
Contributor

MLnick commented Dec 25, 2017 via email

@SparkQA
Copy link

SparkQA commented Dec 26, 2017

Test build #85389 has finished for PR 19527 at commit 144f07d.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@viirya
Copy link
Member Author

viirya commented Dec 26, 2017

Unit tests are reformatted too.

@SparkQA
Copy link

SparkQA commented Dec 26, 2017

Test build #85391 has finished for PR 19527 at commit 587ad42.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

Copy link
Member

@jkbradley jkbradley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates! The tests are a lot easier to read through now. My new comments are just 1 potential bug and 2 suggestions.


@Since("2.3.0")
override def transformSchema(schema: StructType): StructType = {
// When fitting data, we want the the plain number of categories without `handleInvalid` and
Copy link
Member

@jkbradley jkbradley Dec 27, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't correct. If handleInvalid or dropLast are set, then transformSchema should take them into account. (transformSchema should return the correct schema for the given Param values.)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for fitting to get the plain number of categories based on the input schema. Here we just get the plain numbers from the input schema and record them into the trained model.

In model transform, transformSchema takes them into account.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Estimator.transformSchema should return an accurate schema for whatever Estimator.fit().transform() would produce. That's what the API requires. The schema this method is returning right now is not accurate since it does not take the Params handleInvalid and dropLast into account.

I agree the Model should store the plain number of categories, but that's a separate issue from what transformSchema is supposed to do.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

During fitting, we obtain the plain number of categories from the metadata in the schema first, and then from input data if no metadata available.

If transformSchema has the requirement to return accurate output schema even for Estimator (because I think estimator doesn't output data but model), then I think I need to convert the number of categories with handleInvalid and dropLast taken into account back to plain number before storing into the model.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need to change the numCategories which the model stores?

Model.transformSchema already has logic for taking the raw numCategories and computing a schema which takes the Params into account. Can that not be reused for Estimator.transformSchema?

I just think it's a lot simpler to store the raw numCategories in the Model. (It would be really confusing for those values to be based off of 1 setting of Params, and then for the user to change the Params, and for us to therefore need to store both original and new values for Params.)

Copy link
Member Author

@viirya viirya Dec 29, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is misunderstanding. The model stores the raw numCategories now. I don't want to change it.

For now, the transformSchema logic shared between Model and Estimator is the same: it uses the metadata from input schema to compute the output schema with the Params into account or not.

The numCategories is determined from this schema then. As you require Estimator.transformSchema to return a schema with the Params into account, the numCategories derived from the schema is not raw now. But we want to record raw numCategories into the model, so I need to compute raw numCategories from the numCategories derived the schema.

Another approach is, I let Estimator.transformSchema to return a schema with the Params into account. When I need to record the raw numCategories into the model, I call validateAndTransformSchema again to get the raw numCategories without the Param into account.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, sorry, I did misunderstand. I like your last suggestion:

Another approach is, I let Estimator.transformSchema to return a schema with the Params into account. When I need to record the raw numCategories into the model, I call validateAndTransformSchema again to get the raw numCategories without the Param into account.

This will be nice for reusing code + give the desired output for transformSchema.

Thank you!


// The actual number of categories varies due to different setting of `dropLast` and
// `handleInvalid`.
private def configedCategorySizes: Array[Int] = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is always called with a single index, but it will construct the whole array each time. How about making it a method which takes an index and returns a single size Int?

val handleInvalid = getHandleInvalid
val keepInvalid = handleInvalid == OneHotEncoderEstimator.KEEP_INVALID

udf { (label: Double, size: Int) =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic would be simpler if this were written to take the original numCategories value. The way it is now, it takes the "configed" numCategories and then reverses that logic a bit when looking at the label. It's confusing to me, at least.

@jkbradley
Copy link
Member

@viirya Will you have time to update this soon? I'd like to get it in 2.3, which will mean merging it by Jan. 1. If you're busy, I can merge it as is and do a follow-up myself too. Thanks!

@viirya
Copy link
Member Author

viirya commented Dec 31, 2017

@jkbradley Ok. I will update this today.

@SparkQA
Copy link

SparkQA commented Dec 31, 2017

Test build #85551 has finished for PR 19527 at commit e94496a.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@jkbradley
Copy link
Member

Thanks for the updates! I still think there's some confusion, but since I think this code is correct & it doesn't affect APIs, I'll go ahead and merge this. I'll ping you on a follow-up PR to show what I had in mind.

LGTM
Merging with master

@asfgit asfgit closed this in 994065d Dec 31, 2017
@viirya
Copy link
Member Author

viirya commented Dec 31, 2017

@jkbradley Thanks for reviewing and merging this. Thanks for all others helping this too.

@jkbradley
Copy link
Member

Thank you for all the work in this PR! Here's the follow-up: #20132

asfgit pushed a commit that referenced this pull request Jan 5, 2018
## What changes were proposed in this pull request?

Follow-up cleanups for the OneHotEncoderEstimator PR.  See some discussion in the original PR: #19527 or read below for what this PR includes:
* configedCategorySize: I reverted this to return an Array.  I realized the original setup (which I had recommended in the original PR) caused the whole model to be serialized in the UDF.
* encoder: I reorganized the logic to show what I meant in the comment in the previous PR.  I think it's simpler but am open to suggestions.

I also made some small style cleanups based on IntelliJ warnings.

## How was this patch tested?

Existing unit tests

Author: Joseph K. Bradley <[email protected]>

Closes #20132 from jkbradley/viirya-SPARK-13030.

(cherry picked from commit 930b90a)
Signed-off-by: Joseph K. Bradley <[email protected]>
asfgit pushed a commit that referenced this pull request Jan 5, 2018
## What changes were proposed in this pull request?

Follow-up cleanups for the OneHotEncoderEstimator PR.  See some discussion in the original PR: #19527 or read below for what this PR includes:
* configedCategorySize: I reverted this to return an Array.  I realized the original setup (which I had recommended in the original PR) caused the whole model to be serialized in the UDF.
* encoder: I reorganized the logic to show what I meant in the comment in the previous PR.  I think it's simpler but am open to suggestions.

I also made some small style cleanups based on IntelliJ warnings.

## How was this patch tested?

Existing unit tests

Author: Joseph K. Bradley <[email protected]>

Closes #20132 from jkbradley/viirya-SPARK-13030.
@viirya viirya deleted the SPARK-13030 branch December 27, 2023 18:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants