From e9efb62e0795c8d5233b7e5bfc276d74953942b8 Mon Sep 17 00:00:00 2001 From: Huaxin Gao Date: Wed, 6 Jun 2018 08:31:35 +0700 Subject: [PATCH] [SPARK-24187][R][SQL] Add array_join function to SparkR ## What changes were proposed in this pull request? This PR adds array_join function to SparkR ## How was this patch tested? Add unit test in test_sparkSQL.R Author: Huaxin Gao Closes #21313 from huaxingao/spark-24187. --- R/pkg/NAMESPACE | 1 + R/pkg/R/functions.R | 29 ++++++++++++++++++++++++--- R/pkg/R/generics.R | 4 ++++ R/pkg/tests/fulltests/test_sparkSQL.R | 15 ++++++++++++++ 4 files changed, 46 insertions(+), 3 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 73a33af4dd48b..9696f6987ad78 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -201,6 +201,7 @@ exportMethods("%<=>%", "approxCountDistinct", "approxQuantile", "array_contains", + "array_join", "array_max", "array_min", "array_position", diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index abc91aeeb4825..3bff633fbc1ff 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -221,7 +221,9 @@ NULL #' head(select(tmp3, element_at(tmp3$v3, "Valiant"))) #' tmp4 <- mutate(df, v4 = create_array(df$mpg, df$cyl), v5 = create_array(df$cyl, df$hp)) #' head(select(tmp4, concat(tmp4$v4, tmp4$v5), arrays_overlap(tmp4$v4, tmp4$v5))) -#' head(select(tmp, concat(df$mpg, df$cyl, df$hp)))} +#' head(select(tmp, concat(df$mpg, df$cyl, df$hp))) +#' tmp5 <- mutate(df, v6 = create_array(df$model, df$model)) +#' head(select(tmp5, array_join(tmp5$v6, "#"), array_join(tmp5$v6, "#", "NULL")))} NULL #' Window functions for Column operations @@ -3006,6 +3008,27 @@ setMethod("array_contains", column(jc) }) +#' @details +#' \code{array_join}: Concatenates the elements of column using the delimiter. +#' Null values are replaced with nullReplacement if set, otherwise they are ignored. +#' +#' @param delimiter a character string that is used to concatenate the elements of column. +#' @param nullReplacement an optional character string that is used to replace the Null values. +#' @rdname column_collection_functions +#' @aliases array_join array_join,Column-method +#' @note array_join since 2.4.0 +setMethod("array_join", + signature(x = "Column", delimiter = "character"), + function(x, delimiter, nullReplacement = NULL) { + jc <- if (is.null(nullReplacement)) { + callJStatic("org.apache.spark.sql.functions", "array_join", x@jc, delimiter) + } else { + callJStatic("org.apache.spark.sql.functions", "array_join", x@jc, delimiter, + as.character(nullReplacement)) + } + column(jc) + }) + #' @details #' \code{array_max}: Returns the maximum value of the array. #' @@ -3197,8 +3220,8 @@ setMethod("size", #' (or starting from the end if start is negative) with the specified length. #' #' @rdname column_collection_functions -#' @param start an index indicating the first element occuring in the result. -#' @param length a number of consecutive elements choosen to the result. +#' @param start an index indicating the first element occurring in the result. +#' @param length a number of consecutive elements chosen to the result. #' @aliases slice slice,Column-method #' @note slice since 2.4.0 setMethod("slice", diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 8894cb1c5b92f..9321bbaf96ff8 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -757,6 +757,10 @@ setGeneric("approxCountDistinct", function(x, ...) { standardGeneric("approxCoun #' @name NULL setGeneric("array_contains", function(x, value) { standardGeneric("array_contains") }) +#' @rdname column_collection_functions +#' @name NULL +setGeneric("array_join", function(x, delimiter, ...) { standardGeneric("array_join") }) + #' @rdname column_collection_functions #' @name NULL setGeneric("array_max", function(x) { standardGeneric("array_max") }) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 16c1fd5a065eb..36e0f78bb0599 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1518,6 +1518,21 @@ test_that("column functions", { result <- collect(select(df, arrays_overlap(df[[1]], df[[2]])))[[1]] expect_equal(result, c(TRUE, FALSE, NA)) + # Test array_join() + df <- createDataFrame(list(list(list("Hello", "World!")))) + result <- collect(select(df, array_join(df[[1]], "#")))[[1]] + expect_equal(result, "Hello#World!") + df2 <- createDataFrame(list(list(list("Hello", NA, "World!")))) + result <- collect(select(df2, array_join(df2[[1]], "#", "Beautiful")))[[1]] + expect_equal(result, "Hello#Beautiful#World!") + result <- collect(select(df2, array_join(df2[[1]], "#")))[[1]] + expect_equal(result, "Hello#World!") + df3 <- createDataFrame(list(list(list("Hello", NULL, "World!")))) + result <- collect(select(df3, array_join(df3[[1]], "#", "Beautiful")))[[1]] + expect_equal(result, "Hello#Beautiful#World!") + result <- collect(select(df3, array_join(df3[[1]], "#")))[[1]] + expect_equal(result, "Hello#World!") + # Test array_sort() and sort_array() df <- createDataFrame(list(list(list(2L, 1L, 3L, NA)), list(list(NA, 6L, 5L, NA, 4L))))