diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index db8e06db18edc..e8de34d9371a0 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -182,7 +182,8 @@ exportMethods("arrange", exportClasses("Column") -exportMethods("%in%", +exportMethods("%<=>%", + "%in%", "abs", "acos", "add_months", @@ -291,6 +292,7 @@ exportMethods("%in%", "nanvl", "negate", "next_day", + "not", "ntile", "otherwise", "over", diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R index 539d91b0f8797..147ee4b6887b9 100644 --- a/R/pkg/R/column.R +++ b/R/pkg/R/column.R @@ -67,8 +67,7 @@ operators <- list( "+" = "plus", "-" = "minus", "*" = "multiply", "/" = "divide", "%%" = "mod", "==" = "equalTo", ">" = "gt", "<" = "lt", "!=" = "notEqual", "<=" = "leq", ">=" = "geq", # we can not override `&&` and `||`, so use `&` and `|` instead - "&" = "and", "|" = "or", #, "!" = "unary_$bang" - "^" = "pow" + "&" = "and", "|" = "or", "^" = "pow" ) column_functions1 <- c("asc", "desc", "isNaN", "isNull", "isNotNull") column_functions2 <- c("like", "rlike", "getField", "getItem", "contains") @@ -302,3 +301,55 @@ setMethod("otherwise", jc <- callJMethod(x@jc, "otherwise", value) column(jc) }) + +#' \%<=>\% +#' +#' Equality test that is safe for null values. +#' +#' Can be used, unlike standard equality operator, to perform null-safe joins. +#' Equivalent to Scala \code{Column.<=>} and \code{Column.eqNullSafe}. +#' +#' @param x a Column +#' @param value a value to compare +#' @rdname eq_null_safe +#' @name %<=>% +#' @aliases %<=>%,Column-method +#' @export +#' @examples +#' \dontrun{ +#' df1 <- createDataFrame(data.frame( +#' x = c(1, NA, 3, NA), y = c(2, 6, 3, NA) +#' )) +#' +#' head(select(df1, df1$x == df1$y, df1$x %<=>% df1$y)) +#' +#' df2 <- createDataFrame(data.frame(y = c(3, NA))) +#' count(join(df1, df2, df1$y == df2$y)) +#' +#' count(join(df1, df2, df1$y %<=>% df2$y)) +#' } +#' @note \%<=>\% since 2.3.0 +setMethod("%<=>%", + signature(x = "Column", value = "ANY"), + function(x, value) { + value <- if (class(value) == "Column") { value@jc } else { value } + jc <- callJMethod(x@jc, "eqNullSafe", value) + column(jc) + }) + +#' ! +#' +#' Inversion of boolean expression. +#' +#' @rdname not +#' @name not +#' @aliases !,Column-method +#' @export +#' @examples +#' \dontrun{ +#' df <- createDataFrame(data.frame(x = c(-1, 0, 1))) +#' +#' head(select(df, !column("x") > 0)) +#' } +#' @note ! since 2.3.0 +setMethod("!", signature(x = "Column"), function(x) not(x)) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index f4a34fbabe4d7..f9687d680e7a2 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -3859,3 +3859,34 @@ setMethod("posexplode_outer", jc <- callJStatic("org.apache.spark.sql.functions", "posexplode_outer", x@jc) column(jc) }) + +#' not +#' +#' Inversion of boolean expression. +#' +#' \code{not} and \code{!} cannot be applied directly to numerical column. +#' To achieve R-like truthiness column has to be casted to \code{BooleanType}. +#' +#' @param x Column to compute on +#' @rdname not +#' @name not +#' @aliases not,Column-method +#' @export +#' @examples \dontrun{ +#' df <- createDataFrame(data.frame( +#' is_true = c(TRUE, FALSE, NA), +#' flag = c(1, 0, 1) +#' )) +#' +#' head(select(df, not(df$is_true))) +#' +#' # Explicit cast is required when working with numeric column +#' head(select(df, not(cast(df$flag, "boolean")))) +#' } +#' @note not since 2.3.0 +setMethod("not", + signature(x = "Column"), + function(x) { + jc <- callJStatic("org.apache.spark.sql.functions", "not", x@jc) + column(jc) + }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index e510ff9a2d80f..d4e4958dc078c 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -856,6 +856,10 @@ setGeneric("otherwise", function(x, value) { standardGeneric("otherwise") }) #' @export setGeneric("over", function(x, window) { standardGeneric("over") }) +#' @rdname eq_null_safe +#' @export +setGeneric("%<=>%", function(x, value) { standardGeneric("%<=>%") }) + ###################### WindowSpec Methods ########################## #' @rdname partitionBy @@ -1154,6 +1158,10 @@ setGeneric("nanvl", function(y, x) { standardGeneric("nanvl") }) #' @export setGeneric("negate", function(x) { standardGeneric("negate") }) +#' @rdname not +#' @export +setGeneric("not", function(x) { standardGeneric("not") }) + #' @rdname next_day #' @export setGeneric("next_day", function(y, x) { standardGeneric("next_day") }) diff --git a/R/pkg/inst/tests/testthat/test_context.R b/R/pkg/inst/tests/testthat/test_context.R index c847113491113..c64fe6edcd49e 100644 --- a/R/pkg/inst/tests/testthat/test_context.R +++ b/R/pkg/inst/tests/testthat/test_context.R @@ -21,10 +21,10 @@ test_that("Check masked functions", { # Check that we are not masking any new function from base, stats, testthat unexpectedly # NOTE: We should avoid adding entries to *namesOfMaskedCompletely* as masked functions make it # hard for users to use base R functions. Please check when in doubt. - namesOfMaskedCompletely <- c("cov", "filter", "sample") + namesOfMaskedCompletely <- c("cov", "filter", "sample", "not") namesOfMasked <- c("describe", "cov", "filter", "lag", "na.omit", "predict", "sd", "var", "colnames", "colnames<-", "intersect", "rank", "rbind", "sample", "subset", - "summary", "transform", "drop", "window", "as.data.frame", "union") + "summary", "transform", "drop", "window", "as.data.frame", "union", "not") if (as.numeric(R.version$major) >= 3 && as.numeric(R.version$minor) >= 3) { namesOfMasked <- c("endsWith", "startsWith", namesOfMasked) } diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 1828cddffd27c..08296354ca7ed 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1323,6 +1323,8 @@ test_that("column operators", { c3 <- (c + c2 - c2) * c2 %% c2 c4 <- (c > c2) & (c2 <= c3) | (c == c2) & (c2 != c3) c5 <- c2 ^ c3 ^ c4 + c6 <- c2 %<=>% c3 + c7 <- !c6 }) test_that("column functions", { @@ -1348,6 +1350,7 @@ test_that("column functions", { c19 <- spark_partition_id() + coalesce(c) + coalesce(c1, c2, c3) c20 <- to_timestamp(c) + to_timestamp(c, "yyyy") + to_date(c, "yyyy") c21 <- posexplode_outer(c) + explode_outer(c) + c22 <- not(c) # Test if base::is.nan() is exposed expect_equal(is.nan(c("a", "b")), c(FALSE, FALSE)) @@ -1488,6 +1491,13 @@ test_that("column functions", { lapply( list(list(x = 1, y = -1, z = -2), list(x = 2, y = 3, z = 5)), as.environment)) + + df <- as.DataFrame(data.frame(is_true = c(TRUE, FALSE, NA))) + expect_equal( + collect(select(df, alias(not(df$is_true), "is_false"))), + data.frame(is_false = c(FALSE, TRUE, NA)) + ) + }) test_that("column binary mathfunctions", { @@ -1973,6 +1983,16 @@ test_that("filter() on a DataFrame", { filtered6 <- where(df, df$age %in% c(19, 30)) expect_equal(count(filtered6), 2) + # test suites for %<=>% + dfNa <- read.json(jsonPathNa) + expect_equal(count(filter(dfNa, dfNa$age %<=>% 60)), 1) + expect_equal(count(filter(dfNa, !(dfNa$age %<=>% 60))), 5 - 1) + expect_equal(count(filter(dfNa, dfNa$age %<=>% NULL)), 3) + expect_equal(count(filter(dfNa, !(dfNa$age %<=>% NULL))), 5 - 3) + # match NA from two columns + expect_equal(count(filter(dfNa, dfNa$age %<=>% dfNa$height)), 2) + expect_equal(count(filter(dfNa, !(dfNa$age %<=>% dfNa$height))), 5 - 2) + # Test stats::filter is working #expect_true(is.ts(filter(1:100, rep(1, 3)))) # nolint })