Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use unsafe to speed up string marshaling #6433

Merged
merged 6 commits into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions exporters/common/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ dependencies {
api(project(":sdk-extensions:autoconfigure-spi"))

compileOnly(project(":sdk:common"))
compileOnly(project(":exporters:common:compile-stub"))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering if you can explain why a separate module is needed for this, even tho its not published.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I try to compile a class that uses Unsafe with javac from jdk 17 targeting jdk 8 it fails with

javac --release 8 Test.java
Test.java:3: error: package sun.misc does not exist
		System.err.println(sun.misc.Unsafe.class.getName());

curiously when targeting jdk 11 it gets a warning instead

javac --release 11 Test.java
Test.java:3: warning: Unsafe is internal proprietary API and may be removed in a future release
		System.err.println(sun.misc.Unsafe.class.getName());

Not using --release 8 also results in a warning

javac --source 8 --target 8 Test.java
warning: [options] bootstrap class path not set in conjunction with -source 8
Test.java:3: warning: Unsafe is internal proprietary API and may be removed in a future release
		System.err.println(sun.misc.Unsafe.class.getName());

Creating a copy of Unsafe just for compiling felt like the easiest way to make this work. There is a comment in our version of Unsafe

/**
 * sun.misc.Unsafe from the JDK isn't found by the compiler, we provide out own trimmed down version
 * that we can compile against.
 */

I could add a similar comment here also. Is this what you were looking for?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that would be helpful.


compileOnly("org.codehaus.mojo:animal-sniffer-annotations")

Expand All @@ -31,6 +32,7 @@ dependencies {
testImplementation("org.skyscreamer:jsonassert")
testImplementation("com.google.api.grpc:proto-google-common-protos")
testImplementation("io.grpc:grpc-testing")
testImplementation("edu.berkeley.cs.jqf:jqf-fuzz")
testRuntimeOnly("io.grpc:grpc-netty-shaded")
}

Expand Down
6 changes: 6 additions & 0 deletions exporters/common/compile-stub/build.gradle.kts
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
plugins {
id("otel.java-conventions")
}

description = "OpenTelemetry Exporter Compile Stub"
otelJava.moduleName.set("io.opentelemetry.exporter.internal.compile-stub")
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Copyright The OpenTelemetry Authors
* SPDX-License-Identifier: Apache-2.0
*/

package sun.misc;

import java.lang.reflect.Field;

/**
* sun.misc.Unsafe from the JDK isn't found by the compiler, we provide out own trimmed down version
* that we can compile against.
*/
public class Unsafe {

public long objectFieldOffset(Field f) {
return -1;
}

public Object getObject(Object o, long offset) {
return null;
}

public byte getByte(Object o, long offset) {
return 0;
}

public int arrayBaseOffset(Class<?> arrayClass) {
return 0;
}

public long getLong(Object o, long offset) {
return 0;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
*/
public final class MarshalerContext {
private final boolean marshalStringNoAllocation;
private final boolean marshalStringUnsafe;

private int[] sizes = new int[16];
private int sizeReadIndex;
Expand All @@ -32,19 +33,23 @@ public final class MarshalerContext {
private int dataReadIndex;
private int dataWriteIndex;

@SuppressWarnings("BooleanParameter")
public MarshalerContext() {
this(true);
this(/* marshalStringNoAllocation= */ true, /* marshalStringUnsafe= */ true);
}

public MarshalerContext(boolean marshalStringNoAllocation) {
public MarshalerContext(boolean marshalStringNoAllocation, boolean marshalStringUnsafe) {
this.marshalStringNoAllocation = marshalStringNoAllocation;
this.marshalStringUnsafe = marshalStringUnsafe;
}

public boolean marshalStringNoAllocation() {
return marshalStringNoAllocation;
}

public boolean marshalStringUnsafe() {
return marshalStringUnsafe;
}

public void addSize(int size) {
growSizeIfNeeded();
sizes[sizeWriteIndex++] = size;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -299,16 +299,63 @@ public static <K, V> int sizeMessageWithContext(
}

/** Returns the size of utf8 encoded string in bytes. */
@SuppressWarnings("UnusedVariable")
private static int getUtf8Size(String string, MarshalerContext context) {
return getUtf8Size(string);
return getUtf8Size(string, context.marshalStringUnsafe());
}

// Visible for testing
static int getUtf8Size(String string) {
static int getUtf8Size(String string, boolean useUnsafe) {
if (useUnsafe && UnsafeString.isAvailable() && UnsafeString.isLatin1(string)) {
byte[] bytes = UnsafeString.getBytes(string);
// latin1 bytes with negative value (most significant bit set) are encoded as 2 bytes in utf8
return string.length() + countNegative(bytes);
}

return encodedUtf8Length(string);
}

// Inner loop can process at most 8 * 255 bytes without overflowing counter. To process more bytes
// inner loop has to be run multiple times.
private static final int MAX_INNER_LOOP_SIZE = 8 * 255;
// mask that selects only the most significant bit in every byte of the long
private static final long MOST_SIGNIFICANT_BIT_MASK = 0x8080808080808080L;

/** Returns the count of bytes with negative value. */
private static int countNegative(byte[] bytes) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Took me a bit to parse what's going on here but I think I get it:

  • We want to use unsafe to read the value of a string's internal byte[] directly when a string is composed of all latin1 characters that fit within 7 bits (first 128 code points)
  • In order to determine if a string meets this condition as fast as possible, we use unsafe and some clever (seriously clever - where did you come up with this??) shifting to read through the string in chunks of 8 bytes, counting the number of times when a byte has a 1 in the most significant bit and therefore doesn't fit in 7 bits. Reading 8 bytes at a time is faster than reading 1 byte at a time, which is what we do currently when trying to compute the size of a string with reusable_data memory mode.
  • Later, when we are serializing the string, we check if the string is latin1 and had a size which was the same as the string length. If true, we can use unsafe to write the string's internal byte[] to the output stream. If false, then fallback to the existing serialization logic.

Does this sound like a correct characterization @laurit?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this sound like a correct characterization @laurit?

yes

We want to use unsafe to read the value of a string's internal byte[] directly when a string is composed of all latin1 characters that fit within 7 bits (first 128 code points)

when we read they byte array we don't know whether it is going to be 7bit or not

where did you come up with this?

Typically such loops are sped up by using vector instructions that allow processing multiple bytes simultaneously. Processing bytes on long a time is similar to using vector instructions. It would be interesting to see what effect there would be when this method is written using the vector api preview feature in recent jdks.

If true, we can use unsafe to write the string's internal byte[] to the output stream

That part doesn't use unsafe, just plain System.arraycopy

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typically such loops are sped up by using vector instructions that allow processing multiple bytes simultaneously. Processing bytes on long a time is similar to using vector instructions. It would be interesting to see what effect there would be when this method is written using the vector api preview feature in recent jdks.

Fascinating!

int count = 0;
int offset = 0;
// We are processing one long (8 bytes) at a time. In the inner loop we are keeping counts in a
// long where each byte in the long is a separate counter. Due to this the inner loop can
// process a maximum of 8*255 bytes at a time without overflow.
for (int i = 1; i <= bytes.length / MAX_INNER_LOOP_SIZE + 1; i++) {
long tmp = 0; // each byte in this long is a separate counter
int limit = Math.min(i * MAX_INNER_LOOP_SIZE, bytes.length & ~7);
for (; offset < limit; offset += 8) {
long value = UnsafeString.getLong(bytes, offset);
// Mask the value keeping only the most significant bit in each byte and then shift this bit
// to the position of the least significant bit in each byte. If the input byte was not
// negative then after this transformation it will be zero, if it was negative then it will
// be one.
tmp += (value & MOST_SIGNIFICANT_BIT_MASK) >>> 7;
}
// sum up counts
if (tmp != 0) {
for (int j = 0; j < 8; j++) {
count += (int) (tmp & 0xff);
tmp = tmp >>> 8;
}
}
}

// Handle remaining bytes. Previous loop processes 8 bytes a time, if the input size is not
// divisible with 8 the remaining bytes are handled here.
for (int i = offset; i < bytes.length; i++) {
// same as if (bytes[i] < 0) count++;
count += bytes[i] >>> 31;
}
return count;
}

// adapted from
// https://github.com/protocolbuffers/protobuf/blob/b618f6750aed641a23d5f26fbbaf654668846d24/java/core/src/main/java/com/google/protobuf/Utf8.java#L217
private static int encodedUtf8Length(String string) {
Expand Down Expand Up @@ -376,14 +423,24 @@ private static int encodedUtf8LengthGeneral(String string, int start) {
static void writeUtf8(
CodedOutputStream output, String string, int utf8Length, MarshalerContext context)
throws IOException {
writeUtf8(output, string, utf8Length);
writeUtf8(output, string, utf8Length, context.marshalStringUnsafe());
}

// Visible for testing
@SuppressWarnings("UnusedVariable") // utf8Length argument is added for future use
static void writeUtf8(CodedOutputStream output, String string, int utf8Length)
static void writeUtf8(CodedOutputStream output, String string, int utf8Length, boolean useUnsafe)
throws IOException {
encodeUtf8(output, string);
// if the length of the latin1 string and the utf8 output are the same then the string must be
// composed of only 7bit characters and can be directly copied to the output
if (useUnsafe
&& UnsafeString.isAvailable()
&& string.length() == utf8Length
&& UnsafeString.isLatin1(string)) {
byte[] bytes = UnsafeString.getBytes(string);
output.write(bytes, 0, bytes.length);
} else {
encodeUtf8(output, string);
}
}

// encode utf8 the same way as length is computed in encodedUtf8Length
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Copyright The OpenTelemetry Authors
* SPDX-License-Identifier: Apache-2.0
*/

package io.opentelemetry.exporter.internal.marshal;

import java.lang.reflect.Field;
import sun.misc.Unsafe;

class UnsafeAccess {
private static final boolean available = checkUnsafe();

static boolean isAvailable() {
return available;
}

private static boolean checkUnsafe() {
try {
jack-berg marked this conversation as resolved.
Show resolved Hide resolved
Class.forName("sun.misc.Unsafe", false, UnsafeAccess.class.getClassLoader());
return UnsafeHolder.UNSAFE != null;
} catch (ClassNotFoundException e) {
return false;

Check warning on line 23 in exporters/common/src/main/java/io/opentelemetry/exporter/internal/marshal/UnsafeAccess.java

View check run for this annotation

Codecov / codecov/patch

exporters/common/src/main/java/io/opentelemetry/exporter/internal/marshal/UnsafeAccess.java#L22-L23

Added lines #L22 - L23 were not covered by tests
}
}

static long objectFieldOffset(Field field) {
return UnsafeHolder.UNSAFE.objectFieldOffset(field);
}

static Object getObject(Object object, long offset) {
return UnsafeHolder.UNSAFE.getObject(object, offset);
}

static byte getByte(Object object, long offset) {
return UnsafeHolder.UNSAFE.getByte(object, offset);
}

static int arrayBaseOffset(Class<?> arrayClass) {
return UnsafeHolder.UNSAFE.arrayBaseOffset(arrayClass);
}

static long getLong(Object o, long offset) {
return UnsafeHolder.UNSAFE.getLong(o, offset);
}

private UnsafeAccess() {}

private static class UnsafeHolder {
public static final Unsafe UNSAFE;

static {
UNSAFE = getUnsafe();
}

private UnsafeHolder() {}

@SuppressWarnings("NullAway")
private static Unsafe getUnsafe() {
try {
Field field = Unsafe.class.getDeclaredField("theUnsafe");
field.setAccessible(true);
return (Unsafe) field.get(null);
} catch (Exception ignored) {
return null;

Check warning on line 65 in exporters/common/src/main/java/io/opentelemetry/exporter/internal/marshal/UnsafeAccess.java

View check run for this annotation

Codecov / codecov/patch

exporters/common/src/main/java/io/opentelemetry/exporter/internal/marshal/UnsafeAccess.java#L64-L65

Added lines #L64 - L65 were not covered by tests
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Copyright The OpenTelemetry Authors
* SPDX-License-Identifier: Apache-2.0
*/

package io.opentelemetry.exporter.internal.marshal;

import java.lang.reflect.Field;

class UnsafeString {
private static final long valueOffset = getStringFieldOffset("value", byte[].class);
private static final long coderOffset = getStringFieldOffset("coder", byte.class);
private static final int byteArrayBaseOffset = UnsafeAccess.arrayBaseOffset(byte[].class);
private static final boolean available = valueOffset != -1 && coderOffset != -1;

static boolean isAvailable() {
return available;
}

static boolean isLatin1(String string) {
// 0 represents latin1, 1 utf16
return UnsafeAccess.getByte(string, coderOffset) == 0;
}

static byte[] getBytes(String string) {
return (byte[]) UnsafeAccess.getObject(string, valueOffset);
}

static long getLong(byte[] bytes, int index) {
return UnsafeAccess.getLong(bytes, byteArrayBaseOffset + index);
}

private static long getStringFieldOffset(String fieldName, Class<?> expectedType) {
if (!UnsafeAccess.isAvailable()) {
return -1;

Check warning on line 35 in exporters/common/src/main/java/io/opentelemetry/exporter/internal/marshal/UnsafeString.java

View check run for this annotation

Codecov / codecov/patch

exporters/common/src/main/java/io/opentelemetry/exporter/internal/marshal/UnsafeString.java#L35

Added line #L35 was not covered by tests
}

try {
Field field = String.class.getDeclaredField(fieldName);
if (field.getType() != expectedType) {
return -1;

Check warning on line 41 in exporters/common/src/main/java/io/opentelemetry/exporter/internal/marshal/UnsafeString.java

View check run for this annotation

Codecov / codecov/patch

exporters/common/src/main/java/io/opentelemetry/exporter/internal/marshal/UnsafeString.java#L41

Added line #L41 was not covered by tests
}
return UnsafeAccess.objectFieldOffset(field);
} catch (Exception exception) {
return -1;

Check warning on line 45 in exporters/common/src/main/java/io/opentelemetry/exporter/internal/marshal/UnsafeString.java

View check run for this annotation

Codecov / codecov/patch

exporters/common/src/main/java/io/opentelemetry/exporter/internal/marshal/UnsafeString.java#L44-L45

Added lines #L44 - L45 were not covered by tests
}
}

private UnsafeString() {}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Copyright The OpenTelemetry Authors
* SPDX-License-Identifier: Apache-2.0
*/

package io.opentelemetry.exporter.internal.marshal;

import static io.opentelemetry.exporter.internal.marshal.StatelessMarshalerUtil.getUtf8Size;
import static io.opentelemetry.exporter.internal.marshal.StatelessMarshalerUtilTest.testUtf8;
import static org.assertj.core.api.Assertions.assertThat;

import edu.berkeley.cs.jqf.fuzz.Fuzz;
import edu.berkeley.cs.jqf.fuzz.JQF;
import edu.berkeley.cs.jqf.fuzz.junit.GuidedFuzzing;
import edu.berkeley.cs.jqf.fuzz.random.NoGuidance;
import java.nio.charset.StandardCharsets;
import org.junit.jupiter.api.Test;
import org.junit.runner.Result;
import org.junit.runner.RunWith;

@SuppressWarnings("SystemOut")
class StatelessMarshalerUtilFuzzTest {

@RunWith(JQF.class)
public static class EncodeUf8 {

@Fuzz
public void encodeRandomString(String value) {
int utf8Size = value.getBytes(StandardCharsets.UTF_8).length;
assertThat(getUtf8Size(value, false)).isEqualTo(utf8Size);
assertThat(getUtf8Size(value, true)).isEqualTo(utf8Size);
assertThat(testUtf8(value, utf8Size, /* useUnsafe= */ false)).isEqualTo(value);
assertThat(testUtf8(value, utf8Size, /* useUnsafe= */ true)).isEqualTo(value);
}
}

// driver methods to avoid having to use the vintage junit engine, and to enable increasing the
// number of iterations:

@Test
void encodeUf8WithFuzzing() {
Result result =
GuidedFuzzing.run(
EncodeUf8.class, "encodeRandomString", new NoGuidance(10000, System.out), System.out);
assertThat(result.wasSuccessful()).isTrue();
}
}
Loading
Loading