diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index aedfb48058dc5..cc1d19e4a81ff 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -496,7 +496,7 @@ private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean) * @return This AFTAggregator object. */ def merge(other: AFTAggregator): this.type = { - if (totalCnt != 0) { + if (other.count != 0) { totalCnt += other.totalCnt lossSum += other.lossSum diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala index d718ef63b531a..70f9693b4e96b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala @@ -346,6 +346,23 @@ class AFTSurvivalRegressionSuite testEstimatorAndModelReadWrite(aft, datasetMultivariate, AFTSurvivalRegressionSuite.allParamSettings, checkModelData) } + + test("SPARK-15892: Incorrectly merged AFTAggregator with zero total count") { + // This `dataset` will contain an empty partition because it has five rows but + // the parallelism is bigger than that. Because the issue was about `AFTAggregator`s + // being merged incorrectly when it has an empty partition, the trained model has + // 1.0 scale from Euler's number for 0. + val points = sc.parallelize(Seq( + AFTPoint(Vectors.dense(1.560, -0.605), 1.218, 1.0), + AFTPoint(Vectors.dense(0.346, 2.158), 2.949, 0.0), + AFTPoint(Vectors.dense(1.380, 0.231), 3.627, 0.0), + AFTPoint(Vectors.dense(0.520, 1.151), 0.273, 1.0), + AFTPoint(Vectors.dense(0.795, -0.226), 4.199, 0.0)), numSlices = 6) + val dataset = sqlContext.createDataFrame(points) + val trainer = new AFTSurvivalRegression() + val model = trainer.fit(dataset) + assert(model.scale != 1) + } } object AFTSurvivalRegressionSuite {