From 9347731fe611c25f51c8d4831f1198c9438babd5 Mon Sep 17 00:00:00 2001 From: Neal Richardson Date: Thu, 13 May 2021 08:47:16 -0700 Subject: [PATCH] ARROW-12731: [R] Use InMemoryDataset for Table/RecordBatch in dplyr code Discussing with @bkietz on #10166, we realized that we could already evaluate filter/project on Table/RecordBatch by wrapping it in InMemoryDataset and using the Dataset machinery, so I wanted to see how well that worked. Mostly it does, with a couple of caveats: * You can't dictionary_encode a dataset column. `Error: Invalid: ExecuteScalarExpression cannot Execute non-scalar expression {x=dictionary_encode(x, {NON-REPRESENTABLE OPTIONS})}` (ARROW-12632). I will remove the `as.factor` method and leave a TODO to restore it after that JIRA is resolved. * with the existing array_expressions, you could supply an additional Array (or R data convertible to an Array) when doing `mutate()`; this is not implemented for Datasets and that's ok. For Tables/RecordBatches, the behavior in this PR is to pull the data into R, which is fine. There are a lot of changes here, which means the diff is big, but I've tried to group into distinct commits the main action. Highlights: * https://github.com/apache/arrow/pull/10191/commits/5b501c508e8da7313dce0e361369dc62aa645a8f is the main switch to use InMemoryDataset * https://github.com/apache/arrow/pull/10191/commits/b31fb5e594bc49628f7a4459109784caafe99cb4 deletes `array_expression` * https://github.com/apache/arrow/pull/10191/commits/0d3193863fc578d93d9319ea2184e46e9f2f36e1 simplifies the interface for adding functions to the dplyr data_mask; definitely check this one out and see what you think of the new way--I hope it's much simpler to add new functions * https://github.com/apache/arrow/pull/10191/commits/2e6374f94cbcc236becc3e41797a26127cf06ab0 improves the print method for queries by showing both the expression and the expected type of the output column, per suggestion from @bkietz * https://github.com/apache/arrow/pull/10191/commits/d12f584e67531e251a1c72a5b67e14361d31f503 just splits up dplyr.R into many files; https://github.com/apache/arrow/pull/10191/commits/34dc1e6589ca622c8b1baeba7ce03c1d2b0b4c28 deletes tests that are duplicated between test-dplyr*.R and test-dataset.R (since they're now going through a common C++ interface). * https://github.com/apache/arrow/pull/10191/commits/a0914f67319e659348396f106024d69064ea3943 + https://github.com/apache/arrow/pull/10191/commits/eee491a4e9e6735a0f304d1d71306bfd091f702b contain ARROW-12696 Closes #10191 from nealrichardson/dplyr-in-memory Authored-by: Neal Richardson Signed-off-by: Neal Richardson --- r/DESCRIPTION | 11 +- r/NAMESPACE | 5 +- r/R/arrow-datum.R | 71 +- r/R/arrow-package.R | 16 +- r/R/arrow-tabular.R | 6 +- r/R/arrowExports.R | 4 + r/R/dataset-scan.R | 14 +- r/R/dataset-write.R | 8 - r/R/dplyr-arrange.R | 93 ++ r/R/dplyr-collect.R | 62 + r/R/dplyr-eval.R | 99 ++ r/R/dplyr-filter.R | 84 ++ r/R/dplyr-functions.R | 352 ++++++ r/R/dplyr-group-by.R | 65 ++ r/R/dplyr-mutate.R | 117 ++ r/R/dplyr-select.R | 120 ++ r/R/dplyr-summarize.R | 36 + r/R/dplyr.R | 1005 +---------------- r/R/expression.R | 194 +--- r/man/contains_regex.Rd | 2 +- r/man/get_stringr_pattern_options.Rd | 2 +- r/src/arrowExports.cpp | 17 + r/src/expression.cpp | 8 + r/tests/testthat/helper-arrow.R | 2 +- r/tests/testthat/test-RecordBatch.R | 7 +- r/tests/testthat/test-Table.R | 7 +- r/tests/testthat/test-compute-arith.R | 3 +- r/tests/testthat/test-compute-sort.R | 17 +- r/tests/testthat/test-dataset.R | 339 +----- r/tests/testthat/test-dplyr-arrange.R | 2 + r/tests/testthat/test-dplyr-filter.R | 57 +- r/tests/testthat/test-dplyr-group-by.R | 2 + r/tests/testthat/test-dplyr-mutate.R | 39 +- .../testthat/test-dplyr-string-functions.R | 114 +- r/tests/testthat/test-dplyr.R | 28 +- r/tests/testthat/test-expression.R | 56 +- r/tests/testthat/test-filesystem.R | 4 + 37 files changed, 1360 insertions(+), 1708 deletions(-) create mode 100644 r/R/dplyr-arrange.R create mode 100644 r/R/dplyr-collect.R create mode 100644 r/R/dplyr-eval.R create mode 100644 r/R/dplyr-filter.R create mode 100644 r/R/dplyr-functions.R create mode 100644 r/R/dplyr-group-by.R create mode 100644 r/R/dplyr-mutate.R create mode 100644 r/R/dplyr-select.R create mode 100644 r/R/dplyr-summarize.R diff --git a/r/DESCRIPTION b/r/DESCRIPTION index 7f88320fb3d6a..82ca6fed617f1 100644 --- a/r/DESCRIPTION +++ b/r/DESCRIPTION @@ -78,9 +78,18 @@ Collate: 'dataset-write.R' 'deprecated.R' 'dictionary.R' + 'dplyr-arrange.R' + 'dplyr-collect.R' + 'dplyr-eval.R' + 'dplyr-filter.R' + 'expression.R' + 'dplyr-functions.R' + 'dplyr-group-by.R' + 'dplyr-mutate.R' + 'dplyr-select.R' + 'dplyr-summarize.R' 'record-batch.R' 'table.R' - 'expression.R' 'dplyr.R' 'feather.R' 'field.R' diff --git a/r/NAMESPACE b/r/NAMESPACE index 9a05b87476a12..f89d2effea73c 100644 --- a/r/NAMESPACE +++ b/r/NAMESPACE @@ -21,7 +21,6 @@ S3method("[[<-",Schema) S3method("names<-",ArrowTabular) S3method(Ops,ArrowDatum) S3method(Ops,Expression) -S3method(Ops,array_expression) S3method(all,ArrowDatum) S3method(all,equal.ArrowObject) S3method(any,ArrowDatum) @@ -37,7 +36,6 @@ S3method(as.list,ArrowTabular) S3method(as.list,Schema) S3method(as.raw,Buffer) S3method(as.vector,ArrowDatum) -S3method(as.vector,array_expression) S3method(c,Dataset) S3method(dim,ArrowTabular) S3method(dim,Dataset) @@ -51,7 +49,6 @@ S3method(head,arrow_dplyr_query) S3method(is.na,ArrowDatum) S3method(is.na,Expression) S3method(is.na,Scalar) -S3method(is.na,array_expression) S3method(is.nan,ArrowDatum) S3method(is_in,ArrowDatum) S3method(is_in,default) @@ -80,7 +77,6 @@ S3method(names,StructArray) S3method(names,Table) S3method(names,arrow_dplyr_query) S3method(print,"arrow-enum") -S3method(print,array_expression) S3method(print,arrow_dplyr_query) S3method(print,arrow_info) S3method(print,arrow_r_metadata) @@ -295,6 +291,7 @@ importFrom(purrr,as_mapper) importFrom(purrr,keep) importFrom(purrr,map) importFrom(purrr,map2) +importFrom(purrr,map2_chr) importFrom(purrr,map_chr) importFrom(purrr,map_dfr) importFrom(purrr,map_int) diff --git a/r/R/arrow-datum.R b/r/R/arrow-datum.R index 4edcb200ea01c..f7c1d4d4ed776 100644 --- a/r/R/arrow-datum.R +++ b/r/R/arrow-datum.R @@ -46,6 +46,73 @@ as.vector.ArrowDatum <- function(x, mode) { ) } +#' @export +Ops.ArrowDatum <- function(e1, e2) { + if (.Generic == "!") { + eval_array_expression(.Generic, e1) + } else if (.Generic %in% names(.array_function_map)) { + eval_array_expression(.Generic, e1, e2) + } else { + stop(paste0("Unsupported operation on `", class(e1)[1L], "` : "), .Generic, call. = FALSE) + } +} + +# Wrapper around call_function that: +# (1) maps R function names to Arrow C++ compute ("/" --> "divide_checked") +# (2) wraps R input args as Array or Scalar +eval_array_expression <- function(FUN, + ..., + args = list(...), + options = empty_named_list()) { + if (FUN == "-" && length(args) == 1L) { + if (inherits(args[[1]], "ArrowObject")) { + return(eval_array_expression("negate_checked", args[[1]])) + } else { + return(-args[[1]]) + } + } + args <- lapply(args, .wrap_arrow, FUN) + + # In Arrow, "divide" is one function, which does integer division on + # integer inputs and floating-point division on floats + if (FUN == "/") { + # TODO: omg so many ways it's wrong to assume these types + args <- map(args, ~.$cast(float64())) + } else if (FUN == "%/%") { + # In R, integer division works like floor(float division) + out <- eval_array_expression("/", args = args, options = options) + return(out$cast(int32(), allow_float_truncate = TRUE)) + } else if (FUN == "%%") { + # {e1 - e2 * ( e1 %/% e2 )} + # ^^^ form doesn't work because Ops.Array evaluates eagerly, + # but we can build that up + quotient <- eval_array_expression("%/%", args = args) + base <- eval_array_expression("*", quotient, args[[2]]) + # this cast is to ensure that the result of this and e1 are the same + # (autocasting only applies to scalars) + base <- base$cast(args[[1]]$type) + return(eval_array_expression("-", args[[1]], base)) + } + + call_function( + .array_function_map[[FUN]] %||% FUN, + args = args, + options = options + ) +} + +.wrap_arrow <- function(arg, fun) { + if (!inherits(arg, "ArrowObject")) { + # TODO: Array$create if lengths are equal? + if (fun == "%in%") { + arg <- Array$create(arg) + } else { + arg <- Scalar$create(arg) + } + } + arg +} + #' @export na.omit.ArrowDatum <- function(object, ...){ object$Filter(!is.na(object)) @@ -66,10 +133,6 @@ filter_rows <- function(x, i, keep_na = TRUE, ...) { # General purpose function for [ row subsetting with R semantics # Based on the input for `i`, calls x$Filter, x$Slice, or x$Take nrows <- x$num_rows %||% x$length() # Depends on whether Array or Table-like - if (inherits(i, "array_expression")) { - # Evaluate it - i <- eval_array_expression(i) - } if (is.logical(i)) { if (isTRUE(i)) { # Shortcut without doing any work diff --git a/r/R/arrow-package.R b/r/R/arrow-package.R index f6f01fe623a4c..9e8d629e08a5f 100644 --- a/r/R/arrow-package.R +++ b/r/R/arrow-package.R @@ -17,7 +17,7 @@ #' @importFrom stats quantile median na.omit na.exclude na.pass na.fail #' @importFrom R6 R6Class -#' @importFrom purrr as_mapper map map2 map_chr map_dfr map_int map_lgl keep +#' @importFrom purrr as_mapper map map2 map_chr map2_chr map_dfr map_int map_lgl keep #' @importFrom assertthat assert_that is.string #' @importFrom rlang list2 %||% is_false abort dots_n warn enquo quo_is_null enquos is_integerish quos eval_tidy new_data_mask syms env new_environment env_bind as_label set_names exec is_bare_character quo_get_expr quo_set_expr .data seq2 is_quosure enexpr enexprs expr #' @importFrom tidyselect vars_pull vars_rename vars_select eval_select @@ -49,8 +49,18 @@ # Create these once, at package build time if (arrow_available()) { - dplyr_functions$dataset <- build_function_list(build_dataset_expression) - dplyr_functions$array <- build_function_list(build_array_expression) + # Also include all available Arrow Compute functions, + # namespaced as arrow_fun. + # We can't do this at install time because list_compute_functions() may error + all_arrow_funs <- list_compute_functions() + arrow_funcs <- set_names( + lapply(all_arrow_funs, function(fun) { + force(fun) + function(...) build_expr(fun, ...) + }), + paste0("arrow_", all_arrow_funs) + ) + .cache$functions <- c(nse_funcs, arrow_funcs) } invisible() } diff --git a/r/R/arrow-tabular.R b/r/R/arrow-tabular.R index bba5ad5f5e611..2bd0a99534f4c 100644 --- a/r/R/arrow-tabular.R +++ b/r/R/arrow-tabular.R @@ -223,13 +223,13 @@ na.fail.ArrowTabular <- function(object, ...){ #' @export na.omit.ArrowTabular <- function(object, ...){ - not_na <- map(object$columns, ~build_array_expression("is_valid", .x)) + not_na <- map(object$columns, ~call_function("is_valid", .x)) not_na_agg <- Reduce("&", not_na) - object$Filter(eval_array_expression(not_na_agg)) + object$Filter(not_na_agg) } #' @export -na.exclude.ArrowTabular <- na.omit.ArrowTabular +na.exclude.ArrowTabular <- na.omit.ArrowTabular ToString_tabular <- function(x, ...) { # Generic to work with both RecordBatch and Table diff --git a/r/R/arrowExports.R b/r/R/arrowExports.R index 0063836970e2e..c026c72899f6f 100644 --- a/r/R/arrowExports.R +++ b/r/R/arrowExports.R @@ -800,6 +800,10 @@ compute___expr__ToString <- function(x){ .Call(`_arrow_compute___expr__ToString`, x) } +compute___expr__type <- function(x, schema){ + .Call(`_arrow_compute___expr__type`, x, schema) +} + ipc___WriteFeather__Table <- function(stream, table, version, chunk_size, compression, compression_level){ invisible(.Call(`_arrow_ipc___WriteFeather__Table`, stream, table, version, chunk_size, compression, compression_level)) } diff --git a/r/R/dataset-scan.R b/r/R/dataset-scan.R index 84949bbd3972d..a73bfb3dd74cd 100644 --- a/r/R/dataset-scan.R +++ b/r/R/dataset-scan.R @@ -162,16 +162,12 @@ ScannerBuilder <- R6Class("ScannerBuilder", inherit = ArrowObject, # cols is either a character vector or a named list of Expressions if (is.character(cols)) { dataset___ScannerBuilder__ProjectNames(self, cols) + } else if (length(cols) == 0) { + # Empty projection + dataset___ScannerBuilder__ProjectNames(self, character(0)) } else { - # If we have expressions, but they all turn out to be field_refs, - # we can still call the simple method - field_names <- get_field_names(cols) - if (all(nzchar(field_names))) { - dataset___ScannerBuilder__ProjectNames(self, field_names) - } else { - # Else, we are projecting/mutating - dataset___ScannerBuilder__ProjectExprs(self, cols, names(cols)) - } + # List of Expressions + dataset___ScannerBuilder__ProjectExprs(self, cols, names(cols)) } self }, diff --git a/r/R/dataset-write.R b/r/R/dataset-write.R index 8c9a1efc8d85d..90413e9b9ed68 100644 --- a/r/R/dataset-write.R +++ b/r/R/dataset-write.R @@ -64,14 +64,6 @@ write_dataset <- function(dataset, ...) { format <- match.arg(format) if (inherits(dataset, "arrow_dplyr_query")) { - if (inherits(dataset$.data, "ArrowTabular")) { - # collect() to materialize any mutate/rename - dataset <- dplyr::collect(dataset, as_data_frame = FALSE) - } - # We can select a subset of columns but we can't rename them - if (!all(get_field_names(dataset) == names(dataset$selected_columns))) { - stop("Renaming columns when writing a dataset is not yet supported", call. = FALSE) - } # partitioning vars need to be in the `select` schema dataset <- ensure_group_vars(dataset) } else if (inherits(dataset, "grouped_df")) { diff --git a/r/R/dplyr-arrange.R b/r/R/dplyr-arrange.R new file mode 100644 index 0000000000000..59afa4fe6a0a6 --- /dev/null +++ b/r/R/dplyr-arrange.R @@ -0,0 +1,93 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +# The following S3 methods are registered on load if dplyr is present + +arrange.arrow_dplyr_query <- function(.data, ..., .by_group = FALSE) { + call <- match.call() + exprs <- quos(...) + if (.by_group) { + # when the data is is grouped and .by_group is TRUE, order the result by + # the grouping columns first + exprs <- c(quos(!!!dplyr::groups(.data)), exprs) + } + if (length(exprs) == 0) { + # Nothing to do + return(.data) + } + .data <- arrow_dplyr_query(.data) + # find and remove any dplyr::desc() and tidy-eval + # the arrange expressions inside an Arrow data_mask + sorts <- vector("list", length(exprs)) + descs <- logical(0) + mask <- arrow_mask(.data) + for (i in seq_along(exprs)) { + x <- find_and_remove_desc(exprs[[i]]) + exprs[[i]] <- x[["quos"]] + sorts[[i]] <- arrow_eval(exprs[[i]], mask) + if (inherits(sorts[[i]], "try-error")) { + msg <- paste('Expression', as_label(exprs[[i]]), 'not supported in Arrow') + return(abandon_ship(call, .data, msg)) + } + names(sorts)[i] <- as_label(exprs[[i]]) + descs[i] <- x[["desc"]] + } + .data$arrange_vars <- c(sorts, .data$arrange_vars) + .data$arrange_desc <- c(descs, .data$arrange_desc) + .data +} +arrange.Dataset <- arrange.ArrowTabular <- arrange.arrow_dplyr_query + +# Helper to handle desc() in arrange() +# * Takes a quosure as input +# * Returns a list with two elements: +# 1. The quosure with any wrapping parentheses and desc() removed +# 2. A logical value indicating whether desc() was found +# * Performs some other validation +find_and_remove_desc <- function(quosure) { + expr <- quo_get_expr(quosure) + descending <- FALSE + if (length(all.vars(expr)) < 1L) { + stop( + "Expression in arrange() does not contain any field names: ", + deparse(expr), + call. = FALSE + ) + } + # Use a while loop to remove any number of nested pairs of enclosing + # parentheses and any number of nested desc() calls. In the case of multiple + # nested desc() calls, each one toggles the sort order. + while (identical(typeof(expr), "language") && is.call(expr)) { + if (identical(expr[[1]], quote(`(`))) { + # remove enclosing parentheses + expr <- expr[[2]] + } else if (identical(expr[[1]], quote(desc))) { + # remove desc() and toggle descending + expr <- expr[[2]] + descending <- !descending + } else { + break + } + } + return( + list( + quos = quo_set_expr(quosure, expr), + desc = descending + ) + ) +} \ No newline at end of file diff --git a/r/R/dplyr-collect.R b/r/R/dplyr-collect.R new file mode 100644 index 0000000000000..55716291dcbf7 --- /dev/null +++ b/r/R/dplyr-collect.R @@ -0,0 +1,62 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +# The following S3 methods are registered on load if dplyr is present + +collect.arrow_dplyr_query <- function(x, as_data_frame = TRUE, ...) { + x <- ensure_group_vars(x) + x <- ensure_arrange_vars(x) # this sets x$temp_columns + # Pull only the selected rows and cols into R + # See dataset.R for Dataset and Scanner(Builder) classes + tab <- Scanner$create(x)$ToTable() + # Arrange rows + if (length(x$arrange_vars) > 0) { + tab <- tab[ + tab$SortIndices(names(x$arrange_vars), x$arrange_desc), + names(x$selected_columns), # this omits x$temp_columns from the result + drop = FALSE + ] + } + if (as_data_frame) { + df <- as.data.frame(tab) + tab$invalidate() + restore_dplyr_features(df, x) + } else { + restore_dplyr_features(tab, x) + } +} +collect.ArrowTabular <- function(x, as_data_frame = TRUE, ...) { + if (as_data_frame) { + as.data.frame(x, ...) + } else { + x + } +} +collect.Dataset <- function(x, ...) dplyr::collect(arrow_dplyr_query(x), ...) + +compute.arrow_dplyr_query <- function(x, ...) dplyr::collect(x, as_data_frame = FALSE) +compute.ArrowTabular <- function(x, ...) x +compute.Dataset <- compute.arrow_dplyr_query + +pull.arrow_dplyr_query <- function(.data, var = -1) { + .data <- arrow_dplyr_query(.data) + var <- vars_pull(names(.data), !!enquo(var)) + .data$selected_columns <- set_names(.data$selected_columns[var], var) + dplyr::collect(.data)[[1]] +} +pull.Dataset <- pull.ArrowTabular <- pull.arrow_dplyr_query \ No newline at end of file diff --git a/r/R/dplyr-eval.R b/r/R/dplyr-eval.R new file mode 100644 index 0000000000000..2d19bd4cb9052 --- /dev/null +++ b/r/R/dplyr-eval.R @@ -0,0 +1,99 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +arrow_eval <- function (expr, mask) { + # filter(), mutate(), etc. work by evaluating the quoted `exprs` to generate Expressions + # with references to Arrays (if .data is Table/RecordBatch) or Fields (if + # .data is a Dataset). + + # This yields an Expression as long as the `exprs` are implemented in Arrow. + # Otherwise, it returns a try-error + tryCatch(eval_tidy(expr, mask), error = function(e) { + # Look for the cases where bad input was given, i.e. this would fail + # in regular dplyr anyway, and let those raise those as errors; + # else, for things not supported by Arrow return a "try-error", + # which we'll handle differently + msg <- conditionMessage(e) + patterns <- .cache$i18ized_error_pattern + if (is.null(patterns)) { + patterns <- i18ize_error_messages() + # Memoize it + .cache$i18ized_error_pattern <- patterns + } + if (grepl(patterns, msg)) { + stop(e) + } + + out <- structure(msg, class = "try-error", condition = e) + if (grepl("not supported.*Arrow", msg)) { + # One of ours. Mark it so that consumers can handle it differently + class(out) <- c("arrow-try-error", class(out)) + } + invisible(out) + }) +} + +handle_arrow_not_supported <- function(err, lab) { + # Look for informative message from the Arrow function version (see above) + if (inherits(err, "arrow-try-error")) { + # Include it if found + paste0('In ', lab, ', ', as.character(err)) + } else { + # Otherwise be opaque (the original error is probably not useful) + paste('Expression', lab, 'not supported in Arrow') + } +} + +i18ize_error_messages <- function() { + # Figure out what the error messages will be with this LANGUAGE + # so that we can look for them + out <- list( + obj = tryCatch(eval(parse(text = "X_____X")), error = function(e) conditionMessage(e)), + fun = tryCatch(eval(parse(text = "X_____X()")), error = function(e) conditionMessage(e)) + ) + paste(map(out, ~sub("X_____X", ".*", .)), collapse = "|") +} + +# Helper to raise a common error +arrow_not_supported <- function(msg) { + # TODO: raise a classed error? + stop(paste(msg, "not supported by Arrow"), call. = FALSE) +} + +# Create a data mask for evaluating a dplyr expression +arrow_mask <- function(.data) { + f_env <- new_environment(.cache$functions) + + # Add functions that need to error hard and clear. + # Some R functions will still try to evaluate on an Expression + # and return NA with a warning + fail <- function(...) stop("Not implemented") + for (f in c("mean", "sd")) { + f_env[[f]] <- fail + } + + # Add the column references and make the mask + out <- new_data_mask( + new_environment(.data$selected_columns, parent = f_env), + f_env + ) + # Then insert the data pronoun + # TODO: figure out what rlang::as_data_pronoun does/why we should use it + # (because if we do we get `Error: Can't modify the data pronoun` in mutate()) + out$.data <- .data$selected_columns + out +} diff --git a/r/R/dplyr-filter.R b/r/R/dplyr-filter.R new file mode 100644 index 0000000000000..3cbc34511a4ca --- /dev/null +++ b/r/R/dplyr-filter.R @@ -0,0 +1,84 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +# The following S3 methods are registered on load if dplyr is present + +filter.arrow_dplyr_query <- function(.data, ..., .preserve = FALSE) { + # TODO something with the .preserve argument + filts <- quos(...) + if (length(filts) == 0) { + # Nothing to do + return(.data) + } + + .data <- arrow_dplyr_query(.data) + # tidy-eval the filter expressions inside an Arrow data_mask + filters <- lapply(filts, arrow_eval, arrow_mask(.data)) + bad_filters <- map_lgl(filters, ~inherits(., "try-error")) + if (any(bad_filters)) { + # This is similar to abandon_ship() except that the filter eval is + # vectorized, and we apply filters that _did_ work before abandoning ship + # with the rest + expr_labs <- map_chr(filts[bad_filters], as_label) + if (query_on_dataset(.data)) { + # Abort. We don't want to auto-collect if this is a Dataset because that + # could blow up, too big. + stop( + "Filter expression not supported for Arrow Datasets: ", + oxford_paste(expr_labs, quote = FALSE), + "\nCall collect() first to pull data into R.", + call. = FALSE + ) + } else { + arrow_errors <- map2_chr( + filters[bad_filters], expr_labs, + handle_arrow_not_supported + ) + if (length(arrow_errors) == 1) { + msg <- paste0(arrow_errors, "; ") + } else { + msg <- paste0("* ", arrow_errors, "\n", collapse = "") + } + warning( + msg, "pulling data into R", + immediate. = TRUE, + call. = FALSE + ) + # Set any valid filters first, then collect and then apply the invalid ones in R + .data <- set_filters(.data, filters[!bad_filters]) + return(dplyr::filter(dplyr::collect(.data), !!!filts[bad_filters])) + } + } + + set_filters(.data, filters) +} +filter.Dataset <- filter.ArrowTabular <- filter.arrow_dplyr_query + +set_filters <- function(.data, expressions) { + if (length(expressions)) { + # expressions is a list of Expressions. AND them together and set them on .data + new_filter <- Reduce("&", expressions) + if (isTRUE(.data$filtered_rows)) { + # TRUE is default (i.e. no filter yet), so we don't need to & with it + .data$filtered_rows <- new_filter + } else { + .data$filtered_rows <- .data$filtered_rows & new_filter + } + } + .data +} \ No newline at end of file diff --git a/r/R/dplyr-functions.R b/r/R/dplyr-functions.R new file mode 100644 index 0000000000000..bee06a7cb6aa6 --- /dev/null +++ b/r/R/dplyr-functions.R @@ -0,0 +1,352 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +#' @include expression.R +NULL + +# This environment is an internal cache for things including data mask functions +# We'll populate it at package load time. +.cache <- NULL +init_env <- function () { + .cache <<- new.env(hash = TRUE) +} +init_env() + +# nse_funcs is a list of functions that operated on (and return) Expressions +# These will be the basis for a data_mask inside dplyr methods +# and will be added to .cache at package load time + +# Start with mappings from R function name spellings +nse_funcs <- lapply(set_names(names(.array_function_map)), function(operator) { + force(operator) + function(...) build_expr(operator, ...) +}) + +# Now add functions to that list where the mapping from R to Arrow isn't 1:1 +# Each of these functions should have the same signature as the R function +# they're replacing. +# +# When to use `build_expr()` vs. `Expression$create()`? +# +# Use `build_expr()` if you need to +# (1) map R function names to Arrow C++ functions +# (2) wrap R inputs (vectors) as Array/Scalar +# +# `Expression$create()` is lower level. Most of the functions below use it +# because they manage the preparation of the user-provided inputs +# and don't need to wrap scalars + +nse_funcs$cast <- function(x, target_type, safe = TRUE, ...) { + opts <- cast_options(safe, ...) + opts$to_type <- as_type(target_type) + Expression$create("cast", x, options = opts) +} + +nse_funcs$dictionary_encode <- function(x, + null_encoding_behavior = c("mask", "encode")) { + behavior <- toupper(match.arg(null_encoding_behavior)) + null_encoding_behavior <- NullEncodingBehavior[[behavior]] + Expression$create( + "dictionary_encode", + x, + options = list(null_encoding_behavior = null_encoding_behavior) + ) +} + +nse_funcs$between <- function(x, left, right) { + x >= left & x <= right +} + +# as.* type casting functions +# as.factor() is mapped in expression.R +nse_funcs$as.character <- function(x) { + Expression$create("cast", x, options = cast_options(to_type = string())) +} +nse_funcs$as.double <- function(x) { + Expression$create("cast", x, options = cast_options(to_type = float64())) +} +nse_funcs$as.integer <- function(x) { + Expression$create( + "cast", + x, + options = cast_options( + to_type = int32(), + allow_float_truncate = TRUE, + allow_decimal_truncate = TRUE + ) + ) +} +nse_funcs$as.integer64 <- function(x) { + Expression$create( + "cast", + x, + options = cast_options( + to_type = int64(), + allow_float_truncate = TRUE, + allow_decimal_truncate = TRUE + ) + ) +} +nse_funcs$as.logical <- function(x) { + Expression$create("cast", x, options = cast_options(to_type = boolean())) +} +nse_funcs$as.numeric <- function(x) { + Expression$create("cast", x, options = cast_options(to_type = float64())) +} + +# String functions +nse_funcs$nchar <- function(x, type = "chars", allowNA = FALSE, keepNA = NA) { + if (allowNA) { + arrow_not_supported("allowNA = TRUE") + } + if (is.na(keepNA)) { + keepNA <- !identical(type, "width") + } + if (!keepNA) { + # TODO: I think there is a fill_null kernel we could use, set null to 2 + arrow_not_supported("keepNA = TRUE") + } + if (identical(type, "bytes")) { + Expression$create("binary_length", x) + } else { + Expression$create("utf8_length", x) + } +} + +nse_funcs$str_trim <- function(string, side = c("both", "left", "right")) { + side <- match.arg(side) + trim_fun <- switch(side, + left = "utf8_ltrim_whitespace", + right = "utf8_rtrim_whitespace", + both = "utf8_trim_whitespace" + ) + Expression$create(trim_fun, string) +} + +nse_funcs$grepl <- function(pattern, x, ignore.case = FALSE, fixed = FALSE) { + arrow_fun <- ifelse(fixed && !ignore.case, "match_substring", "match_substring_regex") + Expression$create( + arrow_fun, + x, + options = list(pattern = format_string_pattern(pattern, ignore.case, fixed)) + ) +} + +nse_funcs$str_detect <- function(string, pattern, negate = FALSE) { + opts <- get_stringr_pattern_options(enexpr(pattern)) + out <- nse_funcs$grepl( + pattern = opts$pattern, + x = string, + ignore.case = opts$ignore_case, + fixed = opts$fixed + ) + if (negate) { + out <- !out + } + out +} + +# Encapsulate some common logic for sub/gsub/str_replace/str_replace_all +arrow_r_string_replace_function <- function(max_replacements) { + function(pattern, replacement, x, ignore.case = FALSE, fixed = FALSE) { + Expression$create( + ifelse(fixed && !ignore.case, "replace_substring", "replace_substring_regex"), + x, + options = list( + pattern = format_string_pattern(pattern, ignore.case, fixed), + replacement = format_string_replacement(replacement, ignore.case, fixed), + max_replacements = max_replacements + ) + ) + } +} + +arrow_stringr_string_replace_function <- function(max_replacements) { + function(string, pattern, replacement) { + opts <- get_stringr_pattern_options(enexpr(pattern)) + arrow_r_string_replace_function(max_replacements)( + pattern = opts$pattern, + replacement = replacement, + x = string, + ignore.case = opts$ignore_case, + fixed = opts$fixed + ) + } +} + +nse_funcs$sub <- arrow_r_string_replace_function(1L) +nse_funcs$gsub <- arrow_r_string_replace_function(-1L) +nse_funcs$str_replace <- arrow_stringr_string_replace_function(1L) +nse_funcs$str_replace_all <- arrow_stringr_string_replace_function(-1L) + +nse_funcs$strsplit <- function(x, + split, + fixed = FALSE, + perl = FALSE, + useBytes = FALSE) { + assert_that(is.string(split)) + + # The Arrow C++ library does not support splitting a string by a regular + # expression pattern (ARROW-12608) but the default behavior of + # base::strsplit() is to interpret the split pattern as a regex + # (fixed = FALSE). R users commonly pass non-regex split patterns to + # strsplit() without bothering to set fixed = TRUE. It would be annoying if + # that didn't work here. So: if fixed = FALSE, let's check the split pattern + # to see if it is a regex (if it contains any regex metacharacters). If not, + # then allow to proceed. + if (!fixed && contains_regex(split)) { + arrow_not_supported("Regular expression matching in strsplit()") + } + # warn when the user specifies both fixed = TRUE and perl = TRUE, for + # consistency with the behavior of base::strsplit() + if (fixed && perl) { + warning("Argument 'perl = TRUE' will be ignored", call. = FALSE) + } + # since split is not a regex, proceed without any warnings or errors + # regardless of the value of perl, for consistency with the behavior of + # base::strsplit() + Expression$create( + "split_pattern", + x, + options = list(pattern = split, reverse = FALSE, max_splits = -1L) + ) +} + +nse_funcs$str_split <- function(string, pattern, n = Inf, simplify = FALSE) { + opts <- get_stringr_pattern_options(enexpr(pattern)) + if (!opts$fixed && contains_regex(opts$pattern)) { + arrow_not_supported("Regular expression matching in str_split()") + } + if (opts$ignore_case) { + arrow_not_supported("Case-insensitive string splitting") + } + if (n == 0) { + arrow_not_supported("Splitting strings into zero parts") + } + if (identical(n, Inf)) { + n <- 0L + } + if (simplify) { + warning("Argument 'simplify = TRUE' will be ignored", call. = FALSE) + } + # The max_splits option in the Arrow C++ library controls the maximum number + # of places at which the string is split, whereas the argument n to + # str_split() controls the maximum number of pieces to return. So we must + # subtract 1 from n to get max_splits. + Expression$create( + "split_pattern", + string, + options = list( + pattern = + opts$pattern, + reverse = FALSE, + max_splits = n - 1L + ) + ) +} + +# String function helpers + +# format `pattern` as needed for case insensitivity and literal matching by RE2 +format_string_pattern <- function(pattern, ignore.case, fixed) { + # Arrow lacks native support for case-insensitive literal string matching and + # replacement, so we use the regular expression engine (RE2) to do this. + # https://github.com/google/re2/wiki/Syntax + if (ignore.case) { + if (fixed) { + # Everything between "\Q" and "\E" is treated as literal text. + # If the search text contains any literal "\E" strings, make them + # lowercase so they won't signal the end of the literal text: + pattern <- gsub("\\E", "\\e", pattern, fixed = TRUE) + pattern <- paste0("\\Q", pattern, "\\E") + } + # Prepend "(?i)" for case-insensitive matching + pattern <- paste0("(?i)", pattern) + } + pattern +} + +# format `replacement` as needed for literal replacement by RE2 +format_string_replacement <- function(replacement, ignore.case, fixed) { + # Arrow lacks native support for case-insensitive literal string + # replacement, so we use the regular expression engine (RE2) to do this. + # https://github.com/google/re2/wiki/Syntax + if (ignore.case && fixed) { + # Escape single backslashes in the regex replacement text so they are + # interpreted as literal backslashes: + replacement <- gsub("\\", "\\\\", replacement, fixed = TRUE) + } + replacement +} + +#' Get `stringr` pattern options +#' +#' This function assigns definitions for the `stringr` pattern modifier +#' functions (`fixed()`, `regex()`, etc.) inside itself, and uses them to +#' evaluate the quoted expression `pattern`, returning a list that is used +#' to control pattern matching behavior in internal `arrow` functions. +#' +#' @param pattern Unevaluated expression containing a call to a `stringr` +#' pattern modifier function +#' +#' @return List containing elements `pattern`, `fixed`, and `ignore_case` +#' @keywords internal +get_stringr_pattern_options <- function(pattern) { + fixed <- function(pattern, ignore_case = FALSE, ...) { + check_dots(...) + list(pattern = pattern, fixed = TRUE, ignore_case = ignore_case) + } + regex <- function(pattern, ignore_case = FALSE, ...) { + check_dots(...) + list(pattern = pattern, fixed = FALSE, ignore_case = ignore_case) + } + coll <- function(...) { + arrow_not_supported("Pattern modifier `coll()`") + } + boundary <- function(...) { + arrow_not_supported("Pattern modifier `boundary()`") + } + check_dots <- function(...) { + dots <- list(...) + if (length(dots)) { + warning( + "Ignoring pattern modifier ", + ngettext(length(dots), "argument ", "arguments "), + "not supported in Arrow: ", + oxford_paste(names(dots)), + call. = FALSE + ) + } + } + ensure_opts <- function(opts) { + if (is.character(opts)) { + opts <- list(pattern = opts, fixed = FALSE, ignore_case = FALSE) + } + opts + } + ensure_opts(eval(pattern)) +} + +#' Does this string contain regex metacharacters? +#' +#' @param string String to be tested +#' @keywords internal +#' @return Logical: does `string` contain regex metacharacters? +contains_regex <- function(string) { + grepl("[.\\|()[{^$*+?]", string) +} diff --git a/r/R/dplyr-group-by.R b/r/R/dplyr-group-by.R new file mode 100644 index 0000000000000..d2cf79253a51b --- /dev/null +++ b/r/R/dplyr-group-by.R @@ -0,0 +1,65 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +# The following S3 methods are registered on load if dplyr is present + +group_by.arrow_dplyr_query <- function(.data, + ..., + .add = FALSE, + add = .add, + .drop = dplyr::group_by_drop_default(.data)) { + .data <- arrow_dplyr_query(.data) + # ... can contain expressions (i.e. can add (or rename?) columns) + # Check for those (they show up as named expressions) + new_groups <- enquos(...) + new_groups <- new_groups[nzchar(names(new_groups))] + if (length(new_groups)) { + # Add them to the data + .data <- dplyr::mutate(.data, !!!new_groups) + } + if (".add" %in% names(formals(dplyr::group_by))) { + # dplyr >= 1.0 + gv <- dplyr::group_by_prepare(.data, ..., .add = .add)$group_names + } else { + gv <- dplyr::group_by_prepare(.data, ..., add = add)$group_names + } + .data$group_by_vars <- gv + .data$drop_empty_groups <- ifelse(length(gv), .drop, dplyr::group_by_drop_default(.data)) + .data +} +group_by.Dataset <- group_by.ArrowTabular <- group_by.arrow_dplyr_query + +groups.arrow_dplyr_query <- function(x) syms(dplyr::group_vars(x)) +groups.Dataset <- groups.ArrowTabular <- function(x) NULL + +group_vars.arrow_dplyr_query <- function(x) x$group_by_vars +group_vars.Dataset <- group_vars.ArrowTabular <- function(x) NULL + +# the logical literal in the two functions below controls the default value of +# the .drop argument to group_by() +group_by_drop_default.arrow_dplyr_query <- + function(.tbl) .tbl$drop_empty_groups %||% TRUE +group_by_drop_default.Dataset <- group_by_drop_default.ArrowTabular <- + function(.tbl) TRUE + +ungroup.arrow_dplyr_query <- function(x, ...) { + x$group_by_vars <- character() + x$drop_empty_groups <- NULL + x +} +ungroup.Dataset <- ungroup.ArrowTabular <- force \ No newline at end of file diff --git a/r/R/dplyr-mutate.R b/r/R/dplyr-mutate.R new file mode 100644 index 0000000000000..8513a45f6e9ac --- /dev/null +++ b/r/R/dplyr-mutate.R @@ -0,0 +1,117 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +# The following S3 methods are registered on load if dplyr is present + +mutate.arrow_dplyr_query <- function(.data, + ..., + .keep = c("all", "used", "unused", "none"), + .before = NULL, + .after = NULL) { + call <- match.call() + exprs <- quos(...) + + .keep <- match.arg(.keep) + .before <- enquo(.before) + .after <- enquo(.after) + + if (.keep %in% c("all", "unused") && length(exprs) == 0) { + # Nothing to do + return(.data) + } + + .data <- arrow_dplyr_query(.data) + + # Restrict the cases we support for now + if (length(dplyr::group_vars(.data)) > 0) { + # mutate() on a grouped dataset does calculations within groups + # This doesn't matter on scalar ops (arithmetic etc.) but it does + # for things with aggregations (e.g. subtracting the mean) + return(abandon_ship(call, .data, 'mutate() on grouped data not supported in Arrow')) + } + + # Check for unnamed expressions and fix if any + unnamed <- !nzchar(names(exprs)) + # Deparse and take the first element in case they're long expressions + names(exprs)[unnamed] <- map_chr(exprs[unnamed], as_label) + + mask <- arrow_mask(.data) + results <- list() + for (i in seq_along(exprs)) { + # Iterate over the indices and not the names because names may be repeated + # (which overwrites the previous name) + new_var <- names(exprs)[i] + results[[new_var]] <- arrow_eval(exprs[[i]], mask) + if (inherits(results[[new_var]], "try-error")) { + msg <- handle_arrow_not_supported( + results[[new_var]], + as_label(exprs[[i]]) + ) + return(abandon_ship(call, .data, msg)) + } else if (!inherits(results[[new_var]], "Expression") && + !is.null(results[[new_var]])) { + # We need some wrapping to handle literal values + if (length(results[[new_var]]) != 1) { + msg <- paste0('In ', new_var, " = ", as_label(exprs[[i]]), ", only values of size one are recycled") + return(abandon_ship(call, .data, msg)) + } + results[[new_var]] <- Expression$scalar(results[[new_var]]) + } + # Put it in the data mask too + mask[[new_var]] <- mask$.data[[new_var]] <- results[[new_var]] + } + + old_vars <- names(.data$selected_columns) + # Note that this is names(exprs) not names(results): + # if results$new_var is NULL, that means we are supposed to remove it + new_vars <- names(exprs) + + # Assign the new columns into the .data$selected_columns + for (new_var in new_vars) { + .data$selected_columns[[new_var]] <- results[[new_var]] + } + + # Deduplicate new_vars and remove NULL columns from new_vars + new_vars <- intersect(new_vars, names(.data$selected_columns)) + + # Respect .before and .after + if (!quo_is_null(.before) || !quo_is_null(.after)) { + new <- setdiff(new_vars, old_vars) + .data <- dplyr::relocate(.data, !!new, .before = !!.before, .after = !!.after) + } + + # Respect .keep + if (.keep == "none") { + .data$selected_columns <- .data$selected_columns[new_vars] + } else if (.keep != "all") { + # "used" or "unused" + used_vars <- unlist(lapply(exprs, all.vars), use.names = FALSE) + if (.keep == "used") { + .data$selected_columns[setdiff(old_vars, used_vars)] <- NULL + } else { + # "unused" + .data$selected_columns[intersect(old_vars, used_vars)] <- NULL + } + } + # Even if "none", we still keep group vars + ensure_group_vars(.data) +} +mutate.Dataset <- mutate.ArrowTabular <- mutate.arrow_dplyr_query + +transmute.arrow_dplyr_query <- function(.data, ...) dplyr::mutate(.data, ..., .keep = "none") +transmute.Dataset <- transmute.ArrowTabular <- transmute.arrow_dplyr_query \ No newline at end of file diff --git a/r/R/dplyr-select.R b/r/R/dplyr-select.R new file mode 100644 index 0000000000000..3730fe63fec12 --- /dev/null +++ b/r/R/dplyr-select.R @@ -0,0 +1,120 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +# The following S3 methods are registered on load if dplyr is present + +tbl_vars.arrow_dplyr_query <- function(x) names(x$selected_columns) + +select.arrow_dplyr_query <- function(.data, ...) { + check_select_helpers(enexprs(...)) + column_select(arrow_dplyr_query(.data), !!!enquos(...)) +} +select.Dataset <- select.ArrowTabular <- select.arrow_dplyr_query + +rename.arrow_dplyr_query <- function(.data, ...) { + check_select_helpers(enexprs(...)) + column_select(arrow_dplyr_query(.data), !!!enquos(...), .FUN = vars_rename) +} +rename.Dataset <- rename.ArrowTabular <- rename.arrow_dplyr_query + +column_select <- function(.data, ..., .FUN = vars_select) { + # .FUN is either tidyselect::vars_select or tidyselect::vars_rename + # It operates on the names() of selected_columns, i.e. the column names + # factoring in any renaming that may already have happened + out <- .FUN(names(.data), !!!enquos(...)) + # Make sure that the resulting selected columns map back to the original data, + # as in when there are multiple renaming steps + .data$selected_columns <- set_names(.data$selected_columns[out], names(out)) + + # If we've renamed columns, we need to project that renaming into other + # query parameters we've collected + renamed <- out[names(out) != out] + if (length(renamed)) { + # Massage group_by + gbv <- .data$group_by_vars + renamed_groups <- gbv %in% renamed + gbv[renamed_groups] <- names(renamed)[match(gbv[renamed_groups], renamed)] + .data$group_by_vars <- gbv + # No need to massage filters because those contain references to Arrow objects + } + .data +} + +relocate.arrow_dplyr_query <- function(.data, ..., .before = NULL, .after = NULL) { + # The code in this function is adapted from the code in dplyr::relocate.data.frame + # at https://github.com/tidyverse/dplyr/blob/master/R/relocate.R + # TODO: revisit this after https://github.com/tidyverse/dplyr/issues/5829 + check_select_helpers(c(enexprs(...), enexpr(.before), enexpr(.after))) + + .data <- arrow_dplyr_query(.data) + + to_move <- eval_select(expr(c(...)), .data$selected_columns) + + .before <- enquo(.before) + .after <- enquo(.after) + has_before <- !quo_is_null(.before) + has_after <- !quo_is_null(.after) + + if (has_before && has_after) { + abort("Must supply only one of `.before` and `.after`.") + } else if (has_before) { + where <- min(unname(eval_select(.before, .data$selected_columns))) + if (!where %in% to_move) { + to_move <- c(to_move, where) + } + } else if (has_after) { + where <- max(unname(eval_select(.after, .data$selected_columns))) + if (!where %in% to_move) { + to_move <- c(where, to_move) + } + } else { + where <- 1L + if (!where %in% to_move) { + to_move <- c(to_move, where) + } + } + + lhs <- setdiff(seq2(1, where - 1), to_move) + rhs <- setdiff(seq2(where + 1, length(.data$selected_columns)), to_move) + + pos <- vec_unique(c(lhs, to_move, rhs)) + new_names <- names(pos) + .data$selected_columns <- .data$selected_columns[pos] + + if (!is.null(new_names)) { + names(.data$selected_columns)[new_names != ""] <- new_names[new_names != ""] + } + .data +} +relocate.Dataset <- relocate.ArrowTabular <- relocate.arrow_dplyr_query + +check_select_helpers <- function(exprs) { + # Throw an error if unsupported tidyselect selection helpers in `exprs` + exprs <- lapply(exprs, function(x) if (is_quosure(x)) quo_get_expr(x) else x) + unsup_select_helpers <- "where" + funs_in_exprs <- unlist(lapply(exprs, all_funs)) + unsup_funs <- funs_in_exprs[funs_in_exprs %in% unsup_select_helpers] + if (length(unsup_funs)) { + stop( + "Unsupported selection ", + ngettext(length(unsup_funs), "helper: ", "helpers: "), + oxford_paste(paste0(unsup_funs, "()"), quote = FALSE), + call. = FALSE + ) + } +} \ No newline at end of file diff --git a/r/R/dplyr-summarize.R b/r/R/dplyr-summarize.R new file mode 100644 index 0000000000000..ecb459c982ce3 --- /dev/null +++ b/r/R/dplyr-summarize.R @@ -0,0 +1,36 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +# The following S3 methods are registered on load if dplyr is present + +summarise.arrow_dplyr_query <- function(.data, ...) { + call <- match.call() + .data <- arrow_dplyr_query(.data) + if (query_on_dataset(.data)) { + not_implemented_for_dataset("summarize()") + } + exprs <- quos(...) + # Only retain the columns we need to do our aggregations + vars_to_keep <- unique(c( + unlist(lapply(exprs, all.vars)), # vars referenced in summarise + dplyr::group_vars(.data) # vars needed for grouping + )) + .data <- dplyr::select(.data, vars_to_keep) + dplyr::summarise(dplyr::collect(.data), ...) +} +summarise.Dataset <- summarise.ArrowTabular <- summarise.arrow_dplyr_query \ No newline at end of file diff --git a/r/R/dplyr.R b/r/R/dplyr.R index 264c4929f729d..56be8cff1dbde 100644 --- a/r/R/dplyr.R +++ b/r/R/dplyr.R @@ -30,14 +30,19 @@ arrow_dplyr_query <- function(.data) { if (inherits(.data, "arrow_dplyr_query")) { return(.data) } + structure( list( - .data = .data$clone(), + .data = if (inherits(.data, "Dataset")) { + .data$clone() + } else { + InMemoryDataset$create(.data) + }, # selected_columns is a named list: # * contents are references/expressions pointing to the data # * names are the names they should be in the end (i.e. this # records any renaming) - selected_columns = make_field_refs(names(.data), dataset = inherits(.data, "Dataset")), + selected_columns = make_field_refs(names(.data)), # filtered_rows will be an Expression filtered_rows = TRUE, # group_by_vars is a character vector of columns (as renamed) @@ -58,40 +63,39 @@ arrow_dplyr_query <- function(.data) { ) } +make_field_refs <- function(field_names) { + set_names(lapply(field_names, Expression$field_ref), field_names) +} + #' @export print.arrow_dplyr_query <- function(x, ...) { schm <- x$.data$schema - cols <- get_field_names(x) - # If cols are expressions, they won't be in the schema and will be "" in cols - fields <- map_chr(cols, function(name) { + types <- map_chr(x$selected_columns, function(expr) { + name <- expr$field_name if (nzchar(name)) { - schm$GetFieldByName(name)$ToString() + # Just a field_ref, so look up in the schema + schm$GetFieldByName(name)$type$ToString() } else { - "expr" + # Expression, so get its type and append the expression + paste0( + expr$type(schm)$ToString(), + " (", expr$ToString(), ")" + ) } }) - # Strip off the field names as they are in the dataset and add the renamed ones - fields <- paste(names(cols), sub("^.*?: ", "", fields), sep = ": ", collapse = "\n") + fields <- paste(names(types), types, sep = ": ", collapse = "\n") cat(class(x$.data)[1], " (query)\n", sep = "") cat(fields, "\n", sep = "") cat("\n") if (!isTRUE(x$filtered_rows)) { - if (query_on_dataset(x)) { - filter_string <- x$filtered_rows$ToString() - } else { - filter_string <- .format_array_expression(x$filtered_rows) - } + filter_string <- x$filtered_rows$ToString() cat("* Filter: ", filter_string, "\n", sep = "") } if (length(x$group_by_vars)) { cat("* Grouped by ", paste(x$group_by_vars, collapse = ", "), "\n", sep = "") } if (length(x$arrange_vars)) { - if (query_on_dataset(x)) { - arrange_strings <- map_chr(x$arrange_vars, function(x) x$ToString()) - } else { - arrange_strings <- map_chr(x$arrange_vars, .format_array_expression) - } + arrange_strings <- map_chr(x$arrange_vars, function(x) x$ToString()) cat( "* Sorted by ", paste( @@ -109,33 +113,6 @@ print.arrow_dplyr_query <- function(x, ...) { invisible(x) } -get_field_names <- function(selected_cols) { - if (inherits(selected_cols, "arrow_dplyr_query")) { - selected_cols <- selected_cols$selected_columns - } - map_chr(selected_cols, function(x) { - if (inherits(x, "Expression")) { - out <- x$field_name - } else if (inherits(x, "array_expression")) { - out <- x$args$field_name - } else { - out <- NULL - } - # If x isn't some kind of field reference, out is NULL, - # but we always need to return a string - out %||% "" - }) -} - -make_field_refs <- function(field_names, dataset = TRUE) { - if (dataset) { - out <- lapply(field_names, Expression$field_ref) - } else { - out <- lapply(field_names, function(x) array_expression("array_ref", field_name = x)) - } - set_names(out, field_names) -} - # These are the names reflecting all select/rename, not what is in Arrow #' @export names.arrow_dplyr_query <- function(x) names(x$selected_columns) @@ -146,12 +123,8 @@ dim.arrow_dplyr_query <- function(x) { if (isTRUE(x$filtered)) { rows <- x$.data$num_rows - } else if (query_on_dataset(x)) { - scanner <- Scanner$create(x) - rows <- scanner$CountRows() } else { - # Evaluate the filter expression to a BooleanArray and count - rows <- as.integer(sum(eval_array_expression(x$filtered_rows, x$.data), na.rm = TRUE)) + rows <- Scanner$create(x)$CountRows() } c(rows, cols) } @@ -163,631 +136,19 @@ as.data.frame.arrow_dplyr_query <- function(x, row.names = NULL, optional = FALS #' @export head.arrow_dplyr_query <- function(x, n = 6L, ...) { - if (query_on_dataset(x)) { - head.Dataset(x, n, ...) - } else { - out <- collect.arrow_dplyr_query(x, as_data_frame = FALSE) - if (inherits(out, "arrow_dplyr_query")) { - out$.data <- head(out$.data, n) - } else { - out <- head(out, n) - } - out - } + out <- head.Dataset(x, n, ...) + restore_dplyr_features(out, x) } #' @export tail.arrow_dplyr_query <- function(x, n = 6L, ...) { - if (query_on_dataset(x)) { - tail.Dataset(x, n, ...) - } else { - out <- collect.arrow_dplyr_query(x, as_data_frame = FALSE) - if (inherits(out, "arrow_dplyr_query")) { - out$.data <- tail(out$.data, n) - } else { - out <- tail(out, n) - } - out - } + out <- tail.Dataset(x, n, ...) + restore_dplyr_features(out, x) } #' @export -`[.arrow_dplyr_query` <- function(x, i, j, ..., drop = FALSE) { - if (query_on_dataset(x)) { - `[.Dataset`(x, i, j, ..., drop = FALSE) - } else { - stop( - "[ method not implemented for queries. Call 'collect(x, as_data_frame = FALSE)' first", - call. = FALSE - ) - } -} - -# The following S3 methods are registered on load if dplyr is present -tbl_vars.arrow_dplyr_query <- function(x) names(x$selected_columns) - -select.arrow_dplyr_query <- function(.data, ...) { - check_select_helpers(enexprs(...)) - column_select(arrow_dplyr_query(.data), !!!enquos(...)) -} -select.Dataset <- select.ArrowTabular <- select.arrow_dplyr_query - -rename.arrow_dplyr_query <- function(.data, ...) { - check_select_helpers(enexprs(...)) - column_select(arrow_dplyr_query(.data), !!!enquos(...), .FUN = vars_rename) -} -rename.Dataset <- rename.ArrowTabular <- rename.arrow_dplyr_query - -column_select <- function(.data, ..., .FUN = vars_select) { - # .FUN is either tidyselect::vars_select or tidyselect::vars_rename - # It operates on the names() of selected_columns, i.e. the column names - # factoring in any renaming that may already have happened - out <- .FUN(names(.data), !!!enquos(...)) - # Make sure that the resulting selected columns map back to the original data, - # as in when there are multiple renaming steps - .data$selected_columns <- set_names(.data$selected_columns[out], names(out)) - - # If we've renamed columns, we need to project that renaming into other - # query parameters we've collected - renamed <- out[names(out) != out] - if (length(renamed)) { - # Massage group_by - gbv <- .data$group_by_vars - renamed_groups <- gbv %in% renamed - gbv[renamed_groups] <- names(renamed)[match(gbv[renamed_groups], renamed)] - .data$group_by_vars <- gbv - # No need to massage filters because those contain references to Arrow objects - } - .data -} - -relocate.arrow_dplyr_query <- function(.data, ..., .before = NULL, .after = NULL) { - # The code in this function is adapted from the code in dplyr::relocate.data.frame - # at https://github.com/tidyverse/dplyr/blob/master/R/relocate.R - # TODO: revisit this after https://github.com/tidyverse/dplyr/issues/5829 - check_select_helpers(c(enexprs(...), enexpr(.before), enexpr(.after))) - - .data <- arrow_dplyr_query(.data) - - to_move <- eval_select(expr(c(...)), .data$selected_columns) - - .before <- enquo(.before) - .after <- enquo(.after) - has_before <- !quo_is_null(.before) - has_after <- !quo_is_null(.after) - - if (has_before && has_after) { - abort("Must supply only one of `.before` and `.after`.") - } else if (has_before) { - where <- min(unname(eval_select(.before, .data$selected_columns))) - if (!where %in% to_move) { - to_move <- c(to_move, where) - } - } else if (has_after) { - where <- max(unname(eval_select(.after, .data$selected_columns))) - if (!where %in% to_move) { - to_move <- c(where, to_move) - } - } else { - where <- 1L - if (!where %in% to_move) { - to_move <- c(to_move, where) - } - } - - lhs <- setdiff(seq2(1, where - 1), to_move) - rhs <- setdiff(seq2(where + 1, length(.data$selected_columns)), to_move) - - pos <- vec_unique(c(lhs, to_move, rhs)) - new_names <- names(pos) - .data$selected_columns <- .data$selected_columns[pos] - - if (!is.null(new_names)) { - names(.data$selected_columns)[new_names != ""] <- new_names[new_names != ""] - } - .data -} -relocate.Dataset <- relocate.ArrowTabular <- relocate.arrow_dplyr_query - -check_select_helpers <- function(exprs) { - # Throw an error if unsupported tidyselect selection helpers in `exprs` - exprs <- lapply(exprs, function(x) if (is_quosure(x)) quo_get_expr(x) else x) - unsup_select_helpers <- "where" - funs_in_exprs <- unlist(lapply(exprs, all_funs)) - unsup_funs <- funs_in_exprs[funs_in_exprs %in% unsup_select_helpers] - if (length(unsup_funs)) { - stop( - "Unsupported selection ", - ngettext(length(unsup_funs), "helper: ", "helpers: "), - oxford_paste(paste0(unsup_funs, "()"), quote = FALSE), - call. = FALSE - ) - } -} - -filter.arrow_dplyr_query <- function(.data, ..., .preserve = FALSE) { - # TODO something with the .preserve argument - filts <- quos(...) - if (length(filts) == 0) { - # Nothing to do - return(.data) - } - - .data <- arrow_dplyr_query(.data) - # tidy-eval the filter expressions inside an Arrow data_mask - filters <- lapply(filts, arrow_eval, arrow_mask(.data)) - bad_filters <- map_lgl(filters, ~inherits(., "try-error")) - if (any(bad_filters)) { - bads <- oxford_paste(map_chr(filts, as_label)[bad_filters], quote = FALSE) - if (query_on_dataset(.data)) { - # Abort. We don't want to auto-collect if this is a Dataset because that - # could blow up, too big. - stop( - "Filter expression not supported for Arrow Datasets: ", bads, - "\nCall collect() first to pull data into R.", - call. = FALSE - ) - } else { - # TODO: only show this in some debug mode? - warning( - "Filter expression not implemented in Arrow: ", bads, "; pulling data into R", - immediate. = TRUE, - call. = FALSE - ) - # Set any valid filters first, then collect and then apply the invalid ones in R - .data <- set_filters(.data, filters[!bad_filters]) - return(dplyr::filter(dplyr::collect(.data), !!!filts[bad_filters])) - } - } - - set_filters(.data, filters) -} -filter.Dataset <- filter.ArrowTabular <- filter.arrow_dplyr_query - -arrow_eval <- function (expr, mask) { - # filter(), mutate(), etc. work by evaluating the quoted `exprs` to generate Expressions - # with references to Arrays (if .data is Table/RecordBatch) or Fields (if - # .data is a Dataset). - - # This yields an Expression as long as the `exprs` are implemented in Arrow. - # Otherwise, it returns a try-error - tryCatch(eval_tidy(expr, mask), error = function(e) { - # Look for the cases where bad input was given, i.e. this would fail - # in regular dplyr anyway, and let those raise those as errors; - # else, for things not supported by Arrow return a "try-error", - # which we'll handle differently - msg <- conditionMessage(e) - patterns <- dplyr_functions$i18ized_error_pattern - if (is.null(patterns)) { - patterns <- i18ize_error_messages() - # Memoize it - dplyr_functions$i18ized_error_pattern <- patterns - } - if (grepl(patterns, msg)) { - stop(e) - } - invisible(structure(msg, class = "try-error", condition = e)) - }) -} - -i18ize_error_messages <- function() { - # Figure out what the error messages will be with this LANGUAGE - # so that we can look for them - out <- list( - obj = tryCatch(eval(parse(text = "X_____X")), error = function(e) conditionMessage(e)), - fun = tryCatch(eval(parse(text = "X_____X()")), error = function(e) conditionMessage(e)) - ) - paste(map(out, ~sub("X_____X", ".*", .)), collapse = "|") -} - -# Helper to assemble the functions that go in the NSE data mask -# The only difference between the Dataset and the Table/RecordBatch versions -# is that they use a different wrapping function (FUN) to hold the unevaluated -# expression. -build_function_list <- function(FUN) { - wrapper <- function(operator) { - force(operator) - function(...) FUN(operator, ...) - } - all_arrow_funs <- list_compute_functions() - - c( - # Include mappings from R function name spellings - lapply(set_names(names(.array_function_map)), wrapper), - # Plus some special handling where it's not 1:1 - cast = function(x, target_type, safe = TRUE, ...) { - opts <- cast_options(safe, ...) - opts$to_type <- as_type(target_type) - FUN("cast", x, options = opts) - }, - dictionary_encode = function(x, null_encoding_behavior = c("mask", "encode")) { - null_encoding_behavior <- - NullEncodingBehavior[[toupper(match.arg(null_encoding_behavior))]] - FUN( - "dictionary_encode", - x, - options = list(null_encoding_behavior = null_encoding_behavior) - ) - }, - # as.factor() is mapped in expression.R - as.character = function(x) { - FUN("cast", x, options = cast_options(to_type = string())) - }, - as.double = function(x) { - FUN("cast", x, options = cast_options(to_type = float64())) - }, - as.integer = function(x) { - FUN( - "cast", - x, - options = cast_options( - to_type = int32(), - allow_float_truncate = TRUE, - allow_decimal_truncate = TRUE - ) - ) - }, - as.integer64 = function(x) { - FUN( - "cast", - x, - options = cast_options( - to_type = int64(), - allow_float_truncate = TRUE, - allow_decimal_truncate = TRUE - ) - ) - }, - as.logical = function(x) { - FUN("cast", x, options = cast_options(to_type = boolean())) - }, - as.numeric = function(x) { - FUN("cast", x, options = cast_options(to_type = float64())) - }, - nchar = function(x, type = "chars", allowNA = FALSE, keepNA = NA) { - if (allowNA) { - stop("allowNA = TRUE not supported for Arrow", call. = FALSE) - } - if (is.na(keepNA)) { - keepNA <- !identical(type, "width") - } - if (!keepNA) { - # TODO: I think there is a fill_null kernel we could use, set null to 2 - stop("keepNA = TRUE not supported for Arrow", call. = FALSE) - } - if (identical(type, "bytes")) { - FUN("binary_length", x) - } else { - FUN("utf8_length", x) - } - }, - str_trim = function(string, side = c("both", "left", "right")) { - side <- match.arg(side) - switch( - side, - left = FUN("utf8_ltrim_whitespace", string), - right = FUN("utf8_rtrim_whitespace", string), - both = FUN("utf8_trim_whitespace", string) - ) - }, - grepl = arrow_r_string_match_function(FUN), - str_detect = arrow_stringr_string_match_function(FUN), - sub = arrow_r_string_replace_function(FUN, 1L), - gsub = arrow_r_string_replace_function(FUN, -1L), - str_replace = arrow_stringr_string_replace_function(FUN, 1L), - str_replace_all = arrow_stringr_string_replace_function(FUN, -1L), - strsplit = arrow_r_string_split_function(FUN), - str_split = arrow_stringr_string_split_function(FUN), - between = function(x, left, right) { - x >= left & x <= right - }, - # Now also include all available Arrow Compute functions, - # namespaced as arrow_fun - set_names( - lapply(all_arrow_funs, wrapper), - paste0("arrow_", all_arrow_funs) - ) - ) -} - -arrow_r_string_match_function <- function(FUN) { - function(pattern, x, ignore.case = FALSE, fixed = FALSE) { - FUN( - ifelse(fixed && !ignore.case, "match_substring", "match_substring_regex"), - x, - options = list(pattern = format_string_pattern(pattern, ignore.case, fixed)) - ) - } -} - -arrow_stringr_string_match_function <- function(FUN) { - function(string, pattern, negate = FALSE) { - opts <- get_stringr_pattern_options(enexpr(pattern)) - out <- arrow_r_string_match_function(FUN)( - pattern = opts$pattern, - x = string, - ignore.case = opts$ignore_case, - fixed = opts$fixed - ) - if (negate) out <- FUN("invert", out) - out - } -} - -arrow_r_string_replace_function <- function(FUN, max_replacements) { - function(pattern, replacement, x, ignore.case = FALSE, fixed = FALSE) { - FUN( - ifelse(fixed && !ignore.case, "replace_substring", "replace_substring_regex"), - x, - options = list( - pattern = format_string_pattern(pattern, ignore.case, fixed), - replacement = format_string_replacement(replacement, ignore.case, fixed), - max_replacements = max_replacements - ) - ) - } -} - -arrow_stringr_string_replace_function <- function(FUN, max_replacements) { - function(string, pattern, replacement) { - opts <- get_stringr_pattern_options(enexpr(pattern)) - arrow_r_string_replace_function(FUN, max_replacements)( - pattern = opts$pattern, - replacement = replacement, - x = string, - ignore.case = opts$ignore_case, - fixed = opts$fixed - ) - } -} - -arrow_r_string_split_function <- function(FUN, reverse = FALSE, max_splits = -1) { - function(x, split, fixed = FALSE, perl = FALSE, useBytes = FALSE) { - - assert_that(is.string(split)) - - # The Arrow C++ library does not support splitting a string by a regular - # expression pattern (ARROW-12608) but the default behavior of - # base::strsplit() is to interpret the split pattern as a regex - # (fixed = FALSE). R users commonly pass non-regex split patterns to - # strsplit() without bothering to set fixed = TRUE. It would be annoying if - # that didn't work here. So: if fixed = FALSE, let's check the split pattern - # to see if it is a regex (if it contains any regex metacharacters). If not, - # then allow to proceed. - if (!fixed && contains_regex(split)) { - stop("Regular expression matching not supported in strsplit for Arrow", call. = FALSE) - } - # warn when the user specifies both fixed = TRUE and perl = TRUE, for - # consistency with the behavior of base::strsplit() - if (fixed && perl) { - warning("Argument 'perl = TRUE' will be ignored", call. = FALSE) - } - # since split is not a regex, proceed without any warnings or errors - # regardless of the value of perl, for consistency with the behavior of - # base::strsplit() - FUN("split_pattern", x, options = list(pattern = split, reverse = reverse, max_splits = max_splits)) - } -} - -arrow_stringr_string_split_function <- function(FUN, reverse = FALSE) { - function(string, pattern, n = Inf, simplify = FALSE) { - opts <- get_stringr_pattern_options(enexpr(pattern)) - if (!opts$fixed && contains_regex(opts$pattern)) { - stop("Regular expression matching not supported in str_split() for Arrow", call. = FALSE) - } - if (opts$ignore_case) { - stop("Case-insensitive string splitting not supported in Arrow", call. = FALSE) - } - if (n == 0) { - stop("Splitting strings into zero parts not supported in Arrow" , call. = FALSE) - } - if (identical(n, Inf)) { - n <- 0L - } - if (simplify) { - warning("Argument 'simplify = TRUE' will be ignored", call. = FALSE) - } - # The max_splits option in the Arrow C++ library controls the maximum number - # of places at which the string is split, whereas the argument n to - # str_split() controls the maximum number of pieces to return. So we must - # subtract 1 from n to get max_splits. - FUN("split_pattern", string, options = list(pattern = opts$pattern, reverse = reverse, max_splits = n - 1L)) - } -} - -# format `pattern` as needed for case insensitivity and literal matching by RE2 -format_string_pattern <- function(pattern, ignore.case, fixed) { - # Arrow lacks native support for case-insensitive literal string matching and - # replacement, so we use the regular expression engine (RE2) to do this. - # https://github.com/google/re2/wiki/Syntax - if (ignore.case) { - if (fixed) { - # Everything between "\Q" and "\E" is treated as literal text. - # If the search text contains any literal "\E" strings, make them - # lowercase so they won't signal the end of the literal text: - pattern <- gsub("\\E", "\\e", pattern, fixed = TRUE) - pattern <- paste0("\\Q", pattern, "\\E") - } - # Prepend "(?i)" for case-insensitive matching - pattern <- paste0("(?i)", pattern) - } - pattern -} - -# format `replacement` as needed for literal replacement by RE2 -format_string_replacement <- function(replacement, ignore.case, fixed) { - # Arrow lacks native support for case-insensitive literal string - # replacement, so we use the regular expression engine (RE2) to do this. - # https://github.com/google/re2/wiki/Syntax - if (ignore.case && fixed) { - # Escape single backslashes in the regex replacement text so they are - # interpreted as literal backslashes: - replacement <- gsub("\\", "\\\\", replacement, fixed = TRUE) - } - replacement -} - -#' Get `stringr` pattern options -#' -#' This function assigns definitions for the `stringr` pattern modifier -#' functions (`fixed()`, `regex()`, etc.) inside itself, and uses them to -#' evaluate the quoted expression `pattern`, returning a list that is used -#' to control pattern matching behavior in internal `arrow` functions. -#' -#' @param pattern Unevaluated expression containing a call to a `stringr` -#' pattern modifier function -#' -#' @return List containing elements `pattern`, `fixed`, and `ignore_case` -#' @keywords internal -get_stringr_pattern_options <- function(pattern) { - fixed <- function(pattern, ignore_case = FALSE, ...) { - check_dots(...) - list(pattern = pattern, fixed = TRUE, ignore_case = ignore_case) - } - regex <- function(pattern, ignore_case = FALSE, ...) { - check_dots(...) - list(pattern = pattern, fixed = FALSE, ignore_case = ignore_case) - } - coll <- boundary <- function(...) { - stop( - "Pattern modifier `", - match.call()[[1]], - "()` is not supported in Arrow", - call. = FALSE - ) - } - check_dots <- function(...) { - dots <- list(...) - if (length(dots)) { - warning( - "Ignoring pattern modifier ", - ngettext(length(dots), "argument ", "arguments "), - "not supported in Arrow: ", - oxford_paste(names(dots)), - call. = FALSE - ) - } - } - ensure_opts <- function(opts) { - if (is.character(opts)) { - opts <- list(pattern = opts, fixed = FALSE, ignore_case = FALSE) - } - opts - } - ensure_opts(eval(pattern)) -} - -# We'll populate these at package load time. -dplyr_functions <- NULL -init_env <- function () { - dplyr_functions <<- new.env(hash = TRUE) -} -init_env() - -# Create a data mask for evaluating a dplyr expression -arrow_mask <- function(.data) { - if (query_on_dataset(.data)) { - f_env <- new_environment(dplyr_functions$dataset) - } else { - f_env <- new_environment(dplyr_functions$array) - } - - # Add functions that need to error hard and clear. - # Some R functions will still try to evaluate on an Expression - # and return NA with a warning - fail <- function(...) stop("Not implemented") - for (f in c("mean", "sd")) { - f_env[[f]] <- fail - } - - # Add the column references and make the mask - out <- new_data_mask( - new_environment(.data$selected_columns, parent = f_env), - f_env - ) - # Then insert the data pronoun - # TODO: figure out what rlang::as_data_pronoun does/why we should use it - # (because if we do we get `Error: Can't modify the data pronoun` in mutate()) - out$.data <- .data$selected_columns - out -} - -set_filters <- function(.data, expressions) { - if (length(expressions)) { - # expressions is a list of Expressions. AND them together and set them on .data - new_filter <- Reduce("&", expressions) - if (isTRUE(.data$filtered_rows)) { - # TRUE is default (i.e. no filter yet), so we don't need to & with it - .data$filtered_rows <- new_filter - } else { - .data$filtered_rows <- .data$filtered_rows & new_filter - } - } - .data -} - -collect.arrow_dplyr_query <- function(x, as_data_frame = TRUE, ...) { - x <- ensure_group_vars(x) - x <- ensure_arrange_vars(x) # this sets x$temp_columns - # Pull only the selected rows and cols into R - if (query_on_dataset(x)) { - # See dataset.R for Dataset and Scanner(Builder) classes - tab <- Scanner$create(x)$ToTable() - } else { - # This is a Table or RecordBatch - - # Filter and select the data referenced in selected columns - if (isTRUE(x$filtered_rows)) { - filter <- TRUE - } else { - filter <- eval_array_expression(x$filtered_rows, x$.data) - } - # TODO: shortcut if identical(names(x$.data), find_array_refs(c(x$selected_columns, x$temp_columns)))? - tab <- x$.data[ - filter, - find_array_refs(c(x$selected_columns, x$temp_columns)), - keep_na = FALSE - ] - # Now evaluate those expressions on the filtered table - cols <- lapply(c(x$selected_columns, x$temp_columns), eval_array_expression, data = tab) - if (length(cols) == 0) { - tab <- tab[, integer(0)] - } else { - if (inherits(x$.data, "Table")) { - tab <- Table$create(!!!cols) - } else { - tab <- RecordBatch$create(!!!cols) - } - } - } - # Arrange rows - if (length(x$arrange_vars) > 0) { - tab <- tab[ - tab$SortIndices(names(x$arrange_vars), x$arrange_desc), - names(x$selected_columns), # this omits x$temp_columns from the result - drop = FALSE - ] - } - if (as_data_frame) { - df <- as.data.frame(tab) - tab$invalidate() - restore_dplyr_features(df, x) - } else { - restore_dplyr_features(tab, x) - } -} -collect.ArrowTabular <- function(x, as_data_frame = TRUE, ...) { - if (as_data_frame) { - as.data.frame(x, ...) - } else { - x - } -} -collect.Dataset <- function(x, ...) dplyr::collect(arrow_dplyr_query(x), ...) - -compute.arrow_dplyr_query <- function(x, ...) dplyr::collect(x, as_data_frame = FALSE) -compute.ArrowTabular <- function(x, ...) x -compute.Dataset <- compute.arrow_dplyr_query +`[.arrow_dplyr_query` <- `[.Dataset` +# TODO: ^ should also probably restore_dplyr_features, and/or that should be moved down ensure_group_vars <- function(x) { if (inherits(x, "arrow_dplyr_query")) { @@ -797,7 +158,7 @@ ensure_group_vars <- function(x) { # Add them back x$selected_columns <- c( x$selected_columns, - make_field_refs(gv, dataset = query_on_dataset(.data)) + make_field_refs(gv) ) } } @@ -822,13 +183,7 @@ restore_dplyr_features <- function(df, query) { # An arrow_dplyr_query holds some attributes that Arrow doesn't know about # After calling collect(), make sure these features are carried over - grouped <- length(query$group_by_vars) > 0 - renamed <- ncol(df) && !identical(names(df), names(query)) - if (renamed) { - # In case variables were renamed, apply those names - names(df) <- names(query) - } - if (grouped) { + if (length(query$group_by_vars) > 0) { # Preserve groupings, if present if (is.data.frame(df)) { df <- dplyr::grouped_df( @@ -846,217 +201,6 @@ restore_dplyr_features <- function(df, query) { df } -pull.arrow_dplyr_query <- function(.data, var = -1) { - .data <- arrow_dplyr_query(.data) - var <- vars_pull(names(.data), !!enquo(var)) - .data$selected_columns <- set_names(.data$selected_columns[var], var) - dplyr::collect(.data)[[1]] -} -pull.Dataset <- pull.ArrowTabular <- pull.arrow_dplyr_query - -summarise.arrow_dplyr_query <- function(.data, ...) { - call <- match.call() - .data <- arrow_dplyr_query(.data) - if (query_on_dataset(.data)) { - not_implemented_for_dataset("summarize()") - } - exprs <- quos(...) - # Only retain the columns we need to do our aggregations - vars_to_keep <- unique(c( - unlist(lapply(exprs, all.vars)), # vars referenced in summarise - dplyr::group_vars(.data) # vars needed for grouping - )) - .data <- dplyr::select(.data, vars_to_keep) - if (isTRUE(getOption("arrow.summarize", FALSE))) { - # Try stuff, if successful return() - out <- try(do_arrow_group_by(.data, ...), silent = TRUE) - if (inherits(out, "try-error")) { - return(abandon_ship(call, .data, format(out))) - } else { - return(out) - } - } else { - # If unsuccessful or if option not set, do the work in R - dplyr::summarise(dplyr::collect(.data), ...) - } -} -summarise.Dataset <- summarise.ArrowTabular <- summarise.arrow_dplyr_query - -do_arrow_group_by <- function(.data, ...) { - exprs <- quos(...) - mask <- arrow_mask(.data) - # Add aggregation wrappers to arrow_mask somehow - # (this is not ideal, would overwrite same-named objects) - mask$sum <- function(x, na.rm = FALSE) { - list( - fun = "sum", - data = x, - options = list(na.rm = na.rm) - ) - } - results <- list() - for (i in seq_along(exprs)) { - # Iterate over the indices and not the names because names may be repeated - # (which overwrites the previous name) - new_var <- names(exprs)[i] - results[[new_var]] <- arrow_eval(exprs[[i]], mask) - if (inherits(results[[new_var]], "try-error")) { - msg <- paste('Expression', as_label(exprs[[i]]), 'not supported in Arrow') - stop(msg, call. = FALSE) - } - # Put it in the data mask too? - #mask[[new_var]] <- mask$.data[[new_var]] <- results[[new_var]] - } - # Now, from that, split out the array (expressions) and options - opts <- lapply(results, function(x) x[c("fun", "options")]) - inputs <- lapply(results, function(x) eval_array_expression(x$data, .data$.data)) - grouping_vars <- lapply(.data$group_by_vars, function(x) eval_array_expression(.data$selected_columns[[x]], .data$.data)) - compute__GroupBy(inputs, grouping_vars, opts) -} - -group_by.arrow_dplyr_query <- function(.data, - ..., - .add = FALSE, - add = .add, - .drop = dplyr::group_by_drop_default(.data)) { - .data <- arrow_dplyr_query(.data) - # ... can contain expressions (i.e. can add (or rename?) columns) - # Check for those (they show up as named expressions) - new_groups <- enquos(...) - new_groups <- new_groups[nzchar(names(new_groups))] - if (length(new_groups)) { - # Add them to the data - .data <- dplyr::mutate(.data, !!!new_groups) - } - if (".add" %in% names(formals(dplyr::group_by))) { - # dplyr >= 1.0 - gv <- dplyr::group_by_prepare(.data, ..., .add = .add)$group_names - } else { - gv <- dplyr::group_by_prepare(.data, ..., add = add)$group_names - } - .data$group_by_vars <- gv - .data$drop_empty_groups <- ifelse(length(gv), .drop, dplyr::group_by_drop_default(.data)) - .data -} -group_by.Dataset <- group_by.ArrowTabular <- group_by.arrow_dplyr_query - -groups.arrow_dplyr_query <- function(x) syms(dplyr::group_vars(x)) -groups.Dataset <- groups.ArrowTabular <- function(x) NULL - -group_vars.arrow_dplyr_query <- function(x) x$group_by_vars -group_vars.Dataset <- group_vars.ArrowTabular <- function(x) NULL - -# the logical literal in the two functions below controls the default value of -# the .drop argument to group_by() -group_by_drop_default.arrow_dplyr_query <- - function(.tbl) .tbl$drop_empty_groups %||% TRUE -group_by_drop_default.Dataset <- group_by_drop_default.ArrowTabular <- - function(.tbl) TRUE - -ungroup.arrow_dplyr_query <- function(x, ...) { - x$group_by_vars <- character() - x$drop_empty_groups <- NULL - x -} -ungroup.Dataset <- ungroup.ArrowTabular <- force - -mutate.arrow_dplyr_query <- function(.data, - ..., - .keep = c("all", "used", "unused", "none"), - .before = NULL, - .after = NULL) { - call <- match.call() - exprs <- quos(...) - - .keep <- match.arg(.keep) - .before <- enquo(.before) - .after <- enquo(.after) - - if (.keep %in% c("all", "unused") && length(exprs) == 0) { - # Nothing to do - return(.data) - } - - .data <- arrow_dplyr_query(.data) - - # Restrict the cases we support for now - if (length(dplyr::group_vars(.data)) > 0) { - # mutate() on a grouped dataset does calculations within groups - # This doesn't matter on scalar ops (arithmetic etc.) but it does - # for things with aggregations (e.g. subtracting the mean) - return(abandon_ship(call, .data, 'mutate() on grouped data not supported in Arrow')) - } - - # Check for unnamed expressions and fix if any - unnamed <- !nzchar(names(exprs)) - # Deparse and take the first element in case they're long expressions - names(exprs)[unnamed] <- map_chr(exprs[unnamed], as_label) - - is_dataset <- query_on_dataset(.data) - mask <- arrow_mask(.data) - results <- list() - for (i in seq_along(exprs)) { - # Iterate over the indices and not the names because names may be repeated - # (which overwrites the previous name) - new_var <- names(exprs)[i] - results[[new_var]] <- arrow_eval(exprs[[i]], mask) - if (inherits(results[[new_var]], "try-error")) { - msg <- paste('Expression', as_label(exprs[[i]]), 'not supported in Arrow') - return(abandon_ship(call, .data, msg)) - } else if (is_dataset && - !inherits(results[[new_var]], "Expression") && - !is.null(results[[new_var]])) { - # We need some wrapping to handle literal values - if (length(results[[new_var]]) != 1) { - msg <- paste0('In ', new_var, " = ", as_label(exprs[[i]]), ", only values of size one are recycled") - return(abandon_ship(call, .data, msg)) - } - results[[new_var]] <- Expression$scalar(results[[new_var]]) - } - # Put it in the data mask too - mask[[new_var]] <- mask$.data[[new_var]] <- results[[new_var]] - } - - old_vars <- names(.data$selected_columns) - # Note that this is names(exprs) not names(results): - # if results$new_var is NULL, that means we are supposed to remove it - new_vars <- names(exprs) - - # Assign the new columns into the .data$selected_columns - for (new_var in new_vars) { - .data$selected_columns[[new_var]] <- results[[new_var]] - } - - # Deduplicate new_vars and remove NULL columns from new_vars - new_vars <- intersect(new_vars, names(.data$selected_columns)) - - # Respect .before and .after - if (!quo_is_null(.before) || !quo_is_null(.after)) { - new <- setdiff(new_vars, old_vars) - .data <- dplyr::relocate(.data, !!new, .before = !!.before, .after = !!.after) - } - - # Respect .keep - if (.keep == "none") { - .data$selected_columns <- .data$selected_columns[new_vars] - } else if (.keep != "all") { - # "used" or "unused" - used_vars <- unlist(lapply(exprs, all.vars), use.names = FALSE) - if (.keep == "used") { - .data$selected_columns[setdiff(old_vars, used_vars)] <- NULL - } else { - # "unused" - .data$selected_columns[intersect(old_vars, used_vars)] <- NULL - } - } - # Even if "none", we still keep group vars - ensure_group_vars(.data) -} -mutate.Dataset <- mutate.ArrowTabular <- mutate.arrow_dplyr_query - -transmute.arrow_dplyr_query <- function(.data, ...) dplyr::mutate(.data, ..., .keep = "none") -transmute.Dataset <- transmute.ArrowTabular <- transmute.arrow_dplyr_query - # Helper to handle unsupported dplyr features # * For Table/RecordBatch, we collect() and then call the dplyr method in R # * For Dataset, we just error @@ -1079,81 +223,7 @@ abandon_ship <- function(call, .data, msg = NULL) { eval.parent(call, 2) } -arrange.arrow_dplyr_query <- function(.data, ..., .by_group = FALSE) { - call <- match.call() - exprs <- quos(...) - if (.by_group) { - # when the data is is grouped and .by_group is TRUE, order the result by - # the grouping columns first - exprs <- c(quos(!!!dplyr::groups(.data)), exprs) - } - if (length(exprs) == 0) { - # Nothing to do - return(.data) - } - .data <- arrow_dplyr_query(.data) - # find and remove any dplyr::desc() and tidy-eval - # the arrange expressions inside an Arrow data_mask - sorts <- vector("list", length(exprs)) - descs <- logical(0) - mask <- arrow_mask(.data) - for (i in seq_along(exprs)) { - x <- find_and_remove_desc(exprs[[i]]) - exprs[[i]] <- x[["quos"]] - sorts[[i]] <- arrow_eval(exprs[[i]], mask) - if (inherits(sorts[[i]], "try-error")) { - msg <- paste('Expression', as_label(exprs[[i]]), 'not supported in Arrow') - return(abandon_ship(call, .data, msg)) - } - names(sorts)[i] <- as_label(exprs[[i]]) - descs[i] <- x[["desc"]] - } - .data$arrange_vars <- c(sorts, .data$arrange_vars) - .data$arrange_desc <- c(descs, .data$arrange_desc) - .data -} -arrange.Dataset <- arrange.ArrowTabular <- arrange.arrow_dplyr_query - -# Helper to handle desc() in arrange() -# * Takes a quosure as input -# * Returns a list with two elements: -# 1. The quosure with any wrapping parentheses and desc() removed -# 2. A logical value indicating whether desc() was found -# * Performs some other validation -find_and_remove_desc <- function(quosure) { - expr <- quo_get_expr(quosure) - descending <- FALSE - if (length(all.vars(expr)) < 1L) { - stop( - "Expression in arrange() does not contain any field names: ", - deparse(expr), - call. = FALSE - ) - } - # Use a while loop to remove any number of nested pairs of enclosing - # parentheses and any number of nested desc() calls. In the case of multiple - # nested desc() calls, each one toggles the sort order. - while (identical(typeof(expr), "language") && is.call(expr)) { - if (identical(expr[[1]], quote(`(`))) { - # remove enclosing parentheses - expr <- expr[[2]] - } else if (identical(expr[[1]], quote(desc))) { - # remove desc() and toggle descending - expr <- expr[[2]] - descending <- !descending - } else { - break - } - } - return( - list( - quos = quo_set_expr(quosure, expr), - desc = descending - ) - ) -} - -query_on_dataset <- function(x) inherits(x$.data, "Dataset") +query_on_dataset <- function(x) !inherits(x$.data, "InMemoryDataset") not_implemented_for_dataset <- function(method) { stop( @@ -1162,12 +232,3 @@ not_implemented_for_dataset <- function(method) { call. = FALSE ) } - -#' Does this string contain regex metacharacters? -#' -#' @param string String to be tested -#' @keywords internal -#' @return Logical: does `string` contain regex metacharacters? -contains_regex <- function(string) { - grepl("[.\\|()[{^$*+?]", string) -} diff --git a/r/R/expression.R b/r/R/expression.R index 30eb0906d4370..3b24b09bb8b99 100644 --- a/r/R/expression.R +++ b/r/R/expression.R @@ -17,100 +17,6 @@ #' @include arrowExports.R -array_expression <- function(FUN, - ..., - args = list(...), - options = empty_named_list()) { - structure( - list( - fun = FUN, - args = args, - options = options - ), - class = "array_expression" - ) -} - -#' @export -Ops.ArrowDatum <- function(e1, e2) { - if (.Generic == "!") { - eval_array_expression(build_array_expression(.Generic, e1)) - } else if (.Generic %in% names(.array_function_map)) { - eval_array_expression(build_array_expression(.Generic, e1, e2)) - } else { - stop(paste0("Unsupported operation on `", class(e1)[1L], "` : "), .Generic, call. = FALSE) - } -} - -#' @export -Ops.array_expression <- function(e1, e2) { - if (.Generic == "!") { - build_array_expression(.Generic, e1) - } else { - build_array_expression(.Generic, e1, e2) - } -} - -build_array_expression <- function(FUN, - ..., - args = list(...), - options = empty_named_list()) { - if (FUN == "-" && length(args) == 1L) { - if (inherits(args[[1]], c("ArrowObject", "array_expression"))) { - return(build_array_expression("negate_checked", args[[1]])) - } else { - return(-args[[1]]) - } - } - args <- lapply(args, .wrap_arrow, FUN) - - # In Arrow, "divide" is one function, which does integer division on - # integer inputs and floating-point division on floats - if (FUN == "/") { - # TODO: omg so many ways it's wrong to assume these types - args <- lapply(args, cast_array_expression, float64()) - } else if (FUN == "%/%") { - # In R, integer division works like floor(float division) - out <- build_array_expression("/", args = args, options = options) - return(cast_array_expression(out, int32(), allow_float_truncate = TRUE)) - } else if (FUN == "%%") { - # {e1 - e2 * ( e1 %/% e2 )} - # ^^^ form doesn't work because Ops.Array evaluates eagerly, - # but we can build that up - quotient <- build_array_expression("%/%", args = args) - base <- build_array_expression("*", quotient, args[[2]]) - # this cast is to ensure that the result of this and e1 are the same - # (autocasting only applies to scalars) - base <- cast_array_expression(base, args[[1]]$type) - return(build_array_expression("-", args[[1]], base)) - } - - array_expression(.array_function_map[[FUN]] %||% FUN, args = args, options = options) -} - -cast_array_expression <- function(x, to_type, safe = TRUE, ...) { - opts <- list( - to_type = to_type, - allow_int_overflow = !safe, - allow_time_truncate = !safe, - allow_float_truncate = !safe - ) - array_expression("cast", x, options = modifyList(opts, list(...))) -} - -.wrap_arrow <- function(arg, fun) { - if (!inherits(arg, c("ArrowObject", "array_expression"))) { - # TODO: Array$create if lengths are equal? - # TODO: these kernels should autocast like the dataset ones do (e.g. int vs. float) - if (fun == "%in%") { - arg <- Array$create(arg) - } else { - arg <- Scalar$create(arg) - } - } - arg -} - .unary_function_map <- list( "!" = "invert", "as.factor" = "dictionary_encode", @@ -150,86 +56,6 @@ cast_array_expression <- function(x, to_type, safe = TRUE, ...) { .array_function_map <- c(.unary_function_map, .binary_function_map) -eval_array_expression <- function(x, data = NULL) { - if (!is.null(data)) { - x <- bind_array_refs(x, data) - } - if (!inherits(x, "array_expression")) { - # Nothing to evaluate - return(x) - } - x$args <- lapply(x$args, function (a) { - if (inherits(a, "array_expression")) { - eval_array_expression(a) - } else { - a - } - }) - if (x$fun == "is_in_meta_binary" && inherits(x$args[[2]], "Scalar")) { - x$args[[2]] <- Array$create(x$args[[2]]) - } - call_function(x$fun, args = x$args, options = x$options %||% empty_named_list()) -} - -find_array_refs <- function(x) { - if (identical(x$fun, "array_ref")) { - out <- x$args$field_name - } else { - out <- lapply(x$args, find_array_refs) - } - unlist(out) -} - -# Take an array_expression and replace array_refs with arrays/chunkedarrays from data -bind_array_refs <- function(x, data) { - if (inherits(x, "array_expression")) { - if (identical(x$fun, "array_ref")) { - x <- data[[x$args$field_name]] - } else { - x$args <- lapply(x$args, bind_array_refs, data) - } - } - x -} - -#' @export -is.na.array_expression <- function(x) array_expression("is.na", x) - -#' @export -as.vector.array_expression <- function(x, ...) { - as.vector(eval_array_expression(x)) -} - -#' @export -print.array_expression <- function(x, ...) { - cat(.format_array_expression(x), "\n", sep = "") - invisible(x) -} - -.format_array_expression <- function(x) { - printed_args <- map_chr(x$args, function(arg) { - if (inherits(arg, "Scalar")) { - deparse(as.vector(arg)) - } else if (inherits(arg, "ArrowObject")) { - paste0("<", class(arg)[1], ">") - } else if (inherits(arg, "array_expression")) { - .format_array_expression(arg) - } else { - # Should not happen - deparse(arg) - } - }) - if (identical(x$fun, "array_ref")) { - x$args$field_name - } else { - # Prune this for readability - function_name <- sub("_kleene", "", x$fun) - paste0(function_name, "(", paste(printed_args, collapse = ", "), ")") - } -} - -########### - #' Arrow expressions #' #' @description @@ -250,6 +76,7 @@ print.array_expression <- function(x, ...) { Expression <- R6Class("Expression", inherit = ArrowObject, public = list( ToString = function() compute___expr__ToString(self), + type = function(schema) compute___expr__type(self, schema), cast = function(to_type, safe = TRUE, ...) { opts <- list( to_type = to_type, @@ -279,13 +106,16 @@ Expression$scalar <- function(x) { compute___expr__scalar(Scalar$create(x)) } -build_dataset_expression <- function(FUN, - ..., - args = list(...), - options = empty_named_list()) { +# Wrapper around Expression$create that: +# (1) maps R function names to Arrow C++ compute ("/" --> "divide_checked") +# (2) wraps R input args as Array or Scalar +build_expr <- function(FUN, + ..., + args = list(...), + options = empty_named_list()) { if (FUN == "-" && length(args) == 1L) { if (inherits(args[[1]], c("ArrowObject", "Expression"))) { - return(build_dataset_expression("negate_checked", args[[1]])) + return(build_expr("negate_checked", args[[1]])) } else { return(-args[[1]]) } @@ -315,7 +145,7 @@ build_dataset_expression <- function(FUN, args <- lapply(args, function(x) x$cast(float64())) } else if (FUN == "%/%") { # In R, integer division works like floor(float division) - out <- build_dataset_expression("/", args = args) + out <- build_expr("/", args = args) return(out$cast(int32(), allow_float_truncate = TRUE)) } else if (FUN == "%%") { return(args[[1]] - args[[2]] * ( args[[1]] %/% args[[2]] )) @@ -329,9 +159,9 @@ build_dataset_expression <- function(FUN, #' @export Ops.Expression <- function(e1, e2) { if (.Generic == "!") { - build_dataset_expression(.Generic, e1) + build_expr(.Generic, e1) } else { - build_dataset_expression(.Generic, e1, e2) + build_expr(.Generic, e1, e2) } } diff --git a/r/man/contains_regex.Rd b/r/man/contains_regex.Rd index d8fee96d99b94..f05f11d02793e 100644 --- a/r/man/contains_regex.Rd +++ b/r/man/contains_regex.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/dplyr.R +% Please edit documentation in R/dplyr-functions.R \name{contains_regex} \alias{contains_regex} \title{Does this string contain regex metacharacters?} diff --git a/r/man/get_stringr_pattern_options.Rd b/r/man/get_stringr_pattern_options.Rd index 79a9a72b7cfb3..7107b9060244e 100644 --- a/r/man/get_stringr_pattern_options.Rd +++ b/r/man/get_stringr_pattern_options.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/dplyr.R +% Please edit documentation in R/dplyr-functions.R \name{get_stringr_pattern_options} \alias{get_stringr_pattern_options} \title{Get \code{stringr} pattern options} diff --git a/r/src/arrowExports.cpp b/r/src/arrowExports.cpp index b274ac5f3af27..9d75b2da54e2a 100644 --- a/r/src/arrowExports.cpp +++ b/r/src/arrowExports.cpp @@ -3097,6 +3097,22 @@ extern "C" SEXP _arrow_compute___expr__ToString(SEXP x_sexp){ } #endif +// expression.cpp +#if defined(ARROW_R_WITH_ARROW) +std::shared_ptr compute___expr__type(const std::shared_ptr& x, const std::shared_ptr& schema); +extern "C" SEXP _arrow_compute___expr__type(SEXP x_sexp, SEXP schema_sexp){ +BEGIN_CPP11 + arrow::r::Input&>::type x(x_sexp); + arrow::r::Input&>::type schema(schema_sexp); + return cpp11::as_sexp(compute___expr__type(x, schema)); +END_CPP11 +} +#else +extern "C" SEXP _arrow_compute___expr__type(SEXP x_sexp, SEXP schema_sexp){ + Rf_error("Cannot call compute___expr__type(). See https://arrow.apache.org/docs/r/articles/install.html for help installing Arrow C++ libraries. "); +} +#endif + // feather.cpp #if defined(ARROW_R_WITH_ARROW) void ipc___WriteFeather__Table(const std::shared_ptr& stream, const std::shared_ptr& table, int version, int chunk_size, arrow::Compression::type compression, int compression_level); @@ -6884,6 +6900,7 @@ static const R_CallMethodDef CallEntries[] = { { "_arrow_compute___expr__get_field_ref_name", (DL_FUNC) &_arrow_compute___expr__get_field_ref_name, 1}, { "_arrow_compute___expr__scalar", (DL_FUNC) &_arrow_compute___expr__scalar, 1}, { "_arrow_compute___expr__ToString", (DL_FUNC) &_arrow_compute___expr__ToString, 1}, + { "_arrow_compute___expr__type", (DL_FUNC) &_arrow_compute___expr__type, 2}, { "_arrow_ipc___WriteFeather__Table", (DL_FUNC) &_arrow_ipc___WriteFeather__Table, 6}, { "_arrow_ipc___feather___Reader__version", (DL_FUNC) &_arrow_ipc___feather___Reader__version, 1}, { "_arrow_ipc___feather___Reader__Read", (DL_FUNC) &_arrow_ipc___feather___Reader__Read, 2}, diff --git a/r/src/expression.cpp b/r/src/expression.cpp index 798853edd720d..d8745ade47990 100644 --- a/r/src/expression.cpp +++ b/r/src/expression.cpp @@ -68,4 +68,12 @@ std::string compute___expr__ToString(const std::shared_ptr& return x->ToString(); } +// [[arrow::export]] +std::shared_ptr compute___expr__type( + const std::shared_ptr& x, + const std::shared_ptr& schema) { + auto bound = ValueOrStop(x->Bind(*schema)); + return bound.type(); +} + #endif diff --git a/r/tests/testthat/helper-arrow.R b/r/tests/testthat/helper-arrow.R index 89d9bf07ee6ea..0abbfb6a13ad5 100644 --- a/r/tests/testthat/helper-arrow.R +++ b/r/tests/testthat/helper-arrow.R @@ -47,7 +47,7 @@ with_language <- function(lang, expr) { Sys.setenv(LANGUAGE = lang) on.exit({ Sys.setenv(LANGUAGE = old) - dplyr_functions$i18ized_error_pattern <<- NULL + .cache$i18ized_error_pattern <<- NULL }) if (!identical(before, i18ize_error_messages())) { skip(paste("This OS either does not support changing languages to", lang, "or it caches translations")) diff --git a/r/tests/testthat/test-RecordBatch.R b/r/tests/testthat/test-RecordBatch.R index d60ed4fbaba69..eef79100950ae 100644 --- a/r/tests/testthat/test-RecordBatch.R +++ b/r/tests/testthat/test-RecordBatch.R @@ -500,13 +500,14 @@ test_that("Handling string data with embedded nuls", { }) test_that("ARROW-11769 - grouping preserved in record batch creation", { - + skip_if_not_available("dataset") + tbl <- tibble::tibble( int = 1:10, fct = factor(rep(c("A", "B"), 5)), fct2 = factor(rep(c("C", "D"), each = 5)), ) - + expect_identical( tbl %>% dplyr::group_by(fct, fct2) %>% @@ -514,5 +515,5 @@ test_that("ARROW-11769 - grouping preserved in record batch creation", { dplyr::group_vars(), c("fct", "fct2") ) - + }) diff --git a/r/tests/testthat/test-Table.R b/r/tests/testthat/test-Table.R index b88b1ba65e30d..ba41c2be705c8 100644 --- a/r/tests/testthat/test-Table.R +++ b/r/tests/testthat/test-Table.R @@ -476,13 +476,14 @@ test_that("Table$create() with different length columns", { }) test_that("ARROW-11769 - grouping preserved in table creation", { - + skip_if_not_available("dataset") + tbl <- tibble::tibble( int = 1:10, fct = factor(rep(c("A", "B"), 5)), fct2 = factor(rep(c("C", "D"), each = 5)), ) - + expect_identical( tbl %>% dplyr::group_by(fct, fct2) %>% @@ -490,5 +491,5 @@ test_that("ARROW-11769 - grouping preserved in table creation", { dplyr::group_vars(), c("fct", "fct2") ) - + }) diff --git a/r/tests/testthat/test-compute-arith.R b/r/tests/testthat/test-compute-arith.R index 0b6d8e8dd1789..2586ba865b3bc 100644 --- a/r/tests/testthat/test-compute-arith.R +++ b/r/tests/testthat/test-compute-arith.R @@ -111,6 +111,7 @@ test_that("Power", { test_that("Dates casting", { a <- Array$create(c(Sys.Date() + 1:4, NA_integer_)) - skip("autocasting should happen in compute kernels; R workaround fails on this ARROW-8919") + skip("ARROW-11090 (date/datetime arithmetic)") + # Error: NotImplemented: Function add_checked has no kernel matching input types (array[date32[day]], scalar[double]) expect_equal(a + 2, Array$create(c((Sys.Date() + 1:4 ) + 2), NA_integer_)) }) diff --git a/r/tests/testthat/test-compute-sort.R b/r/tests/testthat/test-compute-sort.R index ba38d4ce37eac..63977b554143c 100644 --- a/r/tests/testthat/test-compute-sort.R +++ b/r/tests/testthat/test-compute-sort.R @@ -138,28 +138,21 @@ test_that("sort(vector), sort(Array), sort(ChunkedArray) give equivalent results }) test_that("Table$SortIndices()", { + x <- Table$create(tbl) expect_identical( - { - x <- tbl %>% Table$create() - x$Take(x$SortIndices("chr")) %>% pull(chr) - }, + as.vector(x$Take(x$SortIndices("chr"))$chr), sort(tbl$chr, na.last = TRUE) ) expect_identical( - { - x <- tbl %>% Table$create() - x$Take(x$SortIndices(c("int", "dbl"), c(FALSE, FALSE))) %>% collect() - }, + as.data.frame(x$Take(x$SortIndices(c("int", "dbl"), c(FALSE, FALSE)))), tbl %>% arrange(int, dbl) ) }) test_that("RecordBatch$SortIndices()", { + x <- record_batch(tbl) expect_identical( - { - x <- tbl %>% record_batch() - x$Take(x$SortIndices(c("chr", "int", "dbl"), TRUE)) %>% collect() - }, + as.data.frame(x$Take(x$SortIndices(c("chr", "int", "dbl"), TRUE))), tbl %>% arrange(desc(chr), desc(int), desc(dbl)) ) }) diff --git a/r/tests/testthat/test-dataset.R b/r/tests/testthat/test-dataset.R index 1b0dcc07128a4..334ff6d06f7ba 100644 --- a/r/tests/testthat/test-dataset.R +++ b/r/tests/testthat/test-dataset.R @@ -613,22 +613,6 @@ test_that("Creating UnionDataset", { expect_error(c(ds1, 42), "character") }) -test_that("InMemoryDataset", { - ds <- InMemoryDataset$create(rbind(df1, df2)) - expect_r6_class(ds, "InMemoryDataset") - expect_equivalent( - ds %>% - select(chr, dbl) %>% - filter(dbl > 7 & dbl < 53L) %>% - collect() %>% - arrange(dbl), - rbind( - df1[8:10, c("chr", "dbl")], - df2[1:2, c("chr", "dbl")] - ) - ) -}) - test_that("map_batches", { skip_if_not_available("parquet") ds <- open_dataset(dataset_dir, partitioning = "part") @@ -647,18 +631,6 @@ test_that("partitioning = NULL to ignore partition information (but why?)", { expect_identical(names(ds), names(df1)) # i.e. not c(names(df1), "group", "other") }) -test_that("filter() with is.na()", { - skip_if_not_available("parquet") - ds <- open_dataset(dataset_dir, partitioning = schema(part = uint8())) - expect_equivalent( - ds %>% - select(part, lgl) %>% - filter(!is.na(lgl), part == 1) %>% - collect(), - tibble(part = 1L, lgl = df1$lgl[!is.na(df1$lgl)]) - ) -}) - test_that("filter() with is.nan()", { skip_if_not_available("parquet") ds <- open_dataset(dataset_dir, partitioning = schema(part = uint8())) @@ -693,103 +665,6 @@ test_that("filter() with %in%", { ) }) -test_that("filter() with negative scalar", { - skip_if_not_available("parquet") - ds <- open_dataset(dataset_dir, partitioning = schema(part = uint8())) - expect_equivalent( - ds %>% - filter(part == 1) %>% - select(chr, int) %>% - filter(int > -2) %>% - collect(), - df1[, c("chr", "int")] - ) - - expect_equivalent( - ds %>% - filter(part == 1) %>% - select(chr, int) %>% - filter(int %in% -2) %>% - collect(), - df1[FALSE, c("chr", "int")] - ) - - expect_equivalent( - ds %>% - filter(part == 1) %>% - select(chr, int) %>% - filter(-int < -2) %>% - collect(), - df1[df1$int > 2, c("chr", "int")] - ) -}) - -test_that("filter() with strings", { - skip_if_not_available("parquet") - ds <- open_dataset(dataset_dir, partitioning = schema(part = uint8())) - expect_equivalent( - ds %>% - select(chr, part) %>% - filter(chr == "b", part == 1) %>% - collect(), - tibble(chr = "b", part = 1) - ) - - skip_if_not_available("utf8proc") - expect_equivalent( - ds %>% - select(chr, part) %>% - filter(toupper(chr) == "B", part == 1) %>% - collect(), - tibble(chr = "b", part = 1) - ) -}) - -test_that("filter() with arrow compute functions by name", { - skip_if_not_available("parquet") - ds <- open_dataset(dataset_dir, partitioning = schema(part = uint8())) - expect_equivalent( - ds %>% - select(part, lgl) %>% - filter(arrow_is_valid(lgl), arrow_equal(part, 1)) %>% - collect(), - ds %>% - select(part, lgl) %>% - filter(!is.na(lgl), part == 1L) %>% - collect() - ) -}) - -test_that("filter() with .data", { - skip_if_not_available("parquet") - ds <- open_dataset(dataset_dir, partitioning = schema(part = uint8())) - expect_equivalent( - ds %>% - select(.data$int, .data$part) %>% - filter(.data$int == 3, .data$part == 1) %>% - collect(), - tibble(int = df1$int[3], part = 1) - ) - - expect_equivalent( - ds %>% - select(.data$int, .data$part) %>% - filter(.data$int %in% c(6, 4, 3, 103, 107), .data$part == 1) %>% - collect(), - tibble(int = df1$int[c(3, 4, 6)], part = 1) - ) - - # and the .env pronoun too! - chr <- 1 - expect_equivalent( - ds %>% - select(.data$int, .data$part) %>% - filter(.data$int %in% c(6, 4, 3, 103, 107), .data$part == .env$chr) %>% - collect(), - tibble(int = df1$int[c(3, 4, 6)], part = 1) - ) -}) - test_that("filter() on timestamp columns", { skip_if_not_available("parquet") ds <- open_dataset(dataset_dir, partitioning = schema(part = uint8())) @@ -849,109 +724,6 @@ test_that("filter() on date32 columns", { ) }) -test_that("filter() with expressions", { - skip_if_not_available("parquet") - ds <- open_dataset(dataset_dir, partitioning = schema(part = uint8())) - expect_r6_class(ds$format, "ParquetFileFormat") - expect_r6_class(ds$filesystem, "LocalFileSystem") - expect_r6_class(ds, "Dataset") - expect_equivalent( - ds %>% - select(chr, dbl) %>% - filter(dbl * 2 > 14 & dbl - 50 < 3L) %>% - collect() %>% - arrange(dbl), - rbind( - df1[8:10, c("chr", "dbl")], - df2[1:2, c("chr", "dbl")] - ) - ) - - # check division's special casing. - expect_equivalent( - ds %>% - select(chr, dbl) %>% - filter(dbl / 2 > 3.5 & dbl < 53) %>% - collect() %>% - arrange(dbl), - rbind( - df1[8:10, c("chr", "dbl")], - df2[1:2, c("chr", "dbl")] - ) - ) - - expect_equivalent( - ds %>% - select(chr, dbl, int) %>% - filter(int %/% 2L > 3 & dbl < 53) %>% - collect() %>% - arrange(dbl), - rbind( - df1[8:10, c("chr", "dbl", "int")], - df2[1:2, c("chr", "dbl", "int")] - ) - ) - - expect_equivalent( - ds %>% - select(chr, dbl, int) %>% - filter(int %/% 2 > 3 & dbl < 53) %>% - collect() %>% - arrange(dbl), - rbind( - df1[8:10, c("chr", "dbl", "int")], - df2[1:2, c("chr", "dbl", "int")] - ) - ) - - expect_equivalent( - ds %>% - select(chr, dbl, int) %>% - filter(int %% 2L > 0 & dbl < 53) %>% - collect() %>% - arrange(dbl), - rbind( - df1[c(1, 3, 5, 7, 9), c("chr", "dbl", "int")], - df2[1, c("chr", "dbl", "int")] - ) - ) - - expect_equivalent( - ds %>% - select(chr, dbl, int) %>% - filter(int %% 2L > 0 & dbl < 53) %>% - collect() %>% - arrange(dbl), - rbind( - df1[c(1, 3, 5, 7, 9), c("chr", "dbl", "int")], - df2[1, c("chr", "dbl", "int")] - ) - ) - - expect_equivalent( - ds %>% - select(chr, dbl, int) %>% - filter(int %% 2 > 0 & dbl < 53) %>% - collect() %>% - arrange(dbl), - rbind( - df1[c(1, 3, 5, 7, 9), c("chr", "dbl", "int")], - df2[1, c("chr", "dbl", "int")] - ) - ) - - expect_equivalent( - ds %>% - select(chr, dbl, int) %>% - filter(dbl + int > 15 & dbl < 53L) %>% - collect() %>% - arrange(dbl), - rbind( - df1[8:10, c("chr", "dbl", "int")], - df2[1:2, c("chr", "dbl", "int")] - ) - ) -}) test_that("mutate()", { ds <- open_dataset(dataset_dir, partitioning = schema(part = uint8())) @@ -965,7 +737,7 @@ test_that("mutate()", { chr: string dbl: double int: int32 -twice: expr +twice: double (multiply_checked(int, 2)) * Filter: ((multiply_checked(dbl, 2) > 14) and (subtract_checked(dbl, 50) < 3)) See $.data for the source Arrow object", @@ -985,26 +757,6 @@ See $.data for the source Arrow object", ) }) -test_that("transmute()", { - ds <- open_dataset(dataset_dir, partitioning = schema(part = uint8())) - mutated <- - expect_equivalent( - ds %>% - select(chr, dbl, int) %>% - filter(dbl * 2 > 14 & dbl - 50 < 3L) %>% - transmute(twice = int * 2) %>% - collect() %>% - arrange(twice), - rbind( - df1[8:10, "int", drop = FALSE], - df2[1:2, "int", drop = FALSE] - ) %>% - transmute( - twice = int * 2 - ) - ) -}) - test_that("mutate() features not yet implemented", { expect_error( ds %>% @@ -1015,66 +767,6 @@ test_that("mutate() features not yet implemented", { ) }) - -test_that("mutate() with scalar (length 1) literal inputs", { - expect_equal( - ds %>% - mutate(the_answer = 42) %>% - collect() %>% - pull(the_answer), - rep(42, nrow(ds)) - ) - - expect_error( - ds %>% mutate(the_answer = c(42, 42)), - "In the_answer = c(42, 42), only values of size one are recycled\nCall collect() first to pull data into R.", - fixed = TRUE - ) -}) - -test_that("mutate() with NULL inputs", { - expect_equal( - ds %>% - mutate(int = NULL) %>% - collect(), - ds %>% - select(-int) %>% - collect() - ) -}) - -test_that("empty mutate()", { - expect_equal( - ds %>% - mutate() %>% - collect(), - ds %>% - collect() - ) -}) - -test_that("transmute() with NULL inputs", { - expect_equal( - ds %>% - transmute(int = NULL) %>% - collect(), - ds %>% - select() %>% - collect() - ) -}) - -test_that("empty transmute()", { - expect_equal( - ds %>% - transmute() %>% - collect(), - ds %>% - select() %>% - collect() - ) -}) - test_that("filter scalar validation doesn't crash (ARROW-7772)", { expect_error( ds %>% @@ -1120,7 +812,7 @@ test_that("arrange()", { chr: string dbl: double int: int32 -twice: expr +twice: double (multiply_checked(int, 2)) * Filter: ((multiply_checked(dbl, 2) > 14) and (subtract_checked(dbl, 50) < 3)) * Sorted by chr [asc], multiply_checked(int, 2) [desc], add_checked(dbl, int) [asc] @@ -1189,8 +881,8 @@ test_that("compute()/collect(as_data_frame=FALSE)", { # the group_by() prevents compute() from returning a Table... expect_is(tab5, "arrow_dplyr_query") - # ... but $.data is a Table... - expect_is(tab5$.data, "Table") + # ... but $.data is a Table (InMemoryDataset)... + expect_r6_class(tab5$.data, "InMemoryDataset") # ... and the mutate() was evaluated expect_true("negint" %in% names(tab5$.data)) @@ -1549,17 +1241,17 @@ test_that("Dataset writing: dplyr methods", { expect_true(dir.exists(dst_dir)) expect_identical(dir(dst_dir), sort(paste("int", c(1:10, 101:110), sep = "="))) - # select to specify schema + # select to specify schema (and rename) dst_dir2 <- tempfile() ds %>% group_by(int) %>% - select(chr, dbl) %>% + select(chr, dubs = dbl) %>% write_dataset(dst_dir2, format = "feather") new_ds <- open_dataset(dst_dir2, format = "feather") expect_equivalent( collect(new_ds) %>% arrange(int), - rbind(df1[c("chr", "dbl", "int")], df2[c("chr", "dbl", "int")]) + rbind(df1[c("chr", "dbl", "int")], df2[c("chr", "dbl", "int")]) %>% rename(dubs = dbl) ) # filter to restrict written rows @@ -1573,6 +1265,19 @@ test_that("Dataset writing: dplyr methods", { new_ds %>% select(names(df1)) %>% collect(), df1 %>% filter(int == 4) ) + + # mutate + dst_dir3 <- tempfile() + ds %>% + filter(int == 4) %>% + mutate(twice = int * 2) %>% + write_dataset(dst_dir3, format = "feather") + new_ds <- open_dataset(dst_dir3, format = "feather") + + expect_equivalent( + new_ds %>% select(c(names(df1), "twice")) %>% collect(), + df1 %>% filter(int == 4) %>% mutate(twice = int * 2) + ) }) test_that("Dataset writing: non-hive", { @@ -1750,10 +1455,6 @@ test_that("Dataset writing: unsupported features/input validation", { expect_error(write_dataset(4), 'dataset must be a "Dataset"') ds <- open_dataset(hive_dir) - expect_error( - select(ds, integer = int) %>% write_dataset(ds), - "Renaming columns when writing a dataset is not yet supported" - ) expect_error( write_dataset(ds, partitioning = c("int", "NOTACOLUMN"), format = "ipc"), 'Invalid field name: "NOTACOLUMN"' diff --git a/r/tests/testthat/test-dplyr-arrange.R b/r/tests/testthat/test-dplyr-arrange.R index b476c0329452d..45cd687e84814 100644 --- a/r/tests/testthat/test-dplyr-arrange.R +++ b/r/tests/testthat/test-dplyr-arrange.R @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. +skip_if_not_available("dataset") + library(dplyr) # randomize order of rows in test data diff --git a/r/tests/testthat/test-dplyr-filter.R b/r/tests/testthat/test-dplyr-filter.R index d1bd3cec60782..6bba58a7e06d7 100644 --- a/r/tests/testthat/test-dplyr-filter.R +++ b/r/tests/testthat/test-dplyr-filter.R @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. +skip_if_not_available("dataset") + library(dplyr) library(stringr) @@ -193,7 +195,6 @@ test_that("Negative scalar values", { ) }) - test_that("filter() with between()", { expect_dplyr_equal( input %>% @@ -243,34 +244,6 @@ test_that("filter() with between()", { test_that("filter() with string ops", { skip_if_not_available("utf8proc") - skip_if(getRversion() < "3.4.0", "R < 3.4") - # Extra instrumentation to ensure that we're calling Arrow compute here - # because many base R string functions implicitly call as.character, - # which means they still work on Arrays but actually force data into R - # 1) wrapper that raises a warning if as.character is called. Can't wrap - # the whole test because as.character apparently gets called in other - # (presumably legitimate) places - # 2) Wrap the test in expect_warning(expr, NA) to catch the warning - with_no_as_character <- function(expr) { - trace( - "as.character", - tracer = quote(warning("as.character was called")), - print = FALSE, - where = toupper - ) - on.exit(untrace("as.character", where = toupper)) - force(expr) - } - - expect_warning( - expect_dplyr_equal( - input %>% - filter(dbl > 2, with_no_as_character(toupper(chr)) %in% c("D", "F")) %>% - collect(), - tbl - ), - NA) - expect_dplyr_equal( input %>% filter(dbl > 2, str_length(verses) > 25) %>% @@ -303,9 +276,9 @@ test_that("filter environment scope", { skip("Need to substitute in user defined function too") # TODO: fix this: this isEqualTo function is eagerly evaluating; it should - # instead yield array_expressions. Probably bc the parent env of the function - # has the Ops.Array methods defined; we need to move it so that the parent - # env is the data mask we use in the dplyr eval + # instead yield Expressions. Probably bc the parent env of the function + # has the Ops.Expression methods defined; we need to move it so that the + # parent env is the data mask we use in the dplyr eval isEqualTo <- function(x, y) x == y & !is.na(x) expect_dplyr_equal( input %>% @@ -341,7 +314,7 @@ test_that("Filtering on a column that doesn't exist errors correctly", { }) }) -test_that("Filtering with a function that doesn't have an Array/expr method still works", { +test_that("Filtering with unsupported functions", { expect_warning( expect_dplyr_equal( input %>% @@ -349,7 +322,23 @@ test_that("Filtering with a function that doesn't have an Array/expr method stil collect(), tbl ), - 'Filter expression not implemented in Arrow: pnorm(dbl) > 0.99; pulling data into R', + 'Expression pnorm(dbl) > 0.99 not supported in Arrow; pulling data into R', + fixed = TRUE + ) + expect_warning( + expect_dplyr_equal( + input %>% + filter( + nchar(chr, type = "bytes", allowNA = TRUE) == 1, # bad, Arrow msg + int > 2, # good + pnorm(dbl) > .99 # bad, opaque + ) %>% + collect(), + tbl + ), +'* In nchar(chr, type = "bytes", allowNA = TRUE) == 1, allowNA = TRUE not supported by Arrow +* Expression pnorm(dbl) > 0.99 not supported in Arrow +pulling data into R', fixed = TRUE ) }) diff --git a/r/tests/testthat/test-dplyr-group-by.R b/r/tests/testthat/test-dplyr-group-by.R index 6f5d5672d1948..8583c2f9024d5 100644 --- a/r/tests/testthat/test-dplyr-group-by.R +++ b/r/tests/testthat/test-dplyr-group-by.R @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. +skip_if_not_available("dataset") + library(dplyr) library(stringr) diff --git a/r/tests/testthat/test-dplyr-mutate.R b/r/tests/testthat/test-dplyr-mutate.R index 4f202fa5958c6..98eb4983d32d3 100644 --- a/r/tests/testthat/test-dplyr-mutate.R +++ b/r/tests/testthat/test-dplyr-mutate.R @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. +skip_if_not_available("dataset") + library(dplyr) library(stringr) @@ -116,6 +118,7 @@ test_that("nchar() arguments", { collect(), tbl ) + # This tests the whole abandon_ship() machinery expect_warning( expect_dplyr_equal( input %>% @@ -128,7 +131,8 @@ test_that("nchar() arguments", { collect(), tbl ), - "not supported" + 'In nchar(verses, type = "bytes", allowNA = TRUE), allowNA = TRUE not supported by Arrow; pulling data into R', + fixed = TRUE ) }) @@ -173,7 +177,6 @@ test_that("mutate with reassigning same name", { }) test_that("mutate with single value for recycling", { - skip("Not implemented (ARROW-11705") expect_dplyr_equal( input %>% select(int, padded_strings) %>% @@ -338,31 +341,31 @@ test_that("handle bad expressions", { }) }) +test_that("Can't just add a vector column with mutate()", { + expect_warning( + expect_equal( + Table$create(tbl) %>% + select(int) %>% + mutate(again = 1:10), + tibble::tibble(int = tbl$int, again = 1:10) + ), + "In again = 1:10, only values of size one are recycled; pulling data into R" + ) +}) + test_that("print a mutated table", { expect_output( Table$create(tbl) %>% select(int) %>% mutate(twice = int * 2) %>% print(), -'Table (query) +'InMemoryDataset (query) int: int32 -twice: expr +twice: double (multiply_checked(int, 2)) See $.data for the source Arrow object', - fixed = TRUE) - - # Handling non-expressions/edge cases - expect_output( - Table$create(tbl) %>% - select(int) %>% - mutate(again = 1:10) %>% - print(), -'Table (query) -int: int32 -again: expr - -See $.data for the source Arrow object', - fixed = TRUE) + fixed = TRUE + ) }) test_that("mutate and write_dataset", { diff --git a/r/tests/testthat/test-dplyr-string-functions.R b/r/tests/testthat/test-dplyr-string-functions.R index d7df83cc7a64f..fb5e6752709f3 100644 --- a/r/tests/testthat/test-dplyr-string-functions.R +++ b/r/tests/testthat/test-dplyr-string-functions.R @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. +skip_if_not_available("dataset") skip_if_not_available("utf8proc") library(dplyr) @@ -342,99 +343,73 @@ test_that("arrow_*_split_whitespace functions", { collect(), tibble(x = list(c("Foo\u00A0and", "bar"), c("baz\u2006and\u1680qux\u3000and", "quux"))) ) - }) test_that("errors and warnings in string splitting", { - df <- tibble(x = c("Foo and bar", "baz and qux and quux")) - # These conditions generate an error, but abandon_ship() catches the error, - # issues a warning, and pulls the data into R - expect_warning( - df %>% - Table$create() %>% - mutate(x = strsplit(x, "and.*", fixed = FALSE)) %>% - collect(), - regexp = "not supported" + # issues a warning, and pulls the data into R (if computing on InMemoryDataset) + # Elsewhere we test that abandon_ship() works, + # so here we can just call the functions directly + + x <- Expression$field_ref("x") + expect_error( + nse_funcs$strsplit(x, "and.*", fixed = FALSE), + 'Regular expression matching in strsplit() not supported by Arrow', + fixed = TRUE ) - expect_warning( - df %>% - Table$create() %>% - mutate(x = str_split(x, "and.?")) %>% - collect() + expect_error( + nse_funcs$str_split(x, "and.?"), + 'Regular expression matching in str_split() not supported by Arrow', + fixed = TRUE ) - expect_warning( - df %>% - Table$create() %>% - mutate(x = str_split(x, regex("and.?"), n = 2)) %>% - collect(), - regexp = "not supported" + expect_error( + nse_funcs$str_split(x, regex("and.*")), + 'Regular expression matching in str_split() not supported by Arrow', + fixed = TRUE ) - expect_warning( - df %>% - Table$create() %>% - mutate(x = str_split(x, fixed("and", ignore_case = TRUE))) %>% - collect(), - "not supported" + expect_error( + nse_funcs$str_split(x, fixed("and", ignore_case = TRUE)), + "Case-insensitive string splitting not supported by Arrow" ) - expect_warning( - df %>% - Table$create() %>% - mutate(x = str_split(x, coll("and.?"))) %>% - collect(), - regexp = "not supported" + expect_error( + nse_funcs$str_split(x, coll("and.?")), + "Pattern modifier `coll()` not supported by Arrow", + fixed = TRUE ) - expect_warning( - df %>% - Table$create() %>% - mutate(x = str_split(x, boundary(type = "word"))) %>% - collect(), - regexp = "not supported" + expect_error( + nse_funcs$str_split(x, boundary(type = "word")), + "Pattern modifier `boundary()` not supported by Arrow", + fixed = TRUE ) - expect_warning( - df %>% - Table$create() %>% - mutate(x = str_split(x, "and", n = 0)) %>% - collect(), - regexp = "not supported" + expect_error( + nse_funcs$str_split(x, "and", n = 0), + "Splitting strings into zero parts not supported by Arrow" ) # This condition generates a warning expect_warning( - df %>% - Table$create() %>% - mutate(x = str_split(x, fixed("and"), simplify = TRUE)) %>% - collect(), - "ignored" + nse_funcs$str_split(x, fixed("and"), simplify = TRUE), + "Argument 'simplify = TRUE' will be ignored" ) - }) test_that("errors and warnings in string detection and replacement", { - df <- tibble(x = c("Foo", "bar")) + x <- Expression$field_ref("x") - # These conditions generate an error, but abandon_ship() catches the error, - # issues a warning, and pulls the data into R - expect_warning( - df %>% - Table$create() %>% - filter(str_detect(x, boundary(type = "character"))) %>% - collect(), - regexp = "not implemented" + expect_error( + nse_funcs$str_detect(x, boundary(type = "character")), + "Pattern modifier `boundary()` not supported by Arrow", + fixed = TRUE ) - expect_warning( - df %>% - Table$create() %>% - mutate(x = str_replace_all(x, coll("o", locale = "en"), "ó")) %>% - collect(), - regexp = "not supported" + expect_error( + nse_funcs$str_replace_all(x, coll("o", locale = "en"), "ó"), + "Pattern modifier `coll()` not supported by Arrow", + fixed = TRUE ) # This condition generates a warning expect_warning( - df %>% - Table$create() %>% - transmute(x = str_replace_all(x, regex("o", multiline = TRUE), "u")), + nse_funcs$str_replace_all(x, regex("o", multiline = TRUE), "u"), "Ignoring pattern modifier argument not supported in Arrow: \"multiline\"" ) @@ -521,5 +496,4 @@ test_that("edge cases in string detection and replacement", { collect(), tibble(x = c("ABC")) ) - }) diff --git a/r/tests/testthat/test-dplyr.R b/r/tests/testthat/test-dplyr.R index a02b00f3d95e8..46d30e378236c 100644 --- a/r/tests/testthat/test-dplyr.R +++ b/r/tests/testthat/test-dplyr.R @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. +skip_if_not_available("dataset") + library(dplyr) library(stringr) @@ -57,11 +59,11 @@ test_that("Print method", { filter(int < 5) %>% select(int, chr) %>% print(), -'RecordBatch (query) +'InMemoryDataset (query) int: int32 chr: string -* Filter: and(and(greater(dbl, 2), or(equal(chr, "d"), equal(chr, "f"))), less(int, 5)) +* Filter: (((dbl > 2) and ((chr == "d") or (chr == "f"))) and (int < 5)) See $.data for the source Arrow object', fixed = TRUE ) @@ -187,7 +189,8 @@ test_that("collect(as_data_frame=FALSE)", { filter(int > 5) %>% collect(as_data_frame = FALSE) - expect_r6_class(b2, "RecordBatch") + # collect(as_data_frame = FALSE) always returns Table now + expect_r6_class(b2, "Table") expected <- tbl[tbl$int > 5 & !is.na(tbl$int), c("int", "chr")] expect_equal(as.data.frame(b2), expected) @@ -195,7 +198,7 @@ test_that("collect(as_data_frame=FALSE)", { select(int, strng = chr) %>% filter(int > 5) %>% collect(as_data_frame = FALSE) - expect_r6_class(b3, "RecordBatch") + expect_r6_class(b3, "Table") expect_equal(as.data.frame(b3), set_names(expected, c("int", "strng"))) b4 <- batch %>% @@ -217,14 +220,14 @@ test_that("compute()", { b1 <- batch %>% compute() - expect_is(b1, "RecordBatch") + expect_r6_class(b1, "RecordBatch") b2 <- batch %>% select(int, chr) %>% filter(int > 5) %>% compute() - expect_is(b2, "RecordBatch") + expect_r6_class(b2, "Table") expected <- tbl[tbl$int > 5 & !is.na(tbl$int), c("int", "chr")] expect_equal(as.data.frame(b2), expected) @@ -232,7 +235,7 @@ test_that("compute()", { select(int, strng = chr) %>% filter(int > 5) %>% compute() - expect_is(b3, "RecordBatch") + expect_r6_class(b3, "Table") expect_equal(as.data.frame(b3), set_names(expected, c("int", "strng"))) b4 <- batch %>% @@ -240,7 +243,7 @@ test_that("compute()", { filter(int > 5) %>% group_by(int) %>% compute() - expect_is(b4, "arrow_dplyr_query") + expect_s3_class(b4, "arrow_dplyr_query") expect_equal( as.data.frame(b4), expected %>% @@ -257,7 +260,7 @@ test_that("head", { filter(int > 5) %>% head(2) - expect_r6_class(b2, "RecordBatch") + expect_r6_class(b2, "Table") expected <- tbl[tbl$int > 5 & !is.na(tbl$int), c("int", "chr")][1:2, ] expect_equal(as.data.frame(b2), expected) @@ -265,7 +268,7 @@ test_that("head", { select(int, strng = chr) %>% filter(int > 5) %>% head(2) - expect_r6_class(b3, "RecordBatch") + expect_r6_class(b3, "Table") expect_equal(as.data.frame(b3), set_names(expected, c("int", "strng"))) b4 <- batch %>% @@ -290,7 +293,7 @@ test_that("tail", { filter(int > 5) %>% tail(2) - expect_r6_class(b2, "RecordBatch") + expect_r6_class(b2, "Table") expected <- tail(tbl[tbl$int > 5 & !is.na(tbl$int), c("int", "chr")], 2) expect_equal(as.data.frame(b2), expected) @@ -298,7 +301,7 @@ test_that("tail", { select(int, strng = chr) %>% filter(int > 5) %>% tail(2) - expect_r6_class(b3, "RecordBatch") + expect_r6_class(b3, "Table") expect_equal(as.data.frame(b3), set_names(expected, c("int", "strng"))) b4 <- batch %>% @@ -501,6 +504,7 @@ test_that("explicit type conversions with as.*()", { }) test_that("as.factor()/dictionary_encode()", { + skip("ARROW-12632: ExecuteScalarExpression cannot Execute non-scalar expression {x=dictionary_encode(x, {NON-REPRESENTABLE OPTIONS})}") df1 <- tibble(x = c("C", "D", "B", NA, "D", "B", "S", "A", "B", "Z", "B")) df2 <- tibble(x = c(5, 5, 5, NA, 2, 3, 6, 8)) diff --git a/r/tests/testthat/test-expression.R b/r/tests/testthat/test-expression.R index dd61b5e3ca26f..d0459fde5b523 100644 --- a/r/tests/testthat/test-expression.R +++ b/r/tests/testthat/test-expression.R @@ -17,34 +17,6 @@ context("Expressions") -test_that("Can create an expression", { - expect_s3_class(build_array_expression(">", Array$create(1:5), 4), "array_expression") -}) - -test_that("as.vector(array_expression)", { - expect_equal(as.vector(build_array_expression(">", Array$create(1:5), 4)), c(FALSE, FALSE, FALSE, FALSE, TRUE)) -}) - -test_that("array_expression print method", { - expect_output( - print(build_array_expression(">", Array$create(1:5), 4)), - # Not ideal but it is informative - "greater(, 4)", - fixed = TRUE - ) -}) - -test_that("array_refs", { - tab <- Table$create(a = 1:5) - ex <- build_array_expression(">", array_expression("array_ref", field_name = "a"), 4) - expect_s3_class(ex, "array_expression") - expect_identical(ex$args[[1]]$args$field_name, "a") - expect_identical(find_array_refs(ex), "a") - out <- eval_array_expression(ex, tab) - expect_r6_class(out, "ChunkedArray") - expect_equal(as.vector(out), c(FALSE, FALSE, FALSE, FALSE, TRUE)) -}) - test_that("C++ expressions", { skip_if_not_available("dataset") f <- Expression$field_ref("f") @@ -76,24 +48,14 @@ test_that("C++ expressions", { 'Expression\n(f > 4)', fixed = TRUE ) + expect_type_equal( + f$type(schema(f = float64())), + float64() + ) + expect_type_equal( + (f > 4)$type(schema(f = float64())), + bool() + ) # Interprets that as a list type expect_r6_class(f == c(1L, 2L), "Expression") -}) - -test_that("Can create an expression", { - a <- Array$create(as.numeric(1:5)) - expr <- array_expression("cast", a, options = list(to_type = int32())) - expect_s3_class(expr, "array_expression") - expect_equal(eval_array_expression(expr), Array$create(1:5)) - - b <- Array$create(0.5:4.5) - bad_expr <- array_expression("cast", b, options = list(to_type = int32())) - expect_s3_class(bad_expr, "array_expression") - expect_error( - eval_array_expression(bad_expr), - "Invalid: Float value .* was truncated converting" - ) - expr <- array_expression("cast", b, options = list(to_type = int32(), allow_float_truncate = TRUE)) - expect_s3_class(expr, "array_expression") - expect_equal(eval_array_expression(expr), Array$create(0:4)) -}) +}) \ No newline at end of file diff --git a/r/tests/testthat/test-filesystem.R b/r/tests/testthat/test-filesystem.R index 344865c077a0f..df084f35a494c 100644 --- a/r/tests/testthat/test-filesystem.R +++ b/r/tests/testthat/test-filesystem.R @@ -136,6 +136,7 @@ test_that("LocalFileSystem + Selector", { test_that("FileSystem$from_uri", { skip_on_cran() skip_if_not_available("s3") + skip_if_offline() fs_and_path <- FileSystem$from_uri("s3://ursa-labs-taxi-data") expect_r6_class(fs_and_path$fs, "S3FileSystem") expect_identical(fs_and_path$fs$region, "us-east-2") @@ -144,6 +145,7 @@ test_that("FileSystem$from_uri", { test_that("SubTreeFileSystem$create() with URI", { skip_on_cran() skip_if_not_available("s3") + skip_if_offline() fs <- SubTreeFileSystem$create("s3://ursa-labs-taxi-data") expect_r6_class(fs, "SubTreeFileSystem") expect_identical( @@ -155,6 +157,7 @@ test_that("SubTreeFileSystem$create() with URI", { test_that("S3FileSystem", { skip_on_cran() skip_if_not_available("s3") + skip_if_offline() s3fs <- S3FileSystem$create() expect_r6_class(s3fs, "S3FileSystem") }) @@ -162,6 +165,7 @@ test_that("S3FileSystem", { test_that("s3_bucket", { skip_on_cran() skip_if_not_available("s3") + skip_if_offline() bucket <- s3_bucket("ursa-labs-r-test") expect_r6_class(bucket, "SubTreeFileSystem") expect_r6_class(bucket$base_fs, "S3FileSystem")