-
Notifications
You must be signed in to change notification settings - Fork 28.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-13030][ML] Create OneHotEncoderEstimator for OneHotEncoder as Estimator #19527
Conversation
Test build #82879 has finished for PR 19527 at commit
|
cc @MLnick @WeichenXu123 @jkbradley This adds a new class |
|
||
encoder(col(inputColName).cast(DoubleType)).as(outputColName, metadata) | ||
} | ||
val allCols = Seq(col("*")) ++ encodedColumns |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just looking at this in the context of the previous multi-col work done (Bucketizer
and Imputer
) is there any real performance difference between this and the use of dataset.withColumns(...)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Before running real performance test, I guess the difference might not be so significant.
Besides performance, this multi-column support on OneHotEncoder
can also benefit user convenience when using its API. For example scikit-learn's OneHotEncoder is also with multi-column support.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we can get the number of values for multi-column in https://github.com/apache/spark/pull/19527/files#r145457081, the performance of fitting can be improved.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Performance test results are attached now. Existing OneHotEncoder
only performs transform
. So if we consider the time of fitting and transforming in this estimator/model with existing one hot encoder, this multi-column approach is faster.
* This method is called when we want to generate `AttributeGroup` from actual data for | ||
* one-hot encoder. | ||
*/ | ||
def getOutputAttrGroupFromData( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
calling this multiple times for multiple columns seems inefficient. It should be possible to use dataframe ops to compute this efficiently?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be improved under multi-column usage. Let me think about it. Thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The rdd approach has advantage of early-stop if any values are invalid. It seems that dataframe ops don't have equivalent functions.
We can only check if max values are more than Int.MaxValue after aggregation. Seems that we also need to compute min
aggregation function for columns, so we can check if any values are less than zero.
So currently I think I will modify this to multi-column version but still use rdd approach. Sounds good to you?
throw new SparkException(s"Unseen value: $label. To handle unseen values, " + | ||
s"set Param handleInvalid to ${OneHotEncoderEstimator.SKIP_INVALID}.") | ||
} else { | ||
Vectors.sparse(size, emptyIndices, emptyValues) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This behaviour seems more like "keep" in StringIndexer than "skip". Skip filters out rows with invalid values and returns a dataframe with number of rows <= the number of rows in the input. "keep" maps all invalid values to a single, known value. Should we rename this to KEEP_INVALID
?
https://spark.apache.org/docs/2.2.0/ml-features.html#stringindexer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When I checked scikit-learn's OneHotEncoder: http://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.OneHotEncoder.html, its handle_unknown has error
or ignore
choices. Sounds like ignore
is more close to skip
in semantics?
I'm willing to rename it to keep
if others also think it is better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Although keep
does seem a little strange semantically, it's probably best to match the param option names from StringIndexer
since that will be the typical pipeline StringIndexer
-> OneHotEncoder
.
And skip
definitely would seem to imply that rows are filtered out or dropped, which is not the case here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok. Then I will rename it to keep
. Thanks.
Benchmark against multi-column one hot encoder. Multi-Col, Multiple Run: The first commit. Run multiple Definitely, running one Fitting:
Transforming:
Benchmark codes: import org.apache.spark.ml.feature._
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
import spark.implicits._
import scala.util.Random
val seed = 123l
val random = new Random(seed)
val n = 10000
val m = 1000
val rows = sc.parallelize(1 to n).map(i=> Row(Array.fill(m)(random.nextInt(1000)): _*))
val struct = new StructType(Array.range(0,m,1).map(i => StructField(s"c$i",IntegerType,true)))
val df = spark.createDataFrame(rows, struct)
df.persist()
df.count()
val inputCols = Array.range(0,m,1).map(i => s"c$i")
val outputCols = Array.range(0,m,1).map(i => s"c${i}_encoded")
val encoder = new OneHotEncoderEstimator().setInputCols(inputCols).setOutputCols(outputCols)
var durationFitting = 0.0
var durationTransforming = 0.0
for (i <- 0 until 10) {
val startFitting = System.nanoTime()
val model = encoder.fit(df)
val endFitting = System.nanoTime()
durationFitting += (endFitting - startFitting) / 1e9
val startTransforming = System.nanoTime()
model.transform(df).count
val endTransforming = System.nanoTime()
durationTransforming += (endTransforming - startTransforming) / 1e9
}
println(s"fitting: ${durationFitting / 10}")
println(s"transforming: ${durationTransforming / 10}")
|
import org.apache.spark.ml.attribute._ | ||
import org.apache.spark.ml.linalg.Vectors | ||
import org.apache.spark.ml.param._ | ||
import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol, HasInputCols, HasOutputCol, HasOutputCols} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
HasInputCol
, HasOutputCol
not needed, also I think Transformer
above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, thanks.
import org.apache.spark.sql.types.{DoubleType, NumericType, StructField, StructType} | ||
|
||
/** Private trait for params for OneHotEncoderEstimator and OneHotEncoderModel */ | ||
private[ml] trait OneHotEncoderParams extends Params with HasHandleInvalid |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can the old OneHotEncoder
also inherit this trait?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, should we? If we don't plan to add HasHandleInvalid
, HasInputCols
and HasOutputCols
into existing OneHotEncoder
, I think we can keep it as it is now.
With this one hot encoder estimator added, we may want to deprecate the existing OneHotEncoder
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the best approach is to leave the existing OneHotEncoder
as it is and deprecate it for 3.0.0
. Then in 3.0.0
we can remove the old one and rename OneHotEncoderEstimator
-> OneHotEncoder
to fit with the convention of other Estimator/Model
pairs.
Benchmark against existing one hot encoder. Because existing encoder only needs to run Transforming:
Benchmark codes: import org.apache.spark.ml.feature._
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
import spark.implicits._
import scala.util.Random
val seed = 123l
val random = new Random(seed)
val n = 10000
val m = 1000
val rows = sc.parallelize(1 to n).map(i=> Row(Array.fill(m)(random.nextInt(1000)): _*))
val struct = new StructType(Array.range(0,m,1).map(i => StructField(s"c$i",IntegerType,true)))
val df = spark.createDataFrame(rows, struct)
df.persist()
df.count()
val inputCols = Array.range(0,m,1).map(i => s"c$i")
val outputCols = Array.range(0,m,1).map(i => s"c${i}_encoded")
val encoders = Array.range(0,m,1).map(i => new OneHotEncoder().setInputCol(s"c$i").setOutputCol(s"c${i}_encoded"))
var duration = 0.0
for (i <- 0 until 10) {
var encoded = df
val start = System.nanoTime()
encoders.foreach { encoder =>
encoded = encoder.transform(encoded)
}
encoded.count
val end = System.nanoTime()
duration += (end - start) / 1e9
}
println(s"duration: ${duration / 10}") |
Test build #82917 has finished for PR 19527 at commit
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice work, thanks!
encoder(col(inputColName).cast(DoubleType)).as(outputColName, metadata) | ||
} | ||
val allCols = Seq(col("*")) ++ encodedColumns | ||
dataset.select(allCols: _*) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does we need to handle the case: outputColNames
conflict with existing column names ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah my other #19527 (comment) was that we created withColumns
in Dataset
that handles this?
private[spark] def withColumns(colNames: Seq[String], cols: Seq[Column]): DataFrame = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@WeichenXu123 Yes. transformSchema
has handled the case that there are existing columns which conflict with any of outputColNames
.
@MLnick Sorry I misunderstood your comment #19527 (comment). withColumns
approach calls select
so there is no performance difference. withColumns
did one more thing to detect duplication in the column names. We should check the duplication too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@MLnick Use withColumns
now.
|
||
import OneHotEncoderModel._ | ||
|
||
private def encoders: Array[UserDefinedFunction] = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about defining one UDF (label, size) => Vector
instead of defining an array of UDFs ?
Or do you find some apparent perf difference between them ?
If not, I prefer defining only one UDF, which makes the code clearer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good. Use one UDF now.
Test build #82930 has finished for PR 19527 at commit
|
Test build #82932 has finished for PR 19527 at commit
|
Test build #82933 has finished for PR 19527 at commit
|
Test build #82934 has finished for PR 19527 at commit
|
@BryanCutler @MLnick @WeichenXu123 Thanks for reviewing. Your comments should be all addressed now. Please take a look again when you have more time. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just one minor comment on grammar, LGTM!
@Since("2.3.0") | ||
override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", | ||
"How to handle invalid data " + | ||
"Options are 'keep' (invalid data are ignored) or error (throw an error).", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it should read "invalid data is ignored", same above too
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be nice to be more precise than "is ignored". "ignoring data" sounds a lot like "skipping data", which is not what "keep" does :).
How about something like 'keep' (invalid data produces a vector of zeros) or ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will change the wording. Thanks.
Test build #83000 has finished for PR 19527 at commit
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it seems that the input parameter outputColName is not used in method OneHotEncoderCommon.genOutputAttrNames.
@viirya
@huaxingao Good catch! Thanks. |
Test build #83070 has finished for PR 19527 at commit
|
@@ -41,8 +41,12 @@ import org.apache.spark.sql.types.{DoubleType, NumericType, StructType} | |||
* The output vectors are sparse. | |||
* | |||
* @see `StringIndexer` for converting categorical values into category indices | |||
* @deprecated `OneHotEncoderEstimator` will be renamed `OneHotEncoder` and this `OneHotEncoder` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note for the future: For 3.0, it'd be nice to do what you're describing here but also leave OneHotEncoderEstimator as a deprecated alias. That way, user code won't break but will have deprecation warnings when upgrading to 3.0.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok. Sounds good.
Question about this PR description comment:
Why is this necessary? With
|
The behavior I thought is:
For the cases of If we make it behave as:
For example with 5 categories, we don't know |
For the semantics I described ("OPTION 1"), I think it is clear:
For OPTION 1, I figured:
I realize now that it's the ordering of operations which is unclear here. If we handled "dropLast" before "keep" then we would have OPTION 2:
OPTION 1 is more flexible than disallowing What do you think? |
Ok, I understood. In other words, the extra category is added as the last category and |
Test build #84690 has finished for PR 19527 at commit
|
ping @jkbradley Can you review this again? Thanks. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the delay, and thanks for the PR! I finally made a detailed pass. There are a couple of subtle issues I commented on (and verified locally).
One general comment about unit tests: When you're writing a test which takes a toy dataset, transforms it, and checks against expected output, I recommend writing the input and expected output in the original DataFrame in a nice, human-readable layout. It makes the tests easier to read and maintain. I'm OK if you don't want to reformat this PR though.
size | ||
} | ||
|
||
if (label < numCategory) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just noticed: we should check label >= 0 too
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It'd be nice to test negative values in unit tests too.
private def verifyNumOfValues(schema: StructType): StructType = { | ||
$(outputCols).zipWithIndex.foreach { case (outputColName, idx) => | ||
val inputColName = $(inputCols)(idx) | ||
val attrGroup = AttributeGroup.fromStructField(schema(outputColName)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should use inputColName, not outputColName, here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here we use outputColName
to get the size of attribute group of the output column. For example if the first input column specifies 5 categorical values, the size of the attribute group is 5.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This number should be consistent with the corresponding number in trained model.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, you're right; my bad!
|
||
val numAttrsArray = dataset.select(columns: _*).rdd.map { row => | ||
(0 until numOfColumns).map(idx => row.getDouble(idx)).toArray | ||
}.treeAggregate(new Array[Double](numOfColumns))( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The use of treeAggregate could be more efficient: You're allowed to modify the first arg of combOp and seqOp, which would decrease the number of object allocations a lot.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
dropLast: Boolean, | ||
inputColNames: Seq[String], | ||
outputColNames: Seq[String], | ||
handleInvalid: Boolean = false): Seq[AttributeGroup] = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's call this keepInvalid since that's what it seems to mean.
s"output columns ${outputColNames.length}.") | ||
|
||
inputColNames.zip(outputColNames).map { case (inputColName, outputColName) => | ||
require(schema(inputColName).dataType.isInstanceOf[NumericType], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's slightly more concise to use ml.util.SchemaUtils.checkNumericType
inputColNames.zip(outputColNames).map { case (inputColName, outputColName) => | ||
require(schema(inputColName).dataType.isInstanceOf[NumericType], | ||
s"Input column must be of type NumericType but got ${schema(inputColName).dataType}") | ||
require(!existingFields.exists(_.name == outputColName), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Likewise, this could be checked by using ml.util.SchemaUtils.appendColumn at the end of this method.
* When `handleInvalid` is configured to 'keep', an extra "category" indicating invalid values is | ||
* added as last category. So when `dropLast` is true, invalid values are encoded as all-zeros | ||
* vector. | ||
* |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's be nice to add a note explaining the multi-column API: input/output cols come in pairs, specified by the order in the arrays, and each pair is treated independently.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added.
|
||
private[feature] val KEEP_INVALID: String = "keep" | ||
private[feature] val ERROR_INVALID: String = "error" | ||
private[feature] val supportedHandleInvalids: Array[String] = Array(KEEP_INVALID, ERROR_INVALID) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm OK not supporting skip for now, but I don't see a reason not to have it in the future.
val attrGroups = OneHotEncoderCommon.getOutputAttrGroupFromData( | ||
dataset, $(dropLast), inputColNames, outputColNames, keepInvalid) | ||
attrGroups.zip(columnToScanIndices).foreach { case (attrGroup, idx) => | ||
categorySizes(idx) = attrGroup.size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This category size is being specified with keepInvalid and dropLast taken into account, but it shouldn't be. It should be the plain number of categories since OneHotEncoderModel can have its keepInvalid and dropLast settings changed. Unit testing this somewhat complex behavior would be nice (and verifyNumOfValues should fail in some of these cases).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: I tried this by fitting with dropLast = false and then transforming with dropLast = true, and it caused a failure.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So we don't need to consider keepInvalid and dropLast when fitting?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea, I think it's correct. We only need to consider dropLast and keepInvalid when transforming.
} | ||
case _: NumericAttribute => | ||
throw new RuntimeException( | ||
s"The input column ${inputCol.name} cannot be numeric.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This could be confusing since elsewhere we do say it can be Numeric (since "numeric" is overloaded). How about "cannot be continuous-valued"?
Agree on keeping the new OneHotEncoderEstimator as an alias for 3.0
…On Fri, 1 Dec 2017 at 23:29, jkbradley ***@***.***> wrote:
***@***.**** commented on this pull request.
------------------------------
In mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
<#19527 (comment)>:
> @@ -41,8 +41,12 @@ import org.apache.spark.sql.types.{DoubleType, NumericType, StructType}
* The output vectors are sparse.
*
* @see `StringIndexer` for converting categorical values into category indices
+ * @deprecated `OneHotEncoderEstimator` will be renamed `OneHotEncoder` and this `OneHotEncoder`
Note for the future: For 3.0, it'd be nice to do what you're describing
here but also leave OneHotEncoderEstimator as a deprecated alias. That way,
user code won't break but will have deprecation warnings when upgrading to
3.0.
—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub
<#19527 (review)>,
or mute the thread
<https://github.com/notifications/unsubscribe-auth/AA_SB0bdyHJzm96I5a_rZTSxvo1ya0nxks5s8G-jgaJpZM4P9Yvj>
.
|
Test build #85389 has finished for PR 19527 at commit
|
Unit tests are reformatted too. |
Test build #85391 has finished for PR 19527 at commit
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the updates! The tests are a lot easier to read through now. My new comments are just 1 potential bug and 2 suggestions.
|
||
@Since("2.3.0") | ||
override def transformSchema(schema: StructType): StructType = { | ||
// When fitting data, we want the the plain number of categories without `handleInvalid` and |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This isn't correct. If handleInvalid or dropLast are set, then transformSchema should take them into account. (transformSchema should return the correct schema for the given Param values.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is for fitting to get the plain number of categories based on the input schema. Here we just get the plain numbers from the input schema and record them into the trained model.
In model transform, transformSchema takes them into account.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Estimator.transformSchema should return an accurate schema for whatever Estimator.fit().transform() would produce. That's what the API requires. The schema this method is returning right now is not accurate since it does not take the Params handleInvalid and dropLast into account.
I agree the Model should store the plain number of categories, but that's a separate issue from what transformSchema is supposed to do.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
During fitting, we obtain the plain number of categories from the metadata in the schema first, and then from input data if no metadata available.
If transformSchema
has the requirement to return accurate output schema even for Estimator
(because I think estimator doesn't output data but model), then I think I need to convert the number of categories with handleInvalid and dropLast taken into account back to plain number before storing into the model.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you need to change the numCategories which the model stores?
Model.transformSchema already has logic for taking the raw numCategories and computing a schema which takes the Params into account. Can that not be reused for Estimator.transformSchema?
I just think it's a lot simpler to store the raw numCategories in the Model. (It would be really confusing for those values to be based off of 1 setting of Params, and then for the user to change the Params, and for us to therefore need to store both original and new values for Params.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is misunderstanding. The model stores the raw numCategories now. I don't want to change it.
For now, the transformSchema logic shared between Model and Estimator is the same: it uses the metadata from input schema to compute the output schema with the Params into account or not.
The numCategories is determined from this schema then. As you require Estimator.transformSchema to return a schema with the Params into account, the numCategories derived from the schema is not raw now. But we want to record raw numCategories into the model, so I need to compute raw numCategories from the numCategories derived the schema.
Another approach is, I let Estimator.transformSchema to return a schema with the Params into account. When I need to record the raw numCategories into the model, I call validateAndTransformSchema again to get the raw numCategories without the Param into account.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, sorry, I did misunderstand. I like your last suggestion:
Another approach is, I let Estimator.transformSchema to return a schema with the Params into account. When I need to record the raw numCategories into the model, I call validateAndTransformSchema again to get the raw numCategories without the Param into account.
This will be nice for reusing code + give the desired output for transformSchema.
Thank you!
|
||
// The actual number of categories varies due to different setting of `dropLast` and | ||
// `handleInvalid`. | ||
private def configedCategorySizes: Array[Int] = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is always called with a single index, but it will construct the whole array each time. How about making it a method which takes an index and returns a single size Int?
val handleInvalid = getHandleInvalid | ||
val keepInvalid = handleInvalid == OneHotEncoderEstimator.KEEP_INVALID | ||
|
||
udf { (label: Double, size: Int) => |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This logic would be simpler if this were written to take the original numCategories value. The way it is now, it takes the "configed" numCategories and then reverses that logic a bit when looking at the label. It's confusing to me, at least.
@viirya Will you have time to update this soon? I'd like to get it in 2.3, which will mean merging it by Jan. 1. If you're busy, I can merge it as is and do a follow-up myself too. Thanks! |
@jkbradley Ok. I will update this today. |
Test build #85551 has finished for PR 19527 at commit
|
Thanks for the updates! I still think there's some confusion, but since I think this code is correct & it doesn't affect APIs, I'll go ahead and merge this. I'll ping you on a follow-up PR to show what I had in mind. LGTM |
@jkbradley Thanks for reviewing and merging this. Thanks for all others helping this too. |
Thank you for all the work in this PR! Here's the follow-up: #20132 |
## What changes were proposed in this pull request? Follow-up cleanups for the OneHotEncoderEstimator PR. See some discussion in the original PR: #19527 or read below for what this PR includes: * configedCategorySize: I reverted this to return an Array. I realized the original setup (which I had recommended in the original PR) caused the whole model to be serialized in the UDF. * encoder: I reorganized the logic to show what I meant in the comment in the previous PR. I think it's simpler but am open to suggestions. I also made some small style cleanups based on IntelliJ warnings. ## How was this patch tested? Existing unit tests Author: Joseph K. Bradley <[email protected]> Closes #20132 from jkbradley/viirya-SPARK-13030. (cherry picked from commit 930b90a) Signed-off-by: Joseph K. Bradley <[email protected]>
## What changes were proposed in this pull request? Follow-up cleanups for the OneHotEncoderEstimator PR. See some discussion in the original PR: #19527 or read below for what this PR includes: * configedCategorySize: I reverted this to return an Array. I realized the original setup (which I had recommended in the original PR) caused the whole model to be serialized in the UDF. * encoder: I reorganized the logic to show what I meant in the comment in the previous PR. I think it's simpler but am open to suggestions. I also made some small style cleanups based on IntelliJ warnings. ## How was this patch tested? Existing unit tests Author: Joseph K. Bradley <[email protected]> Closes #20132 from jkbradley/viirya-SPARK-13030.
What changes were proposed in this pull request?
This patch adds a new class
OneHotEncoderEstimator
which extendsEstimator
. Thefit
method returnsOneHotEncoderModel
.Common methods between existing
OneHotEncoder
and newOneHotEncoderEstimator
, such as transforming schema, are extracted and put intoOneHotEncoderCommon
to reduce code duplication.Multi-column support
OneHotEncoderEstimator
adds simpler multi-column support because it is new API and can be free from backward compatibility.handleInvalid Param support
OneHotEncoderEstimator
supportshandleInvalid
Param. It supportserror
andkeep
.How was this patch tested?
Added new test suite
OneHotEncoderEstimatorSuite
.