diff --git a/cpp/src/arrow/compute/exec/options.cc b/cpp/src/arrow/compute/exec/options.cc
index c09ab1c1b68c2..ef1a0c7e2eb2a 100644
--- a/cpp/src/arrow/compute/exec/options.cc
+++ b/cpp/src/arrow/compute/exec/options.cc
@@ -25,6 +25,8 @@
namespace arrow {
namespace compute {
+constexpr int64_t TableSourceNodeOptions::kDefaultMaxBatchSize;
+
std::string ToString(JoinType t) {
switch (t) {
case JoinType::LEFT_SEMI:
diff --git a/cpp/src/arrow/compute/exec/options.h b/cpp/src/arrow/compute/exec/options.h
index 4a0cd602efb54..a8e8c1ee23096 100644
--- a/cpp/src/arrow/compute/exec/options.h
+++ b/cpp/src/arrow/compute/exec/options.h
@@ -64,7 +64,9 @@ class ARROW_EXPORT SourceNodeOptions : public ExecNodeOptions {
/// \brief An extended Source node which accepts a table
class ARROW_EXPORT TableSourceNodeOptions : public ExecNodeOptions {
public:
- TableSourceNodeOptions(std::shared_ptr
table, int64_t max_batch_size)
+ static constexpr int64_t kDefaultMaxBatchSize = 1 << 20;
+ TableSourceNodeOptions(std::shared_ptr table,
+ int64_t max_batch_size = kDefaultMaxBatchSize)
: table(table), max_batch_size(max_batch_size) {}
// arrow table which acts as the data source
diff --git a/cpp/src/arrow/compute/exec/sink_node.cc b/cpp/src/arrow/compute/exec/sink_node.cc
index a1426265cf94f..8af4e8e996cce 100644
--- a/cpp/src/arrow/compute/exec/sink_node.cc
+++ b/cpp/src/arrow/compute/exec/sink_node.cc
@@ -388,47 +388,6 @@ class ConsumingSinkNode : public ExecNode, public BackpressureControl {
std::vector names_;
int32_t backpressure_counter_ = 0;
};
-
-/**
- * @brief This node is an extension on ConsumingSinkNode
- * to facilitate to get the output from an execution plan
- * as a table. We define a custom SinkNodeConsumer to
- * enable this functionality.
- */
-
-struct TableSinkNodeConsumer : public SinkNodeConsumer {
- public:
- TableSinkNodeConsumer(std::shared_ptr* out, MemoryPool* pool)
- : out_(out), pool_(pool) {}
-
- Status Init(const std::shared_ptr& schema,
- BackpressureControl* backpressure_control) override {
- // If the user is collecting into a table then backpressure is meaningless
- ARROW_UNUSED(backpressure_control);
- schema_ = schema;
- return Status::OK();
- }
-
- Status Consume(ExecBatch batch) override {
- std::lock_guard guard(consume_mutex_);
- ARROW_ASSIGN_OR_RAISE(auto rb, batch.ToRecordBatch(schema_, pool_));
- batches_.push_back(rb);
- return Status::OK();
- }
-
- Future<> Finish() override {
- ARROW_ASSIGN_OR_RAISE(*out_, Table::FromRecordBatches(batches_));
- return Status::OK();
- }
-
- private:
- std::shared_ptr* out_;
- MemoryPool* pool_;
- std::shared_ptr schema_;
- std::vector> batches_;
- std::mutex consume_mutex_;
-};
-
static Result MakeTableConsumingSinkNode(
compute::ExecPlan* plan, std::vector inputs,
const compute::ExecNodeOptions& options) {
diff --git a/cpp/src/arrow/compute/exec/util.cc b/cpp/src/arrow/compute/exec/util.cc
index ae70cfcd46f50..a34a9c6271322 100644
--- a/cpp/src/arrow/compute/exec/util.cc
+++ b/cpp/src/arrow/compute/exec/util.cc
@@ -383,5 +383,25 @@ size_t ThreadIndexer::Check(size_t thread_index) {
return thread_index;
}
+Status TableSinkNodeConsumer::Init(const std::shared_ptr& schema,
+ BackpressureControl* backpressure_control) {
+ // If the user is collecting into a table then backpressure is meaningless
+ ARROW_UNUSED(backpressure_control);
+ schema_ = schema;
+ return Status::OK();
+}
+
+Status TableSinkNodeConsumer::Consume(ExecBatch batch) {
+ auto guard = consume_mutex_.Lock();
+ ARROW_ASSIGN_OR_RAISE(auto rb, batch.ToRecordBatch(schema_, pool_));
+ batches_.push_back(std::move(rb));
+ return Status::OK();
+}
+
+Future<> TableSinkNodeConsumer::Finish() {
+ ARROW_ASSIGN_OR_RAISE(*out_, Table::FromRecordBatches(batches_));
+ return Status::OK();
+}
+
} // namespace compute
} // namespace arrow
diff --git a/cpp/src/arrow/compute/exec/util.h b/cpp/src/arrow/compute/exec/util.h
index 30526cb835ab1..7e716808fa008 100644
--- a/cpp/src/arrow/compute/exec/util.h
+++ b/cpp/src/arrow/compute/exec/util.h
@@ -24,6 +24,7 @@
#include
#include "arrow/buffer.h"
+#include "arrow/compute/exec/options.h"
#include "arrow/compute/type_fwd.h"
#include "arrow/memory_pool.h"
#include "arrow/result.h"
@@ -342,5 +343,23 @@ class TailSkipForSIMD {
}
};
+/// \brief A consumer that collects results into an in-memory table
+struct ARROW_EXPORT TableSinkNodeConsumer : public SinkNodeConsumer {
+ public:
+ TableSinkNodeConsumer(std::shared_ptr* out, MemoryPool* pool)
+ : out_(out), pool_(pool) {}
+ Status Init(const std::shared_ptr& schema,
+ BackpressureControl* backpressure_control) override;
+ Status Consume(ExecBatch batch) override;
+ Future<> Finish() override;
+
+ private:
+ std::shared_ptr* out_;
+ MemoryPool* pool_;
+ std::shared_ptr schema_;
+ std::vector> batches_;
+ util::Mutex consume_mutex_;
+};
+
} // namespace compute
} // namespace arrow
diff --git a/cpp/src/arrow/engine/CMakeLists.txt b/cpp/src/arrow/engine/CMakeLists.txt
index 8edd22900e6cb..4109b7b3bcdf3 100644
--- a/cpp/src/arrow/engine/CMakeLists.txt
+++ b/cpp/src/arrow/engine/CMakeLists.txt
@@ -23,9 +23,10 @@ set(ARROW_SUBSTRAIT_SRCS
substrait/expression_internal.cc
substrait/extension_set.cc
substrait/extension_types.cc
- substrait/serde.cc
substrait/plan_internal.cc
substrait/relation_internal.cc
+ substrait/serde.cc
+ substrait/test_plan_builder.cc
substrait/type_internal.cc
substrait/util.cc)
@@ -67,6 +68,7 @@ endif()
add_arrow_test(substrait_test
SOURCES
substrait/ext_test.cc
+ substrait/function_test.cc
substrait/serde_test.cc
EXTRA_LINK_LIBS
${ARROW_SUBSTRAIT_TEST_LINK_LIBS}
diff --git a/cpp/src/arrow/engine/substrait/expression_internal.cc b/cpp/src/arrow/engine/substrait/expression_internal.cc
index 07c222bc4cfd1..589d7e6ac695f 100644
--- a/cpp/src/arrow/engine/substrait/expression_internal.cc
+++ b/cpp/src/arrow/engine/substrait/expression_internal.cc
@@ -41,6 +41,84 @@ namespace internal {
using ::arrow::internal::make_unique;
} // namespace internal
+Status DecodeArg(const substrait::FunctionArgument& arg, uint32_t idx,
+ SubstraitCall* call, const ExtensionSet& ext_set,
+ const ConversionOptions& conversion_options) {
+ if (arg.has_enum_()) {
+ const substrait::FunctionArgument::Enum& enum_val = arg.enum_();
+ if (enum_val.has_specified()) {
+ call->SetEnumArg(idx, enum_val.specified());
+ } else {
+ call->SetEnumArg(idx, util::nullopt);
+ }
+ } else if (arg.has_value()) {
+ ARROW_ASSIGN_OR_RAISE(compute::Expression expr,
+ FromProto(arg.value(), ext_set, conversion_options));
+ call->SetValueArg(idx, std::move(expr));
+ } else if (arg.has_type()) {
+ return Status::NotImplemented("Type arguments not currently supported");
+ } else {
+ return Status::NotImplemented("Unrecognized function argument class");
+ }
+ return Status::OK();
+}
+
+Result DecodeScalarFunction(
+ Id id, const substrait::Expression::ScalarFunction& scalar_fn,
+ const ExtensionSet& ext_set, const ConversionOptions& conversion_options) {
+ ARROW_ASSIGN_OR_RAISE(auto output_type_and_nullable,
+ FromProto(scalar_fn.output_type(), ext_set, conversion_options));
+ SubstraitCall call(id, output_type_and_nullable.first, output_type_and_nullable.second);
+ for (int i = 0; i < scalar_fn.arguments_size(); i++) {
+ ARROW_RETURN_NOT_OK(DecodeArg(scalar_fn.arguments(i), static_cast(i), &call,
+ ext_set, conversion_options));
+ }
+ return std::move(call);
+}
+
+std::string EnumToString(int value, const google::protobuf::EnumDescriptor* descriptor) {
+ const google::protobuf::EnumValueDescriptor* value_desc =
+ descriptor->FindValueByNumber(value);
+ if (value_desc == nullptr) {
+ return "unknown";
+ }
+ return value_desc->name();
+}
+
+Result FromProto(const substrait::AggregateFunction& func, bool is_hash,
+ const ExtensionSet& ext_set,
+ const ConversionOptions& conversion_options) {
+ if (func.phase() != substrait::AggregationPhase::AGGREGATION_PHASE_INITIAL_TO_RESULT) {
+ return Status::NotImplemented(
+ "Unsupported aggregation phase '",
+ EnumToString(func.phase(), substrait::AggregationPhase_descriptor()),
+ "'. Only INITIAL_TO_RESULT is supported");
+ }
+ if (func.invocation() !=
+ substrait::AggregateFunction::AggregationInvocation::
+ AggregateFunction_AggregationInvocation_AGGREGATION_INVOCATION_ALL) {
+ return Status::NotImplemented(
+ "Unsupported aggregation invocation '",
+ EnumToString(func.invocation(),
+ substrait::AggregateFunction::AggregationInvocation_descriptor()),
+ "'. Only AGGREGATION_INVOCATION_ALL is "
+ "supported");
+ }
+ if (func.sorts_size() > 0) {
+ return Status::NotImplemented("Aggregation sorts are not supported");
+ }
+ ARROW_ASSIGN_OR_RAISE(auto output_type_and_nullable,
+ FromProto(func.output_type(), ext_set, conversion_options));
+ ARROW_ASSIGN_OR_RAISE(Id id, ext_set.DecodeFunction(func.function_reference()));
+ SubstraitCall call(id, output_type_and_nullable.first, output_type_and_nullable.second,
+ is_hash);
+ for (int i = 0; i < func.arguments_size(); i++) {
+ ARROW_RETURN_NOT_OK(DecodeArg(func.arguments(i), static_cast(i), &call,
+ ext_set, conversion_options));
+ }
+ return std::move(call);
+}
+
Result FromProto(const substrait::Expression& expr,
const ExtensionSet& ext_set,
const ConversionOptions& conversion_options) {
@@ -166,34 +244,14 @@ Result FromProto(const substrait::Expression& expr,
case substrait::Expression::kScalarFunction: {
const auto& scalar_fn = expr.scalar_function();
- ARROW_ASSIGN_OR_RAISE(auto decoded_function,
+ ARROW_ASSIGN_OR_RAISE(Id function_id,
ext_set.DecodeFunction(scalar_fn.function_reference()));
-
- std::vector arguments(scalar_fn.arguments_size());
- for (int i = 0; i < scalar_fn.arguments_size(); ++i) {
- const auto& argument = scalar_fn.arguments(i);
- switch (argument.arg_type_case()) {
- case substrait::FunctionArgument::kValue: {
- ARROW_ASSIGN_OR_RAISE(
- arguments[i], FromProto(argument.value(), ext_set, conversion_options));
- break;
- }
- default:
- return Status::NotImplemented(
- "only value arguments are currently supported for functions");
- }
- }
-
- auto func_name = decoded_function.name.to_string();
- if (func_name != "cast") {
- return compute::call(func_name, std::move(arguments));
- } else {
- ARROW_ASSIGN_OR_RAISE(
- auto output_type_desc,
- FromProto(scalar_fn.output_type(), ext_set, conversion_options));
- auto cast_options = compute::CastOptions::Safe(std::move(output_type_desc.first));
- return compute::call(func_name, std::move(arguments), std::move(cast_options));
- }
+ ARROW_ASSIGN_OR_RAISE(ExtensionIdRegistry::SubstraitCallToArrow function_converter,
+ ext_set.registry()->GetSubstraitCallToArrow(function_id));
+ ARROW_ASSIGN_OR_RAISE(
+ SubstraitCall substrait_call,
+ DecodeScalarFunction(function_id, scalar_fn, ext_set, conversion_options));
+ return function_converter(substrait_call);
}
default:
@@ -827,6 +885,42 @@ static Result> MakeListElementReference(
return MakeDirectReference(std::move(expr), std::move(ref_segment));
}
+Result> EncodeSubstraitCall(
+ const SubstraitCall& call, ExtensionSet* ext_set,
+ const ConversionOptions& conversion_options) {
+ ARROW_ASSIGN_OR_RAISE(uint32_t anchor, ext_set->EncodeFunction(call.id()));
+ auto scalar_fn = internal::make_unique();
+ scalar_fn->set_function_reference(anchor);
+ ARROW_ASSIGN_OR_RAISE(
+ std::unique_ptr output_type,
+ ToProto(*call.output_type(), call.output_nullable(), ext_set, conversion_options));
+ scalar_fn->set_allocated_output_type(output_type.release());
+
+ for (uint32_t i = 0; i < call.size(); i++) {
+ substrait::FunctionArgument* arg = scalar_fn->add_arguments();
+ if (call.HasEnumArg(i)) {
+ auto enum_val = internal::make_unique();
+ ARROW_ASSIGN_OR_RAISE(util::optional enum_arg,
+ call.GetEnumArg(i));
+ if (enum_arg) {
+ enum_val->set_specified(enum_arg->to_string());
+ } else {
+ enum_val->set_allocated_unspecified(new google::protobuf::Empty());
+ }
+ arg->set_allocated_enum_(enum_val.release());
+ } else if (call.HasValueArg(i)) {
+ ARROW_ASSIGN_OR_RAISE(compute::Expression value_arg, call.GetValueArg(i));
+ ARROW_ASSIGN_OR_RAISE(std::unique_ptr value_expr,
+ ToProto(value_arg, ext_set, conversion_options));
+ arg->set_allocated_value(value_expr.release());
+ } else {
+ return Status::Invalid("Call reported having ", call.size(),
+ " arguments but no argument could be found at index ", i);
+ }
+ }
+ return std::move(scalar_fn);
+}
+
Result> ToProto(
const compute::Expression& expr, ExtensionSet* ext_set,
const ConversionOptions& conversion_options) {
@@ -933,17 +1027,12 @@ Result> ToProto(
}
// other expression types dive into extensions immediately
- ARROW_ASSIGN_OR_RAISE(auto anchor, ext_set->EncodeFunction(call->function_name));
-
- auto scalar_fn = internal::make_unique();
- scalar_fn->set_function_reference(anchor);
- scalar_fn->mutable_arguments()->Reserve(static_cast(arguments.size()));
- for (auto& arg : arguments) {
- auto argument = internal::make_unique();
- argument->set_allocated_value(arg.release());
- scalar_fn->mutable_arguments()->AddAllocated(argument.release());
- }
-
+ ARROW_ASSIGN_OR_RAISE(
+ ExtensionIdRegistry::ArrowToSubstraitCall converter,
+ ext_set->registry()->GetArrowToSubstraitCall(call->function_name));
+ ARROW_ASSIGN_OR_RAISE(SubstraitCall substrait_call, converter(*call));
+ ARROW_ASSIGN_OR_RAISE(std::unique_ptr scalar_fn,
+ EncodeSubstraitCall(substrait_call, ext_set, conversion_options));
out->set_allocated_scalar_function(scalar_fn.release());
return std::move(out);
}
diff --git a/cpp/src/arrow/engine/substrait/expression_internal.h b/cpp/src/arrow/engine/substrait/expression_internal.h
index 2b4dec2a00b21..f132afc0c1ac9 100644
--- a/cpp/src/arrow/engine/substrait/expression_internal.h
+++ b/cpp/src/arrow/engine/substrait/expression_internal.h
@@ -50,5 +50,9 @@ Result> ToProto(const Datum&,
ExtensionSet*,
const ConversionOptions&);
+ARROW_ENGINE_EXPORT
+Result FromProto(const substrait::AggregateFunction&, bool is_hash,
+ const ExtensionSet&, const ConversionOptions&);
+
} // namespace engine
} // namespace arrow
diff --git a/cpp/src/arrow/engine/substrait/ext_test.cc b/cpp/src/arrow/engine/substrait/ext_test.cc
index 8e41cb7c98cee..4b37aa8fcdba3 100644
--- a/cpp/src/arrow/engine/substrait/ext_test.cc
+++ b/cpp/src/arrow/engine/substrait/ext_test.cc
@@ -56,12 +56,10 @@ struct DefaultExtensionIdRegistryProvider : public ExtensionIdRegistryProvider {
struct NestedExtensionIdRegistryProvider : public ExtensionIdRegistryProvider {
virtual ~NestedExtensionIdRegistryProvider() {}
- std::shared_ptr registry_ = substrait::MakeExtensionIdRegistry();
+ std::shared_ptr registry_ = MakeExtensionIdRegistry();
ExtensionIdRegistry* get() const override { return &*registry_; }
};
-using Id = ExtensionIdRegistry::Id;
-
bool operator==(const Id& id1, const Id& id2) {
return id1.uri == id2.uri && id1.name == id2.name;
}
@@ -85,8 +83,8 @@ static const std::vector kTypeNames = {
TypeName{month_day_nano_interval(), "interval_month_day_nano"},
};
-static const std::vector kFunctionNames = {
- "add",
+static const std::vector kFunctionIds = {
+ {kSubstraitArithmeticFunctionsUri, "add"},
};
static const std::vector kTempFunctionNames = {
@@ -141,15 +139,12 @@ TEST_P(ExtensionIdRegistryTest, GetFunctions) {
auto provider = std::get<0>(GetParam());
auto registry = provider->get();
- for (util::string_view name : kFunctionNames) {
- auto id = Id{kArrowExtTypesUri, name};
- for (auto funcrec_opt : {registry->GetFunction(id), registry->GetFunction(name)}) {
- ASSERT_TRUE(funcrec_opt);
- auto funcrec = funcrec_opt.value();
- ASSERT_EQ(id, funcrec.id);
- ASSERT_EQ(name, funcrec.function_name);
- }
+ for (Id func_id : kFunctionIds) {
+ ASSERT_OK_AND_ASSIGN(ExtensionIdRegistry::SubstraitCallToArrow converter,
+ registry->GetSubstraitCallToArrow(func_id));
+ ASSERT_TRUE(converter);
}
+ ASSERT_RAISES(NotImplemented, registry->GetSubstraitCallToArrow(kNonExistentId));
ASSERT_FALSE(registry->GetType(kNonExistentId));
ASSERT_FALSE(registry->GetType(*kNonExistentTypeName.type));
}
@@ -158,10 +153,10 @@ TEST_P(ExtensionIdRegistryTest, ReregisterFunctions) {
auto provider = std::get<0>(GetParam());
auto registry = provider->get();
- for (util::string_view name : kFunctionNames) {
- auto id = Id{kArrowExtTypesUri, name};
- ASSERT_RAISES(Invalid, registry->CanRegisterFunction(id, name.to_string()));
- ASSERT_RAISES(Invalid, registry->RegisterFunction(id, name.to_string()));
+ for (Id function_id : kFunctionIds) {
+ ASSERT_RAISES(Invalid, registry->CanAddSubstraitCallToArrow(function_id));
+ ASSERT_RAISES(Invalid, registry->AddSubstraitCallToArrow(
+ function_id, function_id.name.to_string()));
}
}
@@ -173,11 +168,26 @@ INSTANTIATE_TEST_SUITE_P(
std::make_tuple(std::make_shared(),
"nested")));
+TEST(ExtensionIdRegistryTest, GetSupportedSubstraitFunctions) {
+ ExtensionIdRegistry* default_registry = default_extension_id_registry();
+ std::vector supported_functions =
+ default_registry->GetSupportedSubstraitFunctions();
+ std::size_t num_functions = supported_functions.size();
+ ASSERT_GT(num_functions, 0);
+
+ std::shared_ptr nested =
+ nested_extension_id_registry(default_registry);
+ ASSERT_OK(nested->AddSubstraitCallToArrow(kNonExistentId, "some_function"));
+
+ std::size_t num_nested_functions = nested->GetSupportedSubstraitFunctions().size();
+ ASSERT_EQ(num_functions + 1, num_nested_functions);
+}
+
TEST(ExtensionIdRegistryTest, RegisterTempTypes) {
auto default_registry = default_extension_id_registry();
constexpr int rounds = 3;
for (int i = 0; i < rounds; i++) {
- auto registry = substrait::MakeExtensionIdRegistry();
+ auto registry = MakeExtensionIdRegistry();
for (TypeName e : kTempTypeNames) {
auto id = Id{kArrowExtTypesUri, e.name};
@@ -194,15 +204,15 @@ TEST(ExtensionIdRegistryTest, RegisterTempFunctions) {
auto default_registry = default_extension_id_registry();
constexpr int rounds = 3;
for (int i = 0; i < rounds; i++) {
- auto registry = substrait::MakeExtensionIdRegistry();
+ auto registry = MakeExtensionIdRegistry();
for (util::string_view name : kTempFunctionNames) {
auto id = Id{kArrowExtTypesUri, name};
- ASSERT_OK(registry->CanRegisterFunction(id, name.to_string()));
- ASSERT_OK(registry->RegisterFunction(id, name.to_string()));
- ASSERT_RAISES(Invalid, registry->CanRegisterFunction(id, name.to_string()));
- ASSERT_RAISES(Invalid, registry->RegisterFunction(id, name.to_string()));
- ASSERT_OK(default_registry->CanRegisterFunction(id, name.to_string()));
+ ASSERT_OK(registry->CanAddSubstraitCallToArrow(id));
+ ASSERT_OK(registry->AddSubstraitCallToArrow(id, name.to_string()));
+ ASSERT_RAISES(Invalid, registry->CanAddSubstraitCallToArrow(id));
+ ASSERT_RAISES(Invalid, registry->AddSubstraitCallToArrow(id, name.to_string()));
+ ASSERT_OK(default_registry->CanAddSubstraitCallToArrow(id));
}
}
}
@@ -246,24 +256,24 @@ TEST(ExtensionIdRegistryTest, RegisterNestedFunctions) {
auto default_registry = default_extension_id_registry();
constexpr int rounds = 3;
for (int i = 0; i < rounds; i++) {
- auto registry1 = substrait::MakeExtensionIdRegistry();
+ auto registry1 = MakeExtensionIdRegistry();
- ASSERT_OK(registry1->CanRegisterFunction(id1, name1.to_string()));
- ASSERT_OK(registry1->RegisterFunction(id1, name1.to_string()));
+ ASSERT_OK(registry1->CanAddSubstraitCallToArrow(id1));
+ ASSERT_OK(registry1->AddSubstraitCallToArrow(id1, name1.to_string()));
for (int j = 0; j < rounds; j++) {
- auto registry2 = substrait::MakeExtensionIdRegistry();
+ auto registry2 = MakeExtensionIdRegistry();
- ASSERT_OK(registry2->CanRegisterFunction(id2, name2.to_string()));
- ASSERT_OK(registry2->RegisterFunction(id2, name2.to_string()));
- ASSERT_RAISES(Invalid, registry2->CanRegisterFunction(id2, name2.to_string()));
- ASSERT_RAISES(Invalid, registry2->RegisterFunction(id2, name2.to_string()));
- ASSERT_OK(default_registry->CanRegisterFunction(id2, name2.to_string()));
+ ASSERT_OK(registry2->CanAddSubstraitCallToArrow(id2));
+ ASSERT_OK(registry2->AddSubstraitCallToArrow(id2, name2.to_string()));
+ ASSERT_RAISES(Invalid, registry2->CanAddSubstraitCallToArrow(id2));
+ ASSERT_RAISES(Invalid, registry2->AddSubstraitCallToArrow(id2, name2.to_string()));
+ ASSERT_OK(default_registry->CanAddSubstraitCallToArrow(id2));
}
- ASSERT_RAISES(Invalid, registry1->CanRegisterFunction(id1, name1.to_string()));
- ASSERT_RAISES(Invalid, registry1->RegisterFunction(id1, name1.to_string()));
- ASSERT_OK(default_registry->CanRegisterFunction(id1, name1.to_string()));
+ ASSERT_RAISES(Invalid, registry1->CanAddSubstraitCallToArrow(id1));
+ ASSERT_RAISES(Invalid, registry1->AddSubstraitCallToArrow(id1, name1.to_string()));
+ ASSERT_OK(default_registry->CanAddSubstraitCallToArrow(id1));
}
}
diff --git a/cpp/src/arrow/engine/substrait/extension_set.cc b/cpp/src/arrow/engine/substrait/extension_set.cc
index 08eb6acc9ca89..493d576e839bb 100644
--- a/cpp/src/arrow/engine/substrait/extension_set.cc
+++ b/cpp/src/arrow/engine/substrait/extension_set.cc
@@ -17,9 +17,9 @@
#include "arrow/engine/substrait/extension_set.h"
-#include
-#include
+#include
+#include "arrow/engine/substrait/expression_internal.h"
#include "arrow/util/hash_util.h"
#include "arrow/util/hashing.h"
#include "arrow/util/string_view.h"
@@ -28,6 +28,9 @@ namespace arrow {
namespace engine {
namespace {
+// TODO(ARROW-16988): replace this with EXACT_ROUNDTRIP mode
+constexpr bool kExactRoundTrip = true;
+
struct TypePtrHashEq {
template
size_t operator()(const Ptr& type) const {
@@ -42,16 +45,115 @@ struct TypePtrHashEq {
} // namespace
-size_t ExtensionIdRegistry::IdHashEq::operator()(ExtensionIdRegistry::Id id) const {
+std::string Id::ToString() const {
+ std::stringstream sstream;
+ sstream << uri;
+ sstream << '#';
+ sstream << name;
+ return sstream.str();
+}
+
+size_t IdHashEq::operator()(Id id) const {
constexpr ::arrow::internal::StringViewHash hash = {};
auto out = static_cast(hash(id.uri));
::arrow::internal::hash_combine(out, hash(id.name));
return out;
}
-bool ExtensionIdRegistry::IdHashEq::operator()(ExtensionIdRegistry::Id l,
- ExtensionIdRegistry::Id r) const {
- return l.uri == r.uri && l.name == r.name;
+bool IdHashEq::operator()(Id l, Id r) const { return l.uri == r.uri && l.name == r.name; }
+
+Id IdStorage::Emplace(Id id) {
+ util::string_view owned_uri = EmplaceUri(id.uri);
+
+ util::string_view owned_name;
+ auto name_itr = names_.find(id.name);
+ if (name_itr == names_.end()) {
+ owned_names_.emplace_back(id.name);
+ owned_name = owned_names_.back();
+ names_.insert(owned_name);
+ } else {
+ owned_name = *name_itr;
+ }
+
+ return {owned_uri, owned_name};
+}
+
+util::optional IdStorage::Find(Id id) const {
+ util::optional maybe_owned_uri = FindUri(id.uri);
+ if (!maybe_owned_uri) {
+ return util::nullopt;
+ }
+
+ auto name_itr = names_.find(id.name);
+ if (name_itr == names_.end()) {
+ return util::nullopt;
+ } else {
+ return Id{*maybe_owned_uri, *name_itr};
+ }
+}
+
+util::optional IdStorage::FindUri(util::string_view uri) const {
+ auto uri_itr = uris_.find(uri);
+ if (uri_itr == uris_.end()) {
+ return util::nullopt;
+ }
+ return *uri_itr;
+}
+
+util::string_view IdStorage::EmplaceUri(util::string_view uri) {
+ auto uri_itr = uris_.find(uri);
+ if (uri_itr == uris_.end()) {
+ owned_uris_.emplace_back(uri);
+ util::string_view owned_uri = owned_uris_.back();
+ uris_.insert(owned_uri);
+ return owned_uri;
+ }
+ return *uri_itr;
+}
+
+Result> SubstraitCall::GetEnumArg(
+ uint32_t index) const {
+ if (index >= size_) {
+ return Status::Invalid("Expected Substrait call to have an enum argument at index ",
+ index, " but it did not have enough arguments");
+ }
+ auto enum_arg_it = enum_args_.find(index);
+ if (enum_arg_it == enum_args_.end()) {
+ return Status::Invalid("Expected Substrait call to have an enum argument at index ",
+ index, " but the argument was not an enum.");
+ }
+ return enum_arg_it->second;
+}
+
+bool SubstraitCall::HasEnumArg(uint32_t index) const {
+ return enum_args_.find(index) != enum_args_.end();
+}
+
+void SubstraitCall::SetEnumArg(uint32_t index, util::optional enum_arg) {
+ size_ = std::max(size_, index + 1);
+ enum_args_[index] = std::move(enum_arg);
+}
+
+Result SubstraitCall::GetValueArg(uint32_t index) const {
+ if (index >= size_) {
+ return Status::Invalid("Expected Substrait call to have a value argument at index ",
+ index, " but it did not have enough arguments");
+ }
+ auto value_arg_it = value_args_.find(index);
+ if (value_arg_it == value_args_.end()) {
+ return Status::Invalid("Expected Substrait call to have a value argument at index ",
+ index, " but the argument was not a value");
+ }
+ return value_arg_it->second;
+}
+
+bool SubstraitCall::HasValueArg(uint32_t index) const {
+ return value_args_.find(index) != value_args_.end();
+}
+
+void SubstraitCall::SetValueArg(uint32_t index, compute::Expression value_arg) {
+ size_ = std::max(size_, index + 1);
+ value_args_[index] = std::move(value_arg);
}
// A builder used when creating a Substrait plan from an Arrow execution plan. In
@@ -97,54 +199,54 @@ Result ExtensionSet::Make(
std::unordered_map uris,
std::unordered_map type_ids,
std::unordered_map function_ids, const ExtensionIdRegistry* registry) {
- ExtensionSet set;
+ ExtensionSet set(default_extension_id_registry());
set.registry_ = registry;
- // TODO(bkietz) move this into the registry as registry->OwnUris(&uris) or so
- std::unordered_set
- uris_owned_by_registry;
- for (util::string_view uri : registry->Uris()) {
- uris_owned_by_registry.insert(uri);
- }
-
for (auto& uri : uris) {
- auto it = uris_owned_by_registry.find(uri.second);
- if (it == uris_owned_by_registry.end()) {
- return Status::KeyError("Uri '", uri.second, "' not found in registry");
+ util::optional maybe_uri_internal = registry->FindUri(uri.second);
+ if (maybe_uri_internal) {
+ set.uris_[uri.first] = *maybe_uri_internal;
+ } else {
+ if (kExactRoundTrip) {
+ return Status::Invalid(
+ "Plan contained a URI that the extension registry is unaware of: ",
+ uri.second);
+ }
+ set.uris_[uri.first] = set.plan_specific_ids_.EmplaceUri(uri.second);
}
- uri.second = *it; // Ensure uris point into the registry's memory
- set.AddUri(uri);
}
set.types_.reserve(type_ids.size());
+ for (const auto& type_id : type_ids) {
+ if (type_id.second.empty()) continue;
+ RETURN_NOT_OK(set.CheckHasUri(type_id.second.uri));
- for (unsigned int i = 0; i < static_cast(type_ids.size()); ++i) {
- if (type_ids[i].empty()) continue;
- RETURN_NOT_OK(set.CheckHasUri(type_ids[i].uri));
-
- if (auto rec = registry->GetType(type_ids[i])) {
- set.types_[i] = {rec->id, rec->type};
+ if (auto rec = registry->GetType(type_id.second)) {
+ set.types_[type_id.first] = {rec->id, rec->type};
continue;
}
- return Status::Invalid("Type ", type_ids[i].uri, "#", type_ids[i].name, " not found");
+ return Status::Invalid("Type ", type_id.second.uri, "#", type_id.second.name,
+ " not found");
}
set.functions_.reserve(function_ids.size());
-
- for (unsigned int i = 0; i < static_cast(function_ids.size()); ++i) {
- if (function_ids[i].empty()) continue;
- RETURN_NOT_OK(set.CheckHasUri(function_ids[i].uri));
-
- if (auto rec = registry->GetFunction(function_ids[i])) {
- set.functions_[i] = {rec->id, rec->function_name};
- continue;
+ for (const auto& function_id : function_ids) {
+ if (function_id.second.empty()) continue;
+ RETURN_NOT_OK(set.CheckHasUri(function_id.second.uri));
+ util::optional maybe_id_internal = registry->FindId(function_id.second);
+ if (maybe_id_internal) {
+ set.functions_[function_id.first] = *maybe_id_internal;
+ } else {
+ if (kExactRoundTrip) {
+ return Status::Invalid(
+ "Plan contained a function id that the extension registry is unaware of: ",
+ function_id.second.uri, "#", function_id.second.name);
+ }
+ set.functions_[function_id.first] =
+ set.plan_specific_ids_.Emplace(function_id.second);
}
- return Status::Invalid("Function ", function_ids[i].uri, "#", function_ids[i].name,
- " not found");
}
- set.uris_ = std::move(uris);
-
return std::move(set);
}
@@ -162,39 +264,34 @@ Result ExtensionSet::EncodeType(const DataType& type) {
auto it_success =
types_map_.emplace(rec->id, static_cast(types_map_.size()));
if (it_success.second) {
- DCHECK_EQ(types_.find(static_cast(types_.size())), types_.end())
+ DCHECK_EQ(types_.find(static_cast(types_.size())), types_.end())
<< "Type existed in types_ but not types_map_. ExtensionSet is inconsistent";
- types_[static_cast(types_.size())] = {rec->id, rec->type};
+ types_[static_cast(types_.size())] = {rec->id, rec->type};
}
return it_success.first->second;
}
return Status::KeyError("type ", type.ToString(), " not found in the registry");
}
-Result ExtensionSet::DecodeFunction(uint32_t anchor) const {
- if (functions_.find(anchor) == functions_.end() || functions_.at(anchor).id.empty()) {
+Result ExtensionSet::DecodeFunction(uint32_t anchor) const {
+ if (functions_.find(anchor) == functions_.end() || functions_.at(anchor).empty()) {
return Status::Invalid("User defined function reference ", anchor,
" did not have a corresponding anchor in the extension set");
}
return functions_.at(anchor);
}
-Result ExtensionSet::EncodeFunction(util::string_view function_name) {
- if (auto rec = registry_->GetFunction(function_name)) {
- RETURN_NOT_OK(this->AddUri(rec->id));
- auto it_success =
- functions_map_.emplace(rec->id, static_cast(functions_map_.size()));
- if (it_success.second) {
- DCHECK_EQ(functions_.find(static_cast(functions_.size())),
- functions_.end())
- << "Function existed in functions_ but not functions_map_. ExtensionSet is "
- "inconsistent";
- functions_[static_cast(functions_.size())] = {rec->id,
- rec->function_name};
- }
- return it_success.first->second;
+Result ExtensionSet::EncodeFunction(Id function_id) {
+ RETURN_NOT_OK(this->AddUri(function_id));
+ auto it_success =
+ functions_map_.emplace(function_id, static_cast(functions_map_.size()));
+ if (it_success.second) {
+ DCHECK_EQ(functions_.find(static_cast(functions_.size())), functions_.end())
+ << "Function existed in functions_ but not functions_map_. ExtensionSet is "
+ "inconsistent";
+ functions_[static_cast(functions_.size())] = function_id;
}
- return Status::KeyError("function ", function_name, " not found in the registry");
+ return it_success.first->second;
}
template
@@ -207,16 +304,38 @@ const int* GetIndex(const KeyToIndex& key_to_index, const Key& key) {
namespace {
struct ExtensionIdRegistryImpl : ExtensionIdRegistry {
+ ExtensionIdRegistryImpl() : parent_(nullptr) {}
+ explicit ExtensionIdRegistryImpl(const ExtensionIdRegistry* parent) : parent_(parent) {}
+
virtual ~ExtensionIdRegistryImpl() {}
- std::vector Uris() const override {
- return {uris_.begin(), uris_.end()};
+ util::optional FindUri(util::string_view uri) const override {
+ if (parent_) {
+ util::optional parent_uri = parent_->FindUri(uri);
+ if (parent_uri) {
+ return parent_uri;
+ }
+ }
+ return ids_.FindUri(uri);
+ }
+
+ util::optional FindId(Id id) const override {
+ if (parent_) {
+ util::optional parent_id = parent_->FindId(id);
+ if (parent_id) {
+ return parent_id;
+ }
+ }
+ return ids_.Find(id);
}
util::optional GetType(const DataType& type) const override {
if (auto index = GetIndex(type_to_index_, &type)) {
return TypeRecord{type_ids_[*index], types_[*index]};
}
+ if (parent_) {
+ return parent_->GetType(type);
+ }
return {};
}
@@ -224,6 +343,9 @@ struct ExtensionIdRegistryImpl : ExtensionIdRegistry {
if (auto index = GetIndex(id_to_index_, id)) {
return TypeRecord{type_ids_[*index], types_[*index]};
}
+ if (parent_) {
+ return parent_->GetType(id);
+ }
return {};
}
@@ -234,14 +356,20 @@ struct ExtensionIdRegistryImpl : ExtensionIdRegistry {
if (type_to_index_.find(&*type) != type_to_index_.end()) {
return Status::Invalid("Type was already registered");
}
+ if (parent_) {
+ return parent_->CanRegisterType(id, type);
+ }
return Status::OK();
}
Status RegisterType(Id id, std::shared_ptr type) override {
DCHECK_EQ(type_ids_.size(), types_.size());
- Id copied_id{*uris_.emplace(id.uri.to_string()).first,
- *names_.emplace(id.name.to_string()).first};
+ if (parent_) {
+ ARROW_RETURN_NOT_OK(parent_->CanRegisterType(id, type));
+ }
+
+ Id copied_id = ids_.Emplace(id);
auto index = static_cast(type_ids_.size());
@@ -261,155 +389,394 @@ struct ExtensionIdRegistryImpl : ExtensionIdRegistry {
return Status::OK();
}
- util::optional GetFunction(
- util::string_view arrow_function_name) const override {
- if (auto index = GetIndex(function_name_to_index_, arrow_function_name)) {
- return FunctionRecord{function_ids_[*index], *function_name_ptrs_[*index]};
+ Status CanAddSubstraitCallToArrow(Id substrait_function_id) const override {
+ if (substrait_to_arrow_.find(substrait_function_id) != substrait_to_arrow_.end()) {
+ return Status::Invalid("Cannot register function converter for Substrait id ",
+ substrait_function_id.ToString(),
+ " because a converter already exists");
}
- return {};
+ if (parent_) {
+ return parent_->CanAddSubstraitCallToArrow(substrait_function_id);
+ }
+ return Status::OK();
}
- util::optional GetFunction(Id id) const override {
- if (auto index = GetIndex(function_id_to_index_, id)) {
- return FunctionRecord{function_ids_[*index], *function_name_ptrs_[*index]};
+ Status CanAddSubstraitAggregateToArrow(Id substrait_function_id) const override {
+ if (substrait_to_arrow_agg_.find(substrait_function_id) !=
+ substrait_to_arrow_agg_.end()) {
+ return Status::Invalid(
+ "Cannot register aggregate function converter for Substrait id ",
+ substrait_function_id.ToString(),
+ " because an aggregate converter already exists");
}
- return {};
+ if (parent_) {
+ return parent_->CanAddSubstraitAggregateToArrow(substrait_function_id);
+ }
+ return Status::OK();
+ }
+
+ template
+ Status AddSubstraitToArrowFunc(
+ Id substrait_id, ConverterType conversion_func,
+ std::unordered_map* dest) {
+ // Convert id to view into registry-owned memory
+ Id copied_id = ids_.Emplace(substrait_id);
+
+ auto add_result = dest->emplace(copied_id, std::move(conversion_func));
+ if (!add_result.second) {
+ return Status::Invalid(
+ "Failed to register Substrait to Arrow function converter because a converter "
+ "already existed for Substrait id ",
+ substrait_id.ToString());
+ }
+
+ return Status::OK();
+ }
+
+ Status AddSubstraitCallToArrow(Id substrait_function_id,
+ SubstraitCallToArrow conversion_func) override {
+ if (parent_) {
+ ARROW_RETURN_NOT_OK(parent_->CanAddSubstraitCallToArrow(substrait_function_id));
+ }
+ return AddSubstraitToArrowFunc(
+ substrait_function_id, std::move(conversion_func), &substrait_to_arrow_);
}
- Status CanRegisterFunction(Id id,
- const std::string& arrow_function_name) const override {
- if (function_id_to_index_.find(id) != function_id_to_index_.end()) {
- return Status::Invalid("Function id was already registered");
+ Status AddSubstraitAggregateToArrow(
+ Id substrait_function_id, SubstraitAggregateToArrow conversion_func) override {
+ if (parent_) {
+ ARROW_RETURN_NOT_OK(
+ parent_->CanAddSubstraitAggregateToArrow(substrait_function_id));
}
- if (function_name_to_index_.find(arrow_function_name) !=
- function_name_to_index_.end()) {
- return Status::Invalid("Function name was already registered");
+ return AddSubstraitToArrowFunc(
+ substrait_function_id, std::move(conversion_func), &substrait_to_arrow_agg_);
+ }
+
+ template
+ Status AddArrowToSubstraitFunc(std::string arrow_function_name, ConverterType converter,
+ std::unordered_map* dest) {
+ auto add_result = dest->emplace(std::move(arrow_function_name), std::move(converter));
+ if (!add_result.second) {
+ return Status::Invalid(
+ "Failed to register Arrow to Substrait function converter for Arrow function ",
+ arrow_function_name, " because a converter already existed");
}
return Status::OK();
}
- Status RegisterFunction(Id id, std::string arrow_function_name) override {
- DCHECK_EQ(function_ids_.size(), function_name_ptrs_.size());
+ Status AddArrowToSubstraitCall(std::string arrow_function_name,
+ ArrowToSubstraitCall converter) override {
+ if (parent_) {
+ ARROW_RETURN_NOT_OK(parent_->CanAddArrowToSubstraitCall(arrow_function_name));
+ }
+ return AddArrowToSubstraitFunc(std::move(arrow_function_name), converter,
+ &arrow_to_substrait_);
+ }
- Id copied_id{*uris_.emplace(id.uri.to_string()).first,
- *names_.emplace(id.name.to_string()).first};
+ Status AddArrowToSubstraitAggregate(std::string arrow_function_name,
+ ArrowToSubstraitAggregate converter) override {
+ if (parent_) {
+ ARROW_RETURN_NOT_OK(parent_->CanAddArrowToSubstraitAggregate(arrow_function_name));
+ }
+ return AddArrowToSubstraitFunc(std::move(arrow_function_name), converter,
+ &arrow_to_substrait_agg_);
+ }
- const std::string& copied_function_name{
- *function_names_.emplace(std::move(arrow_function_name)).first};
+ Status CanAddArrowToSubstraitCall(const std::string& function_name) const override {
+ if (arrow_to_substrait_.find(function_name) != arrow_to_substrait_.end()) {
+ return Status::Invalid(
+ "Cannot register function converter because a converter already exists");
+ }
+ if (parent_) {
+ return parent_->CanAddArrowToSubstraitCall(function_name);
+ }
+ return Status::OK();
+ }
- auto index = static_cast(function_ids_.size());
+ Status CanAddArrowToSubstraitAggregate(
+ const std::string& function_name) const override {
+ if (arrow_to_substrait_agg_.find(function_name) != arrow_to_substrait_agg_.end()) {
+ return Status::Invalid(
+ "Cannot register function converter because a converter already exists");
+ }
+ if (parent_) {
+ return parent_->CanAddArrowToSubstraitAggregate(function_name);
+ }
+ return Status::OK();
+ }
- auto it_success = function_id_to_index_.emplace(copied_id, index);
+ Result GetSubstraitCallToArrow(
+ Id substrait_function_id) const override {
+ auto maybe_converter = substrait_to_arrow_.find(substrait_function_id);
+ if (maybe_converter == substrait_to_arrow_.end()) {
+ if (parent_) {
+ return parent_->GetSubstraitCallToArrow(substrait_function_id);
+ }
+ return Status::NotImplemented(
+ "No conversion function exists to convert the Substrait function ",
+ substrait_function_id.uri, "#", substrait_function_id.name,
+ " to an Arrow call expression");
+ }
+ return maybe_converter->second;
+ }
- if (!it_success.second) {
- return Status::Invalid("Function id was already registered");
+ Result GetSubstraitAggregateToArrow(
+ Id substrait_function_id) const override {
+ auto maybe_converter = substrait_to_arrow_agg_.find(substrait_function_id);
+ if (maybe_converter == substrait_to_arrow_agg_.end()) {
+ if (parent_) {
+ return parent_->GetSubstraitAggregateToArrow(substrait_function_id);
+ }
+ return Status::NotImplemented(
+ "No conversion function exists to convert the Substrait aggregate function ",
+ substrait_function_id.uri, "#", substrait_function_id.name,
+ " to an Arrow aggregate");
}
+ return maybe_converter->second;
+ }
- if (!function_name_to_index_.emplace(copied_function_name, index).second) {
- function_id_to_index_.erase(it_success.first);
- return Status::Invalid("Function name was already registered");
+ Result GetArrowToSubstraitCall(
+ const std::string& arrow_function_name) const override {
+ auto maybe_converter = arrow_to_substrait_.find(arrow_function_name);
+ if (maybe_converter == arrow_to_substrait_.end()) {
+ if (parent_) {
+ return parent_->GetArrowToSubstraitCall(arrow_function_name);
+ }
+ return Status::NotImplemented(
+ "No conversion function exists to convert the Arrow function ",
+ arrow_function_name, " to a Substrait call");
}
+ return maybe_converter->second;
+ }
- function_name_ptrs_.push_back(&copied_function_name);
- function_ids_.push_back(copied_id);
- return Status::OK();
+ Result GetArrowToSubstraitAggregate(
+ const std::string& arrow_function_name) const override {
+ auto maybe_converter = arrow_to_substrait_agg_.find(arrow_function_name);
+ if (maybe_converter == arrow_to_substrait_agg_.end()) {
+ if (parent_) {
+ return parent_->GetArrowToSubstraitAggregate(arrow_function_name);
+ }
+ return Status::NotImplemented(
+ "No conversion function exists to convert the Arrow aggregate ",
+ arrow_function_name, " to a Substrait aggregate");
+ }
+ return maybe_converter->second;
}
- Status RegisterFunction(std::string uri, std::string name,
- std::string arrow_function_name) override {
- return RegisterFunction({uri, name}, arrow_function_name);
+ std::vector GetSupportedSubstraitFunctions() const override {
+ std::vector encoded_ids;
+ for (const auto& entry : substrait_to_arrow_) {
+ encoded_ids.push_back(entry.first.ToString());
+ }
+ for (const auto& entry : substrait_to_arrow_agg_) {
+ encoded_ids.push_back(entry.first.ToString());
+ }
+ if (parent_) {
+ std::vector parent_ids = parent_->GetSupportedSubstraitFunctions();
+ encoded_ids.insert(encoded_ids.end(), make_move_iterator(parent_ids.begin()),
+ make_move_iterator(parent_ids.end()));
+ }
+ std::sort(encoded_ids.begin(), encoded_ids.end());
+ return encoded_ids;
}
- // owning storage of uris, names, (arrow::)function_names, types
- // note that storing strings like this is safe since references into an
- // unordered_set are not invalidated on insertion
- std::unordered_set uris_, names_, function_names_;
+ // Defined below since it depends on some helper functions defined below
+ Status AddSubstraitCallToArrow(Id substrait_function_id,
+ std::string arrow_function_name) override;
+
+ // Parent registry, null for the root, non-null for nested
+ const ExtensionIdRegistry* parent_;
+
+ // owning storage of ids & types
+ IdStorage ids_;
DataTypeVector types_;
+ // There should only be one entry per Arrow function so there is no need
+ // to separate ownership and lookup
+ std::unordered_map arrow_to_substrait_;
+ std::unordered_map arrow_to_substrait_agg_;
// non-owning lookup helpers
- std::vector type_ids_, function_ids_;
+ std::vector type_ids_;
std::unordered_map id_to_index_;
std::unordered_map type_to_index_;
-
- std::vector function_name_ptrs_;
- std::unordered_map function_id_to_index_;
- std::unordered_map
- function_name_to_index_;
+ std::unordered_map substrait_to_arrow_;
+ std::unordered_map
+ substrait_to_arrow_agg_;
};
-struct NestedExtensionIdRegistryImpl : ExtensionIdRegistryImpl {
- explicit NestedExtensionIdRegistryImpl(const ExtensionIdRegistry* parent)
- : parent_(parent) {}
-
- virtual ~NestedExtensionIdRegistryImpl() {}
+template
+using EnumParser = std::function(util::optional)>;
- std::vector Uris() const override {
- std::vector uris = parent_->Uris();
- std::unordered_set uri_set;
- uri_set.insert(uris.begin(), uris.end());
- uri_set.insert(uris_.begin(), uris_.end());
- return std::vector(uris);
+template
+EnumParser GetEnumParser(const std::vector& options) {
+ std::unordered_map parse_map;
+ for (std::size_t i = 0; i < options.size(); i++) {
+ parse_map[options[i]] = static_cast(i + 1);
}
-
- util::optional GetType(const DataType& type) const override {
- auto type_opt = ExtensionIdRegistryImpl::GetType(type);
- if (type_opt) {
- return type_opt;
+ return [parse_map](util::optional enum_val) -> Result {
+ if (!enum_val) {
+ // Assumes 0 is always kUnspecified in Enum
+ return static_cast(0);
}
- return parent_->GetType(type);
- }
-
- util::optional GetType(Id id) const override {
- auto type_opt = ExtensionIdRegistryImpl::GetType(id);
- if (type_opt) {
- return type_opt;
+ auto maybe_parsed = parse_map.find(enum_val->to_string());
+ if (maybe_parsed == parse_map.end()) {
+ return Status::Invalid("The value ", *enum_val, " is not an expected enum value");
}
- return parent_->GetType(id);
- }
+ return maybe_parsed->second;
+ };
+}
- Status CanRegisterType(Id id, const std::shared_ptr& type) const override {
- return parent_->CanRegisterType(id, type) &
- ExtensionIdRegistryImpl::CanRegisterType(id, type);
- }
+enum class TemporalComponent { kUnspecified = 0, kYear, kMonth, kDay, kSecond };
+static std::vector kTemporalComponentOptions = {"YEAR", "MONTH", "DAY",
+ "SECOND"};
+static EnumParser kTemporalComponentParser =
+ GetEnumParser(kTemporalComponentOptions);
+
+enum class OverflowBehavior { kUnspecified = 0, kSilent, kSaturate, kError };
+static std::vector kOverflowOptions = {"SILENT", "SATURATE", "ERROR"};
+static EnumParser kOverflowParser =
+ GetEnumParser(kOverflowOptions);
+
+template
+Result ParseEnumArg(const SubstraitCall& call, uint32_t arg_index,
+ const EnumParser& parser) {
+ ARROW_ASSIGN_OR_RAISE(util::optional enum_arg,
+ call.GetEnumArg(arg_index));
+ return parser(enum_arg);
+}
- Status RegisterType(Id id, std::shared_ptr type) override {
- return parent_->CanRegisterType(id, type) &
- ExtensionIdRegistryImpl::RegisterType(id, type);
+Result> GetValueArgs(const SubstraitCall& call,
+ int start_index) {
+ std::vector expressions;
+ for (uint32_t index = start_index; index < call.size(); index++) {
+ ARROW_ASSIGN_OR_RAISE(compute::Expression arg, call.GetValueArg(index));
+ expressions.push_back(arg);
}
+ return std::move(expressions);
+}
- util::optional GetFunction(
- util::string_view arrow_function_name) const override {
- auto func_opt = ExtensionIdRegistryImpl::GetFunction(arrow_function_name);
- if (func_opt) {
- return func_opt;
+ExtensionIdRegistry::SubstraitCallToArrow DecodeOptionlessOverflowableArithmetic(
+ const std::string& function_name) {
+ return [function_name](const SubstraitCall& call) -> Result {
+ ARROW_ASSIGN_OR_RAISE(OverflowBehavior overflow_behavior,
+ ParseEnumArg(call, 0, kOverflowParser));
+ ARROW_ASSIGN_OR_RAISE(std::vector value_args,
+ GetValueArgs(call, 1));
+ if (overflow_behavior == OverflowBehavior::kUnspecified) {
+ overflow_behavior = OverflowBehavior::kSilent;
}
- return parent_->GetFunction(arrow_function_name);
- }
+ if (overflow_behavior == OverflowBehavior::kSilent) {
+ return arrow::compute::call(function_name, std::move(value_args));
+ } else if (overflow_behavior == OverflowBehavior::kError) {
+ return arrow::compute::call(function_name + "_checked", std::move(value_args));
+ } else {
+ return Status::NotImplemented(
+ "Only SILENT and ERROR arithmetic kernels are currently implemented but ",
+ kOverflowOptions[static_cast(overflow_behavior) - 1], " was requested");
+ }
+ };
+}
- util::optional GetFunction(Id id) const override {
- auto func_opt = ExtensionIdRegistryImpl::GetFunction(id);
- if (func_opt) {
- return func_opt;
+template
+ExtensionIdRegistry::ArrowToSubstraitCall EncodeOptionlessOverflowableArithmetic(
+ Id substrait_fn_id) {
+ return
+ [substrait_fn_id](const compute::Expression::Call& call) -> Result {
+ // nullable=true isn't quite correct but we don't know the nullability of
+ // the inputs
+ SubstraitCall substrait_call(substrait_fn_id, call.type.GetSharedPtr(),
+ /*nullable=*/true);
+ if (kChecked) {
+ substrait_call.SetEnumArg(0, "ERROR");
+ } else {
+ substrait_call.SetEnumArg(0, "SILENT");
+ }
+ for (std::size_t i = 0; i < call.arguments.size(); i++) {
+ substrait_call.SetValueArg(static_cast(i + 1), call.arguments[i]);
+ }
+ return std::move(substrait_call);
+ };
+}
+
+ExtensionIdRegistry::SubstraitCallToArrow DecodeOptionlessBasicMapping(
+ const std::string& function_name, uint32_t max_args) {
+ return [function_name,
+ max_args](const SubstraitCall& call) -> Result {
+ if (call.size() > max_args) {
+ return Status::NotImplemented("Acero does not have a kernel for ", function_name,
+ " that receives ", call.size(), " arguments");
}
- return parent_->GetFunction(id);
- }
+ ARROW_ASSIGN_OR_RAISE(std::vector value_args,
+ GetValueArgs(call, 0));
+ return arrow::compute::call(function_name, std::move(value_args));
+ };
+}
- Status CanRegisterFunction(Id id,
- const std::string& arrow_function_name) const override {
- return parent_->CanRegisterFunction(id, arrow_function_name) &
- ExtensionIdRegistryImpl::CanRegisterFunction(id, arrow_function_name);
- }
+ExtensionIdRegistry::SubstraitCallToArrow DecodeTemporalExtractionMapping() {
+ return [](const SubstraitCall& call) -> Result {
+ ARROW_ASSIGN_OR_RAISE(TemporalComponent temporal_component,
+ ParseEnumArg(call, 0, kTemporalComponentParser));
+ if (temporal_component == TemporalComponent::kUnspecified) {
+ return Status::Invalid(
+ "The temporal component enum is a require option for the extract function "
+ "and is not specified");
+ }
+ ARROW_ASSIGN_OR_RAISE(std::vector value_args,
+ GetValueArgs(call, 1));
+ std::string func_name;
+ switch (temporal_component) {
+ case TemporalComponent::kYear:
+ func_name = "year";
+ break;
+ case TemporalComponent::kMonth:
+ func_name = "month";
+ break;
+ case TemporalComponent::kDay:
+ func_name = "day";
+ break;
+ case TemporalComponent::kSecond:
+ func_name = "second";
+ break;
+ default:
+ return Status::Invalid("Unexpected value for temporal component in extract call");
+ }
+ return compute::call(func_name, std::move(value_args));
+ };
+}
- Status RegisterFunction(Id id, std::string arrow_function_name) override {
- return parent_->CanRegisterFunction(id, arrow_function_name) &
- ExtensionIdRegistryImpl::RegisterFunction(id, arrow_function_name);
- }
+ExtensionIdRegistry::SubstraitCallToArrow DecodeConcatMapping() {
+ return [](const SubstraitCall& call) -> Result {
+ ARROW_ASSIGN_OR_RAISE(std::vector value_args,
+ GetValueArgs(call, 0));
+ value_args.push_back(compute::literal(""));
+ return compute::call("binary_join_element_wise", std::move(value_args));
+ };
+}
- const ExtensionIdRegistry* parent_;
-};
+ExtensionIdRegistry::SubstraitAggregateToArrow DecodeBasicAggregate(
+ const std::string& arrow_function_name) {
+ return [arrow_function_name](const SubstraitCall& call) -> Result {
+ if (call.size() != 1) {
+ return Status::NotImplemented(
+ "Only unary aggregate functions are currently supported");
+ }
+ ARROW_ASSIGN_OR_RAISE(compute::Expression arg, call.GetValueArg(0));
+ const FieldRef* arg_ref = arg.field_ref();
+ if (!arg_ref) {
+ return Status::Invalid("Expected an aggregate call ", call.id().uri, "#",
+ call.id().name, " to have a direct reference");
+ }
+ std::string fixed_arrow_func = arrow_function_name;
+ if (call.is_hash()) {
+ fixed_arrow_func = "hash_" + arrow_function_name;
+ }
+ return compute::Aggregate{std::move(fixed_arrow_func), nullptr, *arg_ref, ""};
+ };
+}
struct DefaultExtensionIdRegistry : ExtensionIdRegistryImpl {
DefaultExtensionIdRegistry() {
+ // ----------- Extension Types ----------------------------
struct TypeName {
std::shared_ptr type;
util::string_view name;
@@ -428,32 +795,91 @@ struct DefaultExtensionIdRegistry : ExtensionIdRegistryImpl {
DCHECK_OK(RegisterType({kArrowExtTypesUri, e.name}, std::move(e.type)));
}
- for (TypeName e : {
- TypeName{null(), "null"},
- TypeName{month_interval(), "interval_month"},
- TypeName{day_time_interval(), "interval_day_milli"},
- TypeName{month_day_nano_interval(), "interval_month_day_nano"},
- }) {
+ for (TypeName e :
+ {TypeName{null(), "null"}, TypeName{month_interval(), "interval_month"},
+ TypeName{day_time_interval(), "interval_day_milli"},
+ TypeName{month_day_nano_interval(), "interval_month_day_nano"}}) {
DCHECK_OK(RegisterType({kArrowExtTypesUri, e.name}, std::move(e.type)));
}
- // TODO: this is just a placeholder right now. We'll need a YAML file for
- // all functions (and prototypes) that Arrow provides that are relevant
- // for Substrait, and include mappings for all of them here. See
- // ARROW-15535.
- for (util::string_view name : {
- "add",
- "equal",
- "is_not_distinct_from",
- "hash_count",
- }) {
- DCHECK_OK(RegisterFunction({kArrowExtTypesUri, name}, name.to_string()));
+ // -------------- Substrait -> Arrow Functions -----------------
+ // Mappings with a _checked variant
+ for (const auto& function_name : {"add", "subtract", "multiply", "divide"}) {
+ DCHECK_OK(
+ AddSubstraitCallToArrow({kSubstraitArithmeticFunctionsUri, function_name},
+ DecodeOptionlessOverflowableArithmetic(function_name)));
+ }
+ // Basic mappings that need _kleene appended to them
+ for (const auto& function_name : {"or", "and"}) {
+ DCHECK_OK(AddSubstraitCallToArrow(
+ {kSubstraitBooleanFunctionsUri, function_name},
+ DecodeOptionlessBasicMapping(std::string(function_name) + "_kleene",
+ /*max_args=*/2)));
+ }
+ // Basic binary mappings
+ for (const auto& function_name :
+ std::vector>{
+ {kSubstraitBooleanFunctionsUri, "xor"},
+ {kSubstraitComparisonFunctionsUri, "equal"},
+ {kSubstraitComparisonFunctionsUri, "not_equal"}}) {
+ DCHECK_OK(
+ AddSubstraitCallToArrow({function_name.first, function_name.second},
+ DecodeOptionlessBasicMapping(
+ function_name.second.to_string(), /*max_args=*/2)));
+ }
+ for (const auto& uri :
+ {kSubstraitComparisonFunctionsUri, kSubstraitDatetimeFunctionsUri}) {
+ DCHECK_OK(AddSubstraitCallToArrow(
+ {uri, "lt"}, DecodeOptionlessBasicMapping("less", /*max_args=*/2)));
+ DCHECK_OK(AddSubstraitCallToArrow(
+ {uri, "lte"}, DecodeOptionlessBasicMapping("less_equal", /*max_args=*/2)));
+ DCHECK_OK(AddSubstraitCallToArrow(
+ {uri, "gt"}, DecodeOptionlessBasicMapping("greater", /*max_args=*/2)));
+ DCHECK_OK(AddSubstraitCallToArrow(
+ {uri, "gte"}, DecodeOptionlessBasicMapping("greater_equal", /*max_args=*/2)));
+ }
+ // One-off mappings
+ DCHECK_OK(
+ AddSubstraitCallToArrow({kSubstraitBooleanFunctionsUri, "not"},
+ DecodeOptionlessBasicMapping("invert", /*max_args=*/1)));
+ DCHECK_OK(AddSubstraitCallToArrow({kSubstraitDatetimeFunctionsUri, "extract"},
+ DecodeTemporalExtractionMapping()));
+ DCHECK_OK(AddSubstraitCallToArrow({kSubstraitStringFunctionsUri, "concat"},
+ DecodeConcatMapping()));
+
+ // --------------- Substrait -> Arrow Aggregates --------------
+ for (const auto& fn_name : {"sum", "min", "max"}) {
+ DCHECK_OK(AddSubstraitAggregateToArrow({kSubstraitArithmeticFunctionsUri, fn_name},
+ DecodeBasicAggregate(fn_name)));
+ }
+ DCHECK_OK(AddSubstraitAggregateToArrow({kSubstraitArithmeticFunctionsUri, "avg"},
+ DecodeBasicAggregate("mean")));
+
+ // --------------- Arrow -> Substrait Functions ---------------
+ for (const auto& fn_name : {"add", "subtract", "multiply", "divide"}) {
+ Id fn_id{kSubstraitArithmeticFunctionsUri, fn_name};
+ DCHECK_OK(AddArrowToSubstraitCall(
+ fn_name, EncodeOptionlessOverflowableArithmetic(fn_id)));
+ DCHECK_OK(
+ AddArrowToSubstraitCall(std::string(fn_name) + "_checked",
+ EncodeOptionlessOverflowableArithmetic(fn_id)));
}
}
};
} // namespace
+Status ExtensionIdRegistryImpl::AddSubstraitCallToArrow(Id substrait_function_id,
+ std::string arrow_function_name) {
+ return AddSubstraitCallToArrow(
+ substrait_function_id,
+ [arrow_function_name](const SubstraitCall& call) -> Result {
+ ARROW_ASSIGN_OR_RAISE(std::vector value_args,
+ GetValueArgs(call, 0));
+ return compute::call(arrow_function_name, std::move(value_args));
+ });
+}
+
ExtensionIdRegistry* default_extension_id_registry() {
static DefaultExtensionIdRegistry impl_;
return &impl_;
@@ -461,7 +887,7 @@ ExtensionIdRegistry* default_extension_id_registry() {
std::shared_ptr nested_extension_id_registry(
const ExtensionIdRegistry* parent) {
- return std::make_shared(parent);
+ return std::make_shared(parent);
}
} // namespace engine
diff --git a/cpp/src/arrow/engine/substrait/extension_set.h b/cpp/src/arrow/engine/substrait/extension_set.h
index 04e4586a9f5e2..9cb42f66136b9 100644
--- a/cpp/src/arrow/engine/substrait/extension_set.h
+++ b/cpp/src/arrow/engine/substrait/extension_set.h
@@ -19,26 +19,130 @@
#pragma once
+#include
#include
+#include
#include
+#include "arrow/compute/exec/exec_plan.h"
+#include "arrow/compute/exec/expression.h"
#include "arrow/engine/substrait/visibility.h"
+#include "arrow/result.h"
#include "arrow/type_fwd.h"
+#include "arrow/util/hash_util.h"
+#include "arrow/util/hashing.h"
#include "arrow/util/optional.h"
#include "arrow/util/string_view.h"
-#include "arrow/util/hash_util.h"
-
namespace arrow {
namespace engine {
+constexpr const char* kSubstraitArithmeticFunctionsUri =
+ "https://github.com/substrait-io/substrait/blob/main/extensions/"
+ "functions_arithmetic.yaml";
+constexpr const char* kSubstraitBooleanFunctionsUri =
+ "https://github.com/substrait-io/substrait/blob/main/extensions/"
+ "functions_boolean.yaml";
+constexpr const char* kSubstraitComparisonFunctionsUri =
+ "https://github.com/substrait-io/substrait/blob/main/extensions/"
+ "functions_comparison.yaml";
+constexpr const char* kSubstraitDatetimeFunctionsUri =
+ "https://github.com/substrait-io/substrait/blob/main/extensions/"
+ "functions_datetime.yaml";
+constexpr const char* kSubstraitStringFunctionsUri =
+ "https://github.com/substrait-io/substrait/blob/main/extensions/"
+ "functions_string.yaml";
+
+struct Id {
+ util::string_view uri, name;
+ bool empty() const { return uri.empty() && name.empty(); }
+ std::string ToString() const;
+};
+struct IdHashEq {
+ size_t operator()(Id id) const;
+ bool operator()(Id l, Id r) const;
+};
+
+/// \brief Owning storage for ids
+///
+/// Substrait plans may reuse URIs and names in many places. For convenience
+/// and performance Substarit ids are typically passed around as views. As we
+/// convert a plan from Substrait to Arrow we need to copy these strings out of
+/// the Substrait buffer and into owned storage. This class serves as that owned
+/// storage.
+class IdStorage {
+ public:
+ /// \brief Get an equivalent id pointing into this storage
+ ///
+ /// This operation will copy the ids into storage if they do not already exist
+ Id Emplace(Id id);
+ /// \brief Get an equivalent view pointing into this storage for a URI
+ ///
+ /// If no URI is found then the uri will be copied into storage
+ util::string_view EmplaceUri(util::string_view uri);
+ /// \brief Get an equivalent id pointing into this storage
+ ///
+ /// If no id is found then nullopt will be returned
+ util::optional Find(Id id) const;
+ /// \brief Get an equivalent view pointing into this storage for a URI
+ ///
+ /// If no URI is found then nullopt will be returned
+ util::optional FindUri(util::string_view uri) const;
+
+ private:
+ std::unordered_set uris_;
+ std::unordered_set names_;
+ std::list owned_uris_;
+ std::list owned_names_;
+};
+
+/// \brief Describes a Substrait call
+///
+/// Substrait call expressions contain a list of arguments which can either
+/// be enum arguments (which are serialized as strings), value arguments (which)
+/// are Arrow expressions, or type arguments (not yet implemented)
+class SubstraitCall {
+ public:
+ SubstraitCall(Id id, std::shared_ptr output_type, bool output_nullable,
+ bool is_hash = false)
+ : id_(id),
+ output_type_(std::move(output_type)),
+ output_nullable_(output_nullable),
+ is_hash_(is_hash) {}
+
+ const Id& id() const { return id_; }
+ const std::shared_ptr& output_type() const { return output_type_; }
+ bool output_nullable() const { return output_nullable_; }
+ bool is_hash() const { return is_hash_; }
+
+ bool HasEnumArg(uint32_t index) const;
+ Result> GetEnumArg(uint32_t index) const;
+ void SetEnumArg(uint32_t index, util::optional enum_arg);
+ Result GetValueArg(uint32_t index) const;
+ bool HasValueArg(uint32_t index) const;
+ void SetValueArg(uint32_t index, compute::Expression value_arg);
+ uint32_t size() const { return size_; }
+
+ private:
+ Id id_;
+ std::shared_ptr output_type_;
+ bool output_nullable_;
+ // Only needed when converting from Substrait -> Arrow aggregates. The
+ // Arrow function name depends on whether or not there are any groups
+ bool is_hash_;
+ std::unordered_map> enum_args_;
+ std::unordered_map value_args_;
+ uint32_t size_ = 0;
+};
+
/// Substrait identifies functions and custom data types using a (uri, name) pair.
///
-/// This registry is a bidirectional mapping between Substrait IDs and their corresponding
-/// Arrow counterparts (arrow::DataType and function names in a function registry)
+/// This registry is a bidirectional mapping between Substrait IDs and their
+/// corresponding Arrow counterparts (arrow::DataType and function names in a function
+/// registry)
///
-/// Substrait extension types and variations must be registered with their corresponding
-/// arrow::DataType before they can be used!
+/// Substrait extension types and variations must be registered with their
+/// corresponding arrow::DataType before they can be used!
///
/// Conceptually this can be thought of as two pairs of `unordered_map`s. One pair to
/// go back and forth between Substrait ID and arrow::DataType and another pair to go
@@ -49,56 +153,103 @@ namespace engine {
/// instance).
class ARROW_ENGINE_EXPORT ExtensionIdRegistry {
public:
- /// All uris registered in this ExtensionIdRegistry
- virtual std::vector Uris() const = 0;
-
- struct Id {
- util::string_view uri, name;
-
- bool empty() const { return uri.empty() && name.empty(); }
- };
-
- struct IdHashEq {
- size_t operator()(Id id) const;
- bool operator()(Id l, Id r) const;
- };
+ using ArrowToSubstraitCall =
+ std::function(const arrow::compute::Expression::Call&)>;
+ using SubstraitCallToArrow =
+ std::function(const SubstraitCall&)>;
+ using ArrowToSubstraitAggregate =
+ std::function(const arrow::compute::Aggregate&)>;
+ using SubstraitAggregateToArrow =
+ std::function(const SubstraitCall&)>;
/// \brief A mapping between a Substrait ID and an arrow::DataType
struct TypeRecord {
Id id;
const std::shared_ptr& type;
};
+
+ /// \brief Return a uri view owned by this registry
+ ///
+ /// If the URI has never been emplaced it will return nullopt
+ virtual util::optional FindUri(util::string_view uri) const = 0;
+ /// \brief Return a id view owned by this registry
+ ///
+ /// If the id has never been emplaced it will return nullopt
+ virtual util::optional FindId(Id id) const = 0;
virtual util::optional GetType(const DataType&) const = 0;
virtual util::optional GetType(Id) const = 0;
virtual Status CanRegisterType(Id, const std::shared_ptr& type) const = 0;
virtual Status RegisterType(Id, std::shared_ptr) = 0;
+ /// \brief Register a converter that converts an Arrow call to a Substrait call
+ ///
+ /// Note that there may not be 1:1 parity between ArrowToSubstraitCall and
+ /// SubstraitCallToArrow because some standard functions (e.g. add) may map to
+ /// multiple Arrow functions (e.g. add, add_checked)
+ virtual Status AddArrowToSubstraitCall(std::string arrow_function_name,
+ ArrowToSubstraitCall conversion_func) = 0;
+ /// \brief Check to see if a converter can be registered
+ ///
+ /// \return Status::OK if there are no conflicts, otherwise an error is returned
+ virtual Status CanAddArrowToSubstraitCall(
+ const std::string& arrow_function_name) const = 0;
- /// \brief A mapping between a Substrait ID and an Arrow function
+ /// \brief Register a converter that converts an Arrow aggregate to a Substrait
+ /// aggregate
+ virtual Status AddArrowToSubstraitAggregate(
+ std::string arrow_function_name, ArrowToSubstraitAggregate conversion_func) = 0;
+ /// \brief Check to see if a converter can be registered
///
- /// Note: At the moment we identify functions solely by the name
- /// of the function in the function registry.
+ /// \return Status::OK if there are no conflicts, otherwise an error is returned
+ virtual Status CanAddArrowToSubstraitAggregate(
+ const std::string& arrow_function_name) const = 0;
+
+ /// \brief Register a converter that converts a Substrait call to an Arrow call
+ virtual Status AddSubstraitCallToArrow(Id substrait_function_id,
+ SubstraitCallToArrow conversion_func) = 0;
+ /// \brief Check to see if a converter can be registered
///
- /// TODO(ARROW-15582) some functions will not be simple enough to convert without access
- /// to their arguments/options. For example is_in embeds the set in options rather than
- /// using an argument:
- /// is_in(x, SetLookupOptions(set)) <-> (k...Uri, "is_in")(x, set)
+ /// \return Status::OK if there are no conflicts, otherwise an error is returned
+ virtual Status CanAddSubstraitCallToArrow(Id substrait_function_id) const = 0;
+ /// \brief Register a simple mapping function
///
- /// ... for another example, depending on the value of the first argument to
- /// substrait::add it either corresponds to arrow::add or arrow::add_checked
- struct FunctionRecord {
- Id id;
- const std::string& function_name;
- };
- virtual util::optional GetFunction(Id) const = 0;
- virtual util::optional GetFunction(
- util::string_view arrow_function_name) const = 0;
- virtual Status CanRegisterFunction(Id,
- const std::string& arrow_function_name) const = 0;
- // registers a function without taking ownership of uri and name within Id
- virtual Status RegisterFunction(Id, std::string arrow_function_name) = 0;
- // registers a function while taking ownership of uri and name
- virtual Status RegisterFunction(std::string uri, std::string name,
- std::string arrow_function_name) = 0;
+ /// All calls to the function must pass only value arguments. The arguments
+ /// will be converted to expressions and passed to the Arrow function
+ virtual Status AddSubstraitCallToArrow(Id substrait_function_id,
+ std::string arrow_function_name) = 0;
+
+ /// \brief Register a converter that converts a Substrait aggregate to an Arrow
+ /// aggregate
+ virtual Status AddSubstraitAggregateToArrow(
+ Id substrait_function_id, SubstraitAggregateToArrow conversion_func) = 0;
+ /// \brief Check to see if a converter can be registered
+ ///
+ /// \return Status::OK if there are no conflicts, otherwise an error is returned
+ virtual Status CanAddSubstraitAggregateToArrow(Id substrait_function_id) const = 0;
+
+ /// \brief Return a list of Substrait functions that have a converter
+ ///
+ /// The function ids are encoded as strings using the pattern {uri}#{name}
+ virtual std::vector GetSupportedSubstraitFunctions() const = 0;
+
+ /// \brief Find a converter to map Arrow calls to Substrait calls
+ /// \return A converter function or an invalid status if no converter is registered
+ virtual Result GetArrowToSubstraitCall(
+ const std::string& arrow_function_name) const = 0;
+
+ /// \brief Find a converter to map Arrow aggregates to Substrait aggregates
+ /// \return A converter function or an invalid status if no converter is registered
+ virtual Result GetArrowToSubstraitAggregate(
+ const std::string& arrow_function_name) const = 0;
+
+ /// \brief Find a converter to map a Substrait aggregate to an Arrow aggregate
+ /// \return A converter function or an invalid status if no converter is registered
+ virtual Result GetSubstraitAggregateToArrow(
+ Id substrait_function_id) const = 0;
+
+ /// \brief Find a converter to map a Substrait call to an Arrow call
+ /// \return A converter function or an invalid status if no converter is registered
+ virtual Result GetSubstraitCallToArrow(
+ Id substrait_function_id) const = 0;
};
constexpr util::string_view kArrowExtTypesUri =
@@ -153,9 +304,6 @@ ARROW_ENGINE_EXPORT std::shared_ptr nested_extension_id_reg
/// ExtensionIdRegistry.
class ARROW_ENGINE_EXPORT ExtensionSet {
public:
- using Id = ExtensionIdRegistry::Id;
- using IdHashEq = ExtensionIdRegistry::IdHashEq;
-
struct FunctionRecord {
Id id;
util::string_view name;
@@ -219,12 +367,12 @@ class ARROW_ENGINE_EXPORT ExtensionSet {
/// \return An anchor that can be used to refer to the type within a plan
Result EncodeType(const DataType& type);
- /// \brief Returns a function given an anchor
+ /// \brief Return a function id given an anchor
///
/// This is used when converting a Substrait plan to an Arrow execution plan.
///
/// If the anchor does not exist in this extension set an error will be returned.
- Result DecodeFunction(uint32_t anchor) const;
+ Result DecodeFunction(uint32_t anchor) const;
/// \brief Lookup the anchor for a given function
///
@@ -239,26 +387,30 @@ class ARROW_ENGINE_EXPORT ExtensionSet {
/// returned.
///
/// \return An anchor that can be used to refer to the function within a plan
- Result EncodeFunction(util::string_view function_name);
+ Result EncodeFunction(Id function_id);
- /// \brief Returns the number of custom functions in this extension set
- ///
- /// Note: the functions are currently stored as a sparse vector, so this may return a
- /// value larger than the actual number of functions. This behavior may change in the
- /// future; see ARROW-15583.
+ /// \brief Return the number of custom functions in this extension set
std::size_t num_functions() const { return functions_.size(); }
+ const ExtensionIdRegistry* registry() const { return registry_; }
+
private:
const ExtensionIdRegistry* registry_;
+ // If the registry is not aware of an id then we probably can't do anything
+ // with it. However, in some cases, these may represent extensions or features
+ // that we can safely ignore. For example, we can usually safely ignore
+ // extension type variations if we assume the plan is valid. These ignorable
+ // ids are stored here.
+ IdStorage plan_specific_ids_;
// Map from anchor values to URI values referenced by this extension set
std::unordered_map uris_;
// Map from anchor values to type definitions, used during Substrait->Arrow
// and populated from the Substrait extension set
std::unordered_map types_;
- // Map from anchor values to function definitions, used during Substrait->Arrow
+ // Map from anchor values to function ids, used during Substrait->Arrow
// and populated from the Substrait extension set
- std::unordered_map functions_;
+ std::unordered_map functions_;
// Map from type names to anchor values. Used during Arrow->Substrait
// and built as the plan is created.
std::unordered_map types_map_;
diff --git a/cpp/src/arrow/engine/substrait/function_test.cc b/cpp/src/arrow/engine/substrait/function_test.cc
new file mode 100644
index 0000000000000..225bc56d13681
--- /dev/null
+++ b/cpp/src/arrow/engine/substrait/function_test.cc
@@ -0,0 +1,495 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include
+#include
+#include
+
+#include
+
+#include "arrow/array.h"
+#include "arrow/array/builder_binary.h"
+#include "arrow/compute/cast.h"
+#include "arrow/compute/exec/options.h"
+#include "arrow/compute/exec/util.h"
+#include "arrow/engine/substrait/extension_set.h"
+#include "arrow/engine/substrait/plan_internal.h"
+#include "arrow/engine/substrait/serde.h"
+#include "arrow/engine/substrait/test_plan_builder.h"
+#include "arrow/engine/substrait/type_internal.h"
+#include "arrow/record_batch.h"
+#include "arrow/table.h"
+#include "arrow/testing/future_util.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/type.h"
+
+namespace arrow {
+
+namespace engine {
+struct FunctionTestCase {
+ Id function_id;
+ std::vector arguments;
+ std::vector> data_types;
+ // For a test case that should fail just use the empty string
+ std::string expected_output;
+ std::shared_ptr expected_output_type;
+};
+
+Result> GetArray(const std::string& value,
+ const std::shared_ptr& data_type) {
+ StringBuilder str_builder;
+ if (value.empty()) {
+ ARROW_EXPECT_OK(str_builder.AppendNull());
+ } else {
+ ARROW_EXPECT_OK(str_builder.Append(value));
+ }
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr value_str, str_builder.Finish());
+ ARROW_ASSIGN_OR_RAISE(Datum value_datum, compute::Cast(value_str, data_type));
+ return value_datum.make_array();
+}
+
+Result> GetInputTable(
+ const std::vector& arguments,
+ const std::vector>& data_types) {
+ std::vector> columns;
+ std::vector> fields;
+ EXPECT_EQ(arguments.size(), data_types.size());
+ for (std::size_t i = 0; i < arguments.size(); i++) {
+ if (data_types[i]) {
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr arg_array,
+ GetArray(arguments[i], data_types[i]));
+ columns.push_back(std::move(arg_array));
+ fields.push_back(field("arg_" + std::to_string(i), data_types[i]));
+ }
+ }
+ std::shared_ptr batch =
+ RecordBatch::Make(schema(std::move(fields)), 1, columns);
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr table, Table::FromRecordBatches({batch}));
+ return table;
+}
+
+Result> GetOutputTable(
+ const std::string& output_value, const std::shared_ptr& output_type) {
+ std::vector> columns(1);
+ std::vector> fields(1);
+ ARROW_ASSIGN_OR_RAISE(columns[0], GetArray(output_value, output_type));
+ fields[0] = field("output", output_type);
+ std::shared_ptr batch =
+ RecordBatch::Make(schema(std::move(fields)), 1, columns);
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr table, Table::FromRecordBatches({batch}));
+ return table;
+}
+
+Result> PlanFromTestCase(
+ const FunctionTestCase& test_case, std::shared_ptr* output_table) {
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr input_table,
+ GetInputTable(test_case.arguments, test_case.data_types));
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr substrait,
+ internal::CreateScanProjectSubstrait(
+ test_case.function_id, input_table, test_case.arguments,
+ test_case.data_types, *test_case.expected_output_type));
+ std::shared_ptr