Skip to content

Commit

Permalink
Remove unwarn syntax since anyway this change is not compatible with …
Browse files Browse the repository at this point in the history
…earlier Spark versions
  • Loading branch information
pomadchin committed Oct 18, 2021
1 parent 6d64995 commit 6a248b4
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 62 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.{functions => sparkFunctions}
import frameless.syntax._

import scala.annotation.nowarn

trait AggregateFunctions {
/** Aggregate function: returns the number of items in a group.
*
Expand Down Expand Up @@ -79,24 +77,14 @@ trait AggregateFunctions {
*
* apache/spark
*/
@deprecated("Use sum_distinct", "3.2.0")
def sumDistinct[A, T, Out](column: TypedColumn[T, A])(
implicit
summable: CatalystSummable[A, Out],
oencoder: TypedEncoder[Out],
aencoder: TypedEncoder[A]
): TypedAggregate[T, Out] = sum_distinct(column)

// supress sparkFunstion.sumDistinct call which is used to maintain Spark 3.1.x backwards compat
@nowarn
def sum_distinct[A, T, Out](column: TypedColumn[T, A])(
implicit
summable: CatalystSummable[A, Out],
oencoder: TypedEncoder[Out],
aencoder: TypedEncoder[A]
): TypedAggregate[T, Out] = {
val zeroExpr = Literal.create(summable.zero, TypedEncoder[A].catalystRepr)
val sumExpr = expr(sparkFunctions.sumDistinct(column.untyped))
val sumExpr = expr(sparkFunctions.sum_distinct(column.untyped))
val sumOrZero = Coalesce(Seq(sumExpr, zeroExpr))

new TypedAggregate[T, Out](sumOrZero)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package functions

import org.apache.spark.sql.{Column, functions => sparkFunctions}

import scala.annotation.nowarn
import scala.util.matching.Regex

trait NonAggregateFunctions {
Expand Down Expand Up @@ -87,61 +86,34 @@ trait NonAggregateFunctions {
*
* apache/spark
*/
@deprecated("Use shiftrightunsigned", "3.2.0")
def shiftRightUnsigned[A, B, T](column: AbstractTypedColumn[T, A], numBits: Int)
(implicit
i0: CatalystBitShift[A, B],
i1: TypedEncoder[B]
): column.ThisType[T, B] = shiftrightunsigned(column, numBits)

// supress sparkFunstion.shiftRightUnsigned call which is used to maintain Spark 3.1.x backwards compat
@nowarn
def shiftrightunsigned[A, B, T](column: AbstractTypedColumn[T, A], numBits: Int)
(implicit
i0: CatalystBitShift[A, B],
i1: TypedEncoder[B]
): column.ThisType[T, B] =
column.typed(sparkFunctions.shiftRightUnsigned(column.untyped, numBits))
column.typed(sparkFunctions.shiftrightunsigned(column.untyped, numBits))

/** Non-Aggregate function: shift the the given value numBits right. If given long, will return long else it will return an integer.
*
* apache/spark
*/
@deprecated("Use shiftright", "3.2.0")
def shiftRight[A, B, T](column: AbstractTypedColumn[T, A], numBits: Int)
(implicit
i0: CatalystBitShift[A, B],
i1: TypedEncoder[B]
): column.ThisType[T, B] = shiftright(column, numBits)

// supress sparkFunstion.shiftRight call which is used to maintain Spark 3.1.x backwards compat
@nowarn
def shiftright[A, B, T](column: AbstractTypedColumn[T, A], numBits: Int)
(implicit
i0: CatalystBitShift[A, B],
i1: TypedEncoder[B]
): column.ThisType[T, B] =
column.typed(sparkFunctions.shiftRight(column.untyped, numBits))
column.typed(sparkFunctions.shiftright(column.untyped, numBits))

/** Non-Aggregate function: shift the the given value numBits left. If given long, will return long else it will return an integer.
*
* apache/spark
*/
@deprecated("Use shiftleft", "3.2.0")
def shiftLeft[A, B, T](column: AbstractTypedColumn[T, A], numBits: Int)
(implicit
i0: CatalystBitShift[A, B],
i1: TypedEncoder[B]
): column.ThisType[T, B] = shiftleft(column, numBits)

// supress sparkFunstion.shiftLeft call which is used to maintain Spark 3.1.x backwards compat
@nowarn
def shiftleft[A, B, T](column: AbstractTypedColumn[T, A], numBits: Int)
(implicit
i0: CatalystBitShift[A, B],
i1: TypedEncoder[B]
): column.ThisType[T, B] =
column.typed(sparkFunctions.shiftLeft(column.untyped, numBits))
column.typed(sparkFunctions.shiftleft(column.untyped, numBits))

/** Non-Aggregate function: returns the absolute value of a numeric column
*
Expand Down Expand Up @@ -519,14 +491,8 @@ trait NonAggregateFunctions {
*
* apache/spark
*/
@deprecated("Use bitwise_not", "3.2.0")
def bitwiseNOT[A: CatalystBitwise, T](column: AbstractTypedColumn[T, A]): column.ThisType[T, A] =
bitwise_not(column)

// supress sparkFunstion.bitwiseNOT call which is used to maintain Spark 3.1.x backwards compat
@nowarn
def bitwise_not[A: CatalystBitwise, T](column: AbstractTypedColumn[T, A]): column.ThisType[T, A] =
column.typed(sparkFunctions.bitwiseNOT(column.untyped))(column.uencoder)
column.typed(sparkFunctions.bitwise_not(column.untyped))(column.uencoder)

/** Non-Aggregate function: file name of the current Spark task. Empty string if row did not originate from
* a file
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class AggregateFunctionsTests extends TypedDatasetSuite {
check(sparkSchema[Short, Long](sum))
}

test("sum_distinct") {
test("sumDistinct") {
case class Sum4Tests[A, B](sum: Seq[A] => B)

def prop[A: TypedEncoder, Out: TypedEncoder : Numeric](xs: List[A])(
Expand All @@ -68,7 +68,7 @@ class AggregateFunctionsTests extends TypedDatasetSuite {
val dataset = TypedDataset.create(xs.map(X1(_)))
val A = dataset.col[A]('a)

val datasetSum: List[Out] = dataset.agg(sum_distinct(A)).collect().run().toList
val datasetSum: List[Out] = dataset.agg(sumDistinct(A)).collect().run().toList

datasetSum match {
case x :: Nil => approximatelyEqual(summer.sum(xs), x)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,15 +176,15 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
res ?= resCompare
}

test("shiftrightunsigned") {
test("shiftRightUnsigned") {
val spark = session
import spark.implicits._

def prop[A: TypedEncoder : Encoder, B: TypedEncoder : Encoder]
(values: List[X1[A]], numBits: Int)
(implicit catalystBitShift: CatalystBitShift[A, B], encX1: Encoder[X1[A]]) = {
val typedDS = TypedDataset.create(values)
propBitShift(typedDS)(shiftrightunsigned(typedDS('a), numBits), sparkFunctions.shiftrightunsigned, numBits)
propBitShift(typedDS)(shiftRightUnsigned(typedDS('a), numBits), sparkFunctions.shiftrightunsigned, numBits)
}

check(forAll(prop[Byte, Int] _))
Expand All @@ -194,15 +194,15 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
check(forAll(prop[BigDecimal, Int] _))
}

test("shiftright") {
test("shiftRight") {
val spark = session
import spark.implicits._

def prop[A: TypedEncoder : Encoder, B: TypedEncoder : Encoder]
(values: List[X1[A]], numBits: Int)
(implicit catalystBitShift: CatalystBitShift[A, B], encX1: Encoder[X1[A]]) = {
val typedDS = TypedDataset.create(values)
propBitShift(typedDS)(shiftright(typedDS('a), numBits), sparkFunctions.shiftright, numBits)
propBitShift(typedDS)(shiftRight(typedDS('a), numBits), sparkFunctions.shiftright, numBits)
}

check(forAll(prop[Byte, Int] _))
Expand All @@ -212,15 +212,15 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
check(forAll(prop[BigDecimal, Int] _))
}

test("shiftleft") {
test("shiftLeft") {
val spark = session
import spark.implicits._

def prop[A: TypedEncoder : Encoder, B: TypedEncoder : Encoder]
(values: List[X1[A]], numBits: Int)
(implicit catalystBitShift: CatalystBitShift[A, B], encX1: Encoder[X1[A]]) = {
val typedDS = TypedDataset.create(values)
propBitShift(typedDS)(shiftleft(typedDS('a), numBits), sparkFunctions.shiftleft, numBits)
propBitShift(typedDS)(shiftLeft(typedDS('a), numBits), sparkFunctions.shiftleft, numBits)
}

check(forAll(prop[Byte, Int] _))
Expand Down Expand Up @@ -1648,21 +1648,21 @@ class NonAggregateFunctionsTests extends TypedDatasetSuite {
check(forAll(prop _))
}

test("bitwise_not"){
test("bitwiseNOT"){
val spark = session
import spark.implicits._

def prop[A: CatalystBitwise : TypedEncoder : Encoder]
(values:List[X1[A]])(implicit encX1:Encoder[X1[A]]) = {
val cDS = session.createDataset(values)
val resCompare = cDS
.select(sparkFunctions.bitwise_not(cDS("a")))
.select(sparkFunctions.bitwiseNOT(cDS("a")))
.map(_.getAs[A](0))
.collect().toList

val typedDS = TypedDataset.create(values)
val res = typedDS
.select(bitwise_not(typedDS('a)))
.select(bitwiseNOT(typedDS('a)))
.collect()
.run()
.toList
Expand Down

0 comments on commit 6a248b4

Please sign in to comment.