Skip to content

Commit

Permalink
Use InMemoryDataset for Table/RecordBatch in dplyr code
Browse files Browse the repository at this point in the history
  • Loading branch information
nealrichardson committed May 8, 2021
1 parent e2e7732 commit 5b501c5
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 122 deletions.
8 changes: 0 additions & 8 deletions r/R/dataset-write.R
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,6 @@ write_dataset <- function(dataset,
...) {
format <- match.arg(format)
if (inherits(dataset, "arrow_dplyr_query")) {
if (inherits(dataset$.data, "ArrowTabular")) {
# collect() to materialize any mutate/rename
dataset <- dplyr::collect(dataset, as_data_frame = FALSE)
}
# We can select a subset of columns but we can't rename them
if (!all(get_field_names(dataset) == names(dataset$selected_columns))) {
stop("Renaming columns when writing a dataset is not yet supported", call. = FALSE)
}
# partitioning vars need to be in the `select` schema
dataset <- ensure_group_vars(dataset)
} else if (inherits(dataset, "grouped_df")) {
Expand Down
116 changes: 18 additions & 98 deletions r/R/dplyr.R
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,17 @@ arrow_dplyr_query <- function(.data) {
if (inherits(.data, "arrow_dplyr_query")) {
return(.data)
}
if (!inherits(.data, "Dataset")) {
.data <- InMemoryDataset$create(.data)
}
structure(
list(
.data = .data$clone(),
# selected_columns is a named list:
# * contents are references/expressions pointing to the data
# * names are the names they should be in the end (i.e. this
# records any renaming)
selected_columns = make_field_refs(names(.data), dataset = inherits(.data, "Dataset")),
selected_columns = make_field_refs(names(.data)),
# filtered_rows will be an Expression
filtered_rows = TRUE,
# group_by_vars is a character vector of columns (as renamed)
Expand Down Expand Up @@ -76,22 +79,14 @@ print.arrow_dplyr_query <- function(x, ...) {
cat(fields, "\n", sep = "")
cat("\n")
if (!isTRUE(x$filtered_rows)) {
if (query_on_dataset(x)) {
filter_string <- x$filtered_rows$ToString()
} else {
filter_string <- .format_array_expression(x$filtered_rows)
}
filter_string <- x$filtered_rows$ToString()
cat("* Filter: ", filter_string, "\n", sep = "")
}
if (length(x$group_by_vars)) {
cat("* Grouped by ", paste(x$group_by_vars, collapse = ", "), "\n", sep = "")
}
if (length(x$arrange_vars)) {
if (query_on_dataset(x)) {
arrange_strings <- map_chr(x$arrange_vars, function(x) x$ToString())
} else {
arrange_strings <- map_chr(x$arrange_vars, .format_array_expression)
}
arrange_strings <- map_chr(x$arrange_vars, function(x) x$ToString())
cat(
"* Sorted by ",
paste(
Expand Down Expand Up @@ -127,12 +122,8 @@ get_field_names <- function(selected_cols) {
})
}

make_field_refs <- function(field_names, dataset = TRUE) {
if (dataset) {
out <- lapply(field_names, Expression$field_ref)
} else {
out <- lapply(field_names, function(x) array_expression("array_ref", field_name = x))
}
make_field_refs <- function(field_names) {
out <- lapply(field_names, Expression$field_ref)
set_names(out, field_names)
}

Expand All @@ -146,12 +137,8 @@ dim.arrow_dplyr_query <- function(x) {

if (isTRUE(x$filtered)) {
rows <- x$.data$num_rows
} else if (query_on_dataset(x)) {
scanner <- Scanner$create(x)
rows <- scanner$CountRows()
} else {
# Evaluate the filter expression to a BooleanArray and count
rows <- as.integer(sum(eval_array_expression(x$filtered_rows, x$.data), na.rm = TRUE))
rows <- Scanner$create(x)$CountRows()
}
c(rows, cols)
}
Expand All @@ -162,46 +149,13 @@ as.data.frame.arrow_dplyr_query <- function(x, row.names = NULL, optional = FALS
}

#' @export
head.arrow_dplyr_query <- function(x, n = 6L, ...) {
if (query_on_dataset(x)) {
head.Dataset(x, n, ...)
} else {
out <- collect.arrow_dplyr_query(x, as_data_frame = FALSE)
if (inherits(out, "arrow_dplyr_query")) {
out$.data <- head(out$.data, n)
} else {
out <- head(out, n)
}
out
}
}
head.arrow_dplyr_query <- head.Dataset

#' @export
tail.arrow_dplyr_query <- function(x, n = 6L, ...) {
if (query_on_dataset(x)) {
tail.Dataset(x, n, ...)
} else {
out <- collect.arrow_dplyr_query(x, as_data_frame = FALSE)
if (inherits(out, "arrow_dplyr_query")) {
out$.data <- tail(out$.data, n)
} else {
out <- tail(out, n)
}
out
}
}
tail.arrow_dplyr_query <- tail.Dataset

#' @export
`[.arrow_dplyr_query` <- function(x, i, j, ..., drop = FALSE) {
if (query_on_dataset(x)) {
`[.Dataset`(x, i, j, ..., drop = FALSE)
} else {
stop(
"[ method not implemented for queries. Call 'collect(x, as_data_frame = FALSE)' first",
call. = FALSE
)
}
}
`[.arrow_dplyr_query` <- `[.Dataset`

# The following S3 methods are registered on load if dplyr is present
tbl_vars.arrow_dplyr_query <- function(x) names(x$selected_columns)
Expand Down Expand Up @@ -686,11 +640,7 @@ init_env()

# Create a data mask for evaluating a dplyr expression
arrow_mask <- function(.data) {
if (query_on_dataset(.data)) {
f_env <- new_environment(dplyr_functions$dataset)
} else {
f_env <- new_environment(dplyr_functions$array)
}
f_env <- new_environment(dplyr_functions$dataset)

# Add functions that need to error hard and clear.
# Some R functions will still try to evaluate on an Expression
Expand Down Expand Up @@ -730,36 +680,8 @@ collect.arrow_dplyr_query <- function(x, as_data_frame = TRUE, ...) {
x <- ensure_group_vars(x)
x <- ensure_arrange_vars(x) # this sets x$temp_columns
# Pull only the selected rows and cols into R
if (query_on_dataset(x)) {
# See dataset.R for Dataset and Scanner(Builder) classes
tab <- Scanner$create(x)$ToTable()
} else {
# This is a Table or RecordBatch

# Filter and select the data referenced in selected columns
if (isTRUE(x$filtered_rows)) {
filter <- TRUE
} else {
filter <- eval_array_expression(x$filtered_rows, x$.data)
}
# TODO: shortcut if identical(names(x$.data), find_array_refs(c(x$selected_columns, x$temp_columns)))?
tab <- x$.data[
filter,
find_array_refs(c(x$selected_columns, x$temp_columns)),
keep_na = FALSE
]
# Now evaluate those expressions on the filtered table
cols <- lapply(c(x$selected_columns, x$temp_columns), eval_array_expression, data = tab)
if (length(cols) == 0) {
tab <- tab[, integer(0)]
} else {
if (inherits(x$.data, "Table")) {
tab <- Table$create(!!!cols)
} else {
tab <- RecordBatch$create(!!!cols)
}
}
}
# See dataset.R for Dataset and Scanner(Builder) classes
tab <- Scanner$create(x)$ToTable()
# Arrange rows
if (length(x$arrange_vars) > 0) {
tab <- tab[
Expand Down Expand Up @@ -797,7 +719,7 @@ ensure_group_vars <- function(x) {
# Add them back
x$selected_columns <- c(
x$selected_columns,
make_field_refs(gv, dataset = query_on_dataset(.data))
make_field_refs(gv)
)
}
}
Expand Down Expand Up @@ -992,7 +914,6 @@ mutate.arrow_dplyr_query <- function(.data,
# Deparse and take the first element in case they're long expressions
names(exprs)[unnamed] <- map_chr(exprs[unnamed], as_label)

is_dataset <- query_on_dataset(.data)
mask <- arrow_mask(.data)
results <- list()
for (i in seq_along(exprs)) {
Expand All @@ -1003,8 +924,7 @@ mutate.arrow_dplyr_query <- function(.data,
if (inherits(results[[new_var]], "try-error")) {
msg <- paste('Expression', as_label(exprs[[i]]), 'not supported in Arrow')
return(abandon_ship(call, .data, msg))
} else if (is_dataset &&
!inherits(results[[new_var]], "Expression") &&
} else if (!inherits(results[[new_var]], "Expression") &&
!is.null(results[[new_var]])) {
# We need some wrapping to handle literal values
if (length(results[[new_var]]) != 1) {
Expand Down Expand Up @@ -1153,7 +1073,7 @@ find_and_remove_desc <- function(quosure) {
)
}

query_on_dataset <- function(x) inherits(x$.data, "Dataset")
query_on_dataset <- function(x) inherits(x$.data, "Dataset") && !inherits(x$.data, "InMemoryDataset")

not_implemented_for_dataset <- function(method) {
stop(
Expand Down
5 changes: 3 additions & 2 deletions r/tests/testthat/test-dplyr-mutate.R
Original file line number Diff line number Diff line change
Expand Up @@ -344,20 +344,21 @@ test_that("print a mutated table", {
select(int) %>%
mutate(twice = int * 2) %>%
print(),
'Table (query)
'InMemoryDataset (query)
int: int32
twice: expr
See $.data for the source Arrow object',
fixed = TRUE)

# Handling non-expressions/edge cases
skip("InMemoryDataset$Project() doesn't accept array (or could it?)")
expect_output(
Table$create(tbl) %>%
select(int) %>%
mutate(again = 1:10) %>%
print(),
'Table (query)
'InMemoryDataset (query)
int: int32
again: expr
Expand Down
30 changes: 16 additions & 14 deletions r/tests/testthat/test-dplyr.R
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,11 @@ test_that("Print method", {
filter(int < 5) %>%
select(int, chr) %>%
print(),
'RecordBatch (query)
'InMemoryDataset (query)
int: int32
chr: string
* Filter: and(and(greater(dbl, 2), or(equal(chr, "d"), equal(chr, "f"))), less(int, 5))
* Filter: (((dbl > 2) and ((chr == "d") or (chr == "f"))) and (int < 5))
See $.data for the source Arrow object',
fixed = TRUE
)
Expand Down Expand Up @@ -187,15 +187,16 @@ test_that("collect(as_data_frame=FALSE)", {
filter(int > 5) %>%
collect(as_data_frame = FALSE)

expect_r6_class(b2, "RecordBatch")
# collect(as_data_frame = FALSE) always returns Table now
expect_r6_class(b2, "Table")
expected <- tbl[tbl$int > 5 & !is.na(tbl$int), c("int", "chr")]
expect_equal(as.data.frame(b2), expected)

b3 <- batch %>%
select(int, strng = chr) %>%
filter(int > 5) %>%
collect(as_data_frame = FALSE)
expect_r6_class(b3, "RecordBatch")
expect_r6_class(b3, "Table")
expect_equal(as.data.frame(b3), set_names(expected, c("int", "strng")))

b4 <- batch %>%
Expand All @@ -217,30 +218,30 @@ test_that("compute()", {

b1 <- batch %>% compute()

expect_is(b1, "RecordBatch")
expect_r6_class(b1, "RecordBatch")

b2 <- batch %>%
select(int, chr) %>%
filter(int > 5) %>%
compute()

expect_is(b2, "RecordBatch")
expect_r6_class(b2, "Table")
expected <- tbl[tbl$int > 5 & !is.na(tbl$int), c("int", "chr")]
expect_equal(as.data.frame(b2), expected)

b3 <- batch %>%
select(int, strng = chr) %>%
filter(int > 5) %>%
compute()
expect_is(b3, "RecordBatch")
expect_r6_class(b3, "Table")
expect_equal(as.data.frame(b3), set_names(expected, c("int", "strng")))

b4 <- batch %>%
select(int, strng = chr) %>%
filter(int > 5) %>%
group_by(int) %>%
compute()
expect_is(b4, "arrow_dplyr_query")
expect_s3_class(b4, "arrow_dplyr_query")
expect_equal(
as.data.frame(b4),
expected %>%
Expand All @@ -257,23 +258,24 @@ test_that("head", {
filter(int > 5) %>%
head(2)

expect_r6_class(b2, "RecordBatch")
expect_r6_class(b2, "Table")
expected <- tbl[tbl$int > 5 & !is.na(tbl$int), c("int", "chr")][1:2, ]
expect_equal(as.data.frame(b2), expected)

b3 <- batch %>%
select(int, strng = chr) %>%
filter(int > 5) %>%
head(2)
expect_r6_class(b3, "RecordBatch")
expect_r6_class(b3, "Table")
print(as.data.frame(b3))
expect_equal(as.data.frame(b3), set_names(expected, c("int", "strng")))

b4 <- batch %>%
select(int, strng = chr) %>%
filter(int > 5) %>%
group_by(int) %>%
head(2)
expect_s3_class(b4, "arrow_dplyr_query")
expect_s3_class(b4, "Table")
expect_equal(
as.data.frame(b4),
expected %>%
Expand All @@ -290,23 +292,23 @@ test_that("tail", {
filter(int > 5) %>%
tail(2)

expect_r6_class(b2, "RecordBatch")
expect_r6_class(b2, "Table")
expected <- tail(tbl[tbl$int > 5 & !is.na(tbl$int), c("int", "chr")], 2)
expect_equal(as.data.frame(b2), expected)

b3 <- batch %>%
select(int, strng = chr) %>%
filter(int > 5) %>%
tail(2)
expect_r6_class(b3, "RecordBatch")
expect_r6_class(b3, "Table")
expect_equal(as.data.frame(b3), set_names(expected, c("int", "strng")))

b4 <- batch %>%
select(int, strng = chr) %>%
filter(int > 5) %>%
group_by(int) %>%
tail(2)
expect_s3_class(b4, "arrow_dplyr_query")
expect_s3_class(b4, "Table")
expect_equal(
as.data.frame(b4),
expected %>%
Expand Down

0 comments on commit 5b501c5

Please sign in to comment.