From 24663c47924059bd3167b1745bf7b512a33a56e3 Mon Sep 17 00:00:00 2001 From: Vinicius Fraga Date: Tue, 29 Mar 2022 12:55:39 -0300 Subject: [PATCH] Fix rebase automerge --- .github/workflows/matlab.yml | 27 +- c_glib/arrow-flight-glib/client.cpp | 20 + c_glib/arrow-flight-glib/client.h | 5 + c_glib/test/flight/test-client.rb | 7 + ci/appveyor-cpp-build.bat | 2 + ci/conda_env_cpp.txt | 9 +- ci/docker/conda-cpp.dockerfile | 9 +- ci/scripts/c_glib_test.sh | 3 + ci/scripts/cpp_build.sh | 1 + ci/scripts/cpp_test.sh | 3 + ci/scripts/java_jni_macos_build.sh | 1 + ci/scripts/python_test.sh | 3 + ci/scripts/r_docker_configure.sh | 37 +- ci/scripts/r_test.sh | 10 + ci/scripts/ruby_test.sh | 3 + cpp/cmake_modules/ThirdpartyToolchain.cmake | 4 + cpp/examples/arrow/CMakeLists.txt | 23 +- cpp/gdb_arrow.py | 632 +++++++++++++++++- cpp/src/arrow/CMakeLists.txt | 1 + cpp/src/arrow/compute/api_scalar.cc | 3 + cpp/src/arrow/compute/api_scalar.h | 37 + cpp/src/arrow/compute/exec/hash_join.cc | 3 +- cpp/src/arrow/compute/exec/hash_join.h | 4 +- cpp/src/arrow/compute/exec/hash_join_dict.cc | 2 +- cpp/src/arrow/compute/exec/hash_join_node.cc | 82 ++- .../arrow/compute/exec/hash_join_node_test.cc | 67 ++ cpp/src/arrow/compute/exec/options.h | 40 +- cpp/src/arrow/compute/exec/util_test.cc | 44 +- .../arrow/compute/kernels/codegen_internal.cc | 9 +- .../compute/kernels/codegen_internal_test.cc | 36 +- .../compute/kernels/copy_data_internal.h | 112 ++++ .../compute/kernels/scalar_arithmetic.cc | 194 ++++++ .../compute/kernels/scalar_cast_nested.cc | 79 +++ .../arrow/compute/kernels/scalar_cast_test.cc | 111 +++ .../arrow/compute/kernels/scalar_if_else.cc | 117 +--- .../compute/kernels/scalar_temporal_test.cc | 388 ++++++++++- .../compute/kernels/scalar_temporal_unary.cc | 124 +++- .../arrow/compute/kernels/vector_replace.cc | 125 +--- cpp/src/arrow/dataset/scanner_test.cc | 4 +- cpp/src/arrow/filesystem/gcsfs.cc | 2 +- cpp/src/arrow/flight/CMakeLists.txt | 5 + cpp/src/arrow/flight/client.cc | 210 +++--- cpp/src/arrow/flight/client.h | 16 +- cpp/src/arrow/flight/flight_test.cc | 370 ++-------- cpp/src/arrow/flight/server.cc | 11 +- cpp/src/arrow/flight/server.h | 8 +- cpp/src/arrow/flight/sql/client.cc | 2 + cpp/src/arrow/flight/sql/client.h | 3 + cpp/src/arrow/flight/sql/server_test.cc | 1 + cpp/src/arrow/flight/test_util.cc | 6 +- .../flight/try_compile/check_tls_opts_143.cc | 37 + cpp/src/arrow/flight/types.cc | 122 +++- cpp/src/arrow/flight/types.h | 27 + cpp/src/arrow/ipc/stream_to_file.cc | 6 +- cpp/src/arrow/memory_pool.cc | 266 +++++++- cpp/src/arrow/memory_pool_test.h | 9 +- cpp/src/arrow/python/flight.cc | 11 - cpp/src/arrow/python/flight.h | 7 - cpp/src/arrow/python/gdb.cc | 123 +++- cpp/src/arrow/stl_test.cc | 8 +- cpp/src/arrow/type.cc | 2 + cpp/src/arrow/type.h | 2 + cpp/src/arrow/util/cpu_info.cc | 2 +- cpp/src/arrow/util/debug.cc | 31 + cpp/src/arrow/util/debug.h | 29 + cpp/src/arrow/util/mutex.cc | 31 + cpp/src/arrow/util/mutex.h | 21 + cpp/src/arrow/util/thread_pool.cc | 41 +- cpp/src/arrow/util/thread_pool.h | 6 +- cpp/src/arrow/util/thread_pool_test.cc | 122 +++- .../vendored/portable-snippets/debug-trap.h | 83 +++ cpp/src/gandiva/annotator.cc | 6 + cpp/src/gandiva/annotator.h | 12 + cpp/src/gandiva/cache.cc | 4 + cpp/src/gandiva/compiled_expr.h | 4 +- cpp/src/gandiva/dex.h | 32 +- cpp/src/gandiva/expr_decomposer.cc | 65 +- cpp/src/gandiva/expr_decomposer.h | 3 + cpp/src/gandiva/like_holder.cc | 3 +- cpp/src/gandiva/like_holder_test.cc | 6 +- cpp/src/gandiva/llvm_generator.cc | 57 +- cpp/src/gandiva/llvm_generator.h | 8 +- cpp/src/gandiva/llvm_generator_test.cc | 6 +- cpp/src/gandiva/tests/filter_test.cc | 39 ++ cpp/src/parquet/encoding.cc | 26 +- cpp/src/parquet/file_reader.cc | 5 + cpp/src/parquet/level_conversion_inc.h | 2 +- cpp/thirdparty/versions.txt | 4 +- dev/archery/README.md | 6 +- dev/release/post-05-ruby.sh | 3 + dev/release/post-13-msys2.sh | 15 + dev/release/rat_exclude_files.txt | 5 +- ...rsion7numpy1.18python3.7.____cpython.yaml} | 26 +- ...rsion7numpy1.18python3.8.____cpython.yaml} | 26 +- ...rsion7numpy1.19python3.9.____cpython.yaml} | 24 +- ...rsion9numpy1.18python3.7.____cpython.yaml} | 22 +- ...rsion9numpy1.18python3.8.____cpython.yaml} | 22 +- ...rsion9numpy1.19python3.9.____cpython.yaml} | 20 +- ...sion9numpy1.21python3.10.____cpython.yaml} | 24 +- ...ion10.2numpy1.17python3.6.____cpython.yaml | 70 -- ...rch64_numpy1.18python3.7.____cpython.yaml} | 21 +- ...rch64_numpy1.18python3.8.____cpython.yaml} | 21 +- ...arch64_numpy1.19python3.9.____cpython.yaml | 19 +- ...ch64_numpy1.21python3.10.____cpython.yaml} | 23 +- ...sx_64_numpy1.18python3.7.____cpython.yaml} | 18 +- ...sx_64_numpy1.18python3.8.____cpython.yaml} | 18 +- ...osx_64_numpy1.19python3.9.____cpython.yaml | 16 +- ...x_64_numpy1.21python3.10.____cpython.yaml} | 20 +- ...arm64_numpy1.19python3.8.____cpython.yaml} | 16 +- ...arm64_numpy1.19python3.9.____cpython.yaml} | 16 +- .../.ci_support/r/linux_64_r_base4.0.yaml | 6 +- .../.ci_support/r/linux_64_r_base4.1.yaml | 6 +- .../.ci_support/r/osx_64_r_base4.0.yaml | 2 +- .../.ci_support/r/osx_64_r_base4.1.yaml | 2 +- .../.ci_support/r/win_64_r_base4.0.yaml | 2 +- .../.ci_support/r/win_64_r_base4.1.yaml | 2 +- ...onNonenumpy1.18python3.7.____cpython.yaml} | 20 +- ...onNonenumpy1.18python3.8.____cpython.yaml} | 20 +- ...ionNonenumpy1.19python3.9.____cpython.yaml | 18 +- ...nNonenumpy1.21python3.10.____cpython.yaml} | 22 +- .../conda-recipes/arrow-cpp/build-arrow.sh | 7 + .../conda-recipes/arrow-cpp/build-pyarrow.sh | 5 +- dev/tasks/conda-recipes/arrow-cpp/meta.yaml | 11 +- dev/tasks/conda-recipes/azure.osx.yml | 2 +- dev/tasks/conda-recipes/azure.win.yml | 14 +- dev/tasks/conda-recipes/build_steps.sh | 7 +- dev/tasks/conda-recipes/r-arrow/meta.yaml | 14 +- dev/tasks/homebrew-formulae/apache-arrow.rb | 4 + dev/tasks/homebrew-formulae/github.macos.yml | 42 +- .../apache-arrow/yum/arrow.spec.in | 159 +++-- dev/tasks/macros.jinja | 29 + dev/tasks/r/github.linux.cran.yml | 3 +- dev/tasks/r/github.linux.offline.build.yml | 7 +- dev/tasks/r/github.macos.autobrew.yml | 2 +- dev/tasks/r/github.macos.brew.yml | 64 ++ dev/tasks/tasks.yml | 117 +++- docker-compose.yml | 2 + docs/source/cpp/compute.rst | 76 ++- docs/source/cpp/gdb.rst | 2 +- .../developers/guide/step_by_step/set_up.rst | 2 + .../guide/tutorials/python_tutorial.rst | 8 +- .../developers/guide/tutorials/r_tutorial.rst | 437 +++++++++++- docs/source/format/CDataInterface.rst | 3 +- docs/source/python/api/compute.rst | 2 + format/FlightSql.proto | 54 +- go/arrow/flight/basic_auth_flight_test.go | 18 +- go/arrow/flight/client.go | 27 +- go/arrow/flight/example_flight_server_test.go | 6 +- go/arrow/flight/flight_middleware_test.go | 24 +- go/arrow/flight/flight_test.go | 40 +- go/arrow/flight/gen.go | 2 +- .../flight/{ => internal/flight}/Flight.pb.go | 27 +- .../{ => internal/flight}/Flight_grpc.pb.go | 555 ++++++--------- go/arrow/flight/server.go | 109 ++- go/arrow/flight/server_auth.go | 89 ++- go/arrow/internal/flatbuf/BodyCompression.go | 6 +- .../internal/flight_integration/scenario.go | 45 +- go/arrow/tools.go | 1 + go/go.mod | 3 +- go/go.sum | 4 +- java/adapter/orc/pom.xml | 2 +- java/arrow-flight/pom.xml | 34 - java/flight/flight-core/pom.xml | 6 +- .../integration/AuthBasicProtoScenario.java | 97 --- .../integration/IntegrationAssertions.java | 83 --- .../integration/IntegrationTestClient.java | 197 ------ .../integration/IntegrationTestServer.java | 97 --- .../integration/MiddlewareScenario.java | 168 ----- .../flight/example/integration/Scenario.java | 48 -- .../flight/example/integration/Scenarios.java | 91 --- java/flight/flight-grpc/pom.xml | 2 +- java/flight/flight-jdbc-driver/pom.xml | 2 +- ...owFlightJdbcVectorSchemaRootResultSet.java | 70 +- .../resources/properties/flight.properties | 45 -- ...lightJdbcDenseUnionVectorAccessorTest.java | 1 - .../ArrowFlightJdbcBitVectorAccessorTest.java | 2 +- .../driver/jdbc/utils/AccessorTestUtils.java | 21 +- .../flight/sql/FlightSqlColumnMetadata.java | 27 +- .../arrow/flight/sql/FlightSqlUtils.java | 1 - .../apache/arrow/flight/TestFlightSql.java | 8 +- java/flight/pom.xml | 5 +- java/pom.xml | 20 - matlab/README.md | 101 +-- python/pyarrow/_flight.pyx | 56 +- python/pyarrow/array.pxi | 8 +- python/pyarrow/includes/common.pxd | 1 + python/pyarrow/includes/libarrow.pxd | 1 + python/pyarrow/includes/libarrow_flight.pxd | 30 +- python/pyarrow/table.pxi | 4 +- python/pyarrow/tests/test_compute.py | 5 + python/pyarrow/tests/test_flight.py | 289 ++++---- python/pyarrow/tests/test_gdb.py | 278 +++++++- python/pyarrow/tests/test_memory.py | 74 ++ python/pyarrow/tests/test_pandas.py | 48 ++ python/pyproject.toml | 2 +- python/setup.py | 4 +- r/Makefile | 2 +- r/NAMESPACE | 4 + r/R/arrowExports.R | 4 +- r/R/dataset-write.R | 37 +- r/R/filesystem.R | 2 +- r/R/flight.R | 8 + r/R/type.R | 5 +- r/_pkgdown.yml | 1 + r/man/FileFormat.Rd | 2 +- r/man/array.Rd | 1 + r/man/data-type.Rd | 9 + r/man/flight_disconnect.Rd | 14 + r/man/open_dataset.Rd | 2 +- r/man/s3_bucket.Rd | 2 +- r/man/write_dataset.Rd | 24 + r/src/arrowExports.cpp | 24 +- r/src/dataset.cpp | 8 +- r/tests/testthat/test-dataset-write.R | 155 +++++ r/tests/testthat/test-dplyr-join.R | 88 +++ r/tests/testthat/test-python-flight.R | 6 + r/vignettes/arrow.Rmd | 2 +- 217 files changed, 5838 insertions(+), 3353 deletions(-) create mode 100644 cpp/src/arrow/compute/kernels/copy_data_internal.h create mode 100644 cpp/src/arrow/flight/try_compile/check_tls_opts_143.cc create mode 100644 cpp/src/arrow/util/debug.cc create mode 100644 cpp/src/arrow/util/debug.h create mode 100644 cpp/src/arrow/vendored/portable-snippets/debug-trap.h rename dev/tasks/conda-recipes/.ci_support/{linux_64_cuda_compiler_version10.2numpy1.17python3.7.____cpython.yaml => linux_64_c_compiler_version7cuda_compiler_version10.2cxx_compiler_version7numpy1.18python3.7.____cpython.yaml} (79%) rename dev/tasks/conda-recipes/.ci_support/{linux_64_cuda_compiler_version10.2numpy1.17python3.8.____cpython.yaml => linux_64_c_compiler_version7cuda_compiler_version10.2cxx_compiler_version7numpy1.18python3.8.____cpython.yaml} (79%) rename dev/tasks/conda-recipes/.ci_support/{linux_64_cuda_compiler_version10.2numpy1.19python3.9.____cpython.yaml => linux_64_c_compiler_version7cuda_compiler_version10.2cxx_compiler_version7numpy1.19python3.9.____cpython.yaml} (80%) rename dev/tasks/conda-recipes/.ci_support/{linux_64_cuda_compiler_versionNonenumpy1.17python3.7.____cpython.yaml => linux_64_c_compiler_version9cuda_compiler_versionNonecxx_compiler_version9numpy1.18python3.7.____cpython.yaml} (80%) rename dev/tasks/conda-recipes/.ci_support/{linux_64_cuda_compiler_versionNonenumpy1.17python3.8.____cpython.yaml => linux_64_c_compiler_version9cuda_compiler_versionNonecxx_compiler_version9numpy1.18python3.8.____cpython.yaml} (80%) rename dev/tasks/conda-recipes/.ci_support/{linux_64_cuda_compiler_versionNonenumpy1.19python3.9.____cpython.yaml => linux_64_c_compiler_version9cuda_compiler_versionNonecxx_compiler_version9numpy1.19python3.9.____cpython.yaml} (81%) rename dev/tasks/conda-recipes/.ci_support/{linux_64_cuda_compiler_versionNonenumpy1.17python3.6.____cpython.yaml => linux_64_c_compiler_version9cuda_compiler_versionNonecxx_compiler_version9numpy1.21python3.10.____cpython.yaml} (78%) delete mode 100644 dev/tasks/conda-recipes/.ci_support/linux_64_cuda_compiler_version10.2numpy1.17python3.6.____cpython.yaml rename dev/tasks/conda-recipes/.ci_support/{linux_aarch64_numpy1.17python3.7.____cpython.yaml => linux_aarch64_numpy1.18python3.7.____cpython.yaml} (80%) rename dev/tasks/conda-recipes/.ci_support/{linux_aarch64_numpy1.17python3.8.____cpython.yaml => linux_aarch64_numpy1.18python3.8.____cpython.yaml} (80%) rename dev/tasks/conda-recipes/.ci_support/{linux_aarch64_numpy1.17python3.6.____cpython.yaml => linux_aarch64_numpy1.21python3.10.____cpython.yaml} (78%) rename dev/tasks/conda-recipes/.ci_support/{osx_64_numpy1.17python3.7.____cpython.yaml => osx_64_numpy1.18python3.7.____cpython.yaml} (87%) rename dev/tasks/conda-recipes/.ci_support/{osx_64_numpy1.17python3.8.____cpython.yaml => osx_64_numpy1.18python3.8.____cpython.yaml} (87%) rename dev/tasks/conda-recipes/.ci_support/{osx_64_numpy1.17python3.6.____cpython.yaml => osx_64_numpy1.21python3.10.____cpython.yaml} (85%) rename dev/tasks/conda-recipes/.ci_support/{osx_arm64_python3.8.____cpython.yaml => osx_arm64_numpy1.19python3.8.____cpython.yaml} (86%) rename dev/tasks/conda-recipes/.ci_support/{osx_arm64_python3.9.____cpython.yaml => osx_arm64_numpy1.19python3.9.____cpython.yaml} (86%) rename dev/tasks/conda-recipes/.ci_support/{win_64_cuda_compiler_versionNonenumpy1.17python3.7.____cpython.yaml => win_64_cuda_compiler_versionNonenumpy1.18python3.7.____cpython.yaml} (81%) rename dev/tasks/conda-recipes/.ci_support/{win_64_cuda_compiler_versionNonenumpy1.17python3.8.____cpython.yaml => win_64_cuda_compiler_versionNonenumpy1.18python3.8.____cpython.yaml} (81%) rename dev/tasks/conda-recipes/.ci_support/{win_64_cuda_compiler_versionNonenumpy1.17python3.6.____cpython.yaml => win_64_cuda_compiler_versionNonenumpy1.21python3.10.____cpython.yaml} (79%) create mode 100644 dev/tasks/r/github.macos.brew.yml rename go/arrow/flight/{ => internal/flight}/Flight.pb.go (98%) rename go/arrow/flight/{ => internal/flight}/Flight_grpc.pb.go (58%) delete mode 100644 java/arrow-flight/pom.xml delete mode 100644 java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/AuthBasicProtoScenario.java delete mode 100644 java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationAssertions.java delete mode 100644 java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java delete mode 100644 java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestServer.java delete mode 100644 java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/MiddlewareScenario.java delete mode 100644 java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Scenario.java delete mode 100644 java/flight/flight-core/src/main/java/org/apache/arrow/flight/example/integration/Scenarios.java delete mode 100644 java/flight/flight-jdbc-driver/src/main/resources/properties/flight.properties create mode 100644 r/man/flight_disconnect.Rd diff --git a/.github/workflows/matlab.yml b/.github/workflows/matlab.yml index 00d953e428e2b..ad5c602c6dbe9 100644 --- a/.github/workflows/matlab.yml +++ b/.github/workflows/matlab.yml @@ -37,7 +37,7 @@ concurrency: jobs: - matlab: + ubuntu: name: AMD64 Ubuntu 20.04 MATLAB runs-on: ubuntu-latest if: ${{ !contains(github.event.pull_request.title, 'WIP') }} @@ -49,7 +49,7 @@ jobs: - name: Install ninja-build run: sudo apt-get install ninja-build - name: Install MATLAB - uses: matlab-actions/setup-matlab@v0 + uses: matlab-actions/setup-matlab@v1 - name: Build MATLAB Interface run: ci/scripts/matlab_build.sh $(pwd) - name: Run MATLAB Tests @@ -68,3 +68,26 @@ jobs: uses: matlab-actions/run-tests@v1 with: select-by-folder: matlab/test + macos: + name: AMD64 MacOS 10.15 MATLAB + runs-on: macos-latest + if: ${{ !contains(github.event.pull_request.title, 'WIP') }} + steps: + - name: Check out repository + uses: actions/checkout@v2 + with: + fetch-depth: 0 + - name: Install ninja-build + run: brew install ninja + - name: Install MATLAB + uses: matlab-actions/setup-matlab@v1 + - name: Build MATLAB Interface + run: ci/scripts/matlab_build.sh $(pwd) + - name: Run MATLAB Tests + env: + # Add the installation directory to the MATLAB Search Path by + # setting the MATLABPATH environment variable. + MATLABPATH: matlab/install/arrow_matlab + uses: matlab-actions/run-tests@v1 + with: + select-by-folder: matlab/test diff --git a/c_glib/arrow-flight-glib/client.cpp b/c_glib/arrow-flight-glib/client.cpp index 9f9e71e6f0434..0d1961e6c6222 100644 --- a/c_glib/arrow-flight-glib/client.cpp +++ b/c_glib/arrow-flight-glib/client.cpp @@ -264,6 +264,26 @@ gaflight_client_new(GAFlightLocation *location, } } +/** + * gaflight_client_close: + * @client: A #GAFlightClient. + * @error: (nullable): Return location for a #GError or %NULL. + * + * Returns: %TRUE on success, %FALSE if there was an error. + * + * Since: 8.0.0 + */ +gboolean +gaflight_client_close(GAFlightClient *client, + GError **error) +{ + auto flight_client = gaflight_client_get_raw(client); + auto status = flight_client->Close(); + return garrow::check(error, + status, + "[flight-client][close]"); +} + /** * gaflight_client_list_flights: * @client: A #GAFlightClient. diff --git a/c_glib/arrow-flight-glib/client.h b/c_glib/arrow-flight-glib/client.h index bc297116135d1..f601e66e79003 100644 --- a/c_glib/arrow-flight-glib/client.h +++ b/c_glib/arrow-flight-glib/client.h @@ -86,6 +86,11 @@ gaflight_client_new(GAFlightLocation *location, GAFlightClientOptions *options, GError **error); +GARROW_AVAILABLE_IN_8_0 +gboolean +gaflight_client_close(GAFlightClient *client, + GError **error); + GARROW_AVAILABLE_IN_5_0 GList * gaflight_client_list_flights(GAFlightClient *client, diff --git a/c_glib/test/flight/test-client.rb b/c_glib/test/flight/test-client.rb index f6660a4ca49db..48f03223f4568 100644 --- a/c_glib/test/flight/test-client.rb +++ b/c_glib/test/flight/test-client.rb @@ -36,6 +36,13 @@ def teardown @server.shutdown end + def test_close + client = ArrowFlight::Client.new(@location) + client.close + # Idempotent + client.close + end + def test_list_flights client = ArrowFlight::Client.new(@location) generator = Helper::FlightInfoGenerator.new diff --git a/ci/appveyor-cpp-build.bat b/ci/appveyor-cpp-build.bat index 37509b02b3b9b..a69e7a665bd5f 100644 --- a/ci/appveyor-cpp-build.bat +++ b/ci/appveyor-cpp-build.bat @@ -26,6 +26,8 @@ git submodule update --init || exit /B set ARROW_TEST_DATA=%CD%\testing\data set PARQUET_TEST_DATA=%CD%\cpp\submodules\parquet-testing\data +set ARROW_DEBUG_MEMORY_POOL=trap + @rem @rem In the configurations below we disable building the Arrow static library @rem to save some time. Unfortunately this will still build the Parquet static diff --git a/ci/conda_env_cpp.txt b/ci/conda_env_cpp.txt index cd7136cebbdb3..7120c56ebb1fd 100644 --- a/ci/conda_env_cpp.txt +++ b/ci/conda_env_cpp.txt @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -aws-sdk-cpp +aws-sdk-cpp=1.8.186 benchmark>=1.6.0 boost-cpp>=1.68.0 brotli @@ -25,13 +25,18 @@ cmake gflags glog gmock>=1.10.0 -grpc-cpp>=1.27.3 +google-cloud-cpp>=1.34.0 +# 1.45.0 appears to segfault on Windows/AppVeyor +grpc-cpp>=1.27.3,<1.45.0 gtest>=1.10.0 libprotobuf libutf8proc lz4-c make ninja +# Required by google-cloud-cpp, the Conda package is missing the dependency: +# https://github.com/conda-forge/google-cloud-cpp-feedstock/issues/28 +nlohmann_json pkg-config python rapidjson diff --git a/ci/docker/conda-cpp.dockerfile b/ci/docker/conda-cpp.dockerfile index c9417da446efb..7fe457ac5e76b 100644 --- a/ci/docker/conda-cpp.dockerfile +++ b/ci/docker/conda-cpp.dockerfile @@ -22,9 +22,6 @@ FROM ${repo}:${arch}-conda COPY ci/scripts/install_minio.sh /arrow/ci/scripts RUN /arrow/ci/scripts/install_minio.sh latest /opt/conda -COPY ci/scripts/install_gcs_testbench.sh /arrow/ci/scripts -RUN /arrow/ci/scripts/install_gcs_testbench.sh default - # install the required conda packages into the test environment COPY ci/conda_env_cpp.txt \ ci/conda_env_gandiva.txt \ @@ -37,12 +34,17 @@ RUN mamba install \ valgrind && \ mamba clean --all +# We want to install the GCS testbench using the same Python binary that the Conda code will use. +COPY ci/scripts/install_gcs_testbench.sh /arrow/ci/scripts +RUN /arrow/ci/scripts/install_gcs_testbench.sh default + ENV ARROW_BUILD_TESTS=ON \ ARROW_DATASET=ON \ ARROW_DEPENDENCY_SOURCE=CONDA \ ARROW_FLIGHT=ON \ ARROW_FLIGHT_SQL=ON \ ARROW_GANDIVA=ON \ + ARROW_GCS=ON \ ARROW_HOME=$CONDA_PREFIX \ ARROW_ORC=ON \ ARROW_PARQUET=ON \ @@ -57,6 +59,7 @@ ENV ARROW_BUILD_TESTS=ON \ ARROW_WITH_SNAPPY=ON \ ARROW_WITH_ZLIB=ON \ ARROW_WITH_ZSTD=ON \ + CMAKE_CXX_STANDARD=17 \ GTest_SOURCE=BUNDLED \ PARQUET_BUILD_EXAMPLES=ON \ PARQUET_BUILD_EXECUTABLES=ON \ diff --git a/ci/scripts/c_glib_test.sh b/ci/scripts/c_glib_test.sh index 25c54138ed659..cb576136d4343 100755 --- a/ci/scripts/c_glib_test.sh +++ b/ci/scripts/c_glib_test.sh @@ -26,6 +26,9 @@ export LD_LIBRARY_PATH=${ARROW_HOME}/lib:${LD_LIBRARY_PATH} export PKG_CONFIG_PATH=${ARROW_HOME}/lib/pkgconfig export GI_TYPELIB_PATH=${ARROW_HOME}/lib/girepository-1.0 +# Enable memory debug checks. +export ARROW_DEBUG_MEMORY_POOL=trap + pushd ${source_dir} ruby test/run-test.rb diff --git a/ci/scripts/cpp_build.sh b/ci/scripts/cpp_build.sh index 4bbfcb7f90b49..2e6f35936ab89 100755 --- a/ci/scripts/cpp_build.sh +++ b/ci/scripts/cpp_build.sh @@ -137,6 +137,7 @@ cmake \ -DCMAKE_VERBOSE_MAKEFILE=${CMAKE_VERBOSE_MAKEFILE:-OFF} \ -DCMAKE_C_FLAGS="${CFLAGS:-}" \ -DCMAKE_CXX_FLAGS="${CXXFLAGS:-}" \ + -DCMAKE_CXX_STANDARD="${CMAKE_CXX_STANDARD:-11}" \ -DCMAKE_INSTALL_LIBDIR=${CMAKE_INSTALL_LIBDIR:-lib} \ -DCMAKE_INSTALL_PREFIX=${CMAKE_INSTALL_PREFIX:-${ARROW_HOME}} \ -DCMAKE_UNITY_BUILD=${CMAKE_UNITY_BUILD:-OFF} \ diff --git a/ci/scripts/cpp_test.sh b/ci/scripts/cpp_test.sh index 822557f252c9f..2be1227c49efa 100755 --- a/ci/scripts/cpp_test.sh +++ b/ci/scripts/cpp_test.sh @@ -37,6 +37,9 @@ export LD_LIBRARY_PATH=${ARROW_HOME}/${CMAKE_INSTALL_LIBDIR:-lib}:${LD_LIBRARY_P # to retrieve metadata. Disable this so that S3FileSystem tests run faster. export AWS_EC2_METADATA_DISABLED=TRUE +# Enable memory debug checks. +export ARROW_DEBUG_MEMORY_POOL=trap + ctest_options=() case "$(uname)" in Linux) diff --git a/ci/scripts/java_jni_macos_build.sh b/ci/scripts/java_jni_macos_build.sh index 218d2d3960a56..e9572c334b6ac 100755 --- a/ci/scripts/java_jni_macos_build.sh +++ b/ci/scripts/java_jni_macos_build.sh @@ -50,6 +50,7 @@ mkdir -p "${build_dir}" pushd "${build_dir}" cmake \ + -GNinja \ -DARROW_BOOST_USE_SHARED=OFF \ -DARROW_BROTLI_USE_SHARED=OFF \ -DARROW_BUILD_TESTS=${ARROW_BUILD_TESTS} \ diff --git a/ci/scripts/python_test.sh b/ci/scripts/python_test.sh index 2eb1373593001..e1d06c1872735 100755 --- a/ci/scripts/python_test.sh +++ b/ci/scripts/python_test.sh @@ -30,6 +30,9 @@ export ARROW_GDB_SCRIPT=${arrow_dir}/cpp/gdb_arrow.py # Enable some checks inside Python itself export PYTHONDEVMODE=1 +# Enable memory debug checks. +export ARROW_DEBUG_MEMORY_POOL=trap + # By default, force-test all optional components : ${PYARROW_TEST_CUDA:=${ARROW_CUDA:-ON}} : ${PYARROW_TEST_DATASET:=${ARROW_DATASET:-ON}} diff --git a/ci/scripts/r_docker_configure.sh b/ci/scripts/r_docker_configure.sh index d2dc6405a612f..518df1040d1b1 100755 --- a/ci/scripts/r_docker_configure.sh +++ b/ci/scripts/r_docker_configure.sh @@ -30,6 +30,18 @@ fi # Ensure parallel compilation of C/C++ code echo "MAKEFLAGS=-j$(${R_BIN} -s -e 'cat(parallel::detectCores())')" >> $(R RHOME)/etc/Renviron.site +# Figure out what package manager we have +if [ "`which dnf`" ]; then + PACKAGE_MANAGER=dnf +elif [ "`which yum`" ]; then + PACKAGE_MANAGER=yum +elif [ "`which zypper`" ]; then + PACKAGE_MANAGER=zypper +else + PACKAGE_MANAGER=apt-get + apt-get update +fi + # Special hacking to try to reproduce quirks on fedora-clang-devel on CRAN # which uses a bespoke clang compiled to use libc++ # https://www.stats.ox.ac.uk/pub/bdr/Rconfig/r-devel-linux-x86_64-fedora-clang @@ -48,26 +60,16 @@ fi # Special hacking to try to reproduce quirks on centos using non-default build # tooling. if [[ "$DEVTOOLSET_VERSION" -gt 0 ]]; then - if [ "`which dnf`" ]; then - dnf install -y centos-release-scl - dnf install -y "devtoolset-$DEVTOOLSET_VERSION" - else - yum install -y centos-release-scl - yum install -y "devtoolset-$DEVTOOLSET_VERSION" - fi + $PACKAGE_MANAGER install -y centos-release-scl + $PACKAGE_MANAGER install -y "devtoolset-$DEVTOOLSET_VERSION" fi -# Install openssl for S3 support if [ "$ARROW_S3" == "ON" ] || [ "$ARROW_R_DEV" == "TRUE" ]; then - if [ "`which dnf`" ]; then - dnf install -y libcurl-devel openssl-devel - elif [ "`which yum`" ]; then - yum install -y libcurl-devel openssl-devel - elif [ "`which zypper`" ]; then - zypper install -y libcurl-devel libopenssl-devel - else - apt-get update + # Install curl and openssl for S3 support + if [ "$PACKAGE_MANAGER" = "apt-get" ]; then apt-get install -y libcurl4-openssl-dev libssl-dev + else + $PACKAGE_MANAGER install -y libcurl-devel openssl-devel fi # The Dockerfile should have put this file here @@ -80,5 +82,8 @@ if [ "$ARROW_S3" == "ON" ] || [ "$ARROW_R_DEV" == "TRUE" ]; then fi fi +# Install rsync for bundling cpp source +$PACKAGE_MANAGER install -y rsync + # Workaround for html help install failure; see https://github.com/r-lib/devtools/issues/2084#issuecomment-530912786 Rscript -e 'x <- file.path(R.home("doc"), "html"); if (!file.exists(x)) {dir.create(x, recursive=TRUE); file.copy(system.file("html/R.css", package="stats"), x)}' diff --git a/ci/scripts/r_test.sh b/ci/scripts/r_test.sh index 62e423cf5d90d..b9d6d0d684e09 100755 --- a/ci/scripts/r_test.sh +++ b/ci/scripts/r_test.sh @@ -26,6 +26,13 @@ pushd ${source_dir} printenv +# Before release, we always copy the relevant parts of the cpp source into the +# package. In some CI checks, we will use this version of the source: +# this is done by setting ARROW_SOURCE_HOME to something other than "/arrow" +# (which is where the arrow git checkout is found in docker and other CI jobs) +# In the other CI checks the files are synced but ignored. +make sync-cpp + if [ "$ARROW_USE_PKG_CONFIG" != "false" ]; then export LD_LIBRARY_PATH=${ARROW_HOME}/lib:${LD_LIBRARY_PATH} export R_LD_LIBRARY_PATH=${LD_LIBRARY_PATH} @@ -56,6 +63,9 @@ export _R_CHECK_TESTS_NLINES_=0 # to retrieve metadata. Disable this so that S3FileSystem tests run faster. export AWS_EC2_METADATA_DISABLED=TRUE +# Enable memory debug checks. +export ARROW_DEBUG_MEMORY_POOL=trap + # Hack so that texlive2020 doesn't pollute the home dir export TEXMFCONFIG=/tmp/texmf-config export TEXMFVAR=/tmp/texmf-var diff --git a/ci/scripts/ruby_test.sh b/ci/scripts/ruby_test.sh index 03d20e19831f3..4fd6a85fe3966 100755 --- a/ci/scripts/ruby_test.sh +++ b/ci/scripts/ruby_test.sh @@ -26,4 +26,7 @@ export LD_LIBRARY_PATH=${ARROW_HOME}/lib:${LD_LIBRARY_PATH} export PKG_CONFIG_PATH=${ARROW_HOME}/lib/pkgconfig export GI_TYPELIB_PATH=${ARROW_HOME}/lib/girepository-1.0 +# Enable memory debug checks. +export ARROW_DEBUG_MEMORY_POOL=trap + rake -f ${source_dir}/Rakefile BUILD_DIR=${build_dir} USE_BUNDLER=yes diff --git a/cpp/cmake_modules/ThirdpartyToolchain.cmake b/cpp/cmake_modules/ThirdpartyToolchain.cmake index 5ad25cd54bb7e..3a0353bc7dbd9 100644 --- a/cpp/cmake_modules/ThirdpartyToolchain.cmake +++ b/cpp/cmake_modules/ThirdpartyToolchain.cmake @@ -1751,6 +1751,10 @@ if(ARROW_JEMALLOC) # See https://github.com/jemalloc/jemalloc/issues/1237 "--disable-initial-exec-tls" ${EP_LOG_OPTIONS}) + if(${UPPERCASE_BUILD_TYPE} STREQUAL "DEBUG") + # Enable jemalloc debug checks when Arrow itself has debugging enabled + list(APPEND JEMALLOC_CONFIGURE_COMMAND "--enable-debug") + endif() set(JEMALLOC_BUILD_COMMAND ${MAKE} ${MAKE_BUILD_ARGS}) if(CMAKE_OSX_SYSROOT) list(APPEND JEMALLOC_BUILD_COMMAND "SDKROOT=${CMAKE_OSX_SYSROOT}") diff --git a/cpp/examples/arrow/CMakeLists.txt b/cpp/examples/arrow/CMakeLists.txt index a593911b1e35a..a38bd88333f83 100644 --- a/cpp/examples/arrow/CMakeLists.txt +++ b/cpp/examples/arrow/CMakeLists.txt @@ -92,6 +92,24 @@ if(ARROW_FLIGHT) EXTRA_SOURCES "${CMAKE_CURRENT_BINARY_DIR}/helloworld.pb.cc" "${CMAKE_CURRENT_BINARY_DIR}/helloworld.grpc.pb.cc") + + if(ARROW_FLIGHT_SQL) + if(ARROW_GRPC_USE_SHARED) + set(FLIGHT_SQL_EXAMPLES_LINK_LIBS arrow_flight_sql_shared) + else() + set(FLIGHT_SQL_EXAMPLES_LINK_LIBS arrow_flight_sql_static) + endif() + + add_arrow_example(flight_sql_example + DEPENDENCIES + flight-sql-test-server + EXTRA_LINK_LIBS + ${FLIGHT_EXAMPLES_LINK_LIBS} + ${FLIGHT_SQL_EXAMPLES_LINK_LIBS} + gRPC::grpc++ + ${ARROW_PROTOBUF_LIBPROTOBUF} + ${GFLAGS_LIBRARIES}) + endif() endif() if(ARROW_PARQUET AND ARROW_DATASET) @@ -111,5 +129,8 @@ if(ARROW_PARQUET AND ARROW_DATASET) add_arrow_example(execution_plan_documentation_examples EXTRA_LINK_LIBS ${DATASET_EXAMPLES_LINK_LIBS}) - add_dependencies(dataset_documentation_example parquet) + add_dependencies(execution-plan-documentation-examples parquet) + + add_arrow_example(join_example EXTRA_LINK_LIBS ${DATASET_EXAMPLES_LINK_LIBS}) + add_dependencies(join-example parquet) endif() diff --git a/cpp/gdb_arrow.py b/cpp/gdb_arrow.py index 0bc89f46bf8b6..8a077a22e5055 100644 --- a/cpp/gdb_arrow.py +++ b/cpp/gdb_arrow.py @@ -17,9 +17,13 @@ from collections import namedtuple from collections.abc import Sequence +import datetime import decimal import enum from functools import lru_cache, partial +import itertools +import math +import operator import struct import sys import warnings @@ -27,10 +31,11 @@ import gdb from gdb.types import get_basic_type -# gdb API docs at https://sourceware.org/gdb/onlinedocs/gdb/Python-API.html#Python-API -# TODO check guidelines here: https://sourceware.org/gdb/onlinedocs/gdb/Writing-a-Pretty_002dPrinter.html -# TODO investigate auto-loading: https://sourceware.org/gdb/onlinedocs/gdb/Auto_002dloading-extensions.html#Auto_002dloading-extensions +assert sys.version_info[0] >= 3, "Arrow GDB extension needs Python 3+" + + +# gdb API docs at https://sourceware.org/gdb/onlinedocs/gdb/Python-API.html#Python-API _type_ids = [ @@ -45,6 +50,51 @@ # Mirror the C++ Type::type enum Type = enum.IntEnum('Type', _type_ids, start=0) +# Mirror the C++ TimeUnit::type enum +TimeUnit = enum.IntEnum('TimeUnit', ['SECOND', 'MILLI', 'MICRO', 'NANO'], + start=0) + +type_id_to_struct_code = { + Type.INT8: 'b', + Type.INT16: 'h', + Type.INT32: 'i', + Type.INT64: 'q', + Type.UINT8: 'B', + Type.UINT16: 'H', + Type.UINT32: 'I', + Type.UINT64: 'Q', + Type.HALF_FLOAT: 'e', + Type.FLOAT: 'f', + Type.DOUBLE: 'd', + Type.DATE32: 'i', + Type.DATE64: 'q', + Type.TIME32: 'i', + Type.TIME64: 'q', + Type.INTERVAL_DAY_TIME: 'ii', + Type.INTERVAL_MONTHS: 'i', + Type.INTERVAL_MONTH_DAY_NANO: 'iiq', + Type.DURATION: 'q', + Type.TIMESTAMP: 'q', +} + +TimeUnitTraits = namedtuple('TimeUnitTraits', ('multiplier', + 'fractional_digits')) + +time_unit_traits = { + TimeUnit.SECOND: TimeUnitTraits(1, 0), + TimeUnit.MILLI: TimeUnitTraits(1_000, 3), + TimeUnit.MICRO: TimeUnitTraits(1_000_000, 6), + TimeUnit.NANO: TimeUnitTraits(1_000_000_000, 9), +} + + +def identity(v): + return v + + +def has_null_bitmap(type_id): + return type_id not in (Type.NA, Type.SPARSE_UNION, Type.DENSE_UNION) + @lru_cache() def byte_order(): @@ -203,6 +253,66 @@ def format_month_interval(val): return f"{int(val)}M" +def format_days_milliseconds(days, milliseconds): + return f"{days}d{milliseconds}ms" + + +def format_months_days_nanos(months, days, nanos): + return f"{months}M{days}d{nanos}ns" + + +_date_base = datetime.date(1970, 1, 1).toordinal() + + +def format_date32(val): + """ + Format a date32 value. + """ + val = int(val) + try: + decoded = datetime.date.fromordinal(val + _date_base) + except ValueError: # "ordinal must be >= 1" + return f"{val}d [year <= 0]" + else: + return f"{val}d [{decoded}]" + + +def format_date64(val): + """ + Format a date64 value. + """ + val = int(val) + days, remainder = divmod(val, 86400 * 1000) + if remainder: + return f"{val}ms [non-multiple of 86400000]" + try: + decoded = datetime.date.fromordinal(days + _date_base) + except ValueError: # "ordinal must be >= 1" + return f"{val}ms [year <= 0]" + else: + return f"{val}ms [{decoded}]" + + +def format_timestamp(val, unit): + """ + Format a timestamp value. + """ + val = int(val) + unit = int(unit) + short_unit = short_time_unit(unit) + traits = time_unit_traits[unit] + seconds, subseconds = divmod(val, traits.multiplier) + try: + dt = datetime.datetime.utcfromtimestamp(seconds) + except (ValueError, OSError): # value out of range for datetime.datetime + pretty = "too large to represent" + else: + pretty = dt.isoformat().replace('T', ' ') + if traits.fractional_digits > 0: + pretty += f".{subseconds:0{traits.fractional_digits}d}" + return f"{val}{short_unit} [{pretty}]" + + def cast_to_concrete(val, ty): return (val.reference_value().reinterpret_cast(ty.reference()) .referenced_value()) @@ -416,7 +526,7 @@ def eval_at(self, index, eval_format): Run `eval_format` with the value at `index`. For example, if `eval_format` is "{}.get()", this will evaluate - "{self[0]}.get()". + "{self[index]}.get()". """ self._check_index(index) return gdb.parse_and_eval( @@ -504,6 +614,23 @@ def bytes_literal(self): else: return '""' + def bytes_view(self, offset=0, length=None): + """ + Return a view over the bytes of this buffer. + """ + if self.size > 0: + if length is None: + length = self.size + mem = gdb.selected_inferior().read_memory( + self.val['data_'] + offset, self.size) + else: + mem = memoryview(b"") + # Read individual bytes as unsigned integers rather than + # Python bytes objects + return mem.cast('B') + + view = bytes_view + class BufferPtr: """ @@ -533,6 +660,157 @@ def bytes_literal(self): return self.buf.bytes_literal() +class TypedBuffer(Buffer): + """ + A buffer containing values of a given a struct format code. + """ + _boolean_format = object() + + def __init__(self, val, mem_format): + super().__init__(val) + self.mem_format = mem_format + if not self.is_boolean: + self.byte_width = struct.calcsize('=' + self.mem_format) + + @classmethod + def from_type_id(cls, val, type_id): + assert isinstance(type_id, int) + if type_id == Type.BOOL: + mem_format = cls._boolean_format + else: + mem_format = type_id_to_struct_code[type_id] + return cls(val, mem_format) + + def view(self, offset=0, length=None): + """ + Return a view over the primitive values in this buffer. + + The optional `offset` and `length` are expressed in primitive values, + not bytes. + """ + if self.is_boolean: + return Bitmap.from_buffer(self, offset, length) + + byte_offset = offset * self.byte_width + if length is not None: + mem = self.bytes_view(byte_offset, length * self.byte_width) + else: + mem = self.bytes_view(byte_offset) + return TypedView(mem, self.mem_format) + + @property + def is_boolean(self): + return self.mem_format is self._boolean_format + + +class TypedView(Sequence): + """ + View a bytes-compatible object as a sequence of objects described + by a struct format code. + """ + + def __init__(self, mem, mem_format): + assert isinstance(mem, memoryview) + self.mem = mem + self.mem_format = mem_format + self.byte_width = struct.calcsize('=' + mem_format) + self.length = mem.nbytes // self.byte_width + + def _check_index(self, index): + if not 0 <= index < self.length: + raise IndexError("Wrong index for bitmap") + + def __len__(self): + return self.length + + def __getitem__(self, index): + self._check_index(index) + w = self.byte_width + # Cannot use memoryview.cast() because the 'e' format for half-floats + # is poorly supported. + mem = self.mem[index * w:(index + 1) * w] + return struct.unpack('=' + self.mem_format, mem) + + +class Bitmap(Sequence): + """ + View a bytes-compatible object as a sequence of bools. + """ + _masks = [0x1, 0x2, 0x4, 0x8, 0x10, 0x20, 0x40, 0x80] + + def __init__(self, view, offset, length): + self.view = view + self.offset = offset + self.length = length + + def _check_index(self, index): + if not 0 <= index < self.length: + raise IndexError("Wrong index for bitmap") + + def __len__(self): + return self.length + + def __getitem__(self, index): + self._check_index(index) + index += self.offset + byte_index, bit_index = divmod(index, 8) + byte = self.view[byte_index] + return byte & self._masks[bit_index] != 0 + + @classmethod + def from_buffer(cls, buf, offset, length): + assert isinstance(buf, Buffer) + byte_offset, bit_offset = divmod(offset, 8) + byte_length = math.ceil(length + offset / 8) - byte_offset + return cls(buf.bytes_view(byte_offset, byte_length), + bit_offset, length) + + +class MappedView(Sequence): + + def __init__(self, func, view): + self.view = view + self.func = func + + def __len__(self): + return len(self.view) + + def __getitem__(self, index): + return self.func(self.view[index]) + + +class StarMappedView(Sequence): + + def __init__(self, func, view): + self.view = view + self.func = func + + def __len__(self): + return len(self.view) + + def __getitem__(self, index): + return self.func(*self.view[index]) + + +class NullBitmap(Bitmap): + + def __getitem__(self, index): + self._check_index(index) + if self.view is None: + return True + return super().__getitem__(index) + + @classmethod + def from_buffer(cls, buf, offset, length): + """ + Create a null bitmap from a Buffer (or None if missing, + in which case all values are True). + """ + if buf is None: + return cls(buf, offset, length) + return super().from_buffer(buf, offset, length) + + KeyValue = namedtuple('KeyValue', ('key', 'value')) @@ -572,33 +850,44 @@ def __getitem__(self, i): return self.md[i] -DecimalTraits = namedtuple('DecimalTraits', ('nbits', 'struct_format_le')) +DecimalTraits = namedtuple('DecimalTraits', ('bit_width', 'struct_format_le')) decimal_traits = { 128: DecimalTraits(128, 'Qq'), 256: DecimalTraits(256, 'QQQq'), } -class Decimal: +class BaseDecimal: """ - A arrow::BasicDecimal{128,256...} value. + Base class for arrow::BasicDecimal{128,256...} values. """ - def __init__(self, traits, val): - self.val = val - self.traits = traits + def __init__(self, address): + self.address = address + + @classmethod + def from_value(cls, val): + """ + Create a decimal from a gdb.Value representing the corresponding + arrow::BasicDecimal{128,256...}. + """ + return cls(val['array_'].address) @classmethod - def from_bits(cls, nbits, *args, **kwargs): - return cls(decimal_traits[nbits], *args, **kwargs) + def from_address(cls, address): + """ + Create a decimal from a gdb.Value representing the address of the + raw decimal storage. + """ + return cls(address) @property def words(self): """ The decimal words, from least to most significant. """ - mem = gdb.selected_inferior().read_memory( - self.val['array_'].address, self.traits.nbits // 8) + mem = gdb.selected_inferior().read_memory(self.address, + self.traits.bit_width // 8) fmt = self.traits.struct_format_le if byte_order() == 'big': fmt = fmt[::-1] @@ -613,7 +902,7 @@ def __int__(self): """ v = 0 words = self.words - bits_per_word = self.traits.nbits // len(words) + bits_per_word = self.traits.bit_width // len(words) for w in reversed(words): v = (v << bits_per_word) + w return v @@ -629,12 +918,22 @@ def format(self, precision, scale): return str(decimal.Decimal(v).scaleb(-scale)) -Decimal128 = partial(Decimal.from_bits, 128) -Decimal256 = partial(Decimal.from_bits, 256) +class Decimal128(BaseDecimal): + traits = decimal_traits[128] + + +class Decimal256(BaseDecimal): + traits = decimal_traits[256] + + +decimal_bits_to_class = { + 128: Decimal128, + 256: Decimal256, +} decimal_type_to_class = { - 'Decimal128Type': Decimal128, - 'Decimal256Type': Decimal256, + f"Decimal{bits}Type": cls + for (bits, cls) in decimal_bits_to_class.items() } @@ -1015,7 +1314,7 @@ def to_string(self): if not self.is_valid: return self._format_null() value = self.val['value'] - return f"{self._format_type()} of value {value}d" + return f"{self._format_type()} of value {format_date32(value)}" class Date64ScalarPrinter(TimeScalarPrinter): @@ -1027,7 +1326,7 @@ def to_string(self): if not self.is_valid: return self._format_null() value = self.val['value'] - return f"{self._format_type()} of value {value}ms" + return f"{self._format_type()} of value {format_date64(value)}" class TimestampScalarPrinter(ScalarPrinter): @@ -1073,7 +1372,8 @@ def to_string(self): suffix = f"[precision={precision}, scale={scale}]" if not self.is_valid: return f"{self._format_type()} of null value {suffix}" - value = self.decimal_class(self.val['value']).format(precision, scale) + value = self.decimal_class.from_value(self.val['value'] + ).format(precision, scale) return f"{self._format_type()} of value {value} {suffix}" @@ -1223,10 +1523,12 @@ def __new__(cls, name, val): assert issubclass(cls, ArrayDataPrinter) self = object.__new__(cls) self.name = name + self.val = val self.type_class = type_class self.type_name = type_class.name self.type_id = type_id - self.val = val + self.offset = int(self.val['offset']) + self.length = int(self.val['length']) return self @property @@ -1238,15 +1540,254 @@ def type(self): return cast_to_concrete(deref(self.val['type']), concrete_type) def _format_contents(self): - return (f"length {self.val['length']}, " + return (f"length {self.length}, " + f"offset {self.offset}, " f"{format_null_count(self.val['null_count'])}") + def _buffer(self, index, type_id=None): + buffers = StdVector(self.val['buffers']) + bufptr = SharedPtr(buffers[index]).get() + if int(bufptr) == 0: + return None + if type_id is not None: + return TypedBuffer.from_type_id(bufptr.dereference(), type_id) + else: + return Buffer(bufptr.dereference()) + + def _buffer_values(self, index, type_id, length=None): + """ + Return a typed view of values in the buffer with the given index. + + Values are returned as tuples since some types may decode to + multiple values (for example day_time_interval). + """ + buf = self._buffer(index, type_id) + if buf is None: + return None + if length is None: + length = self.length + return buf.view(self.offset, length) + + def _unpacked_buffer_values(self, index, type_id, length=None): + """ + Like _buffer_values(), but assumes values are 1-tuples + and returns them unpacked. + """ + return StarMappedView(identity, + self._buffer_values(index, type_id, length)) + + def _null_bitmap(self): + buf = self._buffer(0) if has_null_bitmap(self.type_id) else None + return NullBitmap.from_buffer(buf, self.offset, self.length) + + def _null_child(self, i): + return str(i), "null" + + def _valid_child(self, i, value): + return str(i), value + + def display_hint(self): + return None + + def children(self): + return () + def to_string(self): ty = self.type return (f"{self.name} of type {ty}, " f"{self._format_contents()}") +class NumericArrayDataPrinter(ArrayDataPrinter): + """ + ArrayDataPrinter specialization for numeric data types. + """ + _format_value = staticmethod(identity) + + def _values_view(self): + return StarMappedView(self._format_value, + self._buffer_values(1, self.type_id)) + + def display_hint(self): + return "array" + + def children(self): + if self.length == 0: + return + values = self._values_view() + null_bits = self._null_bitmap() + for i, (valid, value) in enumerate(zip(null_bits, values)): + if valid: + yield self._valid_child(i, str(value)) + else: + yield self._null_child(i) + + +class BooleanArrayDataPrinter(NumericArrayDataPrinter): + """ + ArrayDataPrinter specialization for boolean. + """ + + def _format_value(self, v): + return str(v).lower() + + def _values_view(self): + return MappedView(self._format_value, + self._buffer_values(1, self.type_id)) + + +class Date32ArrayDataPrinter(NumericArrayDataPrinter): + """ + ArrayDataPrinter specialization for date32. + """ + _format_value = staticmethod(format_date32) + + +class Date64ArrayDataPrinter(NumericArrayDataPrinter): + """ + ArrayDataPrinter specialization for date64. + """ + _format_value = staticmethod(format_date64) + + +class TimeArrayDataPrinter(NumericArrayDataPrinter): + """ + ArrayDataPrinter specialization for time32 and time64. + """ + + def __init__(self, name, val): + self.unit = self.type['unit_'] + self.unit_string = short_time_unit(self.unit) + + def _format_value(self, val): + return f"{val}{self.unit_string}" + + +class TimestampArrayDataPrinter(NumericArrayDataPrinter): + """ + ArrayDataPrinter specialization for timestamp. + """ + + def __init__(self, name, val): + self.unit = self.type['unit_'] + + def _format_value(self, val): + return format_timestamp(val, self.unit) + + +class MonthIntervalArrayDataPrinter(NumericArrayDataPrinter): + """ + ArrayDataPrinter specialization for month_interval. + """ + _format_value = staticmethod(format_month_interval) + + +class DayTimeIntervalArrayDataPrinter(NumericArrayDataPrinter): + """ + ArrayDataPrinter specialization for day_time_interval. + """ + _format_value = staticmethod(format_days_milliseconds) + + +class MonthDayNanoIntervalArrayDataPrinter(NumericArrayDataPrinter): + """ + ArrayDataPrinter specialization for day_time_interval. + """ + _format_value = staticmethod(format_months_days_nanos) + + +class DecimalArrayDataPrinter(ArrayDataPrinter): + """ + ArrayDataPrinter specialization for decimals. + """ + + def __init__(self, name, val): + ty = self.type + self.precision = int(ty['precision_']) + self.scale = int(ty['scale_']) + self.decimal_class = decimal_type_to_class[self.type_name] + self.byte_width = self.decimal_class.traits.bit_width // 8 + + def display_hint(self): + return "array" + + def children(self): + if self.length == 0: + return + null_bits = self._null_bitmap() + address = self._buffer(1).data + self.offset * self.byte_width + for i, valid in enumerate(null_bits): + if valid: + dec = self.decimal_class.from_address(address) + yield self._valid_child( + i, dec.format(self.precision, self.scale)) + else: + yield self._null_child(i) + address += self.byte_width + + +class FixedSizeBinaryArrayDataPrinter(ArrayDataPrinter): + """ + ArrayDataPrinter specialization for fixed_size_binary. + """ + + def __init__(self, name, val): + self.byte_width = self.type['byte_width_'] + + def display_hint(self): + return "array" + + def children(self): + if self.length == 0: + return + null_bits = self._null_bitmap() + address = self._buffer(1).data + self.offset * self.byte_width + for i, valid in enumerate(null_bits): + if valid: + if self.byte_width: + yield self._valid_child( + i, bytes_literal(address, self.byte_width)) + else: + yield self._valid_child(i, '""') + else: + yield self._null_child(i) + address += self.byte_width + + +class BinaryArrayDataPrinter(ArrayDataPrinter): + """ + ArrayDataPrinter specialization for variable-sized binary. + """ + + def __init__(self, name, val): + self.is_large = self.type_id in (Type.LARGE_BINARY, Type.LARGE_STRING) + self.is_utf8 = self.type_id in (Type.STRING, Type.LARGE_STRING) + self.format_string = utf8_literal if self.is_utf8 else bytes_literal + + def display_hint(self): + return "array" + + def children(self): + if self.length == 0: + return + null_bits = self._null_bitmap() + offsets = self._unpacked_buffer_values( + 1, Type.INT64 if self.is_large else Type.INT32, + length=self.length + 1) + values = self._buffer(2).data + for i, valid in enumerate(null_bits): + if valid: + start = offsets[i] + size = offsets[i + 1] - start + if size: + yield self._valid_child( + i, self.format_string(values + start, size)) + else: + yield self._valid_child(i, '""') + else: + yield self._null_child(i) + + class ArrayPrinter: """ Pretty-printer for arrow::Array and subclasses. @@ -1267,6 +1808,12 @@ def to_string(self): else: return f"arrow::{self.name} of {self._format_contents()}" + def display_hint(self): + return self.data_printer.display_hint() + + def children(self): + return self.data_printer.children() + class ChunkedArrayPrinter: """ @@ -1294,7 +1841,6 @@ def to_string(self): class DataTypeClass: - array_data_printer = ArrayDataPrinter def __init__(self, name): @@ -1311,72 +1857,91 @@ class NumericTypeClass(DataTypeClass): is_parametric = False type_printer = PrimitiveTypePrinter scalar_printer = NumericScalarPrinter + array_data_printer = NumericArrayDataPrinter + + +class BooleanTypeClass(DataTypeClass): + is_parametric = False + type_printer = PrimitiveTypePrinter + scalar_printer = NumericScalarPrinter + array_data_printer = BooleanArrayDataPrinter class Date32TypeClass(DataTypeClass): is_parametric = False type_printer = PrimitiveTypePrinter scalar_printer = Date32ScalarPrinter + array_data_printer = Date32ArrayDataPrinter class Date64TypeClass(DataTypeClass): is_parametric = False type_printer = PrimitiveTypePrinter scalar_printer = Date64ScalarPrinter + array_data_printer = Date64ArrayDataPrinter class TimeTypeClass(DataTypeClass): is_parametric = True type_printer = TimeTypePrinter scalar_printer = TimeScalarPrinter + array_data_printer = TimeArrayDataPrinter class TimestampTypeClass(DataTypeClass): is_parametric = True type_printer = TimestampTypePrinter scalar_printer = TimestampScalarPrinter + array_data_printer = TimestampArrayDataPrinter class DurationTypeClass(DataTypeClass): is_parametric = True type_printer = TimeTypePrinter scalar_printer = TimeScalarPrinter + array_data_printer = TimeArrayDataPrinter class MonthIntervalTypeClass(DataTypeClass): is_parametric = False type_printer = PrimitiveTypePrinter scalar_printer = MonthIntervalScalarPrinter + array_data_printer = MonthIntervalArrayDataPrinter class DayTimeIntervalTypeClass(DataTypeClass): is_parametric = False type_printer = PrimitiveTypePrinter scalar_printer = NumericScalarPrinter + array_data_printer = DayTimeIntervalArrayDataPrinter class MonthDayNanoIntervalTypeClass(DataTypeClass): is_parametric = False type_printer = PrimitiveTypePrinter scalar_printer = NumericScalarPrinter + array_data_printer = MonthDayNanoIntervalArrayDataPrinter class DecimalTypeClass(DataTypeClass): is_parametric = True type_printer = DecimalTypePrinter scalar_printer = DecimalScalarPrinter + array_data_printer = DecimalArrayDataPrinter class BaseBinaryTypeClass(DataTypeClass): is_parametric = False type_printer = PrimitiveTypePrinter scalar_printer = BaseBinaryScalarPrinter + array_data_printer = BinaryArrayDataPrinter class FixedSizeBinaryTypeClass(DataTypeClass): is_parametric = True type_printer = FixedSizeBinaryTypePrinter scalar_printer = FixedSizeBinaryScalarPrinter + array_data_printer = FixedSizeBinaryArrayDataPrinter class BaseListTypeClass(DataTypeClass): @@ -1427,7 +1992,8 @@ class ExtensionTypeClass(DataTypeClass): type_traits_by_id = { Type.NA: DataTypeTraits(NullTypeClass, 'NullType'), - Type.BOOL: DataTypeTraits(NumericTypeClass, 'BooleanType'), + Type.BOOL: DataTypeTraits(BooleanTypeClass, 'BooleanType'), + Type.UINT8: DataTypeTraits(NumericTypeClass, 'UInt8Type'), Type.INT8: DataTypeTraits(NumericTypeClass, 'Int8Type'), Type.UINT16: DataTypeTraits(NumericTypeClass, 'UInt16Type'), @@ -1794,7 +2360,8 @@ def __init__(self, name, val): self.val = val def to_string(self): - return f"{self.val['days']}d{self.val['milliseconds']}ms" + return format_days_milliseconds(self.val['days'], + self.val['milliseconds']) class MonthDayNanosPrinter: @@ -1806,8 +2373,9 @@ def __init__(self, name, val): self.val = val def to_string(self): - return (f"{self.val['months']}M{self.val['days']}d" - f"{self.val['nanoseconds']}ns") + return format_months_days_nanos(self.val['months'], + self.val['days'], + self.val['nanoseconds']) class DecimalPrinter: @@ -1815,13 +2383,13 @@ class DecimalPrinter: Pretty-printer for Arrow decimal values. """ - def __init__(self, nbits, name, val): + def __init__(self, bit_width, name, val): self.name = name self.val = val - self.nbits = nbits + self.bit_width = bit_width def to_string(self): - dec = Decimal.from_bits(self.nbits, self.val) + dec = decimal_bits_to_class[self.bit_width].from_value(self.val) return f"{self.name}({int(dec)})" diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index 44f6f2f04eaeb..b6f1e2481faa0 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -204,6 +204,7 @@ set(ARROW_SRCS util/compression.cc util/counting_semaphore.cc util/cpu_info.cc + util/debug.cc util/decimal.cc util/delimiting.cc util/formatting.cc diff --git a/cpp/src/arrow/compute/api_scalar.cc b/cpp/src/arrow/compute/api_scalar.cc index b2cea4c89879f..a9e2565a3ea2d 100644 --- a/cpp/src/arrow/compute/api_scalar.cc +++ b/cpp/src/arrow/compute/api_scalar.cc @@ -780,6 +780,8 @@ SCALAR_EAGER_UNARY(Day, "day") SCALAR_EAGER_UNARY(DayOfYear, "day_of_year") SCALAR_EAGER_UNARY(Hour, "hour") SCALAR_EAGER_UNARY(YearMonthDay, "year_month_day") +SCALAR_EAGER_UNARY(IsDaylightSavings, "is_dst") +SCALAR_EAGER_UNARY(IsLeapYear, "is_leap_year") SCALAR_EAGER_UNARY(ISOCalendar, "iso_calendar") SCALAR_EAGER_UNARY(ISOWeek, "iso_week") SCALAR_EAGER_UNARY(ISOYear, "iso_year") @@ -792,6 +794,7 @@ SCALAR_EAGER_UNARY(Quarter, "quarter") SCALAR_EAGER_UNARY(Second, "second") SCALAR_EAGER_UNARY(Subsecond, "subsecond") SCALAR_EAGER_UNARY(USWeek, "us_week") +SCALAR_EAGER_UNARY(USYear, "us_year") SCALAR_EAGER_UNARY(Year, "year") Result AssumeTimezone(const Datum& arg, AssumeTimezoneOptions options, diff --git a/cpp/src/arrow/compute/api_scalar.h b/cpp/src/arrow/compute/api_scalar.h index 850c65051023e..1ba03fd7a64e5 100644 --- a/cpp/src/arrow/compute/api_scalar.h +++ b/cpp/src/arrow/compute/api_scalar.h @@ -1143,6 +1143,17 @@ Result CaseWhen(const Datum& cond, const std::vector& cases, ARROW_EXPORT Result Year(const Datum& values, ExecContext* ctx = NULLPTR); +/// \brief IsLeapYear returns if a year is a leap year for each element of `values` +/// +/// \param[in] values input to extract leap year indicator from +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 8.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result IsLeapYear(const Datum& values, ExecContext* ctx = NULLPTR); + /// \brief Month returns month for each element of `values`. /// Month is encoded as January=1, December=12 /// @@ -1219,6 +1230,20 @@ ARROW_EXPORT Result DayOfYear(const Datum& values, ExecContext* ctx = NUL ARROW_EXPORT Result ISOYear(const Datum& values, ExecContext* ctx = NULLPTR); +/// \brief USYear returns US epidemiological year number for each element of `values`. +/// First week of US epidemiological year has the majority (4 or more) of it's +/// days in January. Last week of US epidemiological year has the year's last +/// Wednesday in it. US epidemiological week starts on Sunday. +/// +/// \param[in] values input to extract US epidemiological year from +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 8.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result USYear(const Datum& values, ExecContext* ctx = NULLPTR); + /// \brief ISOWeek returns ISO week of year number for each element of `values`. /// First ISO week has the majority (4 or more) of its days in January. /// ISO week starts on Monday. Year can have 52 or 53 weeks. @@ -1410,6 +1435,18 @@ ARROW_EXPORT Result AssumeTimezone(const Datum& values, AssumeTimezoneOptions options, ExecContext* ctx = NULLPTR); +/// \brief IsDaylightSavings extracts if currently observing daylight savings for each +/// element of `values` +/// +/// \param[in] values input to extract daylight savings indicator from +/// \param[in] ctx the function execution context, optional +/// \return the resulting datum +/// +/// \since 8.0.0 +/// \note API not yet finalized +ARROW_EXPORT Result IsDaylightSavings(const Datum& values, + ExecContext* ctx = NULLPTR); + /// \brief Finds either the FIRST, LAST, or ALL items with a key that matches the given /// query key in a map. /// diff --git a/cpp/src/arrow/compute/exec/hash_join.cc b/cpp/src/arrow/compute/exec/hash_join.cc index c1e587598fade..5a9afaa5bdf5f 100644 --- a/cpp/src/arrow/compute/exec/hash_join.cc +++ b/cpp/src/arrow/compute/exec/hash_join.cc @@ -103,7 +103,7 @@ class HashJoinBasicImpl : public HashJoinImpl { filter_ = std::move(filter); output_batch_callback_ = std::move(output_batch_callback); finished_callback_ = std::move(finished_callback); - local_states_.resize(num_threads); + local_states_.resize(num_threads + 1); // +1 for calling thread + worker threads for (size_t i = 0; i < local_states_.size(); ++i) { local_states_[i].is_initialized = false; local_states_[i].is_has_match_initialized = false; @@ -151,6 +151,7 @@ class HashJoinBasicImpl : public HashJoinImpl { } void InitLocalStateIfNeeded(size_t thread_index) { + DCHECK_LT(thread_index, local_states_.size()); ThreadLocalState& local_state = local_states_[thread_index]; if (!local_state.is_initialized) { InitEncoder(0, HashJoinProjection::KEY, &local_state.exec_batch_keys); diff --git a/cpp/src/arrow/compute/exec/hash_join.h b/cpp/src/arrow/compute/exec/hash_join.h index 83ad4cb3f90f1..12455f0c6d021 100644 --- a/cpp/src/arrow/compute/exec/hash_join.h +++ b/cpp/src/arrow/compute/exec/hash_join.h @@ -59,8 +59,8 @@ class ARROW_EXPORT HashJoinSchema { Result BindFilter(Expression filter, const Schema& left_schema, const Schema& right_schema); - std::shared_ptr MakeOutputSchema(const std::string& left_field_name_prefix, - const std::string& right_field_name_prefix); + std::shared_ptr MakeOutputSchema(const std::string& left_field_name_suffix, + const std::string& right_field_name_suffix); bool LeftPayloadIsEmpty() { return PayloadIsEmpty(0); } diff --git a/cpp/src/arrow/compute/exec/hash_join_dict.cc b/cpp/src/arrow/compute/exec/hash_join_dict.cc index b923433b493ee..ac1fbbaa3df00 100644 --- a/cpp/src/arrow/compute/exec/hash_join_dict.cc +++ b/cpp/src/arrow/compute/exec/hash_join_dict.cc @@ -566,7 +566,7 @@ Status HashJoinDictBuildMulti::PostDecode( } void HashJoinDictProbeMulti::Init(size_t num_threads) { - local_states_.resize(num_threads); + local_states_.resize(num_threads + 1); // +1 for calling thread + worker threads for (size_t i = 0; i < local_states_.size(); ++i) { local_states_[i].is_initialized = false; } diff --git a/cpp/src/arrow/compute/exec/hash_join_node.cc b/cpp/src/arrow/compute/exec/hash_join_node.cc index 9295b5aaf4d7f..93e54c6400e57 100644 --- a/cpp/src/arrow/compute/exec/hash_join_node.cc +++ b/cpp/src/arrow/compute/exec/hash_join_node.cc @@ -86,8 +86,8 @@ Status HashJoinSchema::Init(JoinType join_type, const Schema& left_schema, const Schema& right_schema, const std::vector& right_keys, const Expression& filter, - const std::string& left_field_name_prefix, - const std::string& right_field_name_prefix) { + const std::string& left_field_name_suffix, + const std::string& right_field_name_suffix) { std::vector left_output; if (join_type != JoinType::RIGHT_SEMI && join_type != JoinType::RIGHT_ANTI) { const FieldVector& left_fields = left_schema.fields(); @@ -106,18 +106,18 @@ Status HashJoinSchema::Init(JoinType join_type, const Schema& left_schema, } } return Init(join_type, left_schema, left_keys, left_output, right_schema, right_keys, - right_output, filter, left_field_name_prefix, right_field_name_prefix); + right_output, filter, left_field_name_suffix, right_field_name_suffix); } Status HashJoinSchema::Init( JoinType join_type, const Schema& left_schema, const std::vector& left_keys, const std::vector& left_output, const Schema& right_schema, const std::vector& right_keys, const std::vector& right_output, - const Expression& filter, const std::string& left_field_name_prefix, - const std::string& right_field_name_prefix) { + const Expression& filter, const std::string& left_field_name_suffix, + const std::string& right_field_name_suffix) { RETURN_NOT_OK(ValidateSchemas(join_type, left_schema, left_keys, left_output, right_schema, right_keys, right_output, - left_field_name_prefix, right_field_name_prefix)); + left_field_name_suffix, right_field_name_suffix)); std::vector handles; std::vector*> field_refs; @@ -172,8 +172,8 @@ Status HashJoinSchema::ValidateSchemas(JoinType join_type, const Schema& left_sc const Schema& right_schema, const std::vector& right_keys, const std::vector& right_output, - const std::string& left_field_name_prefix, - const std::string& right_field_name_prefix) { + const std::string& left_field_name_suffix, + const std::string& right_field_name_suffix) { // Checks for key fields: // 1. Key field refs must match exactly one input field // 2. Same number of key fields on left and right @@ -241,7 +241,7 @@ Status HashJoinSchema::ValidateSchemas(JoinType join_type, const Schema& left_sc // 4. Left semi/anti join (right semi/anti join) must not output fields from right // (left) // 5. No name collisions in output fields after adding (potentially empty) - // prefixes to left and right output + // suffixes to left and right output // if (left_output.empty() && right_output.empty()) { return Status::Invalid("Join must output at least one field"); @@ -275,30 +275,60 @@ Status HashJoinSchema::ValidateSchemas(JoinType join_type, const Schema& left_sc } std::shared_ptr HashJoinSchema::MakeOutputSchema( - const std::string& left_field_name_prefix, - const std::string& right_field_name_prefix) { + const std::string& left_field_name_suffix, + const std::string& right_field_name_suffix) { std::vector> fields; int left_size = proj_maps[0].num_cols(HashJoinProjection::OUTPUT); int right_size = proj_maps[1].num_cols(HashJoinProjection::OUTPUT); fields.resize(left_size + right_size); - for (int i = 0; i < left_size + right_size; ++i) { - bool is_left = (i < left_size); - int side = (is_left ? 0 : 1); - int input_field_id = proj_maps[side] - .map(HashJoinProjection::OUTPUT, HashJoinProjection::INPUT) - .get(is_left ? i : i - left_size); + std::unordered_multimap left_field_map; + left_field_map.reserve(left_size); + for (int i = 0; i < left_size; ++i) { + int side = 0; // left + int input_field_id = + proj_maps[side].map(HashJoinProjection::OUTPUT, HashJoinProjection::INPUT).get(i); const std::string& input_field_name = proj_maps[side].field_name(HashJoinProjection::INPUT, input_field_id); const std::shared_ptr& input_data_type = proj_maps[side].data_type(HashJoinProjection::INPUT, input_field_id); + left_field_map.insert({input_field_name, i}); + // insert left table field + fields[i] = + std::make_shared(input_field_name, input_data_type, true /*nullable*/); + } - std::string output_field_name = - (is_left ? left_field_name_prefix : right_field_name_prefix) + input_field_name; + for (int i = 0; i < right_size; ++i) { + int side = 1; // right + int input_field_id = + proj_maps[side].map(HashJoinProjection::OUTPUT, HashJoinProjection::INPUT).get(i); + const std::string& input_field_name = + proj_maps[side].field_name(HashJoinProjection::INPUT, input_field_id); + const std::shared_ptr& input_data_type = + proj_maps[side].data_type(HashJoinProjection::INPUT, input_field_id); + // search the map and add suffix to the elements which + // are present both in left and right tables + auto search_it = left_field_map.equal_range(input_field_name); + bool match_found = false; + for (auto search = search_it.first; search != search_it.second; ++search) { + match_found = true; + auto left_val = search->first; + auto left_index = search->second; + auto left_field = fields[left_index]; + // update left table field with suffix + fields[left_index] = + std::make_shared(input_field_name + left_field_name_suffix, + left_field->type(), true /*nullable*/); + // insert right table field with suffix + fields[left_size + i] = std::make_shared( + input_field_name + right_field_name_suffix, input_data_type, true /*nullable*/); + } - // All fields coming out of join are marked as nullable. - fields[i] = - std::make_shared(output_field_name, input_data_type, true /*nullable*/); + if (!match_found) { + // insert right table field without suffix + fields[left_size + i] = + std::make_shared(input_field_name, input_data_type, true /*nullable*/); + } } return std::make_shared(std::move(fields)); } @@ -452,18 +482,19 @@ class HashJoinNode : public ExecNode { const auto& left_schema = *(inputs[0]->output_schema()); const auto& right_schema = *(inputs[1]->output_schema()); + // This will also validate input schemas if (join_options.output_all) { RETURN_NOT_OK(schema_mgr->Init( join_options.join_type, left_schema, join_options.left_keys, right_schema, join_options.right_keys, join_options.filter, - join_options.output_prefix_for_left, join_options.output_prefix_for_right)); + join_options.output_suffix_for_left, join_options.output_suffix_for_right)); } else { RETURN_NOT_OK(schema_mgr->Init( join_options.join_type, left_schema, join_options.left_keys, join_options.left_output, right_schema, join_options.right_keys, join_options.right_output, join_options.filter, - join_options.output_prefix_for_left, join_options.output_prefix_for_right)); + join_options.output_suffix_for_left, join_options.output_suffix_for_right)); } ARROW_ASSIGN_OR_RAISE( @@ -472,8 +503,7 @@ class HashJoinNode : public ExecNode { // Generate output schema std::shared_ptr output_schema = schema_mgr->MakeOutputSchema( - join_options.output_prefix_for_left, join_options.output_prefix_for_right); - + join_options.output_suffix_for_left, join_options.output_suffix_for_right); // Create hash join implementation object ARROW_ASSIGN_OR_RAISE(std::unique_ptr impl, HashJoinImpl::MakeBasic()); diff --git a/cpp/src/arrow/compute/exec/hash_join_node_test.cc b/cpp/src/arrow/compute/exec/hash_join_node_test.cc index 93c5c050aad31..96469a78ab2fb 100644 --- a/cpp/src/arrow/compute/exec/hash_join_node_test.cc +++ b/cpp/src/arrow/compute/exec/hash_join_node_test.cc @@ -922,6 +922,73 @@ void HashJoinWithExecPlan(Random64Bit& rng, bool parallel, ASSERT_OK_AND_ASSIGN(*output, TableFromExecBatches(output_schema, res)); } +TEST(HashJoin, Suffix) { + BatchesWithSchema input_left; + input_left.batches = {ExecBatchFromJSON({int32(), int32(), int32()}, R"([ + [1, 4, 7], + [2, 5, 8], + [3, 6, 9] + ])")}; + input_left.schema = schema( + {field("lkey", int32()), field("shared", int32()), field("ldistinct", int32())}); + + BatchesWithSchema input_right; + input_right.batches = {ExecBatchFromJSON({int32(), int32(), int32()}, R"([ + [1, 10, 13], + [2, 11, 14], + [3, 12, 15] + ])")}; + input_right.schema = schema( + {field("rkey", int32()), field("shared", int32()), field("rdistinct", int32())}); + + BatchesWithSchema expected; + expected.batches = { + ExecBatchFromJSON({int32(), int32(), int32(), int32(), int32(), int32()}, R"([ + [1, 4, 7, 1, 10, 13], + [2, 5, 8, 2, 11, 14], + [3, 6, 9, 3, 12, 15] + ])")}; + + expected.schema = schema({field("lkey", int32()), field("shared_l", int32()), + field("ldistinct", int32()), field("rkey", int32()), + field("shared_r", int32()), field("rdistinct", int32())}); + + ExecContext exec_ctx; + + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(&exec_ctx)); + AsyncGenerator> sink_gen; + + ExecNode* left_source; + ExecNode* right_source; + ASSERT_OK_AND_ASSIGN( + left_source, + MakeExecNode("source", plan.get(), {}, + SourceNodeOptions{input_left.schema, input_left.gen(/*parallel=*/false, + /*slow=*/false)})); + + ASSERT_OK_AND_ASSIGN(right_source, + MakeExecNode("source", plan.get(), {}, + SourceNodeOptions{input_right.schema, + input_right.gen(/*parallel=*/false, + /*slow=*/false)})) + + HashJoinNodeOptions join_opts{JoinType::INNER, + /*left_keys=*/{"lkey"}, + /*right_keys=*/{"rkey"}, literal(true), "_l", "_r"}; + + ASSERT_OK_AND_ASSIGN( + auto hashjoin, + MakeExecNode("hashjoin", plan.get(), {left_source, right_source}, join_opts)); + + ASSERT_OK_AND_ASSIGN(std::ignore, MakeExecNode("sink", plan.get(), {hashjoin}, + SinkNodeOptions{&sink_gen})); + + ASSERT_FINISHES_OK_AND_ASSIGN(auto result, StartAndCollect(plan.get(), sink_gen)); + + AssertExecBatchesEqual(expected.schema, expected.batches, result); + AssertSchemaEqual(expected.schema, hashjoin->output_schema()); +} + TEST(HashJoin, Random) { Random64Bit rng(42); #if defined(THREAD_SANITIZER) diff --git a/cpp/src/arrow/compute/exec/options.h b/cpp/src/arrow/compute/exec/options.h index 856ded06e5e81..9e99953e87218 100644 --- a/cpp/src/arrow/compute/exec/options.h +++ b/cpp/src/arrow/compute/exec/options.h @@ -205,19 +205,19 @@ enum class JoinKeyCmp { EQ, IS }; /// \brief Make a node which implements join operation using hash join strategy. class ARROW_EXPORT HashJoinNodeOptions : public ExecNodeOptions { public: - static constexpr const char* default_output_prefix_for_left = ""; - static constexpr const char* default_output_prefix_for_right = ""; + static constexpr const char* default_output_suffix_for_left = ""; + static constexpr const char* default_output_suffix_for_right = ""; HashJoinNodeOptions( JoinType in_join_type, std::vector in_left_keys, std::vector in_right_keys, Expression filter = literal(true), - std::string output_prefix_for_left = default_output_prefix_for_left, - std::string output_prefix_for_right = default_output_prefix_for_right) + std::string output_suffix_for_left = default_output_suffix_for_left, + std::string output_suffix_for_right = default_output_suffix_for_right) : join_type(in_join_type), left_keys(std::move(in_left_keys)), right_keys(std::move(in_right_keys)), output_all(true), - output_prefix_for_left(std::move(output_prefix_for_left)), - output_prefix_for_right(std::move(output_prefix_for_right)), + output_suffix_for_left(std::move(output_suffix_for_left)), + output_suffix_for_right(std::move(output_suffix_for_right)), filter(std::move(filter)) { this->key_cmp.resize(this->left_keys.size()); for (size_t i = 0; i < this->left_keys.size(); ++i) { @@ -228,16 +228,16 @@ class ARROW_EXPORT HashJoinNodeOptions : public ExecNodeOptions { JoinType join_type, std::vector left_keys, std::vector right_keys, std::vector left_output, std::vector right_output, Expression filter = literal(true), - std::string output_prefix_for_left = default_output_prefix_for_left, - std::string output_prefix_for_right = default_output_prefix_for_right) + std::string output_suffix_for_left = default_output_suffix_for_left, + std::string output_suffix_for_right = default_output_suffix_for_right) : join_type(join_type), left_keys(std::move(left_keys)), right_keys(std::move(right_keys)), output_all(false), left_output(std::move(left_output)), right_output(std::move(right_output)), - output_prefix_for_left(std::move(output_prefix_for_left)), - output_prefix_for_right(std::move(output_prefix_for_right)), + output_suffix_for_left(std::move(output_suffix_for_left)), + output_suffix_for_right(std::move(output_suffix_for_right)), filter(std::move(filter)) { this->key_cmp.resize(this->left_keys.size()); for (size_t i = 0; i < this->left_keys.size(); ++i) { @@ -249,8 +249,8 @@ class ARROW_EXPORT HashJoinNodeOptions : public ExecNodeOptions { std::vector right_keys, std::vector left_output, std::vector right_output, std::vector key_cmp, Expression filter = literal(true), - std::string output_prefix_for_left = default_output_prefix_for_left, - std::string output_prefix_for_right = default_output_prefix_for_right) + std::string output_suffix_for_left = default_output_suffix_for_left, + std::string output_suffix_for_right = default_output_suffix_for_right) : join_type(join_type), left_keys(std::move(left_keys)), right_keys(std::move(right_keys)), @@ -258,8 +258,8 @@ class ARROW_EXPORT HashJoinNodeOptions : public ExecNodeOptions { left_output(std::move(left_output)), right_output(std::move(right_output)), key_cmp(std::move(key_cmp)), - output_prefix_for_left(std::move(output_prefix_for_left)), - output_prefix_for_right(std::move(output_prefix_for_right)), + output_suffix_for_left(std::move(output_suffix_for_left)), + output_suffix_for_right(std::move(output_suffix_for_right)), filter(std::move(filter)) {} // type of join (inner, left, semi...) @@ -278,12 +278,12 @@ class ARROW_EXPORT HashJoinNodeOptions : public ExecNodeOptions { // key comparison function (determines whether a null key is equal another null // key or not) std::vector key_cmp; - // prefix added to names of output fields coming from left input (used to - // distinguish, if necessary, between fields of the same name in left and right - // input and can be left empty if there are no name collisions) - std::string output_prefix_for_left; - // prefix added to names of output fields coming from right input - std::string output_prefix_for_right; + // suffix added to names of output fields coming from left input (used to distinguish, + // if necessary, between fields of the same name in left and right input and can be left + // empty if there are no name collisions) + std::string output_suffix_for_left; + // suffix added to names of output fields coming from right input + std::string output_suffix_for_right; // residual filter which is applied to matching rows. Rows that do not match // the filter are not included. The filter is applied against the // concatenated input schema (left fields then right fields) and can reference diff --git a/cpp/src/arrow/compute/exec/util_test.cc b/cpp/src/arrow/compute/exec/util_test.cc index 6f4b5315fff5b..6d85991735160 100644 --- a/cpp/src/arrow/compute/exec/util_test.cc +++ b/cpp/src/arrow/compute/exec/util_test.cc @@ -25,8 +25,8 @@ using testing::Eq; namespace arrow { namespace compute { -const char* kLeftPrefix = "left."; -const char* kRightPrefix = "right."; +const char* kLeftSuffix = ".left"; +const char* kRightSuffix = ".right"; TEST(FieldMap, Trivial) { HashJoinSchema schema_mgr; @@ -35,12 +35,12 @@ TEST(FieldMap, Trivial) { auto right = schema({field("i32", int32())}); ASSERT_OK(schema_mgr.Init(JoinType::INNER, *left, {"i32"}, *right, {"i32"}, - literal(true), kLeftPrefix, kRightPrefix)); + literal(true), kLeftSuffix, kRightSuffix)); - auto output = schema_mgr.MakeOutputSchema(kLeftPrefix, kRightPrefix); + auto output = schema_mgr.MakeOutputSchema(kLeftSuffix, kRightSuffix); EXPECT_THAT(*output, Eq(Schema({ - field("left.i32", int32()), - field("right.i32", int32()), + field("i32.left", int32()), + field("i32.right", int32()), }))); auto i = @@ -75,7 +75,7 @@ TEST(FieldMap, SingleKeyField) { auto right = schema({field("f32", float32()), field("i32", int32())}); ASSERT_OK(schema_mgr.Init(JoinType::INNER, *left, {"i32"}, *right, {"i32"}, - literal(true), kLeftPrefix, kRightPrefix)); + literal(true), kLeftSuffix, kRightSuffix)); EXPECT_EQ(schema_mgr.proj_maps[0].num_cols(HashJoinProjection::INPUT), 2); EXPECT_EQ(schema_mgr.proj_maps[1].num_cols(HashJoinProjection::INPUT), 2); @@ -84,12 +84,12 @@ TEST(FieldMap, SingleKeyField) { EXPECT_EQ(schema_mgr.proj_maps[0].num_cols(HashJoinProjection::OUTPUT), 2); EXPECT_EQ(schema_mgr.proj_maps[1].num_cols(HashJoinProjection::OUTPUT), 2); - auto output = schema_mgr.MakeOutputSchema(kLeftPrefix, kRightPrefix); + auto output = schema_mgr.MakeOutputSchema(kLeftSuffix, kRightSuffix); EXPECT_THAT(*output, Eq(Schema({ - field("left.i32", int32()), - field("left.str", utf8()), - field("right.f32", float32()), - field("right.i32", int32()), + field("i32.left", int32()), + field("str", utf8()), + field("f32", float32()), + field("i32.right", int32()), }))); auto i = @@ -113,18 +113,18 @@ TEST(FieldMap, TwoKeyFields) { }); ASSERT_OK(schema_mgr.Init(JoinType::INNER, *left, {"i32", "str"}, *right, - {"i32", "str"}, literal(true), kLeftPrefix, kRightPrefix)); + {"i32", "str"}, literal(true), kLeftSuffix, kRightSuffix)); - auto output = schema_mgr.MakeOutputSchema(kLeftPrefix, kRightPrefix); + auto output = schema_mgr.MakeOutputSchema(kLeftSuffix, kRightSuffix); EXPECT_THAT(*output, Eq(Schema({ - field("left.i32", int32()), - field("left.str", utf8()), - field("left.bool", boolean()), - - field("right.i32", int32()), - field("right.str", utf8()), - field("right.f32", float32()), - field("right.f64", float64()), + field("i32.left", int32()), + field("str.left", utf8()), + field("bool", boolean()), + + field("i32.right", int32()), + field("str.right", utf8()), + field("f32", float32()), + field("f64", float64()), }))); } diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.cc b/cpp/src/arrow/compute/kernels/codegen_internal.cc index ee23d9598463e..b31ef408b10dc 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.cc +++ b/cpp/src/arrow/compute/kernels/codegen_internal.cc @@ -120,7 +120,14 @@ void ReplaceTemporalTypes(const TimeUnit::type unit, std::vector* de continue; } case Type::TIME32: - case Type::TIME64: + case Type::TIME64: { + if (unit > TimeUnit::MILLI) { + it->type = time64(unit); + } else { + it->type = time32(unit); + } + continue; + } case Type::DURATION: { it->type = duration(unit); continue; diff --git a/cpp/src/arrow/compute/kernels/codegen_internal_test.cc b/cpp/src/arrow/compute/kernels/codegen_internal_test.cc index 5e9109c0a3e05..31e0ffe1d3402 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal_test.cc +++ b/cpp/src/arrow/compute/kernels/codegen_internal_test.cc @@ -224,6 +224,14 @@ TEST(TestDispatchBest, CommonTemporalResolution) { ASSERT_EQ(TimeUnit::MILLI, ty); args = {timestamp(TimeUnit::SECOND, "UTC"), timestamp(TimeUnit::SECOND, tz)}; ASSERT_EQ(TimeUnit::SECOND, CommonTemporalResolution(args.data(), args.size())); + args = {time32(TimeUnit::MILLI), duration(TimeUnit::SECOND)}; + ASSERT_EQ(TimeUnit::MILLI, CommonTemporalResolution(args.data(), args.size())); + args = {time64(TimeUnit::MICRO), duration(TimeUnit::NANO)}; + ASSERT_EQ(TimeUnit::NANO, CommonTemporalResolution(args.data(), args.size())); + args = {duration(TimeUnit::SECOND), int64()}; + ASSERT_EQ(TimeUnit::SECOND, CommonTemporalResolution(args.data(), args.size())); + args = {duration(TimeUnit::MILLI), timestamp(TimeUnit::SECOND, tz)}; + ASSERT_EQ(TimeUnit::MILLI, CommonTemporalResolution(args.data(), args.size())); } TEST(TestDispatchBest, ReplaceTemporalTypes) { @@ -241,7 +249,7 @@ TEST(TestDispatchBest, ReplaceTemporalTypes) { ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty)); ReplaceTemporalTypes(ty, &args); AssertTypeEqual(args[0].type, timestamp(TimeUnit::MILLI)); - AssertTypeEqual(args[1].type, duration(TimeUnit::MILLI)); + AssertTypeEqual(args[1].type, time32(TimeUnit::MILLI)); args = {duration(TimeUnit::SECOND), date64()}; ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty)); @@ -259,7 +267,7 @@ TEST(TestDispatchBest, ReplaceTemporalTypes) { ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty)); ReplaceTemporalTypes(ty, &args); AssertTypeEqual(args[0].type, timestamp(TimeUnit::NANO, tz)); - AssertTypeEqual(args[1].type, duration(TimeUnit::NANO)); + AssertTypeEqual(args[1].type, time64(TimeUnit::NANO)); args = {timestamp(TimeUnit::SECOND, tz), date64()}; ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty)); @@ -269,8 +277,32 @@ TEST(TestDispatchBest, ReplaceTemporalTypes) { args = {timestamp(TimeUnit::SECOND, "UTC"), timestamp(TimeUnit::SECOND, tz)}; ty = CommonTemporalResolution(args.data(), args.size()); + ReplaceTemporalTypes(ty, &args); AssertTypeEqual(args[0].type, timestamp(TimeUnit::SECOND, "UTC")); AssertTypeEqual(args[1].type, timestamp(TimeUnit::SECOND, tz)); + + args = {time32(TimeUnit::SECOND), duration(TimeUnit::SECOND)}; + ty = CommonTemporalResolution(args.data(), args.size()); + ReplaceTemporalTypes(ty, &args); + AssertTypeEqual(args[0].type, time32(TimeUnit::SECOND)); + AssertTypeEqual(args[1].type, duration(TimeUnit::SECOND)); + + args = {time64(TimeUnit::MICRO), duration(TimeUnit::SECOND)}; + ty = CommonTemporalResolution(args.data(), args.size()); + ReplaceTemporalTypes(ty, &args); + AssertTypeEqual(args[0].type, time64(TimeUnit::MICRO)); + AssertTypeEqual(args[1].type, duration(TimeUnit::MICRO)); + + args = {time32(TimeUnit::SECOND), duration(TimeUnit::NANO)}; + ty = CommonTemporalResolution(args.data(), args.size()); + ReplaceTemporalTypes(ty, &args); + AssertTypeEqual(args[0].type, time64(TimeUnit::NANO)); + AssertTypeEqual(args[1].type, duration(TimeUnit::NANO)); + + args = {duration(TimeUnit::SECOND), int64()}; + ReplaceTemporalTypes(CommonTemporalResolution(args.data(), args.size()), &args); + AssertTypeEqual(args[0].type, duration(TimeUnit::SECOND)); + AssertTypeEqual(args[1].type, int64()); } } // namespace internal diff --git a/cpp/src/arrow/compute/kernels/copy_data_internal.h b/cpp/src/arrow/compute/kernels/copy_data_internal.h new file mode 100644 index 0000000000000..5a5d4463456aa --- /dev/null +++ b/cpp/src/arrow/compute/kernels/copy_data_internal.h @@ -0,0 +1,112 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "arrow/compute/kernels/codegen_internal.h" + +namespace arrow { +namespace compute { +namespace internal { + +template +struct CopyDataUtils {}; + +template <> +struct CopyDataUtils { + static void CopyData(const DataType&, const Scalar& in, const int64_t in_offset, + uint8_t* out, const int64_t out_offset, const int64_t length) { + bit_util::SetBitsTo( + out, out_offset, length, + in.is_valid ? checked_cast(in).value : false); + } + + static void CopyData(const DataType&, const uint8_t* in, const int64_t in_offset, + uint8_t* out, const int64_t out_offset, const int64_t length) { + arrow::internal::CopyBitmap(in, in_offset, length, out, out_offset); + } + + static void CopyData(const DataType&, const ArrayData& in, const int64_t in_offset, + uint8_t* out, const int64_t out_offset, const int64_t length) { + const auto in_arr = in.GetValues(1, /*absolute_offset=*/0); + CopyData(*in.type, in_arr, in_offset, out, out_offset, length); + } +}; + +template <> +struct CopyDataUtils { + static void CopyData(const DataType& ty, const Scalar& in, const int64_t in_offset, + uint8_t* out, const int64_t out_offset, const int64_t length) { + const int32_t width = checked_cast(ty).byte_width(); + uint8_t* begin = out + (width * out_offset); + const auto& scalar = checked_cast(in); + // Null scalar may have null value buffer + if (!scalar.is_valid) { + std::memset(begin, 0x00, width * length); + } else { + const util::string_view buffer = scalar.view(); + DCHECK_GE(buffer.size(), static_cast(width)); + for (int i = 0; i < length; i++) { + std::memcpy(begin, buffer.data(), width); + begin += width; + } + } + } + + static void CopyData(const DataType& ty, const uint8_t* in, const int64_t in_offset, + uint8_t* out, const int64_t out_offset, const int64_t length) { + const int32_t width = checked_cast(ty).byte_width(); + uint8_t* begin = out + (width * out_offset); + std::memcpy(begin, in + in_offset * width, length * width); + } + + static void CopyData(const DataType& ty, const ArrayData& in, const int64_t in_offset, + uint8_t* out, const int64_t out_offset, const int64_t length) { + const int32_t width = checked_cast(ty).byte_width(); + const auto in_arr = in.GetValues(1, in.offset * width); + CopyData(ty, in_arr, in_offset, out, out_offset, length); + } +}; + +template +struct CopyDataUtils< + Type, enable_if_t::value || is_interval_type::value>> { + using CType = typename TypeTraits::CType; + + static void CopyData(const DataType&, const Scalar& in, const int64_t in_offset, + uint8_t* out, const int64_t out_offset, const int64_t length) { + CType* begin = reinterpret_cast(out) + out_offset; + CType* end = begin + length; + std::fill(begin, end, UnboxScalar::Unbox(in)); + } + + static void CopyData(const DataType&, const uint8_t* in, const int64_t in_offset, + uint8_t* out, const int64_t out_offset, const int64_t length) { + std::memcpy(out + out_offset * sizeof(CType), in + in_offset * sizeof(CType), + length * sizeof(CType)); + } + + static void CopyData(const DataType&, const ArrayData& in, const int64_t in_offset, + uint8_t* out, const int64_t out_offset, const int64_t length) { + const auto in_arr = in.GetValues(1, in.offset * sizeof(CType)); + CopyData(*in.type, in_arr, in_offset, out, out_offset, length); + } +}; + +} // namespace internal +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc index 100f74eee33d0..103f8e66c5017 100644 --- a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc +++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc @@ -240,6 +240,66 @@ struct SubtractCheckedDate32 { } }; +template +struct AddTimeDuration { + template + static T Call(KernelContext*, Arg0 left, Arg1 right, Status* st) { + T result = + arrow::internal::SafeSignedAdd(static_cast(left), static_cast(right)); + if (result < 0 || multiple <= result) { + *st = Status::Invalid(result, " is not within the acceptable range of ", "[0, ", + multiple, ") s"); + } + return result; + } +}; + +template +struct AddTimeDurationChecked { + template + static T Call(KernelContext*, Arg0 left, Arg1 right, Status* st) { + T result = 0; + if (ARROW_PREDICT_FALSE( + AddWithOverflow(static_cast(left), static_cast(right), &result))) { + *st = Status::Invalid("overflow"); + } + if (result < 0 || multiple <= result) { + *st = Status::Invalid(result, " is not within the acceptable range of ", "[0, ", + multiple, ") s"); + } + return result; + } +}; + +template +struct SubtractTimeDuration { + template + static T Call(KernelContext*, Arg0 left, Arg1 right, Status* st) { + T result = arrow::internal::SafeSignedSubtract(left, static_cast(right)); + if (result < 0 || multiple <= result) { + *st = Status::Invalid(result, " is not within the acceptable range of ", "[0, ", + multiple, ") s"); + } + return result; + } +}; + +template +struct SubtractTimeDurationChecked { + template + static T Call(KernelContext*, Arg0 left, Arg1 right, Status* st) { + T result = 0; + if (ARROW_PREDICT_FALSE(SubtractWithOverflow(left, static_cast(right), &result))) { + *st = Status::Invalid("overflow"); + } + if (result < 0 || multiple <= result) { + *st = Status::Invalid(result, " is not within the acceptable range of ", "[0, ", + multiple, ") s"); + } + return result; + } +}; + struct Multiply { static_assert(std::is_same::value, ""); static_assert(std::is_same::value, ""); @@ -2157,6 +2217,58 @@ std::shared_ptr MakeArithmeticFunctionFloatingPointNotNull( return func; } +template