diff --git a/spark/src/main/resources/spark.yml b/spark/src/main/resources/spark.yml index 4f610ded7..9b509a577 100644 --- a/spark/src/main/resources/spark.yml +++ b/spark/src/main/resources/spark.yml @@ -41,3 +41,22 @@ scalar_functions: - args: - value: i64 return: DECIMAL + - name: shift_right + description: >- + Bitwise (signed) shift right. + Params: + base – the base number to shift. + shift – number of bits to right shift. + impls: + - args: + - name: base + value: i64 + - name: shift + value: i32 + return: i64 + - args: + - name: base + value: i32 + - name: shift + value: i32 + return: i32 diff --git a/spark/src/main/scala/io/substrait/spark/ToSubstraitType.scala b/spark/src/main/scala/io/substrait/spark/ToSubstraitType.scala index 295901378..fa5ec4fea 100644 --- a/spark/src/main/scala/io/substrait/spark/ToSubstraitType.scala +++ b/spark/src/main/scala/io/substrait/spark/ToSubstraitType.scala @@ -29,6 +29,8 @@ import scala.collection.JavaConverters.asScalaBufferConverter private class ToSparkType extends TypeVisitor.TypeThrowsVisitor[DataType, RuntimeException]("Unknown expression type.") { + override def visit(expr: Type.I8): DataType = ByteType + override def visit(expr: Type.I16): DataType = ShortType override def visit(expr: Type.I32): DataType = IntegerType override def visit(expr: Type.I64): DataType = LongType diff --git a/spark/src/main/scala/io/substrait/spark/expression/FunctionMappings.scala b/spark/src/main/scala/io/substrait/spark/expression/FunctionMappings.scala index a8662cacf..ac4fed93e 100644 --- a/spark/src/main/scala/io/substrait/spark/expression/FunctionMappings.scala +++ b/spark/src/main/scala/io/substrait/spark/expression/FunctionMappings.scala @@ -60,6 +60,10 @@ class FunctionMappings { s[Concat]("concat"), s[Coalesce]("coalesce"), s[Year]("year"), + s[ShiftRight]("shift_right"), + s[BitwiseAnd]("bitwise_and"), + s[BitwiseOr]("bitwise_or"), + s[BitwiseXor]("bitwise_xor"), // internal s[MakeDecimal]("make_decimal"), diff --git a/spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala b/spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala index 54b472d3b..62b4bfcd9 100644 --- a/spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala +++ b/spark/src/main/scala/io/substrait/spark/expression/ToSparkExpression.scala @@ -44,6 +44,15 @@ class ToSparkExpression( Literal.FalseLiteral } } + + override def visit(expr: SExpression.I8Literal): Expression = { + Literal(expr.value().asInstanceOf[Byte], ToSubstraitType.convert(expr.getType)) + } + + override def visit(expr: SExpression.I16Literal): Expression = { + Literal(expr.value().asInstanceOf[Short], ToSubstraitType.convert(expr.getType)) + } + override def visit(expr: SExpression.I32Literal): Expression = { Literal(expr.value(), ToSubstraitType.convert(expr.getType)) } diff --git a/spark/src/main/scala/io/substrait/spark/expression/ToSubstraitExpression.scala b/spark/src/main/scala/io/substrait/spark/expression/ToSubstraitExpression.scala index 1bd8a6f2f..e43965363 100644 --- a/spark/src/main/scala/io/substrait/spark/expression/ToSubstraitExpression.scala +++ b/spark/src/main/scala/io/substrait/spark/expression/ToSubstraitExpression.scala @@ -146,11 +146,12 @@ abstract class ToSubstraitExpression extends HasOutputStack[Seq[Attribute]] { "org.apache.spark.sql.catalyst.expressions.PromotePrecision") => translateUp(p.children.head) case CaseWhen(branches, elseValue) => translateCaseWhen(branches, elseValue) + case In(value, list) => translateIn(value, list) + case InSet(value, set) => translateIn(value, set.toSeq.map(v => Literal(v))) case scalar @ ScalarFunction(children) => Util .seqToOption(children.map(translateUp)) .flatMap(toScalarFunction.convert(scalar, _)) - case In(value, list) => translateIn(value, list) case p: PlanExpression[_] => translateSubQuery(p) case other => default(other) } diff --git a/spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala b/spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala index e06ddfdb0..bd8570f4d 100644 --- a/spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala +++ b/spark/src/test/scala/io/substrait/spark/TPCDSPlan.scala @@ -32,9 +32,9 @@ class TPCDSPlan extends TPCDSBase with SubstraitPlanTestBase { } // spotless:off - val successfulSQL: Set[String] = Set("q1", "q3", "q4", "q7", + val successfulSQL: Set[String] = Set("q1", "q3", "q4", "q7", "q8", "q11", "q13", "q14b", "q15", "q16", "q18", "q19", - "q21", "q22", "q23a", "q23b", "q24a", "q24b", "q25", "q26", "q28", "q29", + "q21", "q22", "q23a", "q23b", "q24a", "q24b", "q25", "q26", "q27", "q28", "q29", "q30", "q31", "q32", "q33", "q37", "q38", "q40", "q41", "q42", "q43", "q46", "q48", "q50", "q52", "q54", "q55", "q56", "q58", "q59",