From 1d302ef1112de7c684af65fd9b0589715129f91c Mon Sep 17 00:00:00 2001 From: Shawn Yang Date: Sat, 6 Apr 2024 20:43:06 +0800 Subject: [PATCH] feat(java): optimize read float/double for jvm jit inline (#1472) - optimize read float jvm jit inline by separate little/big endian methods - optimize read double jvm jit inline by separate little/big endian methods - generate unsafe get float code online - generate unsafe get double code online --- .../org/apache/fury/builder/CodecBuilder.java | 36 ++-- .../org/apache/fury/memory/MemoryBuffer.java | 175 +++++++----------- .../java/org/apache/fury/FuryTestBase.java | 5 + .../apache/fury/memory/MemoryBufferTest.java | 6 +- .../serializer/PrimitiveSerializersTest.java | 65 ++++++- 5 files changed, 158 insertions(+), 129 deletions(-) diff --git a/java/fury-core/src/main/java/org/apache/fury/builder/CodecBuilder.java b/java/fury-core/src/main/java/org/apache/fury/builder/CodecBuilder.java index 746a698d96..9c69757e17 100644 --- a/java/fury-core/src/main/java/org/apache/fury/builder/CodecBuilder.java +++ b/java/fury-core/src/main/java/org/apache/fury/builder/CodecBuilder.java @@ -572,18 +572,13 @@ protected Expression unsafePutDouble(Expression base, Expression pos, Expression return new StaticInvoke(MemoryBuffer.class, "unsafePutDouble", base, pos, value); } - /** - * Build unsafeGet operation. - * - * @see MemoryBuffer#unsafeGet(Object, long) - */ + /** Build unsafeGet operation. */ protected Expression unsafeGet(Expression base, Expression pos) { - return new StaticInvoke(MemoryBuffer.class, "unsafeGet", PRIMITIVE_BYTE_TYPE, base, pos); + return new StaticInvoke(Platform.class, "getByte", PRIMITIVE_BYTE_TYPE, base, pos); } protected Expression unsafeGetBoolean(Expression base, Expression pos) { - return new StaticInvoke( - MemoryBuffer.class, "unsafeGetBoolean", PRIMITIVE_BOOLEAN_TYPE, base, pos); + return new StaticInvoke(Platform.class, "getBoolean", PRIMITIVE_BOOLEAN_TYPE, base, pos); } protected Expression unsafeGetChar(Expression base, Expression pos) { @@ -614,18 +609,25 @@ protected Expression unsafeGetInt(Expression base, Expression pos) { protected Expression unsafeGetLong(Expression base, Expression pos) { StaticInvoke expr = new StaticInvoke(Platform.class, "getLong", PRIMITIVE_LONG_TYPE, base, pos); if (!Platform.IS_LITTLE_ENDIAN) { - expr = new StaticInvoke(Long.class, "reverseBytes", PRIMITIVE_INT_TYPE, expr.inline()); + expr = new StaticInvoke(Long.class, "reverseBytes", PRIMITIVE_LONG_TYPE, expr.inline()); } return expr; } protected Expression unsafeGetFloat(Expression base, Expression pos) { - return new StaticInvoke(MemoryBuffer.class, "unsafeGetFloat", PRIMITIVE_FLOAT_TYPE, base, pos); + StaticInvoke expr = new StaticInvoke(Platform.class, "getInt", PRIMITIVE_INT_TYPE, base, pos); + if (!Platform.IS_LITTLE_ENDIAN) { + expr = new StaticInvoke(Integer.class, "reverseBytes", PRIMITIVE_INT_TYPE, expr.inline()); + } + return new StaticInvoke(Float.class, "intBitsToFloat", PRIMITIVE_FLOAT_TYPE, expr.inline()); } protected Expression unsafeGetDouble(Expression base, Expression pos) { - return new StaticInvoke( - MemoryBuffer.class, "unsafeGetDouble", PRIMITIVE_DOUBLE_TYPE, base, pos); + StaticInvoke expr = new StaticInvoke(Platform.class, "getLong", PRIMITIVE_LONG_TYPE, base, pos); + if (!Platform.IS_LITTLE_ENDIAN) { + expr = new StaticInvoke(Long.class, "reverseBytes", PRIMITIVE_LONG_TYPE, expr.inline()); + } + return new StaticInvoke(Double.class, "longBitsToDouble", PRIMITIVE_DOUBLE_TYPE, expr.inline()); } protected Expression readChar(Expression buffer) { @@ -655,4 +657,14 @@ protected Expression readLong(Expression buffer) { public static String readLongFunc() { return Platform.IS_LITTLE_ENDIAN ? "readLongOnLE" : "readLongOnBE"; } + + protected Expression readFloat(Expression buffer) { + String func = Platform.IS_LITTLE_ENDIAN ? "readFloatOnLE" : "readFloatOnBE"; + return new Invoke(buffer, func, PRIMITIVE_FLOAT_TYPE); + } + + protected Expression readDouble(Expression buffer) { + String func = Platform.IS_LITTLE_ENDIAN ? "readDoubleOnLE" : "readDoubleOnBE"; + return new Invoke(buffer, func, PRIMITIVE_DOUBLE_TYPE); + } } diff --git a/java/fury-core/src/main/java/org/apache/fury/memory/MemoryBuffer.java b/java/fury-core/src/main/java/org/apache/fury/memory/MemoryBuffer.java index 155b8df6cc..025ceb2820 100644 --- a/java/fury-core/src/main/java/org/apache/fury/memory/MemoryBuffer.java +++ b/java/fury-core/src/main/java/org/apache/fury/memory/MemoryBuffer.java @@ -252,10 +252,6 @@ private void checkPosition(long index, long pos, long length) { } } - public static byte unsafeGet(Object o, long offset) { - return UNSAFE.getByte(o, offset); - } - public byte unsafeGet(int index) { final long pos = address + index; return UNSAFE.getByte(heapMemory, pos); @@ -444,15 +440,6 @@ public void put(int index, byte[] src, int offset, int length) { } } - public static boolean unsafeGetBoolean(Object o, long offset) { - return UNSAFE.getBoolean(o, offset); - } - - public boolean unsafeGetBoolean(int index) { - final long pos = address + index; - return UNSAFE.getByte(heapMemory, pos) != 0; - } - public boolean getBoolean(int index) { return get(index) != 0; } @@ -475,14 +462,6 @@ public char getCharN(int index) { return UNSAFE.getChar(heapMemory, pos); } - public char getCharB(int index) { - if (LITTLE_ENDIAN) { - return Character.reverseBytes(getCharN(index)); - } else { - return getCharN(index); - } - } - public char getChar(int index) { if (LITTLE_ENDIAN) { return getCharN(index); @@ -548,12 +527,6 @@ public short getShort(int index) { } } - public void putShortN(int index, short value) { - final long pos = address + index; - checkPosition(index, pos, 2); - UNSAFE.putShort(heapMemory, pos, value); - } - public void putShort(int index, short value) { final long pos = address + index; checkPosition(index, pos, 2); @@ -590,12 +563,6 @@ public static void unsafePutShort(Object o, long pos, short value) { } } - public int getIntN(int index) { - final long pos = address + index; - checkPosition(index, pos, 4); - return UNSAFE.getInt(heapMemory, pos); - } - public int getInt(int index) { final long pos = address + index; checkPosition(index, pos, 4); @@ -606,12 +573,6 @@ public int getInt(int index) { } } - public void putIntN(int index, int value) { - final long pos = address + index; - checkPosition(index, pos, 4); - UNSAFE.putInt(heapMemory, pos, value); - } - public void putInt(int index, int value) { final long pos = address + index; checkPosition(index, pos, 4); @@ -622,11 +583,6 @@ public void putInt(int index, int value) { } } - public int unsafeGetIntN(int index) { - final long pos = address + index; - return UNSAFE.getInt(heapMemory, pos); - } - public int unsafeGetInt(int index) { final long pos = address + index; if (LITTLE_ENDIAN) { @@ -644,11 +600,6 @@ public static int unsafeGetInt(Object o, long pos) { } } - public void unsafePutIntN(int index, int value) { - final long pos = address + index; - UNSAFE.putInt(heapMemory, pos, value); - } - public void unsafePutInt(int index, int value) { final long pos = address + index; if (LITTLE_ENDIAN) { @@ -666,12 +617,6 @@ public static void unsafePutInt(Object o, long pos, int value) { } } - public long getLongN(int index) { - final long pos = address + index; - checkPosition(index, pos, 8); - return UNSAFE.getLong(heapMemory, pos); - } - public long getLong(int index) { final long pos = address + index; checkPosition(index, pos, 8); @@ -716,11 +661,6 @@ public void putLongB(int index, long value) { } } - public long unsafeGetLongN(int index) { - final long pos = address + index; - return UNSAFE.getLong(heapMemory, pos); - } - public long unsafeGetLong(int index) { final long pos = address + index; if (LITTLE_ENDIAN) { @@ -738,11 +678,6 @@ public static long unsafeGetLong(Object o, long pos) { } } - public void unsafePutLongN(int index, long value) { - final long pos = address + index; - UNSAFE.putLong(heapMemory, pos, value); - } - public void unsafePutLong(int index, long value) { final long pos = address + index; if (LITTLE_ENDIAN) { @@ -760,10 +695,6 @@ public static void unsafePutLong(Object o, long pos, long value) { } } - public float getFloatN(int index) { - return Float.intBitsToFloat(getIntN(index)); - } - public float getFloat(int index) { final long pos = address + index; if (LITTLE_ENDIAN) { @@ -773,10 +704,6 @@ public float getFloat(int index) { } } - public void putFloatN(int index, float value) { - putIntN(index, Float.floatToRawIntBits(value)); - } - public void putFloat(int index, float value) { final long pos = address + index; checkPosition(index, pos, 4); @@ -787,10 +714,6 @@ public void putFloat(int index, float value) { } } - public float unsafeGetFloatN(int index) { - return Float.intBitsToFloat(unsafeGetIntN(index)); - } - public float unsafeGetFloat(int index) { final long pos = address + index; if (LITTLE_ENDIAN) { @@ -800,18 +723,6 @@ public float unsafeGetFloat(int index) { } } - public static float unsafeGetFloat(Object o, long pos) { - if (LITTLE_ENDIAN) { - return Float.intBitsToFloat(UNSAFE.getInt(o, pos)); - } else { - return Float.intBitsToFloat(Integer.reverseBytes(UNSAFE.getInt(o, pos))); - } - } - - public void unsafePutFloatN(int index, float value) { - unsafePutIntN(index, Float.floatToRawIntBits(value)); - } - public void unsafePutFloat(int index, float value) { final long pos = address + index; if (LITTLE_ENDIAN) { @@ -829,10 +740,6 @@ public static void unsafePutFloat(Object o, long pos, float value) { } } - public double getDoubleN(int index) { - return Double.longBitsToDouble(getLongN(index)); - } - public double getDouble(int index) { final long pos = address + index; checkPosition(index, pos, 8); @@ -843,10 +750,6 @@ public double getDouble(int index) { } } - public void putDoubleN(int index, double value) { - putLongN(index, Double.doubleToRawLongBits(value)); - } - public void putDouble(int index, double value) { final long pos = address + index; checkPosition(index, pos, 8); @@ -857,10 +760,6 @@ public void putDouble(int index, double value) { } } - public double unsafeGetDoubleN(int index) { - return Double.longBitsToDouble(unsafeGetLongN(index)); - } - public double unsafeGetDouble(int index) { final long pos = address + index; if (LITTLE_ENDIAN) { @@ -870,18 +769,6 @@ public double unsafeGetDouble(int index) { } } - public static double unsafeGetDouble(Object o, long pos) { - if (LITTLE_ENDIAN) { - return Double.longBitsToDouble(UNSAFE.getLong(o, pos)); - } else { - return Double.longBitsToDouble(Long.reverseBytes(UNSAFE.getLong(o, pos))); - } - } - - public void unsafePutDoubleN(int index, double value) { - unsafePutLongN(index, Double.doubleToRawLongBits(value)); - } - public void unsafePutDouble(int index, double value) { final long pos = address + index; if (LITTLE_ENDIAN) { @@ -2318,6 +2205,37 @@ public float readFloat() { } } + // Reduce method body for better inline in the caller. + @CodegenInvoke + public float readFloatOnLE() { + int readerIdx = readerIndex; + // use subtract to avoid overflow + int remaining = size - readerIdx; + if (remaining < 4) { + throw new IndexOutOfBoundsException( + String.format( + "readerIndex(%d) + length(%d) exceeds size(%d): %s", readerIdx, 4, size, this)); + } + readerIndex = readerIdx + 4; + return Float.intBitsToFloat(UNSAFE.getInt(heapMemory, address + readerIdx)); + } + + // Reduce method body for better inline in the caller. + @CodegenInvoke + public float readFloatOnBE() { + int readerIdx = readerIndex; + // use subtract to avoid overflow + int remaining = size - readerIdx; + if (remaining < 4) { + throw new IndexOutOfBoundsException( + String.format( + "readerIndex(%d) + length(%d) exceeds size(%d): %s", readerIdx, 4, size, this)); + } + readerIndex = readerIdx + 4; + return Float.intBitsToFloat( + Integer.reverseBytes(UNSAFE.getInt(heapMemory, address + readerIdx))); + } + public double readDouble() { int readerIdx = readerIndex; // use subtract to avoid overflow @@ -2335,6 +2253,37 @@ public double readDouble() { } } + // Reduce method body for better inline in the caller. + @CodegenInvoke + public double readDoubleOnLE() { + int readerIdx = readerIndex; + // use subtract to avoid overflow + int remaining = size - readerIdx; + if (remaining < 8) { + throw new IndexOutOfBoundsException( + String.format( + "readerIndex(%d) + length(%d) exceeds size(%d): %s", readerIdx, 8, size, this)); + } + readerIndex = readerIdx + 8; + return Double.longBitsToDouble(UNSAFE.getLong(heapMemory, address + readerIdx)); + } + + // Reduce method body for better inline in the caller. + @CodegenInvoke + public double readDoubleOnBE() { + int readerIdx = readerIndex; + // use subtract to avoid overflow + int remaining = size - readerIdx; + if (remaining < 8) { + throw new IndexOutOfBoundsException( + String.format( + "readerIndex(%d) + length(%d) exceeds size(%d): %s", readerIdx, 8, size, this)); + } + readerIndex = readerIdx + 8; + return Double.longBitsToDouble( + Long.reverseBytes(UNSAFE.getLong(heapMemory, address + readerIdx))); + } + public byte[] readBytes(int length) { int readerIdx = readerIndex; // use subtract to avoid overflow diff --git a/java/fury-core/src/test/java/org/apache/fury/FuryTestBase.java b/java/fury-core/src/test/java/org/apache/fury/FuryTestBase.java index acf841fa1e..af3796612f 100644 --- a/java/fury-core/src/test/java/org/apache/fury/FuryTestBase.java +++ b/java/fury-core/src/test/java/org/apache/fury/FuryTestBase.java @@ -95,6 +95,11 @@ public static Object[][] compressNumber() { return new Object[][] {{false}, {true}}; } + @DataProvider + public static Object[][] compressNumberAndCodeGen() { + return new Object[][] {{false, false}, {true, false}, {false, true}, {true, true}}; + } + @DataProvider public static Object[][] refTrackingAndCompressNumber() { return new Object[][] {{false, false}, {true, false}, {false, true}, {true, true}}; diff --git a/java/fury-core/src/test/java/org/apache/fury/memory/MemoryBufferTest.java b/java/fury-core/src/test/java/org/apache/fury/memory/MemoryBufferTest.java index 78fa41ba7a..89d8b8215b 100644 --- a/java/fury-core/src/test/java/org/apache/fury/memory/MemoryBufferTest.java +++ b/java/fury-core/src/test/java/org/apache/fury/memory/MemoryBufferTest.java @@ -75,9 +75,9 @@ public void testBufferUnsafeWrite() { MemoryBuffer.unsafePutDouble(heapMemory, pos, -1); pos += 8; MemoryBuffer.unsafePutFloat(heapMemory, pos, -1); - assertEquals(MemoryBuffer.unsafeGetFloat(heapMemory, pos), -1); + assertEquals(buffer.unsafeGetFloat((int) (pos - Platform.BYTE_ARRAY_OFFSET)), -1); pos -= 8; - assertEquals(MemoryBuffer.unsafeGetDouble(heapMemory, pos), -1); + assertEquals(buffer.unsafeGetDouble((int) (pos - Platform.BYTE_ARRAY_OFFSET)), -1); pos -= 8; assertEquals(MemoryBuffer.unsafeGetLong(heapMemory, pos), Long.MAX_VALUE); pos -= 4; @@ -85,7 +85,7 @@ public void testBufferUnsafeWrite() { pos -= 2; assertEquals(buffer.getShort((int) (pos - Platform.BYTE_ARRAY_OFFSET)), Short.MAX_VALUE); pos -= 1; - assertEquals(MemoryBuffer.unsafeGet(heapMemory, pos), Byte.MIN_VALUE); + assertEquals(buffer.get((int) (pos - Platform.BYTE_ARRAY_OFFSET)), Byte.MIN_VALUE); } { MemoryBuffer buffer = MemoryUtils.buffer(1024); diff --git a/java/fury-core/src/test/java/org/apache/fury/serializer/PrimitiveSerializersTest.java b/java/fury-core/src/test/java/org/apache/fury/serializer/PrimitiveSerializersTest.java index e7d56cf3c5..42a4b3b3d1 100644 --- a/java/fury-core/src/test/java/org/apache/fury/serializer/PrimitiveSerializersTest.java +++ b/java/fury-core/src/test/java/org/apache/fury/serializer/PrimitiveSerializersTest.java @@ -21,12 +21,17 @@ import static org.testng.Assert.*; +import lombok.AllArgsConstructor; +import lombok.Data; import org.apache.fury.Fury; +import org.apache.fury.FuryTestBase; +import org.apache.fury.config.FuryBuilder; import org.apache.fury.config.Language; +import org.apache.fury.config.LongEncoding; import org.apache.fury.memory.MemoryBuffer; import org.testng.annotations.Test; -public class PrimitiveSerializersTest { +public class PrimitiveSerializersTest extends FuryTestBase { @Test public void testUint8Serializer() { Fury fury = Fury.builder().withLanguage(Language.XLANG).requireClassRegistration(false).build(); @@ -54,4 +59,62 @@ public void testUint16Serializer() { assertThrows(IllegalArgumentException.class, () -> serializer.xwrite(buffer, -1)); assertThrows(IllegalArgumentException.class, () -> serializer.xwrite(buffer, 65536)); } + + @Data + @AllArgsConstructor + public static class PrimitiveStruct { + byte byte1; + byte byte2; + char char1; + char char2; + short short1; + short short2; + int int1; + int int2; + long long1; + long long2; + float float1; + float float2; + double double1; + double double2; + } + + @Test(dataProvider = "compressNumberAndCodeGen") + public void testPrimitiveStruct(boolean compressNumber, boolean codegen) { + PrimitiveStruct struct = + new PrimitiveStruct( + Byte.MIN_VALUE, + Byte.MIN_VALUE, + Character.MIN_VALUE, + Character.MIN_VALUE, + Short.MIN_VALUE, + Short.MIN_VALUE, + Integer.MIN_VALUE, + Integer.MIN_VALUE, + Long.MIN_VALUE, + Long.MIN_VALUE, + Float.MIN_VALUE, + Float.MIN_VALUE, + Double.MIN_VALUE, + Double.MIN_VALUE); + if (compressNumber) { + FuryBuilder builder = + Fury.builder() + .withLanguage(Language.JAVA) + .withCodegen(codegen) + .requireClassRegistration(false); + serDeCheck( + builder.withNumberCompressed(true).withLongCompressed(LongEncoding.PVL).build(), struct); + serDeCheck( + builder.withNumberCompressed(true).withLongCompressed(LongEncoding.SLI).build(), struct); + } else { + Fury fury = + Fury.builder() + .withLanguage(Language.JAVA) + .withCodegen(codegen) + .requireClassRegistration(false) + .build(); + serDeCheck(fury, struct); + } + } }