diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 3d912c9fa32ed..0a1012283ed02 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -1717,6 +1717,23 @@ getColumn <- function(x, c) { column(callJMethod(x@sdf, "col", c)) } +setColumn <- function(x, c, value) { + if (class(value) != "Column" && !is.null(value)) { + if (isAtomicLengthOne(value)) { + value <- lit(value) + } else { + stop("value must be a Column, literal value as atomic in length of 1, or NULL") + } + } + + if (is.null(value)) { + nx <- drop(x, c) + } else { + nx <- withColumn(x, c, value) + } + nx +} + #' @param name name of a Column (without being wrapped by \code{""}). #' @rdname select #' @name $ @@ -1735,19 +1752,7 @@ setMethod("$", signature(x = "SparkDataFrame"), #' @note $<- since 1.4.0 setMethod("$<-", signature(x = "SparkDataFrame"), function(x, name, value) { - if (class(value) != "Column" && !is.null(value)) { - if (isAtomicLengthOne(value)) { - value <- lit(value) - } else { - stop("value must be a Column, literal value as atomic in length of 1, or NULL") - } - } - - if (is.null(value)) { - nx <- drop(x, name) - } else { - nx <- withColumn(x, name, value) - } + nx <- setColumn(x, name, value) x@sdf <- nx@sdf x }) @@ -1767,6 +1772,21 @@ setMethod("[[", signature(x = "SparkDataFrame", i = "numericOrcharacter"), getColumn(x, i) }) +#' @rdname subset +#' @name [[<- +#' @aliases [[<-,SparkDataFrame,numericOrcharacter-method +#' @note [[<- since 2.1.1 +setMethod("[[<-", signature(x = "SparkDataFrame", i = "numericOrcharacter"), + function(x, i, value) { + if (is.numeric(i)) { + cols <- columns(x) + i <- cols[[i]] + } + nx <- setColumn(x, i, value) + x@sdf <- nx@sdf + x + }) + #' @rdname subset #' @name [ #' @aliases [,SparkDataFrame-method @@ -1814,6 +1834,8 @@ setMethod("[", signature(x = "SparkDataFrame"), #' @param j,select expression for the single Column or a list of columns to select from the SparkDataFrame. #' @param drop if TRUE, a Column will be returned if the resulting dataset has only one column. #' Otherwise, a SparkDataFrame will always be returned. +#' @param value a Column or an atomic vector in the length of 1 as literal value, or \code{NULL}. +#' If \code{NULL}, the specified Column is dropped. #' @param ... currently not used. #' @return A new SparkDataFrame containing only the rows that meet the condition with selected columns. #' @export diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 26017427ab5e7..aaa8fb498c85a 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1021,6 +1021,9 @@ test_that("select operators", { df$age2 <- df$age * 2 expect_equal(columns(df), c("name", "age", "age2")) expect_equal(count(where(df, df$age2 == df$age * 2)), 2) + df$age2 <- df[["age"]] * 3 + expect_equal(columns(df), c("name", "age", "age2")) + expect_equal(count(where(df, df$age2 == df$age * 3)), 2) df$age2 <- 21 expect_equal(columns(df), c("name", "age", "age2")) @@ -1033,6 +1036,23 @@ test_that("select operators", { expect_error(df$age3 <- c(22, NA), "value must be a Column, literal value as atomic in length of 1, or NULL") + df[["age2"]] <- 23 + expect_equal(columns(df), c("name", "age", "age2")) + expect_equal(count(where(df, df$age2 == 23)), 3) + + df[[3]] <- 24 + expect_equal(columns(df), c("name", "age", "age2")) + expect_equal(count(where(df, df$age2 == 24)), 3) + + df[[3]] <- df$age + expect_equal(count(where(df, df$age2 == df$age)), 2) + + df[["age2"]] <- df[["name"]] + expect_equal(count(where(df, df$age2 == df$name)), 3) + + expect_error(df[["age3"]] <- c(22, 23), + "value must be a Column, literal value as atomic in length of 1, or NULL") + # Test parameter drop expect_equal(class(df[, 1]) == "SparkDataFrame", T) expect_equal(class(df[, 1, drop = T]) == "Column", T)