diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION
index 052f68c6c24e2..1c1779a763c7e 100644
--- a/R/pkg/DESCRIPTION
+++ b/R/pkg/DESCRIPTION
@@ -19,7 +19,7 @@ Collate:
'jobj.R'
'RDD.R'
'pairRDD.R'
- 'SQLTypes.R'
+ 'schema.R'
'column.R'
'group.R'
'DataFrame.R'
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index a354cdce74afa..80283643861ac 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -5,6 +5,7 @@ exportMethods(
"aggregateByKey",
"aggregateRDD",
"cache",
+ "cartesian",
"checkpoint",
"coalesce",
"cogroup",
@@ -28,6 +29,7 @@ exportMethods(
"fullOuterJoin",
"glom",
"groupByKey",
+ "intersection",
"join",
"keyBy",
"keys",
@@ -52,11 +54,14 @@ exportMethods(
"reduceByKeyLocally",
"repartition",
"rightOuterJoin",
+ "sampleByKey",
"sampleRDD",
"saveAsTextFile",
"saveAsObjectFile",
"sortBy",
"sortByKey",
+ "subtract",
+ "subtractByKey",
"sumRDD",
"take",
"takeOrdered",
@@ -95,6 +100,7 @@ exportClasses("DataFrame")
exportMethods("columns",
"distinct",
"dtypes",
+ "except",
"explain",
"filter",
"groupBy",
@@ -118,7 +124,6 @@ exportMethods("columns",
"show",
"showDF",
"sortDF",
- "subtract",
"toJSON",
"toRDD",
"unionAll",
@@ -178,5 +183,14 @@ export("cacheTable",
"toDF",
"uncacheTable")
-export("print.structType",
- "print.structField")
+export("sparkRSQL.init",
+ "sparkRHive.init")
+
+export("structField",
+ "structField.jobj",
+ "structField.character",
+ "print.structField",
+ "structType",
+ "structType.jobj",
+ "structType.structField",
+ "print.structType")
diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R
index 044fdb4d01223..861fe1c78b0db 100644
--- a/R/pkg/R/DataFrame.R
+++ b/R/pkg/R/DataFrame.R
@@ -17,7 +17,7 @@
# DataFrame.R - DataFrame class and methods implemented in S4 OO classes
-#' @include generics.R jobj.R SQLTypes.R RDD.R pairRDD.R column.R group.R
+#' @include generics.R jobj.R schema.R RDD.R pairRDD.R column.R group.R
NULL
setOldClass("jobj")
@@ -1141,15 +1141,15 @@ setMethod("intersect",
dataFrame(intersected)
})
-#' Subtract
+#' except
#'
#' Return a new DataFrame containing rows in this DataFrame
#' but not in another DataFrame. This is equivalent to `EXCEPT` in SQL.
#'
#' @param x A Spark DataFrame
#' @param y A Spark DataFrame
-#' @return A DataFrame containing the result of the subtract operation.
-#' @rdname subtract
+#' @return A DataFrame containing the result of the except operation.
+#' @rdname except
#' @export
#' @examples
#'\dontrun{
@@ -1157,13 +1157,15 @@ setMethod("intersect",
#' sqlCtx <- sparkRSQL.init(sc)
#' df1 <- jsonFile(sqlCtx, path)
#' df2 <- jsonFile(sqlCtx, path2)
-#' subtractDF <- subtract(df, df2)
+#' exceptDF <- except(df, df2)
#' }
-setMethod("subtract",
+#' @rdname except
+#' @export
+setMethod("except",
signature(x = "DataFrame", y = "DataFrame"),
function(x, y) {
- subtracted <- callJMethod(x@sdf, "except", y@sdf)
- dataFrame(subtracted)
+ excepted <- callJMethod(x@sdf, "except", y@sdf)
+ dataFrame(excepted)
})
#' Save the contents of the DataFrame to a data source
diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R
index 820027ef67e3b..128431334ca52 100644
--- a/R/pkg/R/RDD.R
+++ b/R/pkg/R/RDD.R
@@ -730,6 +730,7 @@ setMethod("take",
index <- -1
jrdd <- getJRDD(x)
numPartitions <- numPartitions(x)
+ serializedModeRDD <- getSerializedMode(x)
# TODO(shivaram): Collect more than one partition based on size
# estimates similar to the scala version of `take`.
@@ -748,13 +749,14 @@ setMethod("take",
elems <- convertJListToRList(partition,
flatten = TRUE,
logicalUpperBound = size,
- serializedMode = getSerializedMode(x))
- # TODO: Check if this append is O(n^2)?
+ serializedMode = serializedModeRDD)
+
resList <- append(resList, elems)
}
resList
})
+
#' First
#'
#' Return the first element of an RDD
@@ -1092,21 +1094,42 @@ takeOrderedElem <- function(x, num, ascending = TRUE) {
if (num < length(part)) {
# R limitation: order works only on primitive types!
ord <- order(unlist(part, recursive = FALSE), decreasing = !ascending)
- list(part[ord[1:num]])
+ part[ord[1:num]]
} else {
- list(part)
+ part
}
}
- reduceFunc <- function(elems, part) {
- newElems <- append(elems, part)
- # R limitation: order works only on primitive types!
- ord <- order(unlist(newElems, recursive = FALSE), decreasing = !ascending)
- newElems[ord[1:num]]
- }
-
newRdd <- mapPartitions(x, partitionFunc)
- reduce(newRdd, reduceFunc)
+
+ resList <- list()
+ index <- -1
+ jrdd <- getJRDD(newRdd)
+ numPartitions <- numPartitions(newRdd)
+ serializedModeRDD <- getSerializedMode(newRdd)
+
+ while (TRUE) {
+ index <- index + 1
+
+ if (index >= numPartitions) {
+ ord <- order(unlist(resList, recursive = FALSE), decreasing = !ascending)
+ resList <- resList[ord[1:num]]
+ break
+ }
+
+ # a JList of byte arrays
+ partitionArr <- callJMethod(jrdd, "collectPartitions", as.list(as.integer(index)))
+ partition <- partitionArr[[1]]
+
+ # elems is capped to have at most `num` elements
+ elems <- convertJListToRList(partition,
+ flatten = TRUE,
+ logicalUpperBound = num,
+ serializedMode = serializedModeRDD)
+
+ resList <- append(resList, elems)
+ }
+ resList
}
#' Returns the first N elements from an RDD in ascending order.
@@ -1465,67 +1488,105 @@ setMethod("zipRDD",
stop("Can only zip RDDs which have the same number of partitions.")
}
- if (getSerializedMode(x) != getSerializedMode(other) ||
- getSerializedMode(x) == "byte") {
- # Append the number of elements in each partition to that partition so that we can later
- # check if corresponding partitions of both RDDs have the same number of elements.
- #
- # Note that this appending also serves the purpose of reserialization, because even if
- # any RDD is serialized, we need to reserialize it to make sure its partitions are encoded
- # as a single byte array. For example, partitions of an RDD generated from partitionBy()
- # may be encoded as multiple byte arrays.
- appendLength <- function(part) {
- part[[length(part) + 1]] <- length(part) + 1
- part
- }
- x <- lapplyPartition(x, appendLength)
- other <- lapplyPartition(other, appendLength)
- }
+ rdds <- appendPartitionLengths(x, other)
+ jrdd <- callJMethod(getJRDD(rdds[[1]]), "zip", getJRDD(rdds[[2]]))
+ # The jrdd's elements are of scala Tuple2 type. The serialized
+ # flag here is used for the elements inside the tuples.
+ rdd <- RDD(jrdd, getSerializedMode(rdds[[1]]))
- zippedJRDD <- callJMethod(getJRDD(x), "zip", getJRDD(other))
- # The zippedRDD's elements are of scala Tuple2 type. The serialized
- # flag Here is used for the elements inside the tuples.
- serializerMode <- getSerializedMode(x)
- zippedRDD <- RDD(zippedJRDD, serializerMode)
+ mergePartitions(rdd, TRUE)
+ })
+
+#' Cartesian product of this RDD and another one.
+#'
+#' Return the Cartesian product of this RDD and another one,
+#' that is, the RDD of all pairs of elements (a, b) where a
+#' is in this and b is in other.
+#'
+#' @param x An RDD.
+#' @param other An RDD.
+#' @return A new RDD which is the Cartesian product of these two RDDs.
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' rdd <- parallelize(sc, 1:2)
+#' sortByKey(cartesian(rdd, rdd))
+#' # list(list(1, 1), list(1, 2), list(2, 1), list(2, 2))
+#'}
+#' @rdname cartesian
+#' @aliases cartesian,RDD,RDD-method
+setMethod("cartesian",
+ signature(x = "RDD", other = "RDD"),
+ function(x, other) {
+ rdds <- appendPartitionLengths(x, other)
+ jrdd <- callJMethod(getJRDD(rdds[[1]]), "cartesian", getJRDD(rdds[[2]]))
+ # The jrdd's elements are of scala Tuple2 type. The serialized
+ # flag here is used for the elements inside the tuples.
+ rdd <- RDD(jrdd, getSerializedMode(rdds[[1]]))
- partitionFunc <- function(split, part) {
- len <- length(part)
- if (len > 0) {
- if (serializerMode == "byte") {
- lengthOfValues <- part[[len]]
- lengthOfKeys <- part[[len - lengthOfValues]]
- stopifnot(len == lengthOfKeys + lengthOfValues)
-
- # check if corresponding partitions of both RDDs have the same number of elements.
- if (lengthOfKeys != lengthOfValues) {
- stop("Can only zip RDDs with same number of elements in each pair of corresponding partitions.")
- }
-
- if (lengthOfKeys > 1) {
- keys <- part[1 : (lengthOfKeys - 1)]
- values <- part[(lengthOfKeys + 1) : (len - 1)]
- } else {
- keys <- list()
- values <- list()
- }
- } else {
- # Keys, values must have same length here, because this has
- # been validated inside the JavaRDD.zip() function.
- keys <- part[c(TRUE, FALSE)]
- values <- part[c(FALSE, TRUE)]
- }
- mapply(
- function(k, v) {
- list(k, v)
- },
- keys,
- values,
- SIMPLIFY = FALSE,
- USE.NAMES = FALSE)
- } else {
- part
- }
+ mergePartitions(rdd, FALSE)
+ })
+
+#' Subtract an RDD with another RDD.
+#'
+#' Return an RDD with the elements from this that are not in other.
+#'
+#' @param x An RDD.
+#' @param other An RDD.
+#' @param numPartitions Number of the partitions in the result RDD.
+#' @return An RDD with the elements from this that are not in other.
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' rdd1 <- parallelize(sc, list(1, 1, 2, 2, 3, 4))
+#' rdd2 <- parallelize(sc, list(2, 4))
+#' collect(subtract(rdd1, rdd2))
+#' # list(1, 1, 3)
+#'}
+#' @rdname subtract
+#' @aliases subtract,RDD
+setMethod("subtract",
+ signature(x = "RDD", other = "RDD"),
+ function(x, other, numPartitions = SparkR::numPartitions(x)) {
+ mapFunction <- function(e) { list(e, NA) }
+ rdd1 <- map(x, mapFunction)
+ rdd2 <- map(other, mapFunction)
+ keys(subtractByKey(rdd1, rdd2, numPartitions))
+ })
+
+#' Intersection of this RDD and another one.
+#'
+#' Return the intersection of this RDD and another one.
+#' The output will not contain any duplicate elements,
+#' even if the input RDDs did. Performs a hash partition
+#' across the cluster.
+#' Note that this method performs a shuffle internally.
+#'
+#' @param x An RDD.
+#' @param other An RDD.
+#' @param numPartitions The number of partitions in the result RDD.
+#' @return An RDD which is the intersection of these two RDDs.
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' rdd1 <- parallelize(sc, list(1, 10, 2, 3, 4, 5))
+#' rdd2 <- parallelize(sc, list(1, 6, 2, 3, 7, 8))
+#' collect(sortBy(intersection(rdd1, rdd2), function(x) { x }))
+#' # list(1, 2, 3)
+#'}
+#' @rdname intersection
+#' @aliases intersection,RDD
+setMethod("intersection",
+ signature(x = "RDD", other = "RDD"),
+ function(x, other, numPartitions = SparkR::numPartitions(x)) {
+ rdd1 <- map(x, function(v) { list(v, NA) })
+ rdd2 <- map(other, function(v) { list(v, NA) })
+
+ filterFunction <- function(elem) {
+ iters <- elem[[2]]
+ all(as.vector(
+ lapply(iters, function(iter) { length(iter) > 0 }), mode = "logical"))
}
-
- PipelinedRDD(zippedRDD, partitionFunc)
+
+ keys(filterRDD(cogroup(rdd1, rdd2, numPartitions = numPartitions), filterFunction))
})
diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R
index 930ada22f4c38..4f05ba524a01a 100644
--- a/R/pkg/R/SQLContext.R
+++ b/R/pkg/R/SQLContext.R
@@ -54,9 +54,9 @@ infer_type <- function(x) {
# StructType
types <- lapply(x, infer_type)
fields <- lapply(1:length(x), function(i) {
- list(name = names[[i]], type = types[[i]], nullable = TRUE)
+ structField(names[[i]], types[[i]], TRUE)
})
- list(type = "struct", fields = fields)
+ do.call(structType, fields)
}
} else if (length(x) > 1) {
list(type = "array", elementType = type, containsNull = TRUE)
@@ -65,30 +65,6 @@ infer_type <- function(x) {
}
}
-#' dump the schema into JSON string
-tojson <- function(x) {
- if (is.list(x)) {
- names <- names(x)
- if (!is.null(names)) {
- items <- lapply(names, function(n) {
- safe_n <- gsub('"', '\\"', n)
- paste(tojson(safe_n), ':', tojson(x[[n]]), sep = '')
- })
- d <- paste(items, collapse = ', ')
- paste('{', d, '}', sep = '')
- } else {
- l <- paste(lapply(x, tojson), collapse = ', ')
- paste('[', l, ']', sep = '')
- }
- } else if (is.character(x)) {
- paste('"', x, '"', sep = '')
- } else if (is.logical(x)) {
- if (x) "true" else "false"
- } else {
- stop(paste("unexpected type:", class(x)))
- }
-}
-
#' Create a DataFrame from an RDD
#'
#' Converts an RDD to a DataFrame by infer the types.
@@ -134,7 +110,7 @@ createDataFrame <- function(sqlCtx, data, schema = NULL, samplingRatio = 1.0) {
stop(paste("unexpected type:", class(data)))
}
- if (is.null(schema) || is.null(names(schema))) {
+ if (is.null(schema) || (!inherits(schema, "structType") && is.null(names(schema)))) {
row <- first(rdd)
names <- if (is.null(schema)) {
names(row)
@@ -143,7 +119,7 @@ createDataFrame <- function(sqlCtx, data, schema = NULL, samplingRatio = 1.0) {
}
if (is.null(names)) {
names <- lapply(1:length(row), function(x) {
- paste("_", as.character(x), sep = "")
+ paste("_", as.character(x), sep = "")
})
}
@@ -159,20 +135,18 @@ createDataFrame <- function(sqlCtx, data, schema = NULL, samplingRatio = 1.0) {
types <- lapply(row, infer_type)
fields <- lapply(1:length(row), function(i) {
- list(name = names[[i]], type = types[[i]], nullable = TRUE)
+ structField(names[[i]], types[[i]], TRUE)
})
- schema <- list(type = "struct", fields = fields)
+ schema <- do.call(structType, fields)
}
- stopifnot(class(schema) == "list")
- stopifnot(schema$type == "struct")
- stopifnot(class(schema$fields) == "list")
- schemaString <- tojson(schema)
+ stopifnot(class(schema) == "structType")
+ # schemaString <- tojson(schema)
jrdd <- getJRDD(lapply(rdd, function(x) x), "row")
srdd <- callJMethod(jrdd, "rdd")
sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "createDF",
- srdd, schemaString, sqlCtx)
+ srdd, schema$jobj, sqlCtx)
dataFrame(sdf)
}
diff --git a/R/pkg/R/SQLTypes.R b/R/pkg/R/SQLTypes.R
deleted file mode 100644
index 962fba5b3cf03..0000000000000
--- a/R/pkg/R/SQLTypes.R
+++ /dev/null
@@ -1,64 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements. See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-# Utility functions for handling SparkSQL DataTypes.
-
-# Handler for StructType
-structType <- function(st) {
- obj <- structure(new.env(parent = emptyenv()), class = "structType")
- obj$jobj <- st
- obj$fields <- function() { lapply(callJMethod(st, "fields"), structField) }
- obj
-}
-
-#' Print a Spark StructType.
-#'
-#' This function prints the contents of a StructType returned from the
-#' SparkR JVM backend.
-#'
-#' @param x A StructType object
-#' @param ... further arguments passed to or from other methods
-print.structType <- function(x, ...) {
- fieldsList <- lapply(x$fields(), function(i) { i$print() })
- print(fieldsList)
-}
-
-# Handler for StructField
-structField <- function(sf) {
- obj <- structure(new.env(parent = emptyenv()), class = "structField")
- obj$jobj <- sf
- obj$name <- function() { callJMethod(sf, "name") }
- obj$dataType <- function() { callJMethod(sf, "dataType") }
- obj$dataType.toString <- function() { callJMethod(obj$dataType(), "toString") }
- obj$dataType.simpleString <- function() { callJMethod(obj$dataType(), "simpleString") }
- obj$nullable <- function() { callJMethod(sf, "nullable") }
- obj$print <- function() { paste("StructField(",
- paste(obj$name(), obj$dataType.toString(), obj$nullable(), sep = ", "),
- ")", sep = "") }
- obj
-}
-
-#' Print a Spark StructField.
-#'
-#' This function prints the contents of a StructField returned from the
-#' SparkR JVM backend.
-#'
-#' @param x A StructField object
-#' @param ... further arguments passed to or from other methods
-print.structField <- function(x, ...) {
- cat(x$print())
-}
diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R
index b282001d8b6b5..95fb9ff0887b6 100644
--- a/R/pkg/R/column.R
+++ b/R/pkg/R/column.R
@@ -17,7 +17,7 @@
# Column Class
-#' @include generics.R jobj.R SQLTypes.R
+#' @include generics.R jobj.R schema.R
NULL
setOldClass("jobj")
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 5fb1ccaa84ee2..6c6233390134c 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -230,6 +230,10 @@ setGeneric("zipWithUniqueId", function(x) { standardGeneric("zipWithUniqueId") }
############ Binary Functions #############
+#' @rdname cartesian
+#' @export
+setGeneric("cartesian", function(x, other) { standardGeneric("cartesian") })
+
#' @rdname countByKey
#' @export
setGeneric("countByKey", function(x) { standardGeneric("countByKey") })
@@ -238,6 +242,11 @@ setGeneric("countByKey", function(x) { standardGeneric("countByKey") })
#' @export
setGeneric("flatMapValues", function(X, FUN) { standardGeneric("flatMapValues") })
+#' @rdname intersection
+#' @export
+setGeneric("intersection", function(x, other, numPartitions = 1L) {
+ standardGeneric("intersection") })
+
#' @rdname keys
#' @export
setGeneric("keys", function(x) { standardGeneric("keys") })
@@ -250,12 +259,18 @@ setGeneric("lookup", function(x, key) { standardGeneric("lookup") })
#' @export
setGeneric("mapValues", function(X, FUN) { standardGeneric("mapValues") })
+#' @rdname sampleByKey
+#' @export
+setGeneric("sampleByKey",
+ function(x, withReplacement, fractions, seed) {
+ standardGeneric("sampleByKey")
+ })
+
#' @rdname values
#' @export
setGeneric("values", function(x) { standardGeneric("values") })
-
############ Shuffle Functions ############
#' @rdname aggregateByKey
@@ -330,9 +345,24 @@ setGeneric("rightOuterJoin", function(x, y, numPartitions) { standardGeneric("ri
#' @rdname sortByKey
#' @export
-setGeneric("sortByKey", function(x, ascending = TRUE, numPartitions = 1L) {
- standardGeneric("sortByKey")
-})
+setGeneric("sortByKey",
+ function(x, ascending = TRUE, numPartitions = 1L) {
+ standardGeneric("sortByKey")
+ })
+
+#' @rdname subtract
+#' @export
+setGeneric("subtract",
+ function(x, other, numPartitions = 1L) {
+ standardGeneric("subtract")
+ })
+
+#' @rdname subtractByKey
+#' @export
+setGeneric("subtractByKey",
+ function(x, other, numPartitions = 1L) {
+ standardGeneric("subtractByKey")
+ })
################### Broadcast Variable Methods #################
@@ -357,6 +387,10 @@ setGeneric("dtypes", function(x) { standardGeneric("dtypes") })
#' @export
setGeneric("explain", function(x, ...) { standardGeneric("explain") })
+#' @rdname except
+#' @export
+setGeneric("except", function(x, y) { standardGeneric("except") })
+
#' @rdname filter
#' @export
setGeneric("filter", function(x, condition) { standardGeneric("filter") })
@@ -434,10 +468,6 @@ setGeneric("showDF", function(x,...) { standardGeneric("showDF") })
#' @export
setGeneric("sortDF", function(x, col, ...) { standardGeneric("sortDF") })
-#' @rdname subtract
-#' @export
-setGeneric("subtract", function(x, y) { standardGeneric("subtract") })
-
#' @rdname tojson
#' @export
setGeneric("toJSON", function(x) { standardGeneric("toJSON") })
diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R
index 855fbdfc7c4ca..02237b3672d6b 100644
--- a/R/pkg/R/group.R
+++ b/R/pkg/R/group.R
@@ -17,7 +17,7 @@
# group.R - GroupedData class and methods implemented in S4 OO classes
-#' @include generics.R jobj.R SQLTypes.R column.R
+#' @include generics.R jobj.R schema.R column.R
NULL
setOldClass("jobj")
diff --git a/R/pkg/R/pairRDD.R b/R/pkg/R/pairRDD.R
index 5d64822859d1f..13efebc11c46e 100644
--- a/R/pkg/R/pairRDD.R
+++ b/R/pkg/R/pairRDD.R
@@ -430,7 +430,7 @@ setMethod("combineByKey",
pred <- function(item) exists(item$hash, keys)
lapply(part,
function(item) {
- item$hash <- as.character(item[[1]])
+ item$hash <- as.character(hashCode(item[[1]]))
updateOrCreatePair(item, keys, combiners, pred, mergeValue, createCombiner)
})
convertEnvsToList(keys, combiners)
@@ -443,7 +443,7 @@ setMethod("combineByKey",
pred <- function(item) exists(item$hash, keys)
lapply(part,
function(item) {
- item$hash <- as.character(item[[1]])
+ item$hash <- as.character(hashCode(item[[1]]))
updateOrCreatePair(item, keys, combiners, pred, mergeCombiners, identity)
})
convertEnvsToList(keys, combiners)
@@ -452,19 +452,19 @@ setMethod("combineByKey",
})
#' Aggregate a pair RDD by each key.
-#'
+#'
#' Aggregate the values of each key in an RDD, using given combine functions
#' and a neutral "zero value". This function can return a different result type,
#' U, than the type of the values in this RDD, V. Thus, we need one operation
-#' for merging a V into a U and one operation for merging two U's, The former
-#' operation is used for merging values within a partition, and the latter is
-#' used for merging values between partitions. To avoid memory allocation, both
-#' of these functions are allowed to modify and return their first argument
+#' for merging a V into a U and one operation for merging two U's, The former
+#' operation is used for merging values within a partition, and the latter is
+#' used for merging values between partitions. To avoid memory allocation, both
+#' of these functions are allowed to modify and return their first argument
#' instead of creating a new U.
-#'
+#'
#' @param x An RDD.
#' @param zeroValue A neutral "zero value".
-#' @param seqOp A function to aggregate the values of each key. It may return
+#' @param seqOp A function to aggregate the values of each key. It may return
#' a different result type from the type of the values.
#' @param combOp A function to aggregate results of seqOp.
#' @return An RDD containing the aggregation result.
@@ -476,7 +476,7 @@ setMethod("combineByKey",
#' zeroValue <- list(0, 0)
#' seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) }
#' combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) }
-#' aggregateByKey(rdd, zeroValue, seqOp, combOp, 2L)
+#' aggregateByKey(rdd, zeroValue, seqOp, combOp, 2L)
#' # list(list(1, list(3, 2)), list(2, list(7, 2)))
#'}
#' @rdname aggregateByKey
@@ -493,12 +493,12 @@ setMethod("aggregateByKey",
})
#' Fold a pair RDD by each key.
-#'
+#'
#' Aggregate the values of each key in an RDD, using an associative function "func"
-#' and a neutral "zero value" which may be added to the result an arbitrary
-#' number of times, and must not change the result (e.g., 0 for addition, or
+#' and a neutral "zero value" which may be added to the result an arbitrary
+#' number of times, and must not change the result (e.g., 0 for addition, or
#' 1 for multiplication.).
-#'
+#'
#' @param x An RDD.
#' @param zeroValue A neutral "zero value".
#' @param func An associative function for folding values of each key.
@@ -548,11 +548,11 @@ setMethod("join",
function(x, y, numPartitions) {
xTagged <- lapply(x, function(i) { list(i[[1]], list(1L, i[[2]])) })
yTagged <- lapply(y, function(i) { list(i[[1]], list(2L, i[[2]])) })
-
+
doJoin <- function(v) {
joinTaggedList(v, list(FALSE, FALSE))
}
-
+
joined <- flatMapValues(groupByKey(unionRDD(xTagged, yTagged), numToInt(numPartitions)),
doJoin)
})
@@ -568,8 +568,8 @@ setMethod("join",
#' @param y An RDD to be joined. Should be an RDD where each element is
#' list(K, V).
#' @param numPartitions Number of partitions to create.
-#' @return For each element (k, v) in x, the resulting RDD will either contain
-#' all pairs (k, (v, w)) for (k, w) in rdd2, or the pair (k, (v, NULL))
+#' @return For each element (k, v) in x, the resulting RDD will either contain
+#' all pairs (k, (v, w)) for (k, w) in rdd2, or the pair (k, (v, NULL))
#' if no elements in rdd2 have key k.
#' @examples
#'\dontrun{
@@ -586,11 +586,11 @@ setMethod("leftOuterJoin",
function(x, y, numPartitions) {
xTagged <- lapply(x, function(i) { list(i[[1]], list(1L, i[[2]])) })
yTagged <- lapply(y, function(i) { list(i[[1]], list(2L, i[[2]])) })
-
+
doJoin <- function(v) {
joinTaggedList(v, list(FALSE, TRUE))
}
-
+
joined <- flatMapValues(groupByKey(unionRDD(xTagged, yTagged), numPartitions), doJoin)
})
@@ -623,18 +623,18 @@ setMethod("rightOuterJoin",
function(x, y, numPartitions) {
xTagged <- lapply(x, function(i) { list(i[[1]], list(1L, i[[2]])) })
yTagged <- lapply(y, function(i) { list(i[[1]], list(2L, i[[2]])) })
-
+
doJoin <- function(v) {
joinTaggedList(v, list(TRUE, FALSE))
}
-
+
joined <- flatMapValues(groupByKey(unionRDD(xTagged, yTagged), numPartitions), doJoin)
})
#' Full outer join two RDDs
#'
#' @description
-#' \code{fullouterjoin} This function full-outer-joins two RDDs where every element is of the form list(K, V).
+#' \code{fullouterjoin} This function full-outer-joins two RDDs where every element is of the form list(K, V).
#' The key types of the two RDDs should be the same.
#'
#' @param x An RDD to be joined. Should be an RDD where each element is
@@ -644,7 +644,7 @@ setMethod("rightOuterJoin",
#' @param numPartitions Number of partitions to create.
#' @return For each element (k, v) in x and (k, w) in y, the resulting RDD
#' will contain all pairs (k, (v, w)) for both (k, v) in x and
-#' (k, w) in y, or the pair (k, (NULL, w))/(k, (v, NULL)) if no elements
+#' (k, w) in y, or the pair (k, (NULL, w))/(k, (v, NULL)) if no elements
#' in x/y have key k.
#' @examples
#'\dontrun{
@@ -683,7 +683,7 @@ setMethod("fullOuterJoin",
#' sc <- sparkR.init()
#' rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4)))
#' rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3)))
-#' cogroup(rdd1, rdd2, numPartitions = 2L)
+#' cogroup(rdd1, rdd2, numPartitions = 2L)
#' # list(list(1, list(1, list(2, 3))), list(2, list(list(4), list()))
#'}
#' @rdname cogroup
@@ -694,7 +694,7 @@ setMethod("cogroup",
rdds <- list(...)
rddsLen <- length(rdds)
for (i in 1:rddsLen) {
- rdds[[i]] <- lapply(rdds[[i]],
+ rdds[[i]] <- lapply(rdds[[i]],
function(x) { list(x[[1]], list(i, x[[2]])) })
}
union.rdd <- Reduce(unionRDD, rdds)
@@ -719,7 +719,7 @@ setMethod("cogroup",
}
})
}
- cogroup.rdd <- mapValues(groupByKey(union.rdd, numPartitions),
+ cogroup.rdd <- mapValues(groupByKey(union.rdd, numPartitions),
group.func)
})
@@ -741,18 +741,18 @@ setMethod("sortByKey",
signature(x = "RDD"),
function(x, ascending = TRUE, numPartitions = SparkR::numPartitions(x)) {
rangeBounds <- list()
-
+
if (numPartitions > 1) {
rddSize <- count(x)
# constant from Spark's RangePartitioner
maxSampleSize <- numPartitions * 20
fraction <- min(maxSampleSize / max(rddSize, 1), 1.0)
-
+
samples <- collect(keys(sampleRDD(x, FALSE, fraction, 1L)))
-
+
# Note: the built-in R sort() function only works on atomic vectors
samples <- sort(unlist(samples, recursive = FALSE), decreasing = !ascending)
-
+
if (length(samples) > 0) {
rangeBounds <- lapply(seq_len(numPartitions - 1),
function(i) {
@@ -764,24 +764,146 @@ setMethod("sortByKey",
rangePartitionFunc <- function(key) {
partition <- 0
-
+
# TODO: Use binary search instead of linear search, similar with Spark
while (partition < length(rangeBounds) && key > rangeBounds[[partition + 1]]) {
partition <- partition + 1
}
-
+
if (ascending) {
partition
} else {
numPartitions - partition - 1
}
}
-
+
partitionFunc <- function(part) {
sortKeyValueList(part, decreasing = !ascending)
}
-
+
newRDD <- partitionBy(x, numPartitions, rangePartitionFunc)
lapplyPartition(newRDD, partitionFunc)
})
+#' Subtract a pair RDD with another pair RDD.
+#'
+#' Return an RDD with the pairs from x whose keys are not in other.
+#'
+#' @param x An RDD.
+#' @param other An RDD.
+#' @param numPartitions Number of the partitions in the result RDD.
+#' @return An RDD with the pairs from x whose keys are not in other.
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' rdd1 <- parallelize(sc, list(list("a", 1), list("b", 4),
+#' list("b", 5), list("a", 2)))
+#' rdd2 <- parallelize(sc, list(list("a", 3), list("c", 1)))
+#' collect(subtractByKey(rdd1, rdd2))
+#' # list(list("b", 4), list("b", 5))
+#'}
+#' @rdname subtractByKey
+#' @aliases subtractByKey,RDD
+setMethod("subtractByKey",
+ signature(x = "RDD", other = "RDD"),
+ function(x, other, numPartitions = SparkR::numPartitions(x)) {
+ filterFunction <- function(elem) {
+ iters <- elem[[2]]
+ (length(iters[[1]]) > 0) && (length(iters[[2]]) == 0)
+ }
+
+ flatMapValues(filterRDD(cogroup(x,
+ other,
+ numPartitions = numPartitions),
+ filterFunction),
+ function (v) { v[[1]] })
+ })
+
+#' Return a subset of this RDD sampled by key.
+#'
+#' @description
+#' \code{sampleByKey} Create a sample of this RDD using variable sampling rates
+#' for different keys as specified by fractions, a key to sampling rate map.
+#'
+#' @param x The RDD to sample elements by key, where each element is
+#' list(K, V) or c(K, V).
+#' @param withReplacement Sampling with replacement or not
+#' @param fraction The (rough) sample target fraction
+#' @param seed Randomness seed value
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' rdd <- parallelize(sc, 1:3000)
+#' pairs <- lapply(rdd, function(x) { if (x %% 3 == 0) list("a", x)
+#' else { if (x %% 3 == 1) list("b", x) else list("c", x) }})
+#' fractions <- list(a = 0.2, b = 0.1, c = 0.3)
+#' sample <- sampleByKey(pairs, FALSE, fractions, 1618L)
+#' 100 < length(lookup(sample, "a")) && 300 > length(lookup(sample, "a")) # TRUE
+#' 50 < length(lookup(sample, "b")) && 150 > length(lookup(sample, "b")) # TRUE
+#' 200 < length(lookup(sample, "c")) && 400 > length(lookup(sample, "c")) # TRUE
+#' lookup(sample, "a")[which.min(lookup(sample, "a"))] >= 0 # TRUE
+#' lookup(sample, "a")[which.max(lookup(sample, "a"))] <= 2000 # TRUE
+#' lookup(sample, "b")[which.min(lookup(sample, "b"))] >= 0 # TRUE
+#' lookup(sample, "b")[which.max(lookup(sample, "b"))] <= 2000 # TRUE
+#' lookup(sample, "c")[which.min(lookup(sample, "c"))] >= 0 # TRUE
+#' lookup(sample, "c")[which.max(lookup(sample, "c"))] <= 2000 # TRUE
+#' fractions <- list(a = 0.2, b = 0.1, c = 0.3, d = 0.4)
+#' sample <- sampleByKey(pairs, FALSE, fractions, 1618L) # Key "d" will be ignored
+#' fractions <- list(a = 0.2, b = 0.1)
+#' sample <- sampleByKey(pairs, FALSE, fractions, 1618L) # KeyError: "c"
+#'}
+#' @rdname sampleByKey
+#' @aliases sampleByKey,RDD-method
+setMethod("sampleByKey",
+ signature(x = "RDD", withReplacement = "logical",
+ fractions = "vector", seed = "integer"),
+ function(x, withReplacement, fractions, seed) {
+
+ for (elem in fractions) {
+ if (elem < 0.0) {
+ stop(paste("Negative fraction value ", fractions[which(fractions == elem)]))
+ }
+ }
+
+ # The sampler: takes a partition and returns its sampled version.
+ samplingFunc <- function(split, part) {
+ set.seed(bitwXor(seed, split))
+ res <- vector("list", length(part))
+ len <- 0
+
+ # mixing because the initial seeds are close to each other
+ runif(10)
+
+ for (elem in part) {
+ if (elem[[1]] %in% names(fractions)) {
+ frac <- as.numeric(fractions[which(elem[[1]] == names(fractions))])
+ if (withReplacement) {
+ count <- rpois(1, frac)
+ if (count > 0) {
+ res[(len + 1):(len + count)] <- rep(list(elem), count)
+ len <- len + count
+ }
+ } else {
+ if (runif(1) < frac) {
+ len <- len + 1
+ res[[len]] <- elem
+ }
+ }
+ } else {
+ stop("KeyError: \"", elem[[1]], "\"")
+ }
+ }
+
+ # TODO(zongheng): look into the performance of the current
+ # implementation. Look into some iterator package? Note that
+ # Scala avoids many calls to creating an empty list and PySpark
+ # similarly achieves this using `yield'. (duplicated from sampleRDD)
+ if (len > 0) {
+ res[1:len]
+ } else {
+ list()
+ }
+ }
+
+ lapplyPartitionsWithIndex(x, samplingFunc)
+ })
diff --git a/R/pkg/R/schema.R b/R/pkg/R/schema.R
new file mode 100644
index 0000000000000..e442119086b17
--- /dev/null
+++ b/R/pkg/R/schema.R
@@ -0,0 +1,162 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# A set of S3 classes and methods that support the SparkSQL `StructType` and `StructField
+# datatypes. These are used to create and interact with DataFrame schemas.
+
+#' structType
+#'
+#' Create a structType object that contains the metadata for a DataFrame. Intended for
+#' use with createDataFrame and toDF.
+#'
+#' @param x a structField object (created with the field() function)
+#' @param ... additional structField objects
+#' @return a structType object
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' sqlCtx <- sparkRSQL.init(sc)
+#' rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) })
+#' schema <- structType(structField("a", "integer"), structField("b", "string"))
+#' df <- createDataFrame(sqlCtx, rdd, schema)
+#' }
+structType <- function(x, ...) {
+ UseMethod("structType", x)
+}
+
+structType.jobj <- function(x) {
+ obj <- structure(list(), class = "structType")
+ obj$jobj <- x
+ obj$fields <- function() { lapply(callJMethod(obj$jobj, "fields"), structField) }
+ obj
+}
+
+structType.structField <- function(x, ...) {
+ fields <- list(x, ...)
+ if (!all(sapply(fields, inherits, "structField"))) {
+ stop("All arguments must be structField objects.")
+ }
+ sfObjList <- lapply(fields, function(field) {
+ field$jobj
+ })
+ stObj <- callJStatic("org.apache.spark.sql.api.r.SQLUtils",
+ "createStructType",
+ listToSeq(sfObjList))
+ structType(stObj)
+}
+
+#' Print a Spark StructType.
+#'
+#' This function prints the contents of a StructType returned from the
+#' SparkR JVM backend.
+#'
+#' @param x A StructType object
+#' @param ... further arguments passed to or from other methods
+print.structType <- function(x, ...) {
+ cat("StructType\n",
+ sapply(x$fields(), function(field) { paste("|-", "name = \"", field$name(),
+ "\", type = \"", field$dataType.toString(),
+ "\", nullable = ", field$nullable(), "\n",
+ sep = "") })
+ , sep = "")
+}
+
+#' structField
+#'
+#' Create a structField object that contains the metadata for a single field in a schema.
+#'
+#' @param x The name of the field
+#' @param type The data type of the field
+#' @param nullable A logical vector indicating whether or not the field is nullable
+#' @return a structField object
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' sqlCtx <- sparkRSQL.init(sc)
+#' rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) })
+#' field1 <- structField("a", "integer", TRUE)
+#' field2 <- structField("b", "string", TRUE)
+#' schema <- structType(field1, field2)
+#' df <- createDataFrame(sqlCtx, rdd, schema)
+#' }
+
+structField <- function(x, ...) {
+ UseMethod("structField", x)
+}
+
+structField.jobj <- function(x) {
+ obj <- structure(list(), class = "structField")
+ obj$jobj <- x
+ obj$name <- function() { callJMethod(x, "name") }
+ obj$dataType <- function() { callJMethod(x, "dataType") }
+ obj$dataType.toString <- function() { callJMethod(obj$dataType(), "toString") }
+ obj$dataType.simpleString <- function() { callJMethod(obj$dataType(), "simpleString") }
+ obj$nullable <- function() { callJMethod(x, "nullable") }
+ obj
+}
+
+structField.character <- function(x, type, nullable = TRUE) {
+ if (class(x) != "character") {
+ stop("Field name must be a string.")
+ }
+ if (class(type) != "character") {
+ stop("Field type must be a string.")
+ }
+ if (class(nullable) != "logical") {
+ stop("nullable must be either TRUE or FALSE")
+ }
+ options <- c("byte",
+ "integer",
+ "double",
+ "numeric",
+ "character",
+ "string",
+ "binary",
+ "raw",
+ "logical",
+ "boolean",
+ "timestamp",
+ "date")
+ dataType <- if (type %in% options) {
+ type
+ } else {
+ stop(paste("Unsupported type for Dataframe:", type))
+ }
+ sfObj <- callJStatic("org.apache.spark.sql.api.r.SQLUtils",
+ "createStructField",
+ x,
+ dataType,
+ nullable)
+ structField(sfObj)
+}
+
+#' Print a Spark StructField.
+#'
+#' This function prints the contents of a StructField returned from the
+#' SparkR JVM backend.
+#'
+#' @param x A StructField object
+#' @param ... further arguments passed to or from other methods
+print.structField <- function(x, ...) {
+ cat("StructField(name = \"", x$name(),
+ "\", type = \"", x$dataType.toString(),
+ "\", nullable = ", x$nullable(),
+ ")",
+ sep = "")
+}
diff --git a/R/pkg/R/serialize.R b/R/pkg/R/serialize.R
index 8a9c0c652ce24..c53d0a961016f 100644
--- a/R/pkg/R/serialize.R
+++ b/R/pkg/R/serialize.R
@@ -69,8 +69,9 @@ writeJobj <- function(con, value) {
}
writeString <- function(con, value) {
- writeInt(con, as.integer(nchar(value) + 1))
- writeBin(value, con, endian = "big")
+ utfVal <- enc2utf8(value)
+ writeInt(con, as.integer(nchar(utfVal, type = "bytes") + 1))
+ writeBin(utfVal, con, endian = "big")
}
writeInt <- function(con, value) {
@@ -189,7 +190,3 @@ writeArgs <- function(con, args) {
}
}
}
-
-writeStrings <- function(con, stringList) {
- writeLines(unlist(stringList), con)
-}
diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R
index c337fb0751e72..23305d3c67074 100644
--- a/R/pkg/R/utils.R
+++ b/R/pkg/R/utils.R
@@ -465,3 +465,83 @@ cleanClosure <- function(func, checkedFuncs = new.env()) {
}
func
}
+
+# Append partition lengths to each partition in two input RDDs if needed.
+# param
+# x An RDD.
+# Other An RDD.
+# return value
+# A list of two result RDDs.
+appendPartitionLengths <- function(x, other) {
+ if (getSerializedMode(x) != getSerializedMode(other) ||
+ getSerializedMode(x) == "byte") {
+ # Append the number of elements in each partition to that partition so that we can later
+ # know the boundary of elements from x and other.
+ #
+ # Note that this appending also serves the purpose of reserialization, because even if
+ # any RDD is serialized, we need to reserialize it to make sure its partitions are encoded
+ # as a single byte array. For example, partitions of an RDD generated from partitionBy()
+ # may be encoded as multiple byte arrays.
+ appendLength <- function(part) {
+ len <- length(part)
+ part[[len + 1]] <- len + 1
+ part
+ }
+ x <- lapplyPartition(x, appendLength)
+ other <- lapplyPartition(other, appendLength)
+ }
+ list (x, other)
+}
+
+# Perform zip or cartesian between elements from two RDDs in each partition
+# param
+# rdd An RDD.
+# zip A boolean flag indicating this call is for zip operation or not.
+# return value
+# A result RDD.
+mergePartitions <- function(rdd, zip) {
+ serializerMode <- getSerializedMode(rdd)
+ partitionFunc <- function(split, part) {
+ len <- length(part)
+ if (len > 0) {
+ if (serializerMode == "byte") {
+ lengthOfValues <- part[[len]]
+ lengthOfKeys <- part[[len - lengthOfValues]]
+ stopifnot(len == lengthOfKeys + lengthOfValues)
+
+ # For zip operation, check if corresponding partitions of both RDDs have the same number of elements.
+ if (zip && lengthOfKeys != lengthOfValues) {
+ stop("Can only zip RDDs with same number of elements in each pair of corresponding partitions.")
+ }
+
+ if (lengthOfKeys > 1) {
+ keys <- part[1 : (lengthOfKeys - 1)]
+ } else {
+ keys <- list()
+ }
+ if (lengthOfValues > 1) {
+ values <- part[(lengthOfKeys + 1) : (len - 1)]
+ } else {
+ values <- list()
+ }
+
+ if (!zip) {
+ return(mergeCompactLists(keys, values))
+ }
+ } else {
+ keys <- part[c(TRUE, FALSE)]
+ values <- part[c(FALSE, TRUE)]
+ }
+ mapply(
+ function(k, v) { list(k, v) },
+ keys,
+ values,
+ SIMPLIFY = FALSE,
+ USE.NAMES = FALSE)
+ } else {
+ part
+ }
+ }
+
+ PipelinedRDD(rdd, partitionFunc)
+}
diff --git a/R/pkg/inst/tests/test_rdd.R b/R/pkg/inst/tests/test_rdd.R
index b76e4db03e715..3ba7d1716302a 100644
--- a/R/pkg/inst/tests/test_rdd.R
+++ b/R/pkg/inst/tests/test_rdd.R
@@ -35,7 +35,7 @@ test_that("get number of partitions in RDD", {
test_that("first on RDD", {
expect_true(first(rdd) == 1)
newrdd <- lapply(rdd, function(x) x + 1)
- expect_true(first(newrdd) == 2)
+ expect_true(first(newrdd) == 2)
})
test_that("count and length on RDD", {
@@ -48,7 +48,7 @@ test_that("count by values and keys", {
actual <- countByValue(mods)
expected <- list(list(0, 3L), list(1, 4L), list(2, 3L))
expect_equal(sortKeyValueList(actual), sortKeyValueList(expected))
-
+
actual <- countByKey(intRdd)
expected <- list(list(2L, 2L), list(1L, 2L))
expect_equal(sortKeyValueList(actual), sortKeyValueList(expected))
@@ -82,11 +82,11 @@ test_that("filterRDD on RDD", {
filtered.rdd <- filterRDD(rdd, function(x) { x %% 2 == 0 })
actual <- collect(filtered.rdd)
expect_equal(actual, list(2, 4, 6, 8, 10))
-
+
filtered.rdd <- Filter(function(x) { x[[2]] < 0 }, intRdd)
actual <- collect(filtered.rdd)
expect_equal(actual, list(list(1L, -1)))
-
+
# Filter out all elements.
filtered.rdd <- filterRDD(rdd, function(x) { x > 10 })
actual <- collect(filtered.rdd)
@@ -96,7 +96,7 @@ test_that("filterRDD on RDD", {
test_that("lookup on RDD", {
vals <- lookup(intRdd, 1L)
expect_equal(vals, list(-1, 200))
-
+
vals <- lookup(intRdd, 3L)
expect_equal(vals, list())
})
@@ -110,7 +110,7 @@ test_that("several transformations on RDD (a benchmark on PipelinedRDD)", {
})
rdd2 <- lapply(rdd2, function(x) x + x)
actual <- collect(rdd2)
- expected <- list(24, 24, 24, 24, 24,
+ expected <- list(24, 24, 24, 24, 24,
168, 170, 172, 174, 176)
expect_equal(actual, expected)
})
@@ -248,10 +248,10 @@ test_that("flatMapValues() on pairwise RDDs", {
l <- parallelize(sc, list(list(1, c(1,2)), list(2, c(3,4))))
actual <- collect(flatMapValues(l, function(x) { x }))
expect_equal(actual, list(list(1,1), list(1,2), list(2,3), list(2,4)))
-
+
# Generate x to x+1 for every value
actual <- collect(flatMapValues(intRdd, function(x) { x:(x + 1) }))
- expect_equal(actual,
+ expect_equal(actual,
list(list(1L, -1), list(1L, 0), list(2L, 100), list(2L, 101),
list(2L, 1), list(2L, 2), list(1L, 200), list(1L, 201)))
})
@@ -348,7 +348,7 @@ test_that("top() on RDDs", {
rdd <- parallelize(sc, l)
actual <- top(rdd, 6L)
expect_equal(actual, as.list(sort(unlist(l), decreasing = TRUE))[1:6])
-
+
l <- list("e", "d", "c", "d", "a")
rdd <- parallelize(sc, l)
actual <- top(rdd, 3L)
@@ -358,7 +358,7 @@ test_that("top() on RDDs", {
test_that("fold() on RDDs", {
actual <- fold(rdd, 0, "+")
expect_equal(actual, Reduce("+", nums, 0))
-
+
rdd <- parallelize(sc, list())
actual <- fold(rdd, 0, "+")
expect_equal(actual, 0)
@@ -371,7 +371,7 @@ test_that("aggregateRDD() on RDDs", {
combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) }
actual <- aggregateRDD(rdd, zeroValue, seqOp, combOp)
expect_equal(actual, list(10, 4))
-
+
rdd <- parallelize(sc, list())
actual <- aggregateRDD(rdd, zeroValue, seqOp, combOp)
expect_equal(actual, list(0, 0))
@@ -380,13 +380,13 @@ test_that("aggregateRDD() on RDDs", {
test_that("zipWithUniqueId() on RDDs", {
rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L)
actual <- collect(zipWithUniqueId(rdd))
- expected <- list(list("a", 0), list("b", 3), list("c", 1),
+ expected <- list(list("a", 0), list("b", 3), list("c", 1),
list("d", 4), list("e", 2))
expect_equal(actual, expected)
-
+
rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 1L)
actual <- collect(zipWithUniqueId(rdd))
- expected <- list(list("a", 0), list("b", 1), list("c", 2),
+ expected <- list(list("a", 0), list("b", 1), list("c", 2),
list("d", 3), list("e", 4))
expect_equal(actual, expected)
})
@@ -394,13 +394,13 @@ test_that("zipWithUniqueId() on RDDs", {
test_that("zipWithIndex() on RDDs", {
rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L)
actual <- collect(zipWithIndex(rdd))
- expected <- list(list("a", 0), list("b", 1), list("c", 2),
+ expected <- list(list("a", 0), list("b", 1), list("c", 2),
list("d", 3), list("e", 4))
expect_equal(actual, expected)
-
+
rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 1L)
actual <- collect(zipWithIndex(rdd))
- expected <- list(list("a", 0), list("b", 1), list("c", 2),
+ expected <- list(list("a", 0), list("b", 1), list("c", 2),
list("d", 3), list("e", 4))
expect_equal(actual, expected)
})
@@ -427,12 +427,12 @@ test_that("pipeRDD() on RDDs", {
actual <- collect(pipeRDD(rdd, "more"))
expected <- as.list(as.character(1:10))
expect_equal(actual, expected)
-
+
trailed.rdd <- parallelize(sc, c("1", "", "2\n", "3\n\r\n"))
actual <- collect(pipeRDD(trailed.rdd, "sort"))
expected <- list("", "1", "2", "3")
expect_equal(actual, expected)
-
+
rev.nums <- 9:0
rev.rdd <- parallelize(sc, rev.nums, 2L)
actual <- collect(pipeRDD(rev.rdd, "sort"))
@@ -446,11 +446,11 @@ test_that("zipRDD() on RDDs", {
actual <- collect(zipRDD(rdd1, rdd2))
expect_equal(actual,
list(list(0, 1000), list(1, 1001), list(2, 1002), list(3, 1003), list(4, 1004)))
-
+
mockFile = c("Spark is pretty.", "Spark is awesome.")
fileName <- tempfile(pattern="spark-test", fileext=".tmp")
writeLines(mockFile, fileName)
-
+
rdd <- textFile(sc, fileName, 1)
actual <- collect(zipRDD(rdd, rdd))
expected <- lapply(mockFile, function(x) { list(x ,x) })
@@ -465,10 +465,125 @@ test_that("zipRDD() on RDDs", {
actual <- collect(zipRDD(rdd, rdd1))
expected <- lapply(mockFile, function(x) { list(x, x) })
expect_equal(actual, expected)
-
+
+ unlink(fileName)
+})
+
+test_that("cartesian() on RDDs", {
+ rdd <- parallelize(sc, 1:3)
+ actual <- collect(cartesian(rdd, rdd))
+ expect_equal(sortKeyValueList(actual),
+ list(
+ list(1, 1), list(1, 2), list(1, 3),
+ list(2, 1), list(2, 2), list(2, 3),
+ list(3, 1), list(3, 2), list(3, 3)))
+
+ # test case where one RDD is empty
+ emptyRdd <- parallelize(sc, list())
+ actual <- collect(cartesian(rdd, emptyRdd))
+ expect_equal(actual, list())
+
+ mockFile = c("Spark is pretty.", "Spark is awesome.")
+ fileName <- tempfile(pattern="spark-test", fileext=".tmp")
+ writeLines(mockFile, fileName)
+
+ rdd <- textFile(sc, fileName)
+ actual <- collect(cartesian(rdd, rdd))
+ expected <- list(
+ list("Spark is awesome.", "Spark is pretty."),
+ list("Spark is awesome.", "Spark is awesome."),
+ list("Spark is pretty.", "Spark is pretty."),
+ list("Spark is pretty.", "Spark is awesome."))
+ expect_equal(sortKeyValueList(actual), expected)
+
+ rdd1 <- parallelize(sc, 0:1)
+ actual <- collect(cartesian(rdd1, rdd))
+ expect_equal(sortKeyValueList(actual),
+ list(
+ list(0, "Spark is pretty."),
+ list(0, "Spark is awesome."),
+ list(1, "Spark is pretty."),
+ list(1, "Spark is awesome.")))
+
+ rdd1 <- map(rdd, function(x) { x })
+ actual <- collect(cartesian(rdd, rdd1))
+ expect_equal(sortKeyValueList(actual), expected)
+
unlink(fileName)
})
+test_that("subtract() on RDDs", {
+ l <- list(1, 1, 2, 2, 3, 4)
+ rdd1 <- parallelize(sc, l)
+
+ # subtract by itself
+ actual <- collect(subtract(rdd1, rdd1))
+ expect_equal(actual, list())
+
+ # subtract by an empty RDD
+ rdd2 <- parallelize(sc, list())
+ actual <- collect(subtract(rdd1, rdd2))
+ expect_equal(as.list(sort(as.vector(actual, mode="integer"))),
+ l)
+
+ rdd2 <- parallelize(sc, list(2, 4))
+ actual <- collect(subtract(rdd1, rdd2))
+ expect_equal(as.list(sort(as.vector(actual, mode="integer"))),
+ list(1, 1, 3))
+
+ l <- list("a", "a", "b", "b", "c", "d")
+ rdd1 <- parallelize(sc, l)
+ rdd2 <- parallelize(sc, list("b", "d"))
+ actual <- collect(subtract(rdd1, rdd2))
+ expect_equal(as.list(sort(as.vector(actual, mode="character"))),
+ list("a", "a", "c"))
+})
+
+test_that("subtractByKey() on pairwise RDDs", {
+ l <- list(list("a", 1), list("b", 4),
+ list("b", 5), list("a", 2))
+ rdd1 <- parallelize(sc, l)
+
+ # subtractByKey by itself
+ actual <- collect(subtractByKey(rdd1, rdd1))
+ expect_equal(actual, list())
+
+ # subtractByKey by an empty RDD
+ rdd2 <- parallelize(sc, list())
+ actual <- collect(subtractByKey(rdd1, rdd2))
+ expect_equal(sortKeyValueList(actual),
+ sortKeyValueList(l))
+
+ rdd2 <- parallelize(sc, list(list("a", 3), list("c", 1)))
+ actual <- collect(subtractByKey(rdd1, rdd2))
+ expect_equal(actual,
+ list(list("b", 4), list("b", 5)))
+
+ l <- list(list(1, 1), list(2, 4),
+ list(2, 5), list(1, 2))
+ rdd1 <- parallelize(sc, l)
+ rdd2 <- parallelize(sc, list(list(1, 3), list(3, 1)))
+ actual <- collect(subtractByKey(rdd1, rdd2))
+ expect_equal(actual,
+ list(list(2, 4), list(2, 5)))
+})
+
+test_that("intersection() on RDDs", {
+ # intersection with self
+ actual <- collect(intersection(rdd, rdd))
+ expect_equal(sort(as.integer(actual)), nums)
+
+ # intersection with an empty RDD
+ emptyRdd <- parallelize(sc, list())
+ actual <- collect(intersection(rdd, emptyRdd))
+ expect_equal(actual, list())
+
+ rdd1 <- parallelize(sc, list(1, 10, 2, 3, 4, 5))
+ rdd2 <- parallelize(sc, list(1, 6, 2, 3, 7, 8))
+ actual <- collect(intersection(rdd1, rdd2))
+ expect_equal(sort(as.integer(actual)), 1:3)
+})
+
test_that("join() on pairwise RDDs", {
rdd1 <- parallelize(sc, list(list(1,1), list(2,4)))
rdd2 <- parallelize(sc, list(list(1,2), list(1,3)))
@@ -596,9 +711,9 @@ test_that("sortByKey() on pairwise RDDs", {
sortedRdd3 <- sortByKey(rdd3)
actual <- collect(sortedRdd3)
expect_equal(actual, list(list("1", 3), list("2", 5), list("a", 1), list("b", 2), list("d", 4)))
-
+
# test on the boundary cases
-
+
# boundary case 1: the RDD to be sorted has only 1 partition
rdd4 <- parallelize(sc, l, 1L)
sortedRdd4 <- sortByKey(rdd4)
@@ -623,7 +738,7 @@ test_that("sortByKey() on pairwise RDDs", {
rdd7 <- parallelize(sc, l3, 2L)
sortedRdd7 <- sortByKey(rdd7)
actual <- collect(sortedRdd7)
- expect_equal(actual, l3)
+ expect_equal(actual, l3)
})
test_that("collectAsMap() on a pairwise RDD", {
@@ -634,12 +749,36 @@ test_that("collectAsMap() on a pairwise RDD", {
rdd <- parallelize(sc, list(list("a", 1), list("b", 2)))
vals <- collectAsMap(rdd)
expect_equal(vals, list(a = 1, b = 2))
-
+
rdd <- parallelize(sc, list(list(1.1, 2.2), list(1.2, 2.4)))
vals <- collectAsMap(rdd)
expect_equal(vals, list(`1.1` = 2.2, `1.2` = 2.4))
-
+
rdd <- parallelize(sc, list(list(1, "a"), list(2, "b")))
vals <- collectAsMap(rdd)
expect_equal(vals, list(`1` = "a", `2` = "b"))
})
+
+test_that("sampleByKey() on pairwise RDDs", {
+ rdd <- parallelize(sc, 1:2000)
+ pairsRDD <- lapply(rdd, function(x) { if (x %% 2 == 0) list("a", x) else list("b", x) })
+ fractions <- list(a = 0.2, b = 0.1)
+ sample <- sampleByKey(pairsRDD, FALSE, fractions, 1618L)
+ expect_equal(100 < length(lookup(sample, "a")) && 300 > length(lookup(sample, "a")), TRUE)
+ expect_equal(50 < length(lookup(sample, "b")) && 150 > length(lookup(sample, "b")), TRUE)
+ expect_equal(lookup(sample, "a")[which.min(lookup(sample, "a"))] >= 0, TRUE)
+ expect_equal(lookup(sample, "a")[which.max(lookup(sample, "a"))] <= 2000, TRUE)
+ expect_equal(lookup(sample, "b")[which.min(lookup(sample, "b"))] >= 0, TRUE)
+ expect_equal(lookup(sample, "b")[which.max(lookup(sample, "b"))] <= 2000, TRUE)
+
+ rdd <- parallelize(sc, 1:2000)
+ pairsRDD <- lapply(rdd, function(x) { if (x %% 2 == 0) list(2, x) else list(3, x) })
+ fractions <- list(`2` = 0.2, `3` = 0.1)
+ sample <- sampleByKey(pairsRDD, TRUE, fractions, 1618L)
+ expect_equal(100 < length(lookup(sample, 2)) && 300 > length(lookup(sample, 2)), TRUE)
+ expect_equal(50 < length(lookup(sample, 3)) && 150 > length(lookup(sample, 3)), TRUE)
+ expect_equal(lookup(sample, 2)[which.min(lookup(sample, 2))] >= 0, TRUE)
+ expect_equal(lookup(sample, 2)[which.max(lookup(sample, 2))] <= 2000, TRUE)
+ expect_equal(lookup(sample, 3)[which.min(lookup(sample, 3))] >= 0, TRUE)
+ expect_equal(lookup(sample, 3)[which.max(lookup(sample, 3))] <= 2000, TRUE)
+})
diff --git a/R/pkg/inst/tests/test_shuffle.R b/R/pkg/inst/tests/test_shuffle.R
index d1da8232aea81..d7dedda553c56 100644
--- a/R/pkg/inst/tests/test_shuffle.R
+++ b/R/pkg/inst/tests/test_shuffle.R
@@ -87,6 +87,18 @@ test_that("combineByKey for doubles", {
expect_equal(sortKeyValueList(actual), sortKeyValueList(expected))
})
+test_that("combineByKey for characters", {
+ stringKeyRDD <- parallelize(sc,
+ list(list("max", 1L), list("min", 2L),
+ list("other", 3L), list("max", 4L)), 2L)
+ reduced <- combineByKey(stringKeyRDD,
+ function(x) { x }, "+", "+", 2L)
+ actual <- collect(reduced)
+
+ expected <- list(list("max", 5L), list("min", 2L), list("other", 3L))
+ expect_equal(sortKeyValueList(actual), sortKeyValueList(expected))
+})
+
test_that("aggregateByKey", {
# test aggregateByKey for int keys
rdd <- parallelize(sc, list(list(1, 1), list(1, 2), list(2, 3), list(2, 4)))
diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R
index cf5cf6d1692af..25831ae2d9e18 100644
--- a/R/pkg/inst/tests/test_sparkSQL.R
+++ b/R/pkg/inst/tests/test_sparkSQL.R
@@ -44,9 +44,8 @@ test_that("infer types", {
expect_equal(infer_type(list(1L, 2L)),
list(type = 'array', elementType = "integer", containsNull = TRUE))
expect_equal(infer_type(list(a = 1L, b = "2")),
- list(type = "struct",
- fields = list(list(name = "a", type = "integer", nullable = TRUE),
- list(name = "b", type = "string", nullable = TRUE))))
+ structType(structField(x = "a", type = "integer", nullable = TRUE),
+ structField(x = "b", type = "string", nullable = TRUE)))
e <- new.env()
assign("a", 1L, envir = e)
expect_equal(infer_type(e),
@@ -54,6 +53,18 @@ test_that("infer types", {
valueContainsNull = TRUE))
})
+test_that("structType and structField", {
+ testField <- structField("a", "string")
+ expect_true(inherits(testField, "structField"))
+ expect_true(testField$name() == "a")
+ expect_true(testField$nullable())
+
+ testSchema <- structType(testField, structField("b", "integer"))
+ expect_true(inherits(testSchema, "structType"))
+ expect_true(inherits(testSchema$fields()[[2]], "structField"))
+ expect_true(testSchema$fields()[[1]]$dataType.toString() == "StringType")
+})
+
test_that("create DataFrame from RDD", {
rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) })
df <- createDataFrame(sqlCtx, rdd, list("a", "b"))
@@ -66,9 +77,8 @@ test_that("create DataFrame from RDD", {
expect_true(inherits(df, "DataFrame"))
expect_equal(columns(df), c("_1", "_2"))
- fields <- list(list(name = "a", type = "integer", nullable = TRUE),
- list(name = "b", type = "string", nullable = TRUE))
- schema <- list(type = "struct", fields = fields)
+ schema <- structType(structField(x = "a", type = "integer", nullable = TRUE),
+ structField(x = "b", type = "string", nullable = TRUE))
df <- createDataFrame(sqlCtx, rdd, schema)
expect_true(inherits(df, "DataFrame"))
expect_equal(columns(df), c("a", "b"))
@@ -94,9 +104,8 @@ test_that("toDF", {
expect_true(inherits(df, "DataFrame"))
expect_equal(columns(df), c("_1", "_2"))
- fields <- list(list(name = "a", type = "integer", nullable = TRUE),
- list(name = "b", type = "string", nullable = TRUE))
- schema <- list(type = "struct", fields = fields)
+ schema <- structType(structField(x = "a", type = "integer", nullable = TRUE),
+ structField(x = "b", type = "string", nullable = TRUE))
df <- toDF(rdd, schema)
expect_true(inherits(df, "DataFrame"))
expect_equal(columns(df), c("a", "b"))
@@ -635,7 +644,7 @@ test_that("isLocal()", {
expect_false(isLocal(df))
})
-test_that("unionAll(), subtract(), and intersect() on a DataFrame", {
+test_that("unionAll(), except(), and intersect() on a DataFrame", {
df <- jsonFile(sqlCtx, jsonPath)
lines <- c("{\"name\":\"Bob\", \"age\":24}",
@@ -650,10 +659,10 @@ test_that("unionAll(), subtract(), and intersect() on a DataFrame", {
expect_true(count(unioned) == 6)
expect_true(first(unioned)$name == "Michael")
- subtracted <- sortDF(subtract(df, df2), desc(df$age))
+ excepted <- sortDF(except(df, df2), desc(df$age))
expect_true(inherits(unioned, "DataFrame"))
- expect_true(count(subtracted) == 2)
- expect_true(first(subtracted)$name == "Justin")
+ expect_true(count(excepted) == 2)
+ expect_true(first(excepted)$name == "Justin")
intersected <- sortDF(intersect(df, df2), df$age)
expect_true(inherits(unioned, "DataFrame"))
diff --git a/R/pkg/inst/worker/worker.R b/R/pkg/inst/worker/worker.R
index c6542928e8ddd..014bf7bd7b3fe 100644
--- a/R/pkg/inst/worker/worker.R
+++ b/R/pkg/inst/worker/worker.R
@@ -17,6 +17,23 @@
# Worker class
+# Get current system time
+currentTimeSecs <- function() {
+ as.numeric(Sys.time())
+}
+
+# Get elapsed time
+elapsedSecs <- function() {
+ proc.time()[3]
+}
+
+# Constants
+specialLengths <- list(END_OF_STERAM = 0L, TIMING_DATA = -1L)
+
+# Timing R process boot
+bootTime <- currentTimeSecs()
+bootElap <- elapsedSecs()
+
rLibDir <- Sys.getenv("SPARKR_RLIBDIR")
# Set libPaths to include SparkR package as loadNamespace needs this
# TODO: Figure out if we can avoid this by not loading any objects that require
@@ -37,7 +54,7 @@ serializer <- SparkR:::readString(inputCon)
# Include packages as required
packageNames <- unserialize(SparkR:::readRaw(inputCon))
for (pkg in packageNames) {
- suppressPackageStartupMessages(require(as.character(pkg), character.only=TRUE))
+ suppressPackageStartupMessages(library(as.character(pkg), character.only=TRUE))
}
# read function dependencies
@@ -46,6 +63,9 @@ computeFunc <- unserialize(SparkR:::readRawLen(inputCon, funcLen))
env <- environment(computeFunc)
parent.env(env) <- .GlobalEnv # Attach under global environment.
+# Timing init envs for computing
+initElap <- elapsedSecs()
+
# Read and set broadcast variables
numBroadcastVars <- SparkR:::readInt(inputCon)
if (numBroadcastVars > 0) {
@@ -56,6 +76,9 @@ if (numBroadcastVars > 0) {
}
}
+# Timing broadcast
+broadcastElap <- elapsedSecs()
+
# If -1: read as normal RDD; if >= 0, treat as pairwise RDD and treat the int
# as number of partitions to create.
numPartitions <- SparkR:::readInt(inputCon)
@@ -73,14 +96,23 @@ if (isEmpty != 0) {
} else if (deserializer == "row") {
data <- SparkR:::readDeserializeRows(inputCon)
}
+ # Timing reading input data for execution
+ inputElap <- elapsedSecs()
+
output <- computeFunc(partition, data)
+ # Timing computing
+ computeElap <- elapsedSecs()
+
if (serializer == "byte") {
SparkR:::writeRawSerialize(outputCon, output)
} else if (serializer == "row") {
SparkR:::writeRowSerialize(outputCon, output)
} else {
- SparkR:::writeStrings(outputCon, output)
+ # write lines one-by-one with flag
+ lapply(output, function(line) SparkR:::writeString(outputCon, line))
}
+ # Timing output
+ outputElap <- elapsedSecs()
} else {
if (deserializer == "byte") {
# Now read as many characters as described in funcLen
@@ -90,6 +122,8 @@ if (isEmpty != 0) {
} else if (deserializer == "row") {
data <- SparkR:::readDeserializeRows(inputCon)
}
+ # Timing reading input data for execution
+ inputElap <- elapsedSecs()
res <- new.env()
@@ -107,6 +141,8 @@ if (isEmpty != 0) {
res[[bucket]] <- acc
}
invisible(lapply(data, hashTupleToEnvir))
+ # Timing computing
+ computeElap <- elapsedSecs()
# Step 2: write out all of the environment as key-value pairs.
for (name in ls(res)) {
@@ -116,13 +152,26 @@ if (isEmpty != 0) {
length(res[[name]]$data) <- res[[name]]$counter
SparkR:::writeRawSerialize(outputCon, res[[name]]$data)
}
+ # Timing output
+ outputElap <- elapsedSecs()
}
+} else {
+ inputElap <- broadcastElap
+ computeElap <- broadcastElap
+ outputElap <- broadcastElap
}
+# Report timing
+SparkR:::writeInt(outputCon, specialLengths$TIMING_DATA)
+SparkR:::writeDouble(outputCon, bootTime)
+SparkR:::writeDouble(outputCon, initElap - bootElap) # init
+SparkR:::writeDouble(outputCon, broadcastElap - initElap) # broadcast
+SparkR:::writeDouble(outputCon, inputElap - broadcastElap) # input
+SparkR:::writeDouble(outputCon, computeElap - inputElap) # compute
+SparkR:::writeDouble(outputCon, outputElap - computeElap) # output
+
# End of output
-if (serializer %in% c("byte", "row")) {
- SparkR:::writeInt(outputCon, 0L)
-}
+SparkR:::writeInt(outputCon, specialLengths$END_OF_STERAM)
close(outputCon)
close(inputCon)
diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala
index b0186e9a007b8..e3a649d755450 100644
--- a/core/src/main/scala/org/apache/spark/SparkConf.scala
+++ b/core/src/main/scala/org/apache/spark/SparkConf.scala
@@ -403,6 +403,9 @@ private[spark] object SparkConf extends Logging {
*/
private val deprecatedConfigs: Map[String, DeprecatedConfig] = {
val configs = Seq(
+ DeprecatedConfig("spark.cache.class", "0.8",
+ "The spark.cache.class property is no longer being used! Specify storage levels using " +
+ "the RDD.persist() method instead."),
DeprecatedConfig("spark.yarn.user.classpath.first", "1.3",
"Please use spark.{driver,executor}.userClassPathFirst instead."))
Map(configs.map { cfg => (cfg.key -> cfg) }:_*)
@@ -420,7 +423,15 @@ private[spark] object SparkConf extends Logging {
"spark.history.fs.update.interval" -> Seq(
AlternateConfig("spark.history.fs.update.interval.seconds", "1.4"),
AlternateConfig("spark.history.fs.updateInterval", "1.3"),
- AlternateConfig("spark.history.updateInterval", "1.3"))
+ AlternateConfig("spark.history.updateInterval", "1.3")),
+ "spark.history.fs.cleaner.interval" -> Seq(
+ AlternateConfig("spark.history.fs.cleaner.interval.seconds", "1.4")),
+ "spark.history.fs.cleaner.maxAge" -> Seq(
+ AlternateConfig("spark.history.fs.cleaner.maxAge.seconds", "1.4")),
+ "spark.yarn.am.waitTime" -> Seq(
+ AlternateConfig("spark.yarn.applicationMaster.waitTries", "1.3",
+ // Translate old value to a duration, with 10s wait time per try.
+ translation = s => s"${s.toLong * 10}s"))
)
/**
@@ -470,7 +481,7 @@ private[spark] object SparkConf extends Logging {
configsWithAlternatives.get(key).flatMap { alts =>
alts.collectFirst { case alt if conf.contains(alt.key) =>
val value = conf.get(alt.key)
- alt.translation.map(_(value)).getOrElse(value)
+ if (alt.translation != null) alt.translation(value) else value
}
}
}
@@ -514,6 +525,6 @@ private[spark] object SparkConf extends Logging {
private case class AlternateConfig(
key: String,
version: String,
- translation: Option[String => String] = None)
+ translation: String => String = null)
}
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index e106c5c4bef60..86269eac52db0 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -23,7 +23,7 @@ import java.io._
import java.lang.reflect.Constructor
import java.net.URI
import java.util.{Arrays, Properties, UUID}
-import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger}
+import java.util.concurrent.atomic.{AtomicReference, AtomicBoolean, AtomicInteger}
import java.util.UUID.randomUUID
import scala.collection.{Map, Set}
@@ -1887,11 +1887,12 @@ object SparkContext extends Logging {
private val SPARK_CONTEXT_CONSTRUCTOR_LOCK = new Object()
/**
- * The active, fully-constructed SparkContext. If no SparkContext is active, then this is `None`.
+ * The active, fully-constructed SparkContext. If no SparkContext is active, then this is `null`.
*
- * Access to this field is guarded by SPARK_CONTEXT_CONSTRUCTOR_LOCK
+ * Access to this field is guarded by SPARK_CONTEXT_CONSTRUCTOR_LOCK.
*/
- private var activeContext: Option[SparkContext] = None
+ private val activeContext: AtomicReference[SparkContext] =
+ new AtomicReference[SparkContext](null)
/**
* Points to a partially-constructed SparkContext if some thread is in the SparkContext
@@ -1926,7 +1927,8 @@ object SparkContext extends Logging {
logWarning(warnMsg)
}
- activeContext.foreach { ctx =>
+ if (activeContext.get() != null) {
+ val ctx = activeContext.get()
val errMsg = "Only one SparkContext may be running in this JVM (see SPARK-2243)." +
" To ignore this error, set spark.driver.allowMultipleContexts = true. " +
s"The currently running SparkContext was created at:\n${ctx.creationSite.longForm}"
@@ -1941,6 +1943,39 @@ object SparkContext extends Logging {
}
}
+ /**
+ * This function may be used to get or instantiate a SparkContext and register it as a
+ * singleton object. Because we can only have one active SparkContext per JVM,
+ * this is useful when applications may wish to share a SparkContext.
+ *
+ * Note: This function cannot be used to create multiple SparkContext instances
+ * even if multiple contexts are allowed.
+ */
+ def getOrCreate(config: SparkConf): SparkContext = {
+ // Synchronize to ensure that multiple create requests don't trigger an exception
+ // from assertNoOtherContextIsRunning within setActiveContext
+ SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized {
+ if (activeContext.get() == null) {
+ setActiveContext(new SparkContext(config), allowMultipleContexts = false)
+ }
+ activeContext.get()
+ }
+ }
+
+ /**
+ * This function may be used to get or instantiate a SparkContext and register it as a
+ * singleton object. Because we can only have one active SparkContext per JVM,
+ * this is useful when applications may wish to share a SparkContext.
+ *
+ * This method allows not passing a SparkConf (useful if just retrieving).
+ *
+ * Note: This function cannot be used to create multiple SparkContext instances
+ * even if multiple contexts are allowed.
+ */
+ def getOrCreate(): SparkContext = {
+ getOrCreate(new SparkConf())
+ }
+
/**
* Called at the beginning of the SparkContext constructor to ensure that no SparkContext is
* running. Throws an exception if a running context is detected and logs a warning if another
@@ -1967,7 +2002,7 @@ object SparkContext extends Logging {
SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized {
assertNoOtherContextIsRunning(sc, allowMultipleContexts)
contextBeingConstructed = None
- activeContext = Some(sc)
+ activeContext.set(sc)
}
}
@@ -1978,7 +2013,7 @@ object SparkContext extends Logging {
*/
private[spark] def clearActiveContext(): Unit = {
SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized {
- activeContext = None
+ activeContext.set(null)
}
}
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 0171488e09562..959aefabd8de4 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -103,7 +103,7 @@ class SparkEnv (
// actorSystem.awaitTermination()
// Note that blockTransferService is stopped by BlockManager since it is started by it.
-
+
// If we only stop sc, but the driver process still run as a services then we need to delete
// the tmp dir, if not, it will create too many tmp dirs.
// We only need to delete the tmp dir create by driver, because sparkFilesDir is point to the
@@ -375,12 +375,6 @@ object SparkEnv extends Logging {
"."
}
- // Warn about deprecated spark.cache.class property
- if (conf.contains("spark.cache.class")) {
- logWarning("The spark.cache.class property is no longer being used! Specify storage " +
- "levels using the RDD.persist() method instead.")
- }
-
val outputCommitCoordinator = mockOutputCommitCoordinator.getOrElse {
new OutputCommitCoordinator(conf)
}
@@ -406,7 +400,7 @@ object SparkEnv extends Logging {
shuffleMemoryManager,
outputCommitCoordinator,
conf)
-
+
// Add a reference to tmp dir created by driver, we will delete this tmp dir when stop() is
// called, and we only need to do it for driver. Because driver may run as a service, and if we
// don't delete this tmp dir when sc is stopped, then will create too many tmp dirs.
diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
index 5fa4d483b8342..6fea5e1144f2f 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
@@ -42,10 +42,15 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
rLibDir: String,
broadcastVars: Array[Broadcast[Object]])
extends RDD[U](parent) with Logging {
+ protected var dataStream: DataInputStream = _
+ private var bootTime: Double = _
override def getPartitions: Array[Partition] = parent.partitions
override def compute(partition: Partition, context: TaskContext): Iterator[U] = {
+ // Timing start
+ bootTime = System.currentTimeMillis / 1000.0
+
// The parent may be also an RRDD, so we should launch it first.
val parentIterator = firstParent[T].iterator(partition, context)
@@ -69,7 +74,7 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
// the socket used to receive the output of task
val outSocket = serverSocket.accept()
val inputStream = new BufferedInputStream(outSocket.getInputStream)
- val dataStream = openDataStream(inputStream)
+ dataStream = new DataInputStream(inputStream)
serverSocket.close()
try {
@@ -155,6 +160,7 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
} else if (deserializer == SerializationFormats.ROW) {
dataOut.write(elem.asInstanceOf[Array[Byte]])
} else if (deserializer == SerializationFormats.STRING) {
+ // write string(for StringRRDD)
printOut.println(elem)
}
}
@@ -180,9 +186,41 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
}.start()
}
- protected def openDataStream(input: InputStream): Closeable
+ protected def readData(length: Int): U
- protected def read(): U
+ protected def read(): U = {
+ try {
+ val length = dataStream.readInt()
+
+ length match {
+ case SpecialLengths.TIMING_DATA =>
+ // Timing data from R worker
+ val boot = dataStream.readDouble - bootTime
+ val init = dataStream.readDouble
+ val broadcast = dataStream.readDouble
+ val input = dataStream.readDouble
+ val compute = dataStream.readDouble
+ val output = dataStream.readDouble
+ logInfo(
+ ("Times: boot = %.3f s, init = %.3f s, broadcast = %.3f s, " +
+ "read-input = %.3f s, compute = %.3f s, write-output = %.3f s, " +
+ "total = %.3f s").format(
+ boot,
+ init,
+ broadcast,
+ input,
+ compute,
+ output,
+ boot + init + broadcast + input + compute + output))
+ read()
+ case length if length >= 0 =>
+ readData(length)
+ }
+ } catch {
+ case eof: EOFException =>
+ throw new SparkException("R worker exited unexpectedly (cranshed)", eof)
+ }
+ }
}
/**
@@ -202,31 +240,16 @@ private class PairwiseRRDD[T: ClassTag](
SerializationFormats.BYTE, packageNames, rLibDir,
broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) {
- private var dataStream: DataInputStream = _
-
- override protected def openDataStream(input: InputStream): Closeable = {
- dataStream = new DataInputStream(input)
- dataStream
- }
-
- override protected def read(): (Int, Array[Byte]) = {
- try {
- val length = dataStream.readInt()
-
- length match {
- case length if length == 2 =>
- val hashedKey = dataStream.readInt()
- val contentPairsLength = dataStream.readInt()
- val contentPairs = new Array[Byte](contentPairsLength)
- dataStream.readFully(contentPairs)
- (hashedKey, contentPairs)
- case _ => null // End of input
- }
- } catch {
- case eof: EOFException => {
- throw new SparkException("R worker exited unexpectedly (crashed)", eof)
- }
- }
+ override protected def readData(length: Int): (Int, Array[Byte]) = {
+ length match {
+ case length if length == 2 =>
+ val hashedKey = dataStream.readInt()
+ val contentPairsLength = dataStream.readInt()
+ val contentPairs = new Array[Byte](contentPairsLength)
+ dataStream.readFully(contentPairs)
+ (hashedKey, contentPairs)
+ case _ => null
+ }
}
lazy val asJavaPairRDD : JavaPairRDD[Int, Array[Byte]] = JavaPairRDD.fromRDD(this)
@@ -247,28 +270,13 @@ private class RRDD[T: ClassTag](
parent, -1, func, deserializer, serializer, packageNames, rLibDir,
broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) {
- private var dataStream: DataInputStream = _
-
- override protected def openDataStream(input: InputStream): Closeable = {
- dataStream = new DataInputStream(input)
- dataStream
- }
-
- override protected def read(): Array[Byte] = {
- try {
- val length = dataStream.readInt()
-
- length match {
- case length if length > 0 =>
- val obj = new Array[Byte](length)
- dataStream.readFully(obj, 0, length)
- obj
- case _ => null
- }
- } catch {
- case eof: EOFException => {
- throw new SparkException("R worker exited unexpectedly (crashed)", eof)
- }
+ override protected def readData(length: Int): Array[Byte] = {
+ length match {
+ case length if length > 0 =>
+ val obj = new Array[Byte](length)
+ dataStream.readFully(obj)
+ obj
+ case _ => null
}
}
@@ -289,26 +297,21 @@ private class StringRRDD[T: ClassTag](
parent, -1, func, deserializer, SerializationFormats.STRING, packageNames, rLibDir,
broadcastVars.map(x => x.asInstanceOf[Broadcast[Object]])) {
- private var dataStream: BufferedReader = _
-
- override protected def openDataStream(input: InputStream): Closeable = {
- dataStream = new BufferedReader(new InputStreamReader(input))
- dataStream
- }
-
- override protected def read(): String = {
- try {
- dataStream.readLine()
- } catch {
- case e: IOException => {
- throw new SparkException("R worker exited unexpectedly (crashed)", e)
- }
+ override protected def readData(length: Int): String = {
+ length match {
+ case length if length > 0 =>
+ SerDe.readStringBytes(dataStream, length)
+ case _ => null
}
}
lazy val asJavaRDD : JavaRDD[String] = JavaRDD.fromRDD(this)
}
+private object SpecialLengths {
+ val TIMING_DATA = -1
+}
+
private[r] class BufferedStreamThread(
in: InputStream,
name: String,
diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
index ccb2a371f4e48..371dfe454d1a2 100644
--- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
@@ -85,13 +85,17 @@ private[spark] object SerDe {
in.readDouble()
}
+ def readStringBytes(in: DataInputStream, len: Int): String = {
+ val bytes = new Array[Byte](len)
+ in.readFully(bytes)
+ assert(bytes(len - 1) == 0)
+ val str = new String(bytes.dropRight(1), "UTF-8")
+ str
+ }
+
def readString(in: DataInputStream): String = {
val len = in.readInt()
- val asciiBytes = new Array[Byte](len)
- in.readFully(asciiBytes)
- assert(asciiBytes(len - 1) == 0)
- val str = new String(asciiBytes.dropRight(1).map(_.toChar))
- str
+ readStringBytes(in, len)
}
def readBoolean(in: DataInputStream): Boolean = {
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala
index 985545742df67..47bdd7749ec3d 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala
@@ -52,8 +52,7 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis
private val UPDATE_INTERVAL_S = conf.getTimeAsSeconds("spark.history.fs.update.interval", "10s")
// Interval between each cleaner checks for event logs to delete
- private val CLEAN_INTERVAL_MS = conf.getLong("spark.history.fs.cleaner.interval.seconds",
- DEFAULT_SPARK_HISTORY_FS_CLEANER_INTERVAL_S) * 1000
+ private val CLEAN_INTERVAL_S = conf.getTimeAsSeconds("spark.history.fs.cleaner.interval", "1d")
private val logDir = conf.getOption("spark.history.fs.logDirectory")
.map { d => Utils.resolveURI(d).toString }
@@ -130,8 +129,7 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis
if (conf.getBoolean("spark.history.fs.cleaner.enabled", false)) {
// A task that periodically cleans event logs on disk.
- pool.scheduleAtFixedRate(getRunner(cleanLogs), 0, CLEAN_INTERVAL_MS,
- TimeUnit.MILLISECONDS)
+ pool.scheduleAtFixedRate(getRunner(cleanLogs), 0, CLEAN_INTERVAL_S, TimeUnit.SECONDS)
}
}
}
@@ -270,8 +268,7 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis
try {
val statusList = Option(fs.listStatus(new Path(logDir))).map(_.toSeq)
.getOrElse(Seq[FileStatus]())
- val maxAge = conf.getLong("spark.history.fs.cleaner.maxAge.seconds",
- DEFAULT_SPARK_HISTORY_FS_MAXAGE_S) * 1000
+ val maxAge = conf.getTimeAsSeconds("spark.history.fs.cleaner.maxAge", "7d") * 1000
val now = System.currentTimeMillis()
val appsToRetain = new mutable.LinkedHashMap[String, FsApplicationHistoryInfo]()
@@ -417,12 +414,6 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis
private object FsHistoryProvider {
val DEFAULT_LOG_DIR = "file:/tmp/spark-events"
-
- // One day
- val DEFAULT_SPARK_HISTORY_FS_CLEANER_INTERVAL_S = Duration(1, TimeUnit.DAYS).toSeconds
-
- // One week
- val DEFAULT_SPARK_HISTORY_FS_MAXAGE_S = Duration(7, TimeUnit.DAYS).toSeconds
}
private class FsApplicationHistoryInfo(
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
index b381436839227..d9d62b0e287ed 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
@@ -67,6 +67,8 @@ private[spark] class MesosSchedulerBackend(
// The listener bus to publish executor added/removed events.
val listenerBus = sc.listenerBus
+
+ private[mesos] val mesosExecutorCores = sc.conf.getDouble("spark.mesos.mesosExecutor.cores", 1)
@volatile var appId: String = _
@@ -139,7 +141,7 @@ private[spark] class MesosSchedulerBackend(
.setName("cpus")
.setType(Value.Type.SCALAR)
.setScalar(Value.Scalar.newBuilder()
- .setValue(scheduler.CPUS_PER_TASK).build())
+ .setValue(mesosExecutorCores).build())
.build()
val memory = Resource.newBuilder()
.setName("mem")
@@ -220,10 +222,9 @@ private[spark] class MesosSchedulerBackend(
val mem = getResource(o.getResourcesList, "mem")
val cpus = getResource(o.getResourcesList, "cpus")
val slaveId = o.getSlaveId.getValue
- // TODO(pwendell): Should below be 1 + scheduler.CPUS_PER_TASK?
(mem >= MemoryUtils.calculateTotalMemory(sc) &&
// need at least 1 for executor, 1 for task
- cpus >= 2 * scheduler.CPUS_PER_TASK) ||
+ cpus >= (mesosExecutorCores + scheduler.CPUS_PER_TASK)) ||
(slaveIdsWithExecutors.contains(slaveId) &&
cpus >= scheduler.CPUS_PER_TASK)
}
@@ -232,10 +233,9 @@ private[spark] class MesosSchedulerBackend(
val cpus = if (slaveIdsWithExecutors.contains(o.getSlaveId.getValue)) {
getResource(o.getResourcesList, "cpus").toInt
} else {
- // If the executor doesn't exist yet, subtract CPU for executor
- // TODO(pwendell): Should below just subtract "1"?
- getResource(o.getResourcesList, "cpus").toInt -
- scheduler.CPUS_PER_TASK
+ // If the Mesos executor has not been started on this slave yet, set aside a few
+ // cores for the Mesos executor by offering fewer cores to the Spark executor
+ (getResource(o.getResourcesList, "cpus") - mesosExecutorCores).toInt
}
new WorkerOffer(
o.getSlaveId.getValue,
diff --git a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala
index 7d87ba5fd2610..8e6c200c4ba00 100644
--- a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala
+++ b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala
@@ -217,6 +217,9 @@ class SparkConfSuite extends FunSuite with LocalSparkContext with ResetSystemPro
val count = conf.getAll.filter { case (k, v) => k.startsWith("spark.history.") }.size
assert(count === 4)
+
+ conf.set("spark.yarn.applicationMaster.waitTries", "42")
+ assert(conf.getTimeAsSeconds("spark.yarn.am.waitTime") === 420)
}
}
diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala
index 94be1c6d6397c..728558a424780 100644
--- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala
+++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala
@@ -67,6 +67,26 @@ class SparkContextSuite extends FunSuite with LocalSparkContext {
}
}
+ test("Test getOrCreate") {
+ var sc2: SparkContext = null
+ SparkContext.clearActiveContext()
+ val conf = new SparkConf().setAppName("test").setMaster("local")
+
+ sc = SparkContext.getOrCreate(conf)
+
+ assert(sc.getConf.get("spark.app.name").equals("test"))
+ sc2 = SparkContext.getOrCreate(new SparkConf().setAppName("test2").setMaster("local"))
+ assert(sc2.getConf.get("spark.app.name").equals("test"))
+ assert(sc === sc2)
+ assert(sc eq sc2)
+
+ // Try creating second context to confirm that it's still possible, if desired
+ sc2 = new SparkContext(new SparkConf().setAppName("test3").setMaster("local")
+ .set("spark.driver.allowMultipleContexts", "true"))
+
+ sc2.stop()
+ }
+
test("BytesWritable implicit conversion is correct") {
// Regression test for SPARK-3121
val bytesWritable = new BytesWritable()
diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala
index a311512e82c5e..cdd7be0fbe5dd 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala
@@ -118,12 +118,12 @@ class MesosSchedulerBackendSuite extends FunSuite with LocalSparkContext with Mo
expectedWorkerOffers.append(new WorkerOffer(
mesosOffers.get(0).getSlaveId.getValue,
mesosOffers.get(0).getHostname,
- 2
+ (minCpu - backend.mesosExecutorCores).toInt
))
expectedWorkerOffers.append(new WorkerOffer(
mesosOffers.get(2).getSlaveId.getValue,
mesosOffers.get(2).getHostname,
- 2
+ (minCpu - backend.mesosExecutorCores).toInt
))
val taskDesc = new TaskDescription(1L, 0, "s1", "n1", 0, ByteBuffer.wrap(new Array[Byte](0)))
when(taskScheduler.resourceOffers(expectedWorkerOffers)).thenReturn(Seq(Seq(taskDesc)))
diff --git a/docs/monitoring.md b/docs/monitoring.md
index 2a130224591ca..8a85928d6d44d 100644
--- a/docs/monitoring.md
+++ b/docs/monitoring.md
@@ -153,19 +153,18 @@ follows:
-
spark.history.fs.cleaner.interval.seconds
-
86400
+
spark.history.fs.cleaner.interval
+
1d
- How often the job history cleaner checks for files to delete, in seconds. Defaults to 86400 (one day).
- Files are only deleted if they are older than spark.history.fs.cleaner.maxAge.seconds.
+ How often the job history cleaner checks for files to delete.
+ Files are only deleted if they are older than spark.history.fs.cleaner.maxAge.
-
spark.history.fs.cleaner.maxAge.seconds
-
3600 * 24 * 7
+
spark.history.fs.cleaner.maxAge
+
7d
- Job history files older than this many seconds will be deleted when the history cleaner runs.
- Defaults to 3600 * 24 * 7 (1 week).
+ Job history files older than this will be deleted when the history cleaner runs.
diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md
index c984639bd34cf..594bf78b67713 100644
--- a/docs/running-on-mesos.md
+++ b/docs/running-on-mesos.md
@@ -210,6 +210,16 @@ See the [configuration page](configuration.html) for information on Spark config
Note that total amount of cores the executor will request in total will not exceed the spark.cores.max setting.
+
+
spark.mesos.mesosExecutor.cores
+
1.0
+
+ (Fine-grained mode only) Number of cores to give each Mesos executor. This does not
+ include the cores used to run the Spark tasks. In other words, even if no Spark task
+ is being run, each Mesos executor will occupy the number of cores configured here.
+ The value can be a floating point number.
+
+
spark.mesos.executor.home
driver side SPARK_HOME
diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md
index 853c9f26b0ec9..0968fc5ad632b 100644
--- a/docs/running-on-yarn.md
+++ b/docs/running-on-yarn.md
@@ -211,7 +211,11 @@ Most of the configs are the same for Spark on YARN as for other deployment modes
# Launching Spark on YARN
Ensure that `HADOOP_CONF_DIR` or `YARN_CONF_DIR` points to the directory which contains the (client side) configuration files for the Hadoop cluster.
-These configs are used to write to the dfs and connect to the YARN ResourceManager.
+These configs are used to write to the dfs and connect to the YARN ResourceManager. The
+configuration contained in this directory will be distributed to the YARN cluster so that all
+containers used by the application use the same configuration. If the configuration references
+Java system properties or environment variables not managed by YARN, they should also be set in the
+Spark application's configuration (driver, executors, and the AM when running in client mode).
There are two deploy modes that can be used to launch Spark applications on YARN. In yarn-cluster mode, the Spark driver runs inside an application master process which is managed by YARN on the cluster, and the client can go away after initiating the application. In yarn-client mode, the driver runs in the client process, and the application master is only used for requesting resources from YARN.
diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md
index 03500867df70f..d49233714a0bb 100644
--- a/docs/sql-programming-guide.md
+++ b/docs/sql-programming-guide.md
@@ -193,8 +193,8 @@ df.groupBy("age").count().show()
{% highlight java %}
-val sc: JavaSparkContext // An existing SparkContext.
-val sqlContext = new org.apache.spark.sql.SQLContext(sc)
+JavaSparkContext sc // An existing SparkContext.
+SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc)
// Create the DataFrame
DataFrame df = sqlContext.jsonFile("examples/src/main/resources/people.json");
@@ -308,8 +308,8 @@ val df = sqlContext.sql("SELECT * FROM table")
{% highlight java %}
-val sqlContext = ... // An existing SQLContext
-val df = sqlContext.sql("SELECT * FROM table")
+SQLContext sqlContext = ... // An existing SQLContext
+DataFrame df = sqlContext.sql("SELECT * FROM table")
{% endhighlight %}
@@ -435,7 +435,7 @@ DataFrame teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AN
// The results of SQL queries are DataFrames and support all the normal RDD operations.
// The columns of a row in the result can be accessed by ordinal.
-List teenagerNames = teenagers.map(new Function() {
+List teenagerNames = teenagers.javaRDD().map(new Function() {
public String call(Row row) {
return "Name: " + row.getString(0);
}
@@ -599,7 +599,7 @@ DataFrame results = sqlContext.sql("SELECT name FROM people");
// The results of SQL queries are DataFrames and support all the normal RDD operations.
// The columns of a row in the result can be accessed by ordinal.
-List names = results.map(new Function() {
+List names = results.javaRDD().map(new Function() {
public String call(Row row) {
return "Name: " + row.getString(0);
}
@@ -860,7 +860,7 @@ DataFrame parquetFile = sqlContext.parquetFile("people.parquet");
//Parquet files can also be registered as tables and then used in SQL statements.
parquetFile.registerTempTable("parquetFile");
DataFrame teenagers = sqlContext.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19");
-List teenagerNames = teenagers.map(new Function() {
+List teenagerNames = teenagers.javaRDD().map(new Function() {
public String call(Row row) {
return "Name: " + row.getString(0);
}
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala
new file mode 100644
index 0000000000000..d4cc8dede07ef
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala
@@ -0,0 +1,322 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.ml
+
+import scala.collection.mutable
+import scala.language.reflectiveCalls
+
+import scopt.OptionParser
+
+import org.apache.spark.ml.tree.DecisionTreeModel
+import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.examples.mllib.AbstractParams
+import org.apache.spark.ml.{Pipeline, PipelineStage}
+import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier}
+import org.apache.spark.ml.feature.{VectorIndexer, StringIndexer}
+import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor}
+import org.apache.spark.ml.util.MetadataUtils
+import org.apache.spark.mllib.evaluation.{RegressionMetrics, MulticlassMetrics}
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.util.MLUtils
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.types.StringType
+import org.apache.spark.sql.{SQLContext, DataFrame}
+
+
+/**
+ * An example runner for decision trees. Run with
+ * {{{
+ * ./bin/run-example ml.DecisionTreeExample [options]
+ * }}}
+ * If you use it as a template to create your own app, please use `spark-submit` to submit your app.
+ */
+object DecisionTreeExample {
+
+ case class Params(
+ input: String = null,
+ testInput: String = "",
+ dataFormat: String = "libsvm",
+ algo: String = "Classification",
+ maxDepth: Int = 5,
+ maxBins: Int = 32,
+ minInstancesPerNode: Int = 1,
+ minInfoGain: Double = 0.0,
+ numTrees: Int = 1,
+ featureSubsetStrategy: String = "auto",
+ fracTest: Double = 0.2,
+ cacheNodeIds: Boolean = false,
+ checkpointDir: Option[String] = None,
+ checkpointInterval: Int = 10) extends AbstractParams[Params]
+
+ def main(args: Array[String]) {
+ val defaultParams = Params()
+
+ val parser = new OptionParser[Params]("DecisionTreeExample") {
+ head("DecisionTreeExample: an example decision tree app.")
+ opt[String]("algo")
+ .text(s"algorithm (Classification, Regression), default: ${defaultParams.algo}")
+ .action((x, c) => c.copy(algo = x))
+ opt[Int]("maxDepth")
+ .text(s"max depth of the tree, default: ${defaultParams.maxDepth}")
+ .action((x, c) => c.copy(maxDepth = x))
+ opt[Int]("maxBins")
+ .text(s"max number of bins, default: ${defaultParams.maxBins}")
+ .action((x, c) => c.copy(maxBins = x))
+ opt[Int]("minInstancesPerNode")
+ .text(s"min number of instances required at child nodes to create the parent split," +
+ s" default: ${defaultParams.minInstancesPerNode}")
+ .action((x, c) => c.copy(minInstancesPerNode = x))
+ opt[Double]("minInfoGain")
+ .text(s"min info gain required to create a split, default: ${defaultParams.minInfoGain}")
+ .action((x, c) => c.copy(minInfoGain = x))
+ opt[Double]("fracTest")
+ .text(s"fraction of data to hold out for testing. If given option testInput, " +
+ s"this option is ignored. default: ${defaultParams.fracTest}")
+ .action((x, c) => c.copy(fracTest = x))
+ opt[Boolean]("cacheNodeIds")
+ .text(s"whether to use node Id cache during training, " +
+ s"default: ${defaultParams.cacheNodeIds}")
+ .action((x, c) => c.copy(cacheNodeIds = x))
+ opt[String]("checkpointDir")
+ .text(s"checkpoint directory where intermediate node Id caches will be stored, " +
+ s"default: ${defaultParams.checkpointDir match {
+ case Some(strVal) => strVal
+ case None => "None"
+ }}")
+ .action((x, c) => c.copy(checkpointDir = Some(x)))
+ opt[Int]("checkpointInterval")
+ .text(s"how often to checkpoint the node Id cache, " +
+ s"default: ${defaultParams.checkpointInterval}")
+ .action((x, c) => c.copy(checkpointInterval = x))
+ opt[String]("testInput")
+ .text(s"input path to test dataset. If given, option fracTest is ignored." +
+ s" default: ${defaultParams.testInput}")
+ .action((x, c) => c.copy(testInput = x))
+ opt[String]("")
+ .text("data format: libsvm (default), dense (deprecated in Spark v1.1)")
+ .action((x, c) => c.copy(dataFormat = x))
+ arg[String]("")
+ .text("input path to labeled examples")
+ .required()
+ .action((x, c) => c.copy(input = x))
+ checkConfig { params =>
+ if (params.fracTest < 0 || params.fracTest > 1) {
+ failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1].")
+ } else {
+ success
+ }
+ }
+ }
+
+ parser.parse(args, defaultParams).map { params =>
+ run(params)
+ }.getOrElse {
+ sys.exit(1)
+ }
+ }
+
+ /** Load a dataset from the given path, using the given format */
+ private[ml] def loadData(
+ sc: SparkContext,
+ path: String,
+ format: String,
+ expectedNumFeatures: Option[Int] = None): RDD[LabeledPoint] = {
+ format match {
+ case "dense" => MLUtils.loadLabeledPoints(sc, path)
+ case "libsvm" => expectedNumFeatures match {
+ case Some(numFeatures) => MLUtils.loadLibSVMFile(sc, path, numFeatures)
+ case None => MLUtils.loadLibSVMFile(sc, path)
+ }
+ case _ => throw new IllegalArgumentException(s"Bad data format: $format")
+ }
+ }
+
+ /**
+ * Load training and test data from files.
+ * @param input Path to input dataset.
+ * @param dataFormat "libsvm" or "dense"
+ * @param testInput Path to test dataset.
+ * @param algo Classification or Regression
+ * @param fracTest Fraction of input data to hold out for testing. Ignored if testInput given.
+ * @return (training dataset, test dataset)
+ */
+ private[ml] def loadDatasets(
+ sc: SparkContext,
+ input: String,
+ dataFormat: String,
+ testInput: String,
+ algo: String,
+ fracTest: Double): (DataFrame, DataFrame) = {
+ val sqlContext = new SQLContext(sc)
+ import sqlContext.implicits._
+
+ // Load training data
+ val origExamples: RDD[LabeledPoint] = loadData(sc, input, dataFormat)
+
+ // Load or create test set
+ val splits: Array[RDD[LabeledPoint]] = if (testInput != "") {
+ // Load testInput.
+ val numFeatures = origExamples.take(1)(0).features.size
+ val origTestExamples: RDD[LabeledPoint] = loadData(sc, input, dataFormat, Some(numFeatures))
+ Array(origExamples, origTestExamples)
+ } else {
+ // Split input into training, test.
+ origExamples.randomSplit(Array(1.0 - fracTest, fracTest), seed = 12345)
+ }
+
+ // For classification, convert labels to Strings since we will index them later with
+ // StringIndexer.
+ def labelsToStrings(data: DataFrame): DataFrame = {
+ algo.toLowerCase match {
+ case "classification" =>
+ data.withColumn("labelString", data("label").cast(StringType))
+ case "regression" =>
+ data
+ case _ =>
+ throw new IllegalArgumentException("Algo ${params.algo} not supported.")
+ }
+ }
+ val dataframes = splits.map(_.toDF()).map(labelsToStrings).map(_.cache())
+
+ (dataframes(0), dataframes(1))
+ }
+
+ def run(params: Params) {
+ val conf = new SparkConf().setAppName(s"DecisionTreeExample with $params")
+ val sc = new SparkContext(conf)
+ params.checkpointDir.foreach(sc.setCheckpointDir)
+ val algo = params.algo.toLowerCase
+
+ println(s"DecisionTreeExample with parameters:\n$params")
+
+ // Load training and test data and cache it.
+ val (training: DataFrame, test: DataFrame) =
+ loadDatasets(sc, params.input, params.dataFormat, params.testInput, algo, params.fracTest)
+
+ val numTraining = training.count()
+ val numTest = test.count()
+ val numFeatures = training.select("features").first().getAs[Vector](0).size
+ println("Loaded data:")
+ println(s" numTraining = $numTraining, numTest = $numTest")
+ println(s" numFeatures = $numFeatures")
+
+ // Set up Pipeline
+ val stages = new mutable.ArrayBuffer[PipelineStage]()
+ // (1) For classification, re-index classes.
+ val labelColName = if (algo == "classification") "indexedLabel" else "label"
+ if (algo == "classification") {
+ val labelIndexer = new StringIndexer().setInputCol("labelString").setOutputCol(labelColName)
+ stages += labelIndexer
+ }
+ // (2) Identify categorical features using VectorIndexer.
+ // Features with more than maxCategories values will be treated as continuous.
+ val featuresIndexer = new VectorIndexer().setInputCol("features")
+ .setOutputCol("indexedFeatures").setMaxCategories(10)
+ stages += featuresIndexer
+ // (3) Learn DecisionTree
+ val dt = algo match {
+ case "classification" =>
+ new DecisionTreeClassifier().setFeaturesCol("indexedFeatures")
+ .setLabelCol(labelColName)
+ .setMaxDepth(params.maxDepth)
+ .setMaxBins(params.maxBins)
+ .setMinInstancesPerNode(params.minInstancesPerNode)
+ .setMinInfoGain(params.minInfoGain)
+ .setCacheNodeIds(params.cacheNodeIds)
+ .setCheckpointInterval(params.checkpointInterval)
+ case "regression" =>
+ new DecisionTreeRegressor().setFeaturesCol("indexedFeatures")
+ .setLabelCol(labelColName)
+ .setMaxDepth(params.maxDepth)
+ .setMaxBins(params.maxBins)
+ .setMinInstancesPerNode(params.minInstancesPerNode)
+ .setMinInfoGain(params.minInfoGain)
+ .setCacheNodeIds(params.cacheNodeIds)
+ .setCheckpointInterval(params.checkpointInterval)
+ case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.")
+ }
+ stages += dt
+ val pipeline = new Pipeline().setStages(stages.toArray)
+
+ // Fit the Pipeline
+ val startTime = System.nanoTime()
+ val pipelineModel = pipeline.fit(training)
+ val elapsedTime = (System.nanoTime() - startTime) / 1e9
+ println(s"Training time: $elapsedTime seconds")
+
+ // Get the trained Decision Tree from the fitted PipelineModel
+ val treeModel: DecisionTreeModel = algo match {
+ case "classification" =>
+ pipelineModel.getModel[DecisionTreeClassificationModel](
+ dt.asInstanceOf[DecisionTreeClassifier])
+ case "regression" =>
+ pipelineModel.getModel[DecisionTreeRegressionModel](dt.asInstanceOf[DecisionTreeRegressor])
+ case _ => throw new IllegalArgumentException("Algo ${params.algo} not supported.")
+ }
+ if (treeModel.numNodes < 20) {
+ println(treeModel.toDebugString) // Print full model.
+ } else {
+ println(treeModel) // Print model summary.
+ }
+
+ // Predict on training
+ val trainingFullPredictions = pipelineModel.transform(training).cache()
+ val trainingPredictions = trainingFullPredictions.select("prediction")
+ .map(_.getDouble(0))
+ val trainingLabels = trainingFullPredictions.select(labelColName).map(_.getDouble(0))
+ // Predict on test data
+ val testFullPredictions = pipelineModel.transform(test).cache()
+ val testPredictions = testFullPredictions.select("prediction")
+ .map(_.getDouble(0))
+ val testLabels = testFullPredictions.select(labelColName).map(_.getDouble(0))
+
+ // For classification, print number of classes for reference.
+ if (algo == "classification") {
+ val numClasses =
+ MetadataUtils.getNumClasses(trainingFullPredictions.schema(labelColName)) match {
+ case Some(n) => n
+ case None => throw new RuntimeException(
+ "DecisionTreeExample had unknown failure when indexing labels for classification.")
+ }
+ println(s"numClasses = $numClasses.")
+ }
+
+ // Evaluate model on training, test data
+ algo match {
+ case "classification" =>
+ val trainingAccuracy =
+ new MulticlassMetrics(trainingPredictions.zip(trainingLabels)).precision
+ println(s"Train accuracy = $trainingAccuracy")
+ val testAccuracy =
+ new MulticlassMetrics(testPredictions.zip(testLabels)).precision
+ println(s"Test accuracy = $testAccuracy")
+ case "regression" =>
+ val trainingRMSE =
+ new RegressionMetrics(trainingPredictions.zip(trainingLabels)).rootMeanSquaredError
+ println(s"Training root mean squared error (RMSE) = $trainingRMSE")
+ val testRMSE =
+ new RegressionMetrics(testPredictions.zip(testLabels)).rootMeanSquaredError
+ println(s"Test root mean squared error (RMSE) = $testRMSE")
+ case _ =>
+ throw new IllegalArgumentException("Algo ${params.algo} not supported.")
+ }
+
+ sc.stop()
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala
index aa27a668f1695..d7dee8fed2a55 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala
@@ -117,12 +117,12 @@ class AttributeGroup private (
case numeric: NumericAttribute =>
// Skip default numeric attributes.
if (numeric.withoutIndex != NumericAttribute.defaultAttr) {
- numericMetadata += numeric.toMetadata(withType = false)
+ numericMetadata += numeric.toMetadataImpl(withType = false)
}
case nominal: NominalAttribute =>
- nominalMetadata += nominal.toMetadata(withType = false)
+ nominalMetadata += nominal.toMetadataImpl(withType = false)
case binary: BinaryAttribute =>
- binaryMetadata += binary.toMetadata(withType = false)
+ binaryMetadata += binary.toMetadataImpl(withType = false)
}
val attrBldr = new MetadataBuilder
if (numericMetadata.nonEmpty) {
@@ -151,7 +151,7 @@ class AttributeGroup private (
}
/** Converts to ML metadata */
- def toMetadata: Metadata = toMetadata(Metadata.empty)
+ def toMetadata(): Metadata = toMetadata(Metadata.empty)
/** Converts to a StructField with some existing metadata. */
def toStructField(existingMetadata: Metadata): StructField = {
@@ -159,7 +159,7 @@ class AttributeGroup private (
}
/** Converts to a StructField. */
- def toStructField: StructField = toStructField(Metadata.empty)
+ def toStructField(): StructField = toStructField(Metadata.empty)
override def equals(other: Any): Boolean = {
other match {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala
index 00b7566aab434..5717d6ec2eaec 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala
@@ -68,21 +68,32 @@ sealed abstract class Attribute extends Serializable {
* Converts this attribute to [[Metadata]].
* @param withType whether to include the type info
*/
- private[attribute] def toMetadata(withType: Boolean): Metadata
+ private[attribute] def toMetadataImpl(withType: Boolean): Metadata
/**
* Converts this attribute to [[Metadata]]. For numeric attributes, the type info is excluded to
* save space, because numeric type is the default attribute type. For nominal and binary
* attributes, the type info is included.
*/
- private[attribute] def toMetadata(): Metadata = {
+ private[attribute] def toMetadataImpl(): Metadata = {
if (attrType == AttributeType.Numeric) {
- toMetadata(withType = false)
+ toMetadataImpl(withType = false)
} else {
- toMetadata(withType = true)
+ toMetadataImpl(withType = true)
}
}
+ /** Converts to ML metadata with some existing metadata. */
+ def toMetadata(existingMetadata: Metadata): Metadata = {
+ new MetadataBuilder()
+ .withMetadata(existingMetadata)
+ .putMetadata(AttributeKeys.ML_ATTR, toMetadataImpl())
+ .build()
+ }
+
+ /** Converts to ML metadata */
+ def toMetadata(): Metadata = toMetadata(Metadata.empty)
+
/**
* Converts to a [[StructField]] with some existing metadata.
* @param existingMetadata existing metadata to carry over
@@ -90,7 +101,7 @@ sealed abstract class Attribute extends Serializable {
def toStructField(existingMetadata: Metadata): StructField = {
val newMetadata = new MetadataBuilder()
.withMetadata(existingMetadata)
- .putMetadata(AttributeKeys.ML_ATTR, withoutName.withoutIndex.toMetadata())
+ .putMetadata(AttributeKeys.ML_ATTR, withoutName.withoutIndex.toMetadataImpl())
.build()
StructField(name.get, DoubleType, nullable = false, newMetadata)
}
@@ -98,7 +109,7 @@ sealed abstract class Attribute extends Serializable {
/** Converts to a [[StructField]]. */
def toStructField(): StructField = toStructField(Metadata.empty)
- override def toString: String = toMetadata(withType = true).toString
+ override def toString: String = toMetadataImpl(withType = true).toString
}
/** Trait for ML attribute factories. */
@@ -210,7 +221,7 @@ class NumericAttribute private[ml] (
override def isNominal: Boolean = false
/** Convert this attribute to metadata. */
- private[attribute] override def toMetadata(withType: Boolean): Metadata = {
+ override private[attribute] def toMetadataImpl(withType: Boolean): Metadata = {
import org.apache.spark.ml.attribute.AttributeKeys._
val bldr = new MetadataBuilder()
if (withType) bldr.putString(TYPE, attrType.name)
@@ -353,6 +364,20 @@ class NominalAttribute private[ml] (
/** Copy without the `numValues`. */
def withoutNumValues: NominalAttribute = copy(numValues = None)
+ /**
+ * Get the number of values, either from `numValues` or from `values`.
+ * Return None if unknown.
+ */
+ def getNumValues: Option[Int] = {
+ if (numValues.nonEmpty) {
+ numValues
+ } else if (values.nonEmpty) {
+ Some(values.get.length)
+ } else {
+ None
+ }
+ }
+
/** Creates a copy of this attribute with optional changes. */
private def copy(
name: Option[String] = name,
@@ -363,7 +388,7 @@ class NominalAttribute private[ml] (
new NominalAttribute(name, index, isOrdinal, numValues, values)
}
- private[attribute] override def toMetadata(withType: Boolean): Metadata = {
+ override private[attribute] def toMetadataImpl(withType: Boolean): Metadata = {
import org.apache.spark.ml.attribute.AttributeKeys._
val bldr = new MetadataBuilder()
if (withType) bldr.putString(TYPE, attrType.name)
@@ -465,7 +490,7 @@ class BinaryAttribute private[ml] (
new BinaryAttribute(name, index, values)
}
- private[attribute] override def toMetadata(withType: Boolean): Metadata = {
+ override private[attribute] def toMetadataImpl(withType: Boolean): Metadata = {
import org.apache.spark.ml.attribute.AttributeKeys._
val bldr = new MetadataBuilder
if (withType) bldr.putString(TYPE, attrType.name)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
new file mode 100644
index 0000000000000..3855e396b5534
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -0,0 +1,155 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification
+
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml.impl.estimator.{Predictor, PredictionModel}
+import org.apache.spark.ml.impl.tree._
+import org.apache.spark.ml.param.{Params, ParamMap}
+import org.apache.spark.ml.tree.{DecisionTreeModel, Node}
+import org.apache.spark.ml.util.MetadataUtils
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree}
+import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
+import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.DataFrame
+
+
+/**
+ * :: AlphaComponent ::
+ *
+ * [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] learning algorithm
+ * for classification.
+ * It supports both binary and multiclass labels, as well as both continuous and categorical
+ * features.
+ */
+@AlphaComponent
+final class DecisionTreeClassifier
+ extends Predictor[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel]
+ with DecisionTreeParams
+ with TreeClassifierParams {
+
+ // Override parameter setters from parent trait for Java API compatibility.
+
+ override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value)
+
+ override def setMaxBins(value: Int): this.type = super.setMaxBins(value)
+
+ override def setMinInstancesPerNode(value: Int): this.type =
+ super.setMinInstancesPerNode(value)
+
+ override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value)
+
+ override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value)
+
+ override def setCacheNodeIds(value: Boolean): this.type =
+ super.setCacheNodeIds(value)
+
+ override def setCheckpointInterval(value: Int): this.type =
+ super.setCheckpointInterval(value)
+
+ override def setImpurity(value: String): this.type = super.setImpurity(value)
+
+ override protected def train(
+ dataset: DataFrame,
+ paramMap: ParamMap): DecisionTreeClassificationModel = {
+ val categoricalFeatures: Map[Int, Int] =
+ MetadataUtils.getCategoricalFeatures(dataset.schema(paramMap(featuresCol)))
+ val numClasses: Int = MetadataUtils.getNumClasses(dataset.schema(paramMap(labelCol))) match {
+ case Some(n: Int) => n
+ case None => throw new IllegalArgumentException("DecisionTreeClassifier was given input" +
+ s" with invalid label column, without the number of classes specified.")
+ // TODO: Automatically index labels.
+ }
+ val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, paramMap)
+ val strategy = getOldStrategy(categoricalFeatures, numClasses)
+ val oldModel = OldDecisionTree.train(oldDataset, strategy)
+ DecisionTreeClassificationModel.fromOld(oldModel, this, paramMap, categoricalFeatures)
+ }
+
+ /** (private[ml]) Create a Strategy instance to use with the old API. */
+ override private[ml] def getOldStrategy(
+ categoricalFeatures: Map[Int, Int],
+ numClasses: Int): OldStrategy = {
+ val strategy = super.getOldStrategy(categoricalFeatures, numClasses)
+ strategy.algo = OldAlgo.Classification
+ strategy.setImpurity(getOldImpurity)
+ strategy
+ }
+}
+
+object DecisionTreeClassifier {
+ /** Accessor for supported impurities */
+ final val supportedImpurities: Array[String] = TreeClassifierParams.supportedImpurities
+}
+
+/**
+ * :: AlphaComponent ::
+ *
+ * [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] model for classification.
+ * It supports both binary and multiclass labels, as well as both continuous and categorical
+ * features.
+ */
+@AlphaComponent
+final class DecisionTreeClassificationModel private[ml] (
+ override val parent: DecisionTreeClassifier,
+ override val fittingParamMap: ParamMap,
+ override val rootNode: Node)
+ extends PredictionModel[Vector, DecisionTreeClassificationModel]
+ with DecisionTreeModel with Serializable {
+
+ require(rootNode != null,
+ "DecisionTreeClassificationModel given null rootNode, but it requires a non-null rootNode.")
+
+ override protected def predict(features: Vector): Double = {
+ rootNode.predict(features)
+ }
+
+ override protected def copy(): DecisionTreeClassificationModel = {
+ val m = new DecisionTreeClassificationModel(parent, fittingParamMap, rootNode)
+ Params.inheritValues(this.extractParamMap(), this, m)
+ m
+ }
+
+ override def toString: String = {
+ s"DecisionTreeClassificationModel of depth $depth with $numNodes nodes"
+ }
+
+ /** (private[ml]) Convert to a model in the old API */
+ private[ml] def toOld: OldDecisionTreeModel = {
+ new OldDecisionTreeModel(rootNode.toOld(1), OldAlgo.Classification)
+ }
+}
+
+private[ml] object DecisionTreeClassificationModel {
+
+ /** (private[ml]) Convert a model from the old API */
+ def fromOld(
+ oldModel: OldDecisionTreeModel,
+ parent: DecisionTreeClassifier,
+ fittingParamMap: ParamMap,
+ categoricalFeatures: Map[Int, Int]): DecisionTreeClassificationModel = {
+ require(oldModel.algo == OldAlgo.Classification,
+ s"Cannot convert non-classification DecisionTreeModel (old API) to" +
+ s" DecisionTreeClassificationModel (new API). Algo is: ${oldModel.algo}")
+ val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures)
+ new DecisionTreeClassificationModel(parent, fittingParamMap, rootNode)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index 4d960df357fe9..23956c512c8a6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -118,7 +118,7 @@ class StringIndexerModel private[ml] (
}
val outputColName = map(outputCol)
val metadata = NominalAttribute.defaultAttr
- .withName(outputColName).withValues(labels).toStructField().metadata
+ .withName(outputColName).withValues(labels).toMetadata()
dataset.select(col("*"), indexer(dataset(map(inputCol))).as(outputColName, metadata))
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala
new file mode 100644
index 0000000000000..6f4509f03d033
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala
@@ -0,0 +1,300 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.impl.tree
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.ml.impl.estimator.PredictorParams
+import org.apache.spark.ml.param._
+import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
+import org.apache.spark.mllib.tree.impurity.{Gini => OldGini, Entropy => OldEntropy,
+ Impurity => OldImpurity, Variance => OldVariance}
+
+
+/**
+ * :: DeveloperApi ::
+ * Parameters for Decision Tree-based algorithms.
+ *
+ * Note: Marked as private and DeveloperApi since this may be made public in the future.
+ */
+@DeveloperApi
+private[ml] trait DecisionTreeParams extends PredictorParams {
+
+ /**
+ * Maximum depth of the tree.
+ * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
+ * (default = 5)
+ * @group param
+ */
+ final val maxDepth: IntParam =
+ new IntParam(this, "maxDepth", "Maximum depth of the tree." +
+ " E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.")
+
+ /**
+ * Maximum number of bins used for discretizing continuous features and for choosing how to split
+ * on features at each node. More bins give higher granularity.
+ * Must be >= 2 and >= number of categories in any categorical feature.
+ * (default = 32)
+ * @group param
+ */
+ final val maxBins: IntParam = new IntParam(this, "maxBins", "Max number of bins for" +
+ " discretizing continuous features. Must be >=2 and >= number of categories for any" +
+ " categorical feature.")
+
+ /**
+ * Minimum number of instances each child must have after split.
+ * If a split causes the left or right child to have fewer than minInstancesPerNode,
+ * the split will be discarded as invalid.
+ * Should be >= 1.
+ * (default = 1)
+ * @group param
+ */
+ final val minInstancesPerNode: IntParam = new IntParam(this, "minInstancesPerNode", "Minimum" +
+ " number of instances each child must have after split. If a split causes the left or right" +
+ " child to have fewer than minInstancesPerNode, the split will be discarded as invalid." +
+ " Should be >= 1.")
+
+ /**
+ * Minimum information gain for a split to be considered at a tree node.
+ * (default = 0.0)
+ * @group param
+ */
+ final val minInfoGain: DoubleParam = new DoubleParam(this, "minInfoGain",
+ "Minimum information gain for a split to be considered at a tree node.")
+
+ /**
+ * Maximum memory in MB allocated to histogram aggregation.
+ * (default = 256 MB)
+ * @group expertParam
+ */
+ final val maxMemoryInMB: IntParam = new IntParam(this, "maxMemoryInMB",
+ "Maximum memory in MB allocated to histogram aggregation.")
+
+ /**
+ * If false, the algorithm will pass trees to executors to match instances with nodes.
+ * If true, the algorithm will cache node IDs for each instance.
+ * Caching can speed up training of deeper trees.
+ * (default = false)
+ * @group expertParam
+ */
+ final val cacheNodeIds: BooleanParam = new BooleanParam(this, "cacheNodeIds", "If false, the" +
+ " algorithm will pass trees to executors to match instances with nodes. If true, the" +
+ " algorithm will cache node IDs for each instance. Caching can speed up training of deeper" +
+ " trees.")
+
+ /**
+ * Specifies how often to checkpoint the cached node IDs.
+ * E.g. 10 means that the cache will get checkpointed every 10 iterations.
+ * This is only used if cacheNodeIds is true and if the checkpoint directory is set in
+ * [[org.apache.spark.SparkContext]].
+ * Must be >= 1.
+ * (default = 10)
+ * @group expertParam
+ */
+ final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "Specifies" +
+ " how often to checkpoint the cached node IDs. E.g. 10 means that the cache will get" +
+ " checkpointed every 10 iterations. This is only used if cacheNodeIds is true and if the" +
+ " checkpoint directory is set in the SparkContext. Must be >= 1.")
+
+ setDefault(maxDepth -> 5, maxBins -> 32, minInstancesPerNode -> 1, minInfoGain -> 0.0,
+ maxMemoryInMB -> 256, cacheNodeIds -> false, checkpointInterval -> 10)
+
+ /** @group setParam */
+ def setMaxDepth(value: Int): this.type = {
+ require(value >= 0, s"maxDepth parameter must be >= 0. Given bad value: $value")
+ set(maxDepth, value)
+ this.asInstanceOf[this.type]
+ }
+
+ /** @group getParam */
+ def getMaxDepth: Int = getOrDefault(maxDepth)
+
+ /** @group setParam */
+ def setMaxBins(value: Int): this.type = {
+ require(value >= 2, s"maxBins parameter must be >= 2. Given bad value: $value")
+ set(maxBins, value)
+ this
+ }
+
+ /** @group getParam */
+ def getMaxBins: Int = getOrDefault(maxBins)
+
+ /** @group setParam */
+ def setMinInstancesPerNode(value: Int): this.type = {
+ require(value >= 1, s"minInstancesPerNode parameter must be >= 1. Given bad value: $value")
+ set(minInstancesPerNode, value)
+ this
+ }
+
+ /** @group getParam */
+ def getMinInstancesPerNode: Int = getOrDefault(minInstancesPerNode)
+
+ /** @group setParam */
+ def setMinInfoGain(value: Double): this.type = {
+ set(minInfoGain, value)
+ this
+ }
+
+ /** @group getParam */
+ def getMinInfoGain: Double = getOrDefault(minInfoGain)
+
+ /** @group expertSetParam */
+ def setMaxMemoryInMB(value: Int): this.type = {
+ require(value > 0, s"maxMemoryInMB parameter must be > 0. Given bad value: $value")
+ set(maxMemoryInMB, value)
+ this
+ }
+
+ /** @group expertGetParam */
+ def getMaxMemoryInMB: Int = getOrDefault(maxMemoryInMB)
+
+ /** @group expertSetParam */
+ def setCacheNodeIds(value: Boolean): this.type = {
+ set(cacheNodeIds, value)
+ this
+ }
+
+ /** @group expertGetParam */
+ def getCacheNodeIds: Boolean = getOrDefault(cacheNodeIds)
+
+ /** @group expertSetParam */
+ def setCheckpointInterval(value: Int): this.type = {
+ require(value >= 1, s"checkpointInterval parameter must be >= 1. Given bad value: $value")
+ set(checkpointInterval, value)
+ this
+ }
+
+ /** @group expertGetParam */
+ def getCheckpointInterval: Int = getOrDefault(checkpointInterval)
+
+ /**
+ * Create a Strategy instance to use with the old API.
+ * NOTE: The caller should set impurity and subsamplingRate (which is set to 1.0,
+ * the default for single trees).
+ */
+ private[ml] def getOldStrategy(
+ categoricalFeatures: Map[Int, Int],
+ numClasses: Int): OldStrategy = {
+ val strategy = OldStrategy.defaultStategy(OldAlgo.Classification)
+ strategy.checkpointInterval = getCheckpointInterval
+ strategy.maxBins = getMaxBins
+ strategy.maxDepth = getMaxDepth
+ strategy.maxMemoryInMB = getMaxMemoryInMB
+ strategy.minInfoGain = getMinInfoGain
+ strategy.minInstancesPerNode = getMinInstancesPerNode
+ strategy.useNodeIdCache = getCacheNodeIds
+ strategy.numClasses = numClasses
+ strategy.categoricalFeaturesInfo = categoricalFeatures
+ strategy.subsamplingRate = 1.0 // default for individual trees
+ strategy
+ }
+}
+
+/**
+ * (private trait) Parameters for Decision Tree-based classification algorithms.
+ */
+private[ml] trait TreeClassifierParams extends Params {
+
+ /**
+ * Criterion used for information gain calculation (case-insensitive).
+ * Supported: "entropy" and "gini".
+ * (default = gini)
+ * @group param
+ */
+ val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" +
+ " information gain calculation (case-insensitive). Supported options:" +
+ s" ${TreeClassifierParams.supportedImpurities.mkString(", ")}")
+
+ setDefault(impurity -> "gini")
+
+ /** @group setParam */
+ def setImpurity(value: String): this.type = {
+ val impurityStr = value.toLowerCase
+ require(TreeClassifierParams.supportedImpurities.contains(impurityStr),
+ s"Tree-based classifier was given unrecognized impurity: $value." +
+ s" Supported options: ${TreeClassifierParams.supportedImpurities.mkString(", ")}")
+ set(impurity, impurityStr)
+ this
+ }
+
+ /** @group getParam */
+ def getImpurity: String = getOrDefault(impurity)
+
+ /** Convert new impurity to old impurity. */
+ private[ml] def getOldImpurity: OldImpurity = {
+ getImpurity match {
+ case "entropy" => OldEntropy
+ case "gini" => OldGini
+ case _ =>
+ // Should never happen because of check in setter method.
+ throw new RuntimeException(
+ s"TreeClassifierParams was given unrecognized impurity: $impurity.")
+ }
+ }
+}
+
+private[ml] object TreeClassifierParams {
+ // These options should be lowercase.
+ val supportedImpurities: Array[String] = Array("entropy", "gini").map(_.toLowerCase)
+}
+
+/**
+ * (private trait) Parameters for Decision Tree-based regression algorithms.
+ */
+private[ml] trait TreeRegressorParams extends Params {
+
+ /**
+ * Criterion used for information gain calculation (case-insensitive).
+ * Supported: "variance".
+ * (default = variance)
+ * @group param
+ */
+ val impurity: Param[String] = new Param[String](this, "impurity", "Criterion used for" +
+ " information gain calculation (case-insensitive). Supported options:" +
+ s" ${TreeRegressorParams.supportedImpurities.mkString(", ")}")
+
+ setDefault(impurity -> "variance")
+
+ /** @group setParam */
+ def setImpurity(value: String): this.type = {
+ val impurityStr = value.toLowerCase
+ require(TreeRegressorParams.supportedImpurities.contains(impurityStr),
+ s"Tree-based regressor was given unrecognized impurity: $value." +
+ s" Supported options: ${TreeRegressorParams.supportedImpurities.mkString(", ")}")
+ set(impurity, impurityStr)
+ this
+ }
+
+ /** @group getParam */
+ def getImpurity: String = getOrDefault(impurity)
+
+ /** Convert new impurity to old impurity. */
+ protected def getOldImpurity: OldImpurity = {
+ getImpurity match {
+ case "variance" => OldVariance
+ case _ =>
+ // Should never happen because of check in setter method.
+ throw new RuntimeException(
+ s"TreeRegressorParams was given unrecognized impurity: $impurity")
+ }
+ }
+}
+
+private[ml] object TreeRegressorParams {
+ // These options should be lowercase.
+ val supportedImpurities: Array[String] = Array("variance").map(_.toLowerCase)
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/package.scala b/mllib/src/main/scala/org/apache/spark/ml/package.scala
index b45bd1499b72e..ac75e9de1a8f2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/package.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/package.scala
@@ -32,6 +32,18 @@ package org.apache.spark
* @groupname getParam Parameter getters
* @groupprio getParam 6
*
+ * @groupname expertParam (expert-only) Parameters
+ * @groupdesc expertParam A list of advanced, expert-only (hyper-)parameter keys this algorithm can
+ * take. Users can set and get the parameter values through setters and getters,
+ * respectively.
+ * @groupprio expertParam 7
+ *
+ * @groupname expertSetParam (expert-only) Parameter setters
+ * @groupprio expertSetParam 8
+ *
+ * @groupname expertGetParam (expert-only) Parameter getters
+ * @groupprio expertGetParam 9
+ *
* @groupname Ungrouped Members
* @groupprio Ungrouped 0
*/
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
index 849c60433c777..ddc5907e7facd 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
@@ -296,8 +296,9 @@ private[spark] object Params {
paramMap: ParamMap,
parent: E,
child: M): Unit = {
+ val childParams = child.params.map(_.name).toSet
parent.params.foreach { param =>
- if (paramMap.contains(param)) {
+ if (paramMap.contains(param) && childParams.contains(param.name)) {
child.set(child.getParam(param.name), paramMap(param))
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
new file mode 100644
index 0000000000000..49a8b77acf960
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -0,0 +1,145 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.regression
+
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor}
+import org.apache.spark.ml.impl.tree._
+import org.apache.spark.ml.param.{Params, ParamMap}
+import org.apache.spark.ml.tree.{DecisionTreeModel, Node}
+import org.apache.spark.ml.util.MetadataUtils
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree}
+import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
+import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.DataFrame
+
+
+/**
+ * :: AlphaComponent ::
+ *
+ * [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] learning algorithm
+ * for regression.
+ * It supports both continuous and categorical features.
+ */
+@AlphaComponent
+final class DecisionTreeRegressor
+ extends Predictor[Vector, DecisionTreeRegressor, DecisionTreeRegressionModel]
+ with DecisionTreeParams
+ with TreeRegressorParams {
+
+ // Override parameter setters from parent trait for Java API compatibility.
+
+ override def setMaxDepth(value: Int): this.type = super.setMaxDepth(value)
+
+ override def setMaxBins(value: Int): this.type = super.setMaxBins(value)
+
+ override def setMinInstancesPerNode(value: Int): this.type =
+ super.setMinInstancesPerNode(value)
+
+ override def setMinInfoGain(value: Double): this.type = super.setMinInfoGain(value)
+
+ override def setMaxMemoryInMB(value: Int): this.type = super.setMaxMemoryInMB(value)
+
+ override def setCacheNodeIds(value: Boolean): this.type = super.setCacheNodeIds(value)
+
+ override def setCheckpointInterval(value: Int): this.type =
+ super.setCheckpointInterval(value)
+
+ override def setImpurity(value: String): this.type = super.setImpurity(value)
+
+ override protected def train(
+ dataset: DataFrame,
+ paramMap: ParamMap): DecisionTreeRegressionModel = {
+ val categoricalFeatures: Map[Int, Int] =
+ MetadataUtils.getCategoricalFeatures(dataset.schema(paramMap(featuresCol)))
+ val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, paramMap)
+ val strategy = getOldStrategy(categoricalFeatures)
+ val oldModel = OldDecisionTree.train(oldDataset, strategy)
+ DecisionTreeRegressionModel.fromOld(oldModel, this, paramMap, categoricalFeatures)
+ }
+
+ /** (private[ml]) Create a Strategy instance to use with the old API. */
+ private[ml] def getOldStrategy(categoricalFeatures: Map[Int, Int]): OldStrategy = {
+ val strategy = super.getOldStrategy(categoricalFeatures, numClasses = 0)
+ strategy.algo = OldAlgo.Regression
+ strategy.setImpurity(getOldImpurity)
+ strategy
+ }
+}
+
+object DecisionTreeRegressor {
+ /** Accessor for supported impurities */
+ final val supportedImpurities: Array[String] = TreeRegressorParams.supportedImpurities
+}
+
+/**
+ * :: AlphaComponent ::
+ *
+ * [[http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree]] model for regression.
+ * It supports both continuous and categorical features.
+ * @param rootNode Root of the decision tree
+ */
+@AlphaComponent
+final class DecisionTreeRegressionModel private[ml] (
+ override val parent: DecisionTreeRegressor,
+ override val fittingParamMap: ParamMap,
+ override val rootNode: Node)
+ extends PredictionModel[Vector, DecisionTreeRegressionModel]
+ with DecisionTreeModel with Serializable {
+
+ require(rootNode != null,
+ "DecisionTreeClassificationModel given null rootNode, but it requires a non-null rootNode.")
+
+ override protected def predict(features: Vector): Double = {
+ rootNode.predict(features)
+ }
+
+ override protected def copy(): DecisionTreeRegressionModel = {
+ val m = new DecisionTreeRegressionModel(parent, fittingParamMap, rootNode)
+ Params.inheritValues(this.extractParamMap(), this, m)
+ m
+ }
+
+ override def toString: String = {
+ s"DecisionTreeRegressionModel of depth $depth with $numNodes nodes"
+ }
+
+ /** Convert to a model in the old API */
+ private[ml] def toOld: OldDecisionTreeModel = {
+ new OldDecisionTreeModel(rootNode.toOld(1), OldAlgo.Regression)
+ }
+}
+
+private[ml] object DecisionTreeRegressionModel {
+
+ /** (private[ml]) Convert a model from the old API */
+ def fromOld(
+ oldModel: OldDecisionTreeModel,
+ parent: DecisionTreeRegressor,
+ fittingParamMap: ParamMap,
+ categoricalFeatures: Map[Int, Int]): DecisionTreeRegressionModel = {
+ require(oldModel.algo == OldAlgo.Regression,
+ s"Cannot convert non-regression DecisionTreeModel (old API) to" +
+ s" DecisionTreeRegressionModel (new API). Algo is: ${oldModel.algo}")
+ val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures)
+ new DecisionTreeRegressionModel(parent, fittingParamMap, rootNode)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
new file mode 100644
index 0000000000000..d6e2203d9f937
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
@@ -0,0 +1,205 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tree
+
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.tree.model.{InformationGainStats => OldInformationGainStats,
+ Node => OldNode, Predict => OldPredict}
+
+
+/**
+ * Decision tree node interface.
+ */
+sealed abstract class Node extends Serializable {
+
+ // TODO: Add aggregate stats (once available). This will happen after we move the DecisionTree
+ // code into the new API and deprecate the old API.
+
+ /** Prediction this node makes (or would make, if it is an internal node) */
+ def prediction: Double
+
+ /** Impurity measure at this node (for training data) */
+ def impurity: Double
+
+ /** Recursive prediction helper method */
+ private[ml] def predict(features: Vector): Double = prediction
+
+ /**
+ * Get the number of nodes in tree below this node, including leaf nodes.
+ * E.g., if this is a leaf, returns 0. If both children are leaves, returns 2.
+ */
+ private[tree] def numDescendants: Int
+
+ /**
+ * Recursive print function.
+ * @param indentFactor The number of spaces to add to each level of indentation.
+ */
+ private[tree] def subtreeToString(indentFactor: Int = 0): String
+
+ /**
+ * Get depth of tree from this node.
+ * E.g.: Depth 0 means this is a leaf node. Depth 1 means 1 internal and 2 leaf nodes.
+ */
+ private[tree] def subtreeDepth: Int
+
+ /**
+ * Create a copy of this node in the old Node format, recursively creating child nodes as needed.
+ * @param id Node ID using old format IDs
+ */
+ private[ml] def toOld(id: Int): OldNode
+}
+
+private[ml] object Node {
+
+ /**
+ * Create a new Node from the old Node format, recursively creating child nodes as needed.
+ */
+ def fromOld(oldNode: OldNode, categoricalFeatures: Map[Int, Int]): Node = {
+ if (oldNode.isLeaf) {
+ // TODO: Once the implementation has been moved to this API, then include sufficient
+ // statistics here.
+ new LeafNode(prediction = oldNode.predict.predict, impurity = oldNode.impurity)
+ } else {
+ val gain = if (oldNode.stats.nonEmpty) {
+ oldNode.stats.get.gain
+ } else {
+ 0.0
+ }
+ new InternalNode(prediction = oldNode.predict.predict, impurity = oldNode.impurity,
+ gain = gain, leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures),
+ rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures),
+ split = Split.fromOld(oldNode.split.get, categoricalFeatures))
+ }
+ }
+}
+
+/**
+ * Decision tree leaf node.
+ * @param prediction Prediction this node makes
+ * @param impurity Impurity measure at this node (for training data)
+ */
+final class LeafNode private[ml] (
+ override val prediction: Double,
+ override val impurity: Double) extends Node {
+
+ override def toString: String = s"LeafNode(prediction = $prediction, impurity = $impurity)"
+
+ override private[ml] def predict(features: Vector): Double = prediction
+
+ override private[tree] def numDescendants: Int = 0
+
+ override private[tree] def subtreeToString(indentFactor: Int = 0): String = {
+ val prefix: String = " " * indentFactor
+ prefix + s"Predict: $prediction\n"
+ }
+
+ override private[tree] def subtreeDepth: Int = 0
+
+ override private[ml] def toOld(id: Int): OldNode = {
+ // NOTE: We do NOT store 'prob' in the new API currently.
+ new OldNode(id, new OldPredict(prediction, prob = 0.0), impurity, isLeaf = true,
+ None, None, None, None)
+ }
+}
+
+/**
+ * Internal Decision Tree node.
+ * @param prediction Prediction this node would make if it were a leaf node
+ * @param impurity Impurity measure at this node (for training data)
+ * @param gain Information gain value.
+ * Values < 0 indicate missing values; this quirk will be removed with future updates.
+ * @param leftChild Left-hand child node
+ * @param rightChild Right-hand child node
+ * @param split Information about the test used to split to the left or right child.
+ */
+final class InternalNode private[ml] (
+ override val prediction: Double,
+ override val impurity: Double,
+ val gain: Double,
+ val leftChild: Node,
+ val rightChild: Node,
+ val split: Split) extends Node {
+
+ override def toString: String = {
+ s"InternalNode(prediction = $prediction, impurity = $impurity, split = $split)"
+ }
+
+ override private[ml] def predict(features: Vector): Double = {
+ if (split.shouldGoLeft(features)) {
+ leftChild.predict(features)
+ } else {
+ rightChild.predict(features)
+ }
+ }
+
+ override private[tree] def numDescendants: Int = {
+ 2 + leftChild.numDescendants + rightChild.numDescendants
+ }
+
+ override private[tree] def subtreeToString(indentFactor: Int = 0): String = {
+ val prefix: String = " " * indentFactor
+ prefix + s"If (${InternalNode.splitToString(split, left=true)})\n" +
+ leftChild.subtreeToString(indentFactor + 1) +
+ prefix + s"Else (${InternalNode.splitToString(split, left=false)})\n" +
+ rightChild.subtreeToString(indentFactor + 1)
+ }
+
+ override private[tree] def subtreeDepth: Int = {
+ 1 + math.max(leftChild.subtreeDepth, rightChild.subtreeDepth)
+ }
+
+ override private[ml] def toOld(id: Int): OldNode = {
+ assert(id.toLong * 2 < Int.MaxValue, "Decision Tree could not be converted from new to old API"
+ + " since the old API does not support deep trees.")
+ // NOTE: We do NOT store 'prob' in the new API currently.
+ new OldNode(id, new OldPredict(prediction, prob = 0.0), impurity, isLeaf = false,
+ Some(split.toOld), Some(leftChild.toOld(OldNode.leftChildIndex(id))),
+ Some(rightChild.toOld(OldNode.rightChildIndex(id))),
+ Some(new OldInformationGainStats(gain, impurity, leftChild.impurity, rightChild.impurity,
+ new OldPredict(leftChild.prediction, prob = 0.0),
+ new OldPredict(rightChild.prediction, prob = 0.0))))
+ }
+}
+
+private object InternalNode {
+
+ /**
+ * Helper method for [[Node.subtreeToString()]].
+ * @param split Split to print
+ * @param left Indicates whether this is the part of the split going to the left,
+ * or that going to the right.
+ */
+ private def splitToString(split: Split, left: Boolean): String = {
+ val featureStr = s"feature ${split.featureIndex}"
+ split match {
+ case contSplit: ContinuousSplit =>
+ if (left) {
+ s"$featureStr <= ${contSplit.threshold}"
+ } else {
+ s"$featureStr > ${contSplit.threshold}"
+ }
+ case catSplit: CategoricalSplit =>
+ val categoriesStr = catSplit.getLeftCategories.mkString("{", ",", "}")
+ if (left) {
+ s"$featureStr in $categoriesStr"
+ } else {
+ s"$featureStr not in $categoriesStr"
+ }
+ }
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala
new file mode 100644
index 0000000000000..cb940f62990ed
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala
@@ -0,0 +1,151 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tree
+
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.tree.configuration.{FeatureType => OldFeatureType}
+import org.apache.spark.mllib.tree.model.{Split => OldSplit}
+
+
+/**
+ * Interface for a "Split," which specifies a test made at a decision tree node
+ * to choose the left or right path.
+ */
+sealed trait Split extends Serializable {
+
+ /** Index of feature which this split tests */
+ def featureIndex: Int
+
+ /** Return true (split to left) or false (split to right) */
+ private[ml] def shouldGoLeft(features: Vector): Boolean
+
+ /** Convert to old Split format */
+ private[tree] def toOld: OldSplit
+}
+
+private[ml] object Split {
+
+ def fromOld(oldSplit: OldSplit, categoricalFeatures: Map[Int, Int]): Split = {
+ oldSplit.featureType match {
+ case OldFeatureType.Categorical =>
+ new CategoricalSplit(featureIndex = oldSplit.feature,
+ leftCategories = oldSplit.categories.toArray, categoricalFeatures(oldSplit.feature))
+ case OldFeatureType.Continuous =>
+ new ContinuousSplit(featureIndex = oldSplit.feature, threshold = oldSplit.threshold)
+ }
+ }
+}
+
+/**
+ * Split which tests a categorical feature.
+ * @param featureIndex Index of the feature to test
+ * @param leftCategories If the feature value is in this set of categories, then the split goes
+ * left. Otherwise, it goes right.
+ * @param numCategories Number of categories for this feature.
+ */
+final class CategoricalSplit(
+ override val featureIndex: Int,
+ leftCategories: Array[Double],
+ private val numCategories: Int)
+ extends Split {
+
+ require(leftCategories.forall(cat => 0 <= cat && cat < numCategories), "Invalid leftCategories" +
+ s" (should be in range [0, $numCategories)): ${leftCategories.mkString(",")}")
+
+ /**
+ * If true, then "categories" is the set of categories for splitting to the left, and vice versa.
+ */
+ private val isLeft: Boolean = leftCategories.length <= numCategories / 2
+
+ /** Set of categories determining the splitting rule, along with [[isLeft]]. */
+ private val categories: Set[Double] = {
+ if (isLeft) {
+ leftCategories.toSet
+ } else {
+ setComplement(leftCategories.toSet)
+ }
+ }
+
+ override private[ml] def shouldGoLeft(features: Vector): Boolean = {
+ if (isLeft) {
+ categories.contains(features(featureIndex))
+ } else {
+ !categories.contains(features(featureIndex))
+ }
+ }
+
+ override def equals(o: Any): Boolean = {
+ o match {
+ case other: CategoricalSplit => featureIndex == other.featureIndex &&
+ isLeft == other.isLeft && categories == other.categories
+ case _ => false
+ }
+ }
+
+ override private[tree] def toOld: OldSplit = {
+ val oldCats = if (isLeft) {
+ categories
+ } else {
+ setComplement(categories)
+ }
+ OldSplit(featureIndex, threshold = 0.0, OldFeatureType.Categorical, oldCats.toList)
+ }
+
+ /** Get sorted categories which split to the left */
+ def getLeftCategories: Array[Double] = {
+ val cats = if (isLeft) categories else setComplement(categories)
+ cats.toArray.sorted
+ }
+
+ /** Get sorted categories which split to the right */
+ def getRightCategories: Array[Double] = {
+ val cats = if (isLeft) setComplement(categories) else categories
+ cats.toArray.sorted
+ }
+
+ /** [0, numCategories) \ cats */
+ private def setComplement(cats: Set[Double]): Set[Double] = {
+ Range(0, numCategories).map(_.toDouble).filter(cat => !cats.contains(cat)).toSet
+ }
+}
+
+/**
+ * Split which tests a continuous feature.
+ * @param featureIndex Index of the feature to test
+ * @param threshold If the feature value is <= this threshold, then the split goes left.
+ * Otherwise, it goes right.
+ */
+final class ContinuousSplit(override val featureIndex: Int, val threshold: Double) extends Split {
+
+ override private[ml] def shouldGoLeft(features: Vector): Boolean = {
+ features(featureIndex) <= threshold
+ }
+
+ override def equals(o: Any): Boolean = {
+ o match {
+ case other: ContinuousSplit =>
+ featureIndex == other.featureIndex && threshold == other.threshold
+ case _ =>
+ false
+ }
+ }
+
+ override private[tree] def toOld: OldSplit = {
+ OldSplit(featureIndex, threshold, OldFeatureType.Continuous, List.empty[Double])
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
new file mode 100644
index 0000000000000..8e3bc3849dcf0
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.tree
+
+import org.apache.spark.annotation.AlphaComponent
+
+
+/**
+ * :: AlphaComponent ::
+ *
+ * Abstraction for Decision Tree models.
+ *
+ * TODO: Add support for predicting probabilities and raw predictions
+ */
+@AlphaComponent
+trait DecisionTreeModel {
+
+ /** Root of the decision tree */
+ def rootNode: Node
+
+ /** Number of nodes in tree, including leaf nodes. */
+ def numNodes: Int = {
+ 1 + rootNode.numDescendants
+ }
+
+ /**
+ * Depth of the tree.
+ * E.g.: Depth 0 means 1 leaf node. Depth 1 means 1 internal node and 2 leaf nodes.
+ */
+ lazy val depth: Int = {
+ rootNode.subtreeDepth
+ }
+
+ /** Summary of the model */
+ override def toString: String = {
+ // Implementing classes should generally override this method to be more descriptive.
+ s"DecisionTreeModel of depth $depth with $numNodes nodes"
+ }
+
+ /** Full description of model */
+ def toDebugString: String = {
+ val header = toString + "\n"
+ header + rootNode.subtreeToString(2)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala
new file mode 100644
index 0000000000000..c84c8b4eb744f
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/MetadataUtils.scala
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.util
+
+import scala.collection.immutable.HashMap
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, BinaryAttribute, NominalAttribute,
+ NumericAttribute}
+import org.apache.spark.sql.types.StructField
+
+
+/**
+ * :: Experimental ::
+ *
+ * Helper utilities for tree-based algorithms
+ */
+@Experimental
+object MetadataUtils {
+
+ /**
+ * Examine a schema to identify the number of classes in a label column.
+ * Returns None if the number of labels is not specified, or if the label column is continuous.
+ */
+ def getNumClasses(labelSchema: StructField): Option[Int] = {
+ Attribute.fromStructField(labelSchema) match {
+ case numAttr: NumericAttribute => None
+ case binAttr: BinaryAttribute => Some(2)
+ case nomAttr: NominalAttribute => nomAttr.getNumValues
+ }
+ }
+
+ /**
+ * Examine a schema to identify categorical (Binary and Nominal) features.
+ *
+ * @param featuresSchema Schema of the features column.
+ * If a feature does not have metadata, it is assumed to be continuous.
+ * If a feature is Nominal, then it must have the number of values
+ * specified.
+ * @return Map: feature index --> number of categories.
+ * The map's set of keys will be the set of categorical feature indices.
+ */
+ def getCategoricalFeatures(featuresSchema: StructField): Map[Int, Int] = {
+ val metadata = AttributeGroup.fromStructField(featuresSchema)
+ if (metadata.attributes.isEmpty) {
+ HashMap.empty[Int, Int]
+ } else {
+ metadata.attributes.get.zipWithIndex.flatMap { case (attr, idx) =>
+ if (attr == null) {
+ Iterator()
+ } else {
+ attr match {
+ case numAttr: NumericAttribute => Iterator()
+ case binAttr: BinaryAttribute => Iterator(idx -> 2)
+ case nomAttr: NominalAttribute =>
+ nomAttr.getNumValues match {
+ case Some(numValues: Int) => Iterator(idx -> numValues)
+ case None => throw new IllegalArgumentException(s"Feature $idx is marked as" +
+ " Nominal (categorical), but it does not have the number of values specified.")
+ }
+ }
+ }
+ }.toMap
+ }
+ }
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index b9d0c56dd1ea3..dfe3a0b6913ef 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -1147,7 +1147,10 @@ object DecisionTree extends Serializable with Logging {
}
}
- assert(splits.length > 0)
+ // TODO: Do not fail; just ignore the useless feature.
+ assert(splits.length > 0,
+ s"DecisionTree could not handle feature $featureIndex since it had only 1 unique value." +
+ " Please remove this feature and then try again.")
// set number of splits accordingly
metadata.setNumSplits(featureIndex, splits.length)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
index c02c79f094b66..0e31c7ed58df8 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
@@ -81,11 +81,11 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy)
/**
* Method to validate a gradient boosting model
* @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
- * @param validationInput Validation dataset:
- RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
- Should be different from and follow the same distribution as input.
- e.g., these two datasets could be created from an original dataset
- by using [[org.apache.spark.rdd.RDD.randomSplit()]]
+ * @param validationInput Validation dataset.
+ * This dataset should be different from the training dataset,
+ * but it should follow the same distribution.
+ * E.g., these two datasets could be created from an original dataset
+ * by using [[org.apache.spark.rdd.RDD.randomSplit()]]
* @return a gradient boosted trees model that can be used for prediction
*/
def runWithValidation(
@@ -194,8 +194,6 @@ object GradientBoostedTrees extends Logging {
val firstTreeWeight = 1.0
baseLearners(0) = firstTreeModel
baseLearnerWeights(0) = firstTreeWeight
- val startingModel = new GradientBoostedTreesModel(
- Regression, Array(firstTreeModel), baseLearnerWeights.slice(0, 1))
var predError: RDD[(Double, Double)] = GradientBoostedTreesModel.
computeInitialPredictionAndError(input, firstTreeWeight, firstTreeModel, loss)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
index db01f2e229e5a..055e60c7d9c95 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
@@ -249,7 +249,7 @@ private class RandomForest (
nodeIdCache.get.deleteAllCheckpoints()
} catch {
case e:IOException =>
- logWarning(s"delete all chackpoints failed. Error reason: ${e.getMessage}")
+ logWarning(s"delete all checkpoints failed. Error reason: ${e.getMessage}")
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
index 664c8df019233..2d6b01524ff3d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
@@ -89,14 +89,14 @@ object BoostingStrategy {
* @return Configuration for boosting algorithm
*/
def defaultParams(algo: Algo): BoostingStrategy = {
- val treeStragtegy = Strategy.defaultStategy(algo)
- treeStragtegy.maxDepth = 3
+ val treeStrategy = Strategy.defaultStategy(algo)
+ treeStrategy.maxDepth = 3
algo match {
case Algo.Classification =>
- treeStragtegy.numClasses = 2
- new BoostingStrategy(treeStragtegy, LogLoss)
+ treeStrategy.numClasses = 2
+ new BoostingStrategy(treeStrategy, LogLoss)
case Algo.Regression =>
- new BoostingStrategy(treeStragtegy, SquaredError)
+ new BoostingStrategy(treeStrategy, SquaredError)
case _ =>
throw new IllegalArgumentException(s"$algo is not supported by boosting.")
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala
index 6f570b4e09c79..2bdef73c4a8f1 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala
@@ -20,7 +20,7 @@ package org.apache.spark.mllib.tree.loss
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.model.TreeEnsembleModel
-import org.apache.spark.rdd.RDD
+
/**
* :: DeveloperApi ::
@@ -45,9 +45,8 @@ object AbsoluteError extends Loss {
if (label - prediction < 0) 1.0 else -1.0
}
- override def computeError(prediction: Double, label: Double): Double = {
+ override private[mllib] def computeError(prediction: Double, label: Double): Double = {
val err = label - prediction
math.abs(err)
}
-
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
index 24ee9f3d51293..778c24526de70 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
@@ -21,7 +21,7 @@ import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.model.TreeEnsembleModel
import org.apache.spark.mllib.util.MLUtils
-import org.apache.spark.rdd.RDD
+
/**
* :: DeveloperApi ::
@@ -47,10 +47,9 @@ object LogLoss extends Loss {
- 4.0 * label / (1.0 + math.exp(2.0 * label * prediction))
}
- override def computeError(prediction: Double, label: Double): Double = {
+ override private[mllib] def computeError(prediction: Double, label: Double): Double = {
val margin = 2.0 * label * prediction
// The following is equivalent to 2.0 * log(1 + exp(-margin)) but more numerically stable.
2.0 * MLUtils.log1pExp(-margin)
}
-
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala
index d3b82b752fa0d..64ffccbce073f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala
@@ -22,6 +22,7 @@ import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.model.TreeEnsembleModel
import org.apache.spark.rdd.RDD
+
/**
* :: DeveloperApi ::
* Trait for adding "pluggable" loss functions for the gradient boosting algorithm.
@@ -57,6 +58,5 @@ trait Loss extends Serializable {
* @param label True label.
* @return Measure of model error on datapoint.
*/
- def computeError(prediction: Double, label: Double): Double
-
+ private[mllib] def computeError(prediction: Double, label: Double): Double
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
index 58857ae15e93e..a5582d3ef3324 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
@@ -20,7 +20,7 @@ package org.apache.spark.mllib.tree.loss
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.model.TreeEnsembleModel
-import org.apache.spark.rdd.RDD
+
/**
* :: DeveloperApi ::
@@ -45,9 +45,8 @@ object SquaredError extends Loss {
2.0 * (prediction - label)
}
- override def computeError(prediction: Double, label: Double): Double = {
+ override private[mllib] def computeError(prediction: Double, label: Double): Double = {
val err = prediction - label
err * err
}
-
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
index c9bafd60fba4d..331af428533de 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
@@ -113,11 +113,13 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable
DecisionTreeModel.SaveLoadV1_0.save(sc, path, this)
}
- override protected def formatVersion: String = "1.0"
+ override protected def formatVersion: String = DecisionTreeModel.formatVersion
}
object DecisionTreeModel extends Loader[DecisionTreeModel] with Logging {
+ private[spark] def formatVersion: String = "1.0"
+
private[tree] object SaveLoadV1_0 {
def thisFormatVersion: String = "1.0"
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
index 4f72bb8014cc0..708ba04b567d3 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
@@ -175,7 +175,7 @@ class Node (
}
}
-private[tree] object Node {
+private[spark] object Node {
/**
* Return a node with the given node id (but nothing else set).
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
index fef3d2acb202a..8341219bfa71c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
@@ -38,6 +38,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
import org.apache.spark.util.Utils
+
/**
* :: Experimental ::
* Represents a random forest model.
@@ -47,7 +48,7 @@ import org.apache.spark.util.Utils
*/
@Experimental
class RandomForestModel(override val algo: Algo, override val trees: Array[DecisionTreeModel])
- extends TreeEnsembleModel(algo, trees, Array.fill(trees.size)(1.0),
+ extends TreeEnsembleModel(algo, trees, Array.fill(trees.length)(1.0),
combiningStrategy = if (algo == Classification) Vote else Average)
with Saveable {
@@ -58,11 +59,13 @@ class RandomForestModel(override val algo: Algo, override val trees: Array[Decis
RandomForestModel.SaveLoadV1_0.thisClassName)
}
- override protected def formatVersion: String = TreeEnsembleModel.SaveLoadV1_0.thisFormatVersion
+ override protected def formatVersion: String = RandomForestModel.formatVersion
}
object RandomForestModel extends Loader[RandomForestModel] {
+ private[mllib] def formatVersion: String = TreeEnsembleModel.SaveLoadV1_0.thisFormatVersion
+
override def load(sc: SparkContext, path: String): RandomForestModel = {
val (loadedClassName, version, jsonMetadata) = Loader.loadMetadata(sc, path)
val classNameV1_0 = SaveLoadV1_0.thisClassName
@@ -102,15 +105,13 @@ class GradientBoostedTreesModel(
extends TreeEnsembleModel(algo, trees, treeWeights, combiningStrategy = Sum)
with Saveable {
- require(trees.size == treeWeights.size)
+ require(trees.length == treeWeights.length)
override def save(sc: SparkContext, path: String): Unit = {
TreeEnsembleModel.SaveLoadV1_0.save(sc, path, this,
GradientBoostedTreesModel.SaveLoadV1_0.thisClassName)
}
- override protected def formatVersion: String = TreeEnsembleModel.SaveLoadV1_0.thisFormatVersion
-
/**
* Method to compute error or loss for every iteration of gradient boosting.
* @param data RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
@@ -138,7 +139,7 @@ class GradientBoostedTreesModel(
evaluationArray(0) = predictionAndError.values.mean()
val broadcastTrees = sc.broadcast(trees)
- (1 until numIterations).map { nTree =>
+ (1 until numIterations).foreach { nTree =>
predictionAndError = remappedData.zip(predictionAndError).mapPartitions { iter =>
val currentTree = broadcastTrees.value(nTree)
val currentTreeWeight = localTreeWeights(nTree)
@@ -155,6 +156,7 @@ class GradientBoostedTreesModel(
evaluationArray
}
+ override protected def formatVersion: String = GradientBoostedTreesModel.formatVersion
}
object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] {
@@ -200,17 +202,17 @@ object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] {
loss: Loss): RDD[(Double, Double)] = {
val newPredError = data.zip(predictionAndError).mapPartitions { iter =>
- iter.map {
- case (lp, (pred, error)) => {
- val newPred = pred + tree.predict(lp.features) * treeWeight
- val newError = loss.computeError(newPred, lp.label)
- (newPred, newError)
- }
+ iter.map { case (lp, (pred, error)) =>
+ val newPred = pred + tree.predict(lp.features) * treeWeight
+ val newError = loss.computeError(newPred, lp.label)
+ (newPred, newError)
}
}
newPredError
}
+ private[mllib] def formatVersion: String = TreeEnsembleModel.SaveLoadV1_0.thisFormatVersion
+
override def load(sc: SparkContext, path: String): GradientBoostedTreesModel = {
val (loadedClassName, version, jsonMetadata) = Loader.loadMetadata(sc, path)
val classNameV1_0 = SaveLoadV1_0.thisClassName
@@ -340,12 +342,12 @@ private[tree] sealed class TreeEnsembleModel(
}
/**
- * Get number of trees in forest.
+ * Get number of trees in ensemble.
*/
- def numTrees: Int = trees.size
+ def numTrees: Int = trees.length
/**
- * Get total number of nodes, summed over all trees in the forest.
+ * Get total number of nodes, summed over all trees in the ensemble.
*/
def totalNumNodes: Int = trees.map(_.numNodes).sum
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java
new file mode 100644
index 0000000000000..43b8787f9dd7e
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification;
+
+import java.io.File;
+import java.io.Serializable;
+import java.util.HashMap;
+import java.util.Map;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.ml.impl.TreeTests;
+import org.apache.spark.mllib.classification.LogisticRegressionSuite;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.util.Utils;
+
+
+public class JavaDecisionTreeClassifierSuite implements Serializable {
+
+ private transient JavaSparkContext sc;
+
+ @Before
+ public void setUp() {
+ sc = new JavaSparkContext("local", "JavaDecisionTreeClassifierSuite");
+ }
+
+ @After
+ public void tearDown() {
+ sc.stop();
+ sc = null;
+ }
+
+ @Test
+ public void runDT() {
+ int nPoints = 20;
+ double A = 2.0;
+ double B = -1.5;
+
+ JavaRDD data = sc.parallelize(
+ LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
+ Map categoricalFeatures = new HashMap();
+ DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2);
+
+ // This tests setters. Training with various options is tested in Scala.
+ DecisionTreeClassifier dt = new DecisionTreeClassifier()
+ .setMaxDepth(2)
+ .setMaxBins(10)
+ .setMinInstancesPerNode(5)
+ .setMinInfoGain(0.0)
+ .setMaxMemoryInMB(256)
+ .setCacheNodeIds(false)
+ .setCheckpointInterval(10)
+ .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
+ for (int i = 0; i < DecisionTreeClassifier.supportedImpurities().length; ++i) {
+ dt.setImpurity(DecisionTreeClassifier.supportedImpurities()[i]);
+ }
+ DecisionTreeClassificationModel model = dt.fit(dataFrame);
+
+ model.transform(dataFrame);
+ model.numNodes();
+ model.depth();
+ model.toDebugString();
+
+ /*
+ // TODO: Add test once save/load are implemented.
+ File tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark");
+ String path = tempDir.toURI().toString();
+ try {
+ model3.save(sc.sc(), path);
+ DecisionTreeClassificationModel sameModel =
+ DecisionTreeClassificationModel.load(sc.sc(), path);
+ TreeTests.checkEqual(model3, sameModel);
+ } finally {
+ Utils.deleteRecursively(tempDir);
+ }
+ */
+ }
+}
diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java
new file mode 100644
index 0000000000000..a3a339004f31c
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.regression;
+
+import java.io.File;
+import java.io.Serializable;
+import java.util.HashMap;
+import java.util.Map;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.ml.impl.TreeTests;
+import org.apache.spark.mllib.classification.LogisticRegressionSuite;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.util.Utils;
+
+
+public class JavaDecisionTreeRegressorSuite implements Serializable {
+
+ private transient JavaSparkContext sc;
+
+ @Before
+ public void setUp() {
+ sc = new JavaSparkContext("local", "JavaDecisionTreeRegressorSuite");
+ }
+
+ @After
+ public void tearDown() {
+ sc.stop();
+ sc = null;
+ }
+
+ @Test
+ public void runDT() {
+ int nPoints = 20;
+ double A = 2.0;
+ double B = -1.5;
+
+ JavaRDD data = sc.parallelize(
+ LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
+ Map categoricalFeatures = new HashMap();
+ DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2);
+
+ // This tests setters. Training with various options is tested in Scala.
+ DecisionTreeRegressor dt = new DecisionTreeRegressor()
+ .setMaxDepth(2)
+ .setMaxBins(10)
+ .setMinInstancesPerNode(5)
+ .setMinInfoGain(0.0)
+ .setMaxMemoryInMB(256)
+ .setCacheNodeIds(false)
+ .setCheckpointInterval(10)
+ .setMaxDepth(2); // duplicate setMaxDepth to check builder pattern
+ for (int i = 0; i < DecisionTreeRegressor.supportedImpurities().length; ++i) {
+ dt.setImpurity(DecisionTreeRegressor.supportedImpurities()[i]);
+ }
+ DecisionTreeRegressionModel model = dt.fit(dataFrame);
+
+ model.transform(dataFrame);
+ model.numNodes();
+ model.depth();
+ model.toDebugString();
+
+ /*
+ // TODO: Add test once save/load are implemented.
+ File tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), "spark");
+ String path = tempDir.toURI().toString();
+ try {
+ model2.save(sc.sc(), path);
+ DecisionTreeRegressionModel sameModel = DecisionTreeRegressionModel.load(sc.sc(), path);
+ TreeTests.checkEqual(model2, sameModel);
+ } finally {
+ Utils.deleteRecursively(tempDir);
+ }
+ */
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala
index 0dcfe5a2002dc..17ddd335deb6d 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala
@@ -44,7 +44,7 @@ class AttributeGroupSuite extends FunSuite {
group("abc")
}
assert(group === AttributeGroup.fromMetadata(group.toMetadataImpl, group.name))
- assert(group === AttributeGroup.fromStructField(group.toStructField))
+ assert(group === AttributeGroup.fromStructField(group.toStructField()))
}
test("attribute group without attributes") {
@@ -54,7 +54,7 @@ class AttributeGroupSuite extends FunSuite {
assert(group0.size === 10)
assert(group0.attributes.isEmpty)
assert(group0 === AttributeGroup.fromMetadata(group0.toMetadataImpl, group0.name))
- assert(group0 === AttributeGroup.fromStructField(group0.toStructField))
+ assert(group0 === AttributeGroup.fromStructField(group0.toStructField()))
val group1 = new AttributeGroup("item")
assert(group1.name === "item")
diff --git a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala
index 6ec35b03656f9..3e1a7196e37cb 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala
@@ -36,9 +36,9 @@ class AttributeSuite extends FunSuite {
assert(attr.max.isEmpty)
assert(attr.std.isEmpty)
assert(attr.sparsity.isEmpty)
- assert(attr.toMetadata() === metadata)
- assert(attr.toMetadata(withType = false) === metadata)
- assert(attr.toMetadata(withType = true) === metadataWithType)
+ assert(attr.toMetadataImpl() === metadata)
+ assert(attr.toMetadataImpl(withType = false) === metadata)
+ assert(attr.toMetadataImpl(withType = true) === metadataWithType)
assert(attr === Attribute.fromMetadata(metadata))
assert(attr === Attribute.fromMetadata(metadataWithType))
intercept[NoSuchElementException] {
@@ -59,9 +59,9 @@ class AttributeSuite extends FunSuite {
assert(!attr.isNominal)
assert(attr.name === Some(name))
assert(attr.index === Some(index))
- assert(attr.toMetadata() === metadata)
- assert(attr.toMetadata(withType = false) === metadata)
- assert(attr.toMetadata(withType = true) === metadataWithType)
+ assert(attr.toMetadataImpl() === metadata)
+ assert(attr.toMetadataImpl(withType = false) === metadata)
+ assert(attr.toMetadataImpl(withType = true) === metadataWithType)
assert(attr === Attribute.fromMetadata(metadata))
assert(attr === Attribute.fromMetadata(metadataWithType))
val field = attr.toStructField()
@@ -81,7 +81,7 @@ class AttributeSuite extends FunSuite {
assert(attr2.max === Some(1.0))
assert(attr2.std === Some(0.5))
assert(attr2.sparsity === Some(0.3))
- assert(attr2 === Attribute.fromMetadata(attr2.toMetadata()))
+ assert(attr2 === Attribute.fromMetadata(attr2.toMetadataImpl()))
}
test("bad numeric attributes") {
@@ -105,9 +105,9 @@ class AttributeSuite extends FunSuite {
assert(attr.values.isEmpty)
assert(attr.numValues.isEmpty)
assert(attr.isOrdinal.isEmpty)
- assert(attr.toMetadata() === metadata)
- assert(attr.toMetadata(withType = true) === metadata)
- assert(attr.toMetadata(withType = false) === metadataWithoutType)
+ assert(attr.toMetadataImpl() === metadata)
+ assert(attr.toMetadataImpl(withType = true) === metadata)
+ assert(attr.toMetadataImpl(withType = false) === metadataWithoutType)
assert(attr === Attribute.fromMetadata(metadata))
assert(attr === NominalAttribute.fromMetadata(metadataWithoutType))
intercept[NoSuchElementException] {
@@ -135,9 +135,9 @@ class AttributeSuite extends FunSuite {
assert(attr.values === Some(values))
assert(attr.indexOf("medium") === 1)
assert(attr.getValue(1) === "medium")
- assert(attr.toMetadata() === metadata)
- assert(attr.toMetadata(withType = true) === metadata)
- assert(attr.toMetadata(withType = false) === metadataWithoutType)
+ assert(attr.toMetadataImpl() === metadata)
+ assert(attr.toMetadataImpl(withType = true) === metadata)
+ assert(attr.toMetadataImpl(withType = false) === metadataWithoutType)
assert(attr === Attribute.fromMetadata(metadata))
assert(attr === NominalAttribute.fromMetadata(metadataWithoutType))
assert(attr.withoutIndex === Attribute.fromStructField(attr.toStructField()))
@@ -147,8 +147,8 @@ class AttributeSuite extends FunSuite {
assert(attr2.index.isEmpty)
assert(attr2.values.get === Array("small", "medium", "large", "x-large"))
assert(attr2.indexOf("x-large") === 3)
- assert(attr2 === Attribute.fromMetadata(attr2.toMetadata()))
- assert(attr2 === NominalAttribute.fromMetadata(attr2.toMetadata(withType = false)))
+ assert(attr2 === Attribute.fromMetadata(attr2.toMetadataImpl()))
+ assert(attr2 === NominalAttribute.fromMetadata(attr2.toMetadataImpl(withType = false)))
}
test("bad nominal attributes") {
@@ -168,9 +168,9 @@ class AttributeSuite extends FunSuite {
assert(attr.name.isEmpty)
assert(attr.index.isEmpty)
assert(attr.values.isEmpty)
- assert(attr.toMetadata() === metadata)
- assert(attr.toMetadata(withType = true) === metadata)
- assert(attr.toMetadata(withType = false) === metadataWithoutType)
+ assert(attr.toMetadataImpl() === metadata)
+ assert(attr.toMetadataImpl(withType = true) === metadata)
+ assert(attr.toMetadataImpl(withType = false) === metadataWithoutType)
assert(attr === Attribute.fromMetadata(metadata))
assert(attr === BinaryAttribute.fromMetadata(metadataWithoutType))
intercept[NoSuchElementException] {
@@ -196,9 +196,9 @@ class AttributeSuite extends FunSuite {
assert(attr.name === Some(name))
assert(attr.index === Some(index))
assert(attr.values.get === values)
- assert(attr.toMetadata() === metadata)
- assert(attr.toMetadata(withType = true) === metadata)
- assert(attr.toMetadata(withType = false) === metadataWithoutType)
+ assert(attr.toMetadataImpl() === metadata)
+ assert(attr.toMetadataImpl(withType = true) === metadata)
+ assert(attr.toMetadataImpl(withType = false) === metadataWithoutType)
assert(attr === Attribute.fromMetadata(metadata))
assert(attr === BinaryAttribute.fromMetadata(metadataWithoutType))
assert(attr.withoutIndex === Attribute.fromStructField(attr.toStructField()))
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
new file mode 100644
index 0000000000000..af88595df5245
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
@@ -0,0 +1,274 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.ml.impl.TreeTests
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree,
+ DecisionTreeSuite => OldDecisionTreeSuite}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.DataFrame
+
+
+class DecisionTreeClassifierSuite extends FunSuite with MLlibTestSparkContext {
+
+ import DecisionTreeClassifierSuite.compareAPIs
+
+ private var categoricalDataPointsRDD: RDD[LabeledPoint] = _
+ private var orderedLabeledPointsWithLabel0RDD: RDD[LabeledPoint] = _
+ private var orderedLabeledPointsWithLabel1RDD: RDD[LabeledPoint] = _
+ private var categoricalDataPointsForMulticlassRDD: RDD[LabeledPoint] = _
+ private var continuousDataPointsForMulticlassRDD: RDD[LabeledPoint] = _
+ private var categoricalDataPointsForMulticlassForOrderedFeaturesRDD: RDD[LabeledPoint] = _
+
+ override def beforeAll() {
+ super.beforeAll()
+ categoricalDataPointsRDD =
+ sc.parallelize(OldDecisionTreeSuite.generateCategoricalDataPoints())
+ orderedLabeledPointsWithLabel0RDD =
+ sc.parallelize(OldDecisionTreeSuite.generateOrderedLabeledPointsWithLabel0())
+ orderedLabeledPointsWithLabel1RDD =
+ sc.parallelize(OldDecisionTreeSuite.generateOrderedLabeledPointsWithLabel1())
+ categoricalDataPointsForMulticlassRDD =
+ sc.parallelize(OldDecisionTreeSuite.generateCategoricalDataPointsForMulticlass())
+ continuousDataPointsForMulticlassRDD =
+ sc.parallelize(OldDecisionTreeSuite.generateContinuousDataPointsForMulticlass())
+ categoricalDataPointsForMulticlassForOrderedFeaturesRDD = sc.parallelize(
+ OldDecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures())
+ }
+
+ /////////////////////////////////////////////////////////////////////////////
+ // Tests calling train()
+ /////////////////////////////////////////////////////////////////////////////
+
+ test("Binary classification stump with ordered categorical features") {
+ val dt = new DecisionTreeClassifier()
+ .setImpurity("gini")
+ .setMaxDepth(2)
+ .setMaxBins(100)
+ val categoricalFeatures = Map(0 -> 3, 1-> 3)
+ val numClasses = 2
+ compareAPIs(categoricalDataPointsRDD, dt, categoricalFeatures, numClasses)
+ }
+
+ test("Binary classification stump with fixed labels 0,1 for Entropy,Gini") {
+ val dt = new DecisionTreeClassifier()
+ .setMaxDepth(3)
+ .setMaxBins(100)
+ val numClasses = 2
+ Array(orderedLabeledPointsWithLabel0RDD, orderedLabeledPointsWithLabel1RDD).foreach { rdd =>
+ DecisionTreeClassifier.supportedImpurities.foreach { impurity =>
+ dt.setImpurity(impurity)
+ compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses)
+ }
+ }
+ }
+
+ test("Multiclass classification stump with 3-ary (unordered) categorical features") {
+ val rdd = categoricalDataPointsForMulticlassRDD
+ val dt = new DecisionTreeClassifier()
+ .setImpurity("Gini")
+ .setMaxDepth(4)
+ val numClasses = 3
+ val categoricalFeatures = Map(0 -> 3, 1 -> 3)
+ compareAPIs(rdd, dt, categoricalFeatures, numClasses)
+ }
+
+ test("Binary classification stump with 1 continuous feature, to check off-by-1 error") {
+ val arr = Array(
+ LabeledPoint(0.0, Vectors.dense(0.0)),
+ LabeledPoint(1.0, Vectors.dense(1.0)),
+ LabeledPoint(1.0, Vectors.dense(2.0)),
+ LabeledPoint(1.0, Vectors.dense(3.0)))
+ val rdd = sc.parallelize(arr)
+ val dt = new DecisionTreeClassifier()
+ .setImpurity("Gini")
+ .setMaxDepth(4)
+ val numClasses = 2
+ compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses)
+ }
+
+ test("Binary classification stump with 2 continuous features") {
+ val arr = Array(
+ LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))),
+ LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))),
+ LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))),
+ LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 2.0)))))
+ val rdd = sc.parallelize(arr)
+ val dt = new DecisionTreeClassifier()
+ .setImpurity("Gini")
+ .setMaxDepth(4)
+ val numClasses = 2
+ compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses)
+ }
+
+ test("Multiclass classification stump with unordered categorical features," +
+ " with just enough bins") {
+ val maxBins = 2 * (math.pow(2, 3 - 1).toInt - 1) // just enough bins to allow unordered features
+ val rdd = categoricalDataPointsForMulticlassRDD
+ val dt = new DecisionTreeClassifier()
+ .setImpurity("Gini")
+ .setMaxDepth(4)
+ .setMaxBins(maxBins)
+ val categoricalFeatures = Map(0 -> 3, 1 -> 3)
+ val numClasses = 3
+ compareAPIs(rdd, dt, categoricalFeatures, numClasses)
+ }
+
+ test("Multiclass classification stump with continuous features") {
+ val rdd = continuousDataPointsForMulticlassRDD
+ val dt = new DecisionTreeClassifier()
+ .setImpurity("Gini")
+ .setMaxDepth(4)
+ .setMaxBins(100)
+ val numClasses = 3
+ compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses)
+ }
+
+ test("Multiclass classification stump with continuous + unordered categorical features") {
+ val rdd = continuousDataPointsForMulticlassRDD
+ val dt = new DecisionTreeClassifier()
+ .setImpurity("Gini")
+ .setMaxDepth(4)
+ .setMaxBins(100)
+ val categoricalFeatures = Map(0 -> 3)
+ val numClasses = 3
+ compareAPIs(rdd, dt, categoricalFeatures, numClasses)
+ }
+
+ test("Multiclass classification stump with 10-ary (ordered) categorical features") {
+ val rdd = categoricalDataPointsForMulticlassForOrderedFeaturesRDD
+ val dt = new DecisionTreeClassifier()
+ .setImpurity("Gini")
+ .setMaxDepth(4)
+ .setMaxBins(100)
+ val categoricalFeatures = Map(0 -> 10, 1 -> 10)
+ val numClasses = 3
+ compareAPIs(rdd, dt, categoricalFeatures, numClasses)
+ }
+
+ test("Multiclass classification tree with 10-ary (ordered) categorical features," +
+ " with just enough bins") {
+ val rdd = categoricalDataPointsForMulticlassForOrderedFeaturesRDD
+ val dt = new DecisionTreeClassifier()
+ .setImpurity("Gini")
+ .setMaxDepth(4)
+ .setMaxBins(10)
+ val categoricalFeatures = Map(0 -> 10, 1 -> 10)
+ val numClasses = 3
+ compareAPIs(rdd, dt, categoricalFeatures, numClasses)
+ }
+
+ test("split must satisfy min instances per node requirements") {
+ val arr = Array(
+ LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))),
+ LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))),
+ LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0)))))
+ val rdd = sc.parallelize(arr)
+ val dt = new DecisionTreeClassifier()
+ .setImpurity("Gini")
+ .setMaxDepth(2)
+ .setMinInstancesPerNode(2)
+ val numClasses = 2
+ compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses)
+ }
+
+ test("do not choose split that does not satisfy min instance per node requirements") {
+ // if a split does not satisfy min instances per node requirements,
+ // this split is invalid, even though the information gain of split is large.
+ val arr = Array(
+ LabeledPoint(0.0, Vectors.dense(0.0, 1.0)),
+ LabeledPoint(1.0, Vectors.dense(1.0, 1.0)),
+ LabeledPoint(0.0, Vectors.dense(0.0, 0.0)),
+ LabeledPoint(0.0, Vectors.dense(0.0, 0.0)))
+ val rdd = sc.parallelize(arr)
+ val dt = new DecisionTreeClassifier()
+ .setImpurity("Gini")
+ .setMaxBins(2)
+ .setMaxDepth(2)
+ .setMinInstancesPerNode(2)
+ val categoricalFeatures = Map(0 -> 2, 1-> 2)
+ val numClasses = 2
+ compareAPIs(rdd, dt, categoricalFeatures, numClasses)
+ }
+
+ test("split must satisfy min info gain requirements") {
+ val arr = Array(
+ LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))),
+ LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))),
+ LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0)))))
+ val rdd = sc.parallelize(arr)
+
+ val dt = new DecisionTreeClassifier()
+ .setImpurity("Gini")
+ .setMaxDepth(2)
+ .setMinInfoGain(1.0)
+ val numClasses = 2
+ compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses)
+ }
+
+ /////////////////////////////////////////////////////////////////////////////
+ // Tests of model save/load
+ /////////////////////////////////////////////////////////////////////////////
+
+ // TODO: Reinstate test once save/load are implemented
+ /*
+ test("model save/load") {
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+
+ val oldModel = OldDecisionTreeSuite.createModel(OldAlgo.Classification)
+ val newModel = DecisionTreeClassificationModel.fromOld(oldModel)
+
+ // Save model, load it back, and compare.
+ try {
+ newModel.save(sc, path)
+ val sameNewModel = DecisionTreeClassificationModel.load(sc, path)
+ TreeTests.checkEqual(newModel, sameNewModel)
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+ }
+ */
+}
+
+private[ml] object DecisionTreeClassifierSuite extends FunSuite {
+
+ /**
+ * Train 2 decision trees on the given dataset, one using the old API and one using the new API.
+ * Convert the old tree to the new format, compare them, and fail if they are not exactly equal.
+ */
+ def compareAPIs(
+ data: RDD[LabeledPoint],
+ dt: DecisionTreeClassifier,
+ categoricalFeatures: Map[Int, Int],
+ numClasses: Int): Unit = {
+ val oldStrategy = dt.getOldStrategy(categoricalFeatures, numClasses)
+ val oldTree = OldDecisionTree.train(data, oldStrategy)
+ val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses)
+ val newTree = dt.fit(newData)
+ // Use parent, fittingParamMap from newTree since these are not checked anyways.
+ val oldTreeAsNew = DecisionTreeClassificationModel.fromOld(oldTree, newTree.parent,
+ newTree.fittingParamMap, categoricalFeatures)
+ TreeTests.checkEqual(oldTreeAsNew, newTree)
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
index 81ef831c42e55..1b261b2643854 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorIndexerSuite.scala
@@ -228,7 +228,7 @@ class VectorIndexerSuite extends FunSuite with MLlibTestSparkContext {
}
val attrGroup = new AttributeGroup("features", featureAttributes)
val densePoints1WithMeta =
- densePoints1.select(densePoints1("features").as("features", attrGroup.toMetadata))
+ densePoints1.select(densePoints1("features").as("features", attrGroup.toMetadata()))
val vectorIndexer = getIndexer.setMaxCategories(2)
val model = vectorIndexer.fit(densePoints1WithMeta)
// Check that ML metadata are preserved.
diff --git a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala
new file mode 100644
index 0000000000000..2e57d4ce37f1d
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.impl
+
+import scala.collection.JavaConverters._
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute}
+import org.apache.spark.ml.impl.tree._
+import org.apache.spark.ml.tree.{DecisionTreeModel, InternalNode, LeafNode, Node}
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{SQLContext, DataFrame}
+
+
+private[ml] object TreeTests extends FunSuite {
+
+ /**
+ * Convert the given data to a DataFrame, and set the features and label metadata.
+ * @param data Dataset. Categorical features and labels must already have 0-based indices.
+ * This must be non-empty.
+ * @param categoricalFeatures Map: categorical feature index -> number of distinct values
+ * @param numClasses Number of classes label can take. If 0, mark as continuous.
+ * @return DataFrame with metadata
+ */
+ def setMetadata(
+ data: RDD[LabeledPoint],
+ categoricalFeatures: Map[Int, Int],
+ numClasses: Int): DataFrame = {
+ val sqlContext = new SQLContext(data.sparkContext)
+ import sqlContext.implicits._
+ val df = data.toDF()
+ val numFeatures = data.first().features.size
+ val featuresAttributes = Range(0, numFeatures).map { feature =>
+ if (categoricalFeatures.contains(feature)) {
+ NominalAttribute.defaultAttr.withIndex(feature).withNumValues(categoricalFeatures(feature))
+ } else {
+ NumericAttribute.defaultAttr.withIndex(feature)
+ }
+ }.toArray
+ val featuresMetadata = new AttributeGroup("features", featuresAttributes).toMetadata()
+ val labelAttribute = if (numClasses == 0) {
+ NumericAttribute.defaultAttr.withName("label")
+ } else {
+ NominalAttribute.defaultAttr.withName("label").withNumValues(numClasses)
+ }
+ val labelMetadata = labelAttribute.toMetadata()
+ df.select(df("features").as("features", featuresMetadata),
+ df("label").as("label", labelMetadata))
+ }
+
+ /** Java-friendly version of [[setMetadata()]] */
+ def setMetadata(
+ data: JavaRDD[LabeledPoint],
+ categoricalFeatures: java.util.Map[java.lang.Integer, java.lang.Integer],
+ numClasses: Int): DataFrame = {
+ setMetadata(data.rdd, categoricalFeatures.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
+ numClasses)
+ }
+
+ /**
+ * Check if the two trees are exactly the same.
+ * Note: I hesitate to override Node.equals since it could cause problems if users
+ * make mistakes such as creating loops of Nodes.
+ * If the trees are not equal, this prints the two trees and throws an exception.
+ */
+ def checkEqual(a: DecisionTreeModel, b: DecisionTreeModel): Unit = {
+ try {
+ checkEqual(a.rootNode, b.rootNode)
+ } catch {
+ case ex: Exception =>
+ throw new AssertionError("checkEqual failed since the two trees were not identical.\n" +
+ "TREE A:\n" + a.toDebugString + "\n" +
+ "TREE B:\n" + b.toDebugString + "\n", ex)
+ }
+ }
+
+ /**
+ * Return true iff the two nodes and their descendants are exactly the same.
+ * Note: I hesitate to override Node.equals since it could cause problems if users
+ * make mistakes such as creating loops of Nodes.
+ */
+ private def checkEqual(a: Node, b: Node): Unit = {
+ assert(a.prediction === b.prediction)
+ assert(a.impurity === b.impurity)
+ (a, b) match {
+ case (aye: InternalNode, bee: InternalNode) =>
+ assert(aye.split === bee.split)
+ checkEqual(aye.leftChild, bee.leftChild)
+ checkEqual(aye.rightChild, bee.rightChild)
+ case (aye: LeafNode, bee: LeafNode) => // do nothing
+ case _ =>
+ throw new AssertionError("Found mismatched nodes")
+ }
+ }
+
+ // TODO: Reinstate after adding ensembles
+ /**
+ * Check if the two models are exactly the same.
+ * If the models are not equal, this throws an exception.
+ */
+ /*
+ def checkEqual(a: TreeEnsembleModel, b: TreeEnsembleModel): Unit = {
+ try {
+ a.getTrees.zip(b.getTrees).foreach { case (treeA, treeB) =>
+ TreeTests.checkEqual(treeA, treeB)
+ }
+ assert(a.getTreeWeights === b.getTreeWeights)
+ } catch {
+ case ex: Exception => throw new AssertionError(
+ "checkEqual failed since the two tree ensembles were not identical")
+ }
+ }
+ */
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
new file mode 100644
index 0000000000000..0b40fe33fae9d
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.regression
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.ml.impl.TreeTests
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree,
+ DecisionTreeSuite => OldDecisionTreeSuite}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.DataFrame
+
+
+class DecisionTreeRegressorSuite extends FunSuite with MLlibTestSparkContext {
+
+ import DecisionTreeRegressorSuite.compareAPIs
+
+ private var categoricalDataPointsRDD: RDD[LabeledPoint] = _
+
+ override def beforeAll() {
+ super.beforeAll()
+ categoricalDataPointsRDD =
+ sc.parallelize(OldDecisionTreeSuite.generateCategoricalDataPoints())
+ }
+
+ /////////////////////////////////////////////////////////////////////////////
+ // Tests calling train()
+ /////////////////////////////////////////////////////////////////////////////
+
+ test("Regression stump with 3-ary (ordered) categorical features") {
+ val dt = new DecisionTreeRegressor()
+ .setImpurity("variance")
+ .setMaxDepth(2)
+ .setMaxBins(100)
+ val categoricalFeatures = Map(0 -> 3, 1-> 3)
+ compareAPIs(categoricalDataPointsRDD, dt, categoricalFeatures)
+ }
+
+ test("Regression stump with binary (ordered) categorical features") {
+ val dt = new DecisionTreeRegressor()
+ .setImpurity("variance")
+ .setMaxDepth(2)
+ .setMaxBins(100)
+ val categoricalFeatures = Map(0 -> 2, 1-> 2)
+ compareAPIs(categoricalDataPointsRDD, dt, categoricalFeatures)
+ }
+
+ /////////////////////////////////////////////////////////////////////////////
+ // Tests of model save/load
+ /////////////////////////////////////////////////////////////////////////////
+
+ // TODO: test("model save/load")
+}
+
+private[ml] object DecisionTreeRegressorSuite extends FunSuite {
+
+ /**
+ * Train 2 decision trees on the given dataset, one using the old API and one using the new API.
+ * Convert the old tree to the new format, compare them, and fail if they are not exactly equal.
+ */
+ def compareAPIs(
+ data: RDD[LabeledPoint],
+ dt: DecisionTreeRegressor,
+ categoricalFeatures: Map[Int, Int]): Unit = {
+ val oldStrategy = dt.getOldStrategy(categoricalFeatures)
+ val oldTree = OldDecisionTree.train(data, oldStrategy)
+ val newData: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 0)
+ val newTree = dt.fit(newData)
+ // Use parent, fittingParamMap from newTree since these are not checked anyways.
+ val oldTreeAsNew = DecisionTreeRegressionModel.fromOld(oldTree, newTree.parent,
+ newTree.fittingParamMap, categoricalFeatures)
+ TreeTests.checkEqual(oldTreeAsNew, newTree)
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index 4c162df810bb2..249b8eae19b17 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -36,6 +36,10 @@ import org.apache.spark.util.Utils
class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
+ /////////////////////////////////////////////////////////////////////////////
+ // Tests examining individual elements of training
+ /////////////////////////////////////////////////////////////////////////////
+
test("Binary classification with continuous features: split and bin calculation") {
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
assert(arr.length === 1000)
@@ -254,6 +258,165 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
assert(bins(0).length === 0)
}
+ test("Avoid aggregation on the last level") {
+ val arr = Array(
+ LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0)),
+ LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)),
+ LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)),
+ LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0)))
+ val input = sc.parallelize(arr)
+
+ val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 1,
+ numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3))
+ val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
+ val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
+
+ val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
+ val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false)
+
+ val topNode = Node.emptyNode(nodeIndex = 1)
+ assert(topNode.predict.predict === Double.MinValue)
+ assert(topNode.impurity === -1.0)
+ assert(topNode.isLeaf === false)
+
+ val nodesForGroup = Map((0, Array(topNode)))
+ val treeToNodeToIndexInfo = Map((0, Map(
+ (topNode.id, new RandomForest.NodeIndexInfo(0, None))
+ )))
+ val nodeQueue = new mutable.Queue[(Int, Node)]()
+ DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode),
+ nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue)
+
+ // don't enqueue leaf nodes into node queue
+ assert(nodeQueue.isEmpty)
+
+ // set impurity and predict for topNode
+ assert(topNode.predict.predict !== Double.MinValue)
+ assert(topNode.impurity !== -1.0)
+
+ // set impurity and predict for child nodes
+ assert(topNode.leftNode.get.predict.predict === 0.0)
+ assert(topNode.rightNode.get.predict.predict === 1.0)
+ assert(topNode.leftNode.get.impurity === 0.0)
+ assert(topNode.rightNode.get.impurity === 0.0)
+ }
+
+ test("Avoid aggregation if impurity is 0.0") {
+ val arr = Array(
+ LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0)),
+ LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0)),
+ LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0)),
+ LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0)))
+ val input = sc.parallelize(arr)
+
+ val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5,
+ numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3))
+ val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
+ val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
+
+ val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
+ val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false)
+
+ val topNode = Node.emptyNode(nodeIndex = 1)
+ assert(topNode.predict.predict === Double.MinValue)
+ assert(topNode.impurity === -1.0)
+ assert(topNode.isLeaf === false)
+
+ val nodesForGroup = Map((0, Array(topNode)))
+ val treeToNodeToIndexInfo = Map((0, Map(
+ (topNode.id, new RandomForest.NodeIndexInfo(0, None))
+ )))
+ val nodeQueue = new mutable.Queue[(Int, Node)]()
+ DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode),
+ nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue)
+
+ // don't enqueue a node into node queue if its impurity is 0.0
+ assert(nodeQueue.isEmpty)
+
+ // set impurity and predict for topNode
+ assert(topNode.predict.predict !== Double.MinValue)
+ assert(topNode.impurity !== -1.0)
+
+ // set impurity and predict for child nodes
+ assert(topNode.leftNode.get.predict.predict === 0.0)
+ assert(topNode.rightNode.get.predict.predict === 1.0)
+ assert(topNode.leftNode.get.impurity === 0.0)
+ assert(topNode.rightNode.get.impurity === 0.0)
+ }
+
+ test("Second level node building with vs. without groups") {
+ val arr = DecisionTreeSuite.generateOrderedLabeledPoints()
+ assert(arr.length === 1000)
+ val rdd = sc.parallelize(arr)
+ val strategy = new Strategy(Classification, Entropy, 3, 2, 100)
+ val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
+ val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
+ assert(splits.length === 2)
+ assert(splits(0).length === 99)
+ assert(bins.length === 2)
+ assert(bins(0).length === 100)
+
+ // Train a 1-node model
+ val strategyOneNode = new Strategy(Classification, Entropy, maxDepth = 1,
+ numClasses = 2, maxBins = 100)
+ val modelOneNode = DecisionTree.train(rdd, strategyOneNode)
+ val rootNode1 = modelOneNode.topNode.deepCopy()
+ val rootNode2 = modelOneNode.topNode.deepCopy()
+ assert(rootNode1.leftNode.nonEmpty)
+ assert(rootNode1.rightNode.nonEmpty)
+
+ val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
+ val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false)
+
+ // Single group second level tree construction.
+ val nodesForGroup = Map((0, Array(rootNode1.leftNode.get, rootNode1.rightNode.get)))
+ val treeToNodeToIndexInfo = Map((0, Map(
+ (rootNode1.leftNode.get.id, new RandomForest.NodeIndexInfo(0, None)),
+ (rootNode1.rightNode.get.id, new RandomForest.NodeIndexInfo(1, None)))))
+ val nodeQueue = new mutable.Queue[(Int, Node)]()
+ DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode1),
+ nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue)
+ val children1 = new Array[Node](2)
+ children1(0) = rootNode1.leftNode.get
+ children1(1) = rootNode1.rightNode.get
+
+ // Train one second-level node at a time.
+ val nodesForGroupA = Map((0, Array(rootNode2.leftNode.get)))
+ val treeToNodeToIndexInfoA = Map((0, Map(
+ (rootNode2.leftNode.get.id, new RandomForest.NodeIndexInfo(0, None)))))
+ nodeQueue.clear()
+ DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode2),
+ nodesForGroupA, treeToNodeToIndexInfoA, splits, bins, nodeQueue)
+ val nodesForGroupB = Map((0, Array(rootNode2.rightNode.get)))
+ val treeToNodeToIndexInfoB = Map((0, Map(
+ (rootNode2.rightNode.get.id, new RandomForest.NodeIndexInfo(0, None)))))
+ nodeQueue.clear()
+ DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode2),
+ nodesForGroupB, treeToNodeToIndexInfoB, splits, bins, nodeQueue)
+ val children2 = new Array[Node](2)
+ children2(0) = rootNode2.leftNode.get
+ children2(1) = rootNode2.rightNode.get
+
+ // Verify whether the splits obtained using single group and multiple group level
+ // construction strategies are the same.
+ for (i <- 0 until 2) {
+ assert(children1(i).stats.nonEmpty && children1(i).stats.get.gain > 0)
+ assert(children2(i).stats.nonEmpty && children2(i).stats.get.gain > 0)
+ assert(children1(i).split === children2(i).split)
+ assert(children1(i).stats.nonEmpty && children2(i).stats.nonEmpty)
+ val stats1 = children1(i).stats.get
+ val stats2 = children2(i).stats.get
+ assert(stats1.gain === stats2.gain)
+ assert(stats1.impurity === stats2.impurity)
+ assert(stats1.leftImpurity === stats2.leftImpurity)
+ assert(stats1.rightImpurity === stats2.rightImpurity)
+ assert(children1(i).predict.predict === children2(i).predict.predict)
+ }
+ }
+
+ /////////////////////////////////////////////////////////////////////////////
+ // Tests calling train()
+ /////////////////////////////////////////////////////////////////////////////
test("Binary classification stump with ordered categorical features") {
val arr = DecisionTreeSuite.generateCategoricalDataPoints()
@@ -438,76 +601,6 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
assert(rootNode.predict.predict === 1)
}
- test("Second level node building with vs. without groups") {
- val arr = DecisionTreeSuite.generateOrderedLabeledPoints()
- assert(arr.length === 1000)
- val rdd = sc.parallelize(arr)
- val strategy = new Strategy(Classification, Entropy, 3, 2, 100)
- val metadata = DecisionTreeMetadata.buildMetadata(rdd, strategy)
- val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
- assert(splits.length === 2)
- assert(splits(0).length === 99)
- assert(bins.length === 2)
- assert(bins(0).length === 100)
-
- // Train a 1-node model
- val strategyOneNode = new Strategy(Classification, Entropy, maxDepth = 1,
- numClasses = 2, maxBins = 100)
- val modelOneNode = DecisionTree.train(rdd, strategyOneNode)
- val rootNode1 = modelOneNode.topNode.deepCopy()
- val rootNode2 = modelOneNode.topNode.deepCopy()
- assert(rootNode1.leftNode.nonEmpty)
- assert(rootNode1.rightNode.nonEmpty)
-
- val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
- val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false)
-
- // Single group second level tree construction.
- val nodesForGroup = Map((0, Array(rootNode1.leftNode.get, rootNode1.rightNode.get)))
- val treeToNodeToIndexInfo = Map((0, Map(
- (rootNode1.leftNode.get.id, new RandomForest.NodeIndexInfo(0, None)),
- (rootNode1.rightNode.get.id, new RandomForest.NodeIndexInfo(1, None)))))
- val nodeQueue = new mutable.Queue[(Int, Node)]()
- DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode1),
- nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue)
- val children1 = new Array[Node](2)
- children1(0) = rootNode1.leftNode.get
- children1(1) = rootNode1.rightNode.get
-
- // Train one second-level node at a time.
- val nodesForGroupA = Map((0, Array(rootNode2.leftNode.get)))
- val treeToNodeToIndexInfoA = Map((0, Map(
- (rootNode2.leftNode.get.id, new RandomForest.NodeIndexInfo(0, None)))))
- nodeQueue.clear()
- DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode2),
- nodesForGroupA, treeToNodeToIndexInfoA, splits, bins, nodeQueue)
- val nodesForGroupB = Map((0, Array(rootNode2.rightNode.get)))
- val treeToNodeToIndexInfoB = Map((0, Map(
- (rootNode2.rightNode.get.id, new RandomForest.NodeIndexInfo(0, None)))))
- nodeQueue.clear()
- DecisionTree.findBestSplits(baggedInput, metadata, Array(rootNode2),
- nodesForGroupB, treeToNodeToIndexInfoB, splits, bins, nodeQueue)
- val children2 = new Array[Node](2)
- children2(0) = rootNode2.leftNode.get
- children2(1) = rootNode2.rightNode.get
-
- // Verify whether the splits obtained using single group and multiple group level
- // construction strategies are the same.
- for (i <- 0 until 2) {
- assert(children1(i).stats.nonEmpty && children1(i).stats.get.gain > 0)
- assert(children2(i).stats.nonEmpty && children2(i).stats.get.gain > 0)
- assert(children1(i).split === children2(i).split)
- assert(children1(i).stats.nonEmpty && children2(i).stats.nonEmpty)
- val stats1 = children1(i).stats.get
- val stats2 = children2(i).stats.get
- assert(stats1.gain === stats2.gain)
- assert(stats1.impurity === stats2.impurity)
- assert(stats1.leftImpurity === stats2.leftImpurity)
- assert(stats1.rightImpurity === stats2.rightImpurity)
- assert(children1(i).predict.predict === children2(i).predict.predict)
- }
- }
-
test("Multiclass classification stump with 3-ary (unordered) categorical features") {
val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass()
val rdd = sc.parallelize(arr)
@@ -528,11 +621,11 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
}
test("Binary classification stump with 1 continuous feature, to check off-by-1 error") {
- val arr = new Array[LabeledPoint](4)
- arr(0) = new LabeledPoint(0.0, Vectors.dense(0.0))
- arr(1) = new LabeledPoint(1.0, Vectors.dense(1.0))
- arr(2) = new LabeledPoint(1.0, Vectors.dense(2.0))
- arr(3) = new LabeledPoint(1.0, Vectors.dense(3.0))
+ val arr = Array(
+ LabeledPoint(0.0, Vectors.dense(0.0)),
+ LabeledPoint(1.0, Vectors.dense(1.0)),
+ LabeledPoint(1.0, Vectors.dense(2.0)),
+ LabeledPoint(1.0, Vectors.dense(3.0)))
val rdd = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
numClasses = 2)
@@ -544,11 +637,11 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
}
test("Binary classification stump with 2 continuous features") {
- val arr = new Array[LabeledPoint](4)
- arr(0) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0))))
- arr(1) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0))))
- arr(2) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0))))
- arr(3) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 2.0))))
+ val arr = Array(
+ LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))),
+ LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))),
+ LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))),
+ LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 2.0)))))
val rdd = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 4,
@@ -668,11 +761,10 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
}
test("split must satisfy min instances per node requirements") {
- val arr = new Array[LabeledPoint](3)
- arr(0) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0))))
- arr(1) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0))))
- arr(2) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0))))
-
+ val arr = Array(
+ LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))),
+ LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))),
+ LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0)))))
val rdd = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini,
maxDepth = 2, numClasses = 2, minInstancesPerNode = 2)
@@ -695,11 +787,11 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
test("do not choose split that does not satisfy min instance per node requirements") {
// if a split does not satisfy min instances per node requirements,
// this split is invalid, even though the information gain of split is large.
- val arr = new Array[LabeledPoint](4)
- arr(0) = new LabeledPoint(0.0, Vectors.dense(0.0, 1.0))
- arr(1) = new LabeledPoint(1.0, Vectors.dense(1.0, 1.0))
- arr(2) = new LabeledPoint(0.0, Vectors.dense(0.0, 0.0))
- arr(3) = new LabeledPoint(0.0, Vectors.dense(0.0, 0.0))
+ val arr = Array(
+ LabeledPoint(0.0, Vectors.dense(0.0, 1.0)),
+ LabeledPoint(1.0, Vectors.dense(1.0, 1.0)),
+ LabeledPoint(0.0, Vectors.dense(0.0, 0.0)),
+ LabeledPoint(0.0, Vectors.dense(0.0, 0.0)))
val rdd = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini,
@@ -715,10 +807,10 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
}
test("split must satisfy min info gain requirements") {
- val arr = new Array[LabeledPoint](3)
- arr(0) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0))))
- arr(1) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0))))
- arr(2) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0))))
+ val arr = Array(
+ LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0)))),
+ LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0)))),
+ LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0)))))
val input = sc.parallelize(arr)
val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2,
@@ -739,91 +831,9 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
assert(gain == InformationGainStats.invalidInformationGainStats)
}
- test("Avoid aggregation on the last level") {
- val arr = new Array[LabeledPoint](4)
- arr(0) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0))
- arr(1) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0))
- arr(2) = new LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0))
- arr(3) = new LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0))
- val input = sc.parallelize(arr)
-
- val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 1,
- numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3))
- val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
- val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
-
- val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
- val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false)
-
- val topNode = Node.emptyNode(nodeIndex = 1)
- assert(topNode.predict.predict === Double.MinValue)
- assert(topNode.impurity === -1.0)
- assert(topNode.isLeaf === false)
-
- val nodesForGroup = Map((0, Array(topNode)))
- val treeToNodeToIndexInfo = Map((0, Map(
- (topNode.id, new RandomForest.NodeIndexInfo(0, None))
- )))
- val nodeQueue = new mutable.Queue[(Int, Node)]()
- DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode),
- nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue)
-
- // don't enqueue leaf nodes into node queue
- assert(nodeQueue.isEmpty)
-
- // set impurity and predict for topNode
- assert(topNode.predict.predict !== Double.MinValue)
- assert(topNode.impurity !== -1.0)
-
- // set impurity and predict for child nodes
- assert(topNode.leftNode.get.predict.predict === 0.0)
- assert(topNode.rightNode.get.predict.predict === 1.0)
- assert(topNode.leftNode.get.impurity === 0.0)
- assert(topNode.rightNode.get.impurity === 0.0)
- }
-
- test("Avoid aggregation if impurity is 0.0") {
- val arr = new Array[LabeledPoint](4)
- arr(0) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0))
- arr(1) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0, 1.0))
- arr(2) = new LabeledPoint(0.0, Vectors.dense(2.0, 0.0, 0.0))
- arr(3) = new LabeledPoint(1.0, Vectors.dense(0.0, 2.0, 1.0))
- val input = sc.parallelize(arr)
-
- val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5,
- numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3))
- val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
- val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
-
- val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
- val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false)
-
- val topNode = Node.emptyNode(nodeIndex = 1)
- assert(topNode.predict.predict === Double.MinValue)
- assert(topNode.impurity === -1.0)
- assert(topNode.isLeaf === false)
-
- val nodesForGroup = Map((0, Array(topNode)))
- val treeToNodeToIndexInfo = Map((0, Map(
- (topNode.id, new RandomForest.NodeIndexInfo(0, None))
- )))
- val nodeQueue = new mutable.Queue[(Int, Node)]()
- DecisionTree.findBestSplits(baggedInput, metadata, Array(topNode),
- nodesForGroup, treeToNodeToIndexInfo, splits, bins, nodeQueue)
-
- // don't enqueue a node into node queue if its impurity is 0.0
- assert(nodeQueue.isEmpty)
-
- // set impurity and predict for topNode
- assert(topNode.predict.predict !== Double.MinValue)
- assert(topNode.impurity !== -1.0)
-
- // set impurity and predict for child nodes
- assert(topNode.leftNode.get.predict.predict === 0.0)
- assert(topNode.rightNode.get.predict.predict === 1.0)
- assert(topNode.leftNode.get.impurity === 0.0)
- assert(topNode.rightNode.get.impurity === 0.0)
- }
+ /////////////////////////////////////////////////////////////////////////////
+ // Tests of model save/load
+ /////////////////////////////////////////////////////////////////////////////
test("Node.subtreeIterator") {
val model = DecisionTreeSuite.createModel(Classification)
@@ -996,8 +1006,9 @@ object DecisionTreeSuite extends FunSuite {
/**
* Create a tree model. This is deterministic and contains a variety of node and feature types.
+ * TODO: Update this to be a correct tree (with matching probabilities, impurities, etc.)
*/
- private[tree] def createModel(algo: Algo): DecisionTreeModel = {
+ private[mllib] def createModel(algo: Algo): DecisionTreeModel = {
val topNode = createInternalNode(id = 1, Continuous)
val (node2, node3) = (createLeafNode(id = 2), createInternalNode(id = 3, Categorical))
val (node6, node7) = (createLeafNode(id = 6), createLeafNode(id = 7))
@@ -1017,7 +1028,7 @@ object DecisionTreeSuite extends FunSuite {
* make mistakes such as creating loops of Nodes.
* If the trees are not equal, this prints the two trees and throws an exception.
*/
- private[tree] def checkEqual(a: DecisionTreeModel, b: DecisionTreeModel): Unit = {
+ private[mllib] def checkEqual(a: DecisionTreeModel, b: DecisionTreeModel): Unit = {
try {
assert(a.algo === b.algo)
checkEqual(a.topNode, b.topNode)
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 1564babefa62f..7ef363a2f07ad 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -68,6 +68,10 @@ object MimaExcludes {
// SPARK-6693 add tostring with max lines and width for matrix
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.mllib.linalg.Matrix.toString")
+ )++ Seq(
+ // SPARK-6703 Add getOrCreate method to SparkContext
+ ProblemFilters.exclude[IncompatibleResultTypeProblem]
+ ("org.apache.spark.SparkContext.org$apache$spark$SparkContext$$activeContext")
)
case v if v.startsWith("1.3") =>
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 326d22e72f104..d70c5b0a6930c 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -489,8 +489,9 @@ def sort(self, *cols, **kwargs):
"""Returns a new :class:`DataFrame` sorted by the specified column(s).
:param cols: list of :class:`Column` or column names to sort by.
- :param ascending: sort by ascending order or not, could be bool, int
- or list of bool, int (default: True).
+ :param ascending: boolean or list of boolean (default True).
+ Sort ascending vs. descending. Specify list for multiple sort orders.
+ If a list is specified, length of the list must equal length of the `cols`.
>>> df.sort(df.age.desc()).collect()
[Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
@@ -519,7 +520,7 @@ def sort(self, *cols, **kwargs):
jcols = [jc if asc else jc.desc()
for asc, jc in zip(ascending, jcols)]
else:
- raise TypeError("ascending can only be bool or list, but got %s" % type(ascending))
+ raise TypeError("ascending can only be boolean or list, but got %s" % type(ascending))
jdf = self._jdf.sort(self._jseq(jcols))
return DataFrame(jdf, self.sql_ctx)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
index d1ea7cc3e9162..ae77f72998a22 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
@@ -23,7 +23,7 @@ import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.api.r.SerDe
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, NamedExpression}
-import org.apache.spark.sql.types.{DataType, StructType}
+import org.apache.spark.sql.types._
import org.apache.spark.sql.{Column, DataFrame, GroupedData, Row, SQLContext, SaveMode}
private[r] object SQLUtils {
@@ -39,8 +39,34 @@ private[r] object SQLUtils {
arr.toSeq
}
- def createDF(rdd: RDD[Array[Byte]], schemaString: String, sqlContext: SQLContext): DataFrame = {
- val schema = DataType.fromJson(schemaString).asInstanceOf[StructType]
+ def createStructType(fields : Seq[StructField]): StructType = {
+ StructType(fields)
+ }
+
+ def getSQLDataType(dataType: String): DataType = {
+ dataType match {
+ case "byte" => org.apache.spark.sql.types.ByteType
+ case "integer" => org.apache.spark.sql.types.IntegerType
+ case "double" => org.apache.spark.sql.types.DoubleType
+ case "numeric" => org.apache.spark.sql.types.DoubleType
+ case "character" => org.apache.spark.sql.types.StringType
+ case "string" => org.apache.spark.sql.types.StringType
+ case "binary" => org.apache.spark.sql.types.BinaryType
+ case "raw" => org.apache.spark.sql.types.BinaryType
+ case "logical" => org.apache.spark.sql.types.BooleanType
+ case "boolean" => org.apache.spark.sql.types.BooleanType
+ case "timestamp" => org.apache.spark.sql.types.TimestampType
+ case "date" => org.apache.spark.sql.types.DateType
+ case _ => throw new IllegalArgumentException(s"Invaid type $dataType")
+ }
+ }
+
+ def createStructField(name: String, dataType: String, nullable: Boolean): StructField = {
+ val dtObj = getSQLDataType(dataType)
+ StructField(name, dtObj, nullable)
+ }
+
+ def createDF(rdd: RDD[Array[Byte]], schema: StructType, sqlContext: SQLContext): DataFrame = {
val num = schema.fields.size
val rowRDD = rdd.map(bytesToRow)
sqlContext.createDataFrame(rowRDD, schema)
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
index c357b7ae9d4da..f7a84207e9da6 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
@@ -373,14 +373,7 @@ private[spark] class ApplicationMaster(
private def waitForSparkContextInitialized(): SparkContext = {
logInfo("Waiting for spark context initialization")
sparkContextRef.synchronized {
- val waitTries = sparkConf.getOption("spark.yarn.applicationMaster.waitTries")
- .map(_.toLong * 10000L)
- if (waitTries.isDefined) {
- logWarning(
- "spark.yarn.applicationMaster.waitTries is deprecated, use spark.yarn.am.waitTime")
- }
- val totalWaitTime = sparkConf.getTimeAsMs("spark.yarn.am.waitTime",
- s"${waitTries.getOrElse(100000L)}ms")
+ val totalWaitTime = sparkConf.getTimeAsMs("spark.yarn.am.waitTime", "100s")
val deadline = System.currentTimeMillis() + totalWaitTime
while (sparkContextRef.get() == null && System.currentTimeMillis < deadline && !finished) {
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
index 52e4dee46c535..019afbd1a1743 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
@@ -17,15 +17,18 @@
package org.apache.spark.deploy.yarn
+import java.io.{File, FileOutputStream}
import java.net.{InetAddress, UnknownHostException, URI, URISyntaxException}
import java.nio.ByteBuffer
+import java.util.zip.{ZipEntry, ZipOutputStream}
import scala.collection.JavaConversions._
-import scala.collection.mutable.{ArrayBuffer, HashMap, ListBuffer, Map}
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, ListBuffer, Map}
import scala.reflect.runtime.universe
import scala.util.{Try, Success, Failure}
import com.google.common.base.Objects
+import com.google.common.io.Files
import org.apache.hadoop.io.DataOutputBuffer
import org.apache.hadoop.conf.Configuration
@@ -77,12 +80,6 @@ private[spark] class Client(
def stop(): Unit = yarnClient.stop()
- /* ------------------------------------------------------------------------------------- *
- | The following methods have much in common in the stable and alpha versions of Client, |
- | but cannot be implemented in the parent trait due to subtle API differences across |
- | hadoop versions. |
- * ------------------------------------------------------------------------------------- */
-
/**
* Submit an application running our ApplicationMaster to the ResourceManager.
*
@@ -223,6 +220,10 @@ private[spark] class Client(
val fs = FileSystem.get(hadoopConf)
val dst = new Path(fs.getHomeDirectory(), appStagingDir)
val nns = getNameNodesToAccess(sparkConf) + dst
+ // Used to keep track of URIs added to the distributed cache. If the same URI is added
+ // multiple times, YARN will fail to launch containers for the app with an internal
+ // error.
+ val distributedUris = new HashSet[String]
obtainTokensForNamenodes(nns, hadoopConf, credentials)
obtainTokenForHiveMetastore(hadoopConf, credentials)
@@ -241,6 +242,17 @@ private[spark] class Client(
"for alternatives.")
}
+ def addDistributedUri(uri: URI): Boolean = {
+ val uriStr = uri.toString()
+ if (distributedUris.contains(uriStr)) {
+ logWarning(s"Resource $uri added multiple times to distributed cache.")
+ false
+ } else {
+ distributedUris += uriStr
+ true
+ }
+ }
+
/**
* Copy the given main resource to the distributed cache if the scheme is not "local".
* Otherwise, set the corresponding key in our SparkConf to handle it downstream.
@@ -258,11 +270,13 @@ private[spark] class Client(
if (!localPath.isEmpty()) {
val localURI = new URI(localPath)
if (localURI.getScheme != LOCAL_SCHEME) {
- val src = getQualifiedLocalPath(localURI, hadoopConf)
- val destPath = copyFileToRemote(dst, src, replication)
- val destFs = FileSystem.get(destPath.toUri(), hadoopConf)
- distCacheMgr.addResource(destFs, hadoopConf, destPath,
- localResources, LocalResourceType.FILE, destName, statCache)
+ if (addDistributedUri(localURI)) {
+ val src = getQualifiedLocalPath(localURI, hadoopConf)
+ val destPath = copyFileToRemote(dst, src, replication)
+ val destFs = FileSystem.get(destPath.toUri(), hadoopConf)
+ distCacheMgr.addResource(destFs, hadoopConf, destPath,
+ localResources, LocalResourceType.FILE, destName, statCache)
+ }
} else if (confKey != null) {
// If the resource is intended for local use only, handle this downstream
// by setting the appropriate property
@@ -271,6 +285,13 @@ private[spark] class Client(
}
}
+ createConfArchive().foreach { file =>
+ require(addDistributedUri(file.toURI()))
+ val destPath = copyFileToRemote(dst, new Path(file.toURI()), replication)
+ distCacheMgr.addResource(fs, hadoopConf, destPath, localResources, LocalResourceType.ARCHIVE,
+ LOCALIZED_HADOOP_CONF_DIR, statCache, appMasterOnly = true)
+ }
+
/**
* Do the same for any additional resources passed in through ClientArguments.
* Each resource category is represented by a 3-tuple of:
@@ -288,13 +309,15 @@ private[spark] class Client(
flist.split(',').foreach { file =>
val localURI = new URI(file.trim())
if (localURI.getScheme != LOCAL_SCHEME) {
- val localPath = new Path(localURI)
- val linkname = Option(localURI.getFragment()).getOrElse(localPath.getName())
- val destPath = copyFileToRemote(dst, localPath, replication)
- distCacheMgr.addResource(
- fs, hadoopConf, destPath, localResources, resType, linkname, statCache)
- if (addToClasspath) {
- cachedSecondaryJarLinks += linkname
+ if (addDistributedUri(localURI)) {
+ val localPath = new Path(localURI)
+ val linkname = Option(localURI.getFragment()).getOrElse(localPath.getName())
+ val destPath = copyFileToRemote(dst, localPath, replication)
+ distCacheMgr.addResource(
+ fs, hadoopConf, destPath, localResources, resType, linkname, statCache)
+ if (addToClasspath) {
+ cachedSecondaryJarLinks += linkname
+ }
}
} else if (addToClasspath) {
// Resource is intended for local use only and should be added to the class path
@@ -310,6 +333,57 @@ private[spark] class Client(
localResources
}
+ /**
+ * Create an archive with the Hadoop config files for distribution.
+ *
+ * These are only used by the AM, since executors will use the configuration object broadcast by
+ * the driver. The files are zipped and added to the job as an archive, so that YARN will explode
+ * it when distributing to the AM. This directory is then added to the classpath of the AM
+ * process, just to make sure that everybody is using the same default config.
+ *
+ * This follows the order of precedence set by the startup scripts, in which HADOOP_CONF_DIR
+ * shows up in the classpath before YARN_CONF_DIR.
+ *
+ * Currently this makes a shallow copy of the conf directory. If there are cases where a
+ * Hadoop config directory contains subdirectories, this code will have to be fixed.
+ */
+ private def createConfArchive(): Option[File] = {
+ val hadoopConfFiles = new HashMap[String, File]()
+ Seq("HADOOP_CONF_DIR", "YARN_CONF_DIR").foreach { envKey =>
+ sys.env.get(envKey).foreach { path =>
+ val dir = new File(path)
+ if (dir.isDirectory()) {
+ dir.listFiles().foreach { file =>
+ if (!hadoopConfFiles.contains(file.getName())) {
+ hadoopConfFiles(file.getName()) = file
+ }
+ }
+ }
+ }
+ }
+
+ if (!hadoopConfFiles.isEmpty) {
+ val hadoopConfArchive = File.createTempFile(LOCALIZED_HADOOP_CONF_DIR, ".zip",
+ new File(Utils.getLocalDir(sparkConf)))
+
+ val hadoopConfStream = new ZipOutputStream(new FileOutputStream(hadoopConfArchive))
+ try {
+ hadoopConfStream.setLevel(0)
+ hadoopConfFiles.foreach { case (name, file) =>
+ hadoopConfStream.putNextEntry(new ZipEntry(name))
+ Files.copy(file, hadoopConfStream)
+ hadoopConfStream.closeEntry()
+ }
+ } finally {
+ hadoopConfStream.close()
+ }
+
+ Some(hadoopConfArchive)
+ } else {
+ None
+ }
+ }
+
/**
* Set up the environment for launching our ApplicationMaster container.
*/
@@ -317,7 +391,7 @@ private[spark] class Client(
logInfo("Setting up the launch environment for our AM container")
val env = new HashMap[String, String]()
val extraCp = sparkConf.getOption("spark.driver.extraClassPath")
- populateClasspath(args, yarnConf, sparkConf, env, extraCp)
+ populateClasspath(args, yarnConf, sparkConf, env, true, extraCp)
env("SPARK_YARN_MODE") = "true"
env("SPARK_YARN_STAGING_DIR") = stagingDir
env("SPARK_USER") = UserGroupInformation.getCurrentUser().getShortUserName()
@@ -718,6 +792,9 @@ object Client extends Logging {
// Distribution-defined classpath to add to processes
val ENV_DIST_CLASSPATH = "SPARK_DIST_CLASSPATH"
+ // Subdirectory where the user's hadoop config files will be placed.
+ val LOCALIZED_HADOOP_CONF_DIR = "__hadoop_conf__"
+
/**
* Find the user-defined Spark jar if configured, or return the jar containing this
* class if not.
@@ -831,11 +908,19 @@ object Client extends Logging {
conf: Configuration,
sparkConf: SparkConf,
env: HashMap[String, String],
+ isAM: Boolean,
extraClassPath: Option[String] = None): Unit = {
extraClassPath.foreach(addClasspathEntry(_, env))
addClasspathEntry(
YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), env
)
+
+ if (isAM) {
+ addClasspathEntry(
+ YarnSparkHadoopUtil.expandEnvironment(Environment.PWD) + Path.SEPARATOR +
+ LOCALIZED_HADOOP_CONF_DIR, env)
+ }
+
if (sparkConf.getBoolean("spark.yarn.user.classpath.first", false)) {
val userClassPath =
if (args != null) {
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala
index da6798cb1b279..1423533470fc0 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala
@@ -103,9 +103,13 @@ private[spark] class ClientArguments(args: Array[String], sparkConf: SparkConf)
* This is intended to be called only after the provided arguments have been parsed.
*/
private def validateArgs(): Unit = {
- if (numExecutors <= 0) {
+ if (numExecutors < 0 || (!isDynamicAllocationEnabled && numExecutors == 0)) {
throw new IllegalArgumentException(
- "You must specify at least 1 executor!\n" + getUsageMessage())
+ s"""
+ |Number of executors was $numExecutors, but must be at least 1
+ |(or 0 if dynamic executor allocation is enabled).
+ |${getUsageMessage()}
+ """.stripMargin)
}
if (executorCores < sparkConf.getInt("spark.task.cpus", 1)) {
throw new SparkException("Executor cores must not be less than " +
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala
index b06069c07f451..9d04d241dae9e 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala
@@ -277,7 +277,7 @@ class ExecutorRunnable(
private def prepareEnvironment(container: Container): HashMap[String, String] = {
val env = new HashMap[String, String]()
val extraCp = sparkConf.getOption("spark.executor.extraClassPath")
- Client.populateClasspath(null, yarnConf, sparkConf, env, extraCp)
+ Client.populateClasspath(null, yarnConf, sparkConf, env, false, extraCp)
sparkConf.getExecutorEnv.foreach { case (key, value) =>
// This assumes each executor environment variable set here is a path
diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala
index c1b94ac9c5bdd..a51c2005cb472 100644
--- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala
+++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala
@@ -20,6 +20,11 @@ package org.apache.spark.deploy.yarn
import java.io.File
import java.net.URI
+import scala.collection.JavaConversions._
+import scala.collection.mutable.{ HashMap => MutableHashMap }
+import scala.reflect.ClassTag
+import scala.util.Try
+
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.hadoop.mapreduce.MRJobConfig
@@ -30,11 +35,6 @@ import org.mockito.Matchers._
import org.mockito.Mockito._
import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers}
-import scala.collection.JavaConversions._
-import scala.collection.mutable.{ HashMap => MutableHashMap }
-import scala.reflect.ClassTag
-import scala.util.Try
-
import org.apache.spark.{SparkException, SparkConf}
import org.apache.spark.util.Utils
@@ -93,7 +93,7 @@ class ClientSuite extends FunSuite with Matchers with BeforeAndAfterAll {
val env = new MutableHashMap[String, String]()
val args = new ClientArguments(Array("--jar", USER, "--addJars", ADDED), sparkConf)
- Client.populateClasspath(args, conf, sparkConf, env)
+ Client.populateClasspath(args, conf, sparkConf, env, true)
val cp = env("CLASSPATH").split(":|;|")
s"$SPARK,$USER,$ADDED".split(",").foreach({ entry =>
@@ -104,13 +104,16 @@ class ClientSuite extends FunSuite with Matchers with BeforeAndAfterAll {
cp should not contain (uri.getPath())
}
})
- if (classOf[Environment].getMethods().exists(_.getName == "$$")) {
- cp should contain("{{PWD}}")
- } else if (Utils.isWindows) {
- cp should contain("%PWD%")
- } else {
- cp should contain(Environment.PWD.$())
- }
+ val pwdVar =
+ if (classOf[Environment].getMethods().exists(_.getName == "$$")) {
+ "{{PWD}}"
+ } else if (Utils.isWindows) {
+ "%PWD%"
+ } else {
+ Environment.PWD.$()
+ }
+ cp should contain(pwdVar)
+ cp should contain (s"$pwdVar${Path.SEPARATOR}${Client.LOCALIZED_HADOOP_CONF_DIR}")
cp should not contain (Client.SPARK_JAR)
cp should not contain (Client.APP_JAR)
}
diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
index a18c94d4ab4a8..3877da4120e7c 100644
--- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
+++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
@@ -77,6 +77,7 @@ class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers wit
private var yarnCluster: MiniYARNCluster = _
private var tempDir: File = _
private var fakeSparkJar: File = _
+ private var hadoopConfDir: File = _
private var logConfDir: File = _
override def beforeAll() {
@@ -120,6 +121,9 @@ class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers wit
logInfo(s"RM address in configuration is ${config.get(YarnConfiguration.RM_ADDRESS)}")
fakeSparkJar = File.createTempFile("sparkJar", null, tempDir)
+ hadoopConfDir = new File(tempDir, Client.LOCALIZED_HADOOP_CONF_DIR)
+ assert(hadoopConfDir.mkdir())
+ File.createTempFile("token", ".txt", hadoopConfDir)
}
override def afterAll() {
@@ -258,7 +262,7 @@ class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers wit
appArgs
Utils.executeAndGetOutput(argv,
- extraEnvironment = Map("YARN_CONF_DIR" -> tempDir.getAbsolutePath()))
+ extraEnvironment = Map("YARN_CONF_DIR" -> hadoopConfDir.getAbsolutePath()))
}
/**