-
Notifications
You must be signed in to change notification settings - Fork 334
/
Copy path5xa-spark-1hot.txt
66 lines (46 loc) · 2.77 KB
/
5xa-spark-1hot.txt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
spark-1.4.1-bin-hadoop2.4/bin/spark-shell --driver-memory 120G --executor-memory 120G --packages com.databricks:spark-csv_2.11:1.2.0
// from Joseph Bradley https://gist.github.com/jkbradley/1e3cc0b3116f2f615b3f
import org.apache.spark.sql.functions.{col, lit}
import org.apache.spark.sql.types.DoubleType
import org.apache.spark.ml.feature.{OneHotEncoder, StringIndexer, VectorAssembler}
import org.apache.spark.ml.Pipeline
import org.apache.spark.mllib.linalg.Vector
// Paths
val origTrainPath = "train-1m.csv"
val origTestPath = "test.csv"
val newTrainPath = "spark1hot-train-1m.parquet"
val newTestPath = "spark1hot-test.parquet"
// Read CSV as Spark DataFrames
val trainDF = sqlContext.read.format("com.databricks.spark.csv").option("header", "true").load(origTrainPath)
val testDF = sqlContext.read.format("com.databricks.spark.csv").option("header", "true").load(origTestPath)
// Combine train, test temporarily
val fullDF = trainDF.withColumn("isTrain", lit(true)).unionAll(testDF.withColumn("isTrain", lit(false)))
// display(fullDF)
// Feature types
val vars_categ = Array("Month","DayofMonth","DayOfWeek","UniqueCarrier", "Origin", "Dest")
val vars_num = Array("DepTime","Distance")
val vars_num_double = vars_num.map(_ + "_double")
val var_y = "dep_delayed_15min"
// Cast column types as needed
val fullDF2 = fullDF.withColumn("DepTime_double", col("DepTime").cast(DoubleType)).withColumn("Distance_double", col("Distance").cast(DoubleType))
// display(fullDF2)
// Assemble Pipeline for featurization.
// Need to use StringIndexer for OneHotEncoder since it does not yet support String input (but it will).
val stringIndexers = vars_categ.map(colName => new StringIndexer().setInputCol(colName).setOutputCol(colName + "_indexed"))
val oneHotEncoders = vars_categ.map(colName => new OneHotEncoder().setInputCol(colName + "_indexed").setOutputCol(colName + "_ohe").setDropLast(false))
val catAssembler = new VectorAssembler().setInputCols(vars_categ.map(_ + "_ohe")).setOutputCol("catFeatures")
val featureAssembler = new VectorAssembler().setInputCols(vars_num_double :+ "catFeatures").setOutputCol("features")
val labelIndexer = new StringIndexer().setInputCol(var_y).setOutputCol("label")
val pipeline = new Pipeline().setStages(stringIndexers ++ oneHotEncoders ++ Array(catAssembler, featureAssembler, labelIndexer))
// Compute features.
val pipelineModel = pipeline.fit(fullDF2)
val transformedDF = pipelineModel.transform(fullDF2)
// display(transformedDF)
// Split back into train, test
val finalTrainDF = transformedDF.where(col("isTrain"))
val finalTestDF = transformedDF.where(!col("isTrain"))
// Save Spark DataFrames as Parquet
finalTrainDF.write.mode("overwrite").parquet(newTrainPath)
finalTestDF.write.mode("overwrite").parquet(newTestPath)
// finalTrainDF.printSchema()
// finalTrainDF.first()