Skip to content

Commit

Permalink
[SPARK-15892][ML] Backport correctly merging AFTAggregators to branch…
Browse files Browse the repository at this point in the history
… 1.6

## What changes were proposed in this pull request?

This PR backports apache#13619.

The original test added in branch-2.0 was failed in branch-1.6.

This seems because the behaviour was changed in apache@101663f. This was failure while calculating Euler's number which ends up with a infinity regardless of this path.

So, I brought the dataset from `AFTSurvivalRegressionExample` to make sure this is working and then wrote the test.

I ran the test before/after creating empty partitions. `model.scale` becomes `1.0` with empty partitions and becames `1.547` without them.

After this patch, this becomes always `1.547`.

## How was this patch tested?

Unit test in `AFTSurvivalRegressionSuite`.

Author: hyukjinkwon <[email protected]>

Closes apache#13725 from HyukjinKwon/SPARK-15892-1-6.
  • Loading branch information
HyukjinKwon authored and mengxr committed Jun 18, 2016
1 parent e530823 commit fd05389
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean)
* @return This AFTAggregator object.
*/
def merge(other: AFTAggregator): this.type = {
if (totalCnt != 0) {
if (other.count != 0) {
totalCnt += other.totalCnt
lossSum += other.lossSum

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,23 @@ class AFTSurvivalRegressionSuite
testEstimatorAndModelReadWrite(aft, datasetMultivariate,
AFTSurvivalRegressionSuite.allParamSettings, checkModelData)
}

test("SPARK-15892: Incorrectly merged AFTAggregator with zero total count") {
// This `dataset` will contain an empty partition because it has five rows but
// the parallelism is bigger than that. Because the issue was about `AFTAggregator`s
// being merged incorrectly when it has an empty partition, the trained model has
// 1.0 scale from Euler's number for 0.
val points = sc.parallelize(Seq(
AFTPoint(Vectors.dense(1.560, -0.605), 1.218, 1.0),
AFTPoint(Vectors.dense(0.346, 2.158), 2.949, 0.0),
AFTPoint(Vectors.dense(1.380, 0.231), 3.627, 0.0),
AFTPoint(Vectors.dense(0.520, 1.151), 0.273, 1.0),
AFTPoint(Vectors.dense(0.795, -0.226), 4.199, 0.0)), numSlices = 6)
val dataset = sqlContext.createDataFrame(points)
val trainer = new AFTSurvivalRegression()
val model = trainer.fit(dataset)
assert(model.scale != 1)
}
}

object AFTSurvivalRegressionSuite {
Expand Down

0 comments on commit fd05389

Please sign in to comment.