Skip to content

Commit

Permalink
GH-38602: [R] Add missing prod for summarize (#38601)
Browse files Browse the repository at this point in the history
### Rationale for this change

`prod` is currently missing for use in summarize.

### What changes are included in this PR?

Added `prod` for summarize aggregation.

### Are these changes tested?

Yes, included the same tests used for the other aggregation functions for summarize.

### Are there any user-facing changes?

Yes, added `prod` function.

* Closes: #38602

Authored-by: Maximilian Muecke <[email protected]>
Signed-off-by: Dewey Dunnington <[email protected]>
  • Loading branch information
m-muecke authored Nov 6, 2023
1 parent c73cb13 commit 6dcba93
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 1 deletion.
1 change: 1 addition & 0 deletions r/R/dplyr-funcs-doc.R
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@
#' * [`paste0()`][base::paste0()]: the `collapse` argument is not yet supported
#' * [`pmax()`][base::pmax()]
#' * [`pmin()`][base::pmin()]
#' * [`prod()`][base::prod()]
#' * [`round()`][base::round()]
#' * [`sign()`][base::sign()]
#' * [`sin()`][base::sin()]
Expand Down
7 changes: 7 additions & 0 deletions r/R/dplyr-summarize.R
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@ register_bindings_aggregate <- function() {
options = list(skip_nulls = na.rm, min_count = 0L)
)
})
register_binding_agg("base::prod", function(..., na.rm = FALSE) {
list(
fun = "product",
data = ensure_one_arg(list2(...), "prod"),
options = list(skip_nulls = na.rm, min_count = 0L)
)
})
register_binding_agg("base::any", function(..., na.rm = FALSE) {
list(
fun = "any",
Expand Down
1 change: 1 addition & 0 deletions r/man/acero.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion r/src/compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,8 @@ std::shared_ptr<arrow::compute::FunctionOptions> make_compute_options(
func_name == "hash_approximate_median" || func_name == "mean" ||
func_name == "hash_mean" || func_name == "min_max" || func_name == "hash_min_max" ||
func_name == "min" || func_name == "hash_min" || func_name == "max" ||
func_name == "hash_max" || func_name == "sum" || func_name == "hash_sum") {
func_name == "hash_max" || func_name == "sum" || func_name == "hash_sum" ||
func_name == "product" || func_name == "hash_product") {
using Options = arrow::compute::ScalarAggregateOptions;
auto out = std::make_shared<Options>(Options::Defaults());
if (!Rf_isNull(options["min_count"])) {
Expand Down
23 changes: 23 additions & 0 deletions r/tests/testthat/test-dplyr-summarize.R
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,27 @@ test_that("Group by sum on dataset", {
)
})

test_that("Group by prod on dataset", {
compare_dplyr_binding(
.input %>%
group_by(some_grouping) %>%
summarize(prod = prod(int, na.rm = TRUE)) %>%
collect(),
tbl
)

compare_dplyr_binding(
.input %>%
group_by(some_grouping) %>%
summarize(
prod = prod(int, na.rm = FALSE),
prod2 = base::prod(int, na.rm = TRUE)
) %>%
collect(),
tbl
)
})

test_that("Group by mean on dataset", {
compare_dplyr_binding(
.input %>%
Expand Down Expand Up @@ -319,6 +340,7 @@ test_that("Functions that take ... but we only accept a single arg", {
# the agg_funcs directly
expect_error(call_binding_agg("n_distinct"), "n_distinct() with 0 arguments", fixed = TRUE)
expect_error(call_binding_agg("sum"), "sum() with 0 arguments", fixed = TRUE)
expect_error(call_binding_agg("prod"), "prod() with 0 arguments", fixed = TRUE)
expect_error(call_binding_agg("any"), "any() with 0 arguments", fixed = TRUE)
expect_error(call_binding_agg("all"), "all() with 0 arguments", fixed = TRUE)
expect_error(call_binding_agg("min"), "min() with 0 arguments", fixed = TRUE)
Expand Down Expand Up @@ -642,6 +664,7 @@ test_that("summarise() with !!sym()", {
group_by(false) %>%
summarise(
sum = sum(!!sym(test_dbl_col)),
prod = prod(!!sym(test_dbl_col)),
any = any(!!sym(test_lgl_col)),
all = all(!!sym(test_lgl_col)),
mean = mean(!!sym(test_dbl_col)),
Expand Down

0 comments on commit 6dcba93

Please sign in to comment.