From 9cd57d1ad7191142d483a58de7c637b5b1833337 Mon Sep 17 00:00:00 2001 From: Sanjiban Sengupta Date: Fri, 3 Jun 2022 15:23:51 +0530 Subject: [PATCH] feat: helpers for conversion functions --- .../engine/substrait/expression_internal.cc | 9 + .../arrow/engine/substrait/extension_set.cc | 225 +++++------------- 2 files changed, 65 insertions(+), 169 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/expression_internal.cc b/cpp/src/arrow/engine/substrait/expression_internal.cc index 5992110c34f27..c36aba6beb1bd 100644 --- a/cpp/src/arrow/engine/substrait/expression_internal.cc +++ b/cpp/src/arrow/engine/substrait/expression_internal.cc @@ -164,6 +164,15 @@ Result FromProto(const substrait::Expression& expr, return arrow_function(scalar_fn); } + case substrait::Expression::kEnum: { + auto enum_expr = expr.enum_(); + if(enum_expr.has_specified()){ + return compute::literal(std::move(enum_expr.specified())); + } else { + return Status::Invalid("Substrait Enum value not specified"); + } + } + default: break; } diff --git a/cpp/src/arrow/engine/substrait/extension_set.cc b/cpp/src/arrow/engine/substrait/extension_set.cc index 70d9ef1acd0f8..da3b38e5c7893 100644 --- a/cpp/src/arrow/engine/substrait/extension_set.cc +++ b/cpp/src/arrow/engine/substrait/extension_set.cc @@ -360,7 +360,7 @@ Status FunctionMapping::AddArrowToSubstrait(std::string arrow_function_name, Arr arrow_to_substrait[arrow_function_name] = conversion_func; return Status::OK(); } else{ - return Status::Invalid("Arrow function already exist in the conversion map"); + return Status::AlreadyExists("Arrow function already exist in the conversion map"); } } @@ -369,7 +369,7 @@ Status FunctionMapping::AddSubstraitToArrow(std::string substrait_function_name, substrait_to_arrow[substrait_function_name] = conversion_func; return Status::OK(); } else{ - return Status::Invalid("Substrait function already exist in the conversion map"); + return Status::AlreadyExists("Substrait function already exist in the conversion map"); } } @@ -377,7 +377,7 @@ Result FunctionMapping::GetArrowFromSubstrait(std::string name if (FunctionMapping::substrait_to_arrow.find(name)!=FunctionMapping::substrait_to_arrow.end()){ return FunctionMapping::substrait_to_arrow.at(name); } else { - return Status::Invalid("Substrait function doesn't exist in the mapping registry"); + return Status::KeyError("Substrait function doesn't exist in the mapping registry"); } } @@ -385,34 +385,42 @@ Result FunctionMapping::GetSubstraitFromArrow(std::string name if (FunctionMapping::arrow_to_substrait.find(name)!=FunctionMapping::arrow_to_substrait.end()){ return FunctionMapping::arrow_to_substrait.at(name); } else { - return Status::Invalid("Arrow function doesn't exist in the mapping registry"); + return Status::KeyError("Arrow function doesn't exist in the mapping registry"); } } - -SubstraitToArrow substrait_add_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result { - auto value_1 = call.args(1); - auto value_2 = call.args(2); +std::vector substrait_convert_arguments(const substrait::Expression::ScalarFunction& call){ + substrait::Expression value; ExtensionSet ext_set_; - ARROW_ASSIGN_OR_RAISE(auto expression_1, FromProto(value_1, ext_set_)); - ARROW_ASSIGN_OR_RAISE(auto expression_2, FromProto(value_2, ext_set_)); - auto options = call.args(0); - if (options.has_enum_()) { - auto overflow_handling = options.enum_(); - if(overflow_handling.has_specified()){ - std::string overflow_type = overflow_handling.specified(); - if(overflow_type == "SILENT"){ - return arrow::compute::call("add", {expression_1,expression_2}, compute::ArithmeticOptions()); - } else if (overflow_type == "SATURATE") { - return Status::Invalid("Arrow does not support a saturating add"); - } else { - return arrow::compute::call("add_checked", {expression_1,expression_2}, compute::ArithmeticOptions(true)); - } - } else { - return arrow::compute::call("add", {expression_1,expression_2}, compute::ArithmeticOptions()); + arrow::compute::Expression expression; + std::vector func_args; + for(int i=0; i value; + for(size_t i = 0; iCopyFrom(*value); + } + return std::move(substrait_call); +} + +SubstraitToArrow substrait_add_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result { + auto func_args = substrait_convert_arguments(call); + if(func_args[0].ToString() == "SILENT"){ + return arrow::compute::call("add", {func_args[1], func_args[2]}, compute::ArithmeticOptions()); + } else if (func_args[0].ToString() == "SATURATE") { + return Status::Invalid("Arrow does not support a saturating add"); } else { - return Status::Invalid("Substrait Function Options should be an enum"); + return arrow::compute::call("add_checked", {func_args[1], func_args[2]}, compute::ArithmeticOptions(true)); } }; @@ -426,16 +434,7 @@ ArrowToSubstrait arrow_add_to_substrait = [] (const arrow::compute::Expression:: std::string overflow_handling = "ERROR"; options.set_specified(overflow_handling); substrait_call.add_args()->set_allocated_enum_(&options); - - auto expression_1 = call.arguments[0]; - auto expression_2 = call.arguments[1]; - - ARROW_ASSIGN_OR_RAISE(auto value_1, ToProto(expression_1, ext_set_)); - ARROW_ASSIGN_OR_RAISE(auto value_2, ToProto(expression_2, ext_set_)); - - substrait_call.add_args()->CopyFrom(*value_1); - substrait_call.add_args()->CopyFrom(*value_2); - return std::move(substrait_call); + return arrow_convert_arguments(call, substrait_call, ext_set_); }; ArrowToSubstrait arrow_unchecked_add_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set_) -> Result { @@ -448,205 +447,93 @@ ArrowToSubstrait arrow_unchecked_add_to_substrait = [] (const arrow::compute::Ex std::string overflow_handling = "SILENT"; options.set_specified(overflow_handling); substrait_call.add_args()->set_allocated_enum_(&options); - - auto expression_1 = call.arguments[0]; - auto expression_2 = call.arguments[1]; - - ARROW_ASSIGN_OR_RAISE(auto value_1, ToProto(expression_1, ext_set_)); - ARROW_ASSIGN_OR_RAISE(auto value_2, ToProto(expression_2, ext_set_)); - - substrait_call.add_args()->CopyFrom(*value_1); - substrait_call.add_args()->CopyFrom(*value_2); - return std::move(substrait_call); + return arrow_convert_arguments(call, substrait_call, ext_set_); }; // Boolean Functions mapping SubstraitToArrow substrait_not_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result { - auto value_1 = call.args(1); - ExtensionSet ext_set_; - ARROW_ASSIGN_OR_RAISE(auto expression_1, FromProto(value_1, ext_set_)); - return arrow::compute::call("invert", {expression_1}); + return arrow::compute::call("invert", substrait_convert_arguments(call)); }; SubstraitToArrow substrait_or_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result { - int num_args = call.args_size(); // OR function has variadic arguments - substrait::Expression value; - ExtensionSet ext_set_; - arrow::compute::Expression expression; - std::vector func_args; - for(int i=0; i Result { - int num_args = call.args_size(); // AND function has variadic arguments - substrait::Expression value; - ExtensionSet ext_set_; - arrow::compute::Expression expression; - std::vector func_args; - for(int i=0; i Result { - auto value_1 = call.args(0); - auto value_2 = call.args(1); - ExtensionSet ext_set_; - ARROW_ASSIGN_OR_RAISE(auto expression_1, FromProto(value_1, ext_set_)); - ARROW_ASSIGN_OR_RAISE(auto expression_2, FromProto(value_2, ext_set_)); - return arrow::compute::call("xor", {expression_1, expression_2}); + return arrow::compute::call("xor", substrait_convert_arguments(call)); }; ArrowToSubstrait arrow_invert_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set_) -> Result { substrait::Expression::ScalarFunction substrait_call; - ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set_->EncodeFunction("not")); substrait_call.set_function_reference(function_reference); - - auto expression_1 = call.arguments[0]; - auto expression_2 = call.arguments[1]; - - ARROW_ASSIGN_OR_RAISE(auto value_1, ToProto(expression_1, ext_set_)); - ARROW_ASSIGN_OR_RAISE(auto value_2, ToProto(expression_2, ext_set_)); - - substrait_call.add_args()->CopyFrom(*value_1); - substrait_call.add_args()->CopyFrom(*value_2); - return std::move(substrait_call); - + return arrow_convert_arguments(call, substrait_call, ext_set_); }; ArrowToSubstrait arrow_or_kleene_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set_) -> Result { substrait::Expression::ScalarFunction substrait_call; - ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set_->EncodeFunction("or")); substrait_call.set_function_reference(function_reference); - - arrow::compute::Expression expression; - std::unique_ptr value; - for(size_t i = 0; iCopyFrom(*value); - } - return std::move(substrait_call); + return arrow_convert_arguments(call, substrait_call, ext_set_); }; ArrowToSubstrait arrow_and_kleene_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set_) -> Result { substrait::Expression::ScalarFunction substrait_call; - ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set_->EncodeFunction("and")); substrait_call.set_function_reference(function_reference); - - arrow::compute::Expression expression; - std::unique_ptr value; - for(size_t i = 0; iCopyFrom(*value); - } - - return std::move(substrait_call); + return arrow_convert_arguments(call, substrait_call, ext_set_); }; ArrowToSubstrait arrow_xor_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set_) -> Result { substrait::Expression::ScalarFunction substrait_call; - ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set_->EncodeFunction("xor")); substrait_call.set_function_reference(function_reference); - - auto expression_1 = call.arguments[0]; - auto expression_2 = call.arguments[1]; - - ARROW_ASSIGN_OR_RAISE(auto value_1, ToProto(expression_1, ext_set_)); - ARROW_ASSIGN_OR_RAISE(auto value_2, ToProto(expression_2, ext_set_)); - - substrait_call.add_args()->CopyFrom(*value_1); - substrait_call.add_args()->CopyFrom(*value_2); - return std::move(substrait_call); + return arrow_convert_arguments(call, substrait_call, ext_set_); }; // Comparison Functions mapping SubstraitToArrow substrait_lt_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result { - auto value_1 = call.args(0); - auto value_2 = call.args(1); - ExtensionSet ext_set_; - ARROW_ASSIGN_OR_RAISE(auto expression_1, FromProto(value_1, ext_set_)); - ARROW_ASSIGN_OR_RAISE(auto expression_2, FromProto(value_2, ext_set_)); - return arrow::compute::call("less", {expression_1, expression_2}); + return arrow::compute::call("less", substrait_convert_arguments(call)); }; SubstraitToArrow substrait_gt_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result { - auto value_1 = call.args(0); - auto value_2 = call.args(1); - ExtensionSet ext_set_; - ARROW_ASSIGN_OR_RAISE(auto expression_1, FromProto(value_1, ext_set_)); - ARROW_ASSIGN_OR_RAISE(auto expression_2, FromProto(value_2, ext_set_)); - return arrow::compute::call("greater", {expression_1, expression_2}); + return arrow::compute::call("greater", substrait_convert_arguments(call)); }; -SubstraitToArrow substrait_lte_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result { - auto value_1 = call.args(0); - auto value_2 = call.args(1); - ExtensionSet ext_set_; - ARROW_ASSIGN_OR_RAISE(auto expression_1, FromProto(value_1, ext_set_)); - ARROW_ASSIGN_OR_RAISE(auto expression_2, FromProto(value_2, ext_set_)); - return arrow::compute::call("less_equal", {expression_1, expression_2}); +SubstraitToArrow substrait_lte_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result { + return arrow::compute::call("less_equal", substrait_convert_arguments(call)); }; SubstraitToArrow substrait_not_equal_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result { - auto value_1 = call.args(0); - auto value_2 = call.args(1); - ExtensionSet ext_set_; - ARROW_ASSIGN_OR_RAISE(auto expression_1, FromProto(value_1, ext_set_)); - ARROW_ASSIGN_OR_RAISE(auto expression_2, FromProto(value_2, ext_set_)); - return arrow::compute::call("not_equal", {expression_1, expression_2}); + return arrow::compute::call("not_equal", substrait_convert_arguments(call)); }; -SubstraitToArrow substrait_equal_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result { - auto value_1 = call.args(0); - auto value_2 = call.args(1); - ExtensionSet ext_set_; - ARROW_ASSIGN_OR_RAISE(auto expression_1, FromProto(value_1, ext_set_)); - ARROW_ASSIGN_OR_RAISE(auto expression_2, FromProto(value_2, ext_set_)); - return arrow::compute::call("equal", {expression_1, expression_2}); +SubstraitToArrow substrait_equal_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result { + return arrow::compute::call("equal", substrait_convert_arguments(call)); }; SubstraitToArrow substrait_is_null_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result { - auto value_1 = call.args(0); - ExtensionSet ext_set_; - ARROW_ASSIGN_OR_RAISE(auto expression_1, FromProto(value_1, ext_set_)); - return arrow::compute::call("is_null", {expression_1}); + return arrow::compute::call("is_null", substrait_convert_arguments(call)); }; SubstraitToArrow substrait_is_not_null_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result { - auto value_1 = call.args(0); - ExtensionSet ext_set_; - ARROW_ASSIGN_OR_RAISE(auto expression_1, FromProto(value_1, ext_set_)); - return arrow::compute::call("is_valid", {expression_1}); + return arrow::compute::call("is_valid", substrait_convert_arguments(call)); }; SubstraitToArrow substrait_is_not_distinct_from_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result { - auto value_1 = call.args(0); - auto value_2 = call.args(1); - ExtensionSet ext_set_; - ARROW_ASSIGN_OR_RAISE(auto expression_1, FromProto(value_1, ext_set_)); - ARROW_ASSIGN_OR_RAISE(auto expression_2, FromProto(value_2, ext_set_)); - auto null_check_1 = arrow::compute::call("is_null", {expression_1}); - auto null_check_2 = arrow::compute::call("is_null", {expression_2}); + std::vector func_args = substrait_convert_arguments(call); + auto null_check_1 = arrow::compute::call("is_null", {func_args[0]}); + auto null_check_2 = arrow::compute::call("is_null", {func_args[1]}); if(null_check_1.IsNullLiteral() && null_check_1.IsNullLiteral()){ return arrow::compute::call("not_equal", {null_check_1, null_check_2}); } - return arrow::compute::call("not_equal", {expression_1, expression_2}); + return arrow::compute::call("not_equal", func_args); }; } // namespace engine