diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 1184e985a1faf..5f04b369a6a60 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -816,7 +816,15 @@ object DecisionTree extends Serializable with Logging { val maxBins = strategy.maxBins val numBins = if (maxBins <= count) maxBins else count.toInt - logDebug("maxBins = " + numBins) + logDebug("numBins = " + numBins) + + // I will also add a require statement ensuring #bins is always greater than the categories + // It's a limitation of the current implementation but a reasonable tradeoff since features + // with large number of categories get favored over continuous features. + if (strategy.categoricalFeaturesInfo.size > 0){ + val maxCategoriesForFeatures = strategy.categoricalFeaturesInfo.maxBy(_._2)._2 + require(numBins >= maxCategoriesForFeatures) + } // Calculate the number of sample for approximate quantile calculation val requiredSamples = numBins*numBins