Skip to content

Commit

Permalink
ARROW-15582: [C++] Add support for registering standard Substrait fun…
Browse files Browse the repository at this point in the history
…ctions (#13613)

This picks up where #13285 has left off.  It mostly focuses on the Substrait->Arrow direction at the moment.  In addition, basic support is added for named tables.  This makes it possible to create unit tests that read from in-memory tables instead of requiring unit tests to do a scan.

The PR creates some utilities in `test_plan_builder.h` which allow for the construction of simple Substrait plans programmatically.  This is used to create unit tests for the function mapping.

The PR extracts id "ownership" out of the `ExtensionIdRegistry` and into its own `IdStorage` class.

The PR gets rid of `NestedExtensionIdRegistryImpl` and instead makes `ExtensionIdRegistryImpl` nested if `parent_ != nullptr`.



Authored-by: Weston Pace <[email protected]>
Signed-off-by: Weston Pace <[email protected]>
  • Loading branch information
westonpace authored Aug 10, 2022
1 parent ae071bb commit cdb5b20
Show file tree
Hide file tree
Showing 27 changed files with 2,196 additions and 480 deletions.
2 changes: 2 additions & 0 deletions cpp/src/arrow/compute/exec/options.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
namespace arrow {
namespace compute {

constexpr int64_t TableSourceNodeOptions::kDefaultMaxBatchSize;

std::string ToString(JoinType t) {
switch (t) {
case JoinType::LEFT_SEMI:
Expand Down
4 changes: 3 additions & 1 deletion cpp/src/arrow/compute/exec/options.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ class ARROW_EXPORT SourceNodeOptions : public ExecNodeOptions {
/// \brief An extended Source node which accepts a table
class ARROW_EXPORT TableSourceNodeOptions : public ExecNodeOptions {
public:
TableSourceNodeOptions(std::shared_ptr<Table> table, int64_t max_batch_size)
static constexpr int64_t kDefaultMaxBatchSize = 1 << 20;
TableSourceNodeOptions(std::shared_ptr<Table> table,
int64_t max_batch_size = kDefaultMaxBatchSize)
: table(table), max_batch_size(max_batch_size) {}

// arrow table which acts as the data source
Expand Down
41 changes: 0 additions & 41 deletions cpp/src/arrow/compute/exec/sink_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -388,47 +388,6 @@ class ConsumingSinkNode : public ExecNode, public BackpressureControl {
std::vector<std::string> names_;
int32_t backpressure_counter_ = 0;
};

/**
* @brief This node is an extension on ConsumingSinkNode
* to facilitate to get the output from an execution plan
* as a table. We define a custom SinkNodeConsumer to
* enable this functionality.
*/

struct TableSinkNodeConsumer : public SinkNodeConsumer {
public:
TableSinkNodeConsumer(std::shared_ptr<Table>* out, MemoryPool* pool)
: out_(out), pool_(pool) {}

Status Init(const std::shared_ptr<Schema>& schema,
BackpressureControl* backpressure_control) override {
// If the user is collecting into a table then backpressure is meaningless
ARROW_UNUSED(backpressure_control);
schema_ = schema;
return Status::OK();
}

Status Consume(ExecBatch batch) override {
std::lock_guard<std::mutex> guard(consume_mutex_);
ARROW_ASSIGN_OR_RAISE(auto rb, batch.ToRecordBatch(schema_, pool_));
batches_.push_back(rb);
return Status::OK();
}

Future<> Finish() override {
ARROW_ASSIGN_OR_RAISE(*out_, Table::FromRecordBatches(batches_));
return Status::OK();
}

private:
std::shared_ptr<Table>* out_;
MemoryPool* pool_;
std::shared_ptr<Schema> schema_;
std::vector<std::shared_ptr<RecordBatch>> batches_;
std::mutex consume_mutex_;
};

static Result<ExecNode*> MakeTableConsumingSinkNode(
compute::ExecPlan* plan, std::vector<compute::ExecNode*> inputs,
const compute::ExecNodeOptions& options) {
Expand Down
20 changes: 20 additions & 0 deletions cpp/src/arrow/compute/exec/util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -383,5 +383,25 @@ size_t ThreadIndexer::Check(size_t thread_index) {
return thread_index;
}

Status TableSinkNodeConsumer::Init(const std::shared_ptr<Schema>& schema,
BackpressureControl* backpressure_control) {
// If the user is collecting into a table then backpressure is meaningless
ARROW_UNUSED(backpressure_control);
schema_ = schema;
return Status::OK();
}

Status TableSinkNodeConsumer::Consume(ExecBatch batch) {
auto guard = consume_mutex_.Lock();
ARROW_ASSIGN_OR_RAISE(auto rb, batch.ToRecordBatch(schema_, pool_));
batches_.push_back(std::move(rb));
return Status::OK();
}

Future<> TableSinkNodeConsumer::Finish() {
ARROW_ASSIGN_OR_RAISE(*out_, Table::FromRecordBatches(batches_));
return Status::OK();
}

} // namespace compute
} // namespace arrow
19 changes: 19 additions & 0 deletions cpp/src/arrow/compute/exec/util.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <vector>

#include "arrow/buffer.h"
#include "arrow/compute/exec/options.h"
#include "arrow/compute/type_fwd.h"
#include "arrow/memory_pool.h"
#include "arrow/result.h"
Expand Down Expand Up @@ -342,5 +343,23 @@ class TailSkipForSIMD {
}
};

/// \brief A consumer that collects results into an in-memory table
struct ARROW_EXPORT TableSinkNodeConsumer : public SinkNodeConsumer {
public:
TableSinkNodeConsumer(std::shared_ptr<Table>* out, MemoryPool* pool)
: out_(out), pool_(pool) {}
Status Init(const std::shared_ptr<Schema>& schema,
BackpressureControl* backpressure_control) override;
Status Consume(ExecBatch batch) override;
Future<> Finish() override;

private:
std::shared_ptr<Table>* out_;
MemoryPool* pool_;
std::shared_ptr<Schema> schema_;
std::vector<std::shared_ptr<RecordBatch>> batches_;
util::Mutex consume_mutex_;
};

} // namespace compute
} // namespace arrow
4 changes: 3 additions & 1 deletion cpp/src/arrow/engine/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ set(ARROW_SUBSTRAIT_SRCS
substrait/expression_internal.cc
substrait/extension_set.cc
substrait/extension_types.cc
substrait/serde.cc
substrait/plan_internal.cc
substrait/relation_internal.cc
substrait/serde.cc
substrait/test_plan_builder.cc
substrait/type_internal.cc
substrait/util.cc)

Expand Down Expand Up @@ -67,6 +68,7 @@ endif()
add_arrow_test(substrait_test
SOURCES
substrait/ext_test.cc
substrait/function_test.cc
substrait/serde_test.cc
EXTRA_LINK_LIBS
${ARROW_SUBSTRAIT_TEST_LINK_LIBS}
Expand Down
165 changes: 127 additions & 38 deletions cpp/src/arrow/engine/substrait/expression_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,84 @@ namespace internal {
using ::arrow::internal::make_unique;
} // namespace internal

Status DecodeArg(const substrait::FunctionArgument& arg, uint32_t idx,
SubstraitCall* call, const ExtensionSet& ext_set,
const ConversionOptions& conversion_options) {
if (arg.has_enum_()) {
const substrait::FunctionArgument::Enum& enum_val = arg.enum_();
if (enum_val.has_specified()) {
call->SetEnumArg(idx, enum_val.specified());
} else {
call->SetEnumArg(idx, util::nullopt);
}
} else if (arg.has_value()) {
ARROW_ASSIGN_OR_RAISE(compute::Expression expr,
FromProto(arg.value(), ext_set, conversion_options));
call->SetValueArg(idx, std::move(expr));
} else if (arg.has_type()) {
return Status::NotImplemented("Type arguments not currently supported");
} else {
return Status::NotImplemented("Unrecognized function argument class");
}
return Status::OK();
}

Result<SubstraitCall> DecodeScalarFunction(
Id id, const substrait::Expression::ScalarFunction& scalar_fn,
const ExtensionSet& ext_set, const ConversionOptions& conversion_options) {
ARROW_ASSIGN_OR_RAISE(auto output_type_and_nullable,
FromProto(scalar_fn.output_type(), ext_set, conversion_options));
SubstraitCall call(id, output_type_and_nullable.first, output_type_and_nullable.second);
for (int i = 0; i < scalar_fn.arguments_size(); i++) {
ARROW_RETURN_NOT_OK(DecodeArg(scalar_fn.arguments(i), static_cast<uint32_t>(i), &call,
ext_set, conversion_options));
}
return std::move(call);
}

std::string EnumToString(int value, const google::protobuf::EnumDescriptor* descriptor) {
const google::protobuf::EnumValueDescriptor* value_desc =
descriptor->FindValueByNumber(value);
if (value_desc == nullptr) {
return "unknown";
}
return value_desc->name();
}

Result<SubstraitCall> FromProto(const substrait::AggregateFunction& func, bool is_hash,
const ExtensionSet& ext_set,
const ConversionOptions& conversion_options) {
if (func.phase() != substrait::AggregationPhase::AGGREGATION_PHASE_INITIAL_TO_RESULT) {
return Status::NotImplemented(
"Unsupported aggregation phase '",
EnumToString(func.phase(), substrait::AggregationPhase_descriptor()),
"'. Only INITIAL_TO_RESULT is supported");
}
if (func.invocation() !=
substrait::AggregateFunction::AggregationInvocation::
AggregateFunction_AggregationInvocation_AGGREGATION_INVOCATION_ALL) {
return Status::NotImplemented(
"Unsupported aggregation invocation '",
EnumToString(func.invocation(),
substrait::AggregateFunction::AggregationInvocation_descriptor()),
"'. Only AGGREGATION_INVOCATION_ALL is "
"supported");
}
if (func.sorts_size() > 0) {
return Status::NotImplemented("Aggregation sorts are not supported");
}
ARROW_ASSIGN_OR_RAISE(auto output_type_and_nullable,
FromProto(func.output_type(), ext_set, conversion_options));
ARROW_ASSIGN_OR_RAISE(Id id, ext_set.DecodeFunction(func.function_reference()));
SubstraitCall call(id, output_type_and_nullable.first, output_type_and_nullable.second,
is_hash);
for (int i = 0; i < func.arguments_size(); i++) {
ARROW_RETURN_NOT_OK(DecodeArg(func.arguments(i), static_cast<uint32_t>(i), &call,
ext_set, conversion_options));
}
return std::move(call);
}

Result<compute::Expression> FromProto(const substrait::Expression& expr,
const ExtensionSet& ext_set,
const ConversionOptions& conversion_options) {
Expand Down Expand Up @@ -166,34 +244,14 @@ Result<compute::Expression> FromProto(const substrait::Expression& expr,
case substrait::Expression::kScalarFunction: {
const auto& scalar_fn = expr.scalar_function();

ARROW_ASSIGN_OR_RAISE(auto decoded_function,
ARROW_ASSIGN_OR_RAISE(Id function_id,
ext_set.DecodeFunction(scalar_fn.function_reference()));

std::vector<compute::Expression> arguments(scalar_fn.arguments_size());
for (int i = 0; i < scalar_fn.arguments_size(); ++i) {
const auto& argument = scalar_fn.arguments(i);
switch (argument.arg_type_case()) {
case substrait::FunctionArgument::kValue: {
ARROW_ASSIGN_OR_RAISE(
arguments[i], FromProto(argument.value(), ext_set, conversion_options));
break;
}
default:
return Status::NotImplemented(
"only value arguments are currently supported for functions");
}
}

auto func_name = decoded_function.name.to_string();
if (func_name != "cast") {
return compute::call(func_name, std::move(arguments));
} else {
ARROW_ASSIGN_OR_RAISE(
auto output_type_desc,
FromProto(scalar_fn.output_type(), ext_set, conversion_options));
auto cast_options = compute::CastOptions::Safe(std::move(output_type_desc.first));
return compute::call(func_name, std::move(arguments), std::move(cast_options));
}
ARROW_ASSIGN_OR_RAISE(ExtensionIdRegistry::SubstraitCallToArrow function_converter,
ext_set.registry()->GetSubstraitCallToArrow(function_id));
ARROW_ASSIGN_OR_RAISE(
SubstraitCall substrait_call,
DecodeScalarFunction(function_id, scalar_fn, ext_set, conversion_options));
return function_converter(substrait_call);
}

default:
Expand Down Expand Up @@ -827,6 +885,42 @@ static Result<std::unique_ptr<substrait::Expression>> MakeListElementReference(
return MakeDirectReference(std::move(expr), std::move(ref_segment));
}

Result<std::unique_ptr<substrait::Expression::ScalarFunction>> EncodeSubstraitCall(
const SubstraitCall& call, ExtensionSet* ext_set,
const ConversionOptions& conversion_options) {
ARROW_ASSIGN_OR_RAISE(uint32_t anchor, ext_set->EncodeFunction(call.id()));
auto scalar_fn = internal::make_unique<substrait::Expression::ScalarFunction>();
scalar_fn->set_function_reference(anchor);
ARROW_ASSIGN_OR_RAISE(
std::unique_ptr<substrait::Type> output_type,
ToProto(*call.output_type(), call.output_nullable(), ext_set, conversion_options));
scalar_fn->set_allocated_output_type(output_type.release());

for (uint32_t i = 0; i < call.size(); i++) {
substrait::FunctionArgument* arg = scalar_fn->add_arguments();
if (call.HasEnumArg(i)) {
auto enum_val = internal::make_unique<substrait::FunctionArgument::Enum>();
ARROW_ASSIGN_OR_RAISE(util::optional<util::string_view> enum_arg,
call.GetEnumArg(i));
if (enum_arg) {
enum_val->set_specified(enum_arg->to_string());
} else {
enum_val->set_allocated_unspecified(new google::protobuf::Empty());
}
arg->set_allocated_enum_(enum_val.release());
} else if (call.HasValueArg(i)) {
ARROW_ASSIGN_OR_RAISE(compute::Expression value_arg, call.GetValueArg(i));
ARROW_ASSIGN_OR_RAISE(std::unique_ptr<substrait::Expression> value_expr,
ToProto(value_arg, ext_set, conversion_options));
arg->set_allocated_value(value_expr.release());
} else {
return Status::Invalid("Call reported having ", call.size(),
" arguments but no argument could be found at index ", i);
}
}
return std::move(scalar_fn);
}

Result<std::unique_ptr<substrait::Expression>> ToProto(
const compute::Expression& expr, ExtensionSet* ext_set,
const ConversionOptions& conversion_options) {
Expand Down Expand Up @@ -933,17 +1027,12 @@ Result<std::unique_ptr<substrait::Expression>> ToProto(
}

// other expression types dive into extensions immediately
ARROW_ASSIGN_OR_RAISE(auto anchor, ext_set->EncodeFunction(call->function_name));

auto scalar_fn = internal::make_unique<substrait::Expression::ScalarFunction>();
scalar_fn->set_function_reference(anchor);
scalar_fn->mutable_arguments()->Reserve(static_cast<int>(arguments.size()));
for (auto& arg : arguments) {
auto argument = internal::make_unique<substrait::FunctionArgument>();
argument->set_allocated_value(arg.release());
scalar_fn->mutable_arguments()->AddAllocated(argument.release());
}

ARROW_ASSIGN_OR_RAISE(
ExtensionIdRegistry::ArrowToSubstraitCall converter,
ext_set->registry()->GetArrowToSubstraitCall(call->function_name));
ARROW_ASSIGN_OR_RAISE(SubstraitCall substrait_call, converter(*call));
ARROW_ASSIGN_OR_RAISE(std::unique_ptr<substrait::Expression::ScalarFunction> scalar_fn,
EncodeSubstraitCall(substrait_call, ext_set, conversion_options));
out->set_allocated_scalar_function(scalar_fn.release());
return std::move(out);
}
Expand Down
4 changes: 4 additions & 0 deletions cpp/src/arrow/engine/substrait/expression_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,5 +50,9 @@ Result<std::unique_ptr<substrait::Expression::Literal>> ToProto(const Datum&,
ExtensionSet*,
const ConversionOptions&);

ARROW_ENGINE_EXPORT
Result<SubstraitCall> FromProto(const substrait::AggregateFunction&, bool is_hash,
const ExtensionSet&, const ConversionOptions&);

} // namespace engine
} // namespace arrow
Loading

0 comments on commit cdb5b20

Please sign in to comment.