diff --git a/r/R/dataset-scan.R b/r/R/dataset-scan.R index 750401e173638..a0ac930edc2fd 100644 --- a/r/R/dataset-scan.R +++ b/r/R/dataset-scan.R @@ -161,16 +161,12 @@ ScannerBuilder <- R6Class("ScannerBuilder", inherit = ArrowObject, # cols is either a character vector or a named list of Expressions if (is.character(cols)) { dataset___ScannerBuilder__ProjectNames(self, cols) + } else if (length(cols) == 0) { + # Empty projection + dataset___ScannerBuilder__ProjectNames(self, character(0)) } else { - # If we have expressions, but they all turn out to be field_refs, - # we can still call the simple method - field_names <- get_field_names(cols) - if (all(nzchar(field_names))) { - dataset___ScannerBuilder__ProjectNames(self, field_names) - } else { - # Else, we are projecting/mutating - dataset___ScannerBuilder__ProjectExprs(self, cols, names(cols)) - } + # List of Expressions + dataset___ScannerBuilder__ProjectExprs(self, cols, names(cols)) } self }, diff --git a/r/R/dplyr.R b/r/R/dplyr.R index a6b0ed660a570..7e45ea23efdb7 100644 --- a/r/R/dplyr.R +++ b/r/R/dplyr.R @@ -151,13 +151,20 @@ as.data.frame.arrow_dplyr_query <- function(x, row.names = NULL, optional = FALS } #' @export -head.arrow_dplyr_query <- head.Dataset +head.arrow_dplyr_query <- function(x, n = 6L, ...) { + out <- head.Dataset(x, n, ...) + restore_dplyr_features(out, x) +} #' @export -tail.arrow_dplyr_query <- tail.Dataset +tail.arrow_dplyr_query <- function(x, n = 6L, ...) { + out <- tail.Dataset(x, n, ...) + restore_dplyr_features(out, x) +} #' @export `[.arrow_dplyr_query` <- `[.Dataset` +# TODO: ^ should also probably restore_dplyr_features, and/or that should be moved down # The following S3 methods are registered on load if dplyr is present tbl_vars.arrow_dplyr_query <- function(x) names(x$selected_columns) diff --git a/r/tests/testthat/test-dataset.R b/r/tests/testthat/test-dataset.R index 4570c1f576294..de585431a351d 100644 --- a/r/tests/testthat/test-dataset.R +++ b/r/tests/testthat/test-dataset.R @@ -1197,8 +1197,8 @@ test_that("compute()/collect(as_data_frame=FALSE)", { # the group_by() prevents compute() from returning a Table... expect_is(tab5, "arrow_dplyr_query") - # ... but $.data is a Table... - expect_is(tab5$.data, "Table") + # ... but $.data is a Table (InMemoryDataset)... + expect_r6_class(tab5$.data, "InMemoryDataset") # ... and the mutate() was evaluated expect_true("negint" %in% names(tab5$.data)) @@ -1561,13 +1561,14 @@ test_that("Dataset writing: dplyr methods", { dst_dir2 <- tempfile() ds %>% group_by(int) %>% - select(chr, dbl) %>% + select(chr, dubs = dbl) %>% write_dataset(dst_dir2, format = "feather") new_ds <- open_dataset(dst_dir2, format = "feather") + # Renaming doesn't work, but mutating does?? expect_equivalent( - collect(new_ds) %>% arrange(int), - rbind(df1[c("chr", "dbl", "int")], df2[c("chr", "dbl", "int")]) + collect(new_ds) %>% arrange(int) %>% print(), + rbind(df1[c("chr", "dbl", "int")], df2[c("chr", "dbl", "int")]) %>% rename(dubs = dbl) %>% print() ) # filter to restrict written rows @@ -1758,10 +1759,6 @@ test_that("Dataset writing: unsupported features/input validation", { expect_error(write_dataset(4), 'dataset must be a "Dataset"') ds <- open_dataset(hive_dir) - expect_error( - select(ds, integer = int) %>% write_dataset(ds), - "Renaming columns when writing a dataset is not yet supported" - ) expect_error( write_dataset(ds, partitioning = c("int", "NOTACOLUMN"), format = "ipc"), 'Invalid field name: "NOTACOLUMN"' diff --git a/r/tests/testthat/test-dplyr.R b/r/tests/testthat/test-dplyr.R index e9806c549971e..08501f75c9976 100644 --- a/r/tests/testthat/test-dplyr.R +++ b/r/tests/testthat/test-dplyr.R @@ -267,7 +267,6 @@ test_that("head", { filter(int > 5) %>% head(2) expect_r6_class(b3, "Table") - print(as.data.frame(b3)) expect_equal(as.data.frame(b3), set_names(expected, c("int", "strng"))) b4 <- batch %>% @@ -275,7 +274,7 @@ test_that("head", { filter(int > 5) %>% group_by(int) %>% head(2) - expect_s3_class(b4, "Table") + expect_s3_class(b4, "arrow_dplyr_query") expect_equal( as.data.frame(b4), expected %>% @@ -308,7 +307,7 @@ test_that("tail", { filter(int > 5) %>% group_by(int) %>% tail(2) - expect_s3_class(b4, "Table") + expect_s3_class(b4, "arrow_dplyr_query") expect_equal( as.data.frame(b4), expected %>%