Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-5972] [MLlib] Cache residuals and gradient in GBT during training and validation #5330

Closed
wants to merge 7 commits into from

Conversation

MechCoder
Copy link
Contributor

The previous PR #4906 helped to extract the learning curve giving the error for each iteration. This continues the work refactoring some code and extending the same logic during training and validation.

@MechCoder
Copy link
Contributor Author

ping @jkbradley

@SparkQA
Copy link

SparkQA commented Apr 2, 2015

Test build #29605 has finished for PR 5330 at commit 100850a.

  • This patch fails PySpark unit tests.
  • This patch merges cleanly.
  • This patch adds no public classes.
  • This patch does not change any dependencies.

@SparkQA
Copy link

SparkQA commented Apr 2, 2015

Test build #29609 has finished for PR 5330 at commit 57cd906.

  • This patch fails Spark unit tests.
  • This patch merges cleanly.
  • This patch adds no public classes.
  • This patch does not change any dependencies.

@MechCoder
Copy link
Contributor Author

This unrelated test failure related to YARN keeps recurring for me.

@SparkQA
Copy link

SparkQA commented Apr 2, 2015

Test build #29612 has finished for PR 5330 at commit 923dbf6.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.
  • This patch does not change any dependencies.

@jkbradley
Copy link
Member

@MechCoder Thanks! I'll make a pass through this soon; at first glance, it looks good. Have you tested this vs. the old implementation? I'm wondering how big a difference there is, and also how big the problem has to be for that difference to be evident.

@MechCoder
Copy link
Contributor Author

I do not have access to a cluster as said before. It would be great if you had some old benchmarks. However it seems it should not matter a lot at least for small n_iterations. But I suppose it would be good to have it anyway to avoid unnecessary re computation (trivial or not) just like before.
Thanks :)

logDebug("error of gbt = " + loss.computeError(startingModel, input))

var predError: RDD[(Double, Double)] = GradientBoostedTreesModel.
computeInitialPredictionAndError(input, 1.0, firstTreeModel, loss)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use "baseLearnerWeights(0)" instead of "1.0"

@jkbradley
Copy link
Member

@MechCoder For this, I feel like local tests might be sufficient since they should show the speedup and since this isn't changing the communication that much. My main worry is about RDD having long lineages; I made a JIRA today about that, but that can be addressed later on: [https://issues.apache.org/jira/browse/SPARK-6684]

@MechCoder
Copy link
Contributor Author

I've fixed up your comments.

It seems that the local tests seem to run in the same time (+ or - 1s), this may due to the fact that numIterations and data size are comparatively low, to take advantage of this.

I can work on the other issue after this is merged.

@SparkQA
Copy link

SparkQA commented Apr 3, 2015

Test build #29660 has finished for PR 5330 at commit c0869e7.

  • This patch fails Spark unit tests.
  • This patch merges cleanly.
  • This patch adds no public classes.
  • This patch does not change any dependencies.

@jkbradley
Copy link
Member

After that one small change, I think this will be ready to merge. Thanks!

@MechCoder
Copy link
Contributor Author

@jkbradley I've fixed up your comment!

Thanks for the info. Just to clarify, does the previous code work because a copy of the broadcast variable in the driver node persists even after unpersisting and it broadcasts repeatedly for each action to PredError . Source (http://stackoverflow.com/a/24587558/1170730)?

@SparkQA
Copy link

SparkQA commented Apr 11, 2015

Test build #30083 has finished for PR 5330 at commit 58f4932.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.
  • This patch does not change any dependencies.

@@ -27,6 +27,7 @@ import org.json4s.jackson.JsonMethods._
import org.apache.spark.{Logging, SparkContext}
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.broadcast.Broadcast
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no longer needed

@jkbradley
Copy link
Member

Your explanation for why the previous code worked is correct. It's just doing extra communication.

I just noticed a few things when opening up this PR in IntelliJ (yay error highlighting). That really should be it, though.

I'm running a speed test to see if I can tell a difference between this and the previous code. I'll post again later today.

@MechCoder
Copy link
Contributor Author

@jkbradley fixed! hopefully that should be it :P

@jkbradley
Copy link
Member

That is it, but I ran some timing tests locally and found essentially no difference between the two implementations, like you reported. I think the issue is the overhead of broadcast variables. I tried broadcasting the full arrays for evaluateEachIteration(), rather than each element separately, and it made evaluateEachIteration() take about 2/3 of the original time. This was with depth-2 trees and 100 iterations of regression on a tiny test dataset ("abalone" from libsvm's copy of the UCI dataset.

Based on this bit of testing, I would guess the best solution will be to handle learning and evaluateEachIteration separately:

  • Learning: Do not broadcast trees or weights (but do use the caching you implemented here).
    • Communicating 1 tree should be a tiny cost compared to the cost of learning the tree.
  • evaluateEachIteration: Broadcast full tree array. Compute errors for all iterations in a single map() call, and aggregate arrays of errors rather than individual errors.
    • Don't broadcast the weights array since it is small.

I'm OK with merging this PR for now and making those items a future to-do. But if you'd prefer to make these updates to this PR, that works too.

Do you agree with this assessment? What path would you prefer?

@MechCoder
Copy link
Contributor Author

Thanks for the tests. I get the gist of what you mean.

I'd be happy to merge this PR and work on this as a future JIRA. If I have any queries, I shall comment on that.

@MechCoder
Copy link
Contributor Author

Never mind, figured out, it should not take much effort, working on an update to this PR itself.

@SparkQA
Copy link

SparkQA commented Apr 13, 2015

Test build #30137 has finished for PR 5330 at commit d542bb0.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.
  • This patch does not change any dependencies.

@MechCoder
Copy link
Contributor Author

@jkbradley I pushed an update to the same PR. I agree with the observation, that it would have a much higher impact on evaluateEachIteration, because during training, prediction (and computing the residuals) is not really the bottleneck

@SparkQA
Copy link

SparkQA commented Apr 13, 2015

Test build #30142 has finished for PR 5330 at commit 32d409d.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.
  • This patch does not change any dependencies.

val newError = loss.computeError(newPred, point.label)
(newPred, newError)
}
val currentTreeWeight = treeWeights(nTree)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should make a local (shallow) copy of treeWeights before the map, within this method:

val localTreeWeights = treeWeights

Referencing treeWeights, a member of the class, will actually make the entire class get serialized by the ClosureCleaner. Assigning it to a local val fixes that.

@jkbradley
Copy link
Member

@MechCoder Thanks a lot for working through all of these tweaks with me! The updates look good except for those 2 items

@MechCoder
Copy link
Contributor Author

@jkbradley fixed!

@jkbradley
Copy link
Member

LGTM once tests pass. Thanks!

@SparkQA
Copy link

SparkQA commented Apr 13, 2015

Test build #30185 has finished for PR 5330 at commit 0b5d659.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.
  • This patch does not change any dependencies.

@asfgit asfgit closed this in 2a55cb4 Apr 13, 2015
@jkbradley
Copy link
Member

Merged into master

@MechCoder MechCoder deleted the spark-5972 branch April 14, 2015 02:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants