diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index 769f4406e2dff..cfbc902e55898 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -118,7 +118,7 @@ object QuantileDiscretizer extends DefaultParamsReadable[QuantileDiscretizer] wi require(totalSamples > 0, "QuantileDiscretizer requires non-empty input dataset but was given an empty input.") val requiredSamples = math.max(numBins * numBins, minSamplesRequired) - val fraction = math.min(requiredSamples.toDouble / dataset.count(), 1.0) + val fraction = math.min(requiredSamples.toDouble / totalSamples, 1.0) dataset.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt()).collect() }