From 1e186630b27025bfdc6969c8faa814c25a698ae3 Mon Sep 17 00:00:00 2001 From: Yue Ni Date: Tue, 19 Sep 2023 14:57:57 +0800 Subject: [PATCH] Add external function registry support to gandiva. JSON based function registry can be used for describing the function metadata, and LLVM bitcode can be automatically loaded as pre-compiled external functions. --- cpp/src/gandiva/CMakeLists.txt | 5 + cpp/src/gandiva/cmake/GenerateBitcode.cmake | 84 +++++ cpp/src/gandiva/engine.cc | 40 +++ cpp/src/gandiva/engine.h | 2 + .../gandiva/extension_tests/CMakeLists.txt | 39 +++ .../complex_type_registry/registry.json | 30 ++ .../extended_funcs/multiply_by_two.cc | 5 + .../extended_funcs/multiply_by_two.h | 7 + .../extended_funcs/registry.json | 19 ++ .../multiple_functions_registry/registry.json | 35 ++ .../multiple_registries/reg_1.json | 21 ++ .../multiple_registries/reg_2.json | 18 ++ .../no_name_func_registry/registry.json | 7 + .../simple_registry/registry.json | 22 ++ cpp/src/gandiva/function_registry.cc | 21 ++ cpp/src/gandiva/function_registry.h | 4 + cpp/src/gandiva/function_registry_external.cc | 301 ++++++++++++++++++ cpp/src/gandiva/function_registry_external.h | 27 ++ .../function_registry_external_test.cc | 89 ++++++ cpp/src/gandiva/function_registry_test.cc | 35 +- cpp/src/gandiva/llvm_generator.h | 1 + cpp/src/gandiva/llvm_generator_test.cc | 15 + cpp/src/gandiva/precompiled/CMakeLists.txt | 66 +--- cpp/src/gandiva/tests/projector_test.cc | 27 ++ cpp/src/gandiva/tests/test_util.h | 26 ++ 25 files changed, 883 insertions(+), 63 deletions(-) create mode 100644 cpp/src/gandiva/cmake/GenerateBitcode.cmake create mode 100644 cpp/src/gandiva/extension_tests/CMakeLists.txt create mode 100644 cpp/src/gandiva/extension_tests/complex_type_registry/registry.json create mode 100644 cpp/src/gandiva/extension_tests/extended_funcs/multiply_by_two.cc create mode 100644 cpp/src/gandiva/extension_tests/extended_funcs/multiply_by_two.h create mode 100644 cpp/src/gandiva/extension_tests/extended_funcs/registry.json create mode 100644 cpp/src/gandiva/extension_tests/multiple_functions_registry/registry.json create mode 100644 cpp/src/gandiva/extension_tests/multiple_registries/reg_1.json create mode 100644 cpp/src/gandiva/extension_tests/multiple_registries/reg_2.json create mode 100644 cpp/src/gandiva/extension_tests/no_name_func_registry/registry.json create mode 100644 cpp/src/gandiva/extension_tests/simple_registry/registry.json create mode 100644 cpp/src/gandiva/function_registry_external.cc create mode 100644 cpp/src/gandiva/function_registry_external.h create mode 100644 cpp/src/gandiva/function_registry_external_test.cc diff --git a/cpp/src/gandiva/CMakeLists.txt b/cpp/src/gandiva/CMakeLists.txt index db260b5acc933..bc970299b997f 100644 --- a/cpp/src/gandiva/CMakeLists.txt +++ b/cpp/src/gandiva/CMakeLists.txt @@ -65,6 +65,7 @@ set(SRC_FILES function_registry.cc function_registry_arithmetic.cc function_registry_datetime.cc + function_registry_external.cc function_registry_hash.cc function_registry_math_ops.cc function_registry_string.cc @@ -232,6 +233,7 @@ add_gandiva_test(internals-test bitmap_accumulator_test.cc engine_llvm_test.cc function_registry_test.cc + function_registry_external_test.cc function_signature_test.cc llvm_types_test.cc llvm_generator_test.cc @@ -251,5 +253,8 @@ add_gandiva_test(internals-test gdv_function_stubs_test.cc interval_holder_test.cc) + add_subdirectory(precompiled) add_subdirectory(tests) +add_subdirectory(extension_tests) + diff --git a/cpp/src/gandiva/cmake/GenerateBitcode.cmake b/cpp/src/gandiva/cmake/GenerateBitcode.cmake new file mode 100644 index 0000000000000..3a56c042c3312 --- /dev/null +++ b/cpp/src/gandiva/cmake/GenerateBitcode.cmake @@ -0,0 +1,84 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Create bitcode for each of the source files. +function(generate_bitcode PRECOMPILED_SRC_LIST OUTPUT_DIR OUTPUT_VAR) + set(LOCAL_BC_FILES "") + + if(MSVC) + # clang pretends to be a particular version of MSVC. Thestandard + # library uses C++14 features, so we have to use that -std version + # to get the IR compilation to work. + # See https://cmake.org/cmake/help/latest/variable/MSVC_VERSION.html + # for MSVC_VERSION and Visual Studio version. + if(MSVC_VERSION LESS 1930) + set(FMS_COMPATIBILITY 19.20) + elseif(MSVC_VERSION LESS 1920) + set(FMS_COMPATIBILITY 19.10) + else() + message(FATAL_ERROR "Unsupported MSVC_VERSION=${MSVC_VERSION}") + endif() + set(PLATFORM_CLANG_OPTIONS -std=c++17 -fms-compatibility + -fms-compatibility-version=${FMS_COMPATIBILITY}) + else() + set(PLATFORM_CLANG_OPTIONS -std=c++17) + endif() + + foreach(SRC_FILE ${PRECOMPILED_SRC_LIST}) + get_filename_component(SRC_BASE ${SRC_FILE} NAME_WE) + get_filename_component(ABSOLUTE_SRC ${SRC_FILE} ABSOLUTE) + set(BC_FILE ${CMAKE_CURRENT_BINARY_DIR}/${OUTPUT_DIR}${SRC_BASE}.bc) + set(PRECOMPILE_COMMAND) + if(CMAKE_OSX_SYSROOT) + list(APPEND + PRECOMPILE_COMMAND + ${CMAKE_COMMAND} + -E + env + SDKROOT=${CMAKE_OSX_SYSROOT}) + endif() + list(APPEND + PRECOMPILE_COMMAND + ${CLANG_EXECUTABLE} + ${PLATFORM_CLANG_OPTIONS} + -DGANDIVA_IR + -DNDEBUG # DCHECK macros not implemented in precompiled code + -DARROW_STATIC # Do not set __declspec(dllimport) on MSVC on Arrow symbols + -DGANDIVA_STATIC # Do not set __declspec(dllimport) on MSVC on Gandiva symbols + -fno-use-cxa-atexit # Workaround for unresolved __dso_handle + -emit-llvm + -O3 + -c + ${ABSOLUTE_SRC} + -o + ${BC_FILE} + ${ARROW_GANDIVA_PC_CXX_FLAGS} + -I${CMAKE_SOURCE_DIR}/src + -I${ARROW_BINARY_DIR}/src) + + if(NOT ARROW_USE_NATIVE_INT128) + foreach(boost_include_dir ${Boost_INCLUDE_DIRS}) + list(APPEND PRECOMPILE_COMMAND -I${boost_include_dir}) + endforeach() + endif() + add_custom_command(OUTPUT ${BC_FILE} + COMMAND ${PRECOMPILE_COMMAND} + DEPENDS ${SRC_FILE}) + list(APPEND LOCAL_BC_FILES ${BC_FILE}) + endforeach() + set(${OUTPUT_VAR} "${LOCAL_BC_FILES}" PARENT_SCOPE) +endfunction() diff --git a/cpp/src/gandiva/engine.cc b/cpp/src/gandiva/engine.cc index 7d75793a3e9e7..be7e7adb85dbf 100644 --- a/cpp/src/gandiva/engine.cc +++ b/cpp/src/gandiva/engine.cc @@ -137,6 +137,10 @@ Status Engine::LoadFunctionIRs() { if (!functions_loaded_) { ARROW_RETURN_NOT_OK(LoadPreCompiledIR()); ARROW_RETURN_NOT_OK(DecimalIR::AddFunctions(this)); + const char* ext_dir_env = std::getenv("GANDIVA_EXTENSION_DIR"); + if (ext_dir_env) { + ARROW_RETURN_NOT_OK(LoadExtendedPreCompiledIR(ext_dir_env)); + } functions_loaded_ = true; } return Status::OK(); @@ -220,6 +224,42 @@ static void SetDataLayout(llvm::Module* module) { } // end of the mofified method from MLIR +// Loading extended IR files from the given directory +// all .bc files under the given directory will be loaded and parsed +Status Engine::LoadExtendedPreCompiledIR(const std::filesystem::path& dir_path) { + for (const auto& entry : std::filesystem::directory_iterator(dir_path)) { + if (entry.is_regular_file() && entry.path().extension() == ".bc") { + llvm::ErrorOr> buffer_or_error = + llvm::MemoryBuffer::getFile(entry.path().string()); + + ARROW_RETURN_IF(!buffer_or_error, + Status::CodeGenError("Could not load module from IR file: ", + entry.path().string() + " Error: " + + buffer_or_error.getError().message())); + + std::unique_ptr buffer = std::move(buffer_or_error.get()); + + llvm::Expected> module_or_error = + llvm::parseBitcodeFile(buffer->getMemBufferRef(), *context()); + if (!module_or_error) { + std::string str; + llvm::raw_string_ostream stream(str); + stream << module_or_error.takeError(); + return Status::CodeGenError("Failed to parse bitcode file: " + + entry.path().string() + " Error: " + stream.str()); + } + std::unique_ptr ir_module = std::move(module_or_error.get()); + + ARROW_RETURN_IF(llvm::verifyModule(*ir_module, &llvm::errs()), + Status::CodeGenError("verify of IR Module failed")); + ARROW_RETURN_IF(llvm::Linker::linkModules(*module_, std::move(ir_module)), + Status::CodeGenError("failed to link IR Modules")); + } + } + + return Status::OK(); +} + // Handling for pre-compiled IR libraries. Status Engine::LoadPreCompiledIR() { auto bitcode = llvm::StringRef(reinterpret_cast(kPrecompiledBitcode), diff --git a/cpp/src/gandiva/engine.h b/cpp/src/gandiva/engine.h index a4d6a5fd1a758..639b6444d0210 100644 --- a/cpp/src/gandiva/engine.h +++ b/cpp/src/gandiva/engine.h @@ -17,6 +17,7 @@ #pragma once +#include #include #include #include @@ -92,6 +93,7 @@ class GANDIVA_EXPORT Engine { /// load pre-compiled IR modules from precompiled_bitcode.cc and merge them into /// the main module. Status LoadPreCompiledIR(); + Status LoadExtendedPreCompiledIR(const std::filesystem::path& dir_path); // Create and add mappings for cpp functions that can be accessed from LLVM. void AddGlobalMappings(); diff --git a/cpp/src/gandiva/extension_tests/CMakeLists.txt b/cpp/src/gandiva/extension_tests/CMakeLists.txt new file mode 100644 index 0000000000000..3a69bb3b99765 --- /dev/null +++ b/cpp/src/gandiva/extension_tests/CMakeLists.txt @@ -0,0 +1,39 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# copy the testing data into the build directory +add_custom_target(extension-tests-data + COMMAND ${CMAKE_COMMAND} -E copy_directory + ${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_BINARY_DIR}/gandiva_extension_tests) + +include(../cmake/GenerateBitcode.cmake) + +set(TEST_EXT_DIR ${CMAKE_CURRENT_SOURCE_DIR}/extended_funcs) +set(TEST_PRECOMPILED_SRCS ${TEST_EXT_DIR}/multiply_by_two.cc) +generate_bitcode("${TEST_PRECOMPILED_SRCS}" "../../../gandiva_extension_tests/extended_funcs/" TEST_BC_FILES) +add_custom_target(extension-tests ALL DEPENDS extension-tests-data ${TEST_BC_FILES}) + +add_dependencies(gandiva-internals-test extension-tests) +add_dependencies(gandiva-tests extension-tests) + +# set the GANDIVA_EXTENSION_TEST_DIR macro so that the tests can pass regardless where they are run from +target_compile_definitions(gandiva-internals-test + PRIVATE -DGANDIVA_EXTENSION_TEST_DIR="${CMAKE_BINARY_DIR}/gandiva_extension_tests") +target_compile_definitions(gandiva-projector-test + PRIVATE -DGANDIVA_EXTENSION_TEST_DIR="${CMAKE_BINARY_DIR}/gandiva_extension_tests") + + diff --git a/cpp/src/gandiva/extension_tests/complex_type_registry/registry.json b/cpp/src/gandiva/extension_tests/complex_type_registry/registry.json new file mode 100644 index 0000000000000..88a9c28e1398d --- /dev/null +++ b/cpp/src/gandiva/extension_tests/complex_type_registry/registry.json @@ -0,0 +1,30 @@ +{ + "version": "1.0", + "functions": [ + { + "name": "greet", + "aliases": [ + ], + "param_types": [ + { + "type": "timestamp", + "unit": "second" + }, + { + "type": "list", + "value_type": { + "type": "int32" + } + } + ], + "return_type": { + "type": "decimal", + "precision": 10, + "scale": 2 + }, + "result_nullable": "never", + "can_return_errors": true, + "pc_name": "greet_timestamp_list" + } + ] +} \ No newline at end of file diff --git a/cpp/src/gandiva/extension_tests/extended_funcs/multiply_by_two.cc b/cpp/src/gandiva/extension_tests/extended_funcs/multiply_by_two.cc new file mode 100644 index 0000000000000..2d82233c1433d --- /dev/null +++ b/cpp/src/gandiva/extension_tests/extended_funcs/multiply_by_two.cc @@ -0,0 +1,5 @@ +#include "multiply_by_two.h" + +int64_t multiply_by_two_int32(int32_t value) { + return value * 2; +} \ No newline at end of file diff --git a/cpp/src/gandiva/extension_tests/extended_funcs/multiply_by_two.h b/cpp/src/gandiva/extension_tests/extended_funcs/multiply_by_two.h new file mode 100644 index 0000000000000..1c28a8da8f9a9 --- /dev/null +++ b/cpp/src/gandiva/extension_tests/extended_funcs/multiply_by_two.h @@ -0,0 +1,7 @@ +#pragma once + +#include + +extern "C" { +int64_t multiply_by_two_int32(int32_t value); +} \ No newline at end of file diff --git a/cpp/src/gandiva/extension_tests/extended_funcs/registry.json b/cpp/src/gandiva/extension_tests/extended_funcs/registry.json new file mode 100644 index 0000000000000..7b434d9475811 --- /dev/null +++ b/cpp/src/gandiva/extension_tests/extended_funcs/registry.json @@ -0,0 +1,19 @@ +{ + "version": "1.0", + "functions": [ + { + "name": "multiply_by_two", + "aliases": [ + ], + "param_types": [ + { + "type": "int32" + } + ], + "return_type": { + "type": "int64" + }, + "pc_name": "multiply_by_two_int32" + } + ] +} \ No newline at end of file diff --git a/cpp/src/gandiva/extension_tests/multiple_functions_registry/registry.json b/cpp/src/gandiva/extension_tests/multiple_functions_registry/registry.json new file mode 100644 index 0000000000000..11ef732d9c01d --- /dev/null +++ b/cpp/src/gandiva/extension_tests/multiple_functions_registry/registry.json @@ -0,0 +1,35 @@ +{ + "version": "1.0", + "functions": [ + { + "name": "say_hello", + "aliases": [ + "hello" + ], + "param_types": [ + { + "type": "utf8" + } + ], + "return_type": { + "type": "int64" + }, + "result_nullable": "never", + "can_return_errors": true, + "pc_name": "say_hello_utf8" + }, + { + "name": "say_goodbye", + "aliases": [ + ], + "param_types": [ + ], + "return_type": { + "type": "utf8" + }, + "result_nullable": "never", + "can_return_errors": true, + "pc_name": "say_goodbye" + } + ] +} \ No newline at end of file diff --git a/cpp/src/gandiva/extension_tests/multiple_registries/reg_1.json b/cpp/src/gandiva/extension_tests/multiple_registries/reg_1.json new file mode 100644 index 0000000000000..6c34e588c56c8 --- /dev/null +++ b/cpp/src/gandiva/extension_tests/multiple_registries/reg_1.json @@ -0,0 +1,21 @@ +{ + "version": "1.0", + "functions": [ + { + "name": "say_hello", + "aliases": [ + "hello" + ], + "param_types": [ + { + "type": "utf8" + } + ], + "return_type": { + "type": "int64" + }, + "result_nullable": "never", + "pc_name": "say_hello_utf8" + } + ] +} \ No newline at end of file diff --git a/cpp/src/gandiva/extension_tests/multiple_registries/reg_2.json b/cpp/src/gandiva/extension_tests/multiple_registries/reg_2.json new file mode 100644 index 0000000000000..70d9e48bf6379 --- /dev/null +++ b/cpp/src/gandiva/extension_tests/multiple_registries/reg_2.json @@ -0,0 +1,18 @@ +{ + "version": "1.0", + "functions": [ + { + "name": "say_goodbye", + "aliases": [ + "goodbye" + ], + "param_types": [ + ], + "return_type": { + "type": "utf8" + }, + "result_nullable": "ifnull", + "pc_name": "say_goodbye" + } + ] +} \ No newline at end of file diff --git a/cpp/src/gandiva/extension_tests/no_name_func_registry/registry.json b/cpp/src/gandiva/extension_tests/no_name_func_registry/registry.json new file mode 100644 index 0000000000000..680d3486e8eab --- /dev/null +++ b/cpp/src/gandiva/extension_tests/no_name_func_registry/registry.json @@ -0,0 +1,7 @@ +{ + "version": "1.0", + "functions": [ + { + } + ] +} \ No newline at end of file diff --git a/cpp/src/gandiva/extension_tests/simple_registry/registry.json b/cpp/src/gandiva/extension_tests/simple_registry/registry.json new file mode 100644 index 0000000000000..b2e3a3a6703ef --- /dev/null +++ b/cpp/src/gandiva/extension_tests/simple_registry/registry.json @@ -0,0 +1,22 @@ +{ + "version": "1.0", + "functions": [ + { + "name": "say_hello", + "aliases": [ + "hello" + ], + "param_types": [ + { + "type": "utf8" + } + ], + "return_type": { + "type": "int64" + }, + "result_nullable": "never", + "can_return_errors": true, + "pc_name": "say_hello_utf8" + } + ] +} \ No newline at end of file diff --git a/cpp/src/gandiva/function_registry.cc b/cpp/src/gandiva/function_registry.cc index 67b7b404b325c..80d2bdfa9ec3f 100644 --- a/cpp/src/gandiva/function_registry.cc +++ b/cpp/src/gandiva/function_registry.cc @@ -16,8 +16,10 @@ // under the License. #include "gandiva/function_registry.h" +#include "arrow/util/logging.h" #include "gandiva/function_registry_arithmetic.h" #include "gandiva/function_registry_datetime.h" +#include "gandiva/function_registry_external.h" #include "gandiva/function_registry_hash.h" #include "gandiva/function_registry_math_ops.h" #include "gandiva/function_registry_string.h" @@ -45,6 +47,21 @@ std::vector FunctionRegistry::pc_registry_; SignatureMap FunctionRegistry::pc_registry_map_ = InitPCMap(); +std::vector LoadExternalFunctionRegistry() { + std::string ext_dir; + const char* ext_dir_env = std::getenv("GANDIVA_EXTENSION_DIR"); + + auto result = GetExternalFunctionRegistry(ext_dir_env ? ext_dir_env : ""); + std::vector funcs; + if (result.ok()) { + funcs = result.ValueUnsafe(); + } else { + ARROW_LOG(WARNING) << "Failed to load external function registry: " + << result.status().message(); + } + return funcs; +} + SignatureMap FunctionRegistry::InitPCMap() { SignatureMap map; @@ -64,6 +81,10 @@ SignatureMap FunctionRegistry::InitPCMap() { auto v6 = GetDateTimeArithmeticFunctionRegistry(); pc_registry_.insert(std::end(pc_registry_), v6.begin(), v6.end()); + + auto v7 = LoadExternalFunctionRegistry(); + pc_registry_.insert(std::end(pc_registry_), v7.begin(), v7.end()); + for (auto& elem : pc_registry_) { for (auto& func_signature : elem.signatures()) { map.insert(std::make_pair(&(func_signature), &elem)); diff --git a/cpp/src/gandiva/function_registry.h b/cpp/src/gandiva/function_registry.h index d9256326019c6..c2765704655f4 100644 --- a/cpp/src/gandiva/function_registry.h +++ b/cpp/src/gandiva/function_registry.h @@ -42,6 +42,10 @@ class GANDIVA_EXPORT FunctionRegistry { static std::vector pc_registry_; static SignatureMap pc_registry_map_; + + FRIEND_TEST(TestFunctionRegistry, LookupExternalFuncs); + FRIEND_TEST(TestFunctionRegistry, LookupMultipleFuncs); + FRIEND_TEST(TestProjector, TestExtendedFunctions); }; } // namespace gandiva diff --git a/cpp/src/gandiva/function_registry_external.cc b/cpp/src/gandiva/function_registry_external.cc new file mode 100644 index 0000000000000..42ac3d153f2d9 --- /dev/null +++ b/cpp/src/gandiva/function_registry_external.cc @@ -0,0 +1,301 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include +#include +#include +#include +#include + +namespace gandiva { +namespace rj = rapidjson; + +class JsonRegistryParser { + public: + static arrow::Result> Parse(std::string_view json) { + rj::Document doc; + doc.Parse(reinterpret_cast(json.data()), + static_cast(json.size())); + + if (doc.HasParseError()) { + return Status::Invalid("Json parse error (offset ", doc.GetErrorOffset(), + "): ", doc.GetParseError()); + } + if (!doc.IsObject()) { + return Status::TypeError("Not a json object"); + } + const rapidjson::Value& functions = doc["functions"]; + if (!functions.IsArray()) { + return Status::TypeError("'functions' property is expected to be a JSON array"); + } + + std::vector funcs; + for (const auto& func : functions.GetArray()) { + ARROW_ASSIGN_OR_RAISE(auto name, GetString(func, "name")); + ARROW_ASSIGN_OR_RAISE(auto aliases, GetAliases(func)); + ARROW_ASSIGN_OR_RAISE(DataTypeVector param_types, ParseParamTypes(func)); + ARROW_ASSIGN_OR_RAISE(auto ret_type, ParseDataType(func["return_type"])); + ARROW_ASSIGN_OR_RAISE(ResultNullableType result_nullable_type, + ParseResultNullable(func)); + ARROW_ASSIGN_OR_RAISE(auto pc_name, GetString(func, "pc_name")); + int32_t flags = GetFlags(func); + funcs.emplace_back(name, aliases, param_types, ret_type, result_nullable_type, + pc_name, flags); + } + return funcs; + } + + private: + static arrow::Result GetString(const rj::GenericValue>& func, + const std::string& key) { + if (!func.HasMember(key.c_str())) { + return Status::TypeError("'" + key + "'" + " property is missing"); + } + if (!func[key.c_str()].IsString()) { + return Status::TypeError("'" + key + "'" + " property should be a string"); + } + return func[key.c_str()].GetString(); + } + + static arrow::Result ParseResultNullable( + const rj::GenericValue>& func) { + std::string nullable; + if (!func.HasMember("result_nullable")) { + nullable = "ifnull"; + } else { + if (!func["result_nullable"].IsString()) { + return Status::TypeError("result_nullable property should be a string"); + } + nullable = func["result_nullable"].GetString(); + } + if (nullable == "ifnull") { + return ResultNullableType::kResultNullIfNull; + } else if (nullable == "never") { + return ResultNullableType::kResultNullNever; + } else if (nullable == "internal") { + return ResultNullableType::kResultNullInternal; + } else { + return Status::TypeError("Unsupported result_nullable value: " + nullable + + ". Only ifnull/never/internal are supported"); + } + } + static int32_t GetFlags(const rj::GenericValue>& func) { + int32_t flags = 0; + for (auto const& [flag_name, flag_value] : + {std::make_pair("needs_context", NativeFunction::kNeedsContext), + std::make_pair("needs_function_holder", NativeFunction::kNeedsFunctionHolder), + std::make_pair("can_return_errors", NativeFunction::kCanReturnErrors)}) { + if (func.HasMember(flag_name) && func[flag_name].GetBool()) { + flags |= flag_value; + } + } + return flags; + } + + static arrow::Result> GetAliases( + const rj::GenericValue>& func) { + std::vector aliases; + if (!func.HasMember("aliases")) { + return aliases; + } + if (func["aliases"].IsArray()) { + for (const auto& alias : func["aliases"].GetArray()) { + aliases.emplace_back(alias.GetString()); + } + } else { + return Status::TypeError("'aliases' property is expected to be a JSON array"); + } + return aliases; + } + + static arrow::Result ParseParamTypes( + const rj::GenericValue>& func) { + arrow::DataTypeVector param_types; + if (!func.HasMember("param_types")) { + return param_types; + } + if (!func["param_types"].IsArray()) { + return Status::TypeError("'param_types' property is expected to be a JSON array"); + } + for (const auto& param_type : func["param_types"].GetArray()) { + ARROW_ASSIGN_OR_RAISE(auto type, ParseDataType(param_type)) + param_types.push_back(type); + } + return param_types; + } + + static arrow::Result> ParseTimestampDataType( + const rj::GenericValue>& data_type) { + if (!data_type.HasMember("unit")) { + return Status::TypeError("'unit' property is required for timestamp data type"); + } + const std::string unit_name = data_type["unit"].GetString(); + arrow::TimeUnit::type unit; + if (unit_name == "second") { + unit = arrow::TimeUnit::SECOND; + } else if (unit_name == "milli") { + unit = arrow::TimeUnit::MILLI; + } else if (unit_name == "micro") { + unit = arrow::TimeUnit::MICRO; + } else if (unit_name == "nano") { + unit = arrow::TimeUnit::NANO; + } else { + return Status::TypeError("Unsupported timestamp unit name: ", unit_name); + } + return arrow::timestamp(unit); + } + + static arrow::Result> ParseDecimalDataType( + const rj::GenericValue>& data_type) { + if (!data_type.HasMember("precision") || !data_type["precision"].IsInt()) { + return Status::TypeError( + "'precision' property is required for decimal data type and should be an " + "integer"); + } + if (!data_type.HasMember("scale") || !data_type["scale"].IsInt()) { + return Status::TypeError( + "'scale' property is required for decimal data type and should be an integer"); + } + auto precision = data_type["precision"].GetInt(); + auto scale = data_type["scale"].GetInt(); + const std::string type_name = data_type["type"].GetString(); + if (type_name == "decimal128") { + return arrow::decimal128(precision, scale); + } else if (type_name == "decimal256") { + return arrow::decimal256(precision, scale); + } + return arrow::decimal(precision, scale); + } + + static arrow::Result> ParseListDataType( + const rj::GenericValue>& data_type) { + if (!data_type.HasMember("value_type") || !data_type["value_type"].IsObject()) { + return Status::TypeError( + "'value_type' property is required for list data type and should be an object"); + } + ARROW_ASSIGN_OR_RAISE(auto value_type, ParseDataType(data_type["value_type"])); + return arrow::list(value_type); + } + + static arrow::Result> ParseComplexDataType( + const rj::GenericValue>& data_type) { + static const std::unordered_map< + std::string, std::function>( + const rj::GenericValue>&)>> + complex_type_map = {{"timestamp", ParseTimestampDataType}, + {"decimal", ParseDecimalDataType}, + {"decimal128", ParseDecimalDataType}, + {"decimal256", ParseDecimalDataType}, + {"list", ParseListDataType}}; + const std::string type_name = data_type["type"].GetString(); + auto it = complex_type_map.find(type_name); + if (it == complex_type_map.end()) { + return Status::TypeError("Unsupported complex type name: ", type_name); + } + return it->second(data_type); + } + + static arrow::Result> ParseDataType( + const rj::GenericValue>& data_type) { + if (!data_type.HasMember("type")) { + return Status::TypeError("'type' property is required for data type"); + } + auto type_name = data_type["type"].GetString(); + auto type = ParseDataTypeFromName(type_name); + if (type == nullptr) { + return ParseComplexDataType(data_type); + } else { + return type; + } + } + + static std::shared_ptr ParseDataTypeFromName( + const std::string& type_name) { + static const std::unordered_map> + simple_type_map = {{"null", arrow::null()}, + {"boolean", arrow::boolean()}, + {"uint8", arrow::uint8()}, + {"int8", arrow::int8()}, + {"uint16", arrow::uint16()}, + {"int16", arrow::int16()}, + {"uint32", arrow::uint32()}, + {"int32", arrow::int32()}, + {"uint64", arrow::uint64()}, + {"int64", arrow::int64()}, + {"float16", arrow::float16()}, + {"float32", arrow::float32()}, + {"float64", arrow::float64()}, + {"utf8", arrow::utf8()}, + {"large_utf8", arrow::large_utf8()}, + {"binary", arrow::binary()}, + {"large_binary", arrow::large_binary()}, + {"date32", arrow::date32()}, + {"date64", arrow::date64()}, + {"day_time_interval", arrow::day_time_interval()}, + {"month_interval", arrow::month_interval()}}; + + auto it = simple_type_map.find(type_name); + return it != simple_type_map.end() ? it->second : nullptr; + } +}; + +// iterate all files under registry_dir by file names +std::vector ListAllFiles(const std::string& registry_dir) { + if (registry_dir.empty()) { + return {}; + } + std::vector filenames; + for (const auto& entry : std::filesystem::directory_iterator(registry_dir)) { + filenames.push_back(entry.path()); + } + + std::sort(filenames.begin(), filenames.end()); + return filenames; +} + +arrow::Result> GetExternalFunctionRegistry( + const std::string& registry_dir) { + // 1) load all .json files under registry_dir + // 2) parse each file and add to registry, the json file format is: + // {"version": "1.0", "functions": [{"name": "func1", "aliases": [], "param_types": + // [{"type": "utf8"}], "return_type": "{"type": "int64"}", "result_nullable": "never", + // "pc_name": "func1", "flags": 0}]} + + std::vector registry; + auto filenames = ListAllFiles(registry_dir); + for (const auto& entry : filenames) { + if (entry.extension() == ".json") { + std::ifstream file(entry); + std::string content((std::istreambuf_iterator(file)), + std::istreambuf_iterator()); + + auto funcs_result = JsonRegistryParser::Parse(content); + if (!funcs_result.ok()) { + return funcs_result.status().WithMessage( + "Failed to parse json file: ", entry.string(), + ". Error: ", funcs_result.status().message()); + } + auto funcs = funcs_result.ValueUnsafe(); + // insert all funcs into registry + registry.insert(registry.end(), funcs.begin(), funcs.end()); + } + } + + return registry; +} +} // namespace gandiva \ No newline at end of file diff --git a/cpp/src/gandiva/function_registry_external.h b/cpp/src/gandiva/function_registry_external.h new file mode 100644 index 0000000000000..d3b7b41957d54 --- /dev/null +++ b/cpp/src/gandiva/function_registry_external.h @@ -0,0 +1,27 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include "gandiva/native_function.h" + +namespace gandiva { +arrow::Result> GetExternalFunctionRegistry( + const std::string& registry_dir = ""); +} \ No newline at end of file diff --git a/cpp/src/gandiva/function_registry_external_test.cc b/cpp/src/gandiva/function_registry_external_test.cc new file mode 100644 index 0000000000000..b7c8a30044f74 --- /dev/null +++ b/cpp/src/gandiva/function_registry_external_test.cc @@ -0,0 +1,89 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/function_registry_external.h" +#include +#include +#include "arrow/testing/gtest_util.h" +#include "gandiva/tests/test_util.h" + +namespace gandiva { +class TestExternalFunctionRegistry : public ::testing::Test { + public: + arrow::Result> GetFuncs(const std::string& registry_dir) { + std::filesystem::path base(GANDIVA_EXTENSION_TEST_DIR); + return GetExternalFunctionRegistry(base / registry_dir); + } +}; + +TEST_F(TestExternalFunctionRegistry, EmptyDir) { + ASSERT_OK_AND_ASSIGN(auto funcs, GetExternalFunctionRegistry("")); + ASSERT_TRUE(funcs.empty()); +} + +TEST_F(TestExternalFunctionRegistry, FunctionWithoutName) { + auto funcs = GetFuncs("no_name_func_registry"); + ASSERT_TRUE(!funcs.ok()); +} + +TEST_F(TestExternalFunctionRegistry, DirWithJsonRegistry) { + ASSERT_OK_AND_ASSIGN(auto funcs, GetFuncs("simple_registry")); + ASSERT_EQ(funcs.size(), 1); + ASSERT_EQ(funcs[0].result_nullable_type(), ResultNullableType::kResultNullNever); + ASSERT_EQ(funcs[0].CanReturnErrors(), true); + ASSERT_EQ(funcs[0].pc_name(), "say_hello_utf8"); +} + +TEST_F(TestExternalFunctionRegistry, DirWithMultiJsonRegistry) { + ASSERT_OK_AND_ASSIGN(auto funcs, GetFuncs("multiple_registries")); + ASSERT_EQ(funcs.size(), 2); + auto sigs_0 = funcs[0].signatures(); + ASSERT_EQ(sigs_0.size(), 2); + ASSERT_EQ(sigs_0[0].param_types().size(), 1); + ASSERT_EQ(sigs_0[0].param_types()[0]->id(), arrow::Type::STRING); + ASSERT_EQ(sigs_0[0].ret_type()->id(), arrow::Type::INT64); + ASSERT_EQ(funcs[0].pc_name(), "say_hello_utf8"); + + ASSERT_EQ(funcs[1].result_nullable_type(), ResultNullableType::kResultNullIfNull); + auto sigs_1 = funcs[1].signatures(); + ASSERT_EQ(sigs_1.size(), 2); + ASSERT_TRUE(sigs_1[0].param_types().empty()); + ASSERT_EQ(sigs_1[0].ret_type()->id(), arrow::Type::STRING); + ASSERT_EQ(funcs[1].pc_name(), "say_goodbye"); +} + +TEST_F(TestExternalFunctionRegistry, DirWithMultiFunctionRegistry) { + ASSERT_OK_AND_ASSIGN(auto funcs, GetFuncs("multiple_functions_registry")); + ASSERT_EQ(funcs.size(), 2); + ASSERT_EQ(funcs[0].pc_name(), "say_hello_utf8"); + ASSERT_EQ(funcs[1].pc_name(), "say_goodbye"); +} + +TEST_F(TestExternalFunctionRegistry, DirWithComplexTypeRegistry) { + ASSERT_OK_AND_ASSIGN(auto funcs, GetFuncs("complex_type_registry")); + ASSERT_EQ(funcs.size(), 1); + ASSERT_EQ(funcs[0].pc_name(), "greet_timestamp_list"); + auto sigs = funcs[0].signatures(); + ASSERT_EQ(sigs.size(), 1); + ASSERT_EQ(sigs[0].param_types().size(), 2); + ASSERT_EQ(sigs[0].param_types()[0]->id(), arrow::Type::TIMESTAMP); + ASSERT_EQ(sigs[0].param_types()[1]->id(), arrow::Type::LIST); + ASSERT_EQ(sigs[0].param_types()[1]->ToString(), "list"); + ASSERT_EQ(sigs[0].ret_type()->id(), arrow::Type::DECIMAL); + ASSERT_EQ(sigs[0].ret_type()->ToString(), "decimal128(10, 2)"); +} +} // namespace gandiva \ No newline at end of file diff --git a/cpp/src/gandiva/function_registry_test.cc b/cpp/src/gandiva/function_registry_test.cc index e3c1e85f79cba..d5ca5cdbfdd3d 100644 --- a/cpp/src/gandiva/function_registry_test.cc +++ b/cpp/src/gandiva/function_registry_test.cc @@ -16,12 +16,13 @@ // under the License. #include "gandiva/function_registry.h" - #include #include #include +#include #include #include +#include "gandiva/tests/test_util.h" namespace gandiva { @@ -93,4 +94,36 @@ TEST_F(TestFunctionRegistry, TestNoDuplicates) { "different precompiled functions:\n" << stream.str(); } + +TEST_F(TestFunctionRegistry, LookupExternalFuncs) { + ExtensionDirSetter ext_dir_setter("extended_funcs", []() { + FunctionRegistry::pc_registry_.clear(); + FunctionRegistry::pc_registry_map_ = FunctionRegistry::InitPCMap(); + }); + + FunctionSignature multiply_by_two_int32("multiply_by_two", {arrow::int32()}, + arrow::int64()); + auto func = registry_.LookupSignature(multiply_by_two_int32); + EXPECT_NE(func, nullptr); + EXPECT_EQ(func->pc_name(), "multiply_by_two_int32"); +} + +TEST_F(TestFunctionRegistry, LookupMultipleFuncs) { + ExtensionDirSetter ext_dir_setter("multiple_functions_registry", []() { + FunctionRegistry::pc_registry_.clear(); + FunctionRegistry::pc_registry_map_ = FunctionRegistry::InitPCMap(); + }); + + FunctionSignature say_hello_utf8("say_hello", {arrow::utf8()}, arrow::int64()); + auto say_hello_func = registry_.LookupSignature(say_hello_utf8); + EXPECT_NE(say_hello_func, nullptr); + EXPECT_EQ(say_hello_func->signatures().size(), 2); + EXPECT_EQ(say_hello_func->pc_name(), "say_hello_utf8"); + + FunctionSignature say_goodbye("say_goodbye", {}, arrow::utf8()); + auto say_goodbye_func = registry_.LookupSignature(say_goodbye); + EXPECT_NE(say_goodbye_func, nullptr); + EXPECT_EQ(say_goodbye_func->signatures().size(), 1); + EXPECT_EQ(say_goodbye_func->pc_name(), "say_goodbye"); +} } // namespace gandiva diff --git a/cpp/src/gandiva/llvm_generator.h b/cpp/src/gandiva/llvm_generator.h index 04f9b854b1d29..c996361255edc 100644 --- a/cpp/src/gandiva/llvm_generator.h +++ b/cpp/src/gandiva/llvm_generator.h @@ -85,6 +85,7 @@ class GANDIVA_EXPORT LLVMGenerator { explicit LLVMGenerator(bool cached); FRIEND_TEST(TestLLVMGenerator, VerifyPCFunctions); + FRIEND_TEST(TestLLVMGenerator, VerifyExtendedPCFunctions); FRIEND_TEST(TestLLVMGenerator, TestAdd); FRIEND_TEST(TestLLVMGenerator, TestNullInternal); diff --git a/cpp/src/gandiva/llvm_generator_test.cc b/cpp/src/gandiva/llvm_generator_test.cc index 028893b0b4594..24c865538074b 100644 --- a/cpp/src/gandiva/llvm_generator_test.cc +++ b/cpp/src/gandiva/llvm_generator_test.cc @@ -38,6 +38,11 @@ class TestLLVMGenerator : public ::testing::Test { FunctionRegistry registry_; }; +class TestLLVMGenerator2 : public ::testing::Test { + protected: + FunctionRegistry registry_; +}; + // Verify that a valid pc function exists for every function in the registry. TEST_F(TestLLVMGenerator, VerifyPCFunctions) { std::unique_ptr generator; @@ -50,6 +55,16 @@ TEST_F(TestLLVMGenerator, VerifyPCFunctions) { } } +TEST_F(TestLLVMGenerator, VerifyExtendedPCFunctions) { + ExtensionDirSetter ext_dir_setter("extended_funcs"); + std::unique_ptr generator; + ASSERT_OK(LLVMGenerator::Make(TestConfiguration(), false, &generator)); + + llvm::Module* module = generator->module(); + ASSERT_OK(generator->engine_->LoadFunctionIRs()); + EXPECT_NE(module->getFunction("multiply_by_two_int32"), nullptr); +} + TEST_F(TestLLVMGenerator, TestAdd) { // Setup LLVM generator to do an arithmetic add of two vectors std::unique_ptr generator; diff --git a/cpp/src/gandiva/precompiled/CMakeLists.txt b/cpp/src/gandiva/precompiled/CMakeLists.txt index 4ca5cc655b2a7..54929db19ddf6 100644 --- a/cpp/src/gandiva/precompiled/CMakeLists.txt +++ b/cpp/src/gandiva/precompiled/CMakeLists.txt @@ -17,6 +17,8 @@ project(gandiva) +include(../cmake/GenerateBitcode.cmake) + set(PRECOMPILED_SRCS arithmetic_ops.cc bitmap.cc @@ -30,68 +32,7 @@ set(PRECOMPILED_SRCS timestamp_arithmetic.cc ../../arrow/util/basic_decimal.cc) -if(MSVC) - # clang pretends to be a particular version of MSVC. Thestandard - # library uses C++14 features, so we have to use that -std version - # to get the IR compilation to work. - # See https://cmake.org/cmake/help/latest/variable/MSVC_VERSION.html - # for MSVC_VERSION and Visual Studio version. - if(MSVC_VERSION LESS 1930) - set(FMS_COMPATIBILITY 19.20) - elseif(MSVC_VERSION LESS 1920) - set(FMS_COMPATIBILITY 19.10) - else() - message(FATAL_ERROR "Unsupported MSVC_VERSION=${MSVC_VERSION}") - endif() - set(PLATFORM_CLANG_OPTIONS -std=c++17 -fms-compatibility - -fms-compatibility-version=${FMS_COMPATIBILITY}) -else() - set(PLATFORM_CLANG_OPTIONS -std=c++17) -endif() - -# Create bitcode for each of the source files. -foreach(SRC_FILE ${PRECOMPILED_SRCS}) - get_filename_component(SRC_BASE ${SRC_FILE} NAME_WE) - get_filename_component(ABSOLUTE_SRC ${SRC_FILE} ABSOLUTE) - set(BC_FILE ${CMAKE_CURRENT_BINARY_DIR}/${SRC_BASE}.bc) - set(PRECOMPILE_COMMAND) - if(CMAKE_OSX_SYSROOT) - list(APPEND - PRECOMPILE_COMMAND - ${CMAKE_COMMAND} - -E - env - SDKROOT=${CMAKE_OSX_SYSROOT}) - endif() - list(APPEND - PRECOMPILE_COMMAND - ${CLANG_EXECUTABLE} - ${PLATFORM_CLANG_OPTIONS} - -DGANDIVA_IR - -DNDEBUG # DCHECK macros not implemented in precompiled code - -DARROW_STATIC # Do not set __declspec(dllimport) on MSVC on Arrow symbols - -DGANDIVA_STATIC # Do not set __declspec(dllimport) on MSVC on Gandiva symbols - -fno-use-cxa-atexit # Workaround for unresolved __dso_handle - -emit-llvm - -O3 - -c - ${ABSOLUTE_SRC} - -o - ${BC_FILE} - ${ARROW_GANDIVA_PC_CXX_FLAGS} - -I${CMAKE_SOURCE_DIR}/src - -I${ARROW_BINARY_DIR}/src) - - if(NOT ARROW_USE_NATIVE_INT128) - foreach(boost_include_dir ${Boost_INCLUDE_DIRS}) - list(APPEND PRECOMPILE_COMMAND -I${boost_include_dir}) - endforeach() - endif() - add_custom_command(OUTPUT ${BC_FILE} - COMMAND ${PRECOMPILE_COMMAND} - DEPENDS ${SRC_FILE}) - list(APPEND BC_FILES ${BC_FILE}) -endforeach() +generate_bitcode("${PRECOMPILED_SRCS}" "" BC_FILES) # link all of the bitcode files into a single bitcode file. add_custom_command(OUTPUT ${GANDIVA_PRECOMPILED_BC_PATH} @@ -144,5 +85,6 @@ if(ARROW_BUILD_TESTS) set_property(TEST gandiva-precompiled-test APPEND PROPERTY LABELS "unittest;gandiva-tests") + add_dependencies(gandiva-tests gandiva-precompiled-test) endif() diff --git a/cpp/src/gandiva/tests/projector_test.cc b/cpp/src/gandiva/tests/projector_test.cc index 462fae64393fd..dd9107e31d727 100644 --- a/cpp/src/gandiva/tests/projector_test.cc +++ b/cpp/src/gandiva/tests/projector_test.cc @@ -26,6 +26,7 @@ #include #include "arrow/memory_pool.h" +#include "gandiva/function_registry.h" #include "gandiva/literal_holder.h" #include "gandiva/node.h" #include "gandiva/tests/test_util.h" @@ -3582,4 +3583,30 @@ TEST_F(TestProjector, TestSqrtFloat64) { EXPECT_ARROW_ARRAY_EQUALS(out, outs.at(0)); } +TEST_F(TestProjector, TestExtendedFunctions) { + ExtensionDirSetter ext_dir_setter("extended_funcs", []() { + FunctionRegistry::pc_registry_.clear(); + FunctionRegistry::pc_registry_map_ = FunctionRegistry::InitPCMap(); + }); + + auto in_field = field("in", arrow::int32()); + auto schema = arrow::schema({in_field}); + auto out_field = field("out", arrow::int64()); + // the multiply_by_two function is only available in the extended_funcs dir's bitcode + auto multiply = + TreeExprBuilder::MakeExpression("multiply_by_two", {in_field}, out_field); + + std::shared_ptr projector; + ARROW_EXPECT_OK(Projector::Make(schema, {multiply}, TestConfiguration(), &projector)); + + int num_records = 4; + auto array = MakeArrowArrayInt32({1, 2, 3, 4}, {true, true, true, true}); + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array}); + auto out = MakeArrowArrayInt64({2, 4, 6, 8}, {true, true, true, true}); + + arrow::ArrayVector outs; + ARROW_EXPECT_OK(projector->Evaluate(*in_batch, pool_, &outs)); + EXPECT_ARROW_ARRAY_EQUALS(out, outs.at(0)); +} + } // namespace gandiva diff --git a/cpp/src/gandiva/tests/test_util.h b/cpp/src/gandiva/tests/test_util.h index 99df90769e0ad..373278515d38c 100644 --- a/cpp/src/gandiva/tests/test_util.h +++ b/cpp/src/gandiva/tests/test_util.h @@ -20,6 +20,7 @@ #include #include +#include #include "arrow/testing/builder.h" #include "arrow/testing/gtest_util.h" #include "gandiva/arrow.h" @@ -101,4 +102,29 @@ static inline std::shared_ptr TestConfiguration() { return builder.DefaultConfiguration(); } +#ifndef GANDIVA_EXTENSION_TEST_DIR +#define GANDIVA_EXTENSION_TEST_DIR "../gandiva_extension_tests" +#endif + +struct ExtensionDirSetter { + explicit ExtensionDirSetter( + const std::string& ext_dir, + const std::optional>& env_reloader = std::nullopt) + : env_reloader_(env_reloader) { + std::filesystem::path base(GANDIVA_EXTENSION_TEST_DIR); + setenv("GANDIVA_EXTENSION_DIR", (base / ext_dir).c_str(), 1); + if (env_reloader_.has_value()) { + env_reloader_.value()(); + } + } + ~ExtensionDirSetter() { + unsetenv("GANDIVA_EXTENSION_DIR"); + if (env_reloader_.has_value()) { + env_reloader_.value()(); + } + } + + const std::optional>& env_reloader_; +}; + } // namespace gandiva