-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DLlib] GBT CriteoClickLogsDataset example #5723
Changes from 2 commits
c8a9040
8b3552e
ab485ed
9e188ef
64e262c
29eb62e
3c863fa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,167 @@ | ||
|
||
package com.intel.analytics.bigdl.dllib.example.nnframes.gbt | ||
|
||
import com.intel.analytics.bigdl.dllib.NNContext | ||
import com.intel.analytics.bigdl.dllib.example.nnframes.xgboost.xgbClassifierTrainingExampleOnCriteoClickLogsDataset.feature_nums | ||
import ml.dmlc.xgboost4j.scala.spark.TrackerConf | ||
import org.apache.spark.ml.Pipeline | ||
import org.apache.spark.ml.classification.GBTClassifier | ||
import org.apache.spark.ml.feature.{StringIndexer, VectorAssembler, VectorIndexer} | ||
import org.apache.spark.sql.types.{LongType, StructField, StructType} | ||
import org.apache.spark.sql.{Row, SQLContext} | ||
import scopt.OptionParser | ||
import org.slf4j.{Logger, LoggerFactory} | ||
|
||
|
||
class Task extends Serializable { | ||
|
||
val default_missing_value = "-999" | ||
|
||
def rowToLibsvm(row: Row): String = { | ||
0 until row.length flatMap { | ||
case 0 => Some(row(0).toString) | ||
case i if row(i) == null => Some(default_missing_value) | ||
case i => Some((if (i < 14) row(i) | ||
else java.lang.Long.parseLong(row(i).toString, 16)).toString) | ||
} mkString " " | ||
} | ||
} | ||
|
||
case class Params( | ||
trainingDataPath: String = "/host/data", | ||
modelSavePath: String = "/host/data/model", | ||
numThreads: Int = 2, | ||
maxIter: Int = 100, | ||
maxDepth: Int = 2, | ||
numWorkers: Int = 1 | ||
) | ||
|
||
object gbtClassifierTrainingExampleOnCriteoClickLogsDataset { | ||
|
||
val feature_nums = 39 | ||
|
||
def main(args: Array[String]): Unit = { | ||
|
||
val log: Logger = LoggerFactory.getLogger(this.getClass) | ||
|
||
|
||
// parse params and set value | ||
|
||
val params = parser.parse(args, new Params).get | ||
val trainingDataPath = params.trainingDataPath // path to data | ||
val modelSavePath = params.modelSavePath // save model to this path | ||
val maxIter = params.maxIter // train max Iter | ||
val maxDepth = params.maxDepth // tree max depth | ||
|
||
val sc = NNContext.initNNContext() | ||
val spark = SQLContext.getOrCreate(sc) | ||
|
||
val task = new Task() | ||
|
||
val tStart = System.nanoTime() | ||
// read csv files to dataframe | ||
var df = spark.read.option("header", "false"). | ||
option("inferSchema", "true").option("delimiter", "\t").csv(trainingDataPath) | ||
|
||
val tBeforePreprocess = System.nanoTime() | ||
var elapsed = (tBeforePreprocess - tStart) / 1000000000.0f // second | ||
log.info("--reading data time is " + elapsed + " s") | ||
// preprocess data | ||
val processedRdd = df.rdd.map(task.rowToLibsvm) | ||
|
||
// declare schema | ||
var structFieldArray = new Array[StructField](feature_nums + 1) | ||
for (i <- 0 to feature_nums) { | ||
structFieldArray(i) = StructField("_c" + i.toString, LongType, true) | ||
} | ||
var schema = new StructType(structFieldArray) | ||
|
||
// convert RDD to RDD[Row] | ||
val rowRDD = processedRdd.map(_.split(" ")).map(row => Row.fromSeq( | ||
for { | ||
i <- 0 to feature_nums | ||
} yield { | ||
row(i).toLong | ||
} | ||
)) | ||
// RDD[Row] to Dataframe | ||
df = spark.createDataFrame(rowRDD, schema) | ||
|
||
|
||
val stringIndexer = new StringIndexer() | ||
.setInputCol("_c0") | ||
.setOutputCol("classIndex") | ||
.fit(df) | ||
val labelTransformed = stringIndexer.transform(df).drop("_c0") | ||
|
||
var inputCols = new Array[String](feature_nums) | ||
for (i <- 0 to feature_nums - 1) { | ||
inputCols(i) = "_c" + (i + 1).toString | ||
} | ||
|
||
val vectorAssembler = new VectorAssembler(). | ||
setInputCols(inputCols). | ||
setOutputCol("features") | ||
|
||
val xgbInput = vectorAssembler.transform(labelTransformed).select("features", "classIndex") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do not use xgb |
||
// randomly split dataset to (train, eval1, eval2, test) in proportion 6:2:1:1 | ||
val Array(train, eval1, eval2, test) = xgbInput.randomSplit(Array(0.6, 0.2, 0.1, 0.1)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. xgb |
||
|
||
train.cache().count() | ||
eval1.cache().count() | ||
eval2.cache().count() | ||
|
||
val tBeforeTraining = System.nanoTime() | ||
elapsed = (tBeforeTraining - tBeforePreprocess) / 1000000000.0f // second | ||
log.info("--preprocess time is " + elapsed + " s") | ||
// use scala tracker | ||
// val gbtParam = Map("tracker_conf" -> TrackerConf(0L, "scala"), | ||
// "eval_sets" -> Map("eval1" -> eval1, "eval2" -> eval2) | ||
// ) | ||
|
||
// Train a GBT model. | ||
val gbtClassifier = new GBTClassifier() | ||
gbtClassifier.setFeaturesCol("features") | ||
gbtClassifier.setLabelCol("classIndex") | ||
gbtClassifier.setMaxDepth(maxDepth) | ||
gbtClassifier.setMaxIter(maxIter) | ||
gbtClassifier.setFeatureSubsetStrategy("auto") | ||
|
||
// Train model. This also runs the indexer. | ||
val gbtClassificationModel = gbtClassifier.fit(train) | ||
val tAfterTraining = System.nanoTime() | ||
elapsed = (tAfterTraining - tBeforeTraining) / 1000000000.0f // second | ||
log.info("--training time is " + elapsed + " s") | ||
|
||
gbtClassificationModel.save(modelSavePath) | ||
|
||
val tAfterSave = System.nanoTime() | ||
elapsed = (tAfterSave - tAfterTraining) / 1000000000.0f // second | ||
log.info("--model save time is " + elapsed + " s") | ||
elapsed = (tAfterSave - tStart) / 1000000000.0f // second | ||
log.info("--end-to-end time is " + elapsed + " s") | ||
sc.stop() | ||
} | ||
|
||
val parser: OptionParser[Params] = new OptionParser[Params]("input xgboost config") { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. xgboost |
||
opt[String]('i', "trainingDataPath") | ||
.text("trainingData Path") | ||
.action((v, p) => p.copy(trainingDataPath = v)) | ||
.required() | ||
|
||
opt[String]('s', "modelSavePath") | ||
.text("savePath of model") | ||
.action((v, p) => p.copy(modelSavePath = v)) | ||
.required() | ||
|
||
opt[Int]('I', "maxIter") | ||
.text("maxIter") | ||
.action((v, p) => p.copy(maxIter = v)) | ||
|
||
opt[Int]('d', "maxDepth") | ||
.text("maxDepth") | ||
.action((v, p) => p.copy(maxDepth = v)) | ||
|
||
} | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove xgboost dep