diff --git a/r/R/dplyr-join.R b/r/R/dplyr-join.R index fad44b5ef2751..2ba3c307c110b 100644 --- a/r/R/dplyr-join.R +++ b/r/R/dplyr-join.R @@ -136,6 +136,17 @@ handle_join_by <- function(by, x, y) { if (is.null(by)) { return(set_names(intersect(names(x), names(y)))) } + if (inherits(by, "dplyr_join_by")) { + if (!all(by$condition == "==" & by$filter == "none")) { + abort( + paste0( + "Inequality conditions and helper functions ", + "are not supported in `join_by()` expressions." + ) + ) + } + by <- set_names(by$y, by$x) + } stopifnot(is.character(by)) if (is.null(names(by))) { by <- set_names(by) diff --git a/r/tests/testthat/test-dplyr-join.R b/r/tests/testthat/test-dplyr-join.R index 5c6798aeebc18..3470a886b3834 100644 --- a/r/tests/testthat/test-dplyr-join.R +++ b/r/tests/testthat/test-dplyr-join.R @@ -67,6 +67,39 @@ test_that("left_join `by` args", { ) }) +test_that("left_join with join_by", { + # only run this test in newer versions of dplyr that include `join_by()` + skip_if_not(packageVersion("dplyr") >= "1.0.99.9000") + + compare_dplyr_binding( + .input %>% + left_join(to_join, join_by(some_grouping)) %>% + collect(), + left + ) + compare_dplyr_binding( + .input %>% + left_join( + to_join %>% + rename(the_grouping = some_grouping), + join_by(some_grouping == the_grouping) + ) %>% + collect(), + left + ) + + compare_dplyr_binding( + .input %>% + rename(the_grouping = some_grouping) %>% + left_join( + to_join, + join_by(the_grouping == some_grouping) + ) %>% + collect(), + left + ) +}) + test_that("join two tables", { expect_identical( arrow_table(left) %>% @@ -136,6 +169,23 @@ test_that("Error handling", { ) }) +test_that("Error handling for unsupported expressions in join_by", { + # only run this test in newer versions of dplyr that include `join_by()` + skip_if_not(packageVersion("dplyr") >= "1.0.99.9000") + + expect_error( + arrow_table(left) %>% + left_join(to_join, join_by(some_grouping >= some_grouping)), + "not supported" + ) + + expect_error( + arrow_table(left) %>% + left_join(to_join, join_by(closest(some_grouping >= some_grouping))), + "not supported" + ) +}) + # TODO: test duplicate col names # TODO: casting: int and float columns?