Skip to content

Commit

Permalink
feat(spark): bitwise functions
Browse files Browse the repository at this point in the history
Adds support in the spark module for 8-bit and 16-bit integer types and for some bitwise functions.
The catalyst optimizer generates expressions using these for certain query types.

Note that `shift_right` (and other bit shifting functions) might want to be considered for the core substrait function catalog,
but it has been added here (temporarily?) as spark extension pending a longer term discussion/decision on their wider utility.

Signed-off-by: Andrew Coleman <[email protected]>
  • Loading branch information
andrew-coleman committed Oct 23, 2024
1 parent ac0b7d1 commit 3bf4782
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 3 deletions.
13 changes: 13 additions & 0 deletions spark/src/main/resources/spark.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,16 @@ scalar_functions:
- args:
- value: i64
return: DECIMAL<P,S>
- name: shift_right
description: >-
Bitwise (signed) shift right.
Params:
base – the base number to shift.
shift – number of bits to right shift.
impls:
- args:
- name: base
value: i64
- name: shift
value: i32
return: i64
2 changes: 2 additions & 0 deletions spark/src/main/scala/io/substrait/spark/ToSubstraitType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ import scala.collection.JavaConverters.asScalaBufferConverter
private class ToSparkType
extends TypeVisitor.TypeThrowsVisitor[DataType, RuntimeException]("Unknown expression type.") {

override def visit(expr: Type.I8): DataType = ByteType
override def visit(expr: Type.I16): DataType = ShortType
override def visit(expr: Type.I32): DataType = IntegerType
override def visit(expr: Type.I64): DataType = LongType

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ class FunctionMappings {
s[GreaterThanOrEqual]("gte"),
s[EqualTo]("equal"),
s[EqualNullSafe]("is_not_distinct_from"),
// s[BitwiseXor]("xor"),
s[IsNull]("is_null"),
s[IsNotNull]("is_not_null"),
s[EndsWith]("ends_with"),
Expand All @@ -57,6 +56,10 @@ class FunctionMappings {
s[StartsWith]("starts_with"),
s[Substring]("substring"),
s[Year]("year"),
s[ShiftRight]("shift_right"),
s[BitwiseAnd]("bitwise_and"),
s[BitwiseOr]("bitwise_or"),
s[BitwiseXor]("bitwise_xor"),

// internal
s[MakeDecimal]("make_decimal"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,15 @@ class ToSparkExpression(
Literal.FalseLiteral
}
}

override def visit(expr: SExpression.I8Literal): Expression = {
Literal(expr.value().asInstanceOf[Byte], ToSubstraitType.convert(expr.getType))
}

override def visit(expr: SExpression.I16Literal): Expression = {
Literal(expr.value().asInstanceOf[Short], ToSubstraitType.convert(expr.getType))
}

override def visit(expr: SExpression.I32Literal): Expression = {
Literal(expr.value(), ToSubstraitType.convert(expr.getType))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ abstract class ToSubstraitExpression extends HasOutputStack[Seq[Attribute]] {
"org.apache.spark.sql.catalyst.expressions.PromotePrecision") =>
translateUp(p.children.head)
case CaseWhen(branches, elseValue) => translateCaseWhen(branches, elseValue)
case InSet(value, set) => translateIn(value, set.toSeq.map(v => Literal(v)))
case scalar @ ScalarFunction(children) =>
Util
.seqToOption(children.map(translateUp))
Expand Down
4 changes: 2 additions & 2 deletions spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ class TPCDSPlan extends TPCDSBase with SubstraitPlanTestBase {
}

// spotless:off
val successfulSQL: Set[String] = Set("q1", "q3", "q4", "q7",
val successfulSQL: Set[String] = Set("q1", "q3", "q4", "q7", "q8",
"q11", "q13", "q14b", "q15", "q16", "q18", "q19",
"q21", "q22", "q23a", "q23b", "q25", "q26", "q28", "q29",
"q21", "q22", "q23a", "q23b", "q25", "q26", "q27", "q28", "q29",
"q30", "q31", "q32", "q33", "q37", "q38",
"q41", "q42", "q43", "q46", "q48",
"q50", "q52", "q54", "q55", "q56", "q58", "q59",
Expand Down

0 comments on commit 3bf4782

Please sign in to comment.