Skip to content

Commit

Permalink
[SPARK-22825][SQL] Fix incorrect results of Casting Array to String
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?
This pr fixed the issue when casting arrays into strings;
```
scala> val df = spark.range(10).select('id.cast("integer")).agg(collect_list('id).as('ids))
scala> df.write.saveAsTable("t")
scala> sql("SELECT cast(ids as String) FROM t").show(false)
+------------------------------------------------------------------+
|ids                                                               |
+------------------------------------------------------------------+
|org.apache.spark.sql.catalyst.expressions.UnsafeArrayData8bc285df|
+------------------------------------------------------------------+
```

This pr modified the result into;
```
+------------------------------+
|ids                           |
+------------------------------+
|[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]|
+------------------------------+
```

## How was this patch tested?
Added tests in `CastSuite` and `SQLQuerySuite`.

Author: Takeshi Yamamuro <[email protected]>

Closes #20024 from maropu/SPARK-22825.

(cherry picked from commit 52fc5c1)
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
maropu authored and cloud-fan committed Jan 5, 2018
1 parent 158f7e6 commit 145820b
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions.codegen;

import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.array.ByteArrayMethods;
import org.apache.spark.unsafe.types.UTF8String;

/**
* A helper class to write {@link UTF8String}s to an internal buffer and build the concatenated
* {@link UTF8String} at the end.
*/
public class UTF8StringBuilder {

private static final int ARRAY_MAX = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH;

private byte[] buffer;
private int cursor = Platform.BYTE_ARRAY_OFFSET;

public UTF8StringBuilder() {
// Since initial buffer size is 16 in `StringBuilder`, we set the same size here
this.buffer = new byte[16];
}

// Grows the buffer by at least `neededSize`
private void grow(int neededSize) {
if (neededSize > ARRAY_MAX - totalSize()) {
throw new UnsupportedOperationException(
"Cannot grow internal buffer by size " + neededSize + " because the size after growing " +
"exceeds size limitation " + ARRAY_MAX);
}
final int length = totalSize() + neededSize;
if (buffer.length < length) {
int newLength = length < ARRAY_MAX / 2 ? length * 2 : ARRAY_MAX;
final byte[] tmp = new byte[newLength];
Platform.copyMemory(
buffer,
Platform.BYTE_ARRAY_OFFSET,
tmp,
Platform.BYTE_ARRAY_OFFSET,
totalSize());
buffer = tmp;
}
}

private int totalSize() {
return cursor - Platform.BYTE_ARRAY_OFFSET;
}

public void append(UTF8String value) {
grow(value.numBytes());
value.writeToMemory(buffer, cursor);
cursor += value.numBytes();
}

public void append(String value) {
append(UTF8String.fromString(value));
}

public UTF8String build() {
return UTF8String.fromBytes(buffer, 0, totalSize());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,28 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
case DateType => buildCast[Int](_, d => UTF8String.fromString(DateTimeUtils.dateToString(d)))
case TimestampType => buildCast[Long](_,
t => UTF8String.fromString(DateTimeUtils.timestampToString(t, timeZone)))
case ArrayType(et, _) =>
buildCast[ArrayData](_, array => {
val builder = new UTF8StringBuilder
builder.append("[")
if (array.numElements > 0) {
val toUTF8String = castToString(et)
if (!array.isNullAt(0)) {
builder.append(toUTF8String(array.get(0, et)).asInstanceOf[UTF8String])
}
var i = 1
while (i < array.numElements) {
builder.append(",")
if (!array.isNullAt(i)) {
builder.append(" ")
builder.append(toUTF8String(array.get(i, et)).asInstanceOf[UTF8String])
}
i += 1
}
}
builder.append("]")
builder.build()
})
case _ => buildCast[Any](_, o => UTF8String.fromString(o.toString))
}

Expand Down Expand Up @@ -597,6 +619,41 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
"""
}

private def writeArrayToStringBuilder(
et: DataType,
array: String,
buffer: String,
ctx: CodegenContext): String = {
val elementToStringCode = castToStringCode(et, ctx)
val funcName = ctx.freshName("elementToString")
val elementToStringFunc = ctx.addNewFunction(funcName,
s"""
|private UTF8String $funcName(${ctx.javaType(et)} element) {
| UTF8String elementStr = null;
| ${elementToStringCode("element", "elementStr", null /* resultIsNull won't be used */)}
| return elementStr;
|}
""".stripMargin)

val loopIndex = ctx.freshName("loopIndex")
s"""
|$buffer.append("[");
|if ($array.numElements() > 0) {
| if (!$array.isNullAt(0)) {
| $buffer.append($elementToStringFunc(${ctx.getValue(array, et, "0")}));
| }
| for (int $loopIndex = 1; $loopIndex < $array.numElements(); $loopIndex++) {
| $buffer.append(",");
| if (!$array.isNullAt($loopIndex)) {
| $buffer.append(" ");
| $buffer.append($elementToStringFunc(${ctx.getValue(array, et, loopIndex)}));
| }
| }
|}
|$buffer.append("]");
""".stripMargin
}

private[this] def castToStringCode(from: DataType, ctx: CodegenContext): CastFunction = {
from match {
case BinaryType =>
Expand All @@ -608,6 +665,17 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
val tz = ctx.addReferenceObj("timeZone", timeZone)
(c, evPrim, evNull) => s"""$evPrim = UTF8String.fromString(
org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampToString($c, $tz));"""
case ArrayType(et, _) =>
(c, evPrim, evNull) => {
val buffer = ctx.freshName("buffer")
val bufferClass = classOf[UTF8StringBuilder].getName
val writeArrayElemCode = writeArrayToStringBuilder(et, c, buffer, ctx)
s"""
|$bufferClass $buffer = new $bufferClass();
|$writeArrayElemCode;
|$evPrim = $buffer.build();
""".stripMargin
}
case _ =>
(c, evPrim, evNull) => s"$evPrim = UTF8String.fromString(String.valueOf($c));"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -853,4 +853,29 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
cast("2", LongType).genCode(ctx)
assert(ctx.inlinedMutableStates.length == 0)
}

test("SPARK-22825 Cast array to string") {
val ret1 = cast(Literal.create(Array(1, 2, 3, 4, 5)), StringType)
checkEvaluation(ret1, "[1, 2, 3, 4, 5]")
val ret2 = cast(Literal.create(Array("ab", "cde", "f")), StringType)
checkEvaluation(ret2, "[ab, cde, f]")
val ret3 = cast(Literal.create(Array("ab", null, "c")), StringType)
checkEvaluation(ret3, "[ab,, c]")
val ret4 = cast(Literal.create(Array("ab".getBytes, "cde".getBytes, "f".getBytes)), StringType)
checkEvaluation(ret4, "[ab, cde, f]")
val ret5 = cast(
Literal.create(Array("2014-12-03", "2014-12-04", "2014-12-06").map(Date.valueOf)),
StringType)
checkEvaluation(ret5, "[2014-12-03, 2014-12-04, 2014-12-06]")
val ret6 = cast(
Literal.create(Array("2014-12-03 13:01:00", "2014-12-04 15:05:00").map(Timestamp.valueOf)),
StringType)
checkEvaluation(ret6, "[2014-12-03 13:01:00, 2014-12-04 15:05:00]")
val ret7 = cast(Literal.create(Array(Array(1, 2, 3), Array(4, 5))), StringType)
checkEvaluation(ret7, "[[1, 2, 3], [4, 5]]")
val ret8 = cast(
Literal.create(Array(Array(Array("a"), Array("b", "c")), Array(Array("d")))),
StringType)
checkEvaluation(ret8, "[[[a], [b, c]], [[d]]]")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
import org.apache.spark.sql.catalyst.util.StringUtils
import org.apache.spark.sql.execution.aggregate
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec}
import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation}
import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, CartesianProductExec, SortMergeJoinExec}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
Expand Down

0 comments on commit 145820b

Please sign in to comment.