Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(spark): bitwise functions #309

Merged
merged 1 commit into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions spark/src/main/resources/spark.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,22 @@ scalar_functions:
- args:
- value: i64
return: DECIMAL<P,S>
- name: shift_right
description: >-
Bitwise (signed) shift right.
Params:
base – the base number to shift.
shift – number of bits to right shift.
impls:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like in Spark the base can be either int or long, and return type is set accordingly. I we should add both options?

Quickly testing this, it works for select shiftright(col, 2) from (values (bigint(1234)) as table(col)) but not for select shiftright(col, 2) from (values (1234) as table(col)), so yep I think we need to list both versions here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall I think it's fine to add the function here initially, but it'd be good to also file the PR/Issue on the core functions since this seems general enough, I think :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, although that might take a bit longer ;)

- args:
- name: base
value: i64
- name: shift
value: i32
return: i64
- args:
- name: base
value: i32
- name: shift
value: i32
return: i32
2 changes: 2 additions & 0 deletions spark/src/main/scala/io/substrait/spark/ToSubstraitType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ import scala.collection.JavaConverters.asScalaBufferConverter
private class ToSparkType
extends TypeVisitor.TypeThrowsVisitor[DataType, RuntimeException]("Unknown expression type.") {

override def visit(expr: Type.I8): DataType = ByteType
override def visit(expr: Type.I16): DataType = ShortType
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

while you're at it, mind adding these also to ToSparkType?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is in ToSparkType. The conversion in ToSubstraitType is already there further down this file. It converts both ways.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah right, all good then!

override def visit(expr: Type.I32): DataType = IntegerType
override def visit(expr: Type.I64): DataType = LongType

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ class FunctionMappings {
s[Concat]("concat"),
s[Coalesce]("coalesce"),
s[Year]("year"),
s[ShiftRight]("shift_right"),
s[BitwiseAnd]("bitwise_and"),
s[BitwiseOr]("bitwise_or"),
s[BitwiseXor]("bitwise_xor"),

// internal
s[MakeDecimal]("make_decimal"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,15 @@ class ToSparkExpression(
Literal.FalseLiteral
}
}

override def visit(expr: SExpression.I8Literal): Expression = {
Literal(expr.value().asInstanceOf[Byte], ToSubstraitType.convert(expr.getType))
}

override def visit(expr: SExpression.I16Literal): Expression = {
Literal(expr.value().asInstanceOf[Short], ToSubstraitType.convert(expr.getType))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and here also the other direction for the conversion?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's already covered in the other direction in ToSubstraitExpression on the following line:
case SubstraitLiteral(substraitLiteral) => Some(substraitLiteral)
The unapply method in this object invokes the conversion.

}

override def visit(expr: SExpression.I32Literal): Expression = {
Literal(expr.value(), ToSubstraitType.convert(expr.getType))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,12 @@ abstract class ToSubstraitExpression extends HasOutputStack[Seq[Attribute]] {
"org.apache.spark.sql.catalyst.expressions.PromotePrecision") =>
translateUp(p.children.head)
case CaseWhen(branches, elseValue) => translateCaseWhen(branches, elseValue)
case In(value, list) => translateIn(value, list)
case InSet(value, set) => translateIn(value, set.toSeq.map(v => Literal(v)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: mind moving this next to the case In(..)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, although InSet needs to come before ScalarFunction otherwise it matches the latter (since it's a UnaryExpression). I've moved In up the list so they are together.

case scalar @ ScalarFunction(children) =>
Util
.seqToOption(children.map(translateUp))
.flatMap(toScalarFunction.convert(scalar, _))
case In(value, list) => translateIn(value, list)
case p: PlanExpression[_] => translateSubQuery(p)
case other => default(other)
}
Expand Down
4 changes: 2 additions & 2 deletions spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ class TPCDSPlan extends TPCDSBase with SubstraitPlanTestBase {
}

// spotless:off
val successfulSQL: Set[String] = Set("q1", "q3", "q4", "q7",
val successfulSQL: Set[String] = Set("q1", "q3", "q4", "q7", "q8",
"q11", "q13", "q14a", "q14b", "q15", "q16", "q18", "q19",
"q21", "q22", "q23a", "q23b", "q24a", "q24b", "q25", "q26", "q28", "q29",
"q21", "q22", "q23a", "q23b", "q24a", "q24b", "q25", "q26", "q27", "q28", "q29",
"q30", "q31", "q32", "q33", "q37", "q38",
"q40", "q41", "q42", "q43", "q46", "q48",
"q50", "q52", "q54", "q55", "q56", "q58", "q59",
Expand Down
Loading