Skip to content

Commit

Permalink
feat: helpers for conversion functions
Browse files Browse the repository at this point in the history
  • Loading branch information
sanjibansg committed Jun 3, 2022
1 parent 5fd670a commit 9cd57d1
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 169 deletions.
9 changes: 9 additions & 0 deletions cpp/src/arrow/engine/substrait/expression_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,15 @@ Result<compute::Expression> FromProto(const substrait::Expression& expr,
return arrow_function(scalar_fn);
}

case substrait::Expression::kEnum: {
auto enum_expr = expr.enum_();
if(enum_expr.has_specified()){
return compute::literal(std::move(enum_expr.specified()));
} else {
return Status::Invalid("Substrait Enum value not specified");
}
}

default:
break;
}
Expand Down
225 changes: 56 additions & 169 deletions cpp/src/arrow/engine/substrait/extension_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ Status FunctionMapping::AddArrowToSubstrait(std::string arrow_function_name, Arr
arrow_to_substrait[arrow_function_name] = conversion_func;
return Status::OK();
} else{
return Status::Invalid("Arrow function already exist in the conversion map");
return Status::AlreadyExists("Arrow function already exist in the conversion map");
}
}

Expand All @@ -369,50 +369,58 @@ Status FunctionMapping::AddSubstraitToArrow(std::string substrait_function_name,
substrait_to_arrow[substrait_function_name] = conversion_func;
return Status::OK();
} else{
return Status::Invalid("Substrait function already exist in the conversion map");
return Status::AlreadyExists("Substrait function already exist in the conversion map");
}
}

Result<SubstraitToArrow> FunctionMapping::GetArrowFromSubstrait(std::string name) const {
if (FunctionMapping::substrait_to_arrow.find(name)!=FunctionMapping::substrait_to_arrow.end()){
return FunctionMapping::substrait_to_arrow.at(name);
} else {
return Status::Invalid("Substrait function doesn't exist in the mapping registry");
return Status::KeyError("Substrait function doesn't exist in the mapping registry");
}
}

Result<ArrowToSubstrait> FunctionMapping::GetSubstraitFromArrow(std::string name) const {
if (FunctionMapping::arrow_to_substrait.find(name)!=FunctionMapping::arrow_to_substrait.end()){
return FunctionMapping::arrow_to_substrait.at(name);
} else {
return Status::Invalid("Arrow function doesn't exist in the mapping registry");
return Status::KeyError("Arrow function doesn't exist in the mapping registry");
}
}


SubstraitToArrow substrait_add_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result<arrow::compute::Expression> {
auto value_1 = call.args(1);
auto value_2 = call.args(2);
std::vector<arrow::compute::Expression> substrait_convert_arguments(const substrait::Expression::ScalarFunction& call){
substrait::Expression value;
ExtensionSet ext_set_;
ARROW_ASSIGN_OR_RAISE(auto expression_1, FromProto(value_1, ext_set_));
ARROW_ASSIGN_OR_RAISE(auto expression_2, FromProto(value_2, ext_set_));
auto options = call.args(0);
if (options.has_enum_()) {
auto overflow_handling = options.enum_();
if(overflow_handling.has_specified()){
std::string overflow_type = overflow_handling.specified();
if(overflow_type == "SILENT"){
return arrow::compute::call("add", {expression_1,expression_2}, compute::ArithmeticOptions());
} else if (overflow_type == "SATURATE") {
return Status::Invalid("Arrow does not support a saturating add");
} else {
return arrow::compute::call("add_checked", {expression_1,expression_2}, compute::ArithmeticOptions(true));
}
} else {
return arrow::compute::call("add", {expression_1,expression_2}, compute::ArithmeticOptions());
arrow::compute::Expression expression;
std::vector<compute::Expression> func_args;
for(int i=0; i<call.args_size(); ++i){
value = call.args(i);
expression = FromProto(value, ext_set_).ValueOrDie();
func_args.push_back(expression);
}
return func_args;
}

substrait::Expression::ScalarFunction arrow_convert_arguments(const arrow::compute::Expression::Call& call, substrait::Expression::ScalarFunction& substrait_call, ExtensionSet* ext_set_){
arrow::compute::Expression expression;
std::unique_ptr<substrait::Expression> value;
for(size_t i = 0; i<call.arguments.size(); ++i){
expression = call.arguments[i];
value = ToProto(expression, ext_set_).ValueOrDie();
substrait_call.add_args()->CopyFrom(*value);
}
return std::move(substrait_call);
}

SubstraitToArrow substrait_add_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result<arrow::compute::Expression> {
auto func_args = substrait_convert_arguments(call);
if(func_args[0].ToString() == "SILENT"){
return arrow::compute::call("add", {func_args[1], func_args[2]}, compute::ArithmeticOptions());
} else if (func_args[0].ToString() == "SATURATE") {
return Status::Invalid("Arrow does not support a saturating add");
} else {
return Status::Invalid("Substrait Function Options should be an enum");
return arrow::compute::call("add_checked", {func_args[1], func_args[2]}, compute::ArithmeticOptions(true));
}
};

Expand All @@ -426,16 +434,7 @@ ArrowToSubstrait arrow_add_to_substrait = [] (const arrow::compute::Expression::
std::string overflow_handling = "ERROR";
options.set_specified(overflow_handling);
substrait_call.add_args()->set_allocated_enum_(&options);

auto expression_1 = call.arguments[0];
auto expression_2 = call.arguments[1];

ARROW_ASSIGN_OR_RAISE(auto value_1, ToProto(expression_1, ext_set_));
ARROW_ASSIGN_OR_RAISE(auto value_2, ToProto(expression_2, ext_set_));

substrait_call.add_args()->CopyFrom(*value_1);
substrait_call.add_args()->CopyFrom(*value_2);
return std::move(substrait_call);
return arrow_convert_arguments(call, substrait_call, ext_set_);
};

ArrowToSubstrait arrow_unchecked_add_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set_) -> Result<substrait::Expression::ScalarFunction> {
Expand All @@ -448,205 +447,93 @@ ArrowToSubstrait arrow_unchecked_add_to_substrait = [] (const arrow::compute::Ex
std::string overflow_handling = "SILENT";
options.set_specified(overflow_handling);
substrait_call.add_args()->set_allocated_enum_(&options);

auto expression_1 = call.arguments[0];
auto expression_2 = call.arguments[1];

ARROW_ASSIGN_OR_RAISE(auto value_1, ToProto(expression_1, ext_set_));
ARROW_ASSIGN_OR_RAISE(auto value_2, ToProto(expression_2, ext_set_));

substrait_call.add_args()->CopyFrom(*value_1);
substrait_call.add_args()->CopyFrom(*value_2);
return std::move(substrait_call);
return arrow_convert_arguments(call, substrait_call, ext_set_);
};


// Boolean Functions mapping
SubstraitToArrow substrait_not_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result<arrow::compute::Expression> {
auto value_1 = call.args(1);
ExtensionSet ext_set_;
ARROW_ASSIGN_OR_RAISE(auto expression_1, FromProto(value_1, ext_set_));
return arrow::compute::call("invert", {expression_1});
return arrow::compute::call("invert", substrait_convert_arguments(call));
};

SubstraitToArrow substrait_or_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result<arrow::compute::Expression> {
int num_args = call.args_size(); // OR function has variadic arguments
substrait::Expression value;
ExtensionSet ext_set_;
arrow::compute::Expression expression;
std::vector<arrow::compute::Expression> func_args;
for(int i=0; i<num_args; ++i){
value = call.args(i);
ARROW_ASSIGN_OR_RAISE(expression, FromProto(value, ext_set_));
func_args.push_back(expression);
}
return arrow::compute::call("or_kleene", func_args);
return arrow::compute::call("or_kleene", substrait_convert_arguments(call));
};

SubstraitToArrow substrait_and_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result<arrow::compute::Expression> {
int num_args = call.args_size(); // AND function has variadic arguments
substrait::Expression value;
ExtensionSet ext_set_;
arrow::compute::Expression expression;
std::vector<arrow::compute::Expression> func_args;
for(int i=0; i<num_args; ++i){
value = call.args(i);
ARROW_ASSIGN_OR_RAISE(expression, FromProto(value, ext_set_));
func_args.push_back(expression);
}
return arrow::compute::call("and_kleene", func_args);
return arrow::compute::call("and_kleene", substrait_convert_arguments(call));
};

SubstraitToArrow substrait_xor_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result<arrow::compute::Expression> {
auto value_1 = call.args(0);
auto value_2 = call.args(1);
ExtensionSet ext_set_;
ARROW_ASSIGN_OR_RAISE(auto expression_1, FromProto(value_1, ext_set_));
ARROW_ASSIGN_OR_RAISE(auto expression_2, FromProto(value_2, ext_set_));
return arrow::compute::call("xor", {expression_1, expression_2});
return arrow::compute::call("xor", substrait_convert_arguments(call));
};

ArrowToSubstrait arrow_invert_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set_) -> Result<substrait::Expression::ScalarFunction> {
substrait::Expression::ScalarFunction substrait_call;

ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set_->EncodeFunction("not"));
substrait_call.set_function_reference(function_reference);

auto expression_1 = call.arguments[0];
auto expression_2 = call.arguments[1];

ARROW_ASSIGN_OR_RAISE(auto value_1, ToProto(expression_1, ext_set_));
ARROW_ASSIGN_OR_RAISE(auto value_2, ToProto(expression_2, ext_set_));

substrait_call.add_args()->CopyFrom(*value_1);
substrait_call.add_args()->CopyFrom(*value_2);
return std::move(substrait_call);

return arrow_convert_arguments(call, substrait_call, ext_set_);
};

ArrowToSubstrait arrow_or_kleene_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set_) -> Result<substrait::Expression::ScalarFunction> {
substrait::Expression::ScalarFunction substrait_call;

ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set_->EncodeFunction("or"));
substrait_call.set_function_reference(function_reference);

arrow::compute::Expression expression;
std::unique_ptr<substrait::Expression> value;
for(size_t i = 0; i<call.arguments.size(); ++i){
expression = call.arguments[i];
ARROW_ASSIGN_OR_RAISE(value, ToProto(expression, ext_set_));
substrait_call.add_args()->CopyFrom(*value);
}
return std::move(substrait_call);
return arrow_convert_arguments(call, substrait_call, ext_set_);
};


ArrowToSubstrait arrow_and_kleene_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set_) -> Result<substrait::Expression::ScalarFunction> {
substrait::Expression::ScalarFunction substrait_call;

ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set_->EncodeFunction("and"));
substrait_call.set_function_reference(function_reference);

arrow::compute::Expression expression;
std::unique_ptr<substrait::Expression> value;
for(size_t i = 0; i<call.arguments.size(); ++i){
expression = call.arguments[i];
ARROW_ASSIGN_OR_RAISE(value, ToProto(expression, ext_set_));
substrait_call.add_args()->CopyFrom(*value);
}

return std::move(substrait_call);
return arrow_convert_arguments(call, substrait_call, ext_set_);
};

ArrowToSubstrait arrow_xor_to_substrait = [] (const arrow::compute::Expression::Call& call, ExtensionSet* ext_set_) -> Result<substrait::Expression::ScalarFunction> {
substrait::Expression::ScalarFunction substrait_call;

ARROW_ASSIGN_OR_RAISE(auto function_reference, ext_set_->EncodeFunction("xor"));
substrait_call.set_function_reference(function_reference);

auto expression_1 = call.arguments[0];
auto expression_2 = call.arguments[1];

ARROW_ASSIGN_OR_RAISE(auto value_1, ToProto(expression_1, ext_set_));
ARROW_ASSIGN_OR_RAISE(auto value_2, ToProto(expression_2, ext_set_));

substrait_call.add_args()->CopyFrom(*value_1);
substrait_call.add_args()->CopyFrom(*value_2);
return std::move(substrait_call);
return arrow_convert_arguments(call, substrait_call, ext_set_);
};

// Comparison Functions mapping
SubstraitToArrow substrait_lt_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result<arrow::compute::Expression> {
auto value_1 = call.args(0);
auto value_2 = call.args(1);
ExtensionSet ext_set_;
ARROW_ASSIGN_OR_RAISE(auto expression_1, FromProto(value_1, ext_set_));
ARROW_ASSIGN_OR_RAISE(auto expression_2, FromProto(value_2, ext_set_));
return arrow::compute::call("less", {expression_1, expression_2});
return arrow::compute::call("less", substrait_convert_arguments(call));
};

