Skip to content

Commit

Permalink
ARROW-12731: [R] Use InMemoryDataset for Table/RecordBatch in dplyr code
Browse files Browse the repository at this point in the history
Discussing with @bkietz on #10166, we realized that we could already evaluate filter/project on Table/RecordBatch by wrapping it in InMemoryDataset and using the Dataset machinery, so I wanted to see how well that worked. Mostly it does, with a couple of caveats:

* You can't dictionary_encode a dataset column. `Error: Invalid: ExecuteScalarExpression cannot Execute non-scalar expression {x=dictionary_encode(x, {NON-REPRESENTABLE OPTIONS})}` (ARROW-12632). I will remove the `as.factor` method and leave a TODO to restore it after that JIRA is resolved.
* with the existing array_expressions, you could supply an additional Array (or R data convertible to an Array) when doing `mutate()`; this is not implemented for Datasets and that's ok. For Tables/RecordBatches, the behavior in this PR is to pull the data into R, which is fine.

There are a lot of changes here, which means the diff is big, but I've tried to group into distinct commits the main action. Highlights:

* 5b501c5 is the main switch to use InMemoryDataset
* b31fb5e deletes `array_expression`
* 0d31938 simplifies the interface for adding functions to the dplyr data_mask; definitely check this one out and see what you think of the new way--I hope it's much simpler to add new functions
* 2e6374f improves the print method for queries by showing both the expression and the expected type of the output column, per suggestion from @bkietz
* d12f584 just splits up dplyr.R into many files; 34dc1e6 deletes tests that are duplicated between test-dplyr*.R and test-dataset.R (since they're now going through a common C++ interface).
* a0914f6 + eee491a contain ARROW-12696

Closes #10191 from nealrichardson/dplyr-in-memory

Authored-by: Neal Richardson <[email protected]>
Signed-off-by: Neal Richardson <[email protected]>
  • Loading branch information
nealrichardson committed May 13, 2021
1 parent b34c8f6 commit 9347731
Show file tree
Hide file tree
Showing 37 changed files with 1,360 additions and 1,708 deletions.
11 changes: 10 additions & 1 deletion r/DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,18 @@ Collate:
'dataset-write.R'
'deprecated.R'
'dictionary.R'
'dplyr-arrange.R'
'dplyr-collect.R'
'dplyr-eval.R'
'dplyr-filter.R'
'expression.R'
'dplyr-functions.R'
'dplyr-group-by.R'
'dplyr-mutate.R'
'dplyr-select.R'
'dplyr-summarize.R'
'record-batch.R'
'table.R'
'expression.R'
'dplyr.R'
'feather.R'
'field.R'
Expand Down
5 changes: 1 addition & 4 deletions r/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ S3method("[[<-",Schema)
S3method("names<-",ArrowTabular)
S3method(Ops,ArrowDatum)
S3method(Ops,Expression)
S3method(Ops,array_expression)
S3method(all,ArrowDatum)
S3method(all,equal.ArrowObject)
S3method(any,ArrowDatum)
Expand All @@ -37,7 +36,6 @@ S3method(as.list,ArrowTabular)
S3method(as.list,Schema)
S3method(as.raw,Buffer)
S3method(as.vector,ArrowDatum)
S3method(as.vector,array_expression)
S3method(c,Dataset)
S3method(dim,ArrowTabular)
S3method(dim,Dataset)
Expand All @@ -51,7 +49,6 @@ S3method(head,arrow_dplyr_query)
S3method(is.na,ArrowDatum)
S3method(is.na,Expression)
S3method(is.na,Scalar)
S3method(is.na,array_expression)
S3method(is.nan,ArrowDatum)
S3method(is_in,ArrowDatum)
S3method(is_in,default)
Expand Down Expand Up @@ -80,7 +77,6 @@ S3method(names,StructArray)
S3method(names,Table)
S3method(names,arrow_dplyr_query)
S3method(print,"arrow-enum")
S3method(print,array_expression)
S3method(print,arrow_dplyr_query)
S3method(print,arrow_info)
S3method(print,arrow_r_metadata)
Expand Down Expand Up @@ -295,6 +291,7 @@ importFrom(purrr,as_mapper)
importFrom(purrr,keep)
importFrom(purrr,map)
importFrom(purrr,map2)
importFrom(purrr,map2_chr)
importFrom(purrr,map_chr)
importFrom(purrr,map_dfr)
importFrom(purrr,map_int)
Expand Down
71 changes: 67 additions & 4 deletions r/R/arrow-datum.R
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,73 @@ as.vector.ArrowDatum <- function(x, mode) {
)
}

