-
Notifications
You must be signed in to change notification settings - Fork 28.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-23528][ML] Add numIter to ClusteringSummary #20701
Conversation
cc @yanboliang @zhengruifeng since I saw you worked on this before, thanks. |
Test build #87829 has finished for PR 20701 at commit
|
Test build #87830 has finished for PR 20701 at commit
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've done a quick pass and I'm going to see if @sethah has some comments.
@@ -46,6 +47,10 @@ class KMeansModel @Since("2.4.0") (@Since("1.0.0") val clusterCenters: Array[Vec | |||
private val clusterCentersWithNorm = | |||
if (clusterCenters == null) null else clusterCenters.map(new VectorWithNorm(_)) | |||
|
|||
@Since("2.4.0") | |||
def this(clusterCenters: Array[Vector], distanceMeasure: String) = | |||
this(clusterCenters: Array[Vector], distanceMeasure, -1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So were using -1 to indicate we don't have the numIter information
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, this can happen for instance when reloading a persisted model. Moreover this is only for the mllib model, which as far as I know is suggested not to be used anymore in favor of the new ml api. Any concern/suggestion about this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds reasonable, I personally don't enjoy -1 to indicate lack of information but it seems to be what we have generally used in the past for mllib summary info into ml so my personal feelings aren't important :)
@@ -97,6 +97,7 @@ class BisectingKMeansSuite | |||
test("fit, transform and summary") { | |||
val predictionColName = "bisecting_kmeans_prediction" | |||
val bkm = new BisectingKMeans().setK(k).setPredictionCol(predictionColName).setSeed(1) | |||
.setMaxIter(2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So I'd be more comfortable having this in a separate test, 2 iterations is not a lot.
@@ -127,6 +128,7 @@ class BisectingKMeansSuite | |||
assert(clusterSizes.length === k) | |||
assert(clusterSizes.sum === numRows) | |||
assert(clusterSizes.forall(_ >= 0)) | |||
assert(summary.numIter == 2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be nice to see a test where its not maxIter value being copied over
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In KMeansSuite
the value is not maxIter
(it performs only 1 iteration in that case). In BisectingKMeans
numIter
is always maxIter
since we are always performing maxIter
(see
spark/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala
Line 192 in b6f837c
for (iter <- 0 until maxIterations) { |
Does it answer to your comment?
Test build #88150 has finished for PR 20701 at commit
|
Test build #88306 has finished for PR 20701 at commit
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this, good progress. I have a few improvements in mind, and maybe we can get @sethah to take a look as well, but if the rest of the ML committers are busy thats ok too.
@@ -46,6 +47,10 @@ class KMeansModel @Since("2.4.0") (@Since("1.0.0") val clusterCenters: Array[Vec | |||
private val clusterCentersWithNorm = | |||
if (clusterCenters == null) null else clusterCenters.map(new VectorWithNorm(_)) | |||
|
|||
@Since("2.4.0") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So I think the correct since annotation here would be 0.8.0 since this is just a move of the previous constructor right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is the right one. 0.8.0 is the annotation for the KMeansModel
class, while the previous main constructor was added (by me) is a previous PR for 2.4.0 in order to add the distanceMeasure
variable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does this constructor need to be public?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, I will make it private, thanks.
@@ -312,4 +312,5 @@ class BisectingKMeansSummary private[clustering] ( | |||
predictions: DataFrame, | |||
predictionCol: String, | |||
featuresCol: String, | |||
k: Int) extends ClusteringSummary(predictions, predictionCol, featuresCol, k) | |||
k: Int, | |||
numIter: Int) extends ClusteringSummary(predictions, predictionCol, featuresCol, k, numIter) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here (and in the others), we should add this as param in the comment above as done with the other params
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for pointing this out, I completely missed it. Thank you, I am adding them.
@@ -34,7 +34,8 @@ class ClusteringSummary private[clustering] ( | |||
@transient val predictions: DataFrame, | |||
val predictionCol: String, | |||
val featuresCol: String, | |||
val k: Int) extends Serializable { | |||
val k: Int, | |||
@Since("2.4.0") val numIter: Int) extends Serializable { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add this param in the comment above.
@@ -46,6 +47,10 @@ class KMeansModel @Since("2.4.0") (@Since("1.0.0") val clusterCenters: Array[Vec | |||
private val clusterCentersWithNorm = | |||
if (clusterCenters == null) null else clusterCenters.map(new VectorWithNorm(_)) | |||
|
|||
@Since("2.4.0") | |||
def this(clusterCenters: Array[Vector], distanceMeasure: String) = | |||
this(clusterCenters: Array[Vector], distanceMeasure, -1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds reasonable, I personally don't enjoy -1 to indicate lack of information but it seems to be what we have generally used in the past for mllib summary info into ml so my personal feelings aren't important :)
@@ -36,8 +36,9 @@ import org.apache.spark.sql.{Row, SparkSession} | |||
* A clustering model for K-means. Each point belongs to the cluster with the closest center. | |||
*/ | |||
@Since("0.8.0") | |||
class KMeansModel @Since("2.4.0") (@Since("1.0.0") val clusterCenters: Array[Vector], | |||
@Since("2.4.0") val distanceMeasure: String) | |||
class KMeansModel private[spark] (@Since("1.0.0") val clusterCenters: Array[Vector], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So previously the main constructor was not private, any particular reason we are making in private? if someone else is implementing something which extends the kmeans model this might be a little frustrating.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just didn't want the user to be able to create a KMeansModel setting the number of iterations. I moved the other constructor which is still available. I don't have strong reasons against making this public, so I am removing the private clause if you think we best let it to be public.
@@ -36,6 +36,11 @@ object MimaExcludes { | |||
|
|||
// Exclude rules for 2.4.x | |||
lazy val v24excludes = v23excludes ++ Seq( | |||
// [SPARK-23528] Add numIter to ClusteringSummary |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a note for other reviewers/myself these are all private spark constructors
Test build #88355 has finished for PR 20701 at commit
|
retest this please |
Test build #88374 has finished for PR 20701 at commit
|
retest this please |
Test build #88381 has finished for PR 20701 at commit
|
any more comments @holdenk ? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
General comment: things that are specific to training, like numIter, have been separated into training summary classes elsewhere, e.g. LinearRegressionTrainingSummary extends LinearRegressionSummary
. Is there some reason to deviate from that here? numIter
doesn't make sense when evaluating on a test set, for instance.
@@ -46,6 +47,10 @@ class KMeansModel @Since("2.4.0") (@Since("1.0.0") val clusterCenters: Array[Vec | |||
private val clusterCentersWithNorm = | |||
if (clusterCenters == null) null else clusterCenters.map(new VectorWithNorm(_)) | |||
|
|||
@Since("2.4.0") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does this constructor need to be public?
@sethah I have not introduces training summary classes because it would have meant a quite bigger change - since they have a quite different approach, having a trait and an Impl class for each of them - and I have not seen that pattern to be always used. |
retest this please |
Test build #88848 has finished for PR 20701 at commit
|
ping @sethah - what do you think about if this needs a separate training summary trait? |
Test build #89722 has finished for PR 20701 at commit
|
kindly ping @holdenk |
1 similar comment
kindly ping @holdenk |
Test build #91256 has finished for PR 20701 at commit
|
LGTM pending Jenkins retest. Jenkins retest this please. |
The AppVeyor build failure looks spurrious and I don't know how to retrigger it. |
retest this please |
thanks for your review @holdenk. I don't know how to retrigger AppVeyor too, unfortunately :( |
Test build #92257 has finished for PR 20701 at commit
|
Test build #92924 has finished for PR 20701 at commit
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me, manually built docs locally to double check the since annotation was inherited because my memory was a bit fuzzy on how that was handled.
Merged to master |
What changes were proposed in this pull request?
Added the number of iterations in
ClusteringSummary
. This is an helpful information in evaluating how to eventually modify the parameters in order to get a better model.How was this patch tested?
modified existing UTs