Skip to content

Commit

Permalink
Fix some failures, find another?
Browse files Browse the repository at this point in the history
  • Loading branch information
nealrichardson committed Apr 30, 2021
1 parent 95ddc57 commit 9640cff
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 23 deletions.
14 changes: 5 additions & 9 deletions r/R/dataset-scan.R
Original file line number Diff line number Diff line change
Expand Up @@ -161,16 +161,12 @@ ScannerBuilder <- R6Class("ScannerBuilder", inherit = ArrowObject,
# cols is either a character vector or a named list of Expressions
if (is.character(cols)) {
dataset___ScannerBuilder__ProjectNames(self, cols)
} else if (length(cols) == 0) {
# Empty projection
dataset___ScannerBuilder__ProjectNames(self, character(0))
} else {
# If we have expressions, but they all turn out to be field_refs,
# we can still call the simple method
field_names <- get_field_names(cols)
if (all(nzchar(field_names))) {
dataset___ScannerBuilder__ProjectNames(self, field_names)
} else {
# Else, we are projecting/mutating
dataset___ScannerBuilder__ProjectExprs(self, cols, names(cols))
}
# List of Expressions
dataset___ScannerBuilder__ProjectExprs(self, cols, names(cols))
}
self
},
Expand Down
11 changes: 9 additions & 2 deletions r/R/dplyr.R
Original file line number Diff line number Diff line change
Expand Up @@ -151,13 +151,20 @@ as.data.frame.arrow_dplyr_query <- function(x, row.names = NULL, optional = FALS
}

#' @export
head.arrow_dplyr_query <- head.Dataset
head.arrow_dplyr_query <- function(x, n = 6L, ...) {
out <- head.Dataset(x, n, ...)
restore_dplyr_features(out, x)
}

#' @export
tail.arrow_dplyr_query <- tail.Dataset
tail.arrow_dplyr_query <- function(x, n = 6L, ...) {
out <- tail.Dataset(x, n, ...)
restore_dplyr_features(out, x)
}

#' @export
`[.arrow_dplyr_query` <- `[.Dataset`
# TODO: ^ should also probably restore_dplyr_features, and/or that should be moved down

# The following S3 methods are registered on load if dplyr is present
tbl_vars.arrow_dplyr_query <- function(x) names(x$selected_columns)
Expand Down
15 changes: 6 additions & 9 deletions r/tests/testthat/test-dataset.R
Original file line number Diff line number Diff line change
Expand Up @@ -1197,8 +1197,8 @@ test_that("compute()/collect(as_data_frame=FALSE)", {
# the group_by() prevents compute() from returning a Table...
expect_is(tab5, "arrow_dplyr_query")

# ... but $.data is a Table...
expect_is(tab5$.data, "Table")
# ... but $.data is a Table (InMemoryDataset)...
expect_r6_class(tab5$.data, "InMemoryDataset")
# ... and the mutate() was evaluated
expect_true("negint" %in% names(tab5$.data))

Expand Down Expand Up @@ -1561,13 +1561,14 @@ test_that("Dataset writing: dplyr methods", {
dst_dir2 <- tempfile()
ds %>%
group_by(int) %>%
select(chr, dbl) %>%
select(chr, dubs = dbl) %>%
write_dataset(dst_dir2, format = "feather")
new_ds <- open_dataset(dst_dir2, format = "feather")

# Renaming doesn't work, but mutating does??
expect_equivalent(
collect(new_ds) %>% arrange(int),
rbind(df1[c("chr", "dbl", "int")], df2[c("chr", "dbl", "int")])
collect(new_ds) %>% arrange(int) %>% print(),
rbind(df1[c("chr", "dbl", "int")], df2[c("chr", "dbl", "int")]) %>% rename(dubs = dbl) %>% print()
)

# filter to restrict written rows
Expand Down Expand Up @@ -1758,10 +1759,6 @@ test_that("Dataset writing: unsupported features/input validation", {
expect_error(write_dataset(4), 'dataset must be a "Dataset"')

ds <- open_dataset(hive_dir)
expect_error(
select(ds, integer = int) %>% write_dataset(ds),
"Renaming columns when writing a dataset is not yet supported"
)
expect_error(
write_dataset(ds, partitioning = c("int", "NOTACOLUMN"), format = "ipc"),
'Invalid field name: "NOTACOLUMN"'
Expand Down
5 changes: 2 additions & 3 deletions r/tests/testthat/test-dplyr.R
Original file line number Diff line number Diff line change
Expand Up @@ -267,15 +267,14 @@ test_that("head", {
filter(int > 5) %>%
head(2)
expect_r6_class(b3, "Table")
print(as.data.frame(b3))
expect_equal(as.data.frame(b3), set_names(expected, c("int", "strng")))

b4 <- batch %>%
select(int, strng = chr) %>%
filter(int > 5) %>%
group_by(int) %>%
head(2)
expect_s3_class(b4, "Table")
expect_s3_class(b4, "arrow_dplyr_query")
expect_equal(
as.data.frame(b4),
expected %>%
Expand Down Expand Up @@ -308,7 +307,7 @@ test_that("tail", {
filter(int > 5) %>%
group_by(int) %>%
tail(2)
expect_s3_class(b4, "Table")
expect_s3_class(b4, "arrow_dplyr_query")
expect_equal(
as.data.frame(b4),
expected %>%
Expand Down

0 comments on commit 9640cff

Please sign in to comment.