#' @export
Ops.ArrowDatum <- function(e1, e2) {
if (.Generic == "!") {
eval_array_expression(.Generic, e1)
} else if (.Generic %in% names(.array_function_map)) {
eval_array_expression(.Generic, e1, e2)
} else {
stop(paste0("Unsupported operation on `", class(e1)[1L], "` : "), .Generic, call. = FALSE)
}
}

# Wrapper around call_function that:
# (1) maps R function names to Arrow C++ compute ("/" --> "divide_checked")
# (2) wraps R input args as Array or Scalar
eval_array_expression <- function(FUN,
...,
args = list(...),
options = empty_named_list()) {
if (FUN == "-" && length(args) == 1L) {
if (inherits(args[[1]], "ArrowObject")) {
return(eval_array_expression("negate_checked", args[[1]]))
} else {
return(-args[[1]])
}
}
args <- lapply(args, .wrap_arrow, FUN)

# In Arrow, "divide" is one function, which does integer division on
# integer inputs and floating-point division on floats
if (FUN == "/") {
# TODO: omg so many ways it's wrong to assume these types
args <- map(args, ~.$cast(float64()))
} else if (FUN == "%/%") {
# In R, integer division works like floor(float division)
out <- eval_array_expression("/", args = args, options = options)
return(out$cast(int32(), allow_float_truncate = TRUE))
} else if (FUN == "%%") {
# {e1 - e2 * ( e1 %/% e2 )}
# ^^^ form doesn't work because Ops.Array evaluates eagerly,
# but we can build that up
quotient <- eval_array_expression("%/%", args = args)
base <- eval_array_expression("*", quotient, args[[2]])
# this cast is to ensure that the result of this and e1 are the same
# (autocasting only applies to scalars)
base <- base$cast(args[[1]]$type)
return(eval_array_expression("-", args[[1]], base))
}

call_function(
.array_function_map[[FUN]] %||% FUN,
args = args,
options = options
)
}

.wrap_arrow <- function(arg, fun) {
if (!inherits(arg, "ArrowObject")) {
# TODO: Array$create if lengths are equal?
if (fun == "%in%") {
arg <- Array$create(arg)
} else {
arg <- Scalar$create(arg)
}
}
arg
}

