Skip to content

Commit

Permalink
[SPARK-20371][R] Add wrappers for collect_list and collect_set
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

Adds wrappers for `collect_list` and `collect_set`.

## How was this patch tested?

Unit tests, `check-cran.sh`

Author: zero323 <[email protected]>

Closes apache#17672 from zero323/SPARK-20371.
  • Loading branch information
zero323 authored and Mingjie Tang committed Apr 24, 2017
1 parent 004f7c0 commit 57370bc
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 0 deletions.
2 changes: 2 additions & 0 deletions R/pkg/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,8 @@ exportMethods("%in%",
"cbrt",
"ceil",
"ceiling",
"collect_list",
"collect_set",
"column",
"concat",
"concat_ws",
Expand Down
40 changes: 40 additions & 0 deletions R/pkg/R/functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -3705,3 +3705,43 @@ setMethod("create_map",
jc <- callJStatic("org.apache.spark.sql.functions", "map", jcols)
column(jc)
})

#' collect_list
#'
#' Creates a list of objects with duplicates.
#'
#' @param x Column to compute on
#'
#' @rdname collect_list
#' @name collect_list
#' @family agg_funcs
#' @aliases collect_list,Column-method
#' @export
#' @examples \dontrun{collect_list(df$x)}
#' @note collect_list since 2.3.0
setMethod("collect_list",
signature(x = "Column"),
function(x) {
jc <- callJStatic("org.apache.spark.sql.functions", "collect_list", x@jc)
column(jc)
})

#' collect_set
#'
#' Creates a list of objects with duplicate elements eliminated.
#'
#' @param x Column to compute on
#'
#' @rdname collect_set
#' @name collect_set
#' @family agg_funcs
#' @aliases collect_set,Column-method
#' @export
#' @examples \dontrun{collect_set(df$x)}
#' @note collect_set since 2.3.0
setMethod("collect_set",
signature(x = "Column"),
function(x) {
jc <- callJStatic("org.apache.spark.sql.functions", "collect_set", x@jc)
column(jc)
})
9 changes: 9 additions & 0 deletions R/pkg/R/generics.R
Original file line number Diff line number Diff line change
Expand Up @@ -918,6 +918,14 @@ setGeneric("cbrt", function(x) { standardGeneric("cbrt") })
#' @export
setGeneric("ceil", function(x) { standardGeneric("ceil") })

#' @rdname collect_list
#' @export
setGeneric("collect_list", function(x) { standardGeneric("collect_list") })

#' @rdname collect_set
#' @export
setGeneric("collect_set", function(x) { standardGeneric("collect_set") })

#' @rdname column
#' @export
setGeneric("column", function(x) { standardGeneric("column") })
Expand Down Expand Up @@ -1358,6 +1366,7 @@ setGeneric("window", function(x, ...) { standardGeneric("window") })
#' @export
setGeneric("year", function(x) { standardGeneric("year") })


###################### Spark.ML Methods ##########################

#' @rdname fitted
Expand Down
22 changes: 22 additions & 0 deletions R/pkg/inst/tests/testthat/test_sparkSQL.R
Original file line number Diff line number Diff line change
Expand Up @@ -1731,6 +1731,28 @@ test_that("group by, agg functions", {
expect_true(abs(sd(1:2) - 0.7071068) < 1e-6)
expect_true(abs(var(1:5, 1:5) - 2.5) < 1e-6)

# Test collect_list and collect_set
gd3_collections_local <- collect(
agg(gd3, collect_set(df8$age), collect_list(df8$age))
)

expect_equal(
unlist(gd3_collections_local[gd3_collections_local$name == "Andy", 2]),
c(30)
)

expect_equal(
unlist(gd3_collections_local[gd3_collections_local$name == "Andy", 3]),
c(30, 30)
)

expect_equal(
sort(unlist(
gd3_collections_local[gd3_collections_local$name == "Justin", 3]
)),
c(1, 19)
)

unlink(jsonPath2)
unlink(jsonPath3)
})
Expand Down

0 comments on commit 57370bc

Please sign in to comment.