Skip to content

Commit

Permalink
Make Spark 3.1.x source compatible
Browse files Browse the repository at this point in the history
  • Loading branch information
pomadchin committed Oct 18, 2021
1 parent de2d72a commit 56c3fc8
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.{functions => sparkFunctions}
import frameless.syntax._

import scala.annotation.nowarn

trait AggregateFunctions {
/** Aggregate function: returns the number of items in a group.
*
Expand Down Expand Up @@ -77,14 +79,15 @@ trait AggregateFunctions {
*
* apache/spark
*/
@nowarn // supress sparkFunstion.sumDistinct call which is used to maintain Spark 3.1.x backwards compat
def sumDistinct[A, T, Out](column: TypedColumn[T, A])(
implicit
summable: CatalystSummable[A, Out],
oencoder: TypedEncoder[Out],
aencoder: TypedEncoder[A]
): TypedAggregate[T, Out] = {
val zeroExpr = Literal.create(summable.zero, TypedEncoder[A].catalystRepr)
val sumExpr = expr(sparkFunctions.sum_distinct(column.untyped))
val sumExpr = expr(sparkFunctions.sumDistinct(column.untyped))
val sumOrZero = Coalesce(Seq(sumExpr, zeroExpr))

new TypedAggregate[T, Out](sumOrZero)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package functions

import org.apache.spark.sql.{Column, functions => sparkFunctions}

import scala.annotation.nowarn
import scala.util.matching.Regex

trait NonAggregateFunctions {
Expand Down Expand Up @@ -86,34 +87,37 @@ trait NonAggregateFunctions {
*
* apache/spark
*/
@nowarn // supress sparkFunstion.shiftLeft call which is used to maintain Spark 3.1.x backwards compat
def shiftRightUnsigned[A, B, T](column: AbstractTypedColumn[T, A], numBits: Int)
(implicit
i0: CatalystBitShift[A, B],
i1: TypedEncoder[B]
): column.ThisType[T, B] =
column.typed(sparkFunctions.shiftrightunsigned(column.untyped, numBits))
column.typed(sparkFunctions.shiftRightUnsigned(column.untyped, numBits))

/** Non-Aggregate function: shift the the given value numBits right. If given long, will return long else it will return an integer.
*
* apache/spark
*/
@nowarn // supress sparkFunstion.shiftReft call which is used to maintain Spark 3.1.x backwards compat
def shiftRight[A, B, T](column: AbstractTypedColumn[T, A], numBits: Int)
(implicit
i0: CatalystBitShift[A, B],
i1: TypedEncoder[B]
): column.ThisType[T, B] =
column.typed(sparkFunctions.shiftright(column.untyped, numBits))
column.typed(sparkFunctions.shiftRight(column.untyped, numBits))

/** Non-Aggregate function: shift the the given value numBits left. If given long, will return long else it will return an integer.
*
* apache/spark
*/
@nowarn // supress sparkFunstion.shiftLeft call which is used to maintain Spark 3.1.x backwards compat
def shiftLeft[A, B, T](column: AbstractTypedColumn[T, A], numBits: Int)
(implicit
i0: CatalystBitShift[A, B],
i1: TypedEncoder[B]
): column.ThisType[T, B] =
column.typed(sparkFunctions.shiftleft(column.untyped, numBits))
column.typed(sparkFunctions.shiftLeft(column.untyped, numBits))

/** Non-Aggregate function: returns the absolute value of a numeric column
*
Expand Down Expand Up @@ -491,8 +495,9 @@ trait NonAggregateFunctions {
*
* apache/spark
*/
@nowarn // supress sparkFunstion.bitwiseNOT call which is used to maintain Spark 3.1.x backwards compat
def bitwiseNOT[A: CatalystBitwise, T](column: AbstractTypedColumn[T, A]): column.ThisType[T, A] =
column.typed(sparkFunctions.bitwise_not(column.untyped))(column.uencoder)
column.typed(sparkFunctions.bitwiseNOT(column.untyped))(column.uencoder)

/** Non-Aggregate function: file name of the current Spark task. Empty string if row did not originate from
* a file
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import org.apache.spark.sql.{Column, Encoder, SaveMode, functions => sparkFuncti
import org.scalacheck.Prop._
import org.scalacheck.{Arbitrary, Gen, Prop}

import scala.annotation.nowarn

class NonAggregateFunctionsTests extends TypedDatasetSuite {
val testTempFiles = "target/testoutput"

Expand Down Expand Up @@ -180,11 +182,12 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
val spark = session
import spark.implicits._

@nowarn // supress sparkFunstion.shiftRightUnsigned call which is used to maintain Spark 3.1.x backwards compat
def prop[A: TypedEncoder : Encoder, B: TypedEncoder : Encoder]
(values: List[X1[A]], numBits: Int)
(implicit catalystBitShift: CatalystBitShift[A, B], encX1: Encoder[X1[A]]) = {
val typedDS = TypedDataset.create(values)
propBitShift(typedDS)(shiftRightUnsigned(typedDS('a), numBits), sparkFunctions.shiftrightunsigned, numBits)
propBitShift(typedDS)(shiftRightUnsigned(typedDS('a), numBits), sparkFunctions.shiftRightUnsigned, numBits)
}

check(forAll(prop[Byte, Int] _))
Expand All @@ -198,11 +201,12 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
val spark = session
import spark.implicits._

@nowarn // supress sparkFunstion.shiftRight call which is used to maintain Spark 3.1.x backwards compat
def prop[A: TypedEncoder : Encoder, B: TypedEncoder : Encoder]
(values: List[X1[A]], numBits: Int)
(implicit catalystBitShift: CatalystBitShift[A, B], encX1: Encoder[X1[A]]) = {
val typedDS = TypedDataset.create(values)
propBitShift(typedDS)(shiftRight(typedDS('a), numBits), sparkFunctions.shiftright, numBits)
propBitShift(typedDS)(shiftRight(typedDS('a), numBits), sparkFunctions.shiftRight, numBits)
}

check(forAll(prop[Byte, Int] _))
Expand All @@ -216,11 +220,12 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
val spark = session
import spark.implicits._

@nowarn // supress sparkFunstion.shiftLeft call which is used to maintain Spark 3.1.x backwards compat
def prop[A: TypedEncoder : Encoder, B: TypedEncoder : Encoder]
(values: List[X1[A]], numBits: Int)
(implicit catalystBitShift: CatalystBitShift[A, B], encX1: Encoder[X1[A]]) = {
val typedDS = TypedDataset.create(values)
propBitShift(typedDS)(shiftLeft(typedDS('a), numBits), sparkFunctions.shiftleft, numBits)
propBitShift(typedDS)(shiftLeft(typedDS('a), numBits), sparkFunctions.shiftLeft, numBits)
}

check(forAll(prop[Byte, Int] _))
Expand Down Expand Up @@ -1652,11 +1657,12 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
val spark = session
import spark.implicits._

@nowarn // supress sparkFunstion.bitwiseNOT call which is used to maintain Spark 3.1.x backwards compat
def prop[A: CatalystBitwise : TypedEncoder : Encoder]
(values:List[X1[A]])(implicit encX1:Encoder[X1[A]]) = {
val cDS = session.createDataset(values)
val resCompare = cDS
.select(sparkFunctions.bitwise_not(cDS("a")))
.select(sparkFunctions.bitwiseNOT(cDS("a")))
.map(_.getAs[A](0))
.collect().toList

Expand Down

0 comments on commit 56c3fc8

Please sign in to comment.