#' @export
na.omit.ArrowDatum <- function(object, ...){
object$Filter(!is.na(object))
Expand All @@ -66,10 +133,6 @@ filter_rows <- function(x, i, keep_na = TRUE, ...) {
# General purpose function for [ row subsetting with R semantics
# Based on the input for `i`, calls x$Filter, x$Slice, or x$Take
nrows <- x$num_rows %||% x$length() # Depends on whether Array or Table-like
if (inherits(i, "array_expression")) {
# Evaluate it
i <- eval_array_expression(i)
}
if (is.logical(i)) {
if (isTRUE(i)) {
# Shortcut without doing any work
Expand Down
16 changes: 13 additions & 3 deletions r/R/arrow-package.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

#' @importFrom stats quantile median na.omit na.exclude na.pass na.fail
#' @importFrom R6 R6Class
#' @importFrom purrr as_mapper map map2 map_chr map_dfr map_int map_lgl keep
#' @importFrom purrr as_mapper map map2 map_chr map2_chr map_dfr map_int map_lgl keep
#' @importFrom assertthat assert_that is.string
#' @importFrom rlang list2 %||% is_false abort dots_n warn enquo quo_is_null enquos is_integerish quos eval_tidy new_data_mask syms env new_environment env_bind as_label set_names exec is_bare_character quo_get_expr quo_set_expr .data seq2 is_quosure enexpr enexprs expr
#' @importFrom tidyselect vars_pull vars_rename vars_select eval_select
Expand Down Expand Up @@ -49,8 +49,18 @@

# Create these once, at package build time
if (arrow_available()) {
dplyr_functions$dataset <- build_function_list(build_dataset_expression)
dplyr_functions$array <- build_function_list(build_array_expression)
# Also include all available Arrow Compute functions,
# namespaced as arrow_fun.
# We can't do this at install time because list_compute_functions() may error
all_arrow_funs <- list_compute_functions()
arrow_funcs <- set_names(
lapply(all_arrow_funs, function(fun) {
force(fun)
function(...) build_expr(fun, ...)
}),
paste0("arrow_", all_arrow_funs)
)
.cache$functions <- c(nse_funcs, arrow_funcs)
}
invisible()
}
Expand Down
6 changes: 3 additions & 3 deletions r/R/arrow-tabular.R
Original file line number Diff line number Diff line change
Expand Up @@ -223,13 +223,13 @@ na.fail.ArrowTabular <- function(object, ...){

#' @export
na.omit.ArrowTabular <- function(object, ...){
not_na <- map(object$columns, ~build_array_expression("is_valid", .x))
not_na <- map(object$columns, ~call_function("is_valid", .x))
not_na_agg <- Reduce("&", not_na)
object$Filter(eval_array_expression(not_na_agg))
object$Filter(not_na_agg)
}

#' @export
na.exclude.ArrowTabular <- na.omit.ArrowTabular
na.exclude.ArrowTabular <- na.omit.ArrowTabular

ToString_tabular <- function(x, ...) {
# Generic to work with both RecordBatch and Table
Expand Down
4 changes: 4 additions & 0 deletions r/R/arrowExports.R

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 5 additions & 9 deletions r/R/dataset-scan.R
Original file line number Diff line number Diff line change
Expand Up @@ -162,16 +162,12 @@ ScannerBuilder <- R6Class("ScannerBuilder", inherit = ArrowObject,
# cols is either a character vector or a named list of Expressions
if (is.character(cols)) {
dataset___ScannerBuilder__ProjectNames(self, cols)
} else if (length(cols) == 0) {
# Empty projection
dataset___ScannerBuilder__ProjectNames(self, character(0))
} else {
# If we have expressions, but they all turn out to be field_refs,
# we can still call the simple method
field_names <- get_field_names(cols)
if (all(nzchar(field_names))) {
dataset___ScannerBuilder__ProjectNames(self, field_names)
} else {
# Else, we are projecting/mutating
dataset___ScannerBuilder__ProjectExprs(self, cols, names(cols))
}
# List of Expressions
dataset___ScannerBuilder__ProjectExprs(self, cols, names(cols))
}
self
},
Expand Down
8 changes: 0 additions & 8 deletions r/R/dataset-write.R
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,6 @@ write_dataset <- function(dataset,
...) {
format <- match.arg(format)
if (inherits(dataset, "arrow_dplyr_query")) {
if (inherits(dataset$.data, "ArrowTabular")) {
# collect() to materialize any mutate/rename
dataset <- dplyr::collect(dataset, as_data_frame = FALSE)
}
# We can select a subset of columns but we can't rename them
if (!all(get_field_names(dataset) == names(dataset$selected_columns))) {
stop("Renaming columns when writing a dataset is not yet supported", call. = FALSE)
}
# partitioning vars need to be in the `select` schema
dataset <- ensure_group_vars(dataset)
} else if (inherits(dataset, "grouped_df")) {
Expand Down
93 changes: 93 additions & 0 deletions r/R/dplyr-arrange.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.


# The following S3 methods are registered on load if dplyr is present

arrange.arrow_dplyr_query <- function(.data, ..., .by_group = FALSE) {
call <- match.call()
exprs <- quos(...)
if (.by_group) {
# when the data is is grouped and .by_group is TRUE, order the result by
# the grouping columns first
exprs <- c(quos(!!!dplyr::groups(.data)), exprs)
}
if (length(exprs) == 0) {
# Nothing to do
return(.data)
}
.data <- arrow_dplyr_query(.data)
# find and remove any dplyr::desc() and tidy-eval
# the arrange expressions inside an Arrow data_mask
sorts <- vector("list", length(exprs))
descs <- logical(0)
mask <- arrow_mask(.data)
for (i in seq_along(exprs)) {
x <- find_and_remove_desc(exprs[[i]])
exprs[[i]] <- x[["quos"]]
sorts[[i]] <- arrow_eval(exprs[[i]], mask)
if (inherits(sorts[[i]], "try-error")) {
msg <- paste('Expression', as_label(exprs[[i]]), 'not supported in Arrow')
return(abandon_ship(call, .data, msg))
}
names(sorts)[i] <- as_label(exprs[[i]])
descs[i] <- x[["desc"]]
}
.data$arrange_vars <- c(sorts, .data$arrange_vars)
.data$arrange_desc <- c(descs, .data$arrange_desc)
.data
}
arrange.Dataset <- arrange.ArrowTabular <- arrange.arrow_dplyr_query

# Helper to handle desc() in arrange()
# * Takes a quosure as input
# * Returns a list with two elements:
# 1. The quosure with any wrapping parentheses and desc() removed
# 2. A logical value indicating whether desc() was found
# * Performs some other validation
find_and_remove_desc <- function(quosure) {
expr <- quo_get_expr(quosure)
descending <- FALSE
if (length(all.vars(expr)) < 1L) {
stop(
"Expression in arrange() does not contain any field names: ",
deparse(expr),
call. = FALSE
)
}
# Use a while loop to remove any number of nested pairs of enclosing
# parentheses and any number of nested desc() calls. In the case of multiple
# nested desc() calls, each one toggles the sort order.
while (identical(typeof(expr), "language") && is.call(expr)) {
if (identical(expr[[1]], quote(`(`))) {
# remove enclosing parentheses
expr <- expr[[2]]
} else if (identical(expr[[1]], quote(desc))) {
# remove desc() and toggle descending
expr <- expr[[2]]
descending <- !descending
} else {
break
}
}
return(
list(
quos = quo_set_expr(quosure, expr),
desc = descending
)
)
}
62 changes: 62 additions & 0 deletions r/R/dplyr-collect.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.


# The following S3 methods are registered on load if dplyr is present

collect.arrow_dplyr_query <- function(x, as_data_frame = TRUE, ...) {
x <- ensure_group_vars(x)
x <- ensure_arrange_vars(x) # this sets x$temp_columns
# Pull only the selected rows and cols into R
# See dataset.R for Dataset and Scanner(Builder) classes
tab <- Scanner$create(x)$ToTable()
# Arrange rows
if (length(x$arrange_vars) > 0) {
tab <- tab[
tab$SortIndices(names(x$arrange_vars), x$arrange_desc),
names(x$selected_columns), # this omits x$temp_columns from the result
drop = FALSE
]
}
if (as_data_frame) {
df <- as.data.frame(tab)
tab$invalidate()
restore_dplyr_features(df, x)
} else {
restore_dplyr_features(tab, x)
}
}
collect.ArrowTabular <- function(x, as_data_frame = TRUE, ...) {
if (as_data_frame) {
as.data.frame(x, ...)
} else {
x
}
}
collect.Dataset <- function(x, ...) dplyr::collect(arrow_dplyr_query(x), ...)

compute.arrow_dplyr_query <- function(x, ...) dplyr::collect(x, as_data_frame = FALSE)
compute.ArrowTabular <- function(x, ...) x
compute.Dataset <- compute.arrow_dplyr_query

pull.arrow_dplyr_query <- function(.data, var = -1) {
.data <- arrow_dplyr_query(.data)
var <- vars_pull(names(.data), !!enquo(var))
.data$selected_columns <- set_names(.data$selected_columns[var], var)
dplyr::collect(.data)[[1]]
}
pull.Dataset <- pull.ArrowTabular <- pull.arrow_dplyr_query
Loading

0 comments on commit 9347731

Please sign in to comment.