diff --git a/cpp/src/gandiva/CMakeLists.txt b/cpp/src/gandiva/CMakeLists.txt index 3448d516768bb..3f038f54a7b27 100644 --- a/cpp/src/gandiva/CMakeLists.txt +++ b/cpp/src/gandiva/CMakeLists.txt @@ -62,7 +62,9 @@ set(SRC_FILES expression_registry.cc exported_funcs_registry.cc exported_funcs.cc + external_c_functions.cc filter.cc + function_holder_maker_registry.cc function_ir_builder.cc function_registry.cc function_registry_arithmetic.cc diff --git a/cpp/src/gandiva/cast_time.cc b/cpp/src/gandiva/cast_time.cc index 843ce01f89d57..eeb2ea3fdd88f 100644 --- a/cpp/src/gandiva/cast_time.cc +++ b/cpp/src/gandiva/cast_time.cc @@ -29,7 +29,7 @@ namespace gandiva { -void ExportedTimeFunctions::AddMappings(Engine* engine) const { +arrow::Status ExportedTimeFunctions::AddMappings(Engine* engine) const { std::vector args; auto types = engine->types(); @@ -42,6 +42,7 @@ void ExportedTimeFunctions::AddMappings(Engine* engine) const { engine->AddGlobalMappingForFunc("gdv_fn_time_with_zone", types->i32_type() /*return_type*/, args, reinterpret_cast(gdv_fn_time_with_zone)); + return arrow::Status::OK(); } } // namespace gandiva diff --git a/cpp/src/gandiva/context_helper.cc b/cpp/src/gandiva/context_helper.cc index 224bfd8f56cd3..03bbe1b7a67d9 100644 --- a/cpp/src/gandiva/context_helper.cc +++ b/cpp/src/gandiva/context_helper.cc @@ -25,7 +25,7 @@ namespace gandiva { -void ExportedContextFunctions::AddMappings(Engine* engine) const { +arrow::Status ExportedContextFunctions::AddMappings(Engine* engine) const { std::vector args; auto types = engine->types(); @@ -50,6 +50,7 @@ void ExportedContextFunctions::AddMappings(Engine* engine) const { engine->AddGlobalMappingForFunc("gdv_fn_context_arena_reset", types->void_type(), args, reinterpret_cast(gdv_fn_context_arena_reset)); + return arrow::Status::OK(); } } // namespace gandiva diff --git a/cpp/src/gandiva/decimal_xlarge.cc b/cpp/src/gandiva/decimal_xlarge.cc index caebd8b09e63a..21212422f3d69 100644 --- a/cpp/src/gandiva/decimal_xlarge.cc +++ b/cpp/src/gandiva/decimal_xlarge.cc @@ -38,7 +38,7 @@ namespace gandiva { -void ExportedDecimalFunctions::AddMappings(Engine* engine) const { +arrow::Status ExportedDecimalFunctions::AddMappings(Engine* engine) const { std::vector args; auto types = engine->types(); @@ -93,6 +93,7 @@ void ExportedDecimalFunctions::AddMappings(Engine* engine) const { engine->AddGlobalMappingForFunc("gdv_xlarge_compare", types->i32_type() /*return_type*/, args, reinterpret_cast(gdv_xlarge_compare)); + return arrow::Status::OK(); } } // namespace gandiva diff --git a/cpp/src/gandiva/engine.cc b/cpp/src/gandiva/engine.cc index 5ae1d76876148..1cea1fd2cbf30 100644 --- a/cpp/src/gandiva/engine.cc +++ b/cpp/src/gandiva/engine.cc @@ -147,7 +147,7 @@ Engine::Engine(const std::shared_ptr& conf, Status Engine::Init() { std::call_once(register_exported_funcs_flag, gandiva::RegisterExportedFuncs); // Add mappings for global functions that can be accessed from LLVM/IR module. - AddGlobalMappings(); + ARROW_RETURN_NOT_OK(AddGlobalMappings()); return Status::OK(); } @@ -447,7 +447,11 @@ void Engine::AddGlobalMappingForFunc(const std::string& name, llvm::Type* ret_ty execution_engine_->addGlobalMapping(fn, function_ptr); } -void Engine::AddGlobalMappings() { ExportedFuncsRegistry::AddMappings(this); } +arrow::Status Engine::AddGlobalMappings() { + ARROW_RETURN_NOT_OK(ExportedFuncsRegistry::AddMappings(this)); + ExternalCFunctions c_funcs(function_registry_); + return c_funcs.AddMappings(this); +} std::string Engine::DumpIR() { std::string ir; diff --git a/cpp/src/gandiva/engine.h b/cpp/src/gandiva/engine.h index 566977dc4adad..df2d8b36d9260 100644 --- a/cpp/src/gandiva/engine.h +++ b/cpp/src/gandiva/engine.h @@ -97,7 +97,7 @@ class GANDIVA_EXPORT Engine { Status LoadExternalPreCompiledIR(); // Create and add mappings for cpp functions that can be accessed from LLVM. - void AddGlobalMappings(); + arrow::Status AddGlobalMappings(); // Remove unused functions to reduce compile time. Status RemoveUnusedFunctions(); diff --git a/cpp/src/gandiva/exported_funcs.h b/cpp/src/gandiva/exported_funcs.h index 82aa020a210c0..414ec5c5bfd61 100644 --- a/cpp/src/gandiva/exported_funcs.h +++ b/cpp/src/gandiva/exported_funcs.h @@ -18,6 +18,7 @@ #pragma once #include +#include "gandiva/function_registry.h" #include "gandiva/visibility.h" namespace gandiva { @@ -29,37 +30,48 @@ class ExportedFuncsBase { public: virtual ~ExportedFuncsBase() = default; - virtual void AddMappings(Engine* engine) const = 0; + virtual arrow::Status AddMappings(Engine* engine) const = 0; }; // Class for exporting Stub functions class ExportedStubFunctions : public ExportedFuncsBase { - void AddMappings(Engine* engine) const override; + arrow::Status AddMappings(Engine* engine) const override; }; // Class for exporting Context functions class ExportedContextFunctions : public ExportedFuncsBase { - void AddMappings(Engine* engine) const override; + arrow::Status AddMappings(Engine* engine) const override; }; // Class for exporting Time functions class ExportedTimeFunctions : public ExportedFuncsBase { - void AddMappings(Engine* engine) const override; + arrow::Status AddMappings(Engine* engine) const override; }; // Class for exporting Decimal functions class ExportedDecimalFunctions : public ExportedFuncsBase { - void AddMappings(Engine* engine) const override; + arrow::Status AddMappings(Engine* engine) const override; }; // Class for exporting String functions class ExportedStringFunctions : public ExportedFuncsBase { - void AddMappings(Engine* engine) const override; + arrow::Status AddMappings(Engine* engine) const override; }; // Class for exporting Hash functions class ExportedHashFunctions : public ExportedFuncsBase { - void AddMappings(Engine* engine) const override; + arrow::Status AddMappings(Engine* engine) const override; +}; + +class ExternalCFunctions : public ExportedFuncsBase { + public: + explicit ExternalCFunctions(std::shared_ptr function_registry) + : function_registry_(std::move(function_registry)) {} + + arrow::Status AddMappings(Engine* engine) const override; + + private: + std::shared_ptr function_registry_; }; GANDIVA_EXPORT void RegisterExportedFuncs(); diff --git a/cpp/src/gandiva/exported_funcs_registry.cc b/cpp/src/gandiva/exported_funcs_registry.cc index 2c928a7a2a46f..137d29eefbea1 100644 --- a/cpp/src/gandiva/exported_funcs_registry.cc +++ b/cpp/src/gandiva/exported_funcs_registry.cc @@ -21,10 +21,11 @@ namespace gandiva { -void ExportedFuncsRegistry::AddMappings(Engine* engine) { +arrow::Status ExportedFuncsRegistry::AddMappings(Engine* engine) { for (const auto& entry : *registered()) { - entry->AddMappings(engine); + ARROW_RETURN_NOT_OK(entry->AddMappings(engine)); } + return arrow::Status::OK(); } const ExportedFuncsRegistry::list_type& ExportedFuncsRegistry::Registered() { diff --git a/cpp/src/gandiva/exported_funcs_registry.h b/cpp/src/gandiva/exported_funcs_registry.h index 08c45aec6a1ed..a34308bb96050 100644 --- a/cpp/src/gandiva/exported_funcs_registry.h +++ b/cpp/src/gandiva/exported_funcs_registry.h @@ -34,7 +34,7 @@ class GANDIVA_EXPORT ExportedFuncsRegistry { using list_type = std::vector>; // Add functions from all the registered classes to the engine. - static void AddMappings(Engine* engine); + static arrow::Status AddMappings(Engine* engine); static bool Register(std::shared_ptr entry) { registered()->emplace_back(std::move(entry)); diff --git a/cpp/src/gandiva/expr_decomposer.cc b/cpp/src/gandiva/expr_decomposer.cc index 957d9d046bd57..42566ca035159 100644 --- a/cpp/src/gandiva/expr_decomposer.cc +++ b/cpp/src/gandiva/expr_decomposer.cc @@ -25,11 +25,12 @@ #include "gandiva/annotator.h" #include "gandiva/dex.h" -#include "gandiva/function_holder_registry.h" +#include "gandiva/function_holder_maker_registry.h" #include "gandiva/function_registry.h" #include "gandiva/function_signature.h" #include "gandiva/in_holder.h" #include "gandiva/node.h" +#include "gandiva/regex_functions_holder.h" namespace gandiva { @@ -81,9 +82,10 @@ Status ExprDecomposer::Visit(const FunctionNode& in_node) { std::shared_ptr holder; int holder_idx = -1; if (native_function->NeedsFunctionHolder()) { - auto status = FunctionHolderRegistry::Make(desc->name(), node, &holder); + auto function_holder_maker_registry = registry_.GetFunctionHolderMakerRegistry(); + ARROW_ASSIGN_OR_RAISE(holder, + function_holder_maker_registry.Make(desc->name(), node)); holder_idx = annotator_.AddHolderPointer(holder.get()); - ARROW_RETURN_NOT_OK(status); } if (native_function->result_nullable_type() == kResultNullIfNull) { diff --git a/cpp/src/gandiva/external_c_functions.cc b/cpp/src/gandiva/external_c_functions.cc new file mode 100644 index 0000000000000..fcba00aed3524 --- /dev/null +++ b/cpp/src/gandiva/external_c_functions.cc @@ -0,0 +1,79 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License + +#include + +#include "gandiva/engine.h" +#include "gandiva/exported_funcs.h" + +namespace { +// calculate the number of arguments for a function signature +size_t GetNumArgs(const gandiva::FunctionSignature& sig, + const gandiva::NativeFunction& func) { + auto num_args = 0; + num_args += func.NeedsContext() ? 1 : 0; + num_args += func.NeedsFunctionHolder() ? 1 : 0; + for (auto const& arg : sig.param_types()) { + num_args += arg->id() == arrow::Type::STRING ? 2 : 1; + } + num_args += sig.ret_type()->id() == arrow::Type::STRING ? 1 : 0; + return num_args; +} + +// map from a NativeFunction's signature to the corresponding LLVM signature +arrow::Result, llvm::Type*>> MapToLLVMSignature( + const gandiva::FunctionSignature& sig, const gandiva::NativeFunction& func, + gandiva::LLVMTypes* types) { + std::vector arg_llvm_types; + arg_llvm_types.reserve(GetNumArgs(sig, func)); + + if (func.NeedsContext()) { + arg_llvm_types.push_back(types->i64_type()); + } + if (func.NeedsFunctionHolder()) { + arg_llvm_types.push_back(types->i64_type()); + } + for (auto const& arg : sig.param_types()) { + arg_llvm_types.push_back(types->IRType(arg->id())); + if (arg->id() == arrow::Type::STRING) { + // string type needs an additional length argument + arg_llvm_types.push_back(types->i32_type()); + } + } + if (sig.ret_type()->id() == arrow::Type::STRING) { + // for string output, the last arg is the output length + arg_llvm_types.push_back(types->i32_ptr_type()); + } + auto ret_llvm_type = types->IRType(sig.ret_type()->id()); + return std::make_pair(std::move(arg_llvm_types), ret_llvm_type); +} +} // namespace + +namespace gandiva { +Status ExternalCFunctions::AddMappings(Engine* engine) const { + auto const& c_funcs = function_registry_->GetCFunctions(); + auto const types = engine->types(); + for (auto& [func, func_ptr] : c_funcs) { + for (auto const& sig : func.signatures()) { + ARROW_ASSIGN_OR_RAISE(auto llvm_signature, MapToLLVMSignature(sig, func, types)); + auto& [args, ret_llvm_type] = llvm_signature; + engine->AddGlobalMappingForFunc(func.pc_name(), ret_llvm_type, args, func_ptr); + } + } + return Status::OK(); +} +} // namespace gandiva diff --git a/cpp/src/gandiva/function_holder_maker_registry.cc b/cpp/src/gandiva/function_holder_maker_registry.cc new file mode 100644 index 0000000000000..bb93402475ae8 --- /dev/null +++ b/cpp/src/gandiva/function_holder_maker_registry.cc @@ -0,0 +1,72 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gandiva/function_holder_maker_registry.h" + +#include + +#include "arrow/util/string.h" +#include "gandiva/function_holder.h" +#include "gandiva/interval_holder.h" +#include "gandiva/random_generator_holder.h" +#include "gandiva/regex_functions_holder.h" +#include "gandiva/to_date_holder.h" + +namespace gandiva { + +using arrow::internal::AsciiToLower; + +FunctionHolderMakerRegistry::FunctionHolderMakerRegistry() + : function_holder_makers_(DefaultHolderMakers()) {} + +arrow::Status FunctionHolderMakerRegistry::Register(const std::string& name, + FunctionHolderMaker holder_maker) { + function_holder_makers_.emplace(AsciiToLower(name), std::move(holder_maker)); + return arrow::Status::OK(); +} + +template +static arrow::Result HolderMaker(const FunctionNode& node) { + std::shared_ptr derived_instance; + ARROW_RETURN_NOT_OK(HolderType::Make(node, &derived_instance)); + return derived_instance; +} + +arrow::Result FunctionHolderMakerRegistry::Make( + const std::string& name, const FunctionNode& node) { + auto lowered_name = AsciiToLower(name); + auto found = function_holder_makers_.find(lowered_name); + if (found == function_holder_makers_.end()) { + return Status::Invalid("function holder not registered for function " + name); + } + + return found->second(node); +} + +FunctionHolderMakerRegistry::MakerMap FunctionHolderMakerRegistry::DefaultHolderMakers() { + static const MakerMap maker_map = { + {"like", HolderMaker}, + {"to_date", HolderMaker}, + {"random", HolderMaker}, + {"rand", HolderMaker}, + {"regexp_replace", HolderMaker}, + {"regexp_extract", HolderMaker}, + {"castintervalday", HolderMaker}, + {"castintervalyear", HolderMaker}}; + return maker_map; +} +} // namespace gandiva diff --git a/cpp/src/gandiva/function_holder_maker_registry.h b/cpp/src/gandiva/function_holder_maker_registry.h new file mode 100644 index 0000000000000..f215a4852aaee --- /dev/null +++ b/cpp/src/gandiva/function_holder_maker_registry.h @@ -0,0 +1,52 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include + +#include "arrow/status.h" +#include "gandiva/function_holder.h" +#include "gandiva/node.h" + +namespace gandiva { + +/// registry of function holder makers +class FunctionHolderMakerRegistry { + public: + using FunctionHolderMaker = + std::function(const FunctionNode&)>; + + FunctionHolderMakerRegistry(); + + arrow::Status Register(const std::string& name, FunctionHolderMaker holder_maker); + + /// \brief lookup a function holder maker using the given function name, + /// and make a FunctionHolderPtr using the found holder maker and the given FunctionNode + arrow::Result Make(const std::string& name, + const FunctionNode& node); + + private: + using MakerMap = std::unordered_map; + + MakerMap function_holder_makers_; + static MakerMap DefaultHolderMakers(); +}; + +} // namespace gandiva diff --git a/cpp/src/gandiva/function_holder_registry.h b/cpp/src/gandiva/function_holder_registry.h deleted file mode 100644 index 7220f0d9d0d5e..0000000000000 --- a/cpp/src/gandiva/function_holder_registry.h +++ /dev/null @@ -1,80 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#pragma once - -#include -#include -#include -#include - -#include "arrow/status.h" -#include "gandiva/function_holder.h" -#include "gandiva/interval_holder.h" -#include "gandiva/node.h" -#include "gandiva/random_generator_holder.h" -#include "gandiva/regex_functions_holder.h" -#include "gandiva/to_date_holder.h" - -namespace gandiva { - -#define LAMBDA_MAKER(derived) \ - [](const FunctionNode& node, FunctionHolderPtr* holder) { \ - std::shared_ptr derived_instance; \ - auto status = derived::Make(node, &derived_instance); \ - if (status.ok()) { \ - *holder = derived_instance; \ - } \ - return status; \ - } - -/// Static registry of function holders. -class FunctionHolderRegistry { - public: - using maker_type = std::function; - using map_type = std::unordered_map; - - static Status Make(const std::string& name, const FunctionNode& node, - FunctionHolderPtr* holder) { - std::string data = name; - std::transform(data.begin(), data.end(), data.begin(), - [](unsigned char c) { return std::tolower(c); }); - - auto found = makers().find(data); - if (found == makers().end()) { - return Status::Invalid("function holder not registered for function " + name); - } - - return found->second(node, holder); - } - - private: - static map_type& makers() { - static map_type maker_map = {{"like", LAMBDA_MAKER(LikeHolder)}, - {"ilike", LAMBDA_MAKER(LikeHolder)}, - {"to_date", LAMBDA_MAKER(ToDateHolder)}, - {"random", LAMBDA_MAKER(RandomGeneratorHolder)}, - {"rand", LAMBDA_MAKER(RandomGeneratorHolder)}, - {"regexp_replace", LAMBDA_MAKER(ReplaceHolder)}, - {"regexp_extract", LAMBDA_MAKER(ExtractHolder)}, - {"castintervalday", LAMBDA_MAKER(IntervalDaysHolder)}, - {"castintervalyear", LAMBDA_MAKER(IntervalYearsHolder)}}; - return maker_map; - } -}; - -} // namespace gandiva diff --git a/cpp/src/gandiva/function_registry.cc b/cpp/src/gandiva/function_registry.cc index 5d676dfa8df74..2e392630ee009 100644 --- a/cpp/src/gandiva/function_registry.cc +++ b/cpp/src/gandiva/function_registry.cc @@ -64,7 +64,7 @@ FunctionRegistry::iterator FunctionRegistry::back() const { const NativeFunction* FunctionRegistry::LookupSignature( const FunctionSignature& signature) const { - auto got = pc_registry_map_.find(&signature); + auto const got = pc_registry_map_.find(&signature); return got == pc_registry_map_.end() ? nullptr : got->second; } @@ -109,11 +109,34 @@ arrow::Status FunctionRegistry::Register(const std::vector& func return Status::OK(); } +arrow::Status FunctionRegistry::Register( + NativeFunction func, void* c_function_ptr, + std::optional function_holder_maker) { + if (function_holder_maker.has_value()) { + // all signatures should have the same base name, use the first signature's base name + auto const& func_base_name = func.signatures().begin()->base_name(); + ARROW_RETURN_NOT_OK(holder_maker_registry_.Register( + func_base_name, std::move(function_holder_maker).value())); + } + c_functions_.emplace_back(func, c_function_ptr); + return FunctionRegistry::Add(std::move(func)); +} + const std::vector>& FunctionRegistry::GetBitcodeBuffers() const { return bitcode_memory_buffers_; } +const std::vector>& FunctionRegistry::GetCFunctions() + const { + return c_functions_; +} + +const FunctionHolderMakerRegistry& FunctionRegistry::GetFunctionHolderMakerRegistry() + const { + return holder_maker_registry_; +} + arrow::Result> MakeDefaultFunctionRegistry() { auto registry = std::make_shared(); for (auto const& funcs : diff --git a/cpp/src/gandiva/function_registry.h b/cpp/src/gandiva/function_registry.h index 01984961dc90f..24b64fac5f3fa 100644 --- a/cpp/src/gandiva/function_registry.h +++ b/cpp/src/gandiva/function_registry.h @@ -18,11 +18,14 @@ #pragma once #include +#include #include #include #include "arrow/buffer.h" #include "arrow/status.h" +#include "gandiva/function_holder.h" +#include "gandiva/function_holder_maker_registry.h" #include "gandiva/function_registry_common.h" #include "gandiva/gandiva_aliases.h" #include "gandiva/native_function.h" @@ -34,6 +37,9 @@ namespace gandiva { class GANDIVA_EXPORT FunctionRegistry { public: using iterator = const NativeFunction*; + using FunctionHolderMaker = + std::function>( + const FunctionNode& function_node)>; FunctionRegistry(); FunctionRegistry(const FunctionRegistry&) = delete; @@ -52,9 +58,24 @@ class GANDIVA_EXPORT FunctionRegistry { arrow::Status Register(const std::vector& funcs, std::shared_ptr bitcode_buffer); + /// \brief register a C function into the function registry + /// @param func the registered function's metadata + /// @param c_function_ptr the function pointer to the + /// registered function's implementation + /// @param function_holder_maker this will be used as the function holder if the + /// function requires a function holder + arrow::Status Register( + NativeFunction func, void* c_function_ptr, + std::optional function_holder_maker = std::nullopt); + /// \brief get a list of bitcode memory buffers saved in the registry const std::vector>& GetBitcodeBuffers() const; + /// \brief get a list of C functions saved in the registry + const std::vector>& GetCFunctions() const; + + const FunctionHolderMakerRegistry& GetFunctionHolderMakerRegistry() const; + iterator begin() const; iterator end() const; iterator back() const; @@ -65,6 +86,8 @@ class GANDIVA_EXPORT FunctionRegistry { std::vector pc_registry_; SignatureMap pc_registry_map_; std::vector> bitcode_memory_buffers_; + std::vector> c_functions_; + FunctionHolderMakerRegistry holder_maker_registry_; Status Add(NativeFunction func); }; diff --git a/cpp/src/gandiva/gdv_function_stubs.cc b/cpp/src/gandiva/gdv_function_stubs.cc index 67d39aeba55da..0ad3c1738e835 100644 --- a/cpp/src/gandiva/gdv_function_stubs.cc +++ b/cpp/src/gandiva/gdv_function_stubs.cc @@ -822,7 +822,7 @@ const char* gdv_mask_show_last_n_utf8_int32(int64_t context, const char* data, namespace gandiva { -void ExportedStubFunctions::AddMappings(Engine* engine) const { +arrow::Status ExportedStubFunctions::AddMappings(Engine* engine) const { std::vector args; auto types = engine->types(); @@ -1268,5 +1268,6 @@ void ExportedStubFunctions::AddMappings(Engine* engine) const { engine->AddGlobalMappingForFunc("mask_utf8", types->i8_ptr_type() /*return_type*/, args, reinterpret_cast(mask_utf8)); + return arrow::Status::OK(); } } // namespace gandiva diff --git a/cpp/src/gandiva/gdv_function_stubs.h b/cpp/src/gandiva/gdv_function_stubs.h index 5356a91f3ce59..3f52537ee05ca 100644 --- a/cpp/src/gandiva/gdv_function_stubs.h +++ b/cpp/src/gandiva/gdv_function_stubs.h @@ -74,8 +74,10 @@ int64_t gdv_fn_to_date_utf8_utf8_int32(int64_t context, int64_t ptr, const char* bool in2_validity, int32_t suppress_errors, bool in3_validity, bool* out_valid); +GANDIVA_EXPORT void gdv_fn_context_set_error_msg(int64_t context_ptr, const char* err_msg); +GANDIVA_EXPORT uint8_t* gdv_fn_context_arena_malloc(int64_t context_ptr, int32_t data_len); void gdv_fn_context_arena_reset(int64_t context_ptr); diff --git a/cpp/src/gandiva/gdv_hash_function_stubs.cc b/cpp/src/gandiva/gdv_hash_function_stubs.cc index 018b0fbb709fb..aac70a06be6c7 100644 --- a/cpp/src/gandiva/gdv_hash_function_stubs.cc +++ b/cpp/src/gandiva/gdv_hash_function_stubs.cc @@ -216,7 +216,7 @@ const char* gdv_fn_sha1_decimal128(int64_t context, int64_t x_high, uint64_t x_l namespace gandiva { -void ExportedHashFunctions::AddMappings(Engine* engine) const { +arrow::Status ExportedHashFunctions::AddMappings(Engine* engine) const { std::vector args; auto types = engine->types(); @@ -1041,5 +1041,6 @@ void ExportedHashFunctions::AddMappings(Engine* engine) const { engine->AddGlobalMappingForFunc("gdv_fn_md5_decimal128", types->i8_ptr_type() /*return_type*/, args, reinterpret_cast(gdv_fn_md5_decimal128)); + return arrow::Status::OK(); } } // namespace gandiva diff --git a/cpp/src/gandiva/gdv_string_function_stubs.cc b/cpp/src/gandiva/gdv_string_function_stubs.cc index 3bfb297af141f..9f5b5ce64b4a9 100644 --- a/cpp/src/gandiva/gdv_string_function_stubs.cc +++ b/cpp/src/gandiva/gdv_string_function_stubs.cc @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -//#pragma once - #include "gandiva/gdv_function_stubs.h" #include @@ -761,7 +759,7 @@ const char* translate_utf8_utf8_utf8(int64_t context, const char* in, int32_t in namespace gandiva { -void ExportedStringFunctions::AddMappings(Engine* engine) const { +arrow::Status ExportedStringFunctions::AddMappings(Engine* engine) const { std::vector args; auto types = engine->types(); @@ -988,5 +986,6 @@ void ExportedStringFunctions::AddMappings(Engine* engine) const { engine->AddGlobalMappingForFunc("translate_utf8_utf8_utf8", types->i8_ptr_type() /*return_type*/, args, reinterpret_cast(translate_utf8_utf8_utf8)); + return arrow::Status::OK(); } } // namespace gandiva diff --git a/cpp/src/gandiva/llvm_generator.h b/cpp/src/gandiva/llvm_generator.h index 1921e2565338b..fae6ed48defa5 100644 --- a/cpp/src/gandiva/llvm_generator.h +++ b/cpp/src/gandiva/llvm_generator.h @@ -88,7 +88,7 @@ class GANDIVA_EXPORT LLVMGenerator { FRIEND_TEST(TestLLVMGenerator, VerifyPCFunctions); FRIEND_TEST(TestLLVMGenerator, TestAdd); FRIEND_TEST(TestLLVMGenerator, TestNullInternal); - FRIEND_TEST(TestLLVMGenerator, VerifyExtendedPCFunctions); + friend class TestLLVMGenerator; llvm::LLVMContext* context() { return engine_->context(); } llvm::IRBuilder<>* ir_builder() { return engine_->ir_builder(); } diff --git a/cpp/src/gandiva/llvm_generator_test.cc b/cpp/src/gandiva/llvm_generator_test.cc index 671ce91e870f6..853d8ae6c3b8d 100644 --- a/cpp/src/gandiva/llvm_generator_test.cc +++ b/cpp/src/gandiva/llvm_generator_test.cc @@ -36,6 +36,24 @@ typedef int64_t (*add_vector_func_t)(int64_t* elements, int nelements); class TestLLVMGenerator : public ::testing::Test { protected: std::shared_ptr registry_ = default_function_registry(); + + public: + // create a Configuration with the given registry and verify that the given function + // exists in the module. + static void VerifyFunctionMapping( + const std::string& function_name, + const std::function( + std::shared_ptr)>& config_factory) { + auto external_registry = std::make_shared(); + auto config = config_factory(std::move(external_registry)); + + std::unique_ptr generator; + ASSERT_OK(LLVMGenerator::Make(config, false, &generator)); + + auto module = generator->module(); + ASSERT_OK(generator->engine_->LoadFunctionIRs()); + EXPECT_NE(module->getFunction(function_name), nullptr); + } }; // Verify that a valid pc function exists for every function in the registry. @@ -116,16 +134,19 @@ TEST_F(TestLLVMGenerator, TestAdd) { } TEST_F(TestLLVMGenerator, VerifyExtendedPCFunctions) { - auto external_registry = std::make_shared(); - auto config_with_func_registry = - TestConfigurationWithFunctionRegistry(std::move(external_registry)); + VerifyFunctionMapping("multiply_by_two_int32", [](auto registry) { + return TestConfigWithFunctionRegistry(std::move(registry)); + }); +} - std::unique_ptr generator; - ASSERT_OK(LLVMGenerator::Make(config_with_func_registry, false, &generator)); +TEST_F(TestLLVMGenerator, VerifyExtendedCFunctions) { + VerifyFunctionMapping("multiply_by_three_int32", [](auto registry) { + return TestConfigWithCFunction(std::move(registry)); + }); - auto module = generator->module(); - ASSERT_OK(generator->engine_->LoadFunctionIRs()); - EXPECT_NE(module->getFunction("multiply_by_two_int32"), nullptr); + VerifyFunctionMapping("multiply_by_n_int32_int32", [](auto registry) { + return TestConfigWithHolderFunction(std::move(registry)); + }); } } // namespace gandiva diff --git a/cpp/src/gandiva/tests/projector_test.cc b/cpp/src/gandiva/tests/projector_test.cc index 38566fb408ab5..59eeb3d92f19a 100644 --- a/cpp/src/gandiva/tests/projector_test.cc +++ b/cpp/src/gandiva/tests/projector_test.cc @@ -3594,7 +3594,7 @@ TEST_F(TestProjector, TestExtendedFunctions) { std::shared_ptr projector; auto external_registry = std::make_shared(); auto config_with_func_registry = - TestConfigurationWithFunctionRegistry(std::move(external_registry)); + TestConfigWithFunctionRegistry(std::move(external_registry)); ARROW_EXPECT_OK( Projector::Make(schema, {multiply}, config_with_func_registry, &projector)); @@ -3608,4 +3608,79 @@ TEST_F(TestProjector, TestExtendedFunctions) { EXPECT_ARROW_ARRAY_EQUALS(out, outs.at(0)); } +TEST_F(TestProjector, TestExtendedCFunctions) { + auto in_field = field("in", arrow::int32()); + auto schema = arrow::schema({in_field}); + auto out_field = field("out", arrow::int64()); + auto multiply = + TreeExprBuilder::MakeExpression("multiply_by_three", {in_field}, out_field); + + std::shared_ptr projector; + auto external_registry = std::make_shared(); + auto config_with_func_registry = TestConfigWithCFunction(std::move(external_registry)); + ARROW_EXPECT_OK( + Projector::Make(schema, {multiply}, config_with_func_registry, &projector)); + + int num_records = 4; + auto array = MakeArrowArrayInt32({1, 2, 3, 4}, {true, true, true, true}); + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array}); + auto out = MakeArrowArrayInt64({3, 6, 9, 12}, {true, true, true, true}); + + arrow::ArrayVector outs; + ARROW_EXPECT_OK(projector->Evaluate(*in_batch, pool_, &outs)); + EXPECT_ARROW_ARRAY_EQUALS(out, outs.at(0)); +} + +TEST_F(TestProjector, TestExtendedCFunctionsWithFunctionHolder) { + auto multiple = TreeExprBuilder::MakeLiteral(5); + auto in_field = field("in", arrow::int32()); + auto schema = arrow::schema({in_field}); + auto out_field = field("out", arrow::int64()); + + auto in_node = TreeExprBuilder::MakeField(in_field); + auto multiply_by_n_func = + TreeExprBuilder::MakeFunction("multiply_by_n", {in_node, multiple}, arrow::int64()); + auto multiply = TreeExprBuilder::MakeExpression(multiply_by_n_func, out_field); + + std::shared_ptr projector; + auto external_registry = std::make_shared(); + auto config_with_func_registry = + TestConfigWithHolderFunction(std::move(external_registry)); + ARROW_EXPECT_OK( + Projector::Make(schema, {multiply}, config_with_func_registry, &projector)); + + int num_records = 4; + auto array = MakeArrowArrayInt32({1, 2, 3, 4}, {true, true, true, true}); + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array}); + auto out = MakeArrowArrayInt64({5, 10, 15, 20}, {true, true, true, true}); + + arrow::ArrayVector outs; + ARROW_EXPECT_OK(projector->Evaluate(*in_batch, pool_, &outs)); + EXPECT_ARROW_ARRAY_EQUALS(out, outs.at(0)); +} + +TEST_F(TestProjector, TestExtendedCFunctionThatNeedsContext) { + auto in_field = field("in", arrow::utf8()); + auto schema = arrow::schema({in_field}); + auto out_field = field("out", arrow::utf8()); + auto multiply = + TreeExprBuilder::MakeExpression("multiply_by_two_formula", {in_field}, out_field); + + std::shared_ptr projector; + auto external_registry = std::make_shared(); + auto config_with_func_registry = + TestConfigWithContextFunction(std::move(external_registry)); + ARROW_EXPECT_OK( + Projector::Make(schema, {multiply}, config_with_func_registry, &projector)); + + int num_records = 4; + auto array = MakeArrowArrayUtf8({"1", "2", "3", "10"}, {true, true, true, true}); + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array}); + auto out = MakeArrowArrayUtf8({"1x2", "2x2", "3x2", "10x2"}, {true, true, true, true}); + + arrow::ArrayVector outs; + ARROW_EXPECT_OK(projector->Evaluate(*in_batch, pool_, &outs)); + EXPECT_ARROW_ARRAY_EQUALS(out, outs.at(0)); +} + } // namespace gandiva diff --git a/cpp/src/gandiva/tests/test_util.cc b/cpp/src/gandiva/tests/test_util.cc index 4a0a15c7223db..959ea3cd7a446 100644 --- a/cpp/src/gandiva/tests/test_util.cc +++ b/cpp/src/gandiva/tests/test_util.cc @@ -17,8 +17,13 @@ #include "gandiva/tests/test_util.h" +#include +#include + #include "arrow/util/io_util.h" #include "arrow/util/logging.h" +#include "gandiva/function_holder.h" +#include "gandiva/gdv_function_stubs.h" namespace gandiva { std::shared_ptr TestConfiguration() { @@ -43,11 +48,124 @@ NativeFunction GetTestExternalFunction() { return multiply_by_two_func; } -std::shared_ptr TestConfigurationWithFunctionRegistry( +static NativeFunction GetTestExternalCFunction() { + NativeFunction multiply_by_three_func( + "multiply_by_three", {}, {arrow::int32()}, arrow::int64(), + ResultNullableType::kResultNullIfNull, "multiply_by_three_int32"); + return multiply_by_three_func; +} + +static NativeFunction GetTestFunctionWithFunctionHolder() { + // the 2nd parameter is expected to be an int32 literal + NativeFunction multiply_by_n_func("multiply_by_n", {}, {arrow::int32(), arrow::int32()}, + arrow::int64(), ResultNullableType::kResultNullIfNull, + "multiply_by_n_int32_int32", + NativeFunction::kNeedsFunctionHolder); + return multiply_by_n_func; +} + +static NativeFunction GetTestFunctionWithContext() { + NativeFunction multiply_by_two_formula( + "multiply_by_two_formula", {}, {arrow::utf8()}, arrow::utf8(), + ResultNullableType::kResultNullIfNull, "multiply_by_two_formula_utf8", + NativeFunction::kNeedsContext); + return multiply_by_two_formula; +} + +static std::shared_ptr BuildConfigurationWithRegistry( + std::shared_ptr registry, + const std::function)>& + register_func) { + ARROW_EXPECT_OK(register_func(registry)); + return ConfigurationBuilder().build(std::move(registry)); +} + +std::shared_ptr TestConfigWithFunctionRegistry( + std::shared_ptr registry) { + return BuildConfigurationWithRegistry(std::move(registry), [](auto reg) { + return reg->Register({GetTestExternalFunction()}, GetTestFunctionLLVMIRPath()); + }); +} + +class MultiplyHolder : public FunctionHolder { + public: + explicit MultiplyHolder(int32_t num) : num_(num) {} + + static arrow::Result> Make(const FunctionNode& node) { + ARROW_RETURN_IF(node.children().size() != 2, + Status::Invalid("'multiply_by_n' function requires two parameters")); + + auto literal = dynamic_cast(node.children().at(1).get()); + ARROW_RETURN_IF( + literal == nullptr, + Status::Invalid( + "'multiply_by_n' function requires a literal as the 2nd parameter")); + + auto literal_type = literal->return_type()->id(); + ARROW_RETURN_IF( + literal_type != arrow::Type::INT32, + Status::Invalid( + "'multiply_by_n' function requires an int32 literal as the 2nd parameter")); + + return std::make_shared( + literal->is_null() ? 0 : std::get(literal->holder())); + } + + int32_t operator()() const { return num_; } + + private: + int32_t num_; +}; + +extern "C" { +// this function is used as an external C function for testing so it has to be declared +// with extern C +static int64_t multiply_by_three(int32_t value) { return value * 3; } + +// this function requires a function holder +static int64_t multiply_by_n(int64_t holder_ptr, int32_t value) { + auto* holder = reinterpret_cast(holder_ptr); + return value * (*holder)(); +} + +// given a number string, return a string "{number}x2" +static const char* multiply_by_two_formula(int64_t ctx, const char* value, + int32_t value_len, int32_t* out_len) { + auto result = std::string(value, value_len) + "x2"; + *out_len = static_cast(result.length()); + auto out = reinterpret_cast(gdv_fn_context_arena_malloc(ctx, *out_len)); + if (out == nullptr) { + gdv_fn_context_set_error_msg(ctx, "Could not allocate memory for output string"); + *out_len = 0; + return ""; + } + memcpy(out, result.c_str(), *out_len); + return out; +} +} + +std::shared_ptr TestConfigWithCFunction( + std::shared_ptr registry) { + return BuildConfigurationWithRegistry(std::move(registry), [](auto reg) { + return reg->Register(GetTestExternalCFunction(), + reinterpret_cast(multiply_by_three)); + }); +} + +std::shared_ptr TestConfigWithHolderFunction( + std::shared_ptr registry) { + return BuildConfigurationWithRegistry(std::move(registry), [](auto reg) { + return reg->Register( + GetTestFunctionWithFunctionHolder(), reinterpret_cast(multiply_by_n), + [](const FunctionNode& node) { return MultiplyHolder::Make(node); }); + }); +} + +std::shared_ptr TestConfigWithContextFunction( std::shared_ptr registry) { - ARROW_EXPECT_OK( - registry->Register({GetTestExternalFunction()}, GetTestFunctionLLVMIRPath())); - auto external_func_config = ConfigurationBuilder().build(std::move(registry)); - return external_func_config; + return BuildConfigurationWithRegistry(std::move(registry), [](auto reg) { + return reg->Register(GetTestFunctionWithContext(), + reinterpret_cast(multiply_by_two_formula)); + }); } } // namespace gandiva diff --git a/cpp/src/gandiva/tests/test_util.h b/cpp/src/gandiva/tests/test_util.h index e431e53096c2c..69d63732aeeaa 100644 --- a/cpp/src/gandiva/tests/test_util.h +++ b/cpp/src/gandiva/tests/test_util.h @@ -98,7 +98,26 @@ static inline ArrayPtr MakeArrowTypeArray(const std::shared_ptr std::shared_ptr TestConfiguration(); -std::shared_ptr TestConfigurationWithFunctionRegistry( +// helper function to create a Configuration with an external function registered to the +// given function registry +std::shared_ptr TestConfigWithFunctionRegistry( + std::shared_ptr registry); + +// helper function to create a Configuration with an external C function registered to +// the given function registry +std::shared_ptr TestConfigWithCFunction( + std::shared_ptr registry); + +// helper function to create a Configuration with an external function registered +// to the given function registry, and the external function is a function with a function +// holder +std::shared_ptr TestConfigWithHolderFunction( + std::shared_ptr registry); + +// helper function to create a Configuration with an external function registered +// to the given function registry, and the external function is a function that needs +// context +std::shared_ptr TestConfigWithContextFunction( std::shared_ptr registry); std::string GetTestFunctionLLVMIRPath();