From 2e6374f94cbcc236becc3e41797a26127cf06ab0 Mon Sep 17 00:00:00 2001 From: Neal Richardson Date: Mon, 10 May 2021 08:42:42 -0700 Subject: [PATCH] Add Expression(schema) method and improve adq print method --- r/R/arrowExports.R | 4 ++++ r/R/dplyr.R | 15 +++++++++------ r/R/expression.R | 1 + r/src/arrowExports.cpp | 17 +++++++++++++++++ r/src/expression.cpp | 8 ++++++++ r/tests/testthat/test-dataset.R | 4 ++-- r/tests/testthat/test-dplyr-mutate.R | 5 +++-- r/tests/testthat/test-expression.R | 8 ++++++++ 8 files changed, 52 insertions(+), 10 deletions(-) diff --git a/r/R/arrowExports.R b/r/R/arrowExports.R index 0063836970e2e..c026c72899f6f 100644 --- a/r/R/arrowExports.R +++ b/r/R/arrowExports.R @@ -800,6 +800,10 @@ compute___expr__ToString <- function(x){ .Call(`_arrow_compute___expr__ToString`, x) } +compute___expr__type <- function(x, schema){ + .Call(`_arrow_compute___expr__type`, x, schema) +} + ipc___WriteFeather__Table <- function(stream, table, version, chunk_size, compression, compression_level){ invisible(.Call(`_arrow_ipc___WriteFeather__Table`, stream, table, version, chunk_size, compression, compression_level)) } diff --git a/r/R/dplyr.R b/r/R/dplyr.R index 342fabf3c7b53..ca100cdd5e757 100644 --- a/r/R/dplyr.R +++ b/r/R/dplyr.R @@ -70,17 +70,20 @@ make_field_refs <- function(field_names) { #' @export print.arrow_dplyr_query <- function(x, ...) { schm <- x$.data$schema - fields <- map_chr(x$selected_columns, function(expr) { + types <- map_chr(x$selected_columns, function(expr) { name <- expr$field_name if (nzchar(name)) { - schm$GetFieldByName(name)$ToString() + # Just a field_ref, so look up in the schema + schm$GetFieldByName(name)$type$ToString() } else { - # It's "" because this is not a field_ref, it's a more complex expression - "expr" + # Expression, so get its type and append the expression + paste0( + expr$type(schm)$ToString(), + " (", expr$ToString(), ")" + ) } }) - # Strip off the field names as they are in the dataset and add the renamed ones - fields <- paste(names(fields), sub("^.*?: ", "", fields), sep = ": ", collapse = "\n") + fields <- paste(names(types), types, sep = ": ", collapse = "\n") cat(class(x$.data)[1], " (query)\n", sep = "") cat(fields, "\n", sep = "") cat("\n") diff --git a/r/R/expression.R b/r/R/expression.R index 30eb0906d4370..03b2e0face037 100644 --- a/r/R/expression.R +++ b/r/R/expression.R @@ -250,6 +250,7 @@ print.array_expression <- function(x, ...) { Expression <- R6Class("Expression", inherit = ArrowObject, public = list( ToString = function() compute___expr__ToString(self), + type = function(schema) compute___expr__type(self, schema), cast = function(to_type, safe = TRUE, ...) { opts <- list( to_type = to_type, diff --git a/r/src/arrowExports.cpp b/r/src/arrowExports.cpp index b274ac5f3af27..9d75b2da54e2a 100644 --- a/r/src/arrowExports.cpp +++ b/r/src/arrowExports.cpp @@ -3097,6 +3097,22 @@ extern "C" SEXP _arrow_compute___expr__ToString(SEXP x_sexp){ } #endif +// expression.cpp +#if defined(ARROW_R_WITH_ARROW) +std::shared_ptr compute___expr__type(const std::shared_ptr& x, const std::shared_ptr& schema); +extern "C" SEXP _arrow_compute___expr__type(SEXP x_sexp, SEXP schema_sexp){ +BEGIN_CPP11 + arrow::r::Input&>::type x(x_sexp); + arrow::r::Input&>::type schema(schema_sexp); + return cpp11::as_sexp(compute___expr__type(x, schema)); +END_CPP11 +} +#else +extern "C" SEXP _arrow_compute___expr__type(SEXP x_sexp, SEXP schema_sexp){ + Rf_error("Cannot call compute___expr__type(). See https://arrow.apache.org/docs/r/articles/install.html for help installing Arrow C++ libraries. "); +} +#endif + // feather.cpp #if defined(ARROW_R_WITH_ARROW) void ipc___WriteFeather__Table(const std::shared_ptr& stream, const std::shared_ptr& table, int version, int chunk_size, arrow::Compression::type compression, int compression_level); @@ -6884,6 +6900,7 @@ static const R_CallMethodDef CallEntries[] = { { "_arrow_compute___expr__get_field_ref_name", (DL_FUNC) &_arrow_compute___expr__get_field_ref_name, 1}, { "_arrow_compute___expr__scalar", (DL_FUNC) &_arrow_compute___expr__scalar, 1}, { "_arrow_compute___expr__ToString", (DL_FUNC) &_arrow_compute___expr__ToString, 1}, + { "_arrow_compute___expr__type", (DL_FUNC) &_arrow_compute___expr__type, 2}, { "_arrow_ipc___WriteFeather__Table", (DL_FUNC) &_arrow_ipc___WriteFeather__Table, 6}, { "_arrow_ipc___feather___Reader__version", (DL_FUNC) &_arrow_ipc___feather___Reader__version, 1}, { "_arrow_ipc___feather___Reader__Read", (DL_FUNC) &_arrow_ipc___feather___Reader__Read, 2}, diff --git a/r/src/expression.cpp b/r/src/expression.cpp index 798853edd720d..d8745ade47990 100644 --- a/r/src/expression.cpp +++ b/r/src/expression.cpp @@ -68,4 +68,12 @@ std::string compute___expr__ToString(const std::shared_ptr& return x->ToString(); } +// [[arrow::export]] +std::shared_ptr compute___expr__type( + const std::shared_ptr& x, + const std::shared_ptr& schema) { + auto bound = ValueOrStop(x->Bind(*schema)); + return bound.type(); +} + #endif diff --git a/r/tests/testthat/test-dataset.R b/r/tests/testthat/test-dataset.R index 585c44972da54..56a35f98f2b98 100644 --- a/r/tests/testthat/test-dataset.R +++ b/r/tests/testthat/test-dataset.R @@ -965,7 +965,7 @@ test_that("mutate()", { chr: string dbl: double int: int32 -twice: expr +twice: double (multiply_checked(int, 2)) * Filter: ((multiply_checked(dbl, 2) > 14) and (subtract_checked(dbl, 50) < 3)) See $.data for the source Arrow object", @@ -1120,7 +1120,7 @@ test_that("arrange()", { chr: string dbl: double int: int32 -twice: expr +twice: double (multiply_checked(int, 2)) * Filter: ((multiply_checked(dbl, 2) > 14) and (subtract_checked(dbl, 50) < 3)) * Sorted by chr [asc], multiply_checked(int, 2) [desc], add_checked(dbl, int) [asc] diff --git a/r/tests/testthat/test-dplyr-mutate.R b/r/tests/testthat/test-dplyr-mutate.R index 898f3008f9044..29806b841fcad 100644 --- a/r/tests/testthat/test-dplyr-mutate.R +++ b/r/tests/testthat/test-dplyr-mutate.R @@ -346,10 +346,11 @@ test_that("print a mutated table", { print(), 'InMemoryDataset (query) int: int32 -twice: expr +twice: double (multiply_checked(int, 2)) See $.data for the source Arrow object', - fixed = TRUE) + fixed = TRUE + ) # Handling non-expressions/edge cases skip("InMemoryDataset$Project() doesn't accept array (or could it?)") diff --git a/r/tests/testthat/test-expression.R b/r/tests/testthat/test-expression.R index dd61b5e3ca26f..b175a73cfa758 100644 --- a/r/tests/testthat/test-expression.R +++ b/r/tests/testthat/test-expression.R @@ -76,6 +76,14 @@ test_that("C++ expressions", { 'Expression\n(f > 4)', fixed = TRUE ) + expect_type_equal( + f$type(schema(f = float64())), + float64() + ) + expect_type_equal( + (f > 4)$type(schema(f = float64())), + bool() + ) # Interprets that as a list type expect_r6_class(f == c(1L, 2L), "Expression") })