Skip to content

Commit

Permalink
[SPARK-23926][SQL] Extending reverse function to support ArrayType ar…
Browse files Browse the repository at this point in the history
…guments

## What changes were proposed in this pull request?

This PR extends `reverse` functions to be able to operate over array columns and covers:
- Introduction of `Reverse` expression that represents logic for reversing arrays and also strings
- Removal of `StringReverse` expression
- A wrapper for PySpark

## How was this patch tested?

New tests added into:
- CollectionExpressionsSuite
- DataFrameFunctionsSuite

## Codegen examples
### Primitive type
```
val df = Seq(
  Seq(1, 3, 4, 2),
  null
).toDF("i")
df.filter($"i".isNotNull || $"i".isNull).select(reverse($"i")).debugCodegen
```
Result:
```
/* 032 */         boolean inputadapter_isNull = inputadapter_row.isNullAt(0);
/* 033 */         ArrayData inputadapter_value = inputadapter_isNull ?
/* 034 */         null : (inputadapter_row.getArray(0));
/* 035 */
/* 036 */         boolean filter_value = true;
/* 037 */
/* 038 */         if (!(!inputadapter_isNull)) {
/* 039 */           filter_value = inputadapter_isNull;
/* 040 */         }
/* 041 */         if (!filter_value) continue;
/* 042 */
/* 043 */         ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(1);
/* 044 */
/* 045 */         boolean project_isNull = inputadapter_isNull;
/* 046 */         ArrayData project_value = null;
/* 047 */
/* 048 */         if (!inputadapter_isNull) {
/* 049 */           final int project_length = inputadapter_value.numElements();
/* 050 */           project_value = inputadapter_value.copy();
/* 051 */           for(int k = 0; k < project_length / 2; k++) {
/* 052 */             int l = project_length - k - 1;
/* 053 */             boolean isNullAtK = project_value.isNullAt(k);
/* 054 */             boolean isNullAtL = project_value.isNullAt(l);
/* 055 */             if(!isNullAtK) {
/* 056 */               int el = project_value.getInt(k);
/* 057 */               if(!isNullAtL) {
/* 058 */                 project_value.setInt(k, project_value.getInt(l));
/* 059 */               } else {
/* 060 */                 project_value.setNullAt(k);
/* 061 */               }
/* 062 */               project_value.setInt(l, el);
/* 063 */             } else if (!isNullAtL) {
/* 064 */               project_value.setInt(k, project_value.getInt(l));
/* 065 */               project_value.setNullAt(l);
/* 066 */             }
/* 067 */           }
/* 068 */
/* 069 */         }
```
### Non-primitive type
```
val df = Seq(
  Seq("a", "c", "d", "b"),
  null
).toDF("s")
df.filter($"s".isNotNull || $"s".isNull).select(reverse($"s")).debugCodegen
```
Result:
```
/* 032 */         boolean inputadapter_isNull = inputadapter_row.isNullAt(0);
/* 033 */         ArrayData inputadapter_value = inputadapter_isNull ?
/* 034 */         null : (inputadapter_row.getArray(0));
/* 035 */
/* 036 */         boolean filter_value = true;
/* 037 */
/* 038 */         if (!(!inputadapter_isNull)) {
/* 039 */           filter_value = inputadapter_isNull;
/* 040 */         }
/* 041 */         if (!filter_value) continue;
/* 042 */
/* 043 */         ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(1);
/* 044 */
/* 045 */         boolean project_isNull = inputadapter_isNull;
/* 046 */         ArrayData project_value = null;
/* 047 */
/* 048 */         if (!inputadapter_isNull) {
/* 049 */           final int project_length = inputadapter_value.numElements();
/* 050 */           project_value = new org.apache.spark.sql.catalyst.util.GenericArrayData(new Object[project_length]);
/* 051 */           for(int k = 0; k < project_length; k++) {
/* 052 */             int l = project_length - k - 1;
/* 053 */             project_value.update(k, inputadapter_value.getUTF8String(l));
/* 054 */           }
/* 055 */
/* 056 */         }
```

Author: mn-mikke <mrkAha12346github>

Closes apache#21034 from mn-mikke/feature/array-api-reverse-to-master.
  • Loading branch information
mn-mikke authored and ueshin committed Apr 18, 2018
1 parent cce4694 commit f81fa47
Show file tree
Hide file tree
Showing 8 changed files with 256 additions and 33 deletions.
20 changes: 19 additions & 1 deletion python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1414,7 +1414,6 @@ def hash(*cols):
'uppercase. Words are delimited by whitespace.',
'lower': 'Converts a string column to lower case.',
'upper': 'Converts a string column to upper case.',
'reverse': 'Reverses the string column and returns it as a new string column.',
'ltrim': 'Trim the spaces from left end for the specified string value.',
'rtrim': 'Trim the spaces from right end for the specified string value.',
'trim': 'Trim the spaces from both ends for the specified string column.',
Expand Down Expand Up @@ -2128,6 +2127,25 @@ def sort_array(col, asc=True):
return Column(sc._jvm.functions.sort_array(_to_java_column(col), asc))


@since(1.5)
@ignore_unicode_prefix
def reverse(col):
"""
Collection function: returns a reversed string or an array with reverse order of elements.
:param col: name of column or expression
>>> df = spark.createDataFrame([('Spark SQL',)], ['data'])
>>> df.select(reverse(df.data).alias('s')).collect()
[Row(s=u'LQS krapS')]
>>> df = spark.createDataFrame([([2, 1, 3],) ,([1],) ,([],)], ['data'])
>>> df.select(reverse(df.data).alias('r')).collect()
[Row(r=[3, 1, 2]), Row(r=[1]), Row(r=[])]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.reverse(_to_java_column(col)))


@since(2.3)
def map_keys(col):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,6 @@ object FunctionRegistry {
expression[RegExpReplace]("regexp_replace"),
expression[StringRepeat]("repeat"),
expression[StringReplace]("replace"),
expression[StringReverse]("reverse"),
expression[RLike]("rlike"),
expression[StringRPad]("rpad"),
expression[StringTrimRight]("rtrim"),
Expand Down Expand Up @@ -411,6 +410,7 @@ object FunctionRegistry {
expression[SortArray]("sort_array"),
expression[ArrayMin]("array_min"),
expression[ArrayMax]("array_max"),
expression[Reverse]("reverse"),
CreateStruct.registryEntry,

// misc functions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

/**
* Given an array or map, returns its size. Returns -1 if null.
Expand Down Expand Up @@ -212,6 +213,93 @@ case class SortArray(base: Expression, ascendingOrder: Expression)
override def prettyName: String = "sort_array"
}

/**
* Returns a reversed string or an array with reverse order of elements.
*/
@ExpressionDescription(
usage = "_FUNC_(array) - Returns a reversed string or an array with reverse order of elements.",
examples = """
Examples:
> SELECT _FUNC_('Spark SQL');
LQS krapS
> SELECT _FUNC_(array(2, 1, 4, 3));
[3, 4, 1, 2]
""",
since = "1.5.0",
note = "Reverse logic for arrays is available since 2.4.0."
)
case class Reverse(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {

// Input types are utilized by type coercion in ImplicitTypeCasts.
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, ArrayType))

override def dataType: DataType = child.dataType

lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType

override def nullSafeEval(input: Any): Any = input match {
case a: ArrayData => new GenericArrayData(a.toObjectArray(elementType).reverse)
case s: UTF8String => s.reverse()
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, c => dataType match {
case _: StringType => stringCodeGen(ev, c)
case _: ArrayType => arrayCodeGen(ctx, ev, c)
})
}

private def stringCodeGen(ev: ExprCode, childName: String): String = {
s"${ev.value} = ($childName).reverse();"
}

private def arrayCodeGen(ctx: CodegenContext, ev: ExprCode, childName: String): String = {
val length = ctx.freshName("length")
val javaElementType = CodeGenerator.javaType(elementType)
val isPrimitiveType = CodeGenerator.isPrimitiveType(elementType)

val initialization = if (isPrimitiveType) {
s"$childName.copy()"
} else {
s"new ${classOf[GenericArrayData].getName()}(new Object[$length])"
}

val numberOfIterations = if (isPrimitiveType) s"$length / 2" else length

val swapAssigments = if (isPrimitiveType) {
val setFunc = "set" + CodeGenerator.primitiveTypeName(elementType)
val getCall = (index: String) => CodeGenerator.getValue(ev.value, elementType, index)
s"""|boolean isNullAtK = ${ev.value}.isNullAt(k);
|boolean isNullAtL = ${ev.value}.isNullAt(l);
|if(!isNullAtK) {
| $javaElementType el = ${getCall("k")};
| if(!isNullAtL) {
| ${ev.value}.$setFunc(k, ${getCall("l")});
| } else {
| ${ev.value}.setNullAt(k);
| }
| ${ev.value}.$setFunc(l, el);
|} else if (!isNullAtL) {
| ${ev.value}.$setFunc(k, ${getCall("l")});
| ${ev.value}.setNullAt(l);
|}""".stripMargin
} else {
s"${ev.value}.update(k, ${CodeGenerator.getValue(childName, elementType, "l")});"
}

s"""
|final int $length = $childName.numElements();
|${ev.value} = $initialization;
|for(int k = 0; k < $numberOfIterations; k++) {
| int l = $length - k - 1;
| $swapAssigments
|}
""".stripMargin
}

override def prettyName: String = "reverse"
}

/**
* Checks if the array (left) has the element (right)
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1504,26 +1504,6 @@ case class StringRepeat(str: Expression, times: Expression)
}
}

/**
* Returns the reversed given string.
*/
@ExpressionDescription(
usage = "_FUNC_(str) - Returns the reversed given string.",
examples = """
Examples:
> SELECT _FUNC_('Spark SQL');
LQS krapS
""")
case class StringReverse(child: Expression) extends UnaryExpression with String2StringExpression {
override def convert(v: UTF8String): UTF8String = v.reverse()

override def prettyName: String = "reverse"

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev, c => s"($c).reverse()")
}
}

/**
* Returns a string consisting of n spaces.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,4 +125,48 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(
ArrayMax(Literal.create(Seq(1.123, 0.1234, 1.121), ArrayType(DoubleType))), 1.123)
}

test("Reverse") {
// Primitive-type elements
val ai0 = Literal.create(Seq(2, 1, 4, 3), ArrayType(IntegerType))
val ai1 = Literal.create(Seq(2, 1, 3), ArrayType(IntegerType))
val ai2 = Literal.create(Seq(null, 1, null, 3), ArrayType(IntegerType))
val ai3 = Literal.create(Seq(2, null, 4, null), ArrayType(IntegerType))
val ai4 = Literal.create(Seq(null, null, null), ArrayType(IntegerType))
val ai5 = Literal.create(Seq(1), ArrayType(IntegerType))
val ai6 = Literal.create(Seq.empty, ArrayType(IntegerType))
val ai7 = Literal.create(null, ArrayType(IntegerType))

checkEvaluation(Reverse(ai0), Seq(3, 4, 1, 2))
checkEvaluation(Reverse(ai1), Seq(3, 1, 2))
checkEvaluation(Reverse(ai2), Seq(3, null, 1, null))
checkEvaluation(Reverse(ai3), Seq(null, 4, null, 2))
checkEvaluation(Reverse(ai4), Seq(null, null, null))
checkEvaluation(Reverse(ai5), Seq(1))
checkEvaluation(Reverse(ai6), Seq.empty)
checkEvaluation(Reverse(ai7), null)

// Non-primitive-type elements
val as0 = Literal.create(Seq("b", "a", "d", "c"), ArrayType(StringType))
val as1 = Literal.create(Seq("b", "a", "c"), ArrayType(StringType))
val as2 = Literal.create(Seq(null, "a", null, "c"), ArrayType(StringType))
val as3 = Literal.create(Seq("b", null, "d", null), ArrayType(StringType))
val as4 = Literal.create(Seq(null, null, null), ArrayType(StringType))
val as5 = Literal.create(Seq("a"), ArrayType(StringType))
val as6 = Literal.create(Seq.empty, ArrayType(StringType))
val as7 = Literal.create(null, ArrayType(StringType))
val aa = Literal.create(
Seq(Seq("a", "b"), Seq("c", "d"), Seq("e")),
ArrayType(ArrayType(StringType)))

checkEvaluation(Reverse(as0), Seq("c", "d", "a", "b"))
checkEvaluation(Reverse(as1), Seq("c", "a", "b"))
checkEvaluation(Reverse(as2), Seq("c", null, "a", null))
checkEvaluation(Reverse(as3), Seq(null, "d", null, "b"))
checkEvaluation(Reverse(as4), Seq(null, null, null))
checkEvaluation(Reverse(as5), Seq("a"))
checkEvaluation(Reverse(as6), Seq.empty)
checkEvaluation(Reverse(as7), null)
checkEvaluation(Reverse(aa), Seq(Seq("e"), Seq("c", "d"), Seq("a", "b")))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -629,9 +629,9 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
test("REVERSE") {
val s = 'a.string.at(0)
val row1 = create_row("abccc")
checkEvaluation(StringReverse(Literal("abccc")), "cccba", row1)
checkEvaluation(StringReverse(s), "cccba", row1)
checkEvaluation(StringReverse(Literal.create(null, StringType)), null, row1)
checkEvaluation(Reverse(Literal("abccc")), "cccba", row1)
checkEvaluation(Reverse(s), "cccba", row1)
checkEvaluation(Reverse(Literal.create(null, StringType)), null, row1)
}

test("SPACE") {
Expand Down
15 changes: 7 additions & 8 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2464,14 +2464,6 @@ object functions {
StringRepeat(str.expr, lit(n).expr)
}

/**
* Reverses the string column and returns it as a new string column.
*
* @group string_funcs
* @since 1.5.0
*/
def reverse(str: Column): Column = withExpr { StringReverse(str.expr) }

/**
* Trim the spaces from right end for the specified string value.
*
Expand Down Expand Up @@ -3316,6 +3308,13 @@ object functions {
*/
def array_max(e: Column): Column = withExpr { ArrayMax(e.expr) }

/**
* Returns a reversed string or an array with reverse order of elements.
* @group collection_funcs
* @since 1.5.0
*/
def reverse(e: Column): Column = withExpr { Reverse(e.expr) }

/**
* Returns an unordered array containing the keys of the map.
* @group collection_funcs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,100 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
checkAnswer(df.selectExpr("array_max(a)"), answer)
}

test("reverse function") {
val dummyFilter = (c: Column) => c.isNull || c.isNotNull // switch codegen on

// String test cases
val oneRowDF = Seq(("Spark", 3215)).toDF("s", "i")

checkAnswer(
oneRowDF.select(reverse('s)),
Seq(Row("krapS"))
)
checkAnswer(
oneRowDF.selectExpr("reverse(s)"),
Seq(Row("krapS"))
)
checkAnswer(
oneRowDF.select(reverse('i)),
Seq(Row("5123"))
)
checkAnswer(
oneRowDF.selectExpr("reverse(i)"),
Seq(Row("5123"))
)
checkAnswer(
oneRowDF.selectExpr("reverse(null)"),
Seq(Row(null))
)

// Array test cases (primitive-type elements)
val idf = Seq(
Seq(1, 9, 8, 7),
Seq(5, 8, 9, 7, 2),
Seq.empty,
null
).toDF("i")

checkAnswer(
idf.select(reverse('i)),
Seq(Row(Seq(7, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5)), Row(Seq.empty), Row(null))
)
checkAnswer(
idf.filter(dummyFilter('i)).select(reverse('i)),
Seq(Row(Seq(7, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5)), Row(Seq.empty), Row(null))
)
checkAnswer(
idf.selectExpr("reverse(i)"),
Seq(Row(Seq(7, 8, 9, 1)), Row(Seq(2, 7, 9, 8, 5)), Row(Seq.empty), Row(null))
)
checkAnswer(
oneRowDF.selectExpr("reverse(array(1, null, 2, null))"),
Seq(Row(Seq(null, 2, null, 1)))
)
checkAnswer(
oneRowDF.filter(dummyFilter('i)).selectExpr("reverse(array(1, null, 2, null))"),
Seq(Row(Seq(null, 2, null, 1)))
)

// Array test cases (non-primitive-type elements)
val sdf = Seq(
Seq("c", "a", "b"),
Seq("b", null, "c", null),
Seq.empty,
null
).toDF("s")

checkAnswer(
sdf.select(reverse('s)),
Seq(Row(Seq("b", "a", "c")), Row(Seq(null, "c", null, "b")), Row(Seq.empty), Row(null))
)
checkAnswer(
sdf.filter(dummyFilter('s)).select(reverse('s)),
Seq(Row(Seq("b", "a", "c")), Row(Seq(null, "c", null, "b")), Row(Seq.empty), Row(null))
)
checkAnswer(
sdf.selectExpr("reverse(s)"),
Seq(Row(Seq("b", "a", "c")), Row(Seq(null, "c", null, "b")), Row(Seq.empty), Row(null))
)
checkAnswer(
oneRowDF.selectExpr("reverse(array(array(1, 2), array(3, 4)))"),
Seq(Row(Seq(Seq(3, 4), Seq(1, 2))))
)
checkAnswer(
oneRowDF.filter(dummyFilter('s)).selectExpr("reverse(array(array(1, 2), array(3, 4)))"),
Seq(Row(Seq(Seq(3, 4), Seq(1, 2))))
)

// Error test cases
intercept[AnalysisException] {
oneRowDF.selectExpr("reverse(struct(1, 'a'))")
}
intercept[AnalysisException] {
oneRowDF.selectExpr("reverse(map(1, 'a'))")
}
}

private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = {
import DataFrameFunctionsSuite.CodegenFallbackExpr
for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) {
Expand Down

0 comments on commit f81fa47

Please sign in to comment.