diff --git a/cpp/src/arrow/compute/exec/options.cc b/cpp/src/arrow/compute/exec/options.cc index c09ab1c1b68c2..ef1a0c7e2eb2a 100644 --- a/cpp/src/arrow/compute/exec/options.cc +++ b/cpp/src/arrow/compute/exec/options.cc @@ -25,6 +25,8 @@ namespace arrow { namespace compute { +constexpr int64_t TableSourceNodeOptions::kDefaultMaxBatchSize; + std::string ToString(JoinType t) { switch (t) { case JoinType::LEFT_SEMI: diff --git a/cpp/src/arrow/compute/exec/options.h b/cpp/src/arrow/compute/exec/options.h index 4a0cd602efb54..a8e8c1ee23096 100644 --- a/cpp/src/arrow/compute/exec/options.h +++ b/cpp/src/arrow/compute/exec/options.h @@ -64,7 +64,9 @@ class ARROW_EXPORT SourceNodeOptions : public ExecNodeOptions { /// \brief An extended Source node which accepts a table class ARROW_EXPORT TableSourceNodeOptions : public ExecNodeOptions { public: - TableSourceNodeOptions(std::shared_ptr table, int64_t max_batch_size) + static constexpr int64_t kDefaultMaxBatchSize = 1 << 20; + TableSourceNodeOptions(std::shared_ptr
table, + int64_t max_batch_size = kDefaultMaxBatchSize) : table(table), max_batch_size(max_batch_size) {} // arrow table which acts as the data source diff --git a/cpp/src/arrow/compute/exec/sink_node.cc b/cpp/src/arrow/compute/exec/sink_node.cc index a1426265cf94f..8af4e8e996cce 100644 --- a/cpp/src/arrow/compute/exec/sink_node.cc +++ b/cpp/src/arrow/compute/exec/sink_node.cc @@ -388,47 +388,6 @@ class ConsumingSinkNode : public ExecNode, public BackpressureControl { std::vector names_; int32_t backpressure_counter_ = 0; }; - -/** - * @brief This node is an extension on ConsumingSinkNode - * to facilitate to get the output from an execution plan - * as a table. We define a custom SinkNodeConsumer to - * enable this functionality. - */ - -struct TableSinkNodeConsumer : public SinkNodeConsumer { - public: - TableSinkNodeConsumer(std::shared_ptr
* out, MemoryPool* pool) - : out_(out), pool_(pool) {} - - Status Init(const std::shared_ptr& schema, - BackpressureControl* backpressure_control) override { - // If the user is collecting into a table then backpressure is meaningless - ARROW_UNUSED(backpressure_control); - schema_ = schema; - return Status::OK(); - } - - Status Consume(ExecBatch batch) override { - std::lock_guard guard(consume_mutex_); - ARROW_ASSIGN_OR_RAISE(auto rb, batch.ToRecordBatch(schema_, pool_)); - batches_.push_back(rb); - return Status::OK(); - } - - Future<> Finish() override { - ARROW_ASSIGN_OR_RAISE(*out_, Table::FromRecordBatches(batches_)); - return Status::OK(); - } - - private: - std::shared_ptr
* out_; - MemoryPool* pool_; - std::shared_ptr schema_; - std::vector> batches_; - std::mutex consume_mutex_; -}; - static Result MakeTableConsumingSinkNode( compute::ExecPlan* plan, std::vector inputs, const compute::ExecNodeOptions& options) { diff --git a/cpp/src/arrow/compute/exec/util.cc b/cpp/src/arrow/compute/exec/util.cc index ae70cfcd46f50..a34a9c6271322 100644 --- a/cpp/src/arrow/compute/exec/util.cc +++ b/cpp/src/arrow/compute/exec/util.cc @@ -383,5 +383,25 @@ size_t ThreadIndexer::Check(size_t thread_index) { return thread_index; } +Status TableSinkNodeConsumer::Init(const std::shared_ptr& schema, + BackpressureControl* backpressure_control) { + // If the user is collecting into a table then backpressure is meaningless + ARROW_UNUSED(backpressure_control); + schema_ = schema; + return Status::OK(); +} + +Status TableSinkNodeConsumer::Consume(ExecBatch batch) { + auto guard = consume_mutex_.Lock(); + ARROW_ASSIGN_OR_RAISE(auto rb, batch.ToRecordBatch(schema_, pool_)); + batches_.push_back(std::move(rb)); + return Status::OK(); +} + +Future<> TableSinkNodeConsumer::Finish() { + ARROW_ASSIGN_OR_RAISE(*out_, Table::FromRecordBatches(batches_)); + return Status::OK(); +} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/util.h b/cpp/src/arrow/compute/exec/util.h index 30526cb835ab1..7e716808fa008 100644 --- a/cpp/src/arrow/compute/exec/util.h +++ b/cpp/src/arrow/compute/exec/util.h @@ -24,6 +24,7 @@ #include #include "arrow/buffer.h" +#include "arrow/compute/exec/options.h" #include "arrow/compute/type_fwd.h" #include "arrow/memory_pool.h" #include "arrow/result.h" @@ -342,5 +343,23 @@ class TailSkipForSIMD { } }; +/// \brief A consumer that collects results into an in-memory table +struct ARROW_EXPORT TableSinkNodeConsumer : public SinkNodeConsumer { + public: + TableSinkNodeConsumer(std::shared_ptr
* out, MemoryPool* pool) + : out_(out), pool_(pool) {} + Status Init(const std::shared_ptr& schema, + BackpressureControl* backpressure_control) override; + Status Consume(ExecBatch batch) override; + Future<> Finish() override; + + private: + std::shared_ptr
* out_; + MemoryPool* pool_; + std::shared_ptr schema_; + std::vector> batches_; + util::Mutex consume_mutex_; +}; + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/engine/CMakeLists.txt b/cpp/src/arrow/engine/CMakeLists.txt index 8edd22900e6cb..4109b7b3bcdf3 100644 --- a/cpp/src/arrow/engine/CMakeLists.txt +++ b/cpp/src/arrow/engine/CMakeLists.txt @@ -23,9 +23,10 @@ set(ARROW_SUBSTRAIT_SRCS substrait/expression_internal.cc substrait/extension_set.cc substrait/extension_types.cc - substrait/serde.cc substrait/plan_internal.cc substrait/relation_internal.cc + substrait/serde.cc + substrait/test_plan_builder.cc substrait/type_internal.cc substrait/util.cc) @@ -67,6 +68,7 @@ endif() add_arrow_test(substrait_test SOURCES substrait/ext_test.cc + substrait/function_test.cc substrait/serde_test.cc EXTRA_LINK_LIBS ${ARROW_SUBSTRAIT_TEST_LINK_LIBS} diff --git a/cpp/src/arrow/engine/substrait/expression_internal.cc b/cpp/src/arrow/engine/substrait/expression_internal.cc index 07c222bc4cfd1..589d7e6ac695f 100644 --- a/cpp/src/arrow/engine/substrait/expression_internal.cc +++ b/cpp/src/arrow/engine/substrait/expression_internal.cc @@ -41,6 +41,84 @@ namespace internal { using ::arrow::internal::make_unique; } // namespace internal +Status DecodeArg(const substrait::FunctionArgument& arg, uint32_t idx, + SubstraitCall* call, const ExtensionSet& ext_set, + const ConversionOptions& conversion_options) { + if (arg.has_enum_()) { + const substrait::FunctionArgument::Enum& enum_val = arg.enum_(); + if (enum_val.has_specified()) { + call->SetEnumArg(idx, enum_val.specified()); + } else { + call->SetEnumArg(idx, util::nullopt); + } + } else if (arg.has_value()) { + ARROW_ASSIGN_OR_RAISE(compute::Expression expr, + FromProto(arg.value(), ext_set, conversion_options)); + call->SetValueArg(idx, std::move(expr)); + } else if (arg.has_type()) { + return Status::NotImplemented("Type arguments not currently supported"); + } else { + return Status::NotImplemented("Unrecognized function argument class"); + } + return Status::OK(); +} + +Result DecodeScalarFunction( + Id id, const substrait::Expression::ScalarFunction& scalar_fn, + const ExtensionSet& ext_set, const ConversionOptions& conversion_options) { + ARROW_ASSIGN_OR_RAISE(auto output_type_and_nullable, + FromProto(scalar_fn.output_type(), ext_set, conversion_options)); + SubstraitCall call(id, output_type_and_nullable.first, output_type_and_nullable.second); + for (int i = 0; i < scalar_fn.arguments_size(); i++) { + ARROW_RETURN_NOT_OK(DecodeArg(scalar_fn.arguments(i), static_cast(i), &call, + ext_set, conversion_options)); + } + return std::move(call); +} + +std::string EnumToString(int value, const google::protobuf::EnumDescriptor* descriptor) { + const google::protobuf::EnumValueDescriptor* value_desc = + descriptor->FindValueByNumber(value); + if (value_desc == nullptr) { + return "unknown"; + } + return value_desc->name(); +} + +Result FromProto(const substrait::AggregateFunction& func, bool is_hash, + const ExtensionSet& ext_set, + const ConversionOptions& conversion_options) { + if (func.phase() != substrait::AggregationPhase::AGGREGATION_PHASE_INITIAL_TO_RESULT) { + return Status::NotImplemented( + "Unsupported aggregation phase '", + EnumToString(func.phase(), substrait::AggregationPhase_descriptor()), + "'. Only INITIAL_TO_RESULT is supported"); + } + if (func.invocation() != + substrait::AggregateFunction::AggregationInvocation:: + AggregateFunction_AggregationInvocation_AGGREGATION_INVOCATION_ALL) { + return Status::NotImplemented( + "Unsupported aggregation invocation '", + EnumToString(func.invocation(), + substrait::AggregateFunction::AggregationInvocation_descriptor()), + "'. Only AGGREGATION_INVOCATION_ALL is " + "supported"); + } + if (func.sorts_size() > 0) { + return Status::NotImplemented("Aggregation sorts are not supported"); + } + ARROW_ASSIGN_OR_RAISE(auto output_type_and_nullable, + FromProto(func.output_type(), ext_set, conversion_options)); + ARROW_ASSIGN_OR_RAISE(Id id, ext_set.DecodeFunction(func.function_reference())); + SubstraitCall call(id, output_type_and_nullable.first, output_type_and_nullable.second, + is_hash); + for (int i = 0; i < func.arguments_size(); i++) { + ARROW_RETURN_NOT_OK(DecodeArg(func.arguments(i), static_cast(i), &call, + ext_set, conversion_options)); + } + return std::move(call); +} + Result FromProto(const substrait::Expression& expr, const ExtensionSet& ext_set, const ConversionOptions& conversion_options) { @@ -166,34 +244,14 @@ Result FromProto(const substrait::Expression& expr, case substrait::Expression::kScalarFunction: { const auto& scalar_fn = expr.scalar_function(); - ARROW_ASSIGN_OR_RAISE(auto decoded_function, + ARROW_ASSIGN_OR_RAISE(Id function_id, ext_set.DecodeFunction(scalar_fn.function_reference())); - - std::vector arguments(scalar_fn.arguments_size()); - for (int i = 0; i < scalar_fn.arguments_size(); ++i) { - const auto& argument = scalar_fn.arguments(i); - switch (argument.arg_type_case()) { - case substrait::FunctionArgument::kValue: { - ARROW_ASSIGN_OR_RAISE( - arguments[i], FromProto(argument.value(), ext_set, conversion_options)); - break; - } - default: - return Status::NotImplemented( - "only value arguments are currently supported for functions"); - } - } - - auto func_name = decoded_function.name.to_string(); - if (func_name != "cast") { - return compute::call(func_name, std::move(arguments)); - } else { - ARROW_ASSIGN_OR_RAISE( - auto output_type_desc, - FromProto(scalar_fn.output_type(), ext_set, conversion_options)); - auto cast_options = compute::CastOptions::Safe(std::move(output_type_desc.first)); - return compute::call(func_name, std::move(arguments), std::move(cast_options)); - } + ARROW_ASSIGN_OR_RAISE(ExtensionIdRegistry::SubstraitCallToArrow function_converter, + ext_set.registry()->GetSubstraitCallToArrow(function_id)); + ARROW_ASSIGN_OR_RAISE( + SubstraitCall substrait_call, + DecodeScalarFunction(function_id, scalar_fn, ext_set, conversion_options)); + return function_converter(substrait_call); } default: @@ -827,6 +885,42 @@ static Result> MakeListElementReference( return MakeDirectReference(std::move(expr), std::move(ref_segment)); } +Result> EncodeSubstraitCall( + const SubstraitCall& call, ExtensionSet* ext_set, + const ConversionOptions& conversion_options) { + ARROW_ASSIGN_OR_RAISE(uint32_t anchor, ext_set->EncodeFunction(call.id())); + auto scalar_fn = internal::make_unique(); + scalar_fn->set_function_reference(anchor); + ARROW_ASSIGN_OR_RAISE( + std::unique_ptr output_type, + ToProto(*call.output_type(), call.output_nullable(), ext_set, conversion_options)); + scalar_fn->set_allocated_output_type(output_type.release()); + + for (uint32_t i = 0; i < call.size(); i++) { + substrait::FunctionArgument* arg = scalar_fn->add_arguments(); + if (call.HasEnumArg(i)) { + auto enum_val = internal::make_unique(); + ARROW_ASSIGN_OR_RAISE(util::optional enum_arg, + call.GetEnumArg(i)); + if (enum_arg) { + enum_val->set_specified(enum_arg->to_string()); + } else { + enum_val->set_allocated_unspecified(new google::protobuf::Empty()); + } + arg->set_allocated_enum_(enum_val.release()); + } else if (call.HasValueArg(i)) { + ARROW_ASSIGN_OR_RAISE(compute::Expression value_arg, call.GetValueArg(i)); + ARROW_ASSIGN_OR_RAISE(std::unique_ptr value_expr, + ToProto(value_arg, ext_set, conversion_options)); + arg->set_allocated_value(value_expr.release()); + } else { + return Status::Invalid("Call reported having ", call.size(), + " arguments but no argument could be found at index ", i); + } + } + return std::move(scalar_fn); +} + Result> ToProto( const compute::Expression& expr, ExtensionSet* ext_set, const ConversionOptions& conversion_options) { @@ -933,17 +1027,12 @@ Result> ToProto( } // other expression types dive into extensions immediately - ARROW_ASSIGN_OR_RAISE(auto anchor, ext_set->EncodeFunction(call->function_name)); - - auto scalar_fn = internal::make_unique(); - scalar_fn->set_function_reference(anchor); - scalar_fn->mutable_arguments()->Reserve(static_cast(arguments.size())); - for (auto& arg : arguments) { - auto argument = internal::make_unique(); - argument->set_allocated_value(arg.release()); - scalar_fn->mutable_arguments()->AddAllocated(argument.release()); - } - + ARROW_ASSIGN_OR_RAISE( + ExtensionIdRegistry::ArrowToSubstraitCall converter, + ext_set->registry()->GetArrowToSubstraitCall(call->function_name)); + ARROW_ASSIGN_OR_RAISE(SubstraitCall substrait_call, converter(*call)); + ARROW_ASSIGN_OR_RAISE(std::unique_ptr scalar_fn, + EncodeSubstraitCall(substrait_call, ext_set, conversion_options)); out->set_allocated_scalar_function(scalar_fn.release()); return std::move(out); } diff --git a/cpp/src/arrow/engine/substrait/expression_internal.h b/cpp/src/arrow/engine/substrait/expression_internal.h index 2b4dec2a00b21..f132afc0c1ac9 100644 --- a/cpp/src/arrow/engine/substrait/expression_internal.h +++ b/cpp/src/arrow/engine/substrait/expression_internal.h @@ -50,5 +50,9 @@ Result> ToProto(const Datum&, ExtensionSet*, const ConversionOptions&); +ARROW_ENGINE_EXPORT +Result FromProto(const substrait::AggregateFunction&, bool is_hash, + const ExtensionSet&, const ConversionOptions&); + } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/ext_test.cc b/cpp/src/arrow/engine/substrait/ext_test.cc index 8e41cb7c98cee..4b37aa8fcdba3 100644 --- a/cpp/src/arrow/engine/substrait/ext_test.cc +++ b/cpp/src/arrow/engine/substrait/ext_test.cc @@ -56,12 +56,10 @@ struct DefaultExtensionIdRegistryProvider : public ExtensionIdRegistryProvider { struct NestedExtensionIdRegistryProvider : public ExtensionIdRegistryProvider { virtual ~NestedExtensionIdRegistryProvider() {} - std::shared_ptr registry_ = substrait::MakeExtensionIdRegistry(); + std::shared_ptr registry_ = MakeExtensionIdRegistry(); ExtensionIdRegistry* get() const override { return &*registry_; } }; -using Id = ExtensionIdRegistry::Id; - bool operator==(const Id& id1, const Id& id2) { return id1.uri == id2.uri && id1.name == id2.name; } @@ -85,8 +83,8 @@ static const std::vector kTypeNames = { TypeName{month_day_nano_interval(), "interval_month_day_nano"}, }; -static const std::vector kFunctionNames = { - "add", +static const std::vector kFunctionIds = { + {kSubstraitArithmeticFunctionsUri, "add"}, }; static const std::vector kTempFunctionNames = { @@ -141,15 +139,12 @@ TEST_P(ExtensionIdRegistryTest, GetFunctions) { auto provider = std::get<0>(GetParam()); auto registry = provider->get(); - for (util::string_view name : kFunctionNames) { - auto id = Id{kArrowExtTypesUri, name}; - for (auto funcrec_opt : {registry->GetFunction(id), registry->GetFunction(name)}) { - ASSERT_TRUE(funcrec_opt); - auto funcrec = funcrec_opt.value(); - ASSERT_EQ(id, funcrec.id); - ASSERT_EQ(name, funcrec.function_name); - } + for (Id func_id : kFunctionIds) { + ASSERT_OK_AND_ASSIGN(ExtensionIdRegistry::SubstraitCallToArrow converter, + registry->GetSubstraitCallToArrow(func_id)); + ASSERT_TRUE(converter); } + ASSERT_RAISES(NotImplemented, registry->GetSubstraitCallToArrow(kNonExistentId)); ASSERT_FALSE(registry->GetType(kNonExistentId)); ASSERT_FALSE(registry->GetType(*kNonExistentTypeName.type)); } @@ -158,10 +153,10 @@ TEST_P(ExtensionIdRegistryTest, ReregisterFunctions) { auto provider = std::get<0>(GetParam()); auto registry = provider->get(); - for (util::string_view name : kFunctionNames) { - auto id = Id{kArrowExtTypesUri, name}; - ASSERT_RAISES(Invalid, registry->CanRegisterFunction(id, name.to_string())); - ASSERT_RAISES(Invalid, registry->RegisterFunction(id, name.to_string())); + for (Id function_id : kFunctionIds) { + ASSERT_RAISES(Invalid, registry->CanAddSubstraitCallToArrow(function_id)); + ASSERT_RAISES(Invalid, registry->AddSubstraitCallToArrow( + function_id, function_id.name.to_string())); } } @@ -173,11 +168,26 @@ INSTANTIATE_TEST_SUITE_P( std::make_tuple(std::make_shared(), "nested"))); +TEST(ExtensionIdRegistryTest, GetSupportedSubstraitFunctions) { + ExtensionIdRegistry* default_registry = default_extension_id_registry(); + std::vector supported_functions = + default_registry->GetSupportedSubstraitFunctions(); + std::size_t num_functions = supported_functions.size(); + ASSERT_GT(num_functions, 0); + + std::shared_ptr nested = + nested_extension_id_registry(default_registry); + ASSERT_OK(nested->AddSubstraitCallToArrow(kNonExistentId, "some_function")); + + std::size_t num_nested_functions = nested->GetSupportedSubstraitFunctions().size(); + ASSERT_EQ(num_functions + 1, num_nested_functions); +} + TEST(ExtensionIdRegistryTest, RegisterTempTypes) { auto default_registry = default_extension_id_registry(); constexpr int rounds = 3; for (int i = 0; i < rounds; i++) { - auto registry = substrait::MakeExtensionIdRegistry(); + auto registry = MakeExtensionIdRegistry(); for (TypeName e : kTempTypeNames) { auto id = Id{kArrowExtTypesUri, e.name}; @@ -194,15 +204,15 @@ TEST(ExtensionIdRegistryTest, RegisterTempFunctions) { auto default_registry = default_extension_id_registry(); constexpr int rounds = 3; for (int i = 0; i < rounds; i++) { - auto registry = substrait::MakeExtensionIdRegistry(); + auto registry = MakeExtensionIdRegistry(); for (util::string_view name : kTempFunctionNames) { auto id = Id{kArrowExtTypesUri, name}; - ASSERT_OK(registry->CanRegisterFunction(id, name.to_string())); - ASSERT_OK(registry->RegisterFunction(id, name.to_string())); - ASSERT_RAISES(Invalid, registry->CanRegisterFunction(id, name.to_string())); - ASSERT_RAISES(Invalid, registry->RegisterFunction(id, name.to_string())); - ASSERT_OK(default_registry->CanRegisterFunction(id, name.to_string())); + ASSERT_OK(registry->CanAddSubstraitCallToArrow(id)); + ASSERT_OK(registry->AddSubstraitCallToArrow(id, name.to_string())); + ASSERT_RAISES(Invalid, registry->CanAddSubstraitCallToArrow(id)); + ASSERT_RAISES(Invalid, registry->AddSubstraitCallToArrow(id, name.to_string())); + ASSERT_OK(default_registry->CanAddSubstraitCallToArrow(id)); } } } @@ -246,24 +256,24 @@ TEST(ExtensionIdRegistryTest, RegisterNestedFunctions) { auto default_registry = default_extension_id_registry(); constexpr int rounds = 3; for (int i = 0; i < rounds; i++) { - auto registry1 = substrait::MakeExtensionIdRegistry(); + auto registry1 = MakeExtensionIdRegistry(); - ASSERT_OK(registry1->CanRegisterFunction(id1, name1.to_string())); - ASSERT_OK(registry1->RegisterFunction(id1, name1.to_string())); + ASSERT_OK(registry1->CanAddSubstraitCallToArrow(id1)); + ASSERT_OK(registry1->AddSubstraitCallToArrow(id1, name1.to_string())); for (int j = 0; j < rounds; j++) { - auto registry2 = substrait::MakeExtensionIdRegistry(); + auto registry2 = MakeExtensionIdRegistry(); - ASSERT_OK(registry2->CanRegisterFunction(id2, name2.to_string())); - ASSERT_OK(registry2->RegisterFunction(id2, name2.to_string())); - ASSERT_RAISES(Invalid, registry2->CanRegisterFunction(id2, name2.to_string())); - ASSERT_RAISES(Invalid, registry2->RegisterFunction(id2, name2.to_string())); - ASSERT_OK(default_registry->CanRegisterFunction(id2, name2.to_string())); + ASSERT_OK(registry2->CanAddSubstraitCallToArrow(id2)); + ASSERT_OK(registry2->AddSubstraitCallToArrow(id2, name2.to_string())); + ASSERT_RAISES(Invalid, registry2->CanAddSubstraitCallToArrow(id2)); + ASSERT_RAISES(Invalid, registry2->AddSubstraitCallToArrow(id2, name2.to_string())); + ASSERT_OK(default_registry->CanAddSubstraitCallToArrow(id2)); } - ASSERT_RAISES(Invalid, registry1->CanRegisterFunction(id1, name1.to_string())); - ASSERT_RAISES(Invalid, registry1->RegisterFunction(id1, name1.to_string())); - ASSERT_OK(default_registry->CanRegisterFunction(id1, name1.to_string())); + ASSERT_RAISES(Invalid, registry1->CanAddSubstraitCallToArrow(id1)); + ASSERT_RAISES(Invalid, registry1->AddSubstraitCallToArrow(id1, name1.to_string())); + ASSERT_OK(default_registry->CanAddSubstraitCallToArrow(id1)); } } diff --git a/cpp/src/arrow/engine/substrait/extension_set.cc b/cpp/src/arrow/engine/substrait/extension_set.cc index 08eb6acc9ca89..493d576e839bb 100644 --- a/cpp/src/arrow/engine/substrait/extension_set.cc +++ b/cpp/src/arrow/engine/substrait/extension_set.cc @@ -17,9 +17,9 @@ #include "arrow/engine/substrait/extension_set.h" -#include -#include +#include +#include "arrow/engine/substrait/expression_internal.h" #include "arrow/util/hash_util.h" #include "arrow/util/hashing.h" #include "arrow/util/string_view.h" @@ -28,6 +28,9 @@ namespace arrow { namespace engine { namespace { +// TODO(ARROW-16988): replace this with EXACT_ROUNDTRIP mode +constexpr bool kExactRoundTrip = true; + struct TypePtrHashEq { template size_t operator()(const Ptr& type) const { @@ -42,16 +45,115 @@ struct TypePtrHashEq { } // namespace -size_t ExtensionIdRegistry::IdHashEq::operator()(ExtensionIdRegistry::Id id) const { +std::string Id::ToString() const { + std::stringstream sstream; + sstream << uri; + sstream << '#'; + sstream << name; + return sstream.str(); +} + +size_t IdHashEq::operator()(Id id) const { constexpr ::arrow::internal::StringViewHash hash = {}; auto out = static_cast(hash(id.uri)); ::arrow::internal::hash_combine(out, hash(id.name)); return out; } -bool ExtensionIdRegistry::IdHashEq::operator()(ExtensionIdRegistry::Id l, - ExtensionIdRegistry::Id r) const { - return l.uri == r.uri && l.name == r.name; +bool IdHashEq::operator()(Id l, Id r) const { return l.uri == r.uri && l.name == r.name; } + +Id IdStorage::Emplace(Id id) { + util::string_view owned_uri = EmplaceUri(id.uri); + + util::string_view owned_name; + auto name_itr = names_.find(id.name); + if (name_itr == names_.end()) { + owned_names_.emplace_back(id.name); + owned_name = owned_names_.back(); + names_.insert(owned_name); + } else { + owned_name = *name_itr; + } + + return {owned_uri, owned_name}; +} + +util::optional IdStorage::Find(Id id) const { + util::optional maybe_owned_uri = FindUri(id.uri); + if (!maybe_owned_uri) { + return util::nullopt; + } + + auto name_itr = names_.find(id.name); + if (name_itr == names_.end()) { + return util::nullopt; + } else { + return Id{*maybe_owned_uri, *name_itr}; + } +} + +util::optional IdStorage::FindUri(util::string_view uri) const { + auto uri_itr = uris_.find(uri); + if (uri_itr == uris_.end()) { + return util::nullopt; + } + return *uri_itr; +} + +util::string_view IdStorage::EmplaceUri(util::string_view uri) { + auto uri_itr = uris_.find(uri); + if (uri_itr == uris_.end()) { + owned_uris_.emplace_back(uri); + util::string_view owned_uri = owned_uris_.back(); + uris_.insert(owned_uri); + return owned_uri; + } + return *uri_itr; +} + +Result> SubstraitCall::GetEnumArg( + uint32_t index) const { + if (index >= size_) { + return Status::Invalid("Expected Substrait call to have an enum argument at index ", + index, " but it did not have enough arguments"); + } + auto enum_arg_it = enum_args_.find(index); + if (enum_arg_it == enum_args_.end()) { + return Status::Invalid("Expected Substrait call to have an enum argument at index ", + index, " but the argument was not an enum."); + } + return enum_arg_it->second; +} + +bool SubstraitCall::HasEnumArg(uint32_t index) const { + return enum_args_.find(index) != enum_args_.end(); +} + +void SubstraitCall::SetEnumArg(uint32_t index, util::optional enum_arg) { + size_ = std::max(size_, index + 1); + enum_args_[index] = std::move(enum_arg); +} + +Result SubstraitCall::GetValueArg(uint32_t index) const { + if (index >= size_) { + return Status::Invalid("Expected Substrait call to have a value argument at index ", + index, " but it did not have enough arguments"); + } + auto value_arg_it = value_args_.find(index); + if (value_arg_it == value_args_.end()) { + return Status::Invalid("Expected Substrait call to have a value argument at index ", + index, " but the argument was not a value"); + } + return value_arg_it->second; +} + +bool SubstraitCall::HasValueArg(uint32_t index) const { + return value_args_.find(index) != value_args_.end(); +} + +void SubstraitCall::SetValueArg(uint32_t index, compute::Expression value_arg) { + size_ = std::max(size_, index + 1); + value_args_[index] = std::move(value_arg); } // A builder used when creating a Substrait plan from an Arrow execution plan. In @@ -97,54 +199,54 @@ Result ExtensionSet::Make( std::unordered_map uris, std::unordered_map type_ids, std::unordered_map function_ids, const ExtensionIdRegistry* registry) { - ExtensionSet set; + ExtensionSet set(default_extension_id_registry()); set.registry_ = registry; - // TODO(bkietz) move this into the registry as registry->OwnUris(&uris) or so - std::unordered_set - uris_owned_by_registry; - for (util::string_view uri : registry->Uris()) { - uris_owned_by_registry.insert(uri); - } - for (auto& uri : uris) { - auto it = uris_owned_by_registry.find(uri.second); - if (it == uris_owned_by_registry.end()) { - return Status::KeyError("Uri '", uri.second, "' not found in registry"); + util::optional maybe_uri_internal = registry->FindUri(uri.second); + if (maybe_uri_internal) { + set.uris_[uri.first] = *maybe_uri_internal; + } else { + if (kExactRoundTrip) { + return Status::Invalid( + "Plan contained a URI that the extension registry is unaware of: ", + uri.second); + } + set.uris_[uri.first] = set.plan_specific_ids_.EmplaceUri(uri.second); } - uri.second = *it; // Ensure uris point into the registry's memory - set.AddUri(uri); } set.types_.reserve(type_ids.size()); + for (const auto& type_id : type_ids) { + if (type_id.second.empty()) continue; + RETURN_NOT_OK(set.CheckHasUri(type_id.second.uri)); - for (unsigned int i = 0; i < static_cast(type_ids.size()); ++i) { - if (type_ids[i].empty()) continue; - RETURN_NOT_OK(set.CheckHasUri(type_ids[i].uri)); - - if (auto rec = registry->GetType(type_ids[i])) { - set.types_[i] = {rec->id, rec->type}; + if (auto rec = registry->GetType(type_id.second)) { + set.types_[type_id.first] = {rec->id, rec->type}; continue; } - return Status::Invalid("Type ", type_ids[i].uri, "#", type_ids[i].name, " not found"); + return Status::Invalid("Type ", type_id.second.uri, "#", type_id.second.name, + " not found"); } set.functions_.reserve(function_ids.size()); - - for (unsigned int i = 0; i < static_cast(function_ids.size()); ++i) { - if (function_ids[i].empty()) continue; - RETURN_NOT_OK(set.CheckHasUri(function_ids[i].uri)); - - if (auto rec = registry->GetFunction(function_ids[i])) { - set.functions_[i] = {rec->id, rec->function_name}; - continue; + for (const auto& function_id : function_ids) { + if (function_id.second.empty()) continue; + RETURN_NOT_OK(set.CheckHasUri(function_id.second.uri)); + util::optional maybe_id_internal = registry->FindId(function_id.second); + if (maybe_id_internal) { + set.functions_[function_id.first] = *maybe_id_internal; + } else { + if (kExactRoundTrip) { + return Status::Invalid( + "Plan contained a function id that the extension registry is unaware of: ", + function_id.second.uri, "#", function_id.second.name); + } + set.functions_[function_id.first] = + set.plan_specific_ids_.Emplace(function_id.second); } - return Status::Invalid("Function ", function_ids[i].uri, "#", function_ids[i].name, - " not found"); } - set.uris_ = std::move(uris); - return std::move(set); } @@ -162,39 +264,34 @@ Result ExtensionSet::EncodeType(const DataType& type) { auto it_success = types_map_.emplace(rec->id, static_cast(types_map_.size())); if (it_success.second) { - DCHECK_EQ(types_.find(static_cast(types_.size())), types_.end()) + DCHECK_EQ(types_.find(static_cast(types_.size())), types_.end()) << "Type existed in types_ but not types_map_. ExtensionSet is inconsistent"; - types_[static_cast(types_.size())] = {rec->id, rec->type}; + types_[static_cast(types_.size())] = {rec->id, rec->type}; } return it_success.first->second; } return Status::KeyError("type ", type.ToString(), " not found in the registry"); } -Result ExtensionSet::DecodeFunction(uint32_t anchor) const { - if (functions_.find(anchor) == functions_.end() || functions_.at(anchor).id.empty()) { +Result ExtensionSet::DecodeFunction(uint32_t anchor) const { + if (functions_.find(anchor) == functions_.end() || functions_.at(anchor).empty()) { return Status::Invalid("User defined function reference ", anchor, " did not have a corresponding anchor in the extension set"); } return functions_.at(anchor); } -Result ExtensionSet::EncodeFunction(util::string_view function_name) { - if (auto rec = registry_->GetFunction(function_name)) { - RETURN_NOT_OK(this->AddUri(rec->id)); - auto it_success = - functions_map_.emplace(rec->id, static_cast(functions_map_.size())); - if (it_success.second) { - DCHECK_EQ(functions_.find(static_cast(functions_.size())), - functions_.end()) - << "Function existed in functions_ but not functions_map_. ExtensionSet is " - "inconsistent"; - functions_[static_cast(functions_.size())] = {rec->id, - rec->function_name}; - } - return it_success.first->second; +Result ExtensionSet::EncodeFunction(Id function_id) { + RETURN_NOT_OK(this->AddUri(function_id)); + auto it_success = + functions_map_.emplace(function_id, static_cast(functions_map_.size())); + if (it_success.second) { + DCHECK_EQ(functions_.find(static_cast(functions_.size())), functions_.end()) + << "Function existed in functions_ but not functions_map_. ExtensionSet is " + "inconsistent"; + functions_[static_cast(functions_.size())] = function_id; } - return Status::KeyError("function ", function_name, " not found in the registry"); + return it_success.first->second; } template @@ -207,16 +304,38 @@ const int* GetIndex(const KeyToIndex& key_to_index, const Key& key) { namespace { struct ExtensionIdRegistryImpl : ExtensionIdRegistry { + ExtensionIdRegistryImpl() : parent_(nullptr) {} + explicit ExtensionIdRegistryImpl(const ExtensionIdRegistry* parent) : parent_(parent) {} + virtual ~ExtensionIdRegistryImpl() {} - std::vector Uris() const override { - return {uris_.begin(), uris_.end()}; + util::optional FindUri(util::string_view uri) const override { + if (parent_) { + util::optional parent_uri = parent_->FindUri(uri); + if (parent_uri) { + return parent_uri; + } + } + return ids_.FindUri(uri); + } + + util::optional FindId(Id id) const override { + if (parent_) { + util::optional parent_id = parent_->FindId(id); + if (parent_id) { + return parent_id; + } + } + return ids_.Find(id); } util::optional GetType(const DataType& type) const override { if (auto index = GetIndex(type_to_index_, &type)) { return TypeRecord{type_ids_[*index], types_[*index]}; } + if (parent_) { + return parent_->GetType(type); + } return {}; } @@ -224,6 +343,9 @@ struct ExtensionIdRegistryImpl : ExtensionIdRegistry { if (auto index = GetIndex(id_to_index_, id)) { return TypeRecord{type_ids_[*index], types_[*index]}; } + if (parent_) { + return parent_->GetType(id); + } return {}; } @@ -234,14 +356,20 @@ struct ExtensionIdRegistryImpl : ExtensionIdRegistry { if (type_to_index_.find(&*type) != type_to_index_.end()) { return Status::Invalid("Type was already registered"); } + if (parent_) { + return parent_->CanRegisterType(id, type); + } return Status::OK(); } Status RegisterType(Id id, std::shared_ptr type) override { DCHECK_EQ(type_ids_.size(), types_.size()); - Id copied_id{*uris_.emplace(id.uri.to_string()).first, - *names_.emplace(id.name.to_string()).first}; + if (parent_) { + ARROW_RETURN_NOT_OK(parent_->CanRegisterType(id, type)); + } + + Id copied_id = ids_.Emplace(id); auto index = static_cast(type_ids_.size()); @@ -261,155 +389,394 @@ struct ExtensionIdRegistryImpl : ExtensionIdRegistry { return Status::OK(); } - util::optional GetFunction( - util::string_view arrow_function_name) const override { - if (auto index = GetIndex(function_name_to_index_, arrow_function_name)) { - return FunctionRecord{function_ids_[*index], *function_name_ptrs_[*index]}; + Status CanAddSubstraitCallToArrow(Id substrait_function_id) const override { + if (substrait_to_arrow_.find(substrait_function_id) != substrait_to_arrow_.end()) { + return Status::Invalid("Cannot register function converter for Substrait id ", + substrait_function_id.ToString(), + " because a converter already exists"); } - return {}; + if (parent_) { + return parent_->CanAddSubstraitCallToArrow(substrait_function_id); + } + return Status::OK(); } - util::optional GetFunction(Id id) const override { - if (auto index = GetIndex(function_id_to_index_, id)) { - return FunctionRecord{function_ids_[*index], *function_name_ptrs_[*index]}; + Status CanAddSubstraitAggregateToArrow(Id substrait_function_id) const override { + if (substrait_to_arrow_agg_.find(substrait_function_id) != + substrait_to_arrow_agg_.end()) { + return Status::Invalid( + "Cannot register aggregate function converter for Substrait id ", + substrait_function_id.ToString(), + " because an aggregate converter already exists"); } - return {}; + if (parent_) { + return parent_->CanAddSubstraitAggregateToArrow(substrait_function_id); + } + return Status::OK(); + } + + template + Status AddSubstraitToArrowFunc( + Id substrait_id, ConverterType conversion_func, + std::unordered_map* dest) { + // Convert id to view into registry-owned memory + Id copied_id = ids_.Emplace(substrait_id); + + auto add_result = dest->emplace(copied_id, std::move(conversion_func)); + if (!add_result.second) { + return Status::Invalid( + "Failed to register Substrait to Arrow function converter because a converter " + "already existed for Substrait id ", + substrait_id.ToString()); + } + + return Status::OK(); + } + + Status AddSubstraitCallToArrow(Id substrait_function_id, + SubstraitCallToArrow conversion_func) override { + if (parent_) { + ARROW_RETURN_NOT_OK(parent_->CanAddSubstraitCallToArrow(substrait_function_id)); + } + return AddSubstraitToArrowFunc( + substrait_function_id, std::move(conversion_func), &substrait_to_arrow_); } - Status CanRegisterFunction(Id id, - const std::string& arrow_function_name) const override { - if (function_id_to_index_.find(id) != function_id_to_index_.end()) { - return Status::Invalid("Function id was already registered"); + Status AddSubstraitAggregateToArrow( + Id substrait_function_id, SubstraitAggregateToArrow conversion_func) override { + if (parent_) { + ARROW_RETURN_NOT_OK( + parent_->CanAddSubstraitAggregateToArrow(substrait_function_id)); } - if (function_name_to_index_.find(arrow_function_name) != - function_name_to_index_.end()) { - return Status::Invalid("Function name was already registered"); + return AddSubstraitToArrowFunc( + substrait_function_id, std::move(conversion_func), &substrait_to_arrow_agg_); + } + + template + Status AddArrowToSubstraitFunc(std::string arrow_function_name, ConverterType converter, + std::unordered_map* dest) { + auto add_result = dest->emplace(std::move(arrow_function_name), std::move(converter)); + if (!add_result.second) { + return Status::Invalid( + "Failed to register Arrow to Substrait function converter for Arrow function ", + arrow_function_name, " because a converter already existed"); } return Status::OK(); } - Status RegisterFunction(Id id, std::string arrow_function_name) override { - DCHECK_EQ(function_ids_.size(), function_name_ptrs_.size()); + Status AddArrowToSubstraitCall(std::string arrow_function_name, + ArrowToSubstraitCall converter) override { + if (parent_) { + ARROW_RETURN_NOT_OK(parent_->CanAddArrowToSubstraitCall(arrow_function_name)); + } + return AddArrowToSubstraitFunc(std::move(arrow_function_name), converter, + &arrow_to_substrait_); + } - Id copied_id{*uris_.emplace(id.uri.to_string()).first, - *names_.emplace(id.name.to_string()).first}; + Status AddArrowToSubstraitAggregate(std::string arrow_function_name, + ArrowToSubstraitAggregate converter) override { + if (parent_) { + ARROW_RETURN_NOT_OK(parent_->CanAddArrowToSubstraitAggregate(arrow_function_name)); + } + return AddArrowToSubstraitFunc(std::move(arrow_function_name), converter, + &arrow_to_substrait_agg_); + } - const std::string& copied_function_name{ - *function_names_.emplace(std::move(arrow_function_name)).first}; + Status CanAddArrowToSubstraitCall(const std::string& function_name) const override { + if (arrow_to_substrait_.find(function_name) != arrow_to_substrait_.end()) { + return Status::Invalid( + "Cannot register function converter because a converter already exists"); + } + if (parent_) { + return parent_->CanAddArrowToSubstraitCall(function_name); + } + return Status::OK(); + } - auto index = static_cast(function_ids_.size()); + Status CanAddArrowToSubstraitAggregate( + const std::string& function_name) const override { + if (arrow_to_substrait_agg_.find(function_name) != arrow_to_substrait_agg_.end()) { + return Status::Invalid( + "Cannot register function converter because a converter already exists"); + } + if (parent_) { + return parent_->CanAddArrowToSubstraitAggregate(function_name); + } + return Status::OK(); + } - auto it_success = function_id_to_index_.emplace(copied_id, index); + Result GetSubstraitCallToArrow( + Id substrait_function_id) const override { + auto maybe_converter = substrait_to_arrow_.find(substrait_function_id); + if (maybe_converter == substrait_to_arrow_.end()) { + if (parent_) { + return parent_->GetSubstraitCallToArrow(substrait_function_id); + } + return Status::NotImplemented( + "No conversion function exists to convert the Substrait function ", + substrait_function_id.uri, "#", substrait_function_id.name, + " to an Arrow call expression"); + } + return maybe_converter->second; + } - if (!it_success.second) { - return Status::Invalid("Function id was already registered"); + Result GetSubstraitAggregateToArrow( + Id substrait_function_id) const override { + auto maybe_converter = substrait_to_arrow_agg_.find(substrait_function_id); + if (maybe_converter == substrait_to_arrow_agg_.end()) { + if (parent_) { + return parent_->GetSubstraitAggregateToArrow(substrait_function_id); + } + return Status::NotImplemented( + "No conversion function exists to convert the Substrait aggregate function ", + substrait_function_id.uri, "#", substrait_function_id.name, + " to an Arrow aggregate"); } + return maybe_converter->second; + } - if (!function_name_to_index_.emplace(copied_function_name, index).second) { - function_id_to_index_.erase(it_success.first); - return Status::Invalid("Function name was already registered"); + Result GetArrowToSubstraitCall( + const std::string& arrow_function_name) const override { + auto maybe_converter = arrow_to_substrait_.find(arrow_function_name); + if (maybe_converter == arrow_to_substrait_.end()) { + if (parent_) { + return parent_->GetArrowToSubstraitCall(arrow_function_name); + } + return Status::NotImplemented( + "No conversion function exists to convert the Arrow function ", + arrow_function_name, " to a Substrait call"); } + return maybe_converter->second; + } - function_name_ptrs_.push_back(&copied_function_name); - function_ids_.push_back(copied_id); - return Status::OK(); + Result GetArrowToSubstraitAggregate( + const std::string& arrow_function_name) const override { + auto maybe_converter = arrow_to_substrait_agg_.find(arrow_function_name); + if (maybe_converter == arrow_to_substrait_agg_.end()) { + if (parent_) { + return parent_->GetArrowToSubstraitAggregate(arrow_function_name); + } + return Status::NotImplemented( + "No conversion function exists to convert the Arrow aggregate ", + arrow_function_name, " to a Substrait aggregate"); + } + return maybe_converter->second; } - Status RegisterFunction(std::string uri, std::string name, - std::string arrow_function_name) override { - return RegisterFunction({uri, name}, arrow_function_name); + std::vector GetSupportedSubstraitFunctions() const override { + std::vector encoded_ids; + for (const auto& entry : substrait_to_arrow_) { + encoded_ids.push_back(entry.first.ToString()); + } + for (const auto& entry : substrait_to_arrow_agg_) { + encoded_ids.push_back(entry.first.ToString()); + } + if (parent_) { + std::vector parent_ids = parent_->GetSupportedSubstraitFunctions(); + encoded_ids.insert(encoded_ids.end(), make_move_iterator(parent_ids.begin()), + make_move_iterator(parent_ids.end())); + } + std::sort(encoded_ids.begin(), encoded_ids.end()); + return encoded_ids; } - // owning storage of uris, names, (arrow::)function_names, types - // note that storing strings like this is safe since references into an - // unordered_set are not invalidated on insertion - std::unordered_set uris_, names_, function_names_; + // Defined below since it depends on some helper functions defined below + Status AddSubstraitCallToArrow(Id substrait_function_id, + std::string arrow_function_name) override; + + // Parent registry, null for the root, non-null for nested + const ExtensionIdRegistry* parent_; + + // owning storage of ids & types + IdStorage ids_; DataTypeVector types_; + // There should only be one entry per Arrow function so there is no need + // to separate ownership and lookup + std::unordered_map arrow_to_substrait_; + std::unordered_map arrow_to_substrait_agg_; // non-owning lookup helpers - std::vector type_ids_, function_ids_; + std::vector type_ids_; std::unordered_map id_to_index_; std::unordered_map type_to_index_; - - std::vector function_name_ptrs_; - std::unordered_map function_id_to_index_; - std::unordered_map - function_name_to_index_; + std::unordered_map substrait_to_arrow_; + std::unordered_map + substrait_to_arrow_agg_; }; -struct NestedExtensionIdRegistryImpl : ExtensionIdRegistryImpl { - explicit NestedExtensionIdRegistryImpl(const ExtensionIdRegistry* parent) - : parent_(parent) {} - - virtual ~NestedExtensionIdRegistryImpl() {} +template +using EnumParser = std::function(util::optional)>; - std::vector Uris() const override { - std::vector uris = parent_->Uris(); - std::unordered_set uri_set; - uri_set.insert(uris.begin(), uris.end()); - uri_set.insert(uris_.begin(), uris_.end()); - return std::vector(uris); +template +EnumParser GetEnumParser(const std::vector& options) { + std::unordered_map parse_map; + for (std::size_t i = 0; i < options.size(); i++) { + parse_map[options[i]] = static_cast(i + 1); } - - util::optional GetType(const DataType& type) const override { - auto type_opt = ExtensionIdRegistryImpl::GetType(type); - if (type_opt) { - return type_opt; + return [parse_map](util::optional enum_val) -> Result { + if (!enum_val) { + // Assumes 0 is always kUnspecified in Enum + return static_cast(0); } - return parent_->GetType(type); - } - - util::optional GetType(Id id) const override { - auto type_opt = ExtensionIdRegistryImpl::GetType(id); - if (type_opt) { - return type_opt; + auto maybe_parsed = parse_map.find(enum_val->to_string()); + if (maybe_parsed == parse_map.end()) { + return Status::Invalid("The value ", *enum_val, " is not an expected enum value"); } - return parent_->GetType(id); - } + return maybe_parsed->second; + }; +} - Status CanRegisterType(Id id, const std::shared_ptr& type) const override { - return parent_->CanRegisterType(id, type) & - ExtensionIdRegistryImpl::CanRegisterType(id, type); - } +enum class TemporalComponent { kUnspecified = 0, kYear, kMonth, kDay, kSecond }; +static std::vector kTemporalComponentOptions = {"YEAR", "MONTH", "DAY", + "SECOND"}; +static EnumParser kTemporalComponentParser = + GetEnumParser(kTemporalComponentOptions); + +enum class OverflowBehavior { kUnspecified = 0, kSilent, kSaturate, kError }; +static std::vector kOverflowOptions = {"SILENT", "SATURATE", "ERROR"}; +static EnumParser kOverflowParser = + GetEnumParser(kOverflowOptions); + +template +Result ParseEnumArg(const SubstraitCall& call, uint32_t arg_index, + const EnumParser& parser) { + ARROW_ASSIGN_OR_RAISE(util::optional enum_arg, + call.GetEnumArg(arg_index)); + return parser(enum_arg); +} - Status RegisterType(Id id, std::shared_ptr type) override { - return parent_->CanRegisterType(id, type) & - ExtensionIdRegistryImpl::RegisterType(id, type); +Result> GetValueArgs(const SubstraitCall& call, + int start_index) { + std::vector expressions; + for (uint32_t index = start_index; index < call.size(); index++) { + ARROW_ASSIGN_OR_RAISE(compute::Expression arg, call.GetValueArg(index)); + expressions.push_back(arg); } + return std::move(expressions); +} - util::optional GetFunction( - util::string_view arrow_function_name) const override { - auto func_opt = ExtensionIdRegistryImpl::GetFunction(arrow_function_name); - if (func_opt) { - return func_opt; +ExtensionIdRegistry::SubstraitCallToArrow DecodeOptionlessOverflowableArithmetic( + const std::string& function_name) { + return [function_name](const SubstraitCall& call) -> Result { + ARROW_ASSIGN_OR_RAISE(OverflowBehavior overflow_behavior, + ParseEnumArg(call, 0, kOverflowParser)); + ARROW_ASSIGN_OR_RAISE(std::vector value_args, + GetValueArgs(call, 1)); + if (overflow_behavior == OverflowBehavior::kUnspecified) { + overflow_behavior = OverflowBehavior::kSilent; } - return parent_->GetFunction(arrow_function_name); - } + if (overflow_behavior == OverflowBehavior::kSilent) { + return arrow::compute::call(function_name, std::move(value_args)); + } else if (overflow_behavior == OverflowBehavior::kError) { + return arrow::compute::call(function_name + "_checked", std::move(value_args)); + } else { + return Status::NotImplemented( + "Only SILENT and ERROR arithmetic kernels are currently implemented but ", + kOverflowOptions[static_cast(overflow_behavior) - 1], " was requested"); + } + }; +} - util::optional GetFunction(Id id) const override { - auto func_opt = ExtensionIdRegistryImpl::GetFunction(id); - if (func_opt) { - return func_opt; +template +ExtensionIdRegistry::ArrowToSubstraitCall EncodeOptionlessOverflowableArithmetic( + Id substrait_fn_id) { + return + [substrait_fn_id](const compute::Expression::Call& call) -> Result { + // nullable=true isn't quite correct but we don't know the nullability of + // the inputs + SubstraitCall substrait_call(substrait_fn_id, call.type.GetSharedPtr(), + /*nullable=*/true); + if (kChecked) { + substrait_call.SetEnumArg(0, "ERROR"); + } else { + substrait_call.SetEnumArg(0, "SILENT"); + } + for (std::size_t i = 0; i < call.arguments.size(); i++) { + substrait_call.SetValueArg(static_cast(i + 1), call.arguments[i]); + } + return std::move(substrait_call); + }; +} + +ExtensionIdRegistry::SubstraitCallToArrow DecodeOptionlessBasicMapping( + const std::string& function_name, uint32_t max_args) { + return [function_name, + max_args](const SubstraitCall& call) -> Result { + if (call.size() > max_args) { + return Status::NotImplemented("Acero does not have a kernel for ", function_name, + " that receives ", call.size(), " arguments"); } - return parent_->GetFunction(id); - } + ARROW_ASSIGN_OR_RAISE(std::vector value_args, + GetValueArgs(call, 0)); + return arrow::compute::call(function_name, std::move(value_args)); + }; +} - Status CanRegisterFunction(Id id, - const std::string& arrow_function_name) const override { - return parent_->CanRegisterFunction(id, arrow_function_name) & - ExtensionIdRegistryImpl::CanRegisterFunction(id, arrow_function_name); - } +ExtensionIdRegistry::SubstraitCallToArrow DecodeTemporalExtractionMapping() { + return [](const SubstraitCall& call) -> Result { + ARROW_ASSIGN_OR_RAISE(TemporalComponent temporal_component, + ParseEnumArg(call, 0, kTemporalComponentParser)); + if (temporal_component == TemporalComponent::kUnspecified) { + return Status::Invalid( + "The temporal component enum is a require option for the extract function " + "and is not specified"); + } + ARROW_ASSIGN_OR_RAISE(std::vector value_args, + GetValueArgs(call, 1)); + std::string func_name; + switch (temporal_component) { + case TemporalComponent::kYear: + func_name = "year"; + break; + case TemporalComponent::kMonth: + func_name = "month"; + break; + case TemporalComponent::kDay: + func_name = "day"; + break; + case TemporalComponent::kSecond: + func_name = "second"; + break; + default: + return Status::Invalid("Unexpected value for temporal component in extract call"); + } + return compute::call(func_name, std::move(value_args)); + }; +} - Status RegisterFunction(Id id, std::string arrow_function_name) override { - return parent_->CanRegisterFunction(id, arrow_function_name) & - ExtensionIdRegistryImpl::RegisterFunction(id, arrow_function_name); - } +ExtensionIdRegistry::SubstraitCallToArrow DecodeConcatMapping() { + return [](const SubstraitCall& call) -> Result { + ARROW_ASSIGN_OR_RAISE(std::vector value_args, + GetValueArgs(call, 0)); + value_args.push_back(compute::literal("")); + return compute::call("binary_join_element_wise", std::move(value_args)); + }; +} - const ExtensionIdRegistry* parent_; -}; +ExtensionIdRegistry::SubstraitAggregateToArrow DecodeBasicAggregate( + const std::string& arrow_function_name) { + return [arrow_function_name](const SubstraitCall& call) -> Result { + if (call.size() != 1) { + return Status::NotImplemented( + "Only unary aggregate functions are currently supported"); + } + ARROW_ASSIGN_OR_RAISE(compute::Expression arg, call.GetValueArg(0)); + const FieldRef* arg_ref = arg.field_ref(); + if (!arg_ref) { + return Status::Invalid("Expected an aggregate call ", call.id().uri, "#", + call.id().name, " to have a direct reference"); + } + std::string fixed_arrow_func = arrow_function_name; + if (call.is_hash()) { + fixed_arrow_func = "hash_" + arrow_function_name; + } + return compute::Aggregate{std::move(fixed_arrow_func), nullptr, *arg_ref, ""}; + }; +} struct DefaultExtensionIdRegistry : ExtensionIdRegistryImpl { DefaultExtensionIdRegistry() { + // ----------- Extension Types ---------------------------- struct TypeName { std::shared_ptr type; util::string_view name; @@ -428,32 +795,91 @@ struct DefaultExtensionIdRegistry : ExtensionIdRegistryImpl { DCHECK_OK(RegisterType({kArrowExtTypesUri, e.name}, std::move(e.type))); } - for (TypeName e : { - TypeName{null(), "null"}, - TypeName{month_interval(), "interval_month"}, - TypeName{day_time_interval(), "interval_day_milli"}, - TypeName{month_day_nano_interval(), "interval_month_day_nano"}, - }) { + for (TypeName e : + {TypeName{null(), "null"}, TypeName{month_interval(), "interval_month"}, + TypeName{day_time_interval(), "interval_day_milli"}, + TypeName{month_day_nano_interval(), "interval_month_day_nano"}}) { DCHECK_OK(RegisterType({kArrowExtTypesUri, e.name}, std::move(e.type))); } - // TODO: this is just a placeholder right now. We'll need a YAML file for - // all functions (and prototypes) that Arrow provides that are relevant - // for Substrait, and include mappings for all of them here. See - // ARROW-15535. - for (util::string_view name : { - "add", - "equal", - "is_not_distinct_from", - "hash_count", - }) { - DCHECK_OK(RegisterFunction({kArrowExtTypesUri, name}, name.to_string())); + // -------------- Substrait -> Arrow Functions ----------------- + // Mappings with a _checked variant + for (const auto& function_name : {"add", "subtract", "multiply", "divide"}) { + DCHECK_OK( + AddSubstraitCallToArrow({kSubstraitArithmeticFunctionsUri, function_name}, + DecodeOptionlessOverflowableArithmetic(function_name))); + } + // Basic mappings that need _kleene appended to them + for (const auto& function_name : {"or", "and"}) { + DCHECK_OK(AddSubstraitCallToArrow( + {kSubstraitBooleanFunctionsUri, function_name}, + DecodeOptionlessBasicMapping(std::string(function_name) + "_kleene", + /*max_args=*/2))); + } + // Basic binary mappings + for (const auto& function_name : + std::vector>{ + {kSubstraitBooleanFunctionsUri, "xor"}, + {kSubstraitComparisonFunctionsUri, "equal"}, + {kSubstraitComparisonFunctionsUri, "not_equal"}}) { + DCHECK_OK( + AddSubstraitCallToArrow({function_name.first, function_name.second}, + DecodeOptionlessBasicMapping( + function_name.second.to_string(), /*max_args=*/2))); + } + for (const auto& uri : + {kSubstraitComparisonFunctionsUri, kSubstraitDatetimeFunctionsUri}) { + DCHECK_OK(AddSubstraitCallToArrow( + {uri, "lt"}, DecodeOptionlessBasicMapping("less", /*max_args=*/2))); + DCHECK_OK(AddSubstraitCallToArrow( + {uri, "lte"}, DecodeOptionlessBasicMapping("less_equal", /*max_args=*/2))); + DCHECK_OK(AddSubstraitCallToArrow( + {uri, "gt"}, DecodeOptionlessBasicMapping("greater", /*max_args=*/2))); + DCHECK_OK(AddSubstraitCallToArrow( + {uri, "gte"}, DecodeOptionlessBasicMapping("greater_equal", /*max_args=*/2))); + } + // One-off mappings + DCHECK_OK( + AddSubstraitCallToArrow({kSubstraitBooleanFunctionsUri, "not"}, + DecodeOptionlessBasicMapping("invert", /*max_args=*/1))); + DCHECK_OK(AddSubstraitCallToArrow({kSubstraitDatetimeFunctionsUri, "extract"}, + DecodeTemporalExtractionMapping())); + DCHECK_OK(AddSubstraitCallToArrow({kSubstraitStringFunctionsUri, "concat"}, + DecodeConcatMapping())); + + // --------------- Substrait -> Arrow Aggregates -------------- + for (const auto& fn_name : {"sum", "min", "max"}) { + DCHECK_OK(AddSubstraitAggregateToArrow({kSubstraitArithmeticFunctionsUri, fn_name}, + DecodeBasicAggregate(fn_name))); + } + DCHECK_OK(AddSubstraitAggregateToArrow({kSubstraitArithmeticFunctionsUri, "avg"}, + DecodeBasicAggregate("mean"))); + + // --------------- Arrow -> Substrait Functions --------------- + for (const auto& fn_name : {"add", "subtract", "multiply", "divide"}) { + Id fn_id{kSubstraitArithmeticFunctionsUri, fn_name}; + DCHECK_OK(AddArrowToSubstraitCall( + fn_name, EncodeOptionlessOverflowableArithmetic(fn_id))); + DCHECK_OK( + AddArrowToSubstraitCall(std::string(fn_name) + "_checked", + EncodeOptionlessOverflowableArithmetic(fn_id))); } } }; } // namespace +Status ExtensionIdRegistryImpl::AddSubstraitCallToArrow(Id substrait_function_id, + std::string arrow_function_name) { + return AddSubstraitCallToArrow( + substrait_function_id, + [arrow_function_name](const SubstraitCall& call) -> Result { + ARROW_ASSIGN_OR_RAISE(std::vector value_args, + GetValueArgs(call, 0)); + return compute::call(arrow_function_name, std::move(value_args)); + }); +} + ExtensionIdRegistry* default_extension_id_registry() { static DefaultExtensionIdRegistry impl_; return &impl_; @@ -461,7 +887,7 @@ ExtensionIdRegistry* default_extension_id_registry() { std::shared_ptr nested_extension_id_registry( const ExtensionIdRegistry* parent) { - return std::make_shared(parent); + return std::make_shared(parent); } } // namespace engine diff --git a/cpp/src/arrow/engine/substrait/extension_set.h b/cpp/src/arrow/engine/substrait/extension_set.h index 04e4586a9f5e2..9cb42f66136b9 100644 --- a/cpp/src/arrow/engine/substrait/extension_set.h +++ b/cpp/src/arrow/engine/substrait/extension_set.h @@ -19,26 +19,130 @@ #pragma once +#include #include +#include #include +#include "arrow/compute/exec/exec_plan.h" +#include "arrow/compute/exec/expression.h" #include "arrow/engine/substrait/visibility.h" +#include "arrow/result.h" #include "arrow/type_fwd.h" +#include "arrow/util/hash_util.h" +#include "arrow/util/hashing.h" #include "arrow/util/optional.h" #include "arrow/util/string_view.h" -#include "arrow/util/hash_util.h" - namespace arrow { namespace engine { +constexpr const char* kSubstraitArithmeticFunctionsUri = + "https://github.com/substrait-io/substrait/blob/main/extensions/" + "functions_arithmetic.yaml"; +constexpr const char* kSubstraitBooleanFunctionsUri = + "https://github.com/substrait-io/substrait/blob/main/extensions/" + "functions_boolean.yaml"; +constexpr const char* kSubstraitComparisonFunctionsUri = + "https://github.com/substrait-io/substrait/blob/main/extensions/" + "functions_comparison.yaml"; +constexpr const char* kSubstraitDatetimeFunctionsUri = + "https://github.com/substrait-io/substrait/blob/main/extensions/" + "functions_datetime.yaml"; +constexpr const char* kSubstraitStringFunctionsUri = + "https://github.com/substrait-io/substrait/blob/main/extensions/" + "functions_string.yaml"; + +struct Id { + util::string_view uri, name; + bool empty() const { return uri.empty() && name.empty(); } + std::string ToString() const; +}; +struct IdHashEq { + size_t operator()(Id id) const; + bool operator()(Id l, Id r) const; +}; + +/// \brief Owning storage for ids +/// +/// Substrait plans may reuse URIs and names in many places. For convenience +/// and performance Substarit ids are typically passed around as views. As we +/// convert a plan from Substrait to Arrow we need to copy these strings out of +/// the Substrait buffer and into owned storage. This class serves as that owned +/// storage. +class IdStorage { + public: + /// \brief Get an equivalent id pointing into this storage + /// + /// This operation will copy the ids into storage if they do not already exist + Id Emplace(Id id); + /// \brief Get an equivalent view pointing into this storage for a URI + /// + /// If no URI is found then the uri will be copied into storage + util::string_view EmplaceUri(util::string_view uri); + /// \brief Get an equivalent id pointing into this storage + /// + /// If no id is found then nullopt will be returned + util::optional Find(Id id) const; + /// \brief Get an equivalent view pointing into this storage for a URI + /// + /// If no URI is found then nullopt will be returned + util::optional FindUri(util::string_view uri) const; + + private: + std::unordered_set uris_; + std::unordered_set names_; + std::list owned_uris_; + std::list owned_names_; +}; + +/// \brief Describes a Substrait call +/// +/// Substrait call expressions contain a list of arguments which can either +/// be enum arguments (which are serialized as strings), value arguments (which) +/// are Arrow expressions, or type arguments (not yet implemented) +class SubstraitCall { + public: + SubstraitCall(Id id, std::shared_ptr output_type, bool output_nullable, + bool is_hash = false) + : id_(id), + output_type_(std::move(output_type)), + output_nullable_(output_nullable), + is_hash_(is_hash) {} + + const Id& id() const { return id_; } + const std::shared_ptr& output_type() const { return output_type_; } + bool output_nullable() const { return output_nullable_; } + bool is_hash() const { return is_hash_; } + + bool HasEnumArg(uint32_t index) const; + Result> GetEnumArg(uint32_t index) const; + void SetEnumArg(uint32_t index, util::optional enum_arg); + Result GetValueArg(uint32_t index) const; + bool HasValueArg(uint32_t index) const; + void SetValueArg(uint32_t index, compute::Expression value_arg); + uint32_t size() const { return size_; } + + private: + Id id_; + std::shared_ptr output_type_; + bool output_nullable_; + // Only needed when converting from Substrait -> Arrow aggregates. The + // Arrow function name depends on whether or not there are any groups + bool is_hash_; + std::unordered_map> enum_args_; + std::unordered_map value_args_; + uint32_t size_ = 0; +}; + /// Substrait identifies functions and custom data types using a (uri, name) pair. /// -/// This registry is a bidirectional mapping between Substrait IDs and their corresponding -/// Arrow counterparts (arrow::DataType and function names in a function registry) +/// This registry is a bidirectional mapping between Substrait IDs and their +/// corresponding Arrow counterparts (arrow::DataType and function names in a function +/// registry) /// -/// Substrait extension types and variations must be registered with their corresponding -/// arrow::DataType before they can be used! +/// Substrait extension types and variations must be registered with their +/// corresponding arrow::DataType before they can be used! /// /// Conceptually this can be thought of as two pairs of `unordered_map`s. One pair to /// go back and forth between Substrait ID and arrow::DataType and another pair to go @@ -49,56 +153,103 @@ namespace engine { /// instance). class ARROW_ENGINE_EXPORT ExtensionIdRegistry { public: - /// All uris registered in this ExtensionIdRegistry - virtual std::vector Uris() const = 0; - - struct Id { - util::string_view uri, name; - - bool empty() const { return uri.empty() && name.empty(); } - }; - - struct IdHashEq { - size_t operator()(Id id) const; - bool operator()(Id l, Id r) const; - }; + using ArrowToSubstraitCall = + std::function(const arrow::compute::Expression::Call&)>; + using SubstraitCallToArrow = + std::function(const SubstraitCall&)>; + using ArrowToSubstraitAggregate = + std::function(const arrow::compute::Aggregate&)>; + using SubstraitAggregateToArrow = + std::function(const SubstraitCall&)>; /// \brief A mapping between a Substrait ID and an arrow::DataType struct TypeRecord { Id id; const std::shared_ptr& type; }; + + /// \brief Return a uri view owned by this registry + /// + /// If the URI has never been emplaced it will return nullopt + virtual util::optional FindUri(util::string_view uri) const = 0; + /// \brief Return a id view owned by this registry + /// + /// If the id has never been emplaced it will return nullopt + virtual util::optional FindId(Id id) const = 0; virtual util::optional GetType(const DataType&) const = 0; virtual util::optional GetType(Id) const = 0; virtual Status CanRegisterType(Id, const std::shared_ptr& type) const = 0; virtual Status RegisterType(Id, std::shared_ptr) = 0; + /// \brief Register a converter that converts an Arrow call to a Substrait call + /// + /// Note that there may not be 1:1 parity between ArrowToSubstraitCall and + /// SubstraitCallToArrow because some standard functions (e.g. add) may map to + /// multiple Arrow functions (e.g. add, add_checked) + virtual Status AddArrowToSubstraitCall(std::string arrow_function_name, + ArrowToSubstraitCall conversion_func) = 0; + /// \brief Check to see if a converter can be registered + /// + /// \return Status::OK if there are no conflicts, otherwise an error is returned + virtual Status CanAddArrowToSubstraitCall( + const std::string& arrow_function_name) const = 0; - /// \brief A mapping between a Substrait ID and an Arrow function + /// \brief Register a converter that converts an Arrow aggregate to a Substrait + /// aggregate + virtual Status AddArrowToSubstraitAggregate( + std::string arrow_function_name, ArrowToSubstraitAggregate conversion_func) = 0; + /// \brief Check to see if a converter can be registered /// - /// Note: At the moment we identify functions solely by the name - /// of the function in the function registry. + /// \return Status::OK if there are no conflicts, otherwise an error is returned + virtual Status CanAddArrowToSubstraitAggregate( + const std::string& arrow_function_name) const = 0; + + /// \brief Register a converter that converts a Substrait call to an Arrow call + virtual Status AddSubstraitCallToArrow(Id substrait_function_id, + SubstraitCallToArrow conversion_func) = 0; + /// \brief Check to see if a converter can be registered /// - /// TODO(ARROW-15582) some functions will not be simple enough to convert without access - /// to their arguments/options. For example is_in embeds the set in options rather than - /// using an argument: - /// is_in(x, SetLookupOptions(set)) <-> (k...Uri, "is_in")(x, set) + /// \return Status::OK if there are no conflicts, otherwise an error is returned + virtual Status CanAddSubstraitCallToArrow(Id substrait_function_id) const = 0; + /// \brief Register a simple mapping function /// - /// ... for another example, depending on the value of the first argument to - /// substrait::add it either corresponds to arrow::add or arrow::add_checked - struct FunctionRecord { - Id id; - const std::string& function_name; - }; - virtual util::optional GetFunction(Id) const = 0; - virtual util::optional GetFunction( - util::string_view arrow_function_name) const = 0; - virtual Status CanRegisterFunction(Id, - const std::string& arrow_function_name) const = 0; - // registers a function without taking ownership of uri and name within Id - virtual Status RegisterFunction(Id, std::string arrow_function_name) = 0; - // registers a function while taking ownership of uri and name - virtual Status RegisterFunction(std::string uri, std::string name, - std::string arrow_function_name) = 0; + /// All calls to the function must pass only value arguments. The arguments + /// will be converted to expressions and passed to the Arrow function + virtual Status AddSubstraitCallToArrow(Id substrait_function_id, + std::string arrow_function_name) = 0; + + /// \brief Register a converter that converts a Substrait aggregate to an Arrow + /// aggregate + virtual Status AddSubstraitAggregateToArrow( + Id substrait_function_id, SubstraitAggregateToArrow conversion_func) = 0; + /// \brief Check to see if a converter can be registered + /// + /// \return Status::OK if there are no conflicts, otherwise an error is returned + virtual Status CanAddSubstraitAggregateToArrow(Id substrait_function_id) const = 0; + + /// \brief Return a list of Substrait functions that have a converter + /// + /// The function ids are encoded as strings using the pattern {uri}#{name} + virtual std::vector GetSupportedSubstraitFunctions() const = 0; + + /// \brief Find a converter to map Arrow calls to Substrait calls + /// \return A converter function or an invalid status if no converter is registered + virtual Result GetArrowToSubstraitCall( + const std::string& arrow_function_name) const = 0; + + /// \brief Find a converter to map Arrow aggregates to Substrait aggregates + /// \return A converter function or an invalid status if no converter is registered + virtual Result GetArrowToSubstraitAggregate( + const std::string& arrow_function_name) const = 0; + + /// \brief Find a converter to map a Substrait aggregate to an Arrow aggregate + /// \return A converter function or an invalid status if no converter is registered + virtual Result GetSubstraitAggregateToArrow( + Id substrait_function_id) const = 0; + + /// \brief Find a converter to map a Substrait call to an Arrow call + /// \return A converter function or an invalid status if no converter is registered + virtual Result GetSubstraitCallToArrow( + Id substrait_function_id) const = 0; }; constexpr util::string_view kArrowExtTypesUri = @@ -153,9 +304,6 @@ ARROW_ENGINE_EXPORT std::shared_ptr nested_extension_id_reg /// ExtensionIdRegistry. class ARROW_ENGINE_EXPORT ExtensionSet { public: - using Id = ExtensionIdRegistry::Id; - using IdHashEq = ExtensionIdRegistry::IdHashEq; - struct FunctionRecord { Id id; util::string_view name; @@ -219,12 +367,12 @@ class ARROW_ENGINE_EXPORT ExtensionSet { /// \return An anchor that can be used to refer to the type within a plan Result EncodeType(const DataType& type); - /// \brief Returns a function given an anchor + /// \brief Return a function id given an anchor /// /// This is used when converting a Substrait plan to an Arrow execution plan. /// /// If the anchor does not exist in this extension set an error will be returned. - Result DecodeFunction(uint32_t anchor) const; + Result DecodeFunction(uint32_t anchor) const; /// \brief Lookup the anchor for a given function /// @@ -239,26 +387,30 @@ class ARROW_ENGINE_EXPORT ExtensionSet { /// returned. /// /// \return An anchor that can be used to refer to the function within a plan - Result EncodeFunction(util::string_view function_name); + Result EncodeFunction(Id function_id); - /// \brief Returns the number of custom functions in this extension set - /// - /// Note: the functions are currently stored as a sparse vector, so this may return a - /// value larger than the actual number of functions. This behavior may change in the - /// future; see ARROW-15583. + /// \brief Return the number of custom functions in this extension set std::size_t num_functions() const { return functions_.size(); } + const ExtensionIdRegistry* registry() const { return registry_; } + private: const ExtensionIdRegistry* registry_; + // If the registry is not aware of an id then we probably can't do anything + // with it. However, in some cases, these may represent extensions or features + // that we can safely ignore. For example, we can usually safely ignore + // extension type variations if we assume the plan is valid. These ignorable + // ids are stored here. + IdStorage plan_specific_ids_; // Map from anchor values to URI values referenced by this extension set std::unordered_map uris_; // Map from anchor values to type definitions, used during Substrait->Arrow // and populated from the Substrait extension set std::unordered_map types_; - // Map from anchor values to function definitions, used during Substrait->Arrow + // Map from anchor values to function ids, used during Substrait->Arrow // and populated from the Substrait extension set - std::unordered_map functions_; + std::unordered_map functions_; // Map from type names to anchor values. Used during Arrow->Substrait // and built as the plan is created. std::unordered_map types_map_; diff --git a/cpp/src/arrow/engine/substrait/function_test.cc b/cpp/src/arrow/engine/substrait/function_test.cc new file mode 100644 index 0000000000000..225bc56d13681 --- /dev/null +++ b/cpp/src/arrow/engine/substrait/function_test.cc @@ -0,0 +1,495 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include +#include + +#include + +#include "arrow/array.h" +#include "arrow/array/builder_binary.h" +#include "arrow/compute/cast.h" +#include "arrow/compute/exec/options.h" +#include "arrow/compute/exec/util.h" +#include "arrow/engine/substrait/extension_set.h" +#include "arrow/engine/substrait/plan_internal.h" +#include "arrow/engine/substrait/serde.h" +#include "arrow/engine/substrait/test_plan_builder.h" +#include "arrow/engine/substrait/type_internal.h" +#include "arrow/record_batch.h" +#include "arrow/table.h" +#include "arrow/testing/future_util.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/type.h" + +namespace arrow { + +namespace engine { +struct FunctionTestCase { + Id function_id; + std::vector arguments; + std::vector> data_types; + // For a test case that should fail just use the empty string + std::string expected_output; + std::shared_ptr expected_output_type; +}; + +Result> GetArray(const std::string& value, + const std::shared_ptr& data_type) { + StringBuilder str_builder; + if (value.empty()) { + ARROW_EXPECT_OK(str_builder.AppendNull()); + } else { + ARROW_EXPECT_OK(str_builder.Append(value)); + } + ARROW_ASSIGN_OR_RAISE(std::shared_ptr value_str, str_builder.Finish()); + ARROW_ASSIGN_OR_RAISE(Datum value_datum, compute::Cast(value_str, data_type)); + return value_datum.make_array(); +} + +Result> GetInputTable( + const std::vector& arguments, + const std::vector>& data_types) { + std::vector> columns; + std::vector> fields; + EXPECT_EQ(arguments.size(), data_types.size()); + for (std::size_t i = 0; i < arguments.size(); i++) { + if (data_types[i]) { + ARROW_ASSIGN_OR_RAISE(std::shared_ptr arg_array, + GetArray(arguments[i], data_types[i])); + columns.push_back(std::move(arg_array)); + fields.push_back(field("arg_" + std::to_string(i), data_types[i])); + } + } + std::shared_ptr batch = + RecordBatch::Make(schema(std::move(fields)), 1, columns); + ARROW_ASSIGN_OR_RAISE(std::shared_ptr
table, Table::FromRecordBatches({batch})); + return table; +} + +Result> GetOutputTable( + const std::string& output_value, const std::shared_ptr& output_type) { + std::vector> columns(1); + std::vector> fields(1); + ARROW_ASSIGN_OR_RAISE(columns[0], GetArray(output_value, output_type)); + fields[0] = field("output", output_type); + std::shared_ptr batch = + RecordBatch::Make(schema(std::move(fields)), 1, columns); + ARROW_ASSIGN_OR_RAISE(std::shared_ptr
table, Table::FromRecordBatches({batch})); + return table; +} + +Result> PlanFromTestCase( + const FunctionTestCase& test_case, std::shared_ptr
* output_table) { + ARROW_ASSIGN_OR_RAISE(std::shared_ptr
input_table, + GetInputTable(test_case.arguments, test_case.data_types)); + ARROW_ASSIGN_OR_RAISE(std::shared_ptr substrait, + internal::CreateScanProjectSubstrait( + test_case.function_id, input_table, test_case.arguments, + test_case.data_types, *test_case.expected_output_type)); + std::shared_ptr consumer = + std::make_shared(output_table, + default_memory_pool()); + + // Mock table provider that ignores the table name and returns input_table + NamedTableProvider table_provider = [input_table](const std::vector&) { + std::shared_ptr options = + std::make_shared(input_table); + return compute::Declaration("table_source", {}, options, "mock_source"); + }; + + ConversionOptions conversion_options; + conversion_options.named_table_provider = std::move(table_provider); + + ARROW_ASSIGN_OR_RAISE( + std::shared_ptr plan, + DeserializePlan(*substrait, std::move(consumer), default_extension_id_registry(), + /*ext_set_out=*/nullptr, conversion_options)); + return plan; +} + +void CheckValidTestCases(const std::vector& valid_cases) { + for (const FunctionTestCase& test_case : valid_cases) { + std::shared_ptr
output_table; + ASSERT_OK_AND_ASSIGN(std::shared_ptr plan, + PlanFromTestCase(test_case, &output_table)); + ASSERT_OK(plan->StartProducing()); + ASSERT_FINISHES_OK(plan->finished()); + + // Could also modify the Substrait plan with an emit to drop the leading columns + ASSERT_OK_AND_ASSIGN(output_table, + output_table->SelectColumns({output_table->num_columns() - 1})); + + ASSERT_OK_AND_ASSIGN( + std::shared_ptr
expected_output, + GetOutputTable(test_case.expected_output, test_case.expected_output_type)); + AssertTablesEqual(*expected_output, *output_table, /*same_chunk_layout=*/false); + } +} + +void CheckErrorTestCases(const std::vector& error_cases) { + for (const FunctionTestCase& test_case : error_cases) { + std::shared_ptr
output_table; + ASSERT_OK_AND_ASSIGN(std::shared_ptr plan, + PlanFromTestCase(test_case, &output_table)); + ASSERT_OK(plan->StartProducing()); + ASSERT_FINISHES_AND_RAISES(Invalid, plan->finished()); + } +} + +// These are not meant to be an exhaustive test of Substrait +// conformance. Instead, we should test just enough to ensure +// we are mapping to the correct function +TEST(FunctionMapping, ValidCases) { + const std::vector valid_test_cases = { + {{kSubstraitArithmeticFunctionsUri, "add"}, + {"SILENT", "127", "10"}, + {nullptr, int8(), int8()}, + "-119", + int8()}, + {{kSubstraitArithmeticFunctionsUri, "subtract"}, + {"SILENT", "-119", "10"}, + {nullptr, int8(), int8()}, + "127", + int8()}, + {{kSubstraitArithmeticFunctionsUri, "multiply"}, + {"SILENT", "10", "13"}, + {nullptr, int8(), int8()}, + "-126", + int8()}, + {{kSubstraitArithmeticFunctionsUri, "divide"}, + {"SILENT", "-128", "-1"}, + {nullptr, int8(), int8()}, + "0", + int8()}, + {{kSubstraitBooleanFunctionsUri, "or"}, + {"1", ""}, + {boolean(), boolean()}, + "1", + boolean()}, + {{kSubstraitBooleanFunctionsUri, "and"}, + {"1", ""}, + {boolean(), boolean()}, + "", + boolean()}, + {{kSubstraitBooleanFunctionsUri, "xor"}, + {"1", "1"}, + {boolean(), boolean()}, + "0", + boolean()}, + {{kSubstraitBooleanFunctionsUri, "not"}, {"1"}, {boolean()}, "0", boolean()}, + {{kSubstraitComparisonFunctionsUri, "equal"}, + {"57", "57"}, + {int8(), int8()}, + "1", + boolean()}, + {{kSubstraitComparisonFunctionsUri, "not_equal"}, + {"57", "57"}, + {int8(), int8()}, + "0", + boolean()}, + {{kSubstraitComparisonFunctionsUri, "lt"}, + {"57", "80"}, + {int8(), int8()}, + "1", + boolean()}, + {{kSubstraitComparisonFunctionsUri, "lt"}, + {"57", "57"}, + {int8(), int8()}, + "0", + boolean()}, + {{kSubstraitComparisonFunctionsUri, "gt"}, + {"57", "30"}, + {int8(), int8()}, + "1", + boolean()}, + {{kSubstraitComparisonFunctionsUri, "gt"}, + {"57", "57"}, + {int8(), int8()}, + "0", + boolean()}, + {{kSubstraitComparisonFunctionsUri, "lte"}, + {"57", "57"}, + {int8(), int8()}, + "1", + boolean()}, + {{kSubstraitComparisonFunctionsUri, "lte"}, + {"50", "57"}, + {int8(), int8()}, + "1", + boolean()}, + {{kSubstraitComparisonFunctionsUri, "gte"}, + {"57", "57"}, + {int8(), int8()}, + "1", + boolean()}, + {{kSubstraitComparisonFunctionsUri, "gte"}, + {"60", "57"}, + {int8(), int8()}, + "1", + boolean()}, + {{kSubstraitDatetimeFunctionsUri, "extract"}, + {"YEAR", "2022-07-15T14:33:14"}, + {nullptr, timestamp(TimeUnit::MICRO)}, + "2022", + int64()}, + {{kSubstraitDatetimeFunctionsUri, "extract"}, + {"MONTH", "2022-07-15T14:33:14"}, + {nullptr, timestamp(TimeUnit::MICRO)}, + "7", + int64()}, + {{kSubstraitDatetimeFunctionsUri, "extract"}, + {"DAY", "2022-07-15T14:33:14"}, + {nullptr, timestamp(TimeUnit::MICRO)}, + "15", + int64()}, + {{kSubstraitDatetimeFunctionsUri, "extract"}, + {"SECOND", "2022-07-15T14:33:14"}, + {nullptr, timestamp(TimeUnit::MICRO)}, + "14", + int64()}, + {{kSubstraitDatetimeFunctionsUri, "extract"}, + {"YEAR", "2022-07-15T14:33:14Z"}, + {nullptr, timestamp(TimeUnit::MICRO, "UTC")}, + "2022", + int64()}, + {{kSubstraitDatetimeFunctionsUri, "extract"}, + {"MONTH", "2022-07-15T14:33:14Z"}, + {nullptr, timestamp(TimeUnit::MICRO, "UTC")}, + "7", + int64()}, + {{kSubstraitDatetimeFunctionsUri, "extract"}, + {"DAY", "2022-07-15T14:33:14Z"}, + {nullptr, timestamp(TimeUnit::MICRO, "UTC")}, + "15", + int64()}, + {{kSubstraitDatetimeFunctionsUri, "extract"}, + {"SECOND", "2022-07-15T14:33:14Z"}, + {nullptr, timestamp(TimeUnit::MICRO, "UTC")}, + "14", + int64()}, + {{kSubstraitDatetimeFunctionsUri, "lt"}, + {"2022-07-15T14:33:14", "2022-07-15T14:33:20"}, + {timestamp(TimeUnit::MICRO), timestamp(TimeUnit::MICRO)}, + "1", + boolean()}, + {{kSubstraitDatetimeFunctionsUri, "lte"}, + {"2022-07-15T14:33:14", "2022-07-15T14:33:14"}, + {timestamp(TimeUnit::MICRO), timestamp(TimeUnit::MICRO)}, + "1", + boolean()}, + {{kSubstraitDatetimeFunctionsUri, "gt"}, + {"2022-07-15T14:33:30", "2022-07-15T14:33:14"}, + {timestamp(TimeUnit::MICRO), timestamp(TimeUnit::MICRO)}, + "1", + boolean()}, + {{kSubstraitDatetimeFunctionsUri, "gte"}, + {"2022-07-15T14:33:14", "2022-07-15T14:33:14"}, + {timestamp(TimeUnit::MICRO), timestamp(TimeUnit::MICRO)}, + "1", + boolean()}, + {{kSubstraitStringFunctionsUri, "concat"}, + {"abc", "def"}, + {utf8(), utf8()}, + "abcdef", + utf8()}}; + CheckValidTestCases(valid_test_cases); +} + +TEST(FunctionMapping, ErrorCases) { + const std::vector error_test_cases = { + {{kSubstraitArithmeticFunctionsUri, "add"}, + {"ERROR", "127", "10"}, + {nullptr, int8(), int8()}, + "", + int8()}, + {{kSubstraitArithmeticFunctionsUri, "subtract"}, + {"ERROR", "-119", "10"}, + {nullptr, int8(), int8()}, + "", + int8()}, + {{kSubstraitArithmeticFunctionsUri, "multiply"}, + {"ERROR", "10", "13"}, + {nullptr, int8(), int8()}, + "", + int8()}, + {{kSubstraitArithmeticFunctionsUri, "divide"}, + {"ERROR", "-128", "-1"}, + {nullptr, int8(), int8()}, + "", + int8()}}; + CheckErrorTestCases(error_test_cases); +} + +// For each aggregate test case we take in three values. We compute the +// aggregate both on the entire set (all three values) and on groups. The +// first two rows will be in the first group and the last row will be in the +// second group. It's important to test both for coverage since the arrow +// function used actually changes when group ids are present +struct AggregateTestCase { + // The substrait function id + Id function_id; + // The three values, as a JSON string + std::string arguments; + // The data type of the three values + std::shared_ptr data_type; + // The result of the aggregate on all three + std::string combined_output; + // The result of the aggregate on each group (i.e. the first two rows + // and the last row). Should be a json-encoded array of size 2 + std::string group_outputs; + // The data type of the outputs + std::shared_ptr output_type; +}; + +std::shared_ptr
GetInputTableForAggregateCase(const AggregateTestCase& test_case) { + std::vector> columns(2); + std::vector> fields(2); + columns[0] = ArrayFromJSON(int8(), "[1, 1, 2]"); + columns[1] = ArrayFromJSON(test_case.data_type, test_case.arguments); + fields[0] = field("key", int8()); + fields[1] = field("value", test_case.data_type); + std::shared_ptr batch = + RecordBatch::Make(schema(std::move(fields)), /*num_rows=*/3, std::move(columns)); + EXPECT_OK_AND_ASSIGN(std::shared_ptr
table, Table::FromRecordBatches({batch})); + return table; +} + +std::shared_ptr
GetOutputTableForAggregateCase( + const std::shared_ptr& output_type, const std::string& json_data) { + std::shared_ptr out_arr = ArrayFromJSON(output_type, json_data); + std::shared_ptr batch = + RecordBatch::Make(schema({field("", output_type)}), 1, {out_arr}); + EXPECT_OK_AND_ASSIGN(std::shared_ptr
table, Table::FromRecordBatches({batch})); + return table; +} + +std::shared_ptr PlanFromAggregateCase( + const AggregateTestCase& test_case, std::shared_ptr
* output_table, + bool with_keys) { + std::shared_ptr
input_table = GetInputTableForAggregateCase(test_case); + std::vector key_idxs = {}; + if (with_keys) { + key_idxs = {0}; + } + EXPECT_OK_AND_ASSIGN( + std::shared_ptr substrait, + internal::CreateScanAggSubstrait(test_case.function_id, input_table, key_idxs, + /*arg_idx=*/1, *test_case.output_type)); + std::shared_ptr consumer = + std::make_shared(output_table, + default_memory_pool()); + + // Mock table provider that ignores the table name and returns input_table + NamedTableProvider table_provider = [input_table](const std::vector&) { + std::shared_ptr options = + std::make_shared(input_table); + return compute::Declaration("table_source", {}, options, "mock_source"); + }; + + ConversionOptions conversion_options; + conversion_options.named_table_provider = std::move(table_provider); + + EXPECT_OK_AND_ASSIGN( + std::shared_ptr plan, + DeserializePlan(*substrait, std::move(consumer), default_extension_id_registry(), + /*ext_set_out=*/nullptr, conversion_options)); + return plan; +} + +void CheckWholeAggregateCase(const AggregateTestCase& test_case) { + std::shared_ptr
output_table; + std::shared_ptr plan = + PlanFromAggregateCase(test_case, &output_table, /*with_keys=*/false); + + ASSERT_OK(plan->StartProducing()); + ASSERT_FINISHES_OK(plan->finished()); + + ASSERT_OK_AND_ASSIGN(output_table, + output_table->SelectColumns({output_table->num_columns() - 1})); + + std::shared_ptr
expected_output = + GetOutputTableForAggregateCase(test_case.output_type, test_case.combined_output); + AssertTablesEqual(*expected_output, *output_table, /*same_chunk_layout=*/false); +} + +void CheckGroupedAggregateCase(const AggregateTestCase& test_case) { + std::shared_ptr
output_table; + std::shared_ptr plan = + PlanFromAggregateCase(test_case, &output_table, /*with_keys=*/true); + + ASSERT_OK(plan->StartProducing()); + ASSERT_FINISHES_OK(plan->finished()); + + // The aggregate node's output is unpredictable so we sort by the key column + ASSERT_OK_AND_ASSIGN( + std::shared_ptr sort_indices, + compute::SortIndices(output_table, compute::SortOptions({compute::SortKey( + output_table->num_columns() - 1, + compute::SortOrder::Ascending)}))); + ASSERT_OK_AND_ASSIGN(Datum sorted_table_datum, + compute::Take(output_table, sort_indices)); + output_table = sorted_table_datum.table(); + // TODO(ARROW-17245) We should be selecting N-1 here but Acero + // currently emits things in reverse order + ASSERT_OK_AND_ASSIGN(output_table, output_table->SelectColumns({0})); + + std::shared_ptr
expected_output = + GetOutputTableForAggregateCase(test_case.output_type, test_case.group_outputs); + + AssertTablesEqual(*expected_output, *output_table, /*same_chunk_layout=*/false); +} + +void CheckAggregateCases(const std::vector& test_cases) { + for (const AggregateTestCase& test_case : test_cases) { + CheckWholeAggregateCase(test_case); + CheckGroupedAggregateCase(test_case); + } +} + +TEST(FunctionMapping, AggregateCases) { + const std::vector test_cases = { + {{kSubstraitArithmeticFunctionsUri, "sum"}, + "[1, 2, 3]", + int8(), + "[6]", + "[3, 3]", + int64()}, + {{kSubstraitArithmeticFunctionsUri, "min"}, + "[1, 2, 3]", + int8(), + "[1]", + "[1, 3]", + int8()}, + {{kSubstraitArithmeticFunctionsUri, "max"}, + "[1, 2, 3]", + int8(), + "[3]", + "[2, 3]", + int8()}, + {{kSubstraitArithmeticFunctionsUri, "avg"}, + "[1, 2, 3]", + float64(), + "[2]", + "[1.5, 3]", + float64()}}; + CheckAggregateCases(test_cases); +} + +} // namespace engine +} // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/options.h b/cpp/src/arrow/engine/substrait/options.h index dcb2088416f69..eace200f0acb1 100644 --- a/cpp/src/arrow/engine/substrait/options.h +++ b/cpp/src/arrow/engine/substrait/options.h @@ -54,11 +54,20 @@ enum class ConversionStrictness { BEST_EFFORT, }; +using NamedTableProvider = + std::function(const std::vector&)>; +static NamedTableProvider kDefaultNamedTableProvider; + /// Options that control the conversion between Substrait and Acero representations of a /// plan. struct ConversionOptions { /// \brief How strictly the converter should adhere to the structure of the input. ConversionStrictness strictness = ConversionStrictness::BEST_EFFORT; + /// \brief A custom strategy to be used for providing named tables + /// + /// The default behavior will return an invalid status if the plan has any + /// named table relations. + NamedTableProvider named_table_provider = kDefaultNamedTableProvider; }; } // namespace engine diff --git a/cpp/src/arrow/engine/substrait/plan_internal.cc b/cpp/src/arrow/engine/substrait/plan_internal.cc index 2da037000cf70..b0fdb9bdc2fcd 100644 --- a/cpp/src/arrow/engine/substrait/plan_internal.cc +++ b/cpp/src/arrow/engine/substrait/plan_internal.cc @@ -74,13 +74,12 @@ Status AddExtensionSetToPlan(const ExtensionSet& ext_set, substrait::Plan* plan) } for (uint32_t anchor = 0; anchor < ext_set.num_functions(); ++anchor) { - ARROW_ASSIGN_OR_RAISE(auto function_record, ext_set.DecodeFunction(anchor)); - if (function_record.id.empty()) continue; + ARROW_ASSIGN_OR_RAISE(Id function_id, ext_set.DecodeFunction(anchor)); auto fn = internal::make_unique(); - fn->set_extension_uri_reference(map[function_record.id.uri]); + fn->set_extension_uri_reference(map[function_id.uri]); fn->set_function_anchor(anchor); - fn->set_name(function_record.id.name.to_string()); + fn->set_name(function_id.name.to_string()); auto ext_decl = internal::make_unique(); ext_decl->set_allocated_extension_function(fn.release()); @@ -104,8 +103,6 @@ Result GetExtensionSetFromPlan(const substrait::Plan& plan, // NOTE: it's acceptable to use views to memory owned by plan; ExtensionSet::Make // will only store views to memory owned by registry. - using Id = ExtensionSet::Id; - std::unordered_map type_ids, function_ids; for (const auto& ext : plan.extensions()) { switch (ext.mapping_type_case()) { diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc b/cpp/src/arrow/engine/substrait/relation_internal.cc index 8cc1da4d9030a..c5c02f51558c9 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.cc +++ b/cpp/src/arrow/engine/substrait/relation_internal.cc @@ -67,6 +67,7 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& ARROW_ASSIGN_OR_RAISE(auto base_schema, FromProto(read.base_schema(), ext_set, conversion_options)); + auto num_columns = static_cast(base_schema->fields().size()); auto scan_options = std::make_shared(); scan_options->use_threads = true; @@ -82,6 +83,22 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& return Status::NotImplemented("substrait::ReadRel::projection"); } + if (read.has_named_table()) { + if (!conversion_options.named_table_provider) { + return Status::Invalid( + "plan contained a named table but a NamedTableProvider has not been " + "configured"); + } + const NamedTableProvider& named_table_provider = + conversion_options.named_table_provider; + const substrait::ReadRel::NamedTable& named_table = read.named_table(); + std::vector table_names(named_table.names().begin(), + named_table.names().end()); + ARROW_ASSIGN_OR_RAISE(compute::Declaration source_decl, + named_table_provider(table_names)); + return DeclarationInfo{std::move(source_decl), num_columns}; + } + if (!read.has_local_files()) { return Status::NotImplemented( "substrait::ReadRel with read_type other than LocalFiles"); @@ -182,7 +199,6 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& std::move(filesystem), std::move(files), std::move(format), {})); - auto num_columns = static_cast(base_schema->fields().size()); ARROW_ASSIGN_OR_RAISE(auto ds, ds_factory->Finish(std::move(base_schema))); return DeclarationInfo{ @@ -349,17 +365,20 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& "than one item"); } std::vector keys; - auto group = aggregate.groupings(0); - keys.reserve(group.grouping_expressions_size()); - for (int exp_id = 0; exp_id < group.grouping_expressions_size(); exp_id++) { - ARROW_ASSIGN_OR_RAISE(auto expr, FromProto(group.grouping_expressions(exp_id), - ext_set, conversion_options)); - const auto* field_ref = expr.field_ref(); - if (field_ref) { - keys.emplace_back(std::move(*field_ref)); - } else { - return Status::Invalid( - "The grouping expression for an aggregate must be a direct reference."); + if (aggregate.groupings_size() > 0) { + const substrait::AggregateRel::Grouping& group = aggregate.groupings(0); + keys.reserve(group.grouping_expressions_size()); + for (int exp_id = 0; exp_id < group.grouping_expressions_size(); exp_id++) { + ARROW_ASSIGN_OR_RAISE( + compute::Expression expr, + FromProto(group.grouping_expressions(exp_id), ext_set, conversion_options)); + const FieldRef* field_ref = expr.field_ref(); + if (field_ref) { + keys.emplace_back(std::move(*field_ref)); + } else { + return Status::Invalid( + "The grouping expression for an aggregate must be a direct reference."); + } } } @@ -373,25 +392,14 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& return Status::NotImplemented("Aggregate filters are not supported."); } const auto& agg_func = agg_measure.measure(); - if (agg_func.arguments_size() != 1) { - return Status::NotImplemented("Aggregate function must be a unary function."); - } - int func_reference = agg_func.function_reference(); - ARROW_ASSIGN_OR_RAISE(auto func_record, ext_set.DecodeFunction(func_reference)); - // aggreagte function name - auto func_name = std::string(func_record.id.name); - // aggregate target - auto subs_func_args = agg_func.arguments(0); - ARROW_ASSIGN_OR_RAISE(auto field_expr, FromProto(subs_func_args.value(), - ext_set, conversion_options)); - auto target = field_expr.field_ref(); - if (!target) { - return Status::Invalid( - "The input expression to an aggregate function must be a direct " - "reference."); - } - aggregates.emplace_back(compute::Aggregate{std::move(func_name), NULLPTR, - std::move(*target), std::move("")}); + ARROW_ASSIGN_OR_RAISE( + SubstraitCall aggregate_call, + FromProto(agg_func, !keys.empty(), ext_set, conversion_options)); + ARROW_ASSIGN_OR_RAISE( + ExtensionIdRegistry::SubstraitAggregateToArrow converter, + ext_set.registry()->GetSubstraitAggregateToArrow(aggregate_call.id())); + ARROW_ASSIGN_OR_RAISE(compute::Aggregate arrow_agg, converter(aggregate_call)); + aggregates.push_back(std::move(arrow_agg)); } else { return Status::Invalid("substrait::AggregateFunction not provided"); } diff --git a/cpp/src/arrow/engine/substrait/serde.cc b/cpp/src/arrow/engine/substrait/serde.cc index 87ad88dccb45c..9f7d979e2f02e 100644 --- a/cpp/src/arrow/engine/substrait/serde.cc +++ b/cpp/src/arrow/engine/substrait/serde.cc @@ -172,7 +172,7 @@ Result> MakeSingleDeclarationPlan( } else { ARROW_ASSIGN_OR_RAISE(auto plan, compute::ExecPlan::Make()); ARROW_RETURN_NOT_OK(declarations[0].AddToPlan(plan.get())); - return plan; + return std::move(plan); } } @@ -182,17 +182,21 @@ Result> DeserializePlan( const Buffer& buf, const std::shared_ptr& consumer, const ExtensionIdRegistry* registry, ExtensionSet* ext_set_out, const ConversionOptions& conversion_options) { - bool factory_done = false; - auto single_consumer = [&factory_done, &consumer] { - if (factory_done) { - return std::shared_ptr{}; + struct SingleConsumer { + std::shared_ptr operator()() { + if (factory_done) { + Status::Invalid("SingleConsumer invoked more than once").Warn(); + return std::shared_ptr{}; + } + factory_done = true; + return consumer; } - factory_done = true; - return consumer; + bool factory_done; + std::shared_ptr consumer; }; - ARROW_ASSIGN_OR_RAISE( - auto declarations, - DeserializePlans(buf, single_consumer, registry, ext_set_out, conversion_options)); + ARROW_ASSIGN_OR_RAISE(auto declarations, + DeserializePlans(buf, SingleConsumer{false, consumer}, registry, + ext_set_out, conversion_options)); return MakeSingleDeclarationPlan(declarations); } diff --git a/cpp/src/arrow/engine/substrait/serde.h b/cpp/src/arrow/engine/substrait/serde.h index 5214606e1c8c7..6c2083fb56a15 100644 --- a/cpp/src/arrow/engine/substrait/serde.h +++ b/cpp/src/arrow/engine/substrait/serde.h @@ -75,7 +75,7 @@ ARROW_ENGINE_EXPORT Result> DeserializePlans( /// Plan is returned here. /// \return an ExecNode corresponding to the single toplevel relation in the Substrait /// Plan -Result> DeserializePlan( +ARROW_ENGINE_EXPORT Result> DeserializePlan( const Buffer& buf, const std::shared_ptr& consumer, const ExtensionIdRegistry* registry = NULLPTR, ExtensionSet* ext_set_out = NULLPTR, const ConversionOptions& conversion_options = {}); diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index 3bb4de4e920a0..04405b316807d 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -701,7 +701,12 @@ TEST(Substrait, ExtensionSetFromPlan) { "extension_uris": [ { "extension_uri_anchor": 7, - "uri": ")" + substrait::default_extension_types_uri() + + "uri": ")" + default_extension_types_uri() + + R"(" + }, + { + "extension_uri_anchor": 18, + "uri": ")" + kSubstraitArithmeticFunctionsUri + R"(" } ], @@ -712,15 +717,15 @@ TEST(Substrait, ExtensionSetFromPlan) { "name": "null" }}, {"extension_function": { - "extension_uri_reference": 7, + "extension_uri_reference": 18, "function_anchor": 42, "name": "add" }} ] - })"; +})"; ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", substrait_json)); for (auto sp_ext_id_reg : - {std::shared_ptr(), substrait::MakeExtensionIdRegistry()}) { + {std::shared_ptr(), MakeExtensionIdRegistry()}) { ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get(); ExtensionSet ext_set(ext_id_reg); ASSERT_OK_AND_ASSIGN(auto sink_decls, @@ -732,10 +737,9 @@ TEST(Substrait, ExtensionSetFromPlan) { EXPECT_EQ(decoded_null_type.id.name, "null"); EXPECT_EQ(*decoded_null_type.type, NullType()); - EXPECT_OK_AND_ASSIGN(auto decoded_add_func, ext_set.DecodeFunction(42)); - EXPECT_EQ(decoded_add_func.id.uri, kArrowExtTypesUri); - EXPECT_EQ(decoded_add_func.id.name, "add"); - EXPECT_EQ(decoded_add_func.name, "add"); + EXPECT_OK_AND_ASSIGN(Id decoded_add_func_id, ext_set.DecodeFunction(42)); + EXPECT_EQ(decoded_add_func_id.uri, kSubstraitArithmeticFunctionsUri); + EXPECT_EQ(decoded_add_func_id.name, "add"); } } @@ -745,7 +749,7 @@ TEST(Substrait, ExtensionSetFromPlanMissingFunc) { "extension_uris": [ { "extension_uri_anchor": 7, - "uri": ")" + substrait::default_extension_types_uri() + + "uri": ")" + default_extension_types_uri() + R"(" } ], @@ -760,7 +764,7 @@ TEST(Substrait, ExtensionSetFromPlanMissingFunc) { ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", substrait_json)); for (auto sp_ext_id_reg : - {std::shared_ptr(), substrait::MakeExtensionIdRegistry()}) { + {std::shared_ptr(), MakeExtensionIdRegistry()}) { ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get(); ExtensionSet ext_set(ext_id_reg); ASSERT_RAISES(Invalid, DeserializePlans( @@ -786,7 +790,7 @@ TEST(Substrait, ExtensionSetFromPlanExhaustedFactory) { "extension_uris": [ { "extension_uri_anchor": 7, - "uri": ")" + substrait::default_extension_types_uri() + + "uri": ")" + default_extension_types_uri() + R"(" } ], @@ -801,7 +805,7 @@ TEST(Substrait, ExtensionSetFromPlanExhaustedFactory) { ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", substrait_json)); for (auto sp_ext_id_reg : - {std::shared_ptr(), substrait::MakeExtensionIdRegistry()}) { + {std::shared_ptr(), MakeExtensionIdRegistry()}) { ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get(); ExtensionSet ext_set(ext_id_reg); ASSERT_RAISES( @@ -823,7 +827,7 @@ TEST(Substrait, ExtensionSetFromPlanRegisterFunc) { "extension_uris": [ { "extension_uri_anchor": 7, - "uri": ")" + substrait::default_extension_types_uri() + + "uri": ")" + default_extension_types_uri() + R"(" } ], @@ -837,24 +841,23 @@ TEST(Substrait, ExtensionSetFromPlanRegisterFunc) { })"; ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", substrait_json)); - auto sp_ext_id_reg = substrait::MakeExtensionIdRegistry(); + auto sp_ext_id_reg = MakeExtensionIdRegistry(); ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get(); // invalid before registration ExtensionSet ext_set_invalid(ext_id_reg); ASSERT_RAISES(Invalid, DeserializePlans( *buf, [] { return kNullConsumer; }, ext_id_reg, &ext_set_invalid)); - ASSERT_OK(substrait::RegisterFunction( - *ext_id_reg, substrait::default_extension_types_uri(), "new_func", "multiply")); + ASSERT_OK(ext_id_reg->AddSubstraitCallToArrow( + {default_extension_types_uri(), "new_func"}, "multiply")); // valid after registration ExtensionSet ext_set_valid(ext_id_reg); ASSERT_OK_AND_ASSIGN(auto sink_decls, DeserializePlans( *buf, [] { return kNullConsumer; }, ext_id_reg, &ext_set_valid)); - EXPECT_OK_AND_ASSIGN(auto decoded_add_func, ext_set_valid.DecodeFunction(42)); - EXPECT_EQ(decoded_add_func.id.uri, kArrowExtTypesUri); - EXPECT_EQ(decoded_add_func.id.name, "new_func"); - EXPECT_EQ(decoded_add_func.name, "multiply"); + EXPECT_OK_AND_ASSIGN(Id decoded_add_func_id, ext_set_valid.DecodeFunction(42)); + EXPECT_EQ(decoded_add_func_id.uri, kArrowExtTypesUri); + EXPECT_EQ(decoded_add_func_id.name, "new_func"); } Result GetSubstraitJSON() { @@ -900,7 +903,7 @@ TEST(Substrait, DeserializeWithConsumerFactory) { GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows"; #else ASSERT_OK_AND_ASSIGN(std::string substrait_json, GetSubstraitJSON()); - ASSERT_OK_AND_ASSIGN(auto buf, substrait::SerializeJsonPlan(substrait_json)); + ASSERT_OK_AND_ASSIGN(auto buf, SerializeJsonPlan(substrait_json)); ASSERT_OK_AND_ASSIGN(auto declarations, DeserializePlans(*buf, NullSinkNodeConsumer::Make)); ASSERT_EQ(declarations.size(), 1); @@ -923,7 +926,7 @@ TEST(Substrait, DeserializeSinglePlanWithConsumerFactory) { GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows"; #else ASSERT_OK_AND_ASSIGN(std::string substrait_json, GetSubstraitJSON()); - ASSERT_OK_AND_ASSIGN(auto buf, substrait::SerializeJsonPlan(substrait_json)); + ASSERT_OK_AND_ASSIGN(auto buf, SerializeJsonPlan(substrait_json)); ASSERT_OK_AND_ASSIGN(std::shared_ptr plan, DeserializePlan(*buf, NullSinkNodeConsumer::Make())); ASSERT_EQ(1, plan->sinks().size()); @@ -960,7 +963,7 @@ TEST(Substrait, DeserializeWithWriteOptionsFactory) { return std::make_shared(options); }; ASSERT_OK_AND_ASSIGN(std::string substrait_json, GetSubstraitJSON()); - ASSERT_OK_AND_ASSIGN(auto buf, substrait::SerializeJsonPlan(substrait_json)); + ASSERT_OK_AND_ASSIGN(auto buf, SerializeJsonPlan(substrait_json)); ASSERT_OK_AND_ASSIGN(auto declarations, DeserializePlans(*buf, write_options_factory)); ASSERT_EQ(declarations.size(), 1); compute::Declaration* decl = &declarations[0]; @@ -984,7 +987,7 @@ TEST(Substrait, DeserializeWithWriteOptionsFactory) { static void test_with_registries( std::function test) { auto default_func_reg = compute::GetFunctionRegistry(); - auto nested_ext_id_reg = substrait::MakeExtensionIdRegistry(); + auto nested_ext_id_reg = MakeExtensionIdRegistry(); auto nested_func_reg = compute::FunctionRegistry::Make(default_func_reg); test(nullptr, default_func_reg); test(nullptr, nested_func_reg.get()); @@ -999,8 +1002,8 @@ TEST(Substrait, GetRecordBatchReader) { ASSERT_OK_AND_ASSIGN(std::string substrait_json, GetSubstraitJSON()); test_with_registries([&substrait_json](ExtensionIdRegistry* ext_id_reg, compute::FunctionRegistry* func_registry) { - ASSERT_OK_AND_ASSIGN(auto buf, substrait::SerializeJsonPlan(substrait_json)); - ASSERT_OK_AND_ASSIGN(auto reader, substrait::ExecuteSerializedPlan(*buf)); + ASSERT_OK_AND_ASSIGN(auto buf, SerializeJsonPlan(substrait_json)); + ASSERT_OK_AND_ASSIGN(auto reader, ExecuteSerializedPlan(*buf)); ASSERT_OK_AND_ASSIGN(auto table, Table::FromRecordBatchReader(reader.get())); // Note: assuming the binary.parquet file contains fixed amount of records // in case of a test failure, re-evalaute the content in the file @@ -1016,8 +1019,8 @@ TEST(Substrait, InvalidPlan) { })"; test_with_registries([&substrait_json](ExtensionIdRegistry* ext_id_reg, compute::FunctionRegistry* func_registry) { - ASSERT_OK_AND_ASSIGN(auto buf, substrait::SerializeJsonPlan(substrait_json)); - ASSERT_RAISES(Invalid, substrait::ExecuteSerializedPlan(*buf)); + ASSERT_OK_AND_ASSIGN(auto buf, SerializeJsonPlan(substrait_json)); + ASSERT_RAISES(Invalid, ExecuteSerializedPlan(*buf)); }); } @@ -1101,7 +1104,10 @@ TEST(Substrait, JoinPlanBasic) { } } } - }] + }], + "output_type": { + "bool": {} + } } }, "type": "JOIN_TYPE_INNER" @@ -1111,7 +1117,7 @@ TEST(Substrait, JoinPlanBasic) { "extension_uris": [ { "extension_uri_anchor": 0, - "uri": ")" + substrait::default_extension_types_uri() + + "uri": ")" + std::string(kSubstraitComparisonFunctionsUri) + R"(" } ], @@ -1125,7 +1131,7 @@ TEST(Substrait, JoinPlanBasic) { })"; ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", substrait_json)); for (auto sp_ext_id_reg : - {std::shared_ptr(), substrait::MakeExtensionIdRegistry()}) { + {std::shared_ptr(), MakeExtensionIdRegistry()}) { ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get(); ExtensionSet ext_set(ext_id_reg); ASSERT_OK_AND_ASSIGN(auto sink_decls, @@ -1241,7 +1247,10 @@ TEST(Substrait, JoinPlanInvalidKeyCmp) { } } } - }] + }], + "output_type": { + "bool": {} + } } }, "type": "JOIN_TYPE_INNER" @@ -1251,7 +1260,7 @@ TEST(Substrait, JoinPlanInvalidKeyCmp) { "extension_uris": [ { "extension_uri_anchor": 0, - "uri": ")" + substrait::default_extension_types_uri() + + "uri": ")" + std::string(kSubstraitArithmeticFunctionsUri) + R"(" } ], @@ -1265,7 +1274,7 @@ TEST(Substrait, JoinPlanInvalidKeyCmp) { })"; ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", substrait_json)); for (auto sp_ext_id_reg : - {std::shared_ptr(), substrait::MakeExtensionIdRegistry()}) { + {std::shared_ptr(), MakeExtensionIdRegistry()}) { ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get(); ExtensionSet ext_set(ext_id_reg); ASSERT_RAISES(Invalid, DeserializePlans( @@ -1333,7 +1342,7 @@ TEST(Substrait, JoinPlanInvalidExpression) { }] })")); for (auto sp_ext_id_reg : - {std::shared_ptr(), substrait::MakeExtensionIdRegistry()}) { + {std::shared_ptr(), MakeExtensionIdRegistry()}) { ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get(); ExtensionSet ext_set(ext_id_reg); ASSERT_RAISES(Invalid, DeserializePlans( @@ -1406,7 +1415,7 @@ TEST(Substrait, JoinPlanInvalidKeys) { }] })")); for (auto sp_ext_id_reg : - {std::shared_ptr(), substrait::MakeExtensionIdRegistry()}) { + {std::shared_ptr(), MakeExtensionIdRegistry()}) { ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get(); ExtensionSet ext_set(ext_id_reg); ASSERT_RAISES(Invalid, DeserializePlans( @@ -1470,6 +1479,7 @@ TEST(Substrait, AggregateBasic) { }], "sorts": [], "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "invocation": "AGGREGATION_INVOCATION_ALL", "outputType": { "i64": {} } @@ -1480,18 +1490,18 @@ TEST(Substrait, AggregateBasic) { }], "extensionUris": [{ "extension_uri_anchor": 0, - "uri": "https://github.com/apache/arrow/blob/master/format/substrait/extension_types.yaml" + "uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml" }], "extensions": [{ "extension_function": { "extension_uri_reference": 0, "function_anchor": 0, - "name": "hash_count" + "name": "sum" } }], })")); - auto sp_ext_id_reg = substrait::MakeExtensionIdRegistry(); + auto sp_ext_id_reg = MakeExtensionIdRegistry(); ASSERT_OK_AND_ASSIGN(auto sink_decls, DeserializePlans(*buf, [] { return kNullConsumer; })); auto agg_decl = sink_decls[0].inputs[0]; @@ -1503,7 +1513,7 @@ TEST(Substrait, AggregateBasic) { EXPECT_EQ(agg_rel->factory_name, "aggregate"); EXPECT_EQ(agg_options.aggregates[0].name, ""); - EXPECT_EQ(agg_options.aggregates[0].function, "hash_count"); + EXPECT_EQ(agg_options.aggregates[0].function, "hash_sum"); } TEST(Substrait, AggregateInvalidRel) { @@ -1516,13 +1526,13 @@ TEST(Substrait, AggregateInvalidRel) { }], "extensionUris": [{ "extension_uri_anchor": 0, - "uri": "https://github.com/apache/arrow/blob/master/format/substrait/extension_types.yaml" + "uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml" }], "extensions": [{ "extension_function": { "extension_uri_reference": 0, "function_anchor": 0, - "name": "hash_count" + "name": "sum" } }], })")); @@ -1577,13 +1587,13 @@ TEST(Substrait, AggregateInvalidFunction) { }], "extensionUris": [{ "extension_uri_anchor": 0, - "uri": "https://github.com/apache/arrow/blob/master/format/substrait/extension_types.yaml" + "uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml" }], "extensions": [{ "extension_function": { "extension_uri_reference": 0, "function_anchor": 0, - "name": "hash_count" + "name": "sum" } }], })")); @@ -1637,6 +1647,7 @@ TEST(Substrait, AggregateInvalidAggFuncArgs) { "args": [], "sorts": [], "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "invocation": "AGGREGATION_INVOCATION_ALL", "outputType": { "i64": {} } @@ -1647,13 +1658,13 @@ TEST(Substrait, AggregateInvalidAggFuncArgs) { }], "extensionUris": [{ "extension_uri_anchor": 0, - "uri": "https://github.com/apache/arrow/blob/master/format/substrait/extension_types.yaml" + "uri": "https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml" }], "extensions": [{ "extension_function": { "extension_uri_reference": 0, "function_anchor": 0, - "name": "hash_count" + "name": "sum" } }], })")); @@ -1707,6 +1718,78 @@ TEST(Substrait, AggregateWithFilter) { "args": [], "sorts": [], "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "invocation": "AGGREGATION_INVOCATION_ALL", + "outputType": { + "i64": {} + } + } + }] + } + } + }], + "extensionUris": [{ + "extension_uri_anchor": 0, + "uri": "https://github.com/apache/arrow/blob/master/format/substrait/extension_types.yaml" + }], + "extensions": [{ + "extension_function": { + "extension_uri_reference": 0, + "function_anchor": 0, + "name": "equal" + } + }], + })")); + + ASSERT_RAISES(NotImplemented, DeserializePlans(*buf, [] { return kNullConsumer; })); +} + +TEST(Substrait, AggregateBadPhase) { + ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", R"({ + "relations": [{ + "rel": { + "aggregate": { + "input": { + "read": { + "base_schema": { + "names": ["A", "B", "C"], + "struct": { + "types": [{ + "i32": {} + }, { + "i32": {} + }, { + "i32": {} + }] + } + }, + "local_files": { + "items": [ + { + "uri_file": "file:///tmp/dat.parquet", + "parquet": {} + } + ] + } + } + }, + "groupings": [{ + "groupingExpressions": [{ + "selection": { + "directReference": { + "structField": { + "field": 0 + } + } + } + }] + }], + "measures": [{ + "measure": { + "functionReference": 0, + "args": [], + "sorts": [], + "phase": "AGGREGATION_PHASE_INITIAL_TO_RESULT", + "invocation": "AGGREGATION_INVOCATION_DISTINCT", "outputType": { "i64": {} } diff --git a/cpp/src/arrow/engine/substrait/test_plan_builder.cc b/cpp/src/arrow/engine/substrait/test_plan_builder.cc new file mode 100644 index 0000000000000..3bd373ae5fa56 --- /dev/null +++ b/cpp/src/arrow/engine/substrait/test_plan_builder.cc @@ -0,0 +1,216 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/engine/substrait/test_plan_builder.h" + +#include + +#include "arrow/compute/exec/exec_plan.h" +#include "arrow/engine/substrait/plan_internal.h" +#include "arrow/engine/substrait/type_internal.h" +#include "arrow/util/macros.h" +#include "arrow/util/make_unique.h" + +#include "substrait/algebra.pb.h" +#include "substrait/plan.pb.h" +#include "substrait/type.pb.h" + +namespace arrow { + +using internal::make_unique; + +namespace engine { +namespace internal { + +static const ConversionOptions kPlanBuilderConversionOptions; + +Result> CreateRead(const Table& table, + ExtensionSet* ext_set) { + auto read = make_unique(); + + ARROW_ASSIGN_OR_RAISE(std::unique_ptr schema, + ToProto(*table.schema(), ext_set, kPlanBuilderConversionOptions)); + read->set_allocated_base_schema(schema.release()); + + auto named_table = make_unique(); + named_table->add_names("test"); + read->set_allocated_named_table(named_table.release()); + + return read; +} + +void CreateDirectReference(int32_t index, substrait::Expression* expr) { + auto reference = make_unique(); + auto reference_segment = make_unique(); + auto struct_field = make_unique(); + struct_field->set_field(index); + reference_segment->set_allocated_struct_field(struct_field.release()); + reference->set_allocated_direct_reference(reference_segment.release()); + + auto root_reference = + make_unique(); + reference->set_allocated_root_reference(root_reference.release()); + expr->set_allocated_selection(reference.release()); +} + +Result> CreateProject( + Id function_id, const std::vector& arguments, + const std::vector>& arg_types, const DataType& output_type, + ExtensionSet* ext_set) { + auto project = make_unique(); + + auto call = make_unique(); + ARROW_ASSIGN_OR_RAISE(uint32_t function_anchor, ext_set->EncodeFunction(function_id)); + call->set_function_reference(function_anchor); + + std::size_t arg_index = 0; + std::size_t table_arg_index = 0; + for (const std::shared_ptr& arg_type : arg_types) { + substrait::FunctionArgument* argument = call->add_arguments(); + if (arg_type) { + // If it has a type then it's a reference to the input table + auto expression = make_unique(); + CreateDirectReference(static_cast(table_arg_index++), expression.get()); + argument->set_allocated_value(expression.release()); + } else { + // If it doesn't have a type then it's an enum + const std::string& enum_value = arguments[arg_index]; + auto enum_ = make_unique(); + if (enum_value.size() > 0) { + enum_->set_specified(enum_value); + } else { + auto unspecified = make_unique(); + enum_->set_allocated_unspecified(unspecified.release()); + } + argument->set_allocated_enum_(enum_.release()); + } + arg_index++; + } + + ARROW_ASSIGN_OR_RAISE( + std::unique_ptr output_type_substrait, + ToProto(output_type, /*nullable=*/true, ext_set, kPlanBuilderConversionOptions)); + call->set_allocated_output_type(output_type_substrait.release()); + + substrait::Expression* call_expression = project->add_expressions(); + call_expression->set_allocated_scalar_function(call.release()); + + return project; +} + +Result> CreateAgg(Id function_id, + const std::vector& keys, + int arg_idx, + const DataType& output_type, + ExtensionSet* ext_set) { + auto agg = make_unique(); + + if (!keys.empty()) { + substrait::AggregateRel::Grouping* grouping = agg->add_groupings(); + for (int key : keys) { + substrait::Expression* key_expr = grouping->add_grouping_expressions(); + CreateDirectReference(key, key_expr); + } + } + + substrait::AggregateRel::Measure* measure_wrapper = agg->add_measures(); + auto agg_func = make_unique(); + ARROW_ASSIGN_OR_RAISE(uint32_t function_anchor, ext_set->EncodeFunction(function_id)); + + agg_func->set_function_reference(function_anchor); + + substrait::FunctionArgument* arg = agg_func->add_arguments(); + auto arg_expr = make_unique(); + CreateDirectReference(arg_idx, arg_expr.get()); + arg->set_allocated_value(arg_expr.release()); + + agg_func->set_phase(substrait::AggregationPhase::AGGREGATION_PHASE_INITIAL_TO_RESULT); + agg_func->set_invocation( + substrait::AggregateFunction::AggregationInvocation:: + AggregateFunction_AggregationInvocation_AGGREGATION_INVOCATION_ALL); + + ARROW_ASSIGN_OR_RAISE( + std::unique_ptr output_type_substrait, + ToProto(output_type, /*nullable=*/true, ext_set, kPlanBuilderConversionOptions)); + agg_func->set_allocated_output_type(output_type_substrait.release()); + measure_wrapper->set_allocated_measure(agg_func.release()); + + return agg; +} + +Result> CreatePlan(std::unique_ptr root, + ExtensionSet* ext_set) { + auto plan = make_unique(); + + substrait::PlanRel* plan_rel = plan->add_relations(); + auto rel_root = make_unique(); + rel_root->set_allocated_input(root.release()); + plan_rel->set_allocated_root(rel_root.release()); + + ARROW_RETURN_NOT_OK(AddExtensionSetToPlan(*ext_set, plan.get())); + return plan; +} + +Result> CreateScanProjectSubstrait( + Id function_id, const std::shared_ptr
& input_table, + const std::vector& arguments, + const std::vector>& data_types, + const DataType& output_type) { + ExtensionSet ext_set; + ARROW_ASSIGN_OR_RAISE(std::unique_ptr read, + CreateRead(*input_table, &ext_set)); + ARROW_ASSIGN_OR_RAISE( + std::unique_ptr project, + CreateProject(function_id, arguments, data_types, output_type, &ext_set)); + + auto read_rel = make_unique(); + read_rel->set_allocated_read(read.release()); + project->set_allocated_input(read_rel.release()); + + auto project_rel = make_unique(); + project_rel->set_allocated_project(project.release()); + + ARROW_ASSIGN_OR_RAISE(std::unique_ptr plan, + CreatePlan(std::move(project_rel), &ext_set)); + return Buffer::FromString(plan->SerializeAsString()); +} + +Result> CreateScanAggSubstrait( + Id function_id, const std::shared_ptr
& input_table, + const std::vector& key_idxs, int arg_idx, const DataType& output_type) { + ExtensionSet ext_set; + + ARROW_ASSIGN_OR_RAISE(std::unique_ptr read, + CreateRead(*input_table, &ext_set)); + ARROW_ASSIGN_OR_RAISE(std::unique_ptr agg, + CreateAgg(function_id, key_idxs, arg_idx, output_type, &ext_set)); + + auto read_rel = make_unique(); + read_rel->set_allocated_read(read.release()); + agg->set_allocated_input(read_rel.release()); + + auto agg_rel = make_unique(); + agg_rel->set_allocated_aggregate(agg.release()); + + ARROW_ASSIGN_OR_RAISE(std::unique_ptr plan, + CreatePlan(std::move(agg_rel), &ext_set)); + return Buffer::FromString(plan->SerializeAsString()); +} + +} // namespace internal +} // namespace engine +} // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/test_plan_builder.h b/cpp/src/arrow/engine/substrait/test_plan_builder.h new file mode 100644 index 0000000000000..9d2d97a8cc9cc --- /dev/null +++ b/cpp/src/arrow/engine/substrait/test_plan_builder.h @@ -0,0 +1,72 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// These utilities are for internal / unit test use only. +// They allow for the construction of simple Substrait plans +// programmatically without first requiring the construction +// of an ExecPlan + +// These utilities have to be here, and not in a test_util.cc +// file (or in a unit test) because only one .so is allowed +// to include each .pb.h file or else protobuf will encounter +// global namespace conflicts. + +#include +#include +#include + +#include "arrow/buffer.h" +#include "arrow/engine/substrait/extension_set.h" +#include "arrow/result.h" +#include "arrow/table.h" +#include "arrow/type.h" + +namespace arrow { +namespace engine { +namespace internal { + +/// \brief Create a scan->project->sink plan for tests +/// +/// The plan will project one additional column using the function +/// defined by `function_id`, `arguments`, and data_types. `arguments` +/// and `data_types` should have the same length but only one of each +/// should be defined at each index. +/// +/// If `data_types` is defined at an index then the plan will create a +/// direct reference (starting at index 0 and increasing by 1 for each +/// argument of this type). +/// +/// If `arguments` is defined at an index then the plan will create an +/// enum argument with that value. +ARROW_ENGINE_EXPORT Result> CreateScanProjectSubstrait( + Id function_id, const std::shared_ptr
& input_table, + const std::vector& arguments, + const std::vector>& data_types, + const DataType& output_type); + +/// \brief Create a scan->aggregate->sink plan for tests +/// +/// The plan will create an aggregate with one grouping set (defined by +/// key_idxs) and one measure. The measure will be a unary function +/// defined by `function_id` and a direct reference to `arg_idx`. +ARROW_ENGINE_EXPORT Result> CreateScanAggSubstrait( + Id function_id, const std::shared_ptr
& input_table, + const std::vector& key_idxs, int arg_idx, const DataType& output_type); + +} // namespace internal +} // namespace engine +} // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/util.cc b/cpp/src/arrow/engine/substrait/util.cc index 36240d468278c..936bde5c652e5 100644 --- a/cpp/src/arrow/engine/substrait/util.cc +++ b/cpp/src/arrow/engine/substrait/util.cc @@ -23,8 +23,6 @@ namespace arrow { namespace engine { -namespace substrait { - namespace { /// \brief A SinkNodeConsumer specialized to output ExecBatches via PushGenerator @@ -136,19 +134,11 @@ std::shared_ptr MakeExtensionIdRegistry() { return nested_extension_id_registry(default_extension_id_registry()); } -Status RegisterFunction(ExtensionIdRegistry& registry, const std::string& id_uri, - const std::string& id_name, - const std::string& arrow_function_name) { - return registry.RegisterFunction(id_uri, id_name, arrow_function_name); -} - const std::string& default_extension_types_uri() { static std::string uri = engine::kArrowExtTypesUri.to_string(); return uri; } -} // namespace substrait - } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/util.h b/cpp/src/arrow/engine/substrait/util.h index 134d633bb33d3..3ac9320e1da76 100644 --- a/cpp/src/arrow/engine/substrait/util.h +++ b/cpp/src/arrow/engine/substrait/util.h @@ -27,8 +27,6 @@ namespace arrow { namespace engine { -namespace substrait { - /// \brief Retrieve a RecordBatchReader from a Substrait plan. ARROW_ENGINE_EXPORT Result> ExecuteSerializedPlan( const Buffer& substrait_buffer, const ExtensionIdRegistry* registry = NULLPTR, @@ -43,24 +41,8 @@ ARROW_ENGINE_EXPORT Result> SerializeJsonPlan( /// See arrow::engine::nested_extension_id_registry for details. ARROW_ENGINE_EXPORT std::shared_ptr MakeExtensionIdRegistry(); -/// \brief Register a function manually. -/// -/// Register an arrow function name by an ID, defined by a URI and a name, on a given -/// extension-id-registry. -/// -/// \param[in] registry an extension-id-registry to use -/// \param[in] id_uri a URI of the ID to register by -/// \param[in] id_name a name of the ID to register by -/// \param[in] arrow_function_name name of arrow function to register -ARROW_ENGINE_EXPORT Status RegisterFunction(ExtensionIdRegistry& registry, - const std::string& id_uri, - const std::string& id_name, - const std::string& arrow_function_name); - ARROW_ENGINE_EXPORT const std::string& default_extension_types_uri(); -} // namespace substrait - } // namespace engine } // namespace arrow diff --git a/docs/source/cpp/streaming_execution.rst b/docs/source/cpp/streaming_execution.rst index e49225637df13..daa5f4be2f013 100644 --- a/docs/source/cpp/streaming_execution.rst +++ b/docs/source/cpp/streaming_execution.rst @@ -144,6 +144,17 @@ Join Relations join key is supported. * The ``post_join_filter`` property is not supported and will be ignored. +Aggregate Relations +^^^^^^^^^^^^^^^^^^^ + + * At most one grouping set is supported. + * Each grouping expression must be a direct reference. + * Each measure's arguments must be direct references. + * A measure may not have a filter + * A measure may not have sorts + * A measure's invocation must be AGGREGATION_INVOCATION_ALL + * A measure's phase must be AGGREGATION_PHASE_INITIAL_TO_RESULT + Expressions (general) ^^^^^^^^^^^^^^^^^^^^^ @@ -152,20 +163,128 @@ Expressions (general) grouping set. Acero typically expects these expressions to be direct references. Planners should extract the implicit projection into a formal project relation before delivering the plan to Acero. + * Older versions of Isthmus would omit optional arguments instead of including them + as unspecified enums. Acero will not support these plans. Literals ^^^^^^^^ * A literal with non-default nullability will cause a plan to be rejected. +Types +^^^^^ + + * Acero does not have full support for non-nullable types and may allow input + to have nulls without rejecting it. + * The table below shows the mapping between Arrow types and Substrait type + classes that are currently supported + +.. list-table:: Substrait / Arrow Type Mapping + :widths: 25 25 + :header-rows: 1 + + * - Substrait Type + - Arrow Type + - Caveat + * - boolean + - boolean + - + * - i8 + - int8 + - + * - i16 + - int16 + - + * - i16 + - int16 + - + * - i32 + - int32 + - + * - i64 + - int64 + - + * - fp32 + - float32 + - + * - fp64 + - float64 + - + * - string + - string + - + * - binary + - binary + - + * - timestamp + - timestamp + - + * - timestamp_tz + - timestamp + - + * - date + - date32 + - + * - time + - time64 + - + * - interval_year + - + - Not currently supported + * - interval_day + - + - Not currently supported + * - uuid + - + - Not currently supported + * - FIXEDCHAR + - + - Not currently supported + * - VARCHAR + - + - Not currently supported + * - FIXEDBINARY + - fixed_size_binary + - + * - DECIMAL + - decimal128 + - + * - STRUCT + - struct + - Arrow struct fields will have no name (empty string) + * - NSTRUCT + - + - Not currently supported + * - LIST + - list + - + * - MAP + - map + - K must not be nullable + Functions ^^^^^^^^^ - * The only functions currently supported by Acero are: - - * add - * equal - * is_not_distinct_from + * Acero does not support the legacy ``args`` style of declaring arguments + * The following functions have caveats or are not supported at all. Note that + this is not a comprehensive list. Functions are being added to Substrait at + a rapid pace and new functions may be missing. + + * Acero does not support the SATURATE option for overflow + * Acero does not support kernels that take more than two arguments + for the functions ``and``, ``or``, ``xor`` + * Acero does not support temporal arithmetic + * Acero does not support the following standard functions: + + * ``is_not_distinct_from`` + * ``like`` + * ``substring`` + * ``starts_with`` + * ``ends_with`` + * ``contains`` + * ``count`` + * ``count_distinct`` + * ``approx_count_distinct`` * The functions above must be referenced using the URI ``https://github.com/apache/arrow/blob/master/format/substrait/extension_types.yaml`` diff --git a/python/pyarrow/_substrait.pyx b/python/pyarrow/_substrait.pyx index 7f079fb717b79..05794a95a20ee 100644 --- a/python/pyarrow/_substrait.pyx +++ b/python/pyarrow/_substrait.pyx @@ -19,6 +19,7 @@ from cython.operator cimport dereference as deref from pyarrow import Buffer +from pyarrow.lib import frombytes from pyarrow.lib cimport * from pyarrow.includes.libarrow cimport * from pyarrow.includes.libarrow_substrait cimport * @@ -77,3 +78,27 @@ def _parse_json_plan(plan): with nogil: c_buf_plan = GetResultValue(c_res_buffer) return pyarrow_wrap_buffer(c_buf_plan) + + +def get_supported_functions(): + """ + Get a list of Substrait functions that the underlying + engine currently supports. + + Returns + ------- + list[str] + A list of function ids encoded as '{uri}#{name}' + """ + + cdef: + ExtensionIdRegistry* c_id_registry + std_vector[c_string] c_ids + + c_id_registry = default_extension_id_registry() + c_ids = c_id_registry.GetSupportedSubstraitFunctions() + + functions_list = [] + for c_id in c_ids: + functions_list.append(frombytes(c_id)) + return functions_list diff --git a/python/pyarrow/includes/libarrow_substrait.pxd b/python/pyarrow/includes/libarrow_substrait.pxd index 2e1a17b06bddd..0b3ace75d92b0 100644 --- a/python/pyarrow/includes/libarrow_substrait.pxd +++ b/python/pyarrow/includes/libarrow_substrait.pxd @@ -17,10 +17,20 @@ # distutils: language = c++ +from libcpp.vector cimport vector as std_vector + from pyarrow.includes.common cimport * from pyarrow.includes.libarrow cimport * -cdef extern from "arrow/engine/substrait/util.h" namespace "arrow::engine::substrait" nogil: +cdef extern from "arrow/engine/substrait/util.h" namespace "arrow::engine" nogil: CResult[shared_ptr[CRecordBatchReader]] ExecuteSerializedPlan(const CBuffer& substrait_buffer) CResult[shared_ptr[CBuffer]] SerializeJsonPlan(const c_string& substrait_json) + +cdef extern from "arrow/engine/substrait/extension_set.h" \ + namespace "arrow::engine" nogil: + + cdef cppclass ExtensionIdRegistry: + std_vector[c_string] GetSupportedSubstraitFunctions() + + ExtensionIdRegistry* default_extension_id_registry() diff --git a/python/pyarrow/substrait.py b/python/pyarrow/substrait.py index e3ff28f4ebaea..590d03521fe50 100644 --- a/python/pyarrow/substrait.py +++ b/python/pyarrow/substrait.py @@ -16,5 +16,6 @@ # under the License. from pyarrow._substrait import ( # noqa + get_supported_functions, run_query, ) diff --git a/python/pyarrow/tests/test_substrait.py b/python/pyarrow/tests/test_substrait.py index f05d68a95a14f..c8fa6afcb9ffa 100644 --- a/python/pyarrow/tests/test_substrait.py +++ b/python/pyarrow/tests/test_substrait.py @@ -145,3 +145,23 @@ def test_binary_conversion_with_json_options(tmpdir): res_tb = reader.read_all() assert table.select(["bar"]) == res_tb.select(["bar"]) + + +# Substrait has not finalized what the URI should be for standard functions +# In the meantime, lets just check the suffix +def has_function(fns, ext_file, fn_name): + suffix = f'{ext_file}#{fn_name}' + for fn in fns: + if fn.endswith(suffix): + return True + return False + + +def test_get_supported_functions(): + supported_functions = pa._substrait.get_supported_functions() + # It probably doesn't make sense to exhaustively verfiy this list but + # we can check a sample aggregate and a sample non-aggregate entry + assert has_function(supported_functions, + 'functions_arithmetic.yaml', 'add') + assert has_function(supported_functions, + 'functions_arithmetic.yaml', 'sum')