diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index 2f78dd30b3af7..9ed400bd9eafe 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -644,4 +644,5 @@ private class AFTCostFun( */ private[regression] case class AFTPoint(features: Vector, label: Double, censor: Double) { require(censor == 1.0 || censor == 0.0, "censor of class AFTPoint must be 1.0 or 0.0") + require(label > 0.0, "label of AFTPoint must be positive") } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala index 0fdfdf37cf38d..4eb4fc50ef3c4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala @@ -19,14 +19,15 @@ package org.apache.spark.ml.regression import scala.util.Random -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.random.{ExponentialGenerator, WeibullGenerator} import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Row, SparkSession} class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -400,6 +401,19 @@ class AFTSurvivalRegressionSuite val trainer = new AFTSurvivalRegression() trainer.fit(dataset) } + + test("SPARK-19234: Fail fast on zero-valued labels") { + val dataset = spark.createDataFrame(Seq( + (1.218, 1.0, Vectors.dense(1.560, -0.605)), + (0.000, 0.0, Vectors.dense(0.346, 2.158)), // ← generates error; zero labels invalid + (4.199, 0.0, Vectors.dense(0.795, -0.226)))).toDF("label", "censor", "features") + val aft = new AFTSurvivalRegression() + withClue("label of AFTPoint must be positive") { + intercept[SparkException] { + aft.fit(dataset) + } + } + } } object AFTSurvivalRegressionSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala index f1ed568d5e60a..9795550f64ce2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala @@ -64,7 +64,7 @@ object MLTestingUtils extends SparkFunSuite { actuals.foreach(actual => check(expected, actual)) val dfWithStringLabels = spark.createDataFrame(Seq( - ("0", 1, Vectors.dense(0, 2, 3), 0.0) + ("1", 1, Vectors.dense(0, 2, 3), 0.0) )).toDF("label", "weight", "features", "censor") val thrown = intercept[IllegalArgumentException] { estimator.fit(dfWithStringLabels) @@ -156,7 +156,6 @@ object MLTestingUtils extends SparkFunSuite { featuresColName: String = "features", censorColName: String = "censor"): Map[NumericType, DataFrame] = { val df = spark.createDataFrame(Seq( - (0, Vectors.dense(0)), (1, Vectors.dense(1)), (2, Vectors.dense(2)), (3, Vectors.dense(3)),