Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-8301][SQL] Improve UTF8String substring/startsWith/endsWith/contains performance #6804

Closed
wants to merge 11 commits into from
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,6 @@ public double getDouble(int i) {

public UTF8String getUTF8String(int i) {
assertIndexIsValid(i);
final UTF8String str = new UTF8String();
final long offsetToStringSize = getLong(i);
final int stringSizeInBytes =
(int) PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offsetToStringSize);
Expand All @@ -322,8 +321,7 @@ public UTF8String getUTF8String(int i) {
PlatformDependent.BYTE_ARRAY_OFFSET,
stringSizeInBytes
);
str.set(strBytes);
return str;
return UTF8String.fromBytes(strBytes);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -437,17 +437,17 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w

case (BinaryType, StringType) =>
defineCodeGen (ctx, ev, c =>
s"new ${ctx.stringType}().set($c)")
s"${ctx.stringType}.fromBytes($c)")
case (DateType, StringType) =>
defineCodeGen(ctx, ev, c =>
s"""new ${ctx.stringType}().set(
s"""${ctx.stringType}.fromString(
org.apache.spark.sql.catalyst.util.DateUtils.toString($c))""")
// Special handling required for timestamps in hive test cases since the toString function
// does not match the expected output.
case (TimestampType, StringType) =>
super.genCode(ctx, ev)
case (_, StringType) =>
defineCodeGen(ctx, ev, c => s"new ${ctx.stringType}().set(String.valueOf($c))")
defineCodeGen(ctx, ev, c => s"${ctx.stringType}.fromString(String.valueOf($c))")

// fallback for DecimalType, this must be before other numeric types
case (_, dt: DecimalType) =>
Expand Down
30 changes: 18 additions & 12 deletions unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import java.io.Serializable;
import java.io.UnsupportedEncodingException;
import java.util.Arrays;
import javax.annotation.Nullable;
import javax.annotation.Nonnull;

import org.apache.spark.unsafe.PlatformDependent;

Expand All @@ -34,7 +34,7 @@
*/
public final class UTF8String implements Comparable<UTF8String>, Serializable {

@Nullable
@Nonnull
private byte[] bytes;

private static int[] bytesOfCodePointInUTF8 = {2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
Expand All @@ -55,7 +55,7 @@ public static UTF8String fromString(String str) {
/**
* Updates the UTF8String with String.
*/
public UTF8String set(final String str) {
protected UTF8String set(final String str) {
try {
bytes = str.getBytes("utf-8");
} catch (UnsupportedEncodingException e) {
Expand All @@ -69,7 +69,7 @@ public UTF8String set(final String str) {
/**
* Updates the UTF8String with byte[], which should be encoded in UTF-8.
*/
public UTF8String set(final byte[] bytes) {
protected UTF8String set(final byte[] bytes) {
this.bytes = bytes;
return this;
}
Expand Down Expand Up @@ -131,24 +131,30 @@ public boolean contains(final UTF8String substring) {
}

for (int i = 0; i <= bytes.length - b.length; i++) {
// TODO: Avoid copying.
if (bytes[i] == b[0] && Arrays.equals(Arrays.copyOfRange(bytes, i, i + b.length), b)) {
if (bytes[i] == b[0] && startsWith(b, i)) {
return true;
}
}
return false;
}

private boolean startsWith(final byte[] prefix, int offsetInBytes) {
if (prefix.length + offsetInBytes > bytes.length || offsetInBytes < 0) {
return false;
}
int i = 0;
while (i < prefix.length && prefix[i] == bytes[i + offsetInBytes]) {
i++;
}
return i == prefix.length;
}

public boolean startsWith(final UTF8String prefix) {
final byte[] b = prefix.getBytes();
// TODO: Avoid copying.
return b.length <= bytes.length && Arrays.equals(Arrays.copyOfRange(bytes, 0, b.length), b);
return startsWith(prefix.getBytes(), 0);
}

public boolean endsWith(final UTF8String suffix) {
final byte[] b = suffix.getBytes();
return b.length <= bytes.length &&
Arrays.equals(Arrays.copyOfRange(bytes, bytes.length - b.length, bytes.length), b);
return startsWith(suffix.getBytes(), bytes.length - suffix.getBytes().length);
}

public UTF8String toUpperCase() {
Expand Down