SubstraitToArrow substrait_gt_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result<arrow::compute::Expression> {
auto value_1 = call.args(0);
auto value_2 = call.args(1);
ExtensionSet ext_set_;
ARROW_ASSIGN_OR_RAISE(auto expression_1, FromProto(value_1, ext_set_));
ARROW_ASSIGN_OR_RAISE(auto expression_2, FromProto(value_2, ext_set_));
return arrow::compute::call("greater", {expression_1, expression_2});
return arrow::compute::call("greater", substrait_convert_arguments(call));
};

SubstraitToArrow substrait_lte_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result<arrow::compute::Expression> {
auto value_1 = call.args(0);
auto value_2 = call.args(1);
ExtensionSet ext_set_;
ARROW_ASSIGN_OR_RAISE(auto expression_1, FromProto(value_1, ext_set_));
ARROW_ASSIGN_OR_RAISE(auto expression_2, FromProto(value_2, ext_set_));
return arrow::compute::call("less_equal", {expression_1, expression_2});
SubstraitToArrow substrait_lte_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result<arrow::compute::Expression> {
return arrow::compute::call("less_equal", substrait_convert_arguments(call));
};

SubstraitToArrow substrait_not_equal_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result<arrow::compute::Expression> {
auto value_1 = call.args(0);
auto value_2 = call.args(1);
ExtensionSet ext_set_;
ARROW_ASSIGN_OR_RAISE(auto expression_1, FromProto(value_1, ext_set_));
ARROW_ASSIGN_OR_RAISE(auto expression_2, FromProto(value_2, ext_set_));
return arrow::compute::call("not_equal", {expression_1, expression_2});
return arrow::compute::call("not_equal", substrait_convert_arguments(call));
};

SubstraitToArrow substrait_equal_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result<arrow::compute::Expression> {
auto value_1 = call.args(0);
auto value_2 = call.args(1);
ExtensionSet ext_set_;
ARROW_ASSIGN_OR_RAISE(auto expression_1, FromProto(value_1, ext_set_));
ARROW_ASSIGN_OR_RAISE(auto expression_2, FromProto(value_2, ext_set_));
return arrow::compute::call("equal", {expression_1, expression_2});
SubstraitToArrow substrait_equal_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result<arrow::compute::Expression> {
return arrow::compute::call("equal", substrait_convert_arguments(call));
};

SubstraitToArrow substrait_is_null_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result<arrow::compute::Expression> {
auto value_1 = call.args(0);
ExtensionSet ext_set_;
ARROW_ASSIGN_OR_RAISE(auto expression_1, FromProto(value_1, ext_set_));
return arrow::compute::call("is_null", {expression_1});
return arrow::compute::call("is_null", substrait_convert_arguments(call));
};

SubstraitToArrow substrait_is_not_null_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result<arrow::compute::Expression> {
auto value_1 = call.args(0);
ExtensionSet ext_set_;
ARROW_ASSIGN_OR_RAISE(auto expression_1, FromProto(value_1, ext_set_));
return arrow::compute::call("is_valid", {expression_1});
return arrow::compute::call("is_valid", substrait_convert_arguments(call));
};

SubstraitToArrow substrait_is_not_distinct_from_to_arrow = [] (const substrait::Expression::ScalarFunction& call) -> Result<arrow::compute::Expression> {
auto value_1 = call.args(0);
auto value_2 = call.args(1);
ExtensionSet ext_set_;
ARROW_ASSIGN_OR_RAISE(auto expression_1, FromProto(value_1, ext_set_));
ARROW_ASSIGN_OR_RAISE(auto expression_2, FromProto(value_2, ext_set_));
auto null_check_1 = arrow::compute::call("is_null", {expression_1});
auto null_check_2 = arrow::compute::call("is_null", {expression_2});
std::vector<compute::Expression> func_args = substrait_convert_arguments(call);
auto null_check_1 = arrow::compute::call("is_null", {func_args[0]});
auto null_check_2 = arrow::compute::call("is_null", {func_args[1]});
if(null_check_1.IsNullLiteral() && null_check_1.IsNullLiteral()){
return arrow::compute::call("not_equal", {null_check_1, null_check_2});
}
return arrow::compute::call("not_equal", {expression_1, expression_2});
return arrow::compute::call("not_equal", func_args);
};

} // namespace engine
Expand Down

0 comments on commit 9cd57d1

Please sign in to comment.