Skip to content

Commit

Permalink
feat: support for subtract and multiply functions
Browse files Browse the repository at this point in the history
  • Loading branch information
sanjibansg committed Jun 3, 2022
1 parent 9cd57d1 commit a56a556
Showing 1 changed file with 58 additions and 13 deletions.
71 changes: 58 additions & 13 deletions cpp/src/arrow/engine/substrait/extension_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,13 @@ substrait::Expression::ScalarFunction arrow_convert_arguments(const arrow::compu
return std::move(substrait_call);
}

substrait::Expression::ScalarFunction arrow_convert_arithmetic_arguments(const arrow::compute::Expression::Call& call, substrait::Expression::ScalarFunction& substrait_call, ExtensionSet* ext_set_, std::string overflow_handling){
substrait::Expression::Enum options;
options.set_specified(overflow_handling);
substrait_call.add_args()->set_allocated_enum_(&options);
return arrow_convert_arguments(call, substrait_call, ext_set_);
}

SubstraitToArrow substrait_add_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result<arrow::compute::Expression> {
auto func_args = substrait_convert_arguments(call);
if(func_args[0].ToString() == "SILENT"){
Expand All @@ -424,30 +431,68 @@ SubstraitToArrow substrait_add_to_arrow = [] (const substrait::Expression::Scala
}
};

SubstraitToArrow substrait_subtract_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result<arrow::compute::Expression> {
auto func_args = substrait_convert_arguments(call);
if(func_args[0].ToString() == "SILENT"){
return arrow::compute::call("subtract", {func_args[1], func_args[2]}, compute::ArithmeticOptions());
} else if (func_args[0].ToString() == "SATURATE") {
return Status::Invalid("Arrow does not support a saturating add");
} else {
return arrow::compute::call("subtract_checked", {func_args[1], func_args[2]}, compute::ArithmeticOptions(true));
}
};

SubstraitToArrow substrait_multiply_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result<arrow::compute::Expression> {
auto func_args = substrait_convert_arguments(call);
if(func_args[0].ToString() == "SILENT"){
return arrow::compute::call("multiply", {func_args[1], func_args[2]}, compute::ArithmeticOptions());
} else if (func_args[0].ToString() == "SATURATE") {
return Status::Invalid("Arrow does not support a saturating add");
} else {
return arrow::compute::call("mutiply_checked", {func_args[1], func_args[2]}, compute::ArithmeticOptions(true));
}
};

ArrowToSubstrait arrow_add_to_substrait = [] (const arrow::compute::Expression::Call& call, arrow::engine::ExtensionSet* ext_set_) -> Result<substrait::Expression::ScalarFunction> {
substrait::Expression::ScalarFunction substrait_call;

ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set_->EncodeFunction("add"));
substrait_call.set_function_reference(function_reference);

substrait::Expression::Enum options;
std::string overflow_handling = "ERROR";
options.set_specified(overflow_handling);
substrait_call.add_args()->set_allocated_enum_(&options);
return arrow_convert_arguments(call, substrait_call, ext_set_);
return arrow_convert_arithmetic_arguments(call, substrait_call, ext_set_, "ERROR");
};

ArrowToSubstrait arrow_unchecked_add_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set_) -> Result<substrait::Expression::ScalarFunction> {
substrait::Expression::ScalarFunction substrait_call;

ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set_->EncodeFunction("add"));
substrait_call.set_function_reference(function_reference);
return arrow_convert_arithmetic_arguments(call, substrait_call, ext_set_, "SILENT");
};

substrait::Expression::Enum options;
std::string overflow_handling = "SILENT";
options.set_specified(overflow_handling);
substrait_call.add_args()->set_allocated_enum_(&options);
return arrow_convert_arguments(call, substrait_call, ext_set_);
ArrowToSubstrait arrow_subtract_to_substrait = [] (const arrow::compute::Expression::Call& call, arrow::engine::ExtensionSet* ext_set_) -> Result<substrait::Expression::ScalarFunction> {
substrait::Expression::ScalarFunction substrait_call;
ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set_->EncodeFunction("subtract"));
substrait_call.set_function_reference(function_reference);
return arrow_convert_arithmetic_arguments(call, substrait_call, ext_set_, "ERROR");
};

ArrowToSubstrait arrow_unchecked_subtract_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set_) -> Result<substrait::Expression::ScalarFunction> {
substrait::Expression::ScalarFunction substrait_call;
ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set_->EncodeFunction("subtract"));
substrait_call.set_function_reference(function_reference);
return arrow_convert_arithmetic_arguments(call, substrait_call, ext_set_, "SILENT") ;
};

ArrowToSubstrait arrow_multiply_to_substrait = [] (const arrow::compute::Expression::Call& call, arrow::engine::ExtensionSet* ext_set_) -> Result<substrait::Expression::ScalarFunction> {
substrait::Expression::ScalarFunction substrait_call;
ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set_->EncodeFunction("multiply"));
substrait_call.set_function_reference(function_reference);
return arrow_convert_arithmetic_arguments(call, substrait_call, ext_set_, "ERROR");
};

ArrowToSubstrait arrow_unchecked_multiply_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set_) -> Result<substrait::Expression::ScalarFunction> {
substrait::Expression::ScalarFunction substrait_call;
ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set_->EncodeFunction("multiply"));
substrait_call.set_function_reference(function_reference);
return arrow_convert_arithmetic_arguments(call, substrait_call, ext_set_, "SILENT");
};


Expand Down

0 comments on commit a56a556

Please sign in to comment.