Skip to content

Commit

Permalink
feat(spark): add byte and short (from substrait-io#309)
Browse files Browse the repository at this point in the history
  • Loading branch information
Blizzara committed Oct 24, 2024
1 parent ea7bdb1 commit c43f399
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 1 deletion.
2 changes: 2 additions & 0 deletions spark/src/main/scala/io/substrait/spark/ToSubstraitType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ import scala.collection.JavaConverters.asScalaBufferConverter
private class ToSparkType
extends TypeVisitor.TypeThrowsVisitor[DataType, RuntimeException]("Unknown expression type.") {

override def visit(expr: Type.I8): DataType = ByteType
override def visit(expr: Type.I16): DataType = ShortType
override def visit(expr: Type.I32): DataType = IntegerType
override def visit(expr: Type.I64): DataType = LongType

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,15 @@ class ToSparkExpression(
Literal.FalseLiteral
}
}

override def visit(expr: SExpression.I8Literal): Expression = {
Literal(expr.value().asInstanceOf[Byte], ToSubstraitType.convert(expr.getType))
}

override def visit(expr: SExpression.I16Literal): Expression = {
Literal(expr.value().asInstanceOf[Short], ToSubstraitType.convert(expr.getType))
}

override def visit(expr: SExpression.I32Literal): Expression = {
Literal(expr.value(), ToSubstraitType.convert(expr.getType))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import io.substrait.spark.expression.{ToSparkExpression, ToSubstraitLiteral}
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.util.MapData
import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, DataType, DayTimeIntervalType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, StringType, TimestampNTZType, TimestampType, YearMonthIntervalType}
import org.apache.spark.sql.types._
import org.apache.spark.substrait.SparkTypeUtil
import org.apache.spark.unsafe.types.UTF8String

Expand All @@ -16,6 +16,8 @@ class TypesAndLiteralsSuite extends SparkFunSuite {
val toSparkExpression = new ToSparkExpression(null, null)

val types: Seq[DataType] = List(
ByteType,
ShortType,
IntegerType,
LongType,
FloatType,
Expand Down Expand Up @@ -52,6 +54,8 @@ class TypesAndLiteralsSuite extends SparkFunSuite {
val defaultLiterals: Seq[Literal] = types.map(Literal.default)

val literals: Seq[Literal] = List(
Literal(1.toByte),
Literal(1.toShort),
Literal(1),
Literal(1L),
Literal(1.0f),
Expand Down

0 comments on commit c43f399

Please sign in to comment.