-
Notifications
You must be signed in to change notification settings - Fork 0
/
LigthGbmUsage.scala
140 lines (114 loc) · 5.35 KB
/
LigthGbmUsage.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
package spark.playground
import com.microsoft.ml.spark.core.metrics.MetricConstants
import com.microsoft.ml.spark.lightgbm.{LightGBMClassificationModel, LightGBMClassifier}
import com.microsoft.ml.spark.train.ComputeModelStatistics
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.mllib.evaluation.{BinaryClassificationMetrics, MulticlassMetrics}
import org.apache.spark.rdd.RDD
object LigthGbmUsage extends LocalSparkContext {
def main(args: Array[String]): Unit = {
val df = spark.sqlContext.read.parquet("C:/Users/User/Downloads/part-00000-711560fe-8fdd-4777-a379-b52996fd212d-c000.gz.parquet")
df.show()
val featureColumns = Array[String]()
val seed: Long = 41L
val labelColumn = "label"
val predictedColumn = "Predicted" + labelColumn
val featureColumn = "scaledFeatures"
val lgbm = new LightGBMClassifier()
.setBaggingSeed(seed.toInt)
.setLabelCol(labelColumn)
.setFeaturesCol(featureColumn)
.setPredictionCol(predictedColumn)
val splits = df.randomSplit(Array(0.8, 0.2), seed = seed)
val training = splits(0).cache()
training.groupBy("label").count.show
val test = splits(1).cache()
test.groupBy("label").count.show
val stages = Array(lgbm)
val pipeline = new Pipeline().setStages(stages)
//We fit our DataFrame into the pipeline to generate a model
val model = pipeline.fit(training)
//We'll make predictions using the model and the test data
val predictions = model.transform(test)
//This will evaluate the error/deviation of the regression using the Root Mean Squared deviation
val evaluator = new BinaryClassificationEvaluator()
.setLabelCol(labelColumn)
.setMetricName("areaUnderROC")
println(evaluator.evaluate(predictions))
val evaluatedData = new ComputeModelStatistics()
.setLabelCol("label")
.setScoredLabelsCol(predictedColumn)
.setEvaluationMetric(MetricConstants.ClassificationMetricsName)
.transform(predictions)
evaluatedData.show(false)
model.stages.collect {
case lgbmModel: LightGBMClassificationModel =>
// println("Learned regression GBT model:\n" + lgbmModel.toDebugString)
println("gain")
MetricHelper.getFeatureImportance(lgbmModel, featureColumns, "gain").foreach { case (importance, name) =>
println(s"$name: $importance")
}
println("split")
MetricHelper.getFeatureImportance(lgbmModel, featureColumns, "split").foreach { case (importance, name) =>
println(s"$name: $importance")
}
case x =>
println(s"Don't know $x of type ${x.getClass}")
}
}
}
object MetricHelper {
def showMetrics(estimatorName: String, hashCode: Int, scoreAndLabels: RDD[(Double, Double)]): Double = {
val binaryMetrics = new BinaryClassificationMetrics(scoreAndLabels)
// AUPRC
val auPRC = binaryMetrics.areaUnderPR
// AUROC
val auROC = binaryMetrics.areaUnderROC
// F-measure
val f1Score = binaryMetrics.fMeasureByThreshold.map { case (t, f) =>
s"Threshold: $t, F-score: $f, Beta = 1"
}.collect()
val multiClassMetrics = new MulticlassMetrics(scoreAndLabels)
val accuracy = multiClassMetrics.accuracy
val confusionMatrixTable: String = getConfusionMatrixTable(multiClassMetrics)
val fp = multiClassMetrics.confusionMatrix(0, 1)
val str =
s"""Estimator: $estimatorName [$hashCode]
|\tFalse Positive: ${fp / scoreAndLabels.count() * 100}%%
|\tF-measure:
|\t\t${f1Score.mkString("\n\t\t")}
|\tArea under precision-recall curve = $auPRC
|\tArea under ROC = $auROC
|\tConfusion matrix:
|%s
|\tAccuracy = $accuracy""".stripMargin.format("\t\t" + confusionMatrixTable.replace("\n", "\n\t\t"))
println(str) //scalastyle:ignore
accuracy
}
def getConfusionMatrixTable(multiClassMetrics: MulticlassMetrics): String = {
val cm = multiClassMetrics.confusionMatrix
val tn = cm(0, 0)
val fp = cm(0, 1)
val fn = cm(1, 0)
val tp = cm(1, 1)
val firstColumn = Array[String]("Actual = 0", tn.toString, fp.toString)
val secondColumn = Array[String]("Actual = 1", fn.toString, tp.toString)
val fcl = firstColumn.maxBy(x => x.length).length - 1 // -1 because we added extra space in template
val scl = secondColumn.maxBy(x => x.length).length - 1
val confusionMatrixTableTemplate =
"+---------------+------------+------------+\n" +
"| ############# | Actual = 1 | Actual = 0 |\n" +
"+---------------+------------+------------+\n" +
s"| Predicted = 1 | %-${fcl}d | %-${scl}d |\n" +
"+---------------+------------+------------+\n" +
s"| Predicted = 0 | %-${fcl}d | %-${scl}d |\n" +
"+---------------+------------+------------+"
val confusionMatrixTable = confusionMatrixTableTemplate.format(tp.toLong, fp.toLong, fn.toLong, tn.toLong)
confusionMatrixTable
}
def getFeatureImportance(gbtModel: LightGBMClassificationModel, allColumns: Array[String], importanceType: String): Array[(Double, String)] = {
val featureImportance = gbtModel.getFeatureImportances(importanceType).zipWithIndex.map { case (importance, idx) => (importance, allColumns.lift(idx).getOrElse("")) }
featureImportance.filter { case (importance, _) => importance > 0 }.sortBy(x => -x._1)
}
}