Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-24187][R][SQL]Add array_join function to SparkR #21313

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions R/pkg/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ exportMethods("%<=>%",
"approxCountDistinct",
"approxQuantile",
"array_contains",
"array_join",
"array_max",
"array_min",
"array_position",
Expand Down
30 changes: 27 additions & 3 deletions R/pkg/R/functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,9 @@ NULL
#' head(select(tmp3, element_at(tmp3$v3, "Valiant")))
#' tmp4 <- mutate(df, v4 = create_array(df$mpg, df$cyl), v5 = create_array(df$cyl, df$hp))
#' head(select(tmp4, concat(tmp4$v4, tmp4$v5), arrays_overlap(tmp4$v4, tmp4$v5)))
#' head(select(tmp, concat(df$mpg, df$cyl, df$hp)))}
#' head(select(tmp, concat(df$mpg, df$cyl, df$hp)))
#' tmp5 <- mutate(df, v6 = create_array(df$model, df$model))
#' head(select(tmp5, array_join(tmp5$v6, "#"), array_join(tmp5$v6, "#", "NULL")))}
NULL

#' Window functions for Column operations
Expand Down Expand Up @@ -3006,6 +3008,28 @@ setMethod("array_contains",
column(jc)
})

#' @details
#' \code{array_join}: Concatenates the elements of column using the delimiter.
#' Null values are replaced with nullReplacement if set, otherwise they are ignored.
#'
#' @param delimiter character(s) to use to concatenate the elements of column.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't check scala - what's "(s)" here in "character(s)" mean? I ask because "character" refers to the type in R

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@felixcheung scala doesn't have a doc for param delimiter. I added this myself. What I am trying to say is "one or more characters". I will change to "a character string" so it will be

@param delimiter a character string to use to concatenate the elements of column.

Does this look ok to you?

#' @param nullReplacement character(s) to use to replace the Null values.
#' @rdname column_collection_functions
#' @aliases array_join array_join,Column-method
#' @note array_join since 2.4.0
setMethod("array_join",
signature(x = "Column"),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is delimiter supposed to be a string? it's better to have its type in the signature

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@felixcheung
I will add the type so it will be

signature(x = "Column", delimiter = "character"),

function(x, delimiter, nullReplacement = NA) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait.. why is nullReplacement default to NA? that's a bit unusual

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@felixcheung
I will change the default to NULL.

jc <- if (is.na(nullReplacement)) {
callJStatic("org.apache.spark.sql.functions", "array_join", x@jc, delimiter)
}
else {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: } else {

callJStatic("org.apache.spark.sql.functions", "array_join", x@jc, delimiter,
nullReplacement)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as.character(nullReplacement)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

re #21313 (comment)
so what's the behavior if delimiter is more than one character?
like array_join(df$a, "#Foo", "Beautiful")?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's Hello#FooBeautiful#FooWorld!

}
column(jc)
})

#' @details
#' \code{array_max}: Returns the maximum value of the array.
#'
Expand Down Expand Up @@ -3197,8 +3221,8 @@ setMethod("size",
#' (or starting from the end if start is negative) with the specified length.
#'
#' @rdname column_collection_functions
#' @param start an index indicating the first element occuring in the result.
#' @param length a number of consecutive elements choosen to the result.
#' @param start an index indicating the first element occurring in the result.
#' @param length a number of consecutive elements chosen to the result.
#' @aliases slice slice,Column-method
#' @note slice since 2.4.0
setMethod("slice",
Expand Down
4 changes: 4 additions & 0 deletions R/pkg/R/generics.R
Original file line number Diff line number Diff line change
Expand Up @@ -757,6 +757,10 @@ setGeneric("approxCountDistinct", function(x, ...) { standardGeneric("approxCoun
#' @name NULL
setGeneric("array_contains", function(x, value) { standardGeneric("array_contains") })

#' @rdname column_collection_functions
#' @name NULL
setGeneric("array_join", function(x, delimiter, ...) { standardGeneric("array_join") })

#' @rdname column_collection_functions
#' @name NULL
setGeneric("array_max", function(x) { standardGeneric("array_max") })
Expand Down
10 changes: 10 additions & 0 deletions R/pkg/tests/fulltests/test_sparkSQL.R
Original file line number Diff line number Diff line change
Expand Up @@ -1518,6 +1518,16 @@ test_that("column functions", {
result <- collect(select(df, arrays_overlap(df[[1]], df[[2]])))[[1]]
expect_equal(result, c(TRUE, FALSE, NA))

# Test array_join()
df <- createDataFrame(list(list(list("Hello", "World!"))))
result <- collect(select(df, array_join(df[[1]], "#")))[[1]]
expect_equal(result, "Hello#World!")
df2 <- createDataFrame(list(list(list("Hello", NA, "World!"))))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does it work with NULL?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@HyukjinKwon Thank you very much for your review. I will add a test for NULL and also change the }.

result <- collect(select(df2, array_join(df2[[1]], "#", "Beautiful")))[[1]]
expect_equal(result, "Hello#Beautiful#World!")
result <- collect(select(df2, array_join(df2[[1]], "#")))[[1]]
expect_equal(result, "Hello#World!")

# Test array_sort() and sort_array()
df <- createDataFrame(list(list(list(2L, 1L, 3L, NA)), list(list(NA, 6L, 5L, NA, 4L))))

Expand Down