diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index a2a3dba213e7f..7749bdd687d1f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -135,6 +135,11 @@ object DecisionTree extends Serializable { */ def isSampleValid(parentFilters: List[Filter], labeledPoint: LabeledPoint): Boolean = { + //Leaf + if (parentFilters.length == 0 ){ + return false + } + for (filter <- parentFilters) { val features = labeledPoint.features val featureIndex = filter.split.feature diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala new file mode 100644 index 0000000000000..4ca02beec03c0 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.mllib.tree.model + +class InformationGainStats(val gain : Double, + val impurity: Double, + val leftImpurity : Double, + val leftSamples : Long, + val rightImpurity : Double, + val rightSamples : Long) { + + override def toString = + "gain = " + gain + ", impurity = " + impurity + ", left impurity = " + + leftImpurity + ", leftSamples = " + leftSamples + ", right impurity = " + + rightImpurity + ", rightSamples = " + rightSamples + + +}