Skip to content

Commit

Permalink
ARROW-15040: [R] Enable write_csv_arrow to take a Dataset or arrow_dp…
Browse files Browse the repository at this point in the history
…lyr_query as input

Closes apache#11971 from thisisnic/ARROW-15040_rb_reader

Lead-authored-by: Nic Crane <[email protected]>
Co-authored-by: Nicola Crane <[email protected]>
Signed-off-by: Nic Crane <[email protected]>
  • Loading branch information
thisisnic committed Mar 1, 2022
1 parent 676b49f commit 13045f4
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 10 deletions.
4 changes: 4 additions & 0 deletions r/R/arrowExports.R

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 18 additions & 2 deletions r/R/csv.R
Original file line number Diff line number Diff line change
Expand Up @@ -699,15 +699,18 @@ write_csv_arrow <- function(x,
if (is.null(write_options)) {
write_options <- readr_to_csv_write_options(
include_header = include_header,
batch_size = batch_size)
batch_size = batch_size
)
}

x_out <- x
if (is.data.frame(x)) {
x <- Table$create(x)
}

assert_that(is_writable_table(x))
if (inherits(x, c("Dataset", "arrow_dplyr_query"))) {
x <- Scanner$create(x)$ToRecordBatchReader()
}

if (!inherits(sink, "OutputStream")) {
sink <- make_output_stream(sink)
Expand All @@ -718,6 +721,19 @@ write_csv_arrow <- function(x,
csv___WriteCSV__RecordBatch(x, write_options, sink)
} else if (inherits(x, "Table")) {
csv___WriteCSV__Table(x, write_options, sink)
} else if (inherits(x, c("RecordBatchReader"))) {
csv___WriteCSV__RecordBatchReader(x, write_options, sink)
} else {
abort(
c(
paste0(
paste(
"x must be an object of class 'data.frame', 'RecordBatch',",
"'Dataset', 'Table', or 'RecordBatchReader' not '"
), class(x)[[1]], "'."
)
)
)
}

invisible(x_out)
Expand Down
19 changes: 19 additions & 0 deletions r/src/arrowExports.cpp

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions r/src/csv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -202,4 +202,12 @@ void csv___WriteCSV__RecordBatch(
StopIfNotOk(arrow::csv::WriteCSV(*record_batch, *write_options, stream.get()));
}

// [[arrow::export]]
void csv___WriteCSV__RecordBatchReader(
const std::shared_ptr<arrow::RecordBatchReader>& reader,
const std::shared_ptr<arrow::csv::WriteOptions>& write_options,
const std::shared_ptr<arrow::io::OutputStream>& stream) {
StopIfNotOk(arrow::csv::WriteCSV(reader, *write_options, stream.get()));
}

#endif
55 changes: 47 additions & 8 deletions r/tests/testthat/test-csv.R
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ test_that("Write a CSV file with invalid input type", {
bad_input <- Array$create(1:5)
expect_error(
write_csv_arrow(bad_input, csv_file),
regexp = "x must be an object of class 'data.frame', 'RecordBatch', or 'Table', not 'Array'."
regexp = "x must be an object of class .* not 'Array'."
)
})

Expand Down Expand Up @@ -461,18 +461,23 @@ test_that("Writing a CSV errors when unsupported (yet) readr args are used", {
append = FALSE,
quote = "all",
escape = "double",
eol = "\n", ),
paste("The following arguments are not yet supported in Arrow: \"append\",",
"\"quote\", \"escape\", and \"eol\"")
eol = "\n"
),
paste(
"The following arguments are not yet supported in Arrow: \"append\",",
"\"quote\", \"escape\", and \"eol\""
)
)
})

test_that("write_csv_arrow deals with duplication in sink/file", {
# errors when both file and sink are supplied
expect_error(
write_csv_arrow(tbl, file = csv_file, sink = csv_file),
paste("You have supplied both \"file\" and \"sink\" arguments. Please",
"supply only one of them")
paste(
"You have supplied both \"file\" and \"sink\" arguments. Please",
"supply only one of them"
)
)
})

Expand All @@ -484,8 +489,10 @@ test_that("write_csv_arrow deals with duplication in include_headers/col_names",
include_header = TRUE,
col_names = TRUE
),
paste("You have supplied both \"col_names\" and \"include_header\"",
"arguments. Please supply only one of them")
paste(
"You have supplied both \"col_names\" and \"include_header\"",
"arguments. Please supply only one of them"
)
)

written_tbl <- suppressMessages(
Expand All @@ -503,3 +510,35 @@ test_that("read_csv_arrow() deals with BOMs (byte-order-marks) correctly", {
tibble(a = 1, b = 2)
)
})

test_that("write_csv_arrow can write from Dataset objects", {
skip_if_not_available("dataset")
data_dir <- make_temp_dir()
write_dataset(tbl_no_dates, data_dir, partitioning = "lgl")
data_in <- open_dataset(data_dir)

csv_file <- tempfile()
tbl_out <- write_csv_arrow(data_in, csv_file)
expect_true(file.exists(csv_file))

tbl_in <- read_csv_arrow(csv_file)
expect_named(tbl_in, c("dbl", "false", "chr", "lgl"))
expect_equal(nrow(tbl_in), 10)
})

test_that("write_csv_arrow can write from RecordBatchReader objects", {
skip_if_not_available("dataset")
library(dplyr, warn.conflicts = FALSE)

query_obj <- arrow_table(tbl_no_dates) %>%
filter(lgl == TRUE)

csv_file <- tempfile()
on.exit(unlink(csv_file))
tbl_out <- write_csv_arrow(query_obj, csv_file)
expect_true(file.exists(csv_file))

tbl_in <- read_csv_arrow(csv_file)
expect_named(tbl_in, c("dbl", "lgl", "false", "chr"))
expect_equal(nrow(tbl_in), 3)
})

0 comments on commit 13045f4

Please sign in to comment.