From 55c70b62f8d9f55865d7cb79af0eb1146e86a094 Mon Sep 17 00:00:00 2001 From: Yue Date: Wed, 8 Nov 2023 08:56:03 +0800 Subject: [PATCH] GH-37753: [C++][Gandiva] Add external function registry support (#38116) # Rationale for this change This PR tries to enhance Gandiva by supporting external function registry, so that developers can author third party functions without modifying Gandiva's core codebase. See https://github.com/apache/arrow/issues/37753 for more details. In this PR, the external function needs to be compiled into LLVM IR for integration. # What changes are included in this PR? Two new APIs are added to `FunctionRegistry`: ```C++ /// \brief register a set of functions into the function registry from a given bitcode /// file arrow::Status Register(const std::vector& funcs, const std::string& bitcode_path); /// \brief register a set of functions into the function registry from a given bitcode /// buffer arrow::Status Register(const std::vector& funcs, std::shared_ptr bitcode_buffer); ``` Developers can use these two APIs to register external functions. Typically, developers will register a set of function metadatas (`funcs`) for all functions in a LLVM bitcode file, by giving either the path to the LLVM bitcode file or an `arrow::Buffer` containing the LLVM bitcode buffer. The overall flow looks like this: ![image](https://github.com/apache/arrow/assets/27754/b2b346fe-931f-4253-b198-4c388c57a56b) # Are these changes tested? Some unit tests are added to verify this enhancement # Are there any user-facing changes? Some new ways to interfacing the library are added in this PR: * The `Configuration` class now supports accepting a customized function registry, which developers can register their own external functions and uses it as the function registry * The `FunctionRegistry` class has two new APIs mentioned above * The `FunctionRegistry` class, after instantiation, now it doesn't have any built-in function registered in it. And we switch to use a new function `GANDIVA_EXPORT std::shared_ptr default_function_registry();` to retrieve the default function registry, which contains all the Gandiva built-in functions. * Some library depending on Gandiva C++ library, such as Gandiva's Ruby binding's `Gandiva::FunctionRegistry` class behavior is changed accordingly # Notes * Performance * the code generation time grows with the number of externally added function bitcodes (the more functions are added, the slower the codegen will be), even if the externally function is not used in the given expression at all. But this is not a new issue, and it applies to built-in functions as well (the more built-in functions are there, the slower the codegen will be). In my limited testing, this is because `llvm::Linker::linkModule` takes non trivial of time, which happens to every IR loaded, and the `RemoveUnusedFunctions` happens after that, which doesn't help to reduce the time of `linkModule`. We may have to selectively load only necessary IR (primarily selectively doing `linkModule` for these IR), but more metadata may be needed to tell which functions can be found in which IR. This could be a separated PR for improving it, please advice if any one has any idea on improving it. Thanks. * Integration with other programming languages via LLVM IR/bitcode * So far I only added an external C++ function in the codebase for unit testing purpose. Rust based function is possible but I gave it a try and found another issue (Rust has std lib which needs to be processed in different approach), I will do some exploration for other languages such as zig later. * Non pre-compiled functions, may require some different approach to get the function pointer, and we may discuss and work on it in a separated PR later. Another issue https://github.com/apache/arrow/issues/38589 was logged for this. * The discussion thread in dev mail list, https://lists.apache.org/thread/lm4sbw61w9cl7fsmo7tz3gvkq0ox6rod * I submitted another PR previously (https://github.com/apache/arrow/pull/37787) which introduced JSON based function registry, and after discussion, I will close that PR and use this PR instead * Closes: #37753 Lead-authored-by: Yue Ni Co-authored-by: Sutou Kouhei Signed-off-by: Sutou Kouhei --- c_glib/arrow-glib/version.h.in | 23 ++++ c_glib/doc/gandiva-glib/gandiva-glib-docs.xml | 4 + c_glib/gandiva-glib/function-registry.cpp | 118 +++++++++++++++--- c_glib/gandiva-glib/function-registry.h | 2 + c_glib/gandiva-glib/function-registry.hpp | 30 +++++ c_glib/test/gandiva/test-function-registry.rb | 2 +- c_glib/test/gandiva/test-native-function.rb | 2 +- cpp/cmake_modules/GandivaAddBitcode.cmake | 75 +++++++++++ cpp/cmake_modules/ThirdpartyToolchain.cmake | 21 ++-- cpp/src/gandiva/CMakeLists.txt | 6 +- cpp/src/gandiva/GandivaConfig.cmake.in | 1 + cpp/src/gandiva/configuration.cc | 5 +- cpp/src/gandiva/configuration.h | 26 +++- cpp/src/gandiva/engine.cc | 69 +++++++--- cpp/src/gandiva/engine.h | 4 + cpp/src/gandiva/expr_decomposer_test.cc | 17 ++- cpp/src/gandiva/expr_validator.cc | 2 +- cpp/src/gandiva/expr_validator.h | 9 +- cpp/src/gandiva/expression_registry.cc | 6 +- cpp/src/gandiva/expression_registry.h | 6 +- cpp/src/gandiva/expression_registry_test.cc | 4 +- cpp/src/gandiva/filter.cc | 3 +- cpp/src/gandiva/function_registry.cc | 111 +++++++++++----- cpp/src/gandiva/function_registry.h | 34 ++++- cpp/src/gandiva/function_registry_test.cc | 47 +++++-- cpp/src/gandiva/llvm_generator.cc | 13 +- cpp/src/gandiva/llvm_generator.h | 8 +- cpp/src/gandiva/llvm_generator_test.cc | 19 ++- cpp/src/gandiva/native_function.h | 12 +- cpp/src/gandiva/precompiled/CMakeLists.txt | 67 ++-------- cpp/src/gandiva/projector.cc | 3 +- cpp/src/gandiva/tests/CMakeLists.txt | 47 ++++--- cpp/src/gandiva/tests/date_time_test.cc | 26 ++-- .../tests/external_functions/CMakeLists.txt | 50 ++++++++ .../external_functions/multiply_by_two.cc | 20 +++ .../external_functions/multiply_by_two.h | 24 ++++ cpp/src/gandiva/tests/filter_test.cc | 6 +- cpp/src/gandiva/tests/huge_table_test.cc | 5 +- .../tests/projector_build_validation_test.cc | 24 ++-- cpp/src/gandiva/tests/projector_test.cc | 26 ++++ cpp/src/gandiva/tests/test_util.cc | 52 ++++++++ cpp/src/gandiva/tests/test_util.h | 11 +- cpp/src/gandiva/tree_expr_test.cc | 16 +-- 43 files changed, 809 insertions(+), 247 deletions(-) create mode 100644 c_glib/gandiva-glib/function-registry.hpp create mode 100644 cpp/cmake_modules/GandivaAddBitcode.cmake create mode 100644 cpp/src/gandiva/tests/external_functions/CMakeLists.txt create mode 100644 cpp/src/gandiva/tests/external_functions/multiply_by_two.cc create mode 100644 cpp/src/gandiva/tests/external_functions/multiply_by_two.h create mode 100644 cpp/src/gandiva/tests/test_util.cc diff --git a/c_glib/arrow-glib/version.h.in b/c_glib/arrow-glib/version.h.in index 60c02936193bc..abb8ba08708de 100644 --- a/c_glib/arrow-glib/version.h.in +++ b/c_glib/arrow-glib/version.h.in @@ -110,6 +110,15 @@ # define GARROW_UNAVAILABLE(major, minor) G_UNAVAILABLE(major, minor) #endif +/** + * GARROW_VERSION_15_0: + * + * You can use this macro value for compile time API version check. + * + * Since: 15.0.0 + */ +#define GARROW_VERSION_15_0 G_ENCODE_VERSION(15, 0) + /** * GARROW_VERSION_14_0: * @@ -346,6 +355,20 @@ #define GARROW_AVAILABLE_IN_ALL +#if GARROW_VERSION_MIN_REQUIRED >= GARROW_VERSION_15_0 +# define GARROW_DEPRECATED_IN_15_0 GARROW_DEPRECATED +# define GARROW_DEPRECATED_IN_15_0_FOR(function) GARROW_DEPRECATED_FOR(function) +#else +# define GARROW_DEPRECATED_IN_15_0 +# define GARROW_DEPRECATED_IN_15_0_FOR(function) +#endif + +#if GARROW_VERSION_MAX_ALLOWED < GARROW_VERSION_15_0 +# define GARROW_AVAILABLE_IN_15_0 GARROW_UNAVAILABLE(15, 0) +#else +# define GARROW_AVAILABLE_IN_15_0 +#endif + #if GARROW_VERSION_MIN_REQUIRED >= GARROW_VERSION_14_0 # define GARROW_DEPRECATED_IN_14_0 GARROW_DEPRECATED # define GARROW_DEPRECATED_IN_14_0_FOR(function) GARROW_DEPRECATED_FOR(function) diff --git a/c_glib/doc/gandiva-glib/gandiva-glib-docs.xml b/c_glib/doc/gandiva-glib/gandiva-glib-docs.xml index 182bbfb527eb2..a5c32f11337e8 100644 --- a/c_glib/doc/gandiva-glib/gandiva-glib-docs.xml +++ b/c_glib/doc/gandiva-glib/gandiva-glib-docs.xml @@ -100,6 +100,10 @@ Index of deprecated API + + Index of new symbols in 15.0.0 + + Index of new symbols in 4.0.0 diff --git a/c_glib/gandiva-glib/function-registry.cpp b/c_glib/gandiva-glib/function-registry.cpp index a95019bd62c2b..f47262986db82 100644 --- a/c_glib/gandiva-glib/function-registry.cpp +++ b/c_glib/gandiva-glib/function-registry.cpp @@ -18,8 +18,8 @@ */ #include -#include +#include #include #include @@ -34,18 +34,86 @@ G_BEGIN_DECLS * Since: 0.14.0 */ -G_DEFINE_TYPE(GGandivaFunctionRegistry, - ggandiva_function_registry, - G_TYPE_OBJECT) +struct GGandivaFunctionRegistryPrivate { + std::shared_ptr function_registry; +}; + +enum { + PROP_FUNCTION_REGISTRY = 1, +}; + +G_DEFINE_TYPE_WITH_PRIVATE(GGandivaFunctionRegistry, + ggandiva_function_registry, + G_TYPE_OBJECT) + +#define GGANDIVA_FUNCTION_REGISTRY_GET_PRIVATE(object) \ + static_cast( \ + ggandiva_function_registry_get_instance_private( \ + GGANDIVA_FUNCTION_REGISTRY(object))) + +static void +ggandiva_function_registry_finalize(GObject *object) +{ + auto priv = GGANDIVA_FUNCTION_REGISTRY_GET_PRIVATE(object); + priv->function_registry.~shared_ptr(); + G_OBJECT_CLASS(ggandiva_function_registry_parent_class)->finalize(object); +} + +static void +ggandiva_function_registry_set_property(GObject *object, + guint prop_id, + const GValue *value, + GParamSpec *pspec) +{ + auto priv = GGANDIVA_FUNCTION_REGISTRY_GET_PRIVATE(object); + + switch (prop_id) { + case PROP_FUNCTION_REGISTRY: + priv->function_registry = + *static_cast *>( + g_value_get_pointer(value)); + break; + default: + G_OBJECT_WARN_INVALID_PROPERTY_ID(object, prop_id, pspec); + break; + } +} static void ggandiva_function_registry_init(GGandivaFunctionRegistry *object) { + auto priv = GGANDIVA_FUNCTION_REGISTRY_GET_PRIVATE(object); + new(&priv->function_registry) std::shared_ptr; } static void ggandiva_function_registry_class_init(GGandivaFunctionRegistryClass *klass) { + auto gobject_class = G_OBJECT_CLASS(klass); + gobject_class->finalize = ggandiva_function_registry_finalize; + gobject_class->set_property = ggandiva_function_registry_set_property; + + GParamSpec *spec; + spec = g_param_spec_pointer("function-registry", + "Function registry", + "The raw std::shared_ptr *", + static_cast(G_PARAM_WRITABLE | + G_PARAM_CONSTRUCT_ONLY)); + g_object_class_install_property(gobject_class, PROP_FUNCTION_REGISTRY, spec); +} + +/** + * ggandiva_function_registry_default: + * + * Returns: (transfer full): The process-wide default function registry. + * + * Since: 15.0.0 + */ +GGandivaFunctionRegistry * +ggandiva_function_registry_default(void) +{ + auto gandiva_function_registry = gandiva::default_function_registry(); + return ggandiva_function_registry_new_raw(&gandiva_function_registry); } /** @@ -58,7 +126,8 @@ ggandiva_function_registry_class_init(GGandivaFunctionRegistryClass *klass) GGandivaFunctionRegistry * ggandiva_function_registry_new(void) { - return GGANDIVA_FUNCTION_REGISTRY(g_object_new(GGANDIVA_TYPE_FUNCTION_REGISTRY, NULL)); + auto gandiva_function_registry = std::make_shared(); + return ggandiva_function_registry_new_raw(&gandiva_function_registry); } /** @@ -75,15 +144,16 @@ GGandivaNativeFunction * ggandiva_function_registry_lookup(GGandivaFunctionRegistry *function_registry, GGandivaFunctionSignature *function_signature) { - gandiva::FunctionRegistry gandiva_function_registry; + auto gandiva_function_registry = + ggandiva_function_registry_get_raw(function_registry); auto gandiva_function_signature = ggandiva_function_signature_get_raw(function_signature); auto gandiva_native_function = - gandiva_function_registry.LookupSignature(*gandiva_function_signature); + gandiva_function_registry->LookupSignature(*gandiva_function_signature); if (gandiva_native_function) { return ggandiva_native_function_new_raw(gandiva_native_function); } else { - return NULL; + return nullptr; } } @@ -99,18 +169,32 @@ ggandiva_function_registry_lookup(GGandivaFunctionRegistry *function_registry, GList * ggandiva_function_registry_get_native_functions(GGandivaFunctionRegistry *function_registry) { - gandiva::FunctionRegistry gandiva_function_registry; - + auto gandiva_function_registry = + ggandiva_function_registry_get_raw(function_registry); GList *native_functions = nullptr; - for (auto gandiva_native_function = gandiva_function_registry.begin(); - gandiva_native_function != gandiva_function_registry.end(); - ++gandiva_native_function) { - auto native_function = ggandiva_native_function_new_raw(gandiva_native_function); + for (const auto &gandiva_native_function : *gandiva_function_registry) { + auto native_function = ggandiva_native_function_new_raw(&gandiva_native_function); native_functions = g_list_prepend(native_functions, native_function); } - native_functions = g_list_reverse(native_functions); - - return native_functions; + return g_list_reverse(native_functions); } G_END_DECLS + +GGandivaFunctionRegistry * +ggandiva_function_registry_new_raw( + std::shared_ptr *gandiva_function_registry) +{ + return GGANDIVA_FUNCTION_REGISTRY( + g_object_new(GGANDIVA_TYPE_FUNCTION_REGISTRY, + "function-registry", gandiva_function_registry, + nullptr)); +} + +std::shared_ptr +ggandiva_function_registry_get_raw(GGandivaFunctionRegistry *function_registry) +{ + auto priv = GGANDIVA_FUNCTION_REGISTRY_GET_PRIVATE(function_registry); + return priv->function_registry; +} + diff --git a/c_glib/gandiva-glib/function-registry.h b/c_glib/gandiva-glib/function-registry.h index 1a0d767d45354..8ff6027cf1734 100644 --- a/c_glib/gandiva-glib/function-registry.h +++ b/c_glib/gandiva-glib/function-registry.h @@ -35,6 +35,8 @@ struct _GGandivaFunctionRegistryClass GObjectClass parent_class; }; +GARROW_AVAILABLE_IN_15_0 +GGandivaFunctionRegistry *ggandiva_function_registry_default(void); GGandivaFunctionRegistry *ggandiva_function_registry_new(void); GGandivaNativeFunction * ggandiva_function_registry_lookup(GGandivaFunctionRegistry *function_registry, diff --git a/c_glib/gandiva-glib/function-registry.hpp b/c_glib/gandiva-glib/function-registry.hpp new file mode 100644 index 0000000000000..0430fc57dead2 --- /dev/null +++ b/c_glib/gandiva-glib/function-registry.hpp @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#pragma once + +#include + +#include + +GGandivaFunctionRegistry * +ggandiva_function_registry_new_raw( + std::shared_ptr *gandiva_function_registry); +std::shared_ptr +ggandiva_function_registry_get_raw(GGandivaFunctionRegistry *function_registry); diff --git a/c_glib/test/gandiva/test-function-registry.rb b/c_glib/test/gandiva/test-function-registry.rb index 25bac6673105e..d0f959a1c5f5f 100644 --- a/c_glib/test/gandiva/test-function-registry.rb +++ b/c_glib/test/gandiva/test-function-registry.rb @@ -20,7 +20,7 @@ class TestGandivaFunctionRegistry < Test::Unit::TestCase def setup omit("Gandiva is required") unless defined?(::Gandiva) - @registry = Gandiva::FunctionRegistry.new + @registry = Gandiva::FunctionRegistry.default end sub_test_case("lookup") do diff --git a/c_glib/test/gandiva/test-native-function.rb b/c_glib/test/gandiva/test-native-function.rb index 7888f96b678b7..630a1f7c32d2a 100644 --- a/c_glib/test/gandiva/test-native-function.rb +++ b/c_glib/test/gandiva/test-native-function.rb @@ -20,7 +20,7 @@ class TestGandivaNativeFunction < Test::Unit::TestCase def setup omit("Gandiva is required") unless defined?(::Gandiva) - @registry = Gandiva::FunctionRegistry.new + @registry = Gandiva::FunctionRegistry.default @not = lookup("not", [boolean_data_type], boolean_data_type) @isnull = lookup("isnull", [int8_data_type], boolean_data_type) end diff --git a/cpp/cmake_modules/GandivaAddBitcode.cmake b/cpp/cmake_modules/GandivaAddBitcode.cmake new file mode 100644 index 0000000000000..98847f8a186fe --- /dev/null +++ b/cpp/cmake_modules/GandivaAddBitcode.cmake @@ -0,0 +1,75 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Create bitcode for the given source file. +function(gandiva_add_bitcode SOURCE) + set(CLANG_OPTIONS -std=c++17) + if(MSVC) + # "19.20" means that it's compatible with Visual Studio 16 2019. + # We can update this to "19.30" when we dropped support for Visual + # Studio 16 2019. + # + # See https://cmake.org/cmake/help/latest/variable/MSVC_VERSION.html + # for MSVC_VERSION and Visual Studio version. + set(FMS_COMPATIBILITY 19.20) + list(APPEND CLANG_OPTIONS -fms-compatibility + -fms-compatibility-version=${FMS_COMPATIBILITY}) + endif() + + get_filename_component(SOURCE_BASE ${SOURCE} NAME_WE) + get_filename_component(ABSOLUTE_SOURCE ${SOURCE} ABSOLUTE) + set(BC_FILE ${CMAKE_CURRENT_BINARY_DIR}/${SOURCE_BASE}.bc) + set(PRECOMPILE_COMMAND) + if(CMAKE_OSX_SYSROOT) + list(APPEND + PRECOMPILE_COMMAND + ${CMAKE_COMMAND} + -E + env + SDKROOT=${CMAKE_OSX_SYSROOT}) + endif() + list(APPEND + PRECOMPILE_COMMAND + ${CLANG_EXECUTABLE} + ${CLANG_OPTIONS} + -DGANDIVA_IR + -DNDEBUG # DCHECK macros not implemented in precompiled code + -DARROW_STATIC # Do not set __declspec(dllimport) on MSVC on Arrow symbols + -DGANDIVA_STATIC # Do not set __declspec(dllimport) on MSVC on Gandiva symbols + -fno-use-cxa-atexit # Workaround for unresolved __dso_handle + -emit-llvm + -O3 + -c + ${ABSOLUTE_SOURCE} + -o + ${BC_FILE} + ${ARROW_GANDIVA_PC_CXX_FLAGS}) + if(ARROW_BINARY_DIR) + list(APPEND PRECOMPILE_COMMAND -I${ARROW_BINARY_DIR}/src) + endif() + if(ARROW_SOURCE_DIR) + list(APPEND PRECOMPILE_COMMAND -I${ARROW_SOURCE_DIR}/src) + endif() + if(NOT ARROW_USE_NATIVE_INT128) + foreach(boost_include_dir ${Boost_INCLUDE_DIRS}) + list(APPEND PRECOMPILE_COMMAND -I${boost_include_dir}) + endforeach() + endif() + add_custom_command(OUTPUT ${BC_FILE} + COMMAND ${PRECOMPILE_COMMAND} + DEPENDS ${SOURCE_FILE}) +endfunction() diff --git a/cpp/cmake_modules/ThirdpartyToolchain.cmake b/cpp/cmake_modules/ThirdpartyToolchain.cmake index 5de8ff9b1cb11..52632d554aafb 100644 --- a/cpp/cmake_modules/ThirdpartyToolchain.cmake +++ b/cpp/cmake_modules/ThirdpartyToolchain.cmake @@ -230,18 +230,21 @@ macro(build_dependency DEPENDENCY_NAME) endif() endmacro() -# Find modules are needed by the consumer in case of a static build, or if the -# linkage is PUBLIC or INTERFACE. -macro(provide_find_module PACKAGE_NAME ARROW_CMAKE_PACKAGE_NAME) - set(module_ "${CMAKE_SOURCE_DIR}/cmake_modules/Find${PACKAGE_NAME}.cmake") - if(EXISTS "${module_}") - message(STATUS "Providing CMake module for ${PACKAGE_NAME} as part of ${ARROW_CMAKE_PACKAGE_NAME} CMake package" +function(provide_cmake_module MODULE_NAME ARROW_CMAKE_PACKAGE_NAME) + set(module "${CMAKE_SOURCE_DIR}/cmake_modules/${MODULE_NAME}.cmake") + if(EXISTS "${module}") + message(STATUS "Providing CMake module for ${MODULE_NAME} as part of ${ARROW_CMAKE_PACKAGE_NAME} CMake package" ) - install(FILES "${module_}" + install(FILES "${module}" DESTINATION "${ARROW_CMAKE_DIR}/${ARROW_CMAKE_PACKAGE_NAME}") endif() - unset(module_) -endmacro() +endfunction() + +# Find modules are needed by the consumer in case of a static build, or if the +# linkage is PUBLIC or INTERFACE. +function(provide_find_module PACKAGE_NAME ARROW_CMAKE_PACKAGE_NAME) + provide_cmake_module("Find${PACKAGE_NAME}" ${ARROW_CMAKE_PACKAGE_NAME}) +endfunction() macro(resolve_dependency DEPENDENCY_NAME) set(options) diff --git a/cpp/src/gandiva/CMakeLists.txt b/cpp/src/gandiva/CMakeLists.txt index 6b6743bc8e52f..3448d516768bb 100644 --- a/cpp/src/gandiva/CMakeLists.txt +++ b/cpp/src/gandiva/CMakeLists.txt @@ -25,11 +25,14 @@ add_custom_target(gandiva-benchmarks) add_dependencies(gandiva-all gandiva gandiva-tests gandiva-benchmarks) +include(GandivaAddBitcode) + find_package(LLVMAlt REQUIRED) provide_find_module(LLVMAlt "Gandiva") if(ARROW_WITH_ZSTD AND "${zstd_SOURCE}" STREQUAL "SYSTEM") provide_find_module(zstdAlt "Gandiva") endif() +provide_cmake_module(GandivaAddBitcode "Gandiva") # Set the path where the bitcode file generated, see precompiled/CMakeLists.txt set(GANDIVA_PRECOMPILED_BC_PATH "${CMAKE_CURRENT_BINARY_DIR}/irhelpers.bc") @@ -249,7 +252,8 @@ add_gandiva_test(internals-test random_generator_holder_test.cc hash_utils_test.cc gdv_function_stubs_test.cc - interval_holder_test.cc) + interval_holder_test.cc + tests/test_util.cc) add_subdirectory(precompiled) add_subdirectory(tests) diff --git a/cpp/src/gandiva/GandivaConfig.cmake.in b/cpp/src/gandiva/GandivaConfig.cmake.in index f02e29f25bb3a..68579debd183b 100644 --- a/cpp/src/gandiva/GandivaConfig.cmake.in +++ b/cpp/src/gandiva/GandivaConfig.cmake.in @@ -49,6 +49,7 @@ else() endif() include("${CMAKE_CURRENT_LIST_DIR}/GandivaTargets.cmake") +include("${CMAKE_CURRENT_LIST_DIR}/GandivaAddBitcode.cmake") arrow_keep_backward_compatibility(Gandiva gandiva) diff --git a/cpp/src/gandiva/configuration.cc b/cpp/src/gandiva/configuration.cc index 1e26c5c70d4ec..b79f4118e07f2 100644 --- a/cpp/src/gandiva/configuration.cc +++ b/cpp/src/gandiva/configuration.cc @@ -29,11 +29,14 @@ std::size_t Configuration::Hash() const { size_t result = kHashSeed; arrow::internal::hash_combine(result, static_cast(optimize_)); arrow::internal::hash_combine(result, static_cast(target_host_cpu_)); + arrow::internal::hash_combine( + result, reinterpret_cast(function_registry_.get())); return result; } bool Configuration::operator==(const Configuration& other) const { - return optimize_ == other.optimize_ && target_host_cpu_ == other.target_host_cpu_; + return optimize_ == other.optimize_ && target_host_cpu_ == other.target_host_cpu_ && + function_registry_ == other.function_registry_; } bool Configuration::operator!=(const Configuration& other) const { diff --git a/cpp/src/gandiva/configuration.h b/cpp/src/gandiva/configuration.h index 9cd301524d03d..f43a2b190731f 100644 --- a/cpp/src/gandiva/configuration.h +++ b/cpp/src/gandiva/configuration.h @@ -21,6 +21,7 @@ #include #include "arrow/status.h" +#include "gandiva/function_registry.h" #include "gandiva/visibility.h" namespace gandiva { @@ -34,8 +35,14 @@ class GANDIVA_EXPORT Configuration { public: friend class ConfigurationBuilder; - Configuration() : optimize_(true), target_host_cpu_(true) {} - explicit Configuration(bool optimize) : optimize_(optimize), target_host_cpu_(true) {} + explicit Configuration(bool optimize, + std::shared_ptr function_registry = + gandiva::default_function_registry()) + : optimize_(optimize), + target_host_cpu_(true), + function_registry_(function_registry) {} + + Configuration() : Configuration(true) {} std::size_t Hash() const; bool operator==(const Configuration& other) const; @@ -43,13 +50,21 @@ class GANDIVA_EXPORT Configuration { bool optimize() const { return optimize_; } bool target_host_cpu() const { return target_host_cpu_; } + std::shared_ptr function_registry() const { + return function_registry_; + } void set_optimize(bool optimize) { optimize_ = optimize; } void target_host_cpu(bool target_host_cpu) { target_host_cpu_ = target_host_cpu; } + void set_function_registry(std::shared_ptr function_registry) { + function_registry_ = std::move(function_registry); + } private: bool optimize_; /* optimise the generated llvm IR */ bool target_host_cpu_; /* set the mcpu flag to host cpu while compiling llvm ir */ + std::shared_ptr + function_registry_; /* function registry that may contain external functions */ }; /// \brief configuration builder for gandiva @@ -68,6 +83,13 @@ class GANDIVA_EXPORT ConfigurationBuilder { return configuration; } + std::shared_ptr build( + std::shared_ptr function_registry) { + std::shared_ptr configuration( + new Configuration(true, std::move(function_registry))); + return configuration; + } + static std::shared_ptr DefaultConfiguration() { return default_configuration_; } diff --git a/cpp/src/gandiva/engine.cc b/cpp/src/gandiva/engine.cc index 8ebe927437567..5ae1d76876148 100644 --- a/cpp/src/gandiva/engine.cc +++ b/cpp/src/gandiva/engine.cc @@ -141,7 +141,8 @@ Engine::Engine(const std::shared_ptr& conf, module_(module), types_(*context_), optimize_(conf->optimize()), - cached_(cached) {} + cached_(cached), + function_registry_(conf->function_registry()) {} Status Engine::Init() { std::call_once(register_exported_funcs_flag, gandiva::RegisterExportedFuncs); @@ -155,6 +156,7 @@ Status Engine::LoadFunctionIRs() { if (!functions_loaded_) { ARROW_RETURN_NOT_OK(LoadPreCompiledIR()); ARROW_RETURN_NOT_OK(DecimalIR::AddFunctions(this)); + ARROW_RETURN_NOT_OK(LoadExternalPreCompiledIR()); functions_loaded_ = true; } return Status::OK(); @@ -236,7 +238,38 @@ static void SetDataLayout(llvm::Module* module) { module->setDataLayout(machine->createDataLayout()); } -// end of the mofified method from MLIR +// end of the modified method from MLIR + +template +static arrow::Result AsArrowResult(llvm::Expected& expected) { + if (!expected) { + std::string str; + llvm::raw_string_ostream stream(str); + stream << expected.takeError(); + return Status::CodeGenError(stream.str()); + } + return std::move(expected.get()); +} + +static arrow::Status VerifyAndLinkModule( + llvm::Module* dest_module, + llvm::Expected> src_module_or_error) { + ARROW_ASSIGN_OR_RAISE(auto src_ir_module, AsArrowResult(src_module_or_error)); + + // set dataLayout + SetDataLayout(src_ir_module.get()); + + std::string error_info; + llvm::raw_string_ostream error_stream(error_info); + ARROW_RETURN_IF( + llvm::verifyModule(*src_ir_module, &error_stream), + Status::CodeGenError("verify of IR Module failed: " + error_stream.str())); + + ARROW_RETURN_IF(llvm::Linker::linkModules(*dest_module, std::move(src_ir_module)), + Status::CodeGenError("failed to link IR Modules")); + + return Status::OK(); +} // Handling for pre-compiled IR libraries. Status Engine::LoadPreCompiledIR() { @@ -256,23 +289,25 @@ Status Engine::LoadPreCompiledIR() { /// Parse the IR module. llvm::Expected> module_or_error = llvm::getOwningLazyBitcodeModule(std::move(buffer), *context()); - if (!module_or_error) { - // NOTE: llvm::handleAllErrors() fails linking with RTTI-disabled LLVM builds - // (ARROW-5148) - std::string str; - llvm::raw_string_ostream stream(str); - stream << module_or_error.takeError(); - return Status::CodeGenError(stream.str()); - } - std::unique_ptr ir_module = std::move(module_or_error.get()); + // NOTE: llvm::handleAllErrors() fails linking with RTTI-disabled LLVM builds + // (ARROW-5148) + ARROW_RETURN_NOT_OK(VerifyAndLinkModule(module_, std::move(module_or_error))); + return Status::OK(); +} - // set dataLayout - SetDataLayout(ir_module.get()); +static llvm::MemoryBufferRef AsLLVMMemoryBuffer(const arrow::Buffer& arrow_buffer) { + auto data = reinterpret_cast(arrow_buffer.data()); + auto size = arrow_buffer.size(); + return llvm::MemoryBufferRef(llvm::StringRef(data, size), "external_bitcode"); +} - ARROW_RETURN_IF(llvm::verifyModule(*ir_module, &llvm::errs()), - Status::CodeGenError("verify of IR Module failed")); - ARROW_RETURN_IF(llvm::Linker::linkModules(*module_, std::move(ir_module)), - Status::CodeGenError("failed to link IR Modules")); +Status Engine::LoadExternalPreCompiledIR() { + auto const& buffers = function_registry_->GetBitcodeBuffers(); + for (auto const& buffer : buffers) { + auto llvm_memory_buffer_ref = AsLLVMMemoryBuffer(*buffer); + auto module_or_error = llvm::parseBitcodeFile(llvm_memory_buffer_ref, *context()); + ARROW_RETURN_NOT_OK(VerifyAndLinkModule(module_, std::move(module_or_error))); + } return Status::OK(); } diff --git a/cpp/src/gandiva/engine.h b/cpp/src/gandiva/engine.h index a4d6a5fd1a758..566977dc4adad 100644 --- a/cpp/src/gandiva/engine.h +++ b/cpp/src/gandiva/engine.h @@ -93,6 +93,9 @@ class GANDIVA_EXPORT Engine { /// the main module. Status LoadPreCompiledIR(); + // load external pre-compiled bitcodes into module + Status LoadExternalPreCompiledIR(); + // Create and add mappings for cpp functions that can be accessed from LLVM. void AddGlobalMappings(); @@ -111,6 +114,7 @@ class GANDIVA_EXPORT Engine { bool module_finalized_ = false; bool cached_; bool functions_loaded_ = false; + std::shared_ptr function_registry_; }; } // namespace gandiva diff --git a/cpp/src/gandiva/expr_decomposer_test.cc b/cpp/src/gandiva/expr_decomposer_test.cc index 638ceebcb19fd..7681d9e646297 100644 --- a/cpp/src/gandiva/expr_decomposer_test.cc +++ b/cpp/src/gandiva/expr_decomposer_test.cc @@ -24,7 +24,6 @@ #include "gandiva/function_registry.h" #include "gandiva/gandiva_aliases.h" #include "gandiva/node.h" -#include "gandiva/tree_expr_builder.h" namespace gandiva { @@ -32,12 +31,12 @@ using arrow::int32; class TestExprDecomposer : public ::testing::Test { protected: - FunctionRegistry registry_; + std::shared_ptr registry_ = default_function_registry(); }; TEST_F(TestExprDecomposer, TestStackSimple) { Annotator annotator; - ExprDecomposer decomposer(registry_, annotator); + ExprDecomposer decomposer(*registry_, annotator); // if (a) _ // else _ @@ -58,7 +57,7 @@ TEST_F(TestExprDecomposer, TestStackSimple) { TEST_F(TestExprDecomposer, TestNested) { Annotator annotator; - ExprDecomposer decomposer(registry_, annotator); + ExprDecomposer decomposer(*registry_, annotator); // if (a) _ // else _ @@ -97,7 +96,7 @@ TEST_F(TestExprDecomposer, TestNested) { TEST_F(TestExprDecomposer, TestInternalIf) { Annotator annotator; - ExprDecomposer decomposer(registry_, annotator); + ExprDecomposer decomposer(*registry_, annotator); // if (a) _ // if (b) _ @@ -136,7 +135,7 @@ TEST_F(TestExprDecomposer, TestInternalIf) { TEST_F(TestExprDecomposer, TestParallelIf) { Annotator annotator; - ExprDecomposer decomposer(registry_, annotator); + ExprDecomposer decomposer(*registry_, annotator); // if (a) _ // else _ @@ -174,7 +173,7 @@ TEST_F(TestExprDecomposer, TestParallelIf) { TEST_F(TestExprDecomposer, TestIfInCondition) { Annotator annotator; - ExprDecomposer decomposer(registry_, annotator); + ExprDecomposer decomposer(*registry_, annotator); // if (if _ else _) : a // - @@ -245,7 +244,7 @@ TEST_F(TestExprDecomposer, TestIfInCondition) { TEST_F(TestExprDecomposer, TestFunctionBetweenNestedIf) { Annotator annotator; - ExprDecomposer decomposer(registry_, annotator); + ExprDecomposer decomposer(*registry_, annotator); // if (a) _ // else @@ -286,7 +285,7 @@ TEST_F(TestExprDecomposer, TestFunctionBetweenNestedIf) { TEST_F(TestExprDecomposer, TestComplexIfCondition) { Annotator annotator; - ExprDecomposer decomposer(registry_, annotator); + ExprDecomposer decomposer(*registry_, annotator); // if (if _ // else diff --git a/cpp/src/gandiva/expr_validator.cc b/cpp/src/gandiva/expr_validator.cc index 35a13494523d0..8a6f86e6f0419 100644 --- a/cpp/src/gandiva/expr_validator.cc +++ b/cpp/src/gandiva/expr_validator.cc @@ -93,7 +93,7 @@ Status ExprValidator::Visit(const FunctionNode& node) { const auto& desc = node.descriptor(); FunctionSignature signature(desc->name(), desc->params(), desc->return_type()); - const NativeFunction* native_function = registry_.LookupSignature(signature); + const NativeFunction* native_function = registry_->LookupSignature(signature); ARROW_RETURN_IF(native_function == nullptr, Status::ExpressionValidationError("Function ", signature.ToString(), " not supported yet. ")); diff --git a/cpp/src/gandiva/expr_validator.h b/cpp/src/gandiva/expr_validator.h index 7f6d7fd131fbe..8a423fc93b02b 100644 --- a/cpp/src/gandiva/expr_validator.h +++ b/cpp/src/gandiva/expr_validator.h @@ -37,8 +37,9 @@ class FunctionRegistry; /// data types, signatures and return types class ExprValidator : public NodeVisitor { public: - explicit ExprValidator(LLVMTypes* types, SchemaPtr schema) - : types_(types), schema_(schema) { + explicit ExprValidator(LLVMTypes* types, SchemaPtr schema, + std::shared_ptr registry) + : types_(types), schema_(schema), registry_(std::move(registry)) { for (auto& field : schema_->fields()) { field_map_[field->name()] = field; } @@ -65,12 +66,12 @@ class ExprValidator : public NodeVisitor { Status Visit(const InExpressionNode& node) override; Status Visit(const InExpressionNode& node) override; - FunctionRegistry registry_; - LLVMTypes* types_; SchemaPtr schema_; + std::shared_ptr registry_; + using FieldMap = std::unordered_map; FieldMap field_map_; }; diff --git a/cpp/src/gandiva/expression_registry.cc b/cpp/src/gandiva/expression_registry.cc index 9bff97f5ad269..dd964a7cb8a7a 100644 --- a/cpp/src/gandiva/expression_registry.cc +++ b/cpp/src/gandiva/expression_registry.cc @@ -22,9 +22,9 @@ namespace gandiva { -ExpressionRegistry::ExpressionRegistry() { - function_registry_.reset(new FunctionRegistry()); -} +ExpressionRegistry::ExpressionRegistry( + std::shared_ptr function_registry) + : function_registry_{function_registry} {} ExpressionRegistry::~ExpressionRegistry() {} diff --git a/cpp/src/gandiva/expression_registry.h b/cpp/src/gandiva/expression_registry.h index 609a2dbbe21f9..156a6392564f9 100644 --- a/cpp/src/gandiva/expression_registry.h +++ b/cpp/src/gandiva/expression_registry.h @@ -21,6 +21,7 @@ #include #include "gandiva/arrow.h" +#include "gandiva/function_registry.h" #include "gandiva/function_signature.h" #include "gandiva/gandiva_aliases.h" #include "gandiva/visibility.h" @@ -37,7 +38,8 @@ class GANDIVA_EXPORT ExpressionRegistry { public: using native_func_iterator_type = const NativeFunction*; using func_sig_iterator_type = const FunctionSignature*; - ExpressionRegistry(); + explicit ExpressionRegistry(std::shared_ptr function_registry = + gandiva::default_function_registry()); ~ExpressionRegistry(); static DataTypeVector supported_types() { return supported_types_; } class GANDIVA_EXPORT FunctionSignatureIterator { @@ -62,7 +64,7 @@ class GANDIVA_EXPORT ExpressionRegistry { private: static DataTypeVector supported_types_; - std::unique_ptr function_registry_; + std::shared_ptr function_registry_; }; /// \brief Get the list of all function signatures. diff --git a/cpp/src/gandiva/expression_registry_test.cc b/cpp/src/gandiva/expression_registry_test.cc index c254ff4f3aa5e..cd784192c194e 100644 --- a/cpp/src/gandiva/expression_registry_test.cc +++ b/cpp/src/gandiva/expression_registry_test.cc @@ -31,7 +31,7 @@ typedef int64_t (*add_vector_func_t)(int64_t* elements, int nelements); class TestExpressionRegistry : public ::testing::Test { protected: - FunctionRegistry registry_; + std::shared_ptr registry_ = default_function_registry(); }; // Verify all functions in registry are exported. @@ -42,7 +42,7 @@ TEST_F(TestExpressionRegistry, VerifySupportedFunctions) { iter != expr_registry.function_signature_end(); iter++) { functions.push_back((*iter)); } - for (auto& iter : registry_) { + for (auto& iter : *registry_) { for (auto& func_iter : iter.signatures()) { auto element = std::find(functions.begin(), functions.end(), func_iter); EXPECT_NE(element, functions.end()) << "function signature " << func_iter.ToString() diff --git a/cpp/src/gandiva/filter.cc b/cpp/src/gandiva/filter.cc index 78917467a0f56..416d97b5dbd1d 100644 --- a/cpp/src/gandiva/filter.cc +++ b/cpp/src/gandiva/filter.cc @@ -71,7 +71,8 @@ Status Filter::Make(SchemaPtr schema, ConditionPtr condition, if (!is_cached) { // Run the validation on the expression. // Return if the expression is invalid since we will not be able to process further. - ExprValidator expr_validator(llvm_gen->types(), schema); + ExprValidator expr_validator(llvm_gen->types(), schema, + configuration->function_registry()); ARROW_RETURN_NOT_OK(expr_validator.Validate(condition)); } diff --git a/cpp/src/gandiva/function_registry.cc b/cpp/src/gandiva/function_registry.cc index 67b7b404b325c..5d676dfa8df74 100644 --- a/cpp/src/gandiva/function_registry.cc +++ b/cpp/src/gandiva/function_registry.cc @@ -16,6 +16,13 @@ // under the License. #include "gandiva/function_registry.h" + +#include +#include + +#include + +#include "arrow/util/logging.h" #include "gandiva/function_registry_arithmetic.h" #include "gandiva/function_registry_datetime.h" #include "gandiva/function_registry_hash.h" @@ -23,12 +30,26 @@ #include "gandiva/function_registry_string.h" #include "gandiva/function_registry_timestamp_arithmetic.h" -#include -#include -#include - namespace gandiva { +static constexpr uint32_t kMaxFunctionSignatures = 2048; + +// encapsulates an llvm memory buffer in an arrow buffer +// this is needed because we don't expose the llvm memory buffer to the outside world in +// the header file +class LLVMMemoryArrowBuffer : public arrow::Buffer { + public: + explicit LLVMMemoryArrowBuffer(std::unique_ptr llvm_buffer) + : arrow::Buffer(reinterpret_cast(llvm_buffer->getBufferStart()), + static_cast(llvm_buffer->getBufferSize())), + llvm_buffer_(std::move(llvm_buffer)) {} + + private: + std::unique_ptr llvm_buffer_; +}; + +FunctionRegistry::FunctionRegistry() { pc_registry_.reserve(kMaxFunctionSignatures); } + FunctionRegistry::iterator FunctionRegistry::begin() const { return &(*pc_registry_.begin()); } @@ -41,42 +62,74 @@ FunctionRegistry::iterator FunctionRegistry::back() const { return &(pc_registry_.back()); } -std::vector FunctionRegistry::pc_registry_; +const NativeFunction* FunctionRegistry::LookupSignature( + const FunctionSignature& signature) const { + auto got = pc_registry_map_.find(&signature); + return got == pc_registry_map_.end() ? nullptr : got->second; +} -SignatureMap FunctionRegistry::pc_registry_map_ = InitPCMap(); +Status FunctionRegistry::Add(NativeFunction func) { + if (pc_registry_.size() == kMaxFunctionSignatures) { + return Status::CapacityError("Exceeded max function signatures limit of ", + kMaxFunctionSignatures); + } + pc_registry_.emplace_back(std::move(func)); + auto const& last_func = pc_registry_.back(); + for (auto const& func_signature : last_func.signatures()) { + pc_registry_map_.emplace(&func_signature, &last_func); + } + return arrow::Status::OK(); +} -SignatureMap FunctionRegistry::InitPCMap() { - SignatureMap map; +arrow::Result> GetBufferFromFile( + const std::string& bitcode_file_path) { + auto buffer_or_error = llvm::MemoryBuffer::getFile(bitcode_file_path); - auto v1 = GetArithmeticFunctionRegistry(); - pc_registry_.insert(std::end(pc_registry_), v1.begin(), v1.end()); - auto v2 = GetDateTimeFunctionRegistry(); - pc_registry_.insert(std::end(pc_registry_), v2.begin(), v2.end()); + ARROW_RETURN_IF(!buffer_or_error, + Status::IOError("Could not load module from bitcode file: ", + bitcode_file_path + + " Error: " + buffer_or_error.getError().message())); - auto v3 = GetHashFunctionRegistry(); - pc_registry_.insert(std::end(pc_registry_), v3.begin(), v3.end()); + return std::move(buffer_or_error.get()); +} - auto v4 = GetMathOpsFunctionRegistry(); - pc_registry_.insert(std::end(pc_registry_), v4.begin(), v4.end()); +Status FunctionRegistry::Register(const std::vector& funcs, + const std::string& bitcode_path) { + ARROW_ASSIGN_OR_RAISE(auto llvm_buffer, GetBufferFromFile(bitcode_path)); + auto buffer = std::make_shared(std::move(llvm_buffer)); + return Register(funcs, std::move(buffer)); +} - auto v5 = GetStringFunctionRegistry(); - pc_registry_.insert(std::end(pc_registry_), v5.begin(), v5.end()); +arrow::Status FunctionRegistry::Register(const std::vector& funcs, + std::shared_ptr bitcode_buffer) { + bitcode_memory_buffers_.emplace_back(std::move(bitcode_buffer)); + for (const auto& func : funcs) { + ARROW_RETURN_NOT_OK(FunctionRegistry::Add(func)); + } + return Status::OK(); +} - auto v6 = GetDateTimeArithmeticFunctionRegistry(); - pc_registry_.insert(std::end(pc_registry_), v6.begin(), v6.end()); - for (auto& elem : pc_registry_) { - for (auto& func_signature : elem.signatures()) { - map.insert(std::make_pair(&(func_signature), &elem)); +const std::vector>& FunctionRegistry::GetBitcodeBuffers() + const { + return bitcode_memory_buffers_; +} + +arrow::Result> MakeDefaultFunctionRegistry() { + auto registry = std::make_shared(); + for (auto const& funcs : + {GetArithmeticFunctionRegistry(), GetDateTimeFunctionRegistry(), + GetHashFunctionRegistry(), GetMathOpsFunctionRegistry(), + GetStringFunctionRegistry(), GetDateTimeArithmeticFunctionRegistry()}) { + for (auto const& func_signature : funcs) { + ARROW_RETURN_NOT_OK(registry->Add(func_signature)); } } - - return map; + return std::move(registry); } -const NativeFunction* FunctionRegistry::LookupSignature( - const FunctionSignature& signature) const { - auto got = pc_registry_map_.find(&signature); - return got == pc_registry_map_.end() ? nullptr : got->second; +std::shared_ptr default_function_registry() { + static auto default_registry = *MakeDefaultFunctionRegistry(); + return default_registry; } } // namespace gandiva diff --git a/cpp/src/gandiva/function_registry.h b/cpp/src/gandiva/function_registry.h index d9256326019c6..01984961dc90f 100644 --- a/cpp/src/gandiva/function_registry.h +++ b/cpp/src/gandiva/function_registry.h @@ -17,7 +17,12 @@ #pragma once +#include +#include #include + +#include "arrow/buffer.h" +#include "arrow/status.h" #include "gandiva/function_registry_common.h" #include "gandiva/gandiva_aliases.h" #include "gandiva/native_function.h" @@ -30,18 +35,41 @@ class GANDIVA_EXPORT FunctionRegistry { public: using iterator = const NativeFunction*; + FunctionRegistry(); + FunctionRegistry(const FunctionRegistry&) = delete; + FunctionRegistry& operator=(const FunctionRegistry&) = delete; + /// Lookup a pre-compiled function by its signature. const NativeFunction* LookupSignature(const FunctionSignature& signature) const; + /// \brief register a set of functions into the function registry from a given bitcode + /// file + arrow::Status Register(const std::vector& funcs, + const std::string& bitcode_path); + + /// \brief register a set of functions into the function registry from a given bitcode + /// buffer + arrow::Status Register(const std::vector& funcs, + std::shared_ptr bitcode_buffer); + + /// \brief get a list of bitcode memory buffers saved in the registry + const std::vector>& GetBitcodeBuffers() const; + iterator begin() const; iterator end() const; iterator back() const; + friend arrow::Result> MakeDefaultFunctionRegistry(); + private: - static SignatureMap InitPCMap(); + std::vector pc_registry_; + SignatureMap pc_registry_map_; + std::vector> bitcode_memory_buffers_; - static std::vector pc_registry_; - static SignatureMap pc_registry_map_; + Status Add(NativeFunction func); }; +/// \brief get the default function registry +GANDIVA_EXPORT std::shared_ptr default_function_registry(); + } // namespace gandiva diff --git a/cpp/src/gandiva/function_registry_test.cc b/cpp/src/gandiva/function_registry_test.cc index e3c1e85f79cba..bbe72c0ee970c 100644 --- a/cpp/src/gandiva/function_registry_test.cc +++ b/cpp/src/gandiva/function_registry_test.cc @@ -23,17 +23,26 @@ #include #include +#include "gandiva/tests/test_util.h" + namespace gandiva { class TestFunctionRegistry : public ::testing::Test { protected: - FunctionRegistry registry_; + std::shared_ptr registry_ = gandiva::default_function_registry(); + + static std::unique_ptr MakeFunctionRegistryWithExternalFunction() { + auto registry = std::make_unique(); + ARROW_EXPECT_OK( + registry->Register({GetTestExternalFunction()}, GetTestFunctionLLVMIRPath())); + return registry; + } }; TEST_F(TestFunctionRegistry, TestFound) { FunctionSignature add_i32_i32("add", {arrow::int32(), arrow::int32()}, arrow::int32()); - const NativeFunction* function = registry_.LookupSignature(add_i32_i32); + const NativeFunction* function = registry_->LookupSignature(add_i32_i32); EXPECT_NE(function, nullptr); EXPECT_THAT(function->signatures(), testing::Contains(add_i32_i32)); EXPECT_EQ(function->pc_name(), "add_int32_int32"); @@ -42,11 +51,32 @@ TEST_F(TestFunctionRegistry, TestFound) { TEST_F(TestFunctionRegistry, TestNotFound) { FunctionSignature addX_i32_i32("addX", {arrow::int32(), arrow::int32()}, arrow::int32()); - EXPECT_EQ(registry_.LookupSignature(addX_i32_i32), nullptr); + EXPECT_EQ(registry_->LookupSignature(addX_i32_i32), nullptr); FunctionSignature add_i32_i32_ret64("add", {arrow::int32(), arrow::int32()}, arrow::int64()); - EXPECT_EQ(registry_.LookupSignature(add_i32_i32_ret64), nullptr); + EXPECT_EQ(registry_->LookupSignature(add_i32_i32_ret64), nullptr); +} + +TEST_F(TestFunctionRegistry, TestCustomFunctionRegistry) { + auto registry = MakeFunctionRegistryWithExternalFunction(); + + auto multiply_by_two_func = GetTestExternalFunction(); + auto multiply_by_two_int32_ret64 = multiply_by_two_func.signatures().front(); + EXPECT_NE(registry->LookupSignature(multiply_by_two_int32_ret64), nullptr); + + FunctionSignature add_i32_i32_ret64("add", {arrow::int32(), arrow::int32()}, + arrow::int64()); + EXPECT_EQ(registry->LookupSignature(add_i32_i32_ret64), nullptr); +} + +TEST_F(TestFunctionRegistry, TestGetBitcodeMemoryBuffersDefaultFunctionRegistry) { + EXPECT_EQ(registry_->GetBitcodeBuffers().size(), 0); +} + +TEST_F(TestFunctionRegistry, TestGetBitcodeMemoryBuffersCustomFunctionRegistry) { + auto registry = MakeFunctionRegistryWithExternalFunction(); + EXPECT_EQ(registry->GetBitcodeBuffers().size(), 1); } // one nativefunction object per precompiled function @@ -55,10 +85,9 @@ TEST_F(TestFunctionRegistry, TestNoDuplicates) { std::unordered_set native_func_duplicates; std::unordered_set func_sigs; std::unordered_set func_sig_duplicates; - for (auto native_func_it = registry_.begin(); native_func_it != registry_.end(); - ++native_func_it) { - auto& first_sig = native_func_it->signatures().front(); - auto pc_func_sig = FunctionSignature(native_func_it->pc_name(), + for (const auto& native_func_it : *registry_) { + auto& first_sig = native_func_it.signatures().front(); + auto pc_func_sig = FunctionSignature(native_func_it.pc_name(), first_sig.param_types(), first_sig.ret_type()) .ToString(); if (pc_func_sigs.count(pc_func_sig) == 0) { @@ -67,7 +96,7 @@ TEST_F(TestFunctionRegistry, TestNoDuplicates) { native_func_duplicates.insert(pc_func_sig); } - for (auto& sig : native_func_it->signatures()) { + for (auto& sig : native_func_it.signatures()) { auto sig_str = sig.ToString(); if (func_sigs.count(sig_str) == 0) { func_sigs.insert(sig_str); diff --git a/cpp/src/gandiva/llvm_generator.cc b/cpp/src/gandiva/llvm_generator.cc index fa1d97be301a8..41cbe0ffe3a3a 100644 --- a/cpp/src/gandiva/llvm_generator.cc +++ b/cpp/src/gandiva/llvm_generator.cc @@ -36,11 +36,16 @@ namespace gandiva { AddTrace(__VA_ARGS__); \ } -LLVMGenerator::LLVMGenerator(bool cached) : cached_(cached), enable_ir_traces_(false) {} +LLVMGenerator::LLVMGenerator(bool cached, + std::shared_ptr function_registry) + : cached_(cached), + function_registry_(std::move(function_registry)), + enable_ir_traces_(false) {} -Status LLVMGenerator::Make(std::shared_ptr config, bool cached, +Status LLVMGenerator::Make(const std::shared_ptr& config, bool cached, std::unique_ptr* llvm_generator) { - std::unique_ptr llvmgen_obj(new LLVMGenerator(cached)); + std::unique_ptr llvmgen_obj( + new LLVMGenerator(cached, config->function_registry())); ARROW_RETURN_NOT_OK(Engine::Make(config, cached, &(llvmgen_obj->engine_))); *llvm_generator = std::move(llvmgen_obj); @@ -64,7 +69,7 @@ void LLVMGenerator::SetLLVMObjectCache(GandivaObjectCache& object_cache) { Status LLVMGenerator::Add(const ExpressionPtr expr, const FieldDescriptorPtr output) { int idx = static_cast(compiled_exprs_.size()); // decompose the expression to separate out value and validities. - ExprDecomposer decomposer(function_registry_, annotator_); + ExprDecomposer decomposer(*function_registry_, annotator_); ValueValidityPairPtr value_validity; ARROW_RETURN_NOT_OK(decomposer.Decompose(*expr->root(), &value_validity)); // Generate the IR function for the decomposed expression. diff --git a/cpp/src/gandiva/llvm_generator.h b/cpp/src/gandiva/llvm_generator.h index 04f9b854b1d29..1921e2565338b 100644 --- a/cpp/src/gandiva/llvm_generator.h +++ b/cpp/src/gandiva/llvm_generator.h @@ -47,7 +47,7 @@ class FunctionHolder; class GANDIVA_EXPORT LLVMGenerator { public: /// \brief Factory method to initialize the generator. - static Status Make(std::shared_ptr config, bool cached, + static Status Make(const std::shared_ptr& config, bool cached, std::unique_ptr* llvm_generator); /// \brief Get the cache to be used for LLVM ObjectCache. @@ -82,11 +82,13 @@ class GANDIVA_EXPORT LLVMGenerator { std::string DumpIR() { return engine_->DumpIR(); } private: - explicit LLVMGenerator(bool cached); + explicit LLVMGenerator(bool cached, + std::shared_ptr function_registry); FRIEND_TEST(TestLLVMGenerator, VerifyPCFunctions); FRIEND_TEST(TestLLVMGenerator, TestAdd); FRIEND_TEST(TestLLVMGenerator, TestNullInternal); + FRIEND_TEST(TestLLVMGenerator, VerifyExtendedPCFunctions); llvm::LLVMContext* context() { return engine_->context(); } llvm::IRBuilder<>* ir_builder() { return engine_->ir_builder(); } @@ -250,7 +252,7 @@ class GANDIVA_EXPORT LLVMGenerator { std::unique_ptr engine_; std::vector> compiled_exprs_; bool cached_; - FunctionRegistry function_registry_; + std::shared_ptr function_registry_; Annotator annotator_; SelectionVector::Mode selection_vector_mode_; diff --git a/cpp/src/gandiva/llvm_generator_test.cc b/cpp/src/gandiva/llvm_generator_test.cc index 028893b0b4594..671ce91e870f6 100644 --- a/cpp/src/gandiva/llvm_generator_test.cc +++ b/cpp/src/gandiva/llvm_generator_test.cc @@ -35,7 +35,7 @@ typedef int64_t (*add_vector_func_t)(int64_t* elements, int nelements); class TestLLVMGenerator : public ::testing::Test { protected: - FunctionRegistry registry_; + std::shared_ptr registry_ = default_function_registry(); }; // Verify that a valid pc function exists for every function in the registry. @@ -45,7 +45,7 @@ TEST_F(TestLLVMGenerator, VerifyPCFunctions) { llvm::Module* module = generator->module(); ASSERT_OK(generator->engine_->LoadFunctionIRs()); - for (auto& iter : registry_) { + for (auto& iter : *registry_) { EXPECT_NE(module->getFunction(iter.pc_name()), nullptr); } } @@ -73,7 +73,7 @@ TEST_F(TestLLVMGenerator, TestAdd) { FunctionSignature signature(func_desc->name(), func_desc->params(), func_desc->return_type()); const NativeFunction* native_func = - generator->function_registry_.LookupSignature(signature); + generator->function_registry_->LookupSignature(signature); std::vector pairs{pair0, pair1}; auto func_dex = std::make_shared( @@ -115,4 +115,17 @@ TEST_F(TestLLVMGenerator, TestAdd) { EXPECT_EQ(out_bitmap, 0ULL); } +TEST_F(TestLLVMGenerator, VerifyExtendedPCFunctions) { + auto external_registry = std::make_shared(); + auto config_with_func_registry = + TestConfigurationWithFunctionRegistry(std::move(external_registry)); + + std::unique_ptr generator; + ASSERT_OK(LLVMGenerator::Make(config_with_func_registry, false, &generator)); + + auto module = generator->module(); + ASSERT_OK(generator->engine_->LoadFunctionIRs()); + EXPECT_NE(module->getFunction("multiply_by_two_int32"), nullptr); +} + } // namespace gandiva diff --git a/cpp/src/gandiva/native_function.h b/cpp/src/gandiva/native_function.h index 1268a25674a9d..c20de3dbdd54d 100644 --- a/cpp/src/gandiva/native_function.h +++ b/cpp/src/gandiva/native_function.h @@ -54,16 +54,16 @@ class GANDIVA_EXPORT NativeFunction { bool CanReturnErrors() const { return (flags_ & kCanReturnErrors) != 0; } NativeFunction(const std::string& base_name, const std::vector& aliases, - const DataTypeVector& param_types, DataTypePtr ret_type, - const ResultNullableType& result_nullable_type, - const std::string& pc_name, int32_t flags = 0) + const DataTypeVector& param_types, const DataTypePtr& ret_type, + const ResultNullableType& result_nullable_type, std::string pc_name, + int32_t flags = 0) : signatures_(), flags_(flags), result_nullable_type_(result_nullable_type), - pc_name_(pc_name) { - signatures_.push_back(FunctionSignature(base_name, param_types, ret_type)); + pc_name_(std::move(pc_name)) { + signatures_.emplace_back(base_name, param_types, ret_type); for (auto& func_name : aliases) { - signatures_.push_back(FunctionSignature(func_name, param_types, ret_type)); + signatures_.emplace_back(func_name, param_types, ret_type); } } diff --git a/cpp/src/gandiva/precompiled/CMakeLists.txt b/cpp/src/gandiva/precompiled/CMakeLists.txt index 3e41640861123..e62a8e3d4a375 100644 --- a/cpp/src/gandiva/precompiled/CMakeLists.txt +++ b/cpp/src/gandiva/precompiled/CMakeLists.txt @@ -15,8 +15,6 @@ # specific language governing permissions and limitations # under the License. -project(gandiva) - set(PRECOMPILED_SRCS arithmetic_ops.cc bitmap.cc @@ -29,69 +27,18 @@ set(PRECOMPILED_SRCS time.cc timestamp_arithmetic.cc ../../arrow/util/basic_decimal.cc) - -set(PLATFORM_CLANG_OPTIONS -std=c++17) -if(MSVC) - # "19.20" means that it's compatible with Visual Studio 16 2019. - # We can update this to "19.30" when we dropped support for Visual - # Studio 16 2019. - # - # See https://cmake.org/cmake/help/latest/variable/MSVC_VERSION.html - # for MSVC_VERSION and Visual Studio version. - set(FMS_COMPATIBILITY 19.20) - list(APPEND PLATFORM_CLANG_OPTIONS -fms-compatibility - -fms-compatibility-version=${FMS_COMPATIBILITY}) -endif() - -# Create bitcode for each of the source files. -foreach(SRC_FILE ${PRECOMPILED_SRCS}) - get_filename_component(SRC_BASE ${SRC_FILE} NAME_WE) - get_filename_component(ABSOLUTE_SRC ${SRC_FILE} ABSOLUTE) - set(BC_FILE ${CMAKE_CURRENT_BINARY_DIR}/${SRC_BASE}.bc) - set(PRECOMPILE_COMMAND) - if(CMAKE_OSX_SYSROOT) - list(APPEND - PRECOMPILE_COMMAND - ${CMAKE_COMMAND} - -E - env - SDKROOT=${CMAKE_OSX_SYSROOT}) - endif() - list(APPEND - PRECOMPILE_COMMAND - ${CLANG_EXECUTABLE} - ${PLATFORM_CLANG_OPTIONS} - -DGANDIVA_IR - -DNDEBUG # DCHECK macros not implemented in precompiled code - -DARROW_STATIC # Do not set __declspec(dllimport) on MSVC on Arrow symbols - -DGANDIVA_STATIC # Do not set __declspec(dllimport) on MSVC on Gandiva symbols - -fno-use-cxa-atexit # Workaround for unresolved __dso_handle - -emit-llvm - -O3 - -c - ${ABSOLUTE_SRC} - -o - ${BC_FILE} - ${ARROW_GANDIVA_PC_CXX_FLAGS} - -I${CMAKE_SOURCE_DIR}/src - -I${ARROW_BINARY_DIR}/src) - - if(NOT ARROW_USE_NATIVE_INT128) - foreach(boost_include_dir ${Boost_INCLUDE_DIRS}) - list(APPEND PRECOMPILE_COMMAND -I${boost_include_dir}) - endforeach() - endif() - add_custom_command(OUTPUT ${BC_FILE} - COMMAND ${PRECOMPILE_COMMAND} - DEPENDS ${SRC_FILE}) - list(APPEND BC_FILES ${BC_FILE}) +set(GANDIVA_PRECOMPILED_BC_FILES) +foreach(SOURCE ${PRECOMPILED_SRCS}) + gandiva_add_bitcode(${SOURCE}) + get_filename_component(SOURCE_BASE ${SOURCE} NAME_WE) + list(APPEND GANDIVA_PRECOMPILED_BC_FILES ${CMAKE_CURRENT_BINARY_DIR}/${SOURCE_BASE}.bc) endforeach() # link all of the bitcode files into a single bitcode file. add_custom_command(OUTPUT ${GANDIVA_PRECOMPILED_BC_PATH} COMMAND ${LLVM_LINK_EXECUTABLE} -o ${GANDIVA_PRECOMPILED_BC_PATH} - ${BC_FILES} - DEPENDS ${BC_FILES}) + ${GANDIVA_PRECOMPILED_BC_FILES} + DEPENDS ${GANDIVA_PRECOMPILED_BC_FILES}) # turn the bitcode file into a C++ static data variable. add_custom_command(OUTPUT ${GANDIVA_PRECOMPILED_CC_PATH} diff --git a/cpp/src/gandiva/projector.cc b/cpp/src/gandiva/projector.cc index 7024a3bc208af..e717e825dfc71 100644 --- a/cpp/src/gandiva/projector.cc +++ b/cpp/src/gandiva/projector.cc @@ -87,7 +87,8 @@ Status Projector::Make(SchemaPtr schema, const ExpressionVector& exprs, // Return if any of the expression is invalid since // we will not be able to process further. if (!is_cached) { - ExprValidator expr_validator(llvm_gen->types(), schema); + ExprValidator expr_validator(llvm_gen->types(), schema, + configuration->function_registry()); for (auto& expr : exprs) { ARROW_RETURN_NOT_OK(expr_validator.Validate(expr)); } diff --git a/cpp/src/gandiva/tests/CMakeLists.txt b/cpp/src/gandiva/tests/CMakeLists.txt index 5fa2da16c632f..68138f50d813d 100644 --- a/cpp/src/gandiva/tests/CMakeLists.txt +++ b/cpp/src/gandiva/tests/CMakeLists.txt @@ -15,28 +15,41 @@ # specific language governing permissions and limitations # under the License. -add_gandiva_test(filter_test) -add_gandiva_test(projector_test) -add_gandiva_test(projector_build_validation_test) -add_gandiva_test(if_expr_test) -add_gandiva_test(literal_test) -add_gandiva_test(boolean_expr_test) -add_gandiva_test(binary_test) -add_gandiva_test(date_time_test) -add_gandiva_test(to_string_test) -add_gandiva_test(utf8_test) -add_gandiva_test(hash_test) -add_gandiva_test(in_expr_test) -add_gandiva_test(null_validity_test) -add_gandiva_test(decimal_test) -add_gandiva_test(decimal_single_test) -add_gandiva_test(filter_project_test) +add_gandiva_test(projector-test + SOURCES + binary_test.cc + boolean_expr_test.cc + date_time_test.cc + decimal_single_test.cc + decimal_test.cc + filter_project_test.cc + filter_test.cc + hash_test.cc + huge_table_test.cc + if_expr_test.cc + in_expr_test.cc + literal_test.cc + null_validity_test.cc + projector_build_validation_test.cc + projector_test.cc + test_util.cc + to_string_test.cc + utf8_test.cc) if(ARROW_BUILD_STATIC) - add_gandiva_test(projector_test_static SOURCES projector_test.cc USE_STATIC_LINKING) + add_gandiva_test(projector_test_static + SOURCES + projector_test.cc + test_util.cc + USE_STATIC_LINKING) add_arrow_benchmark(micro_benchmarks + SOURCES + micro_benchmarks.cc + test_util.cc PREFIX "gandiva" EXTRA_LINK_LIBS gandiva_static) endif() + +add_subdirectory(external_functions) diff --git a/cpp/src/gandiva/tests/date_time_test.cc b/cpp/src/gandiva/tests/date_time_test.cc index ce1c3d05f6638..6208f1ecba9b5 100644 --- a/cpp/src/gandiva/tests/date_time_test.cc +++ b/cpp/src/gandiva/tests/date_time_test.cc @@ -36,7 +36,7 @@ using arrow::int32; using arrow::int64; using arrow::timestamp; -class TestProjector : public ::testing::Test { +class DateTimeTestProjector : public ::testing::Test { public: void SetUp() { pool_ = arrow::default_memory_pool(); } @@ -111,7 +111,7 @@ int32_t DaysSince(time_t base_line, int32_t yy, int32_t mm, int32_t dd, int32_t return static_cast(((ts - base_line) * 1000 + millis) / MILLIS_IN_DAY); } -TEST_F(TestProjector, TestIsNull) { +TEST_F(DateTimeTestProjector, TestIsNull) { auto d0 = field("d0", date64()); auto t0 = field("t0", time32(arrow::TimeUnit::MILLI)); auto schema = arrow::schema({d0, t0}); @@ -155,7 +155,7 @@ TEST_F(TestProjector, TestIsNull) { EXPECT_ARROW_ARRAY_EQUALS(exp_isnotnull, outputs.at(1)); } -TEST_F(TestProjector, TestDate32IsNull) { +TEST_F(DateTimeTestProjector, TestDate32IsNull) { auto d0 = field("d0", date32()); auto schema = arrow::schema({d0}); @@ -191,7 +191,7 @@ TEST_F(TestProjector, TestDate32IsNull) { EXPECT_ARROW_ARRAY_EQUALS(exp_isnull, outputs.at(0)); } -TEST_F(TestProjector, TestDateTime) { +TEST_F(DateTimeTestProjector, TestDateTime) { auto field0 = field("f0", date64()); auto field1 = field("f1", date32()); auto field2 = field("f2", timestamp(arrow::TimeUnit::MILLI)); @@ -292,7 +292,7 @@ TEST_F(TestProjector, TestDateTime) { EXPECT_ARROW_ARRAY_EQUALS(exp_dd_from_ts, outputs.at(5)); } -TEST_F(TestProjector, TestTime) { +TEST_F(DateTimeTestProjector, TestTime) { auto field0 = field("f0", time32(arrow::TimeUnit::MILLI)); auto schema = arrow::schema({field0}); @@ -339,7 +339,7 @@ TEST_F(TestProjector, TestTime) { EXPECT_ARROW_ARRAY_EQUALS(exp_hour, outputs.at(1)); } -TEST_F(TestProjector, TestTimestampDiff) { +TEST_F(DateTimeTestProjector, TestTimestampDiff) { auto f0 = field("f0", timestamp(arrow::TimeUnit::MILLI)); auto f1 = field("f1", timestamp(arrow::TimeUnit::MILLI)); auto schema = arrow::schema({f0, f1}); @@ -439,7 +439,7 @@ TEST_F(TestProjector, TestTimestampDiff) { } } -TEST_F(TestProjector, TestTimestampDiffMonth) { +TEST_F(DateTimeTestProjector, TestTimestampDiffMonth) { auto f0 = field("f0", timestamp(arrow::TimeUnit::MILLI)); auto f1 = field("f1", timestamp(arrow::TimeUnit::MILLI)); auto schema = arrow::schema({f0, f1}); @@ -497,7 +497,7 @@ TEST_F(TestProjector, TestTimestampDiffMonth) { } } -TEST_F(TestProjector, TestMonthsBetween) { +TEST_F(DateTimeTestProjector, TestMonthsBetween) { auto f0 = field("f0", arrow::date64()); auto f1 = field("f1", arrow::date64()); auto schema = arrow::schema({f0, f1}); @@ -550,7 +550,7 @@ TEST_F(TestProjector, TestMonthsBetween) { EXPECT_ARROW_ARRAY_EQUALS(exp_output, outputs.at(0)); } -TEST_F(TestProjector, TestCastTimestampFromInt64) { +TEST_F(DateTimeTestProjector, TestCastTimestampFromInt64) { auto f0 = field("f0", arrow::int64()); auto schema = arrow::schema({f0}); @@ -600,7 +600,7 @@ TEST_F(TestProjector, TestCastTimestampFromInt64) { EXPECT_ARROW_ARRAY_EQUALS(exp_output, outputs.at(0)); } -TEST_F(TestProjector, TestLastDay) { +TEST_F(DateTimeTestProjector, TestLastDay) { auto f0 = field("f0", arrow::date64()); auto schema = arrow::schema({f0}); @@ -650,7 +650,7 @@ TEST_F(TestProjector, TestLastDay) { EXPECT_ARROW_ARRAY_EQUALS(exp_output, outputs.at(0)); } -TEST_F(TestProjector, TestToTimestampFromInt) { +TEST_F(DateTimeTestProjector, TestToTimestampFromInt) { auto f0 = field("f0", arrow::int32()); auto f1 = field("f1", arrow::int64()); auto f2 = field("f2", arrow::float32()); @@ -721,7 +721,7 @@ TEST_F(TestProjector, TestToTimestampFromInt) { EXPECT_ARROW_ARRAY_EQUALS(exp_output1, outputs.at(3)); } -TEST_F(TestProjector, TestToUtcTimestamp) { +TEST_F(DateTimeTestProjector, TestToUtcTimestamp) { auto f0 = field("f0", timestamp(arrow::TimeUnit::MILLI)); auto f1 = field("f1", arrow::utf8()); @@ -775,7 +775,7 @@ TEST_F(TestProjector, TestToUtcTimestamp) { EXPECT_ARROW_ARRAY_EQUALS(exp_output, outputs.at(0)); } -TEST_F(TestProjector, TestFromUtcTimestamp) { +TEST_F(DateTimeTestProjector, TestFromUtcTimestamp) { auto f0 = field("f0", timestamp(arrow::TimeUnit::MILLI)); auto f1 = field("f1", arrow::utf8()); diff --git a/cpp/src/gandiva/tests/external_functions/CMakeLists.txt b/cpp/src/gandiva/tests/external_functions/CMakeLists.txt new file mode 100644 index 0000000000000..c309549e874e3 --- /dev/null +++ b/cpp/src/gandiva/tests/external_functions/CMakeLists.txt @@ -0,0 +1,50 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +if(NO_TESTS) + return() +endif() +# +## copy the testing data into the build directory +add_custom_target(extension-tests-data + COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_CURRENT_SOURCE_DIR} + ${CMAKE_CURRENT_BINARY_DIR}) + +set(TEST_PRECOMPILED_SOURCES multiply_by_two.cc) +set(TEST_PRECOMPILED_BC_FILES) +foreach(SOURCE ${TEST_PRECOMPILED_SOURCES}) + gandiva_add_bitcode(${SOURCE}) + get_filename_component(SOURCE_BASE ${SOURCE} NAME_WE) + list(APPEND TEST_PRECOMPILED_BC_FILES ${CMAKE_CURRENT_BINARY_DIR}/${SOURCE_BASE}.bc) +endforeach() +add_custom_target(extension-tests ALL DEPENDS extension-tests-data + ${TEST_PRECOMPILED_BC_FILES}) +# +## set the GANDIVA_EXTENSION_TEST_DIR macro so that the tests can pass regardless where they are run from +## corresponding extension test data files and bitcode will be copied/generated +set(TARGETS_DEPENDING_ON_TEST_BITCODE_FILES gandiva-internals-test gandiva-projector-test + gandiva-projector-test-static) +foreach(TARGET ${TARGETS_DEPENDING_ON_TEST_BITCODE_FILES}) + if(TARGET ${TARGET}) + add_dependencies(${TARGET} extension-tests) + target_compile_definitions(${TARGET} + PRIVATE -DGANDIVA_EXTENSION_TEST_DIR="${CMAKE_CURRENT_BINARY_DIR}" + ) + endif() +endforeach() + +add_dependencies(gandiva-tests extension-tests) diff --git a/cpp/src/gandiva/tests/external_functions/multiply_by_two.cc b/cpp/src/gandiva/tests/external_functions/multiply_by_two.cc new file mode 100644 index 0000000000000..cc7e2b0f8267f --- /dev/null +++ b/cpp/src/gandiva/tests/external_functions/multiply_by_two.cc @@ -0,0 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "multiply_by_two.h" // NOLINT + +int64_t multiply_by_two_int32(int32_t value) { return value * 2; } diff --git a/cpp/src/gandiva/tests/external_functions/multiply_by_two.h b/cpp/src/gandiva/tests/external_functions/multiply_by_two.h new file mode 100644 index 0000000000000..b8aec5185457b --- /dev/null +++ b/cpp/src/gandiva/tests/external_functions/multiply_by_two.h @@ -0,0 +1,24 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +extern "C" { +int64_t multiply_by_two_int32(int32_t value); +} diff --git a/cpp/src/gandiva/tests/filter_test.cc b/cpp/src/gandiva/tests/filter_test.cc index effd31cc27aa0..749000aa0cf27 100644 --- a/cpp/src/gandiva/tests/filter_test.cc +++ b/cpp/src/gandiva/tests/filter_test.cc @@ -42,8 +42,8 @@ class TestFilter : public ::testing::Test { TEST_F(TestFilter, TestFilterCache) { // schema for input fields - auto field0 = field("f0", int32()); - auto field1 = field("f1", int32()); + auto field0 = field("f0_filter_cache", int32()); + auto field1 = field("f1_filter_cache", int32()); auto schema = arrow::schema({field0, field1}); // Build condition f0 + f1 < 10 @@ -69,7 +69,7 @@ TEST_F(TestFilter, TestFilterCache) { EXPECT_TRUE(cached_filter->GetBuiltFromCache()); // schema is different should return a new filter. - auto field2 = field("f2", int32()); + auto field2 = field("f2_filter_cache", int32()); auto different_schema = arrow::schema({field0, field1, field2}); std::shared_ptr should_be_new_filter; status = diff --git a/cpp/src/gandiva/tests/huge_table_test.cc b/cpp/src/gandiva/tests/huge_table_test.cc index 46f814b472d84..34c8512f1b0a9 100644 --- a/cpp/src/gandiva/tests/huge_table_test.cc +++ b/cpp/src/gandiva/tests/huge_table_test.cc @@ -139,8 +139,11 @@ TEST_F(LARGE_MEMORY_TEST(TestHugeFilter), TestSimpleHugeFilter) { auto status = Filter::Make(schema, condition, TestConfiguration(), &filter); EXPECT_TRUE(status.ok()); + auto array1 = MakeArrowArray(arr1, validity); + auto array2 = MakeArrowArray(arr2, validity); + // prepare input record batch - auto in_batch = arrow::RecordBatch::Make(schema, num_records, {arr1, arr2}); + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array1, array2}); std::shared_ptr selection_vector; status = SelectionVector::MakeInt64(num_records, pool_, &selection_vector); diff --git a/cpp/src/gandiva/tests/projector_build_validation_test.cc b/cpp/src/gandiva/tests/projector_build_validation_test.cc index 5b86844f940bf..1ed4c77a074ab 100644 --- a/cpp/src/gandiva/tests/projector_build_validation_test.cc +++ b/cpp/src/gandiva/tests/projector_build_validation_test.cc @@ -27,7 +27,7 @@ using arrow::boolean; using arrow::float32; using arrow::int32; -class TestProjector : public ::testing::Test { +class ValidationTestProjector : public ::testing::Test { public: void SetUp() { pool_ = arrow::default_memory_pool(); } @@ -35,7 +35,7 @@ class TestProjector : public ::testing::Test { arrow::MemoryPool* pool_; }; -TEST_F(TestProjector, TestNonexistentFunction) { +TEST_F(ValidationTestProjector, TestNonexistentFunction) { // schema for input fields auto field0 = field("f0", float32()); auto field1 = field("f2", float32()); @@ -57,7 +57,7 @@ TEST_F(TestProjector, TestNonexistentFunction) { EXPECT_TRUE(status.message().find(expected_error) != std::string::npos); } -TEST_F(TestProjector, TestNotMatchingDataType) { +TEST_F(ValidationTestProjector, TestNotMatchingDataType) { // schema for input fields auto field0 = field("f0", float32()); auto schema = arrow::schema({field0}); @@ -78,7 +78,7 @@ TEST_F(TestProjector, TestNotMatchingDataType) { EXPECT_TRUE(status.message().find(expected_error) != std::string::npos); } -TEST_F(TestProjector, TestNotSupportedDataType) { +TEST_F(ValidationTestProjector, TestNotSupportedDataType) { // schema for input fields auto field0 = field("f0", list(int32())); auto schema = arrow::schema({field0}); @@ -98,7 +98,7 @@ TEST_F(TestProjector, TestNotSupportedDataType) { EXPECT_TRUE(status.message().find(expected_error) != std::string::npos); } -TEST_F(TestProjector, TestIncorrectSchemaMissingField) { +TEST_F(ValidationTestProjector, TestIncorrectSchemaMissingField) { // schema for input fields auto field0 = field("f0", float32()); auto field1 = field("f2", float32()); @@ -119,7 +119,7 @@ TEST_F(TestProjector, TestIncorrectSchemaMissingField) { EXPECT_TRUE(status.message().find(expected_error) != std::string::npos); } -TEST_F(TestProjector, TestIncorrectSchemaTypeNotMatching) { +TEST_F(ValidationTestProjector, TestIncorrectSchemaTypeNotMatching) { // schema for input fields auto field0 = field("f0", float32()); auto field1 = field("f2", float32()); @@ -142,7 +142,7 @@ TEST_F(TestProjector, TestIncorrectSchemaTypeNotMatching) { EXPECT_TRUE(status.message().find(expected_error) != std::string::npos); } -TEST_F(TestProjector, TestIfNotSupportedFunction) { +TEST_F(ValidationTestProjector, TestIfNotSupportedFunction) { // schema for input fields auto fielda = field("a", int32()); auto fieldb = field("b", int32()); @@ -170,7 +170,7 @@ TEST_F(TestProjector, TestIfNotSupportedFunction) { EXPECT_TRUE(status.IsExpressionValidationError()); } -TEST_F(TestProjector, TestIfNotMatchingReturnType) { +TEST_F(ValidationTestProjector, TestIfNotMatchingReturnType) { // schema for input fields auto fielda = field("a", int32()); auto fieldb = field("b", int32()); @@ -193,7 +193,7 @@ TEST_F(TestProjector, TestIfNotMatchingReturnType) { EXPECT_TRUE(status.IsExpressionValidationError()); } -TEST_F(TestProjector, TestElseNotMatchingReturnType) { +TEST_F(ValidationTestProjector, TestElseNotMatchingReturnType) { // schema for input fields auto fielda = field("a", int32()); auto fieldb = field("b", int32()); @@ -218,7 +218,7 @@ TEST_F(TestProjector, TestElseNotMatchingReturnType) { EXPECT_TRUE(status.IsExpressionValidationError()); } -TEST_F(TestProjector, TestElseNotSupportedType) { +TEST_F(ValidationTestProjector, TestElseNotSupportedType) { // schema for input fields auto fielda = field("a", int32()); auto fieldb = field("b", int32()); @@ -244,7 +244,7 @@ TEST_F(TestProjector, TestElseNotSupportedType) { EXPECT_EQ(status.code(), StatusCode::ExpressionValidationError); } -TEST_F(TestProjector, TestAndMinChildren) { +TEST_F(ValidationTestProjector, TestAndMinChildren) { // schema for input fields auto fielda = field("a", boolean()); auto schema = arrow::schema({fielda}); @@ -263,7 +263,7 @@ TEST_F(TestProjector, TestAndMinChildren) { EXPECT_TRUE(status.IsExpressionValidationError()); } -TEST_F(TestProjector, TestAndBooleanArgType) { +TEST_F(ValidationTestProjector, TestAndBooleanArgType) { // schema for input fields auto fielda = field("a", boolean()); auto fieldb = field("b", int32()); diff --git a/cpp/src/gandiva/tests/projector_test.cc b/cpp/src/gandiva/tests/projector_test.cc index 462fae64393fd..38566fb408ab5 100644 --- a/cpp/src/gandiva/tests/projector_test.cc +++ b/cpp/src/gandiva/tests/projector_test.cc @@ -26,6 +26,7 @@ #include #include "arrow/memory_pool.h" +#include "gandiva/function_registry.h" #include "gandiva/literal_holder.h" #include "gandiva/node.h" #include "gandiva/tests/test_util.h" @@ -3582,4 +3583,29 @@ TEST_F(TestProjector, TestSqrtFloat64) { EXPECT_ARROW_ARRAY_EQUALS(out, outs.at(0)); } +TEST_F(TestProjector, TestExtendedFunctions) { + auto in_field = field("in", arrow::int32()); + auto schema = arrow::schema({in_field}); + auto out_field = field("out", arrow::int64()); + // the multiply_by_two function is only available in the external function's IR bitcode + auto multiply = + TreeExprBuilder::MakeExpression("multiply_by_two", {in_field}, out_field); + + std::shared_ptr projector; + auto external_registry = std::make_shared(); + auto config_with_func_registry = + TestConfigurationWithFunctionRegistry(std::move(external_registry)); + ARROW_EXPECT_OK( + Projector::Make(schema, {multiply}, config_with_func_registry, &projector)); + + int num_records = 4; + auto array = MakeArrowArrayInt32({1, 2, 3, 4}, {true, true, true, true}); + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array}); + auto out = MakeArrowArrayInt64({2, 4, 6, 8}, {true, true, true, true}); + + arrow::ArrayVector outs; + ARROW_EXPECT_OK(projector->Evaluate(*in_batch, pool_, &outs)); + EXPECT_ARROW_ARRAY_EQUALS(out, outs.at(0)); +} + } // namespace gandiva diff --git a/cpp/src/gandiva/tests/test_util.cc b/cpp/src/gandiva/tests/test_util.cc new file mode 100644 index 0000000000000..42f67d3824a21 --- /dev/null +++ b/cpp/src/gandiva/tests/test_util.cc @@ -0,0 +1,52 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/tests/test_util.h" + +#include + +namespace gandiva { +std::shared_ptr TestConfiguration() { + auto builder = ConfigurationBuilder(); + return builder.DefaultConfiguration(); +} + +#ifndef GANDIVA_EXTENSION_TEST_DIR +#define GANDIVA_EXTENSION_TEST_DIR "." +#endif + +std::string GetTestFunctionLLVMIRPath() { + std::filesystem::path base(GANDIVA_EXTENSION_TEST_DIR); + std::filesystem::path ir_file = base / "multiply_by_two.bc"; + return ir_file.string(); +} + +NativeFunction GetTestExternalFunction() { + NativeFunction multiply_by_two_func( + "multiply_by_two", {}, {arrow::int32()}, arrow::int64(), + ResultNullableType::kResultNullIfNull, "multiply_by_two_int32"); + return multiply_by_two_func; +} + +std::shared_ptr TestConfigurationWithFunctionRegistry( + std::shared_ptr registry) { + ARROW_EXPECT_OK( + registry->Register({GetTestExternalFunction()}, GetTestFunctionLLVMIRPath())); + auto external_func_config = ConfigurationBuilder().build(std::move(registry)); + return external_func_config; +} +} // namespace gandiva diff --git a/cpp/src/gandiva/tests/test_util.h b/cpp/src/gandiva/tests/test_util.h index 99df90769e0ad..e431e53096c2c 100644 --- a/cpp/src/gandiva/tests/test_util.h +++ b/cpp/src/gandiva/tests/test_util.h @@ -96,9 +96,12 @@ static inline ArrayPtr MakeArrowTypeArray(const std::shared_ptr EXPECT_TRUE((a)->Equals(b)) << "expected type: " << (a)->ToString() \ << " actual type: " << (b)->ToString() -static inline std::shared_ptr TestConfiguration() { - auto builder = ConfigurationBuilder(); - return builder.DefaultConfiguration(); -} +std::shared_ptr TestConfiguration(); + +std::shared_ptr TestConfigurationWithFunctionRegistry( + std::shared_ptr registry); + +std::string GetTestFunctionLLVMIRPath(); +NativeFunction GetTestExternalFunction(); } // namespace gandiva diff --git a/cpp/src/gandiva/tree_expr_test.cc b/cpp/src/gandiva/tree_expr_test.cc index e70cf12898124..86a826f29367f 100644 --- a/cpp/src/gandiva/tree_expr_test.cc +++ b/cpp/src/gandiva/tree_expr_test.cc @@ -45,7 +45,7 @@ class TestExprTree : public ::testing::Test { FieldPtr i1_; // int32 FieldPtr b0_; // bool - FunctionRegistry registry_; + std::shared_ptr registry_ = gandiva::default_function_registry(); }; TEST_F(TestExprTree, TestField) { @@ -57,7 +57,7 @@ TEST_F(TestExprTree, TestField) { auto n1 = TreeExprBuilder::MakeField(b0_); EXPECT_EQ(n1->return_type(), boolean()); - ExprDecomposer decomposer(registry_, annotator); + ExprDecomposer decomposer(*registry_, annotator); ValueValidityPairPtr pair; auto status = decomposer.Decompose(*n1, &pair); DCHECK_EQ(status.ok(), true) << status.message(); @@ -88,7 +88,7 @@ TEST_F(TestExprTree, TestBinary) { EXPECT_EQ(add->return_type(), int32()); EXPECT_TRUE(sign == FunctionSignature("add", {int32(), int32()}, int32())); - ExprDecomposer decomposer(registry_, annotator); + ExprDecomposer decomposer(*registry_, annotator); ValueValidityPairPtr pair; auto status = decomposer.Decompose(*n, &pair); DCHECK_EQ(status.ok(), true) << status.message(); @@ -97,7 +97,7 @@ TEST_F(TestExprTree, TestBinary) { auto null_if_null = std::dynamic_pointer_cast(value); FunctionSignature signature("add", {int32(), int32()}, int32()); - const NativeFunction* fn = registry_.LookupSignature(signature); + const NativeFunction* fn = registry_->LookupSignature(signature); EXPECT_EQ(null_if_null->native_function(), fn); } @@ -114,7 +114,7 @@ TEST_F(TestExprTree, TestUnary) { EXPECT_EQ(unaryFn->return_type(), boolean()); EXPECT_TRUE(sign == FunctionSignature("isnumeric", {int32()}, boolean())); - ExprDecomposer decomposer(registry_, annotator); + ExprDecomposer decomposer(*registry_, annotator); ValueValidityPairPtr pair; auto status = decomposer.Decompose(*n, &pair); DCHECK_EQ(status.ok(), true) << status.message(); @@ -123,7 +123,7 @@ TEST_F(TestExprTree, TestUnary) { auto never_null = std::dynamic_pointer_cast(value); FunctionSignature signature("isnumeric", {int32()}, boolean()); - const NativeFunction* fn = registry_.LookupSignature(signature); + const NativeFunction* fn = registry_->LookupSignature(signature); EXPECT_EQ(never_null->native_function(), fn); } @@ -143,7 +143,7 @@ TEST_F(TestExprTree, TestExpression) { func_desc->return_type()); EXPECT_TRUE(sign == FunctionSignature("add", {int32(), int32()}, int32())); - ExprDecomposer decomposer(registry_, annotator); + ExprDecomposer decomposer(*registry_, annotator); ValueValidityPairPtr pair; auto status = decomposer.Decompose(*root_node, &pair); DCHECK_EQ(status.ok(), true) << status.message(); @@ -152,7 +152,7 @@ TEST_F(TestExprTree, TestExpression) { auto null_if_null = std::dynamic_pointer_cast(value); FunctionSignature signature("add", {int32(), int32()}, int32()); - const NativeFunction* fn = registry_.LookupSignature(signature); + const NativeFunction* fn = registry_->LookupSignature(signature); EXPECT_EQ(null_if_null->native_function(), fn); }