From fb16b71a96ef55541207b77c9bb9bc49d0a85243 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 12 Jun 2016 01:40:04 +0900 Subject: [PATCH 1/4] Fix incorrect comparison --- .../org/apache/spark/ml/regression/AFTSurvivalRegression.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index e5f23f44bc5ee..7f57af19e9df9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -538,7 +538,7 @@ private class AFTAggregator( * @return This AFTAggregator object. */ def merge(other: AFTAggregator): this.type = { - if (totalCnt != 0) { + if (other.count != 0) { totalCnt += other.totalCnt lossSum += other.lossSum From 4447d0a969229dfe5d6cf1bdfc3c0ac62c1fd53e Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 12 Jun 2016 13:35:39 +0900 Subject: [PATCH 2/4] Add a test --- .../ml/regression/AFTSurvivalRegressionSuite.scala | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala index 05aae80c660ea..4dcc57c2ad3b9 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala @@ -390,6 +390,18 @@ class AFTSurvivalRegressionSuite testEstimatorAndModelReadWrite(aft, datasetMultivariate, AFTSurvivalRegressionSuite.allParamSettings, checkModelData) } + + test("SPARK-15892: Incorrectly merged AFTAggregator with zero total count") { + // This `dataset` will contain a empty partition because it has two rows but + // the parallelism is bigger than that. Because the issue was about `AFTAggregator`s + // being merged incorrectly when it has a empty partition, running the codes below + // should not throw a exception. + val dataset = spark.createDataFrame( + sc.parallelize(generateAFTInput( + 1, Array(5.5), Array(0.8), 1000, 42, 1.0, 2.0, 2.0).take(2), numSlices = 3)) + val trainer = new AFTSurvivalRegression() + trainer.fit(dataset) + } } object AFTSurvivalRegressionSuite { From c86ede8fc913be2713d5f6794741f9fbc51f5169 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Sun, 12 Jun 2016 15:14:56 +0900 Subject: [PATCH 3/4] Fix typos --- .../spark/ml/regression/AFTSurvivalRegressionSuite.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala index 4dcc57c2ad3b9..44e55bbd4ccf3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala @@ -392,10 +392,10 @@ class AFTSurvivalRegressionSuite } test("SPARK-15892: Incorrectly merged AFTAggregator with zero total count") { - // This `dataset` will contain a empty partition because it has two rows but + // This `dataset` will contain an empty partition because it has two rows but // the parallelism is bigger than that. Because the issue was about `AFTAggregator`s - // being merged incorrectly when it has a empty partition, running the codes below - // should not throw a exception. + // being merged incorrectly when it has an empty partition, running the codes below + // should not throw an exception. val dataset = spark.createDataFrame( sc.parallelize(generateAFTInput( 1, Array(5.5), Array(0.8), 1000, 42, 1.0, 2.0, 2.0).take(2), numSlices = 3)) From 2c8adbfe53493b4bcef1a633ab429f6612c67fe5 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sun, 12 Jun 2016 15:43:11 +0900 Subject: [PATCH 4/4] Do not make 1000 AFTInput but just 2 --- .../apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala index 44e55bbd4ccf3..1c70b702de063 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala @@ -398,7 +398,7 @@ class AFTSurvivalRegressionSuite // should not throw an exception. val dataset = spark.createDataFrame( sc.parallelize(generateAFTInput( - 1, Array(5.5), Array(0.8), 1000, 42, 1.0, 2.0, 2.0).take(2), numSlices = 3)) + 1, Array(5.5), Array(0.8), 2, 42, 1.0, 2.0, 2.0), numSlices = 3)) val trainer = new AFTSurvivalRegression() trainer.fit(dataset) }