Skip to content

Commit

Permalink
Verify the external gandiva function that needs context to alloc memo…
Browse files Browse the repository at this point in the history
…ry can work.
  • Loading branch information
niyue committed Nov 8, 2023
1 parent 6de2416 commit c11101b
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 54 deletions.
62 changes: 35 additions & 27 deletions cpp/src/gandiva/external_stub_functions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,37 +48,45 @@ static arrow::Result<llvm::Type*> AsLLVMType(const DataTypePtr& from_type,
}
}

// map from a NativeFunction's signature to the corresponding LLVM signature
static arrow::Result<std::pair<std::vector<llvm::Type*>, llvm::Type*>> MapToLLVMSignature(
const FunctionSignature& sig, const NativeFunction& func, LLVMTypes* types) {
std::vector<llvm::Type*> args;
args.reserve(sig.param_types().size());
if (func.NeedsContext()) {
args.emplace_back(types->i64_type());
}
if (func.NeedsFunctionHolder()) {
args.emplace_back(types->i64_type());
}
for (auto const& arg : sig.param_types()) {
if (arg->id() == arrow::Type::STRING) {
args.emplace_back(types->i8_ptr_type());
args.emplace_back(types->i32_type());
} else {
ARROW_ASSIGN_OR_RAISE(auto arg_llvm_type, AsLLVMType(arg, types));
args.emplace_back(arg_llvm_type);
}
}
llvm::Type* ret_llvm_type;
if (sig.ret_type()->id() == arrow::Type::STRING) {
// for string output, the last arg is the output length
args.emplace_back(types->i32_ptr_type());
ret_llvm_type = types->i8_ptr_type();
} else {
ARROW_ASSIGN_OR_RAISE(ret_llvm_type, AsLLVMType(sig.ret_type(), types));
}
auto return_type = AsLLVMType(sig.ret_type(), types);
return std::make_pair(args, ret_llvm_type);
}

arrow::Status ExternalStubFunctions::AddMappings(Engine* engine) const {
auto external_stub_funcs = function_registry_->GetStubFunctions();
auto types = engine->types();
for (auto& [func, func_ptr] : external_stub_funcs) {
for (auto& sig : func.signatures()) {
std::vector<llvm::Type*> args;
args.reserve(sig.param_types().size());
if (func.NeedsContext()) {
args.emplace_back(types->i64_type());
}
if (func.NeedsFunctionHolder()) {
args.emplace_back(types->i64_type());
}
for (auto const& arg : sig.param_types()) {
if (arg->id() == arrow::Type::STRING) {
args.emplace_back(types->i8_ptr_type());
args.emplace_back(types->i32_type());
} else {
ARROW_ASSIGN_OR_RAISE(auto arg_llvm_type, AsLLVMType(arg, types));
args.emplace_back(arg_llvm_type);
}
}
llvm::Type* ret_llvm_type;
if (sig.ret_type()->id() == arrow::Type::STRING) {
// for string output, the last arg is the output length
args.emplace_back(types->i32_ptr_type());
ret_llvm_type = types->i8_ptr_type();
} else {
ARROW_ASSIGN_OR_RAISE(ret_llvm_type, AsLLVMType(sig.ret_type(), types));
}
auto return_type = AsLLVMType(sig.ret_type(), types);
for (auto const& sig : func.signatures()) {
ARROW_ASSIGN_OR_RAISE(auto llvm_signature, MapToLLVMSignature(sig, func, types));
auto& [args, ret_llvm_type] = llvm_signature;
engine->AddGlobalMappingForFunc(func.pc_name(), ret_llvm_type, args, func_ptr);
}
}
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/gandiva/llvm_generator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ TEST_F(TestLLVMGenerator, TestAdd) {
TEST_F(TestLLVMGenerator, VerifyExtendedPCFunctions) {
auto external_registry = std::make_shared<FunctionRegistry>();
auto config_with_func_registry =
TestConfigurationWithFunctionRegistry(std::move(external_registry));
TestConfigWithFunctionRegistry(std::move(external_registry));

std::unique_ptr<LLVMGenerator> generator;
ASSERT_OK(LLVMGenerator::Make(config_with_func_registry, false, &generator));
Expand Down
30 changes: 27 additions & 3 deletions cpp/src/gandiva/tests/projector_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3594,7 +3594,7 @@ TEST_F(TestProjector, TestExtendedFunctions) {
std::shared_ptr<Projector> projector;
auto external_registry = std::make_shared<FunctionRegistry>();
auto config_with_func_registry =
TestConfigurationWithFunctionRegistry(std::move(external_registry));
TestConfigWithFunctionRegistry(std::move(external_registry));
ARROW_EXPECT_OK(
Projector::Make(schema, {multiply}, config_with_func_registry, &projector));

Expand All @@ -3618,7 +3618,7 @@ TEST_F(TestProjector, TestExtendedStubFunctions) {
std::shared_ptr<Projector> projector;
auto external_registry = std::make_shared<FunctionRegistry>();
auto config_with_func_registry =
TestConfigurationWithExternalStubFunctionRegistry(std::move(external_registry));
TestConfigWithStubFunction(std::move(external_registry));
ARROW_EXPECT_OK(
Projector::Make(schema, {multiply}, config_with_func_registry, &projector));

Expand Down Expand Up @@ -3646,7 +3646,7 @@ TEST_F(TestProjector, TestExtendedStubFunctionsWithFunctionHolder) {
std::shared_ptr<Projector> projector;
auto external_registry = std::make_shared<FunctionRegistry>();
auto config_with_func_registry =
TestConfigurationWithFunctionHolderRegistry(std::move(external_registry));
TestConfigWithHolderFunction(std::move(external_registry));
ARROW_EXPECT_OK(
Projector::Make(schema, {multiply}, config_with_func_registry, &projector));

Expand All @@ -3660,4 +3660,28 @@ TEST_F(TestProjector, TestExtendedStubFunctionsWithFunctionHolder) {
EXPECT_ARROW_ARRAY_EQUALS(out, outs.at(0));
}

TEST_F(TestProjector, TestExtendedStubFunctionThatNeedsContext) {
auto in_field = field("in", arrow::utf8());
auto schema = arrow::schema({in_field});
auto out_field = field("out", arrow::utf8());
auto multiply =
TreeExprBuilder::MakeExpression("multiply_by_two_formula", {in_field}, out_field);

std::shared_ptr<Projector> projector;
auto external_registry = std::make_shared<FunctionRegistry>();
auto config_with_func_registry =
TestConfigWithContextFunction(std::move(external_registry));
ARROW_EXPECT_OK(
Projector::Make(schema, {multiply}, config_with_func_registry, &projector));

int num_records = 4;
auto array = MakeArrowArrayUtf8({"1", "2", "3", "10"}, {true, true, true, true});
auto in_batch = arrow::RecordBatch::Make(schema, num_records, {array});
auto out = MakeArrowArrayUtf8({"1x2", "2x2", "3x2", "10x2"}, {true, true, true, true});

arrow::ArrayVector outs;
ARROW_EXPECT_OK(projector->Evaluate(*in_batch, pool_, &outs));
EXPECT_ARROW_ARRAY_EQUALS(out, outs.at(0));
}

} // namespace gandiva
80 changes: 60 additions & 20 deletions cpp/src/gandiva/tests/test_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@

#include <filesystem>
#include <memory>
#include <utility>

#include "gandiva/function_holder.h"
#include "gandiva/gdv_function_stubs.h"

namespace gandiva {
std::shared_ptr<Configuration> TestConfiguration() {
Expand Down Expand Up @@ -61,12 +63,27 @@ static NativeFunction GetTestFunctionWithFunctionHolder() {
return multiply_by_n_func;
}

std::shared_ptr<Configuration> TestConfigurationWithFunctionRegistry(
static NativeFunction GetTestFunctionWithContext() {
NativeFunction multiply_by_two_formula(
"multiply_by_two_formula", {}, {arrow::utf8()}, arrow::utf8(),
ResultNullableType::kResultNullIfNull, "multiply_by_two_formula_utf8",
NativeFunction::kNeedsContext);
return multiply_by_two_formula;
}

static std::shared_ptr<Configuration> BuildConfigurationWithRegistry(
std::shared_ptr<FunctionRegistry> registry,
const std::function<arrow::Status(std::shared_ptr<FunctionRegistry>)>&
register_func) {
ARROW_EXPECT_OK(register_func(registry));
return ConfigurationBuilder().build(std::move(registry));
}

std::shared_ptr<Configuration> TestConfigWithFunctionRegistry(
std::shared_ptr<FunctionRegistry> registry) {
ARROW_EXPECT_OK(
registry->Register({GetTestExternalFunction()}, GetTestFunctionLLVMIRPath()));
auto external_func_config = ConfigurationBuilder().build(std::move(registry));
return external_func_config;
return BuildConfigurationWithRegistry(std::move(registry), [](auto reg) {
return reg->Register({GetTestExternalFunction()}, GetTestFunctionLLVMIRPath());
});
}

class MultiplyHolder : public FunctionHolder {
Expand Down Expand Up @@ -110,26 +127,49 @@ static int64_t multiply_by_n(int64_t holder_ptr, int32_t value) {
MultiplyHolder* holder = reinterpret_cast<MultiplyHolder*>(holder_ptr);
return value * (*holder)();
}

// given a number string, return a string "{number}x2"
static const char* multiply_by_two_formula(int64_t ctx, const char* value,
int32_t value_len, int32_t* out_len) {
auto result = std::string(value, value_len) + "x2";
*out_len = static_cast<int32_t>(result.length());
auto out = reinterpret_cast<char*>(gdv_fn_context_arena_malloc(ctx, *out_len));
if (out == nullptr) {
gdv_fn_context_set_error_msg(ctx, "Could not allocate memory for output string");
*out_len = 0;
return "";
}
memcpy(out, result.c_str(), *out_len);
return out;
}
}

std::shared_ptr<Configuration> TestConfigWithStubFunction(
std::shared_ptr<FunctionRegistry> registry) {
return BuildConfigurationWithRegistry(std::move(registry), [](auto reg) {
return reg->Register(GetTestExternalStubFunction(),
reinterpret_cast<void*>(multiply_by_three));
});
}

std::shared_ptr<Configuration> TestConfigurationWithExternalStubFunctionRegistry(
std::shared_ptr<Configuration> TestConfigWithHolderFunction(
std::shared_ptr<FunctionRegistry> registry) {
ARROW_EXPECT_OK(registry->Register(GetTestExternalStubFunction(),
reinterpret_cast<void*>(multiply_by_three)));
auto external_func_config = ConfigurationBuilder().build(std::move(registry));
return external_func_config;
return BuildConfigurationWithRegistry(std::move(registry), [](auto reg) {
return reg->Register(
GetTestFunctionWithFunctionHolder(), reinterpret_cast<void*>(multiply_by_n),
[](const FunctionNode& node) -> arrow::Result<FunctionHolderPtr> {
std::shared_ptr<MultiplyHolder> derived_instance;
ARROW_RETURN_NOT_OK(MultiplyHolder::Make(node, &derived_instance));
return derived_instance;
});
});
}

std::shared_ptr<Configuration> TestConfigurationWithFunctionHolderRegistry(
std::shared_ptr<Configuration> TestConfigWithContextFunction(
std::shared_ptr<FunctionRegistry> registry) {
ARROW_EXPECT_OK(registry->Register(
GetTestFunctionWithFunctionHolder(), reinterpret_cast<void*>(multiply_by_n),
[](const FunctionNode& node) -> arrow::Result<FunctionHolderPtr> {
std::shared_ptr<MultiplyHolder> derived_instance;
ARROW_RETURN_NOT_OK(MultiplyHolder::Make(node, &derived_instance));
return derived_instance;
}));
auto external_func_config = ConfigurationBuilder().build(std::move(registry));
return external_func_config;
return BuildConfigurationWithRegistry(std::move(registry), [](auto reg) {
return reg->Register(GetTestFunctionWithContext(),
reinterpret_cast<void*>(multiply_by_two_formula));
});
}
} // namespace gandiva
19 changes: 16 additions & 3 deletions cpp/src/gandiva/tests/test_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,26 @@ static inline ArrayPtr MakeArrowTypeArray(const std::shared_ptr<arrow::DataType>

std::shared_ptr<Configuration> TestConfiguration();

std::shared_ptr<Configuration> TestConfigurationWithFunctionRegistry(
// helper function to create a Configuration with an external function registered to the
// given function registry
std::shared_ptr<Configuration> TestConfigWithFunctionRegistry(
std::shared_ptr<FunctionRegistry> registry);

std::shared_ptr<Configuration> TestConfigurationWithExternalStubFunctionRegistry(
// helper function to create a Configuration with an external stub function registered to
// the given function registry
std::shared_ptr<Configuration> TestConfigWithStubFunction(
std::shared_ptr<FunctionRegistry> registry);

std::shared_ptr<Configuration> TestConfigurationWithFunctionHolderRegistry(
// helper function to create a Configuration with an external function registered
// to the given function registry, and the external function is a function with a function
// holder
std::shared_ptr<Configuration> TestConfigWithHolderFunction(
std::shared_ptr<FunctionRegistry> registry);

// helper function to create a Configuration with an external function registered
// to the given function registry, and the external function is a function that needs
// context
std::shared_ptr<Configuration> TestConfigWithContextFunction(
std::shared_ptr<FunctionRegistry> registry);

std::string GetTestFunctionLLVMIRPath();
Expand Down

0 comments on commit c11101b

Please sign in to comment.