-
Notifications
You must be signed in to change notification settings - Fork 831
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use unsafe to speed up string marshaling #6433
Changes from 5 commits
995adf2
1edd81a
53d9447
bfa30a7
eb1cc64
43c82d6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
plugins { | ||
id("otel.java-conventions") | ||
} | ||
|
||
description = "OpenTelemetry Exporter Compile Stub" | ||
otelJava.moduleName.set("io.opentelemetry.exporter.internal.compile-stub") |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
/* | ||
* Copyright The OpenTelemetry Authors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package sun.misc; | ||
|
||
import java.lang.reflect.Field; | ||
|
||
/** | ||
* sun.misc.Unsafe from the JDK isn't found by the compiler, we provide out own trimmed down version | ||
* that we can compile against. | ||
*/ | ||
public class Unsafe { | ||
|
||
public long objectFieldOffset(Field f) { | ||
return -1; | ||
} | ||
|
||
public Object getObject(Object o, long offset) { | ||
return null; | ||
} | ||
|
||
public byte getByte(Object o, long offset) { | ||
return 0; | ||
} | ||
|
||
public int arrayBaseOffset(Class<?> arrayClass) { | ||
return 0; | ||
} | ||
|
||
public long getLong(Object o, long offset) { | ||
return 0; | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -299,16 +299,63 @@ public static <K, V> int sizeMessageWithContext( | |
} | ||
|
||
/** Returns the size of utf8 encoded string in bytes. */ | ||
@SuppressWarnings("UnusedVariable") | ||
private static int getUtf8Size(String string, MarshalerContext context) { | ||
return getUtf8Size(string); | ||
return getUtf8Size(string, context.marshalStringUnsafe()); | ||
} | ||
|
||
// Visible for testing | ||
static int getUtf8Size(String string) { | ||
static int getUtf8Size(String string, boolean useUnsafe) { | ||
if (useUnsafe && UnsafeString.isAvailable() && UnsafeString.isLatin1(string)) { | ||
byte[] bytes = UnsafeString.getBytes(string); | ||
// latin1 bytes with negative value (most significant bit set) are encoded as 2 bytes in utf8 | ||
return string.length() + countNegative(bytes); | ||
} | ||
|
||
return encodedUtf8Length(string); | ||
} | ||
|
||
// Inner loop can process at most 8 * 255 bytes without overflowing counter. To process more bytes | ||
// inner loop has to be run multiple times. | ||
private static final int MAX_INNER_LOOP_SIZE = 8 * 255; | ||
// mask that selects only the most significant bit in every byte of the long | ||
private static final long MOST_SIGNIFICANT_BIT_MASK = 0x8080808080808080L; | ||
|
||
/** Returns the count of bytes with negative value. */ | ||
private static int countNegative(byte[] bytes) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Took me a bit to parse what's going on here but I think I get it:
Does this sound like a correct characterization @laurit? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
yes
when we read they byte array we don't know whether it is going to be 7bit or not
Typically such loops are sped up by using vector instructions that allow processing multiple bytes simultaneously. Processing bytes on long a time is similar to using vector instructions. It would be interesting to see what effect there would be when this method is written using the vector api preview feature in recent jdks.
That part doesn't use unsafe, just plain There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Fascinating! |
||
int count = 0; | ||
int offset = 0; | ||
// We are processing one long (8 bytes) at a time. In the inner loop we are keeping counts in a | ||
// long where each byte in the long is a separate counter. Due to this the inner loop can | ||
// process a maximum of 8*255 bytes at a time without overflow. | ||
for (int i = 1; i <= bytes.length / MAX_INNER_LOOP_SIZE + 1; i++) { | ||
long tmp = 0; // each byte in this long is a separate counter | ||
int limit = Math.min(i * MAX_INNER_LOOP_SIZE, bytes.length & ~7); | ||
for (; offset < limit; offset += 8) { | ||
long value = UnsafeString.getLong(bytes, offset); | ||
// Mask the value keeping only the most significant bit in each byte and then shift this bit | ||
// to the position of the least significant bit in each byte. If the input byte was not | ||
// negative then after this transformation it will be zero, if it was negative then it will | ||
// be one. | ||
tmp += (value & MOST_SIGNIFICANT_BIT_MASK) >>> 7; | ||
} | ||
// sum up counts | ||
if (tmp != 0) { | ||
for (int j = 0; j < 8; j++) { | ||
count += (int) (tmp & 0xff); | ||
tmp = tmp >>> 8; | ||
} | ||
} | ||
} | ||
|
||
// Handle remaining bytes. Previous loop processes 8 bytes a time, if the input size is not | ||
// divisible with 8 the remaining bytes are handled here. | ||
for (int i = offset; i < bytes.length; i++) { | ||
// same as if (bytes[i] < 0) count++; | ||
count += bytes[i] >>> 31; | ||
} | ||
return count; | ||
} | ||
|
||
// adapted from | ||
// https://github.com/protocolbuffers/protobuf/blob/b618f6750aed641a23d5f26fbbaf654668846d24/java/core/src/main/java/com/google/protobuf/Utf8.java#L217 | ||
private static int encodedUtf8Length(String string) { | ||
|
@@ -376,14 +423,24 @@ private static int encodedUtf8LengthGeneral(String string, int start) { | |
static void writeUtf8( | ||
CodedOutputStream output, String string, int utf8Length, MarshalerContext context) | ||
throws IOException { | ||
writeUtf8(output, string, utf8Length); | ||
writeUtf8(output, string, utf8Length, context.marshalStringUnsafe()); | ||
} | ||
|
||
// Visible for testing | ||
@SuppressWarnings("UnusedVariable") // utf8Length argument is added for future use | ||
static void writeUtf8(CodedOutputStream output, String string, int utf8Length) | ||
static void writeUtf8(CodedOutputStream output, String string, int utf8Length, boolean useUnsafe) | ||
throws IOException { | ||
encodeUtf8(output, string); | ||
// if the length of the latin1 string and the utf8 output are the same then the string must be | ||
// composed of only 7bit characters and can be directly copied to the output | ||
if (useUnsafe | ||
&& UnsafeString.isAvailable() | ||
&& string.length() == utf8Length | ||
&& UnsafeString.isLatin1(string)) { | ||
byte[] bytes = UnsafeString.getBytes(string); | ||
output.write(bytes, 0, bytes.length); | ||
} else { | ||
encodeUtf8(output, string); | ||
} | ||
} | ||
|
||
// encode utf8 the same way as length is computed in encodedUtf8Length | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
/* | ||
* Copyright The OpenTelemetry Authors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package io.opentelemetry.exporter.internal.marshal; | ||
|
||
import java.lang.reflect.Field; | ||
import sun.misc.Unsafe; | ||
|
||
class UnsafeAccess { | ||
private static final boolean available = checkUnsafe(); | ||
|
||
static boolean isAvailable() { | ||
return available; | ||
} | ||
|
||
private static boolean checkUnsafe() { | ||
try { | ||
jack-berg marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Class.forName("sun.misc.Unsafe", false, UnsafeAccess.class.getClassLoader()); | ||
return UnsafeHolder.UNSAFE != null; | ||
} catch (ClassNotFoundException e) { | ||
return false; | ||
Check warning on line 23 in exporters/common/src/main/java/io/opentelemetry/exporter/internal/marshal/UnsafeAccess.java Codecov / codecov/patchexporters/common/src/main/java/io/opentelemetry/exporter/internal/marshal/UnsafeAccess.java#L22-L23
|
||
} | ||
} | ||
|
||
static long objectFieldOffset(Field field) { | ||
return UnsafeHolder.UNSAFE.objectFieldOffset(field); | ||
} | ||
|
||
static Object getObject(Object object, long offset) { | ||
return UnsafeHolder.UNSAFE.getObject(object, offset); | ||
} | ||
|
||
static byte getByte(Object object, long offset) { | ||
return UnsafeHolder.UNSAFE.getByte(object, offset); | ||
} | ||
|
||
static int arrayBaseOffset(Class<?> arrayClass) { | ||
return UnsafeHolder.UNSAFE.arrayBaseOffset(arrayClass); | ||
} | ||
|
||
static long getLong(Object o, long offset) { | ||
return UnsafeHolder.UNSAFE.getLong(o, offset); | ||
} | ||
|
||
private UnsafeAccess() {} | ||
|
||
private static class UnsafeHolder { | ||
public static final Unsafe UNSAFE; | ||
|
||
static { | ||
UNSAFE = getUnsafe(); | ||
} | ||
|
||
private UnsafeHolder() {} | ||
|
||
@SuppressWarnings("NullAway") | ||
private static Unsafe getUnsafe() { | ||
try { | ||
Field field = Unsafe.class.getDeclaredField("theUnsafe"); | ||
field.setAccessible(true); | ||
return (Unsafe) field.get(null); | ||
} catch (Exception ignored) { | ||
return null; | ||
Check warning on line 65 in exporters/common/src/main/java/io/opentelemetry/exporter/internal/marshal/UnsafeAccess.java Codecov / codecov/patchexporters/common/src/main/java/io/opentelemetry/exporter/internal/marshal/UnsafeAccess.java#L64-L65
|
||
} | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
/* | ||
* Copyright The OpenTelemetry Authors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package io.opentelemetry.exporter.internal.marshal; | ||
|
||
import java.lang.reflect.Field; | ||
|
||
class UnsafeString { | ||
private static final long valueOffset = getStringFieldOffset("value", byte[].class); | ||
private static final long coderOffset = getStringFieldOffset("coder", byte.class); | ||
private static final int byteArrayBaseOffset = UnsafeAccess.arrayBaseOffset(byte[].class); | ||
private static final boolean available = valueOffset != -1 && coderOffset != -1; | ||
|
||
static boolean isAvailable() { | ||
return available; | ||
} | ||
|
||
static boolean isLatin1(String string) { | ||
// 0 represents latin1, 1 utf16 | ||
return UnsafeAccess.getByte(string, coderOffset) == 0; | ||
} | ||
|
||
static byte[] getBytes(String string) { | ||
return (byte[]) UnsafeAccess.getObject(string, valueOffset); | ||
} | ||
|
||
static long getLong(byte[] bytes, int index) { | ||
return UnsafeAccess.getLong(bytes, byteArrayBaseOffset + index); | ||
} | ||
|
||
private static long getStringFieldOffset(String fieldName, Class<?> expectedType) { | ||
if (!UnsafeAccess.isAvailable()) { | ||
return -1; | ||
Check warning on line 35 in exporters/common/src/main/java/io/opentelemetry/exporter/internal/marshal/UnsafeString.java Codecov / codecov/patchexporters/common/src/main/java/io/opentelemetry/exporter/internal/marshal/UnsafeString.java#L35
|
||
} | ||
|
||
try { | ||
Field field = String.class.getDeclaredField(fieldName); | ||
if (field.getType() != expectedType) { | ||
return -1; | ||
Check warning on line 41 in exporters/common/src/main/java/io/opentelemetry/exporter/internal/marshal/UnsafeString.java Codecov / codecov/patchexporters/common/src/main/java/io/opentelemetry/exporter/internal/marshal/UnsafeString.java#L41
|
||
} | ||
return UnsafeAccess.objectFieldOffset(field); | ||
} catch (Exception exception) { | ||
return -1; | ||
Check warning on line 45 in exporters/common/src/main/java/io/opentelemetry/exporter/internal/marshal/UnsafeString.java Codecov / codecov/patchexporters/common/src/main/java/io/opentelemetry/exporter/internal/marshal/UnsafeString.java#L44-L45
|
||
} | ||
} | ||
|
||
private UnsafeString() {} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
/* | ||
* Copyright The OpenTelemetry Authors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package io.opentelemetry.exporter.internal.marshal; | ||
|
||
import static io.opentelemetry.exporter.internal.marshal.StatelessMarshalerUtil.getUtf8Size; | ||
import static io.opentelemetry.exporter.internal.marshal.StatelessMarshalerUtilTest.testUtf8; | ||
import static org.assertj.core.api.Assertions.assertThat; | ||
|
||
import edu.berkeley.cs.jqf.fuzz.Fuzz; | ||
import edu.berkeley.cs.jqf.fuzz.JQF; | ||
import edu.berkeley.cs.jqf.fuzz.junit.GuidedFuzzing; | ||
import edu.berkeley.cs.jqf.fuzz.random.NoGuidance; | ||
import java.nio.charset.StandardCharsets; | ||
import org.junit.jupiter.api.Test; | ||
import org.junit.runner.Result; | ||
import org.junit.runner.RunWith; | ||
|
||
@SuppressWarnings("SystemOut") | ||
class StatelessMarshalerUtilFuzzTest { | ||
|
||
@RunWith(JQF.class) | ||
public static class EncodeUf8 { | ||
|
||
@Fuzz | ||
public void encodeRandomString(String value) { | ||
int utf8Size = value.getBytes(StandardCharsets.UTF_8).length; | ||
assertThat(getUtf8Size(value, false)).isEqualTo(utf8Size); | ||
assertThat(getUtf8Size(value, true)).isEqualTo(utf8Size); | ||
assertThat(testUtf8(value, utf8Size, /* useUnsafe= */ false)).isEqualTo(value); | ||
assertThat(testUtf8(value, utf8Size, /* useUnsafe= */ true)).isEqualTo(value); | ||
} | ||
} | ||
|
||
// driver methods to avoid having to use the vintage junit engine, and to enable increasing the | ||
// number of iterations: | ||
|
||
@Test | ||
void encodeUf8WithFuzzing() { | ||
Result result = | ||
GuidedFuzzing.run( | ||
EncodeUf8.class, "encodeRandomString", new NoGuidance(10000, System.out), System.out); | ||
assertThat(result.wasSuccessful()).isTrue(); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wondering if you can explain why a separate module is needed for this, even tho its not published.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When I try to compile a class that uses
Unsafe
with javac from jdk 17 targeting jdk 8 it fails withcuriously when targeting jdk 11 it gets a warning instead
Not using
--release 8
also results in a warningCreating a copy of
Unsafe
just for compiling felt like the easiest way to make this work. There is a comment in our version ofUnsafe
I could add a similar comment here also. Is this what you were looking for?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah that would be helpful.