From 56c3fc8c09bc58192032629407fc2413bbe30d3e Mon Sep 17 00:00:00 2001 From: Grigory Pomadchin Date: Mon, 18 Oct 2021 12:27:57 -0400 Subject: [PATCH] Make Spark 3.1.x source compatible --- .../frameless/functions/AggregateFunctions.scala | 5 ++++- .../functions/NonAggregateFunctions.scala | 13 +++++++++---- .../functions/NonAggregateFunctionsTests.scala | 14 ++++++++++---- 3 files changed, 23 insertions(+), 9 deletions(-) diff --git a/dataset/src/main/scala/frameless/functions/AggregateFunctions.scala b/dataset/src/main/scala/frameless/functions/AggregateFunctions.scala index 6b466b3e4..21e9c8aa2 100644 --- a/dataset/src/main/scala/frameless/functions/AggregateFunctions.scala +++ b/dataset/src/main/scala/frameless/functions/AggregateFunctions.scala @@ -6,6 +6,8 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.{functions => sparkFunctions} import frameless.syntax._ +import scala.annotation.nowarn + trait AggregateFunctions { /** Aggregate function: returns the number of items in a group. * @@ -77,6 +79,7 @@ trait AggregateFunctions { * * apache/spark */ + @nowarn // supress sparkFunstion.sumDistinct call which is used to maintain Spark 3.1.x backwards compat def sumDistinct[A, T, Out](column: TypedColumn[T, A])( implicit summable: CatalystSummable[A, Out], @@ -84,7 +87,7 @@ trait AggregateFunctions { aencoder: TypedEncoder[A] ): TypedAggregate[T, Out] = { val zeroExpr = Literal.create(summable.zero, TypedEncoder[A].catalystRepr) - val sumExpr = expr(sparkFunctions.sum_distinct(column.untyped)) + val sumExpr = expr(sparkFunctions.sumDistinct(column.untyped)) val sumOrZero = Coalesce(Seq(sumExpr, zeroExpr)) new TypedAggregate[T, Out](sumOrZero) diff --git a/dataset/src/main/scala/frameless/functions/NonAggregateFunctions.scala b/dataset/src/main/scala/frameless/functions/NonAggregateFunctions.scala index fc0597d42..4b2b41a84 100644 --- a/dataset/src/main/scala/frameless/functions/NonAggregateFunctions.scala +++ b/dataset/src/main/scala/frameless/functions/NonAggregateFunctions.scala @@ -3,6 +3,7 @@ package functions import org.apache.spark.sql.{Column, functions => sparkFunctions} +import scala.annotation.nowarn import scala.util.matching.Regex trait NonAggregateFunctions { @@ -86,34 +87,37 @@ trait NonAggregateFunctions { * * apache/spark */ + @nowarn // supress sparkFunstion.shiftLeft call which is used to maintain Spark 3.1.x backwards compat def shiftRightUnsigned[A, B, T](column: AbstractTypedColumn[T, A], numBits: Int) (implicit i0: CatalystBitShift[A, B], i1: TypedEncoder[B] ): column.ThisType[T, B] = - column.typed(sparkFunctions.shiftrightunsigned(column.untyped, numBits)) + column.typed(sparkFunctions.shiftRightUnsigned(column.untyped, numBits)) /** Non-Aggregate function: shift the the given value numBits right. If given long, will return long else it will return an integer. * * apache/spark */ + @nowarn // supress sparkFunstion.shiftReft call which is used to maintain Spark 3.1.x backwards compat def shiftRight[A, B, T](column: AbstractTypedColumn[T, A], numBits: Int) (implicit i0: CatalystBitShift[A, B], i1: TypedEncoder[B] ): column.ThisType[T, B] = - column.typed(sparkFunctions.shiftright(column.untyped, numBits)) + column.typed(sparkFunctions.shiftRight(column.untyped, numBits)) /** Non-Aggregate function: shift the the given value numBits left. If given long, will return long else it will return an integer. * * apache/spark */ + @nowarn // supress sparkFunstion.shiftLeft call which is used to maintain Spark 3.1.x backwards compat def shiftLeft[A, B, T](column: AbstractTypedColumn[T, A], numBits: Int) (implicit i0: CatalystBitShift[A, B], i1: TypedEncoder[B] ): column.ThisType[T, B] = - column.typed(sparkFunctions.shiftleft(column.untyped, numBits)) + column.typed(sparkFunctions.shiftLeft(column.untyped, numBits)) /** Non-Aggregate function: returns the absolute value of a numeric column * @@ -491,8 +495,9 @@ trait NonAggregateFunctions { * * apache/spark */ + @nowarn // supress sparkFunstion.bitwiseNOT call which is used to maintain Spark 3.1.x backwards compat def bitwiseNOT[A: CatalystBitwise, T](column: AbstractTypedColumn[T, A]): column.ThisType[T, A] = - column.typed(sparkFunctions.bitwise_not(column.untyped))(column.uencoder) + column.typed(sparkFunctions.bitwiseNOT(column.untyped))(column.uencoder) /** Non-Aggregate function: file name of the current Spark task. Empty string if row did not originate from * a file diff --git a/dataset/src/test/scala/frameless/functions/NonAggregateFunctionsTests.scala b/dataset/src/test/scala/frameless/functions/NonAggregateFunctionsTests.scala index 4bea844b1..8c3119a22 100644 --- a/dataset/src/test/scala/frameless/functions/NonAggregateFunctionsTests.scala +++ b/dataset/src/test/scala/frameless/functions/NonAggregateFunctionsTests.scala @@ -9,6 +9,8 @@ import org.apache.spark.sql.{Column, Encoder, SaveMode, functions => sparkFuncti import org.scalacheck.Prop._ import org.scalacheck.{Arbitrary, Gen, Prop} +import scala.annotation.nowarn + class NonAggregateFunctionsTests extends TypedDatasetSuite { val testTempFiles = "target/testoutput" @@ -180,11 +182,12 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ + @nowarn // supress sparkFunstion.shiftRightUnsigned call which is used to maintain Spark 3.1.x backwards compat def prop[A: TypedEncoder : Encoder, B: TypedEncoder : Encoder] (values: List[X1[A]], numBits: Int) (implicit catalystBitShift: CatalystBitShift[A, B], encX1: Encoder[X1[A]]) = { val typedDS = TypedDataset.create(values) - propBitShift(typedDS)(shiftRightUnsigned(typedDS('a), numBits), sparkFunctions.shiftrightunsigned, numBits) + propBitShift(typedDS)(shiftRightUnsigned(typedDS('a), numBits), sparkFunctions.shiftRightUnsigned, numBits) } check(forAll(prop[Byte, Int] _)) @@ -198,11 +201,12 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ + @nowarn // supress sparkFunstion.shiftRight call which is used to maintain Spark 3.1.x backwards compat def prop[A: TypedEncoder : Encoder, B: TypedEncoder : Encoder] (values: List[X1[A]], numBits: Int) (implicit catalystBitShift: CatalystBitShift[A, B], encX1: Encoder[X1[A]]) = { val typedDS = TypedDataset.create(values) - propBitShift(typedDS)(shiftRight(typedDS('a), numBits), sparkFunctions.shiftright, numBits) + propBitShift(typedDS)(shiftRight(typedDS('a), numBits), sparkFunctions.shiftRight, numBits) } check(forAll(prop[Byte, Int] _)) @@ -216,11 +220,12 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ + @nowarn // supress sparkFunstion.shiftLeft call which is used to maintain Spark 3.1.x backwards compat def prop[A: TypedEncoder : Encoder, B: TypedEncoder : Encoder] (values: List[X1[A]], numBits: Int) (implicit catalystBitShift: CatalystBitShift[A, B], encX1: Encoder[X1[A]]) = { val typedDS = TypedDataset.create(values) - propBitShift(typedDS)(shiftLeft(typedDS('a), numBits), sparkFunctions.shiftleft, numBits) + propBitShift(typedDS)(shiftLeft(typedDS('a), numBits), sparkFunctions.shiftLeft, numBits) } check(forAll(prop[Byte, Int] _)) @@ -1652,11 +1657,12 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite { val spark = session import spark.implicits._ + @nowarn // supress sparkFunstion.bitwiseNOT call which is used to maintain Spark 3.1.x backwards compat def prop[A: CatalystBitwise : TypedEncoder : Encoder] (values:List[X1[A]])(implicit encX1:Encoder[X1[A]]) = { val cDS = session.createDataset(values) val resCompare = cDS - .select(sparkFunctions.bitwise_not(cDS("a"))) + .select(sparkFunctions.bitwiseNOT(cDS("a"))) .map(_.getAs[A](0)) .collect().toList