diff --git a/cpp/src/gandiva/external_stub_functions.cc b/cpp/src/gandiva/external_stub_functions.cc index 693dd61aff98c..669b97669f827 100644 --- a/cpp/src/gandiva/external_stub_functions.cc +++ b/cpp/src/gandiva/external_stub_functions.cc @@ -48,37 +48,45 @@ static arrow::Result AsLLVMType(const DataTypePtr& from_type, } } +// map from a NativeFunction's signature to the corresponding LLVM signature +static arrow::Result, llvm::Type*>> MapToLLVMSignature( + const FunctionSignature& sig, const NativeFunction& func, LLVMTypes* types) { + std::vector args; + args.reserve(sig.param_types().size()); + if (func.NeedsContext()) { + args.emplace_back(types->i64_type()); + } + if (func.NeedsFunctionHolder()) { + args.emplace_back(types->i64_type()); + } + for (auto const& arg : sig.param_types()) { + if (arg->id() == arrow::Type::STRING) { + args.emplace_back(types->i8_ptr_type()); + args.emplace_back(types->i32_type()); + } else { + ARROW_ASSIGN_OR_RAISE(auto arg_llvm_type, AsLLVMType(arg, types)); + args.emplace_back(arg_llvm_type); + } + } + llvm::Type* ret_llvm_type; + if (sig.ret_type()->id() == arrow::Type::STRING) { + // for string output, the last arg is the output length + args.emplace_back(types->i32_ptr_type()); + ret_llvm_type = types->i8_ptr_type(); + } else { + ARROW_ASSIGN_OR_RAISE(ret_llvm_type, AsLLVMType(sig.ret_type(), types)); + } + auto return_type = AsLLVMType(sig.ret_type(), types); + return std::make_pair(args, ret_llvm_type); +} + arrow::Status ExternalStubFunctions::AddMappings(Engine* engine) const { auto external_stub_funcs = function_registry_->GetStubFunctions(); auto types = engine->types(); for (auto& [func, func_ptr] : external_stub_funcs) { - for (auto& sig : func.signatures()) { - std::vector args; - args.reserve(sig.param_types().size()); - if (func.NeedsContext()) { - args.emplace_back(types->i64_type()); - } - if (func.NeedsFunctionHolder()) { - args.emplace_back(types->i64_type()); - } - for (auto const& arg : sig.param_types()) { - if (arg->id() == arrow::Type::STRING) { - args.emplace_back(types->i8_ptr_type()); - args.emplace_back(types->i32_type()); - } else { - ARROW_ASSIGN_OR_RAISE(auto arg_llvm_type, AsLLVMType(arg, types)); - args.emplace_back(arg_llvm_type); - } - } - llvm::Type* ret_llvm_type; - if (sig.ret_type()->id() == arrow::Type::STRING) { - // for string output, the last arg is the output length - args.emplace_back(types->i32_ptr_type()); - ret_llvm_type = types->i8_ptr_type(); - } else { - ARROW_ASSIGN_OR_RAISE(ret_llvm_type, AsLLVMType(sig.ret_type(), types)); - } - auto return_type = AsLLVMType(sig.ret_type(), types); + for (auto const& sig : func.signatures()) { + ARROW_ASSIGN_OR_RAISE(auto llvm_signature, MapToLLVMSignature(sig, func, types)); + auto& [args, ret_llvm_type] = llvm_signature; engine->AddGlobalMappingForFunc(func.pc_name(), ret_llvm_type, args, func_ptr); } } diff --git a/cpp/src/gandiva/llvm_generator_test.cc b/cpp/src/gandiva/llvm_generator_test.cc index 671ce91e870f6..9ff1de006be08 100644 --- a/cpp/src/gandiva/llvm_generator_test.cc +++ b/cpp/src/gandiva/llvm_generator_test.cc @@ -118,7 +118,7 @@ TEST_F(TestLLVMGenerator, TestAdd) { TEST_F(TestLLVMGenerator, VerifyExtendedPCFunctions) { auto external_registry = std::make_shared(); auto config_with_func_registry = - TestConfigurationWithFunctionRegistry(std::move(external_registry)); + TestConfigWithFunctionRegistry(std::move(external_registry)); std::unique_ptr generator; ASSERT_OK(LLVMGenerator::Make(config_with_func_registry, false, &generator)); diff --git a/cpp/src/gandiva/tests/projector_test.cc b/cpp/src/gandiva/tests/projector_test.cc index eb8b03f4315f9..86a1f0a9ebac4 100644 --- a/cpp/src/gandiva/tests/projector_test.cc +++ b/cpp/src/gandiva/tests/projector_test.cc @@ -3594,7 +3594,7 @@ TEST_F(TestProjector, TestExtendedFunctions) { std::shared_ptr projector; auto external_registry = std::make_shared(); auto config_with_func_registry = - TestConfigurationWithFunctionRegistry(std::move(external_registry)); + TestConfigWithFunctionRegistry(std::move(external_registry)); ARROW_EXPECT_OK( Projector::Make(schema, {multiply}, config_with_func_registry, &projector)); @@ -3618,7 +3618,7 @@ TEST_F(TestProjector, TestExtendedStubFunctions) { std::shared_ptr projector; auto external_registry = std::make_shared(); auto config_with_func_registry = - TestConfigurationWithExternalStubFunctionRegistry(std::move(external_registry)); + TestConfigWithStubFunction(std::move(external_registry)); ARROW_EXPECT_OK( Projector::Make(schema, {multiply}, config_with_func_registry, &projector)); @@ -3646,7 +3646,7 @@ TEST_F(TestProjector, TestExtendedStubFunctionsWithFunctionHolder) { std::shared_ptr projector; auto external_registry = std::make_shared(); auto config_with_func_registry = - TestConfigurationWithFunctionHolderRegistry(std::move(external_registry)); + TestConfigWithHolderFunction(std::move(external_registry)); ARROW_EXPECT_OK( Projector::Make(schema, {multiply}, config_with_func_registry, &projector)); @@ -3660,4 +3660,28 @@ TEST_F(TestProjector, TestExtendedStubFunctionsWithFunctionHolder) { EXPECT_ARROW_ARRAY_EQUALS(out, outs.at(0)); } +TEST_F(TestProjector, TestExtendedStubFunctionThatNeedsContext) { + auto in_field = field("in", arrow::utf8()); + auto schema = arrow::schema({in_field}); + auto out_field = field("out", arrow::utf8()); + auto multiply = + TreeExprBuilder::MakeExpression("multiply_by_two_formula", {in_field}, out_field); + + std::shared_ptr projector; + auto external_registry = std::make_shared(); + auto config_with_func_registry = + TestConfigWithContextFunction(std::move(external_registry)); + ARROW_EXPECT_OK( + Projector::Make(schema, {multiply}, config_with_func_registry, &projector)); + + int num_records = 4; + auto array = MakeArrowArrayUtf8({"1", "2", "3", "10"}, {true, true, true, true}); + auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array}); + auto out = MakeArrowArrayUtf8({"1x2", "2x2", "3x2", "10x2"}, {true, true, true, true}); + + arrow::ArrayVector outs; + ARROW_EXPECT_OK(projector->Evaluate(*in_batch, pool_, &outs)); + EXPECT_ARROW_ARRAY_EQUALS(out, outs.at(0)); +} + } // namespace gandiva diff --git a/cpp/src/gandiva/tests/test_util.cc b/cpp/src/gandiva/tests/test_util.cc index 838d8b31d9f04..a1db8ec037453 100644 --- a/cpp/src/gandiva/tests/test_util.cc +++ b/cpp/src/gandiva/tests/test_util.cc @@ -19,8 +19,10 @@ #include #include +#include #include "gandiva/function_holder.h" +#include "gandiva/gdv_function_stubs.h" namespace gandiva { std::shared_ptr TestConfiguration() { @@ -61,12 +63,27 @@ static NativeFunction GetTestFunctionWithFunctionHolder() { return multiply_by_n_func; } -std::shared_ptr TestConfigurationWithFunctionRegistry( +static NativeFunction GetTestFunctionWithContext() { + NativeFunction multiply_by_two_formula( + "multiply_by_two_formula", {}, {arrow::utf8()}, arrow::utf8(), + ResultNullableType::kResultNullIfNull, "multiply_by_two_formula_utf8", + NativeFunction::kNeedsContext); + return multiply_by_two_formula; +} + +static std::shared_ptr BuildConfigurationWithRegistry( + std::shared_ptr registry, + const std::function)>& + register_func) { + ARROW_EXPECT_OK(register_func(registry)); + return ConfigurationBuilder().build(std::move(registry)); +} + +std::shared_ptr TestConfigWithFunctionRegistry( std::shared_ptr registry) { - ARROW_EXPECT_OK( - registry->Register({GetTestExternalFunction()}, GetTestFunctionLLVMIRPath())); - auto external_func_config = ConfigurationBuilder().build(std::move(registry)); - return external_func_config; + return BuildConfigurationWithRegistry(std::move(registry), [](auto reg) { + return reg->Register({GetTestExternalFunction()}, GetTestFunctionLLVMIRPath()); + }); } class MultiplyHolder : public FunctionHolder { @@ -110,26 +127,49 @@ static int64_t multiply_by_n(int64_t holder_ptr, int32_t value) { MultiplyHolder* holder = reinterpret_cast(holder_ptr); return value * (*holder)(); } + +// given a number string, return a string "{number}x2" +static const char* multiply_by_two_formula(int64_t ctx, const char* value, + int32_t value_len, int32_t* out_len) { + auto result = std::string(value, value_len) + "x2"; + *out_len = static_cast(result.length()); + auto out = reinterpret_cast(gdv_fn_context_arena_malloc(ctx, *out_len)); + if (out == nullptr) { + gdv_fn_context_set_error_msg(ctx, "Could not allocate memory for output string"); + *out_len = 0; + return ""; + } + memcpy(out, result.c_str(), *out_len); + return out; +} +} + +std::shared_ptr TestConfigWithStubFunction( + std::shared_ptr registry) { + return BuildConfigurationWithRegistry(std::move(registry), [](auto reg) { + return reg->Register(GetTestExternalStubFunction(), + reinterpret_cast(multiply_by_three)); + }); } -std::shared_ptr TestConfigurationWithExternalStubFunctionRegistry( +std::shared_ptr TestConfigWithHolderFunction( std::shared_ptr registry) { - ARROW_EXPECT_OK(registry->Register(GetTestExternalStubFunction(), - reinterpret_cast(multiply_by_three))); - auto external_func_config = ConfigurationBuilder().build(std::move(registry)); - return external_func_config; + return BuildConfigurationWithRegistry(std::move(registry), [](auto reg) { + return reg->Register( + GetTestFunctionWithFunctionHolder(), reinterpret_cast(multiply_by_n), + [](const FunctionNode& node) -> arrow::Result { + std::shared_ptr derived_instance; + ARROW_RETURN_NOT_OK(MultiplyHolder::Make(node, &derived_instance)); + return derived_instance; + }); + }); } -std::shared_ptr TestConfigurationWithFunctionHolderRegistry( +std::shared_ptr TestConfigWithContextFunction( std::shared_ptr registry) { - ARROW_EXPECT_OK(registry->Register( - GetTestFunctionWithFunctionHolder(), reinterpret_cast(multiply_by_n), - [](const FunctionNode& node) -> arrow::Result { - std::shared_ptr derived_instance; - ARROW_RETURN_NOT_OK(MultiplyHolder::Make(node, &derived_instance)); - return derived_instance; - })); - auto external_func_config = ConfigurationBuilder().build(std::move(registry)); - return external_func_config; + return BuildConfigurationWithRegistry(std::move(registry), [](auto reg) { + return reg->Register(GetTestFunctionWithContext(), + reinterpret_cast(multiply_by_two_formula)); + }); } } // namespace gandiva diff --git a/cpp/src/gandiva/tests/test_util.h b/cpp/src/gandiva/tests/test_util.h index 5bc5bc3cdf287..3bd52fc91cbde 100644 --- a/cpp/src/gandiva/tests/test_util.h +++ b/cpp/src/gandiva/tests/test_util.h @@ -98,13 +98,26 @@ static inline ArrayPtr MakeArrowTypeArray(const std::shared_ptr std::shared_ptr TestConfiguration(); -std::shared_ptr TestConfigurationWithFunctionRegistry( +// helper function to create a Configuration with an external function registered to the +// given function registry +std::shared_ptr TestConfigWithFunctionRegistry( std::shared_ptr registry); -std::shared_ptr TestConfigurationWithExternalStubFunctionRegistry( +// helper function to create a Configuration with an external stub function registered to +// the given function registry +std::shared_ptr TestConfigWithStubFunction( std::shared_ptr registry); -std::shared_ptr TestConfigurationWithFunctionHolderRegistry( +// helper function to create a Configuration with an external function registered +// to the given function registry, and the external function is a function with a function +// holder +std::shared_ptr TestConfigWithHolderFunction( + std::shared_ptr registry); + +// helper function to create a Configuration with an external function registered +// to the given function registry, and the external function is a function that needs +// context +std::shared_ptr TestConfigWithContextFunction( std::shared_ptr registry); std::string GetTestFunctionLLVMIRPath();