Skip to content

Commit

Permalink
feat(java): optimize read float/double for jvm jit inline (#1472)
Browse files Browse the repository at this point in the history
- optimize read float jvm jit inline by separate little/big endian
methods
- optimize read double jvm jit inline by separate little/big endian
methods
- generate unsafe get float code online
- generate unsafe get double code online
  • Loading branch information
chaokunyang authored Apr 6, 2024
1 parent 9597962 commit 1d302ef
Show file tree
Hide file tree
Showing 5 changed files with 158 additions and 129 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -572,18 +572,13 @@ protected Expression unsafePutDouble(Expression base, Expression pos, Expression
return new StaticInvoke(MemoryBuffer.class, "unsafePutDouble", base, pos, value);
}

/**
* Build unsafeGet operation.
*
* @see MemoryBuffer#unsafeGet(Object, long)
*/
/** Build unsafeGet operation. */
protected Expression unsafeGet(Expression base, Expression pos) {
return new StaticInvoke(MemoryBuffer.class, "unsafeGet", PRIMITIVE_BYTE_TYPE, base, pos);
return new StaticInvoke(Platform.class, "getByte", PRIMITIVE_BYTE_TYPE, base, pos);
}

protected Expression unsafeGetBoolean(Expression base, Expression pos) {
return new StaticInvoke(
MemoryBuffer.class, "unsafeGetBoolean", PRIMITIVE_BOOLEAN_TYPE, base, pos);
return new StaticInvoke(Platform.class, "getBoolean", PRIMITIVE_BOOLEAN_TYPE, base, pos);
}

protected Expression unsafeGetChar(Expression base, Expression pos) {
Expand Down Expand Up @@ -614,18 +609,25 @@ protected Expression unsafeGetInt(Expression base, Expression pos) {
protected Expression unsafeGetLong(Expression base, Expression pos) {
StaticInvoke expr = new StaticInvoke(Platform.class, "getLong", PRIMITIVE_LONG_TYPE, base, pos);
if (!Platform.IS_LITTLE_ENDIAN) {
expr = new StaticInvoke(Long.class, "reverseBytes", PRIMITIVE_INT_TYPE, expr.inline());
expr = new StaticInvoke(Long.class, "reverseBytes", PRIMITIVE_LONG_TYPE, expr.inline());
}
return expr;
}

protected Expression unsafeGetFloat(Expression base, Expression pos) {
return new StaticInvoke(MemoryBuffer.class, "unsafeGetFloat", PRIMITIVE_FLOAT_TYPE, base, pos);
StaticInvoke expr = new StaticInvoke(Platform.class, "getInt", PRIMITIVE_INT_TYPE, base, pos);
if (!Platform.IS_LITTLE_ENDIAN) {
expr = new StaticInvoke(Integer.class, "reverseBytes", PRIMITIVE_INT_TYPE, expr.inline());
}
return new StaticInvoke(Float.class, "intBitsToFloat", PRIMITIVE_FLOAT_TYPE, expr.inline());
}

protected Expression unsafeGetDouble(Expression base, Expression pos) {
return new StaticInvoke(
MemoryBuffer.class, "unsafeGetDouble", PRIMITIVE_DOUBLE_TYPE, base, pos);
StaticInvoke expr = new StaticInvoke(Platform.class, "getLong", PRIMITIVE_LONG_TYPE, base, pos);
if (!Platform.IS_LITTLE_ENDIAN) {
expr = new StaticInvoke(Long.class, "reverseBytes", PRIMITIVE_LONG_TYPE, expr.inline());
}
return new StaticInvoke(Double.class, "longBitsToDouble", PRIMITIVE_DOUBLE_TYPE, expr.inline());
}

protected Expression readChar(Expression buffer) {
Expand Down Expand Up @@ -655,4 +657,14 @@ protected Expression readLong(Expression buffer) {
public static String readLongFunc() {
return Platform.IS_LITTLE_ENDIAN ? "readLongOnLE" : "readLongOnBE";
}

protected Expression readFloat(Expression buffer) {
String func = Platform.IS_LITTLE_ENDIAN ? "readFloatOnLE" : "readFloatOnBE";
return new Invoke(buffer, func, PRIMITIVE_FLOAT_TYPE);
}

protected Expression readDouble(Expression buffer) {
String func = Platform.IS_LITTLE_ENDIAN ? "readDoubleOnLE" : "readDoubleOnBE";
return new Invoke(buffer, func, PRIMITIVE_DOUBLE_TYPE);
}
}
175 changes: 62 additions & 113 deletions java/fury-core/src/main/java/org/apache/fury/memory/MemoryBuffer.java
Original file line number Diff line number Diff line change
Expand Up @@ -252,10 +252,6 @@ private void checkPosition(long index, long pos, long length) {
}
}

public static byte unsafeGet(Object o, long offset) {
return UNSAFE.getByte(o, offset);
}

public byte unsafeGet(int index) {
final long pos = address + index;
return UNSAFE.getByte(heapMemory, pos);
Expand Down Expand Up @@ -444,15 +440,6 @@ public void put(int index, byte[] src, int offset, int length) {
}
}

public static boolean unsafeGetBoolean(Object o, long offset) {
return UNSAFE.getBoolean(o, offset);
}

public boolean unsafeGetBoolean(int index) {
final long pos = address + index;
return UNSAFE.getByte(heapMemory, pos) != 0;
}

public boolean getBoolean(int index) {
return get(index) != 0;
}
Expand All @@ -475,14 +462,6 @@ public char getCharN(int index) {
return UNSAFE.getChar(heapMemory, pos);
}

public char getCharB(int index) {
if (LITTLE_ENDIAN) {
return Character.reverseBytes(getCharN(index));
} else {
return getCharN(index);
}
}

public char getChar(int index) {
if (LITTLE_ENDIAN) {
return getCharN(index);
Expand Down Expand Up @@ -548,12 +527,6 @@ public short getShort(int index) {
}
}

public void putShortN(int index, short value) {
final long pos = address + index;
checkPosition(index, pos, 2);
UNSAFE.putShort(heapMemory, pos, value);
}

public void putShort(int index, short value) {
final long pos = address + index;
checkPosition(index, pos, 2);
Expand Down Expand Up @@ -590,12 +563,6 @@ public static void unsafePutShort(Object o, long pos, short value) {
}
}

public int getIntN(int index) {
final long pos = address + index;
checkPosition(index, pos, 4);
return UNSAFE.getInt(heapMemory, pos);
}

public int getInt(int index) {
final long pos = address + index;
checkPosition(index, pos, 4);
Expand All @@ -606,12 +573,6 @@ public int getInt(int index) {
}
}

public void putIntN(int index, int value) {
final long pos = address + index;
checkPosition(index, pos, 4);
UNSAFE.putInt(heapMemory, pos, value);
}

public void putInt(int index, int value) {
final long pos = address + index;
checkPosition(index, pos, 4);
Expand All @@ -622,11 +583,6 @@ public void putInt(int index, int value) {
}
}

public int unsafeGetIntN(int index) {
final long pos = address + index;
return UNSAFE.getInt(heapMemory, pos);
}

public int unsafeGetInt(int index) {
final long pos = address + index;
if (LITTLE_ENDIAN) {
Expand All @@ -644,11 +600,6 @@ public static int unsafeGetInt(Object o, long pos) {
}
}

public void unsafePutIntN(int index, int value) {
final long pos = address + index;
UNSAFE.putInt(heapMemory, pos, value);
}

public void unsafePutInt(int index, int value) {
final long pos = address + index;
if (LITTLE_ENDIAN) {
Expand All @@ -666,12 +617,6 @@ public static void unsafePutInt(Object o, long pos, int value) {
}
}

public long getLongN(int index) {
final long pos = address + index;
checkPosition(index, pos, 8);
return UNSAFE.getLong(heapMemory, pos);
}

public long getLong(int index) {
final long pos = address + index;
checkPosition(index, pos, 8);
Expand Down Expand Up @@ -716,11 +661,6 @@ public void putLongB(int index, long value) {
}
}

public long unsafeGetLongN(int index) {
final long pos = address + index;
return UNSAFE.getLong(heapMemory, pos);
}

public long unsafeGetLong(int index) {
final long pos = address + index;
if (LITTLE_ENDIAN) {
Expand All @@ -738,11 +678,6 @@ public static long unsafeGetLong(Object o, long pos) {
}
}

public void unsafePutLongN(int index, long value) {
final long pos = address + index;
UNSAFE.putLong(heapMemory, pos, value);
}

public void unsafePutLong(int index, long value) {
final long pos = address + index;
if (LITTLE_ENDIAN) {
Expand All @@ -760,10 +695,6 @@ public static void unsafePutLong(Object o, long pos, long value) {
}
}

public float getFloatN(int index) {
return Float.intBitsToFloat(getIntN(index));
}

public float getFloat(int index) {
final long pos = address + index;
if (LITTLE_ENDIAN) {
Expand All @@ -773,10 +704,6 @@ public float getFloat(int index) {
}
}

public void putFloatN(int index, float value) {
putIntN(index, Float.floatToRawIntBits(value));
}

public void putFloat(int index, float value) {
final long pos = address + index;
checkPosition(index, pos, 4);
Expand All @@ -787,10 +714,6 @@ public void putFloat(int index, float value) {
}
}

public float unsafeGetFloatN(int index) {
return Float.intBitsToFloat(unsafeGetIntN(index));
}

public float unsafeGetFloat(int index) {
final long pos = address + index;
if (LITTLE_ENDIAN) {
Expand All @@ -800,18 +723,6 @@ public float unsafeGetFloat(int index) {
}
}

public static float unsafeGetFloat(Object o, long pos) {
if (LITTLE_ENDIAN) {
return Float.intBitsToFloat(UNSAFE.getInt(o, pos));
} else {
return Float.intBitsToFloat(Integer.reverseBytes(UNSAFE.getInt(o, pos)));
}
}

public void unsafePutFloatN(int index, float value) {
unsafePutIntN(index, Float.floatToRawIntBits(value));
}

public void unsafePutFloat(int index, float value) {
final long pos = address + index;
if (LITTLE_ENDIAN) {
Expand All @@ -829,10 +740,6 @@ public static void unsafePutFloat(Object o, long pos, float value) {
}
}

public double getDoubleN(int index) {
return Double.longBitsToDouble(getLongN(index));
}

public double getDouble(int index) {
final long pos = address + index;
checkPosition(index, pos, 8);
Expand All @@ -843,10 +750,6 @@ public double getDouble(int index) {
}
}

public void putDoubleN(int index, double value) {
putLongN(index, Double.doubleToRawLongBits(value));
}

public void putDouble(int index, double value) {
final long pos = address + index;
checkPosition(index, pos, 8);
Expand All @@ -857,10 +760,6 @@ public void putDouble(int index, double value) {
}
}

public double unsafeGetDoubleN(int index) {
return Double.longBitsToDouble(unsafeGetLongN(index));
}

public double unsafeGetDouble(int index) {
final long pos = address + index;
if (LITTLE_ENDIAN) {
Expand All @@ -870,18 +769,6 @@ public double unsafeGetDouble(int index) {
}
}

public static double unsafeGetDouble(Object o, long pos) {
if (LITTLE_ENDIAN) {
return Double.longBitsToDouble(UNSAFE.getLong(o, pos));
} else {
return Double.longBitsToDouble(Long.reverseBytes(UNSAFE.getLong(o, pos)));
}
}

public void unsafePutDoubleN(int index, double value) {
unsafePutLongN(index, Double.doubleToRawLongBits(value));
}

public void unsafePutDouble(int index, double value) {
final long pos = address + index;
if (LITTLE_ENDIAN) {
Expand Down Expand Up @@ -2318,6 +2205,37 @@ public float readFloat() {
}
}

// Reduce method body for better inline in the caller.
@CodegenInvoke
public float readFloatOnLE() {
int readerIdx = readerIndex;
// use subtract to avoid overflow
int remaining = size - readerIdx;
if (remaining < 4) {
throw new IndexOutOfBoundsException(
String.format(
"readerIndex(%d) + length(%d) exceeds size(%d): %s", readerIdx, 4, size, this));
}
readerIndex = readerIdx + 4;
return Float.intBitsToFloat(UNSAFE.getInt(heapMemory, address + readerIdx));
}

// Reduce method body for better inline in the caller.
@CodegenInvoke
public float readFloatOnBE() {
int readerIdx = readerIndex;
// use subtract to avoid overflow
int remaining = size - readerIdx;
if (remaining < 4) {
throw new IndexOutOfBoundsException(
String.format(
"readerIndex(%d) + length(%d) exceeds size(%d): %s", readerIdx, 4, size, this));
}
readerIndex = readerIdx + 4;
return Float.intBitsToFloat(
Integer.reverseBytes(UNSAFE.getInt(heapMemory, address + readerIdx)));
}

public double readDouble() {
int readerIdx = readerIndex;
// use subtract to avoid overflow
Expand All @@ -2335,6 +2253,37 @@ public double readDouble() {
}
}

// Reduce method body for better inline in the caller.
@CodegenInvoke
public double readDoubleOnLE() {
int readerIdx = readerIndex;
// use subtract to avoid overflow
int remaining = size - readerIdx;
if (remaining < 8) {
throw new IndexOutOfBoundsException(
String.format(
"readerIndex(%d) + length(%d) exceeds size(%d): %s", readerIdx, 8, size, this));
}
readerIndex = readerIdx + 8;
return Double.longBitsToDouble(UNSAFE.getLong(heapMemory, address + readerIdx));
}

// Reduce method body for better inline in the caller.
@CodegenInvoke
public double readDoubleOnBE() {
int readerIdx = readerIndex;
// use subtract to avoid overflow
int remaining = size - readerIdx;
if (remaining < 8) {
throw new IndexOutOfBoundsException(
String.format(
"readerIndex(%d) + length(%d) exceeds size(%d): %s", readerIdx, 8, size, this));
}
readerIndex = readerIdx + 8;
return Double.longBitsToDouble(
Long.reverseBytes(UNSAFE.getLong(heapMemory, address + readerIdx)));
}

public byte[] readBytes(int length) {
int readerIdx = readerIndex;
// use subtract to avoid overflow
Expand Down
Loading

0 comments on commit 1d302ef

Please sign in to comment.