diff --git a/cpp/src/arrow/engine/substrait/extension_set.cc b/cpp/src/arrow/engine/substrait/extension_set.cc index da3b38e5c7893..4abbc793b9080 100644 --- a/cpp/src/arrow/engine/substrait/extension_set.cc +++ b/cpp/src/arrow/engine/substrait/extension_set.cc @@ -413,6 +413,13 @@ substrait::Expression::ScalarFunction arrow_convert_arguments(const arrow::compu return std::move(substrait_call); } +substrait::Expression::ScalarFunction arrow_convert_arithmetic_arguments(const arrow::compute::Expression::Call& call, substrait::Expression::ScalarFunction& substrait_call, ExtensionSet* ext_set_, std::string overflow_handling){ + substrait::Expression::Enum options; + options.set_specified(overflow_handling); + substrait_call.add_args()->set_allocated_enum_(&options); + return arrow_convert_arguments(call, substrait_call, ext_set_); +} + SubstraitToArrow substrait_add_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result { auto func_args = substrait_convert_arguments(call); if(func_args[0].ToString() == "SILENT"){ @@ -424,30 +431,68 @@ SubstraitToArrow substrait_add_to_arrow = [] (const substrait::Expression::Scala } }; +SubstraitToArrow substrait_subtract_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result { + auto func_args = substrait_convert_arguments(call); + if(func_args[0].ToString() == "SILENT"){ + return arrow::compute::call("subtract", {func_args[1], func_args[2]}, compute::ArithmeticOptions()); + } else if (func_args[0].ToString() == "SATURATE") { + return Status::Invalid("Arrow does not support a saturating add"); + } else { + return arrow::compute::call("subtract_checked", {func_args[1], func_args[2]}, compute::ArithmeticOptions(true)); + } +}; + +SubstraitToArrow substrait_multiply_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result { + auto func_args = substrait_convert_arguments(call); + if(func_args[0].ToString() == "SILENT"){ + return arrow::compute::call("multiply", {func_args[1], func_args[2]}, compute::ArithmeticOptions()); + } else if (func_args[0].ToString() == "SATURATE") { + return Status::Invalid("Arrow does not support a saturating add"); + } else { + return arrow::compute::call("mutiply_checked", {func_args[1], func_args[2]}, compute::ArithmeticOptions(true)); + } +}; + ArrowToSubstrait arrow_add_to_substrait = [] (const arrow::compute::Expression::Call& call, arrow::engine::ExtensionSet* ext_set_) -> Result { substrait::Expression::ScalarFunction substrait_call; - ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set_->EncodeFunction("add")); substrait_call.set_function_reference(function_reference); - - substrait::Expression::Enum options; - std::string overflow_handling = "ERROR"; - options.set_specified(overflow_handling); - substrait_call.add_args()->set_allocated_enum_(&options); - return arrow_convert_arguments(call, substrait_call, ext_set_); + return arrow_convert_arithmetic_arguments(call, substrait_call, ext_set_, "ERROR"); }; ArrowToSubstrait arrow_unchecked_add_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set_) -> Result { substrait::Expression::ScalarFunction substrait_call; - ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set_->EncodeFunction("add")); substrait_call.set_function_reference(function_reference); + return arrow_convert_arithmetic_arguments(call, substrait_call, ext_set_, "SILENT"); +}; - substrait::Expression::Enum options; - std::string overflow_handling = "SILENT"; - options.set_specified(overflow_handling); - substrait_call.add_args()->set_allocated_enum_(&options); - return arrow_convert_arguments(call, substrait_call, ext_set_); +ArrowToSubstrait arrow_subtract_to_substrait = [] (const arrow::compute::Expression::Call& call, arrow::engine::ExtensionSet* ext_set_) -> Result { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set_->EncodeFunction("subtract")); + substrait_call.set_function_reference(function_reference); + return arrow_convert_arithmetic_arguments(call, substrait_call, ext_set_, "ERROR"); +}; + +ArrowToSubstrait arrow_unchecked_subtract_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set_) -> Result { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set_->EncodeFunction("subtract")); + substrait_call.set_function_reference(function_reference); + return arrow_convert_arithmetic_arguments(call, substrait_call, ext_set_, "SILENT") ; +}; + +ArrowToSubstrait arrow_multiply_to_substrait = [] (const arrow::compute::Expression::Call& call, arrow::engine::ExtensionSet* ext_set_) -> Result { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set_->EncodeFunction("multiply")); + substrait_call.set_function_reference(function_reference); + return arrow_convert_arithmetic_arguments(call, substrait_call, ext_set_, "ERROR"); +}; + +ArrowToSubstrait arrow_unchecked_multiply_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set_) -> Result { + substrait::Expression::ScalarFunction substrait_call; + ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set_->EncodeFunction("multiply")); + substrait_call.set_function_reference(function_reference); + return arrow_convert_arithmetic_arguments(call, substrait_call, ext_set_, "SILENT"); };