Skip to content

Commit

Permalink
update since versions in mllib.tree
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr committed Aug 26, 2015
1 parent 125205c commit a4781d5
Show file tree
Hide file tree
Showing 12 changed files with 57 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ import org.apache.spark.util.random.XORShiftRandom
*/
@Since("1.0.0")
@Experimental
class DecisionTree (private val strategy: Strategy) extends Serializable with Logging {
class DecisionTree @Since("1.0.0") (private val strategy: Strategy)
extends Serializable with Logging {

strategy.assertValid()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ import org.apache.spark.storage.StorageLevel
*/
@Since("1.2.0")
@Experimental
class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy)
class GradientBoostedTrees @Since("1.2.0") (private val boostingStrategy: BoostingStrategy)
extends Serializable with Logging {

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ import org.apache.spark.annotation.{Experimental, Since}
@Since("1.0.0")
@Experimental
object Algo extends Enumeration {
@Since("1.0.0")
type Algo = Value
@Since("1.0.0")
val Classification, Regression = Value

private[mllib] def fromString(name: String): Algo = name match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,14 @@ import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss}
*/
@Since("1.2.0")
@Experimental
case class BoostingStrategy(
case class BoostingStrategy @Since("1.4.0") (
// Required boosting parameters
@BeanProperty var treeStrategy: Strategy,
@BeanProperty var loss: Loss,
@Since("1.2.0") @BeanProperty var treeStrategy: Strategy,
@Since("1.2.0") @BeanProperty var loss: Loss,
// Optional boosting parameters
@BeanProperty var numIterations: Int = 100,
@BeanProperty var learningRate: Double = 0.1,
@BeanProperty var validationTol: Double = 1e-5) extends Serializable {
@Since("1.2.0") @BeanProperty var numIterations: Int = 100,
@Since("1.2.0") @BeanProperty var learningRate: Double = 0.1,
@Since("1.4.0") @BeanProperty var validationTol: Double = 1e-5) extends Serializable {

/**
* Check validity of parameters.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ import org.apache.spark.annotation.{Experimental, Since}
@Since("1.0.0")
@Experimental
object FeatureType extends Enumeration {
@Since("1.0.0")
type FeatureType = Value
@Since("1.0.0")
val Continuous, Categorical = Value
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ import org.apache.spark.annotation.{Experimental, Since}
@Since("1.0.0")
@Experimental
object QuantileStrategy extends Enumeration {
@Since("1.0.0")
type QuantileStrategy = Value
@Since("1.0.0")
val Sort, MinMax, ApproxHist = Value
}
Original file line number Diff line number Diff line change
Expand Up @@ -69,20 +69,20 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
*/
@Since("1.0.0")
@Experimental
class Strategy (
@BeanProperty var algo: Algo,
@BeanProperty var impurity: Impurity,
@BeanProperty var maxDepth: Int,
@BeanProperty var numClasses: Int = 2,
@BeanProperty var maxBins: Int = 32,
@BeanProperty var quantileCalculationStrategy: QuantileStrategy = Sort,
@BeanProperty var categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](),
@BeanProperty var minInstancesPerNode: Int = 1,
@BeanProperty var minInfoGain: Double = 0.0,
@BeanProperty var maxMemoryInMB: Int = 256,
@BeanProperty var subsamplingRate: Double = 1,
@BeanProperty var useNodeIdCache: Boolean = false,
@BeanProperty var checkpointInterval: Int = 10) extends Serializable {
class Strategy @Since("1.3.0") (
@Since("1.0.0") @BeanProperty var algo: Algo,
@Since("1.0.0") @BeanProperty var impurity: Impurity,
@Since("1.0.0") @BeanProperty var maxDepth: Int,
@Since("1.2.0") @BeanProperty var numClasses: Int = 2,
@Since("1.0.0") @BeanProperty var maxBins: Int = 32,
@Since("1.0.0") @BeanProperty var quantileCalculationStrategy: QuantileStrategy = Sort,
@Since("1.0.0") @BeanProperty var categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](),
@Since("1.2.0") @BeanProperty var minInstancesPerNode: Int = 1,
@Since("1.2.0") @BeanProperty var minInfoGain: Double = 0.0,
@Since("1.0.0") @BeanProperty var maxMemoryInMB: Int = 256,
@Since("1.2.0") @BeanProperty var subsamplingRate: Double = 1,
@Since("1.2.0") @BeanProperty var useNodeIdCache: Boolean = false,
@Since("1.2.0") @BeanProperty var checkpointInterval: Int = 10) extends Serializable {

/**
*/
Expand Down Expand Up @@ -206,6 +206,7 @@ object Strategy {
}

@deprecated("Use Strategy.defaultStrategy instead.", "1.5.0")
@Since("")
def defaultStategy(algo: Algo): Strategy = defaultStrategy(algo)

}
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ import org.apache.spark.util.Utils
*/
@Since("1.0.0")
@Experimental
class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable with Saveable {
class DecisionTreeModel @Since("1.0.0") (
@Since("1.0.0") val topNode: Node,
@Since("1.0.0") val algo: Algo) extends Serializable with Saveable {

/**
* Predict values for a single data point using the model trained.
Expand Down Expand Up @@ -110,6 +112,7 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable
/**
* Print the full model to a string.
*/
@Since("1.2.0")
def toDebugString: String = {
val header = toString + "\n"
header + topNode.subtreeToString(2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,15 @@ import org.apache.spark.mllib.linalg.Vector
*/
@Since("1.0.0")
@DeveloperApi
class Node (
val id: Int,
var predict: Predict,
var impurity: Double,
var isLeaf: Boolean,
var split: Option[Split],
var leftNode: Option[Node],
var rightNode: Option[Node],
var stats: Option[InformationGainStats]) extends Serializable with Logging {
class Node @Since("1.0.0") (
@Since("1.0.0") val id: Int,
@Since("1.0.0") var predict: Predict,
@Since("1.0.0") var impurity: Double,
@Since("1.0.0") var isLeaf: Boolean,
@Since("1.0.0") var split: Option[Split],
@Since("1.0.0") var leftNode: Option[Node],
@Since("1.0.0") var rightNode: Option[Node],
@Since("1.0.0") var stats: Option[InformationGainStats]) extends Serializable with Logging {

override def toString: String = {
s"id = $id, isLeaf = $isLeaf, predict = $predict, impurity = $impurity, " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ import org.apache.spark.annotation.{DeveloperApi, Since}
*/
@Since("1.2.0")
@DeveloperApi
class Predict(
val predict: Double,
val prob: Double = 0.0) extends Serializable {
class Predict @Since("1.2.0") (
@Since("1.2.0") val predict: Double,
@Since("1.2.0") val prob: Double = 0.0) extends Serializable {

override def toString: String = s"$predict (prob = $prob)"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType
@Since("1.0.0")
@DeveloperApi
case class Split(
feature: Int,
threshold: Double,
featureType: FeatureType,
categories: List[Double]) {
@Since("1.0.0") feature: Int,
@Since("1.0.0") threshold: Double,
@Since("1.0.0") featureType: FeatureType,
@Since("1.0.0") categories: List[Double]) {

override def toString: String = {
s"Feature = $feature, threshold = $threshold, featureType = $featureType, " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ import org.apache.spark.util.Utils
*/
@Since("1.2.0")
@Experimental
class RandomForestModel(override val algo: Algo, override val trees: Array[DecisionTreeModel])
class RandomForestModel @Since("1.2.0") (
@Since("1.2.0") override val algo: Algo,
@Since("1.2.0") override val trees: Array[DecisionTreeModel])
extends TreeEnsembleModel(algo, trees, Array.fill(trees.length)(1.0),
combiningStrategy = if (algo == Classification) Vote else Average)
with Saveable {
Expand Down Expand Up @@ -115,10 +117,10 @@ object RandomForestModel extends Loader[RandomForestModel] {
*/
@Since("1.2.0")
@Experimental
class GradientBoostedTreesModel(
override val algo: Algo,
override val trees: Array[DecisionTreeModel],
override val treeWeights: Array[Double])
class GradientBoostedTreesModel @Since("1.2.0") (
@Since("1.2.0") override val algo: Algo,
@Since("1.2.0") override val trees: Array[DecisionTreeModel],
@Since("1.2.0") override val treeWeights: Array[Double])
extends TreeEnsembleModel(algo, trees, treeWeights, combiningStrategy = Sum)
with Saveable {

Expand Down

0 comments on commit a4781d5

Please sign in to comment.