From 13045f4b2971ccb1b77c03b12e245015c0e4fcb8 Mon Sep 17 00:00:00 2001 From: Nic Crane Date: Tue, 1 Mar 2022 18:08:33 +0000 Subject: [PATCH] ARROW-15040: [R] Enable write_csv_arrow to take a Dataset or arrow_dplyr_query as input Closes #11971 from thisisnic/ARROW-15040_rb_reader Lead-authored-by: Nic Crane Co-authored-by: Nicola Crane Signed-off-by: Nic Crane --- r/R/arrowExports.R | 4 +++ r/R/csv.R | 20 ++++++++++++-- r/src/arrowExports.cpp | 19 +++++++++++++ r/src/csv.cpp | 8 ++++++ r/tests/testthat/test-csv.R | 55 +++++++++++++++++++++++++++++++------ 5 files changed, 96 insertions(+), 10 deletions(-) diff --git a/r/R/arrowExports.R b/r/R/arrowExports.R index 4689543b5b89f..c9468f52ae3b3 100644 --- a/r/R/arrowExports.R +++ b/r/R/arrowExports.R @@ -516,6 +516,10 @@ csv___WriteCSV__RecordBatch <- function(record_batch, write_options, stream) { invisible(.Call(`_arrow_csv___WriteCSV__RecordBatch`, record_batch, write_options, stream)) } +csv___WriteCSV__RecordBatchReader <- function(reader, write_options, stream) { + invisible(.Call(`_arrow_csv___WriteCSV__RecordBatchReader`, reader, write_options, stream)) +} + dataset___Dataset__NewScan <- function(ds) { .Call(`_arrow_dataset___Dataset__NewScan`, ds) } diff --git a/r/R/csv.R b/r/R/csv.R index 1842394f6b917..0ffd93a9caafa 100644 --- a/r/R/csv.R +++ b/r/R/csv.R @@ -699,7 +699,8 @@ write_csv_arrow <- function(x, if (is.null(write_options)) { write_options <- readr_to_csv_write_options( include_header = include_header, - batch_size = batch_size) + batch_size = batch_size + ) } x_out <- x @@ -707,7 +708,9 @@ write_csv_arrow <- function(x, x <- Table$create(x) } - assert_that(is_writable_table(x)) + if (inherits(x, c("Dataset", "arrow_dplyr_query"))) { + x <- Scanner$create(x)$ToRecordBatchReader() + } if (!inherits(sink, "OutputStream")) { sink <- make_output_stream(sink) @@ -718,6 +721,19 @@ write_csv_arrow <- function(x, csv___WriteCSV__RecordBatch(x, write_options, sink) } else if (inherits(x, "Table")) { csv___WriteCSV__Table(x, write_options, sink) + } else if (inherits(x, c("RecordBatchReader"))) { + csv___WriteCSV__RecordBatchReader(x, write_options, sink) + } else { + abort( + c( + paste0( + paste( + "x must be an object of class 'data.frame', 'RecordBatch',", + "'Dataset', 'Table', or 'RecordBatchReader' not '" + ), class(x)[[1]], "'." + ) + ) + ) } invisible(x_out) diff --git a/r/src/arrowExports.cpp b/r/src/arrowExports.cpp index eb77fdc40d1dd..59762790fa3cf 100644 --- a/r/src/arrowExports.cpp +++ b/r/src/arrowExports.cpp @@ -2027,6 +2027,24 @@ extern "C" SEXP _arrow_csv___WriteCSV__RecordBatch(SEXP record_batch_sexp, SEXP } #endif +// csv.cpp +#if defined(ARROW_R_WITH_ARROW) +void csv___WriteCSV__RecordBatchReader(const std::shared_ptr& reader, const std::shared_ptr& write_options, const std::shared_ptr& stream); +extern "C" SEXP _arrow_csv___WriteCSV__RecordBatchReader(SEXP reader_sexp, SEXP write_options_sexp, SEXP stream_sexp){ +BEGIN_CPP11 + arrow::r::Input&>::type reader(reader_sexp); + arrow::r::Input&>::type write_options(write_options_sexp); + arrow::r::Input&>::type stream(stream_sexp); + csv___WriteCSV__RecordBatchReader(reader, write_options, stream); + return R_NilValue; +END_CPP11 +} +#else +extern "C" SEXP _arrow_csv___WriteCSV__RecordBatchReader(SEXP reader_sexp, SEXP write_options_sexp, SEXP stream_sexp){ + Rf_error("Cannot call csv___WriteCSV__RecordBatchReader(). See https://arrow.apache.org/docs/r/articles/install.html for help installing Arrow C++ libraries. "); +} +#endif + // dataset.cpp #if defined(ARROW_R_WITH_DATASET) std::shared_ptr dataset___Dataset__NewScan(const std::shared_ptr& ds); @@ -7729,6 +7747,7 @@ static const R_CallMethodDef CallEntries[] = { { "_arrow_TimestampParser__MakeISO8601", (DL_FUNC) &_arrow_TimestampParser__MakeISO8601, 0}, { "_arrow_csv___WriteCSV__Table", (DL_FUNC) &_arrow_csv___WriteCSV__Table, 3}, { "_arrow_csv___WriteCSV__RecordBatch", (DL_FUNC) &_arrow_csv___WriteCSV__RecordBatch, 3}, + { "_arrow_csv___WriteCSV__RecordBatchReader", (DL_FUNC) &_arrow_csv___WriteCSV__RecordBatchReader, 3}, { "_arrow_dataset___Dataset__NewScan", (DL_FUNC) &_arrow_dataset___Dataset__NewScan, 1}, { "_arrow_dataset___Dataset__schema", (DL_FUNC) &_arrow_dataset___Dataset__schema, 1}, { "_arrow_dataset___Dataset__type_name", (DL_FUNC) &_arrow_dataset___Dataset__type_name, 1}, diff --git a/r/src/csv.cpp b/r/src/csv.cpp index 93d07d82ed441..bb901e798d73e 100644 --- a/r/src/csv.cpp +++ b/r/src/csv.cpp @@ -202,4 +202,12 @@ void csv___WriteCSV__RecordBatch( StopIfNotOk(arrow::csv::WriteCSV(*record_batch, *write_options, stream.get())); } +// [[arrow::export]] +void csv___WriteCSV__RecordBatchReader( + const std::shared_ptr& reader, + const std::shared_ptr& write_options, + const std::shared_ptr& stream) { + StopIfNotOk(arrow::csv::WriteCSV(reader, *write_options, stream.get())); +} + #endif diff --git a/r/tests/testthat/test-csv.R b/r/tests/testthat/test-csv.R index bc9354342c3ed..2f3ac7a941e04 100644 --- a/r/tests/testthat/test-csv.R +++ b/r/tests/testthat/test-csv.R @@ -389,7 +389,7 @@ test_that("Write a CSV file with invalid input type", { bad_input <- Array$create(1:5) expect_error( write_csv_arrow(bad_input, csv_file), - regexp = "x must be an object of class 'data.frame', 'RecordBatch', or 'Table', not 'Array'." + regexp = "x must be an object of class .* not 'Array'." ) }) @@ -461,9 +461,12 @@ test_that("Writing a CSV errors when unsupported (yet) readr args are used", { append = FALSE, quote = "all", escape = "double", - eol = "\n", ), - paste("The following arguments are not yet supported in Arrow: \"append\",", - "\"quote\", \"escape\", and \"eol\"") + eol = "\n" + ), + paste( + "The following arguments are not yet supported in Arrow: \"append\",", + "\"quote\", \"escape\", and \"eol\"" + ) ) }) @@ -471,8 +474,10 @@ test_that("write_csv_arrow deals with duplication in sink/file", { # errors when both file and sink are supplied expect_error( write_csv_arrow(tbl, file = csv_file, sink = csv_file), - paste("You have supplied both \"file\" and \"sink\" arguments. Please", - "supply only one of them") + paste( + "You have supplied both \"file\" and \"sink\" arguments. Please", + "supply only one of them" + ) ) }) @@ -484,8 +489,10 @@ test_that("write_csv_arrow deals with duplication in include_headers/col_names", include_header = TRUE, col_names = TRUE ), - paste("You have supplied both \"col_names\" and \"include_header\"", - "arguments. Please supply only one of them") + paste( + "You have supplied both \"col_names\" and \"include_header\"", + "arguments. Please supply only one of them" + ) ) written_tbl <- suppressMessages( @@ -503,3 +510,35 @@ test_that("read_csv_arrow() deals with BOMs (byte-order-marks) correctly", { tibble(a = 1, b = 2) ) }) + +test_that("write_csv_arrow can write from Dataset objects", { + skip_if_not_available("dataset") + data_dir <- make_temp_dir() + write_dataset(tbl_no_dates, data_dir, partitioning = "lgl") + data_in <- open_dataset(data_dir) + + csv_file <- tempfile() + tbl_out <- write_csv_arrow(data_in, csv_file) + expect_true(file.exists(csv_file)) + + tbl_in <- read_csv_arrow(csv_file) + expect_named(tbl_in, c("dbl", "false", "chr", "lgl")) + expect_equal(nrow(tbl_in), 10) +}) + +test_that("write_csv_arrow can write from RecordBatchReader objects", { + skip_if_not_available("dataset") + library(dplyr, warn.conflicts = FALSE) + + query_obj <- arrow_table(tbl_no_dates) %>% + filter(lgl == TRUE) + + csv_file <- tempfile() + on.exit(unlink(csv_file)) + tbl_out <- write_csv_arrow(query_obj, csv_file) + expect_true(file.exists(csv_file)) + + tbl_in <- read_csv_arrow(csv_file) + expect_named(tbl_in, c("dbl", "lgl", "false", "chr")) + expect_equal(nrow(tbl_in), 3) +})