Skip to content

Commit

Permalink
Merge pull request apache#108 from concretevitamin/take-optimize
Browse files Browse the repository at this point in the history
[SPARKR-139] Fix slow take(): deserialize only up to necessary # of elements.
  • Loading branch information
shivaram committed Nov 15, 2014
2 parents e4217dd + c06fc90 commit bc3e9f6
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 25 deletions.
38 changes: 22 additions & 16 deletions pkg/R/RDD.R
Original file line number Diff line number Diff line change
Expand Up @@ -383,13 +383,13 @@ setMethod("length",
count(x)
})

#' Return the count of each unique value in this RDD as a list of
#' Return the count of each unique value in this RDD as a list of
#' (value, count) pairs.
#'
#'
#' Same as countByValue in Spark.
#'
#' @param rdd The RDD to count
#' @return list of (value, count) pairs, where count is number of each unique
#' @return list of (value, count) pairs, where count is number of each unique
#' value in rdd.
#' @rdname countByValue
#' @export
Expand All @@ -412,7 +412,7 @@ setMethod("countByValue",

#' Count the number of elements for each key, and return the result to the
#' master as lists of (key, count) pairs.
#'
#'
#' Same as countByKey in Spark.
#'
#' @param rdd The RDD to count keys.
Expand Down Expand Up @@ -596,8 +596,8 @@ setMethod("mapPartitionsWithIndex",
lapplyPartitionsWithIndex(X, FUN)
})

#' This function returns a new RDD containing only the elements that satisfy
#' a predicate (i.e. returning TRUE in a given logical function).
#' This function returns a new RDD containing only the elements that satisfy
#' a predicate (i.e. returning TRUE in a given logical function).
#' The same as `filter()' in Spark.
#'
#' @param rdd The RDD to be filtered.
Expand Down Expand Up @@ -733,7 +733,7 @@ setMethod("take",
index <- -1
jrdd <- getJRDD(rdd)
numPartitions <- numPartitions(rdd)

# TODO(shivaram): Collect more than one partition based on size
# estimates similar to the scala version of `take`.
while (TRUE) {
Expand All @@ -748,16 +748,22 @@ setMethod("take",
"collectPartitions",
.jarray(as.integer(index)))
partition <- partitionArr[[1]]
elems <- convertJListToRList(partition, flatten = TRUE)

size <- num - length(resList)
# elems is capped to have at most `size` elements
elems <- convertJListToRList(partition,
flatten = TRUE,
logicalUpperBound = size,
serialized = rdd@env$serialized)
# TODO: Check if this append is O(n^2)?
resList <- append(resList, head(elems, n = num - length(resList)))
resList <- append(resList, elems)
}
resList
})

#' Removes the duplicates from RDD.
#'
#' This function returns a new RDD containing the distinct elements in the
#' This function returns a new RDD containing the distinct elements in the
#' given RDD. The same as `distinct()' in Spark.
#'
#' @param rdd The RDD to remove duplicates from.
Expand All @@ -770,7 +776,7 @@ setMethod("take",
#' rdd <- parallelize(sc, c(1,2,2,3,3,3))
#' sort(unlist(collect(distinct(rdd)))) # c(1, 2, 3)
#'}
setGeneric("distinct",
setGeneric("distinct",
function(rdd, numPartitions) { standardGeneric("distinct") })

setClassUnion("missingOrInteger", c("missing", "integer"))
Expand All @@ -783,8 +789,8 @@ setMethod("distinct",
numPartitions <- SparkR::numPartitions(rdd)
}
identical.mapped <- lapply(rdd, function(x) { list(x, NULL) })
reduced <- reduceByKey(identical.mapped,
function(x, y) { x },
reduced <- reduceByKey(identical.mapped,
function(x, y) { x },
numPartitions)
resRDD <- lapply(reduced, function(x) { x[[1]] })
resRDD
Expand Down Expand Up @@ -1008,7 +1014,7 @@ setMethod("mapValues",

#' Pass each value in the key-value pair RDD through a flatMap function without
#' changing the keys; this also retains the original RDD's partitioning.
#'
#'
#' The same as 'flatMapValues()' in Spark.
#'
#' @param X The RDD to apply the transformation.
Expand Down Expand Up @@ -1338,7 +1344,7 @@ setMethod("combineByKey",
#'
#' @param x An RDD.
#' @param y An RDD.
#' @return a new RDD created by performing the simple union (witout removing
#' @return a new RDD created by performing the simple union (witout removing
#' duplicates) of two input RDDs.
#' @rdname unionRDD
#' @export
Expand Down Expand Up @@ -1369,6 +1375,6 @@ setMethod("unionRDD",
jrdd <- .jcall(getJRDD(x), "Lorg/apache/spark/api/java/JavaRDD;",
"union", getJRDD(y))
union.rdd <- RDD(jrdd, TRUE)
}
}
union.rdd
})
34 changes: 26 additions & 8 deletions pkg/R/utils.R
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
# Utilities and Helpers

# Given a JList<T>, returns an R list containing the same elements. Takes care
# of deserializations and type conversions.
convertJListToRList <- function(jList, flatten) {
size <- .jcall(jList, "I", "size")
results <- if (size > 0) {
lapply(0:(size - 1),
# Given a JList<T>, returns an R list containing the same elements, the number
# of which is optionally upper bounded by `logicalUpperBound` (by default,
# return all elements). Takes care of deserializations and type conversions.
convertJListToRList <- function(jList, flatten, logicalUpperBound = NULL, serialized = TRUE) {
arrSize <- .jcall(jList, "I", "size")

# Unserialized datasets (such as an RDD directly generated by textFile()):
# each partition is not dense-packed into one Array[Byte], and `arrSize`
# here corresponds to number of logical elements. Thus we can prune here.
if (!serialized && !is.null(logicalUpperBound)) {
arrSize <- min(arrSize, logicalUpperBound)
}

results <- if (arrSize > 0) {
lapply(0:(arrSize - 1),
function(index) {
jElem <- .jcall(jList,
"Ljava/lang/Object;",
Expand All @@ -16,11 +25,21 @@ convertJListToRList <- function(jList, flatten) {
obj <- .jsimplify(jElem)

if (inherits(obj, "jobjRef") && .jinstanceof(obj, "[B")) {
# RDD[Array[Byte]].
# RDD[Array[Byte]]. `obj` is a whole partition.

rRaw <- .jevalArray(.jcastToArray(jElem))
res <- unserialize(rRaw)

# For serialized datasets, `obj` (and `rRaw`) here corresponds to
# one whole partition dense-packed together. We deserialize the
# whole partition first, then cap the number of elements to be returned.

# TODO: is it possible to distinguish element boundary so that we can
# unserialize only what we need?
if (!is.null(logicalUpperBound)) {
res <- head(res, n = logicalUpperBound)
}

} else if (inherits(obj, "jobjRef") &&
.jinstanceof(obj, "scala.Tuple2")) {
# JavaPairRDD[Array[Byte], Array[Byte]].
Expand Down Expand Up @@ -53,7 +72,6 @@ convertJListToRList <- function(jList, flatten) {
} else {
as.list(results)
}

}

# Given a Java array of byte arrays, deserilize each, returning an R list of
Expand Down
5 changes: 4 additions & 1 deletion pkg/inst/tests/test_take.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@ jsc <- sparkR.init()

test_that("take() gives back the original elements in correct count and order", {
numVectorRDD <- parallelize(jsc, numVector, 10)
# case: number of elements to take is less than the size of the first partition
expect_equal(take(numVectorRDD, 1), as.list(head(numVector, n = 1)))
expect_equal(take(numVectorRDD, 3), as.list(head(numVector, n = 3)))
# case: number of elements to take is the same as the size of the first partition
expect_equal(take(numVectorRDD, 11), as.list(head(numVector, n = 11)))
# case: number of elements to take is greater than all elements
expect_equal(take(numVectorRDD, length(numVector)), as.list(numVector))
expect_equal(take(numVectorRDD, length(numVector) + 1), as.list(numVector))

Expand Down
2 changes: 2 additions & 0 deletions pkg/src/project/plugins.sbt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
resolvers += "Sonatype snapshots" at "http://oss.sonatype.org/content/repositories/snapshots/"

addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.9.1")

addSbtPlugin("com.github.mpeltonen" % "sbt-idea" % "1.6.0")

0 comments on commit bc3e9f6

Please sign in to comment.