From 71c5b76fa6affd7453be27709ae2dfcc90eae52d Mon Sep 17 00:00:00 2001 From: Nikita Ivchenko Date: Wed, 10 Apr 2024 10:44:02 +0300 Subject: [PATCH] feat(java): channel stream reader (#1483) Co-authored-by: Shawn Yang --- .../main/java/org/apache/fury/BaseFury.java | 9 ++ .../src/main/java/org/apache/fury/Fury.java | 32 +++++ .../java/org/apache/fury/ThreadLocalFury.java | 21 +++ .../org/apache/fury/io/FuryInputStream.java | 1 - .../apache/fury/io/FuryReadableChannel.java | 126 +++++++++++++++--- .../org/apache/fury/io/FuryStreamReader.java | 13 ++ .../org/apache/fury/memory/MemoryBuffer.java | 20 ++- .../org/apache/fury/pool/ThreadPoolFury.java | 21 +++ .../test/java/org/apache/fury/StreamTest.java | 51 +++++++ 9 files changed, 272 insertions(+), 22 deletions(-) diff --git a/java/fury-core/src/main/java/org/apache/fury/BaseFury.java b/java/fury-core/src/main/java/org/apache/fury/BaseFury.java index 1bcb31ee97..91a62cc49e 100644 --- a/java/fury-core/src/main/java/org/apache/fury/BaseFury.java +++ b/java/fury-core/src/main/java/org/apache/fury/BaseFury.java @@ -21,6 +21,7 @@ import java.io.OutputStream; import org.apache.fury.io.FuryInputStream; +import org.apache.fury.io.FuryReadableChannel; import org.apache.fury.memory.MemoryBuffer; import org.apache.fury.serializer.BufferCallback; import org.apache.fury.serializer.Serializer; @@ -114,6 +115,10 @@ public interface BaseFury { Object deserialize(FuryInputStream inputStream, Iterable outOfBandBuffers); + Object deserialize(FuryReadableChannel channel); + + Object deserialize(FuryReadableChannel channel, Iterable outOfBandBuffers); + /** * Serialize java object without class info, deserialization should use {@link * #deserializeJavaObject}. @@ -142,6 +147,8 @@ public interface BaseFury { T deserializeJavaObject(FuryInputStream inputStream, Class cls); + T deserializeJavaObject(FuryReadableChannel channel, Class cls); + byte[] serializeJavaObjectAndClass(Object obj); void serializeJavaObjectAndClass(MemoryBuffer buffer, Object obj); @@ -153,4 +160,6 @@ public interface BaseFury { Object deserializeJavaObjectAndClass(MemoryBuffer buffer); Object deserializeJavaObjectAndClass(FuryInputStream inputStream); + + Object deserializeJavaObjectAndClass(FuryReadableChannel channel); } diff --git a/java/fury-core/src/main/java/org/apache/fury/Fury.java b/java/fury-core/src/main/java/org/apache/fury/Fury.java index 63421ada60..e1a8bbd5b2 100644 --- a/java/fury-core/src/main/java/org/apache/fury/Fury.java +++ b/java/fury-core/src/main/java/org/apache/fury/Fury.java @@ -38,6 +38,7 @@ import org.apache.fury.config.LongEncoding; import org.apache.fury.exception.DeserializationException; import org.apache.fury.io.FuryInputStream; +import org.apache.fury.io.FuryReadableChannel; import org.apache.fury.logging.Logger; import org.apache.fury.logging.LoggerFactory; import org.apache.fury.memory.MemoryBuffer; @@ -771,6 +772,17 @@ public Object deserialize(FuryInputStream inputStream, Iterable ou } } + @Override + public Object deserialize(FuryReadableChannel channel) { + return deserialize(channel, null); + } + + @Override + public Object deserialize(FuryReadableChannel channel, Iterable outOfBandBuffers) { + MemoryBuffer buf = channel.getBuffer(); + return deserialize(buf, outOfBandBuffers); + } + private RuntimeException handleReadFailed(Throwable t) { if (refResolver instanceof MapRefResolver) { ObjectArray readObjects = ((MapRefResolver) refResolver).getReadObjects(); @@ -1092,6 +1104,16 @@ public T deserializeJavaObject(FuryInputStream inputStream, Class cls) { } } + /** + * Deserialize java object from binary channel by passing class info, serialization should use + * {@link #serializeJavaObject}. + */ + @Override + public T deserializeJavaObject(FuryReadableChannel channel, Class cls) { + MemoryBuffer buf = channel.getBuffer(); + return deserializeJavaObject(buf, cls); + } + /** * Deserialize java object from binary by passing class info, serialization should use {@link * #deserializeJavaObjectAndClass}. @@ -1181,6 +1203,16 @@ public Object deserializeJavaObjectAndClass(FuryInputStream inputStream) { } } + /** + * Deserialize class info and java object from binary channel, serialization should use {@link + * #serializeJavaObjectAndClass}. + */ + @Override + public Object deserializeJavaObjectAndClass(FuryReadableChannel channel) { + MemoryBuffer buf = channel.getBuffer(); + return deserializeJavaObjectAndClass(buf); + } + private void serializeToStream(OutputStream outputStream, Consumer function) { MemoryBuffer buf = getBuffer(); if (outputStream.getClass() == ByteArrayOutputStream.class) { diff --git a/java/fury-core/src/main/java/org/apache/fury/ThreadLocalFury.java b/java/fury-core/src/main/java/org/apache/fury/ThreadLocalFury.java index cf6e45df92..67928cae37 100644 --- a/java/fury-core/src/main/java/org/apache/fury/ThreadLocalFury.java +++ b/java/fury-core/src/main/java/org/apache/fury/ThreadLocalFury.java @@ -26,6 +26,7 @@ import java.util.function.Function; import javax.annotation.concurrent.ThreadSafe; import org.apache.fury.io.FuryInputStream; +import org.apache.fury.io.FuryReadableChannel; import org.apache.fury.memory.MemoryBuffer; import org.apache.fury.memory.MemoryUtils; import org.apache.fury.resolver.ClassResolver; @@ -164,6 +165,16 @@ public Object deserialize(FuryInputStream inputStream, Iterable ou return bindingThreadLocal.get().get().deserialize(inputStream, outOfBandBuffers); } + @Override + public Object deserialize(FuryReadableChannel channel) { + return bindingThreadLocal.get().get().deserialize(channel); + } + + @Override + public Object deserialize(FuryReadableChannel channel, Iterable outOfBandBuffers) { + return bindingThreadLocal.get().get().deserialize(channel, outOfBandBuffers); + } + @Override public byte[] serializeJavaObject(Object obj) { return bindingThreadLocal.get().get().serializeJavaObject(obj); @@ -194,6 +205,11 @@ public T deserializeJavaObject(FuryInputStream inputStream, Class cls) { return bindingThreadLocal.get().get().deserializeJavaObject(inputStream, cls); } + @Override + public T deserializeJavaObject(FuryReadableChannel channel, Class cls) { + return bindingThreadLocal.get().get().deserializeJavaObject(channel, cls); + } + @Override public byte[] serializeJavaObjectAndClass(Object obj) { return bindingThreadLocal.get().get().serializeJavaObjectAndClass(obj); @@ -224,6 +240,11 @@ public Object deserializeJavaObjectAndClass(FuryInputStream inputStream) { return bindingThreadLocal.get().get().deserializeJavaObjectAndClass(inputStream); } + @Override + public Object deserializeJavaObjectAndClass(FuryReadableChannel channel) { + return bindingThreadLocal.get().get().deserializeJavaObjectAndClass(channel); + } + @Override public void setClassLoader(ClassLoader classLoader) { setClassLoader(classLoader, StagingType.SOFT_STAGING); diff --git a/java/fury-core/src/main/java/org/apache/fury/io/FuryInputStream.java b/java/fury-core/src/main/java/org/apache/fury/io/FuryInputStream.java index 4fc321528b..f5006b4b19 100644 --- a/java/fury-core/src/main/java/org/apache/fury/io/FuryInputStream.java +++ b/java/fury-core/src/main/java/org/apache/fury/io/FuryInputStream.java @@ -33,7 +33,6 @@ */ @NotThreadSafe public class FuryInputStream extends InputStream implements FuryStreamReader { - private static final int BUFFER_GROW_STEP_THRESHOLD = 100 * 1024 * 1024; private final InputStream stream; private final int bufferSize; private final MemoryBuffer buffer; diff --git a/java/fury-core/src/main/java/org/apache/fury/io/FuryReadableChannel.java b/java/fury-core/src/main/java/org/apache/fury/io/FuryReadableChannel.java index 74cd6e0377..0aea9e7691 100644 --- a/java/fury-core/src/main/java/org/apache/fury/io/FuryReadableChannel.java +++ b/java/fury-core/src/main/java/org/apache/fury/io/FuryReadableChannel.java @@ -22,42 +22,136 @@ import java.io.IOException; import java.nio.ByteBuffer; import java.nio.channels.ReadableByteChannel; +import javax.annotation.concurrent.NotThreadSafe; +import org.apache.fury.exception.DeserializationException; import org.apache.fury.memory.MemoryBuffer; +import org.apache.fury.util.Platform; +import org.apache.fury.util.Preconditions; -// TODO support zero-copy channel reading. -public class FuryReadableChannel extends AbstractStreamReader implements ReadableByteChannel { +@NotThreadSafe +public class FuryReadableChannel implements FuryStreamReader, ReadableByteChannel { private final ReadableByteChannel channel; - private final ByteBuffer byteBuffer; - private final MemoryBuffer buffer; + private final MemoryBuffer memoryBuffer; + private ByteBuffer byteBuffer; public FuryReadableChannel(ReadableByteChannel channel) { - this(channel, ByteBuffer.allocate(4096)); + this(channel, ByteBuffer.allocateDirect(4096)); } - private FuryReadableChannel(ReadableByteChannel channel, ByteBuffer directBuffer) { + public FuryReadableChannel(ReadableByteChannel channel, ByteBuffer directBuffer) { + Preconditions.checkArgument( + directBuffer.isDirect(), "FuryReadableChannel support only direct ByteBuffer."); this.channel = channel; this.byteBuffer = directBuffer; - this.buffer = MemoryBuffer.fromByteBuffer(directBuffer); + this.memoryBuffer = MemoryBuffer.fromDirectByteBuffer(directBuffer, 0, this); + } + + @Override + public int fillBuffer(int minFillSize) { + try { + ByteBuffer byteBuf = byteBuffer; + MemoryBuffer memoryBuf = memoryBuffer; + int position = byteBuf.position(); + int newLimit = position + minFillSize; + if (newLimit > byteBuf.capacity()) { + int newSize = + newLimit < BUFFER_GROW_STEP_THRESHOLD ? newLimit << 2 : (int) (newLimit * 1.5); + ByteBuffer newByteBuf = ByteBuffer.allocateDirect(newSize); + byteBuf.position(0); + newByteBuf.put(byteBuf); + byteBuf = byteBuffer = newByteBuf; + memoryBuf.initDirectBuffer(Platform.getAddress(byteBuf), position, byteBuf); + } + byteBuf.limit(newLimit); + int readCount = channel.read(byteBuf); + memoryBuf.increaseSize(readCount); + return readCount; + } catch (IOException e) { + throw new DeserializationException("Failed to read the provided byte channel", e); + } } @Override public int read(ByteBuffer dst) throws IOException { - MemoryBuffer buf = buffer; + int length = dst.remaining(); + MemoryBuffer buf = memoryBuffer; + int remaining = buf.remaining(); + if (remaining >= length) { + buf.read(dst, length); + return length; + } else { + buf.read(dst, remaining); + return channel.read(dst) + remaining; + } + } + + @Override + public void readTo(byte[] dst, int dstIndex, int length) { + MemoryBuffer buf = memoryBuffer; int remaining = buf.remaining(); - int len = dst.remaining(); - if (remaining >= len) { - buf.read(dst, len); - return len; + if (remaining >= length) { + buf.readBytes(dst, dstIndex, length); } else { + buf.readBytes(dst, dstIndex, remaining); try { - buf.read(dst, remaining); - return channel.read(dst) + remaining; + ByteBuffer buffer = ByteBuffer.wrap(dst, dstIndex + remaining, length - remaining); + channel.read(buffer); } catch (IOException e) { - throw new RuntimeException(e); + throw new DeserializationException("Failed to read the provided byte channel", e); } } } + @Override + public void readToUnsafe(Object target, long targetPointer, int numBytes) { + MemoryBuffer buf = memoryBuffer; + int remaining = buf.remaining(); + if (remaining < numBytes) { + fillBuffer(numBytes - remaining); + } + long address = buf.getUnsafeReaderAddress(); + Platform.copyMemory(null, address, target, targetPointer, numBytes); + buf.increaseReaderIndex(numBytes); + } + + @Override + public void readToByteBuffer(ByteBuffer dst, int length) { + MemoryBuffer buf = memoryBuffer; + int remaining = buf.remaining(); + if (remaining >= length) { + buf.read(dst, length); + } else { + buf.read(dst, remaining); + try { + int dstLimit = dst.limit(); + int newLimit = dst.position() + length - remaining; + if (dstLimit > newLimit) { + dst.limit(newLimit); + channel.read(dst); + dst.limit(dstLimit); + } else { + channel.read(dst); + } + } catch (IOException e) { + throw new DeserializationException("Failed to read the provided byte channel", e); + } + } + } + + @Override + public int readToByteBuffer(ByteBuffer dst) { + MemoryBuffer buf = memoryBuffer; + int remaining = buf.remaining(); + if (remaining > 0) { + buf.read(dst, remaining); + } + try { + return channel.read(dst) + remaining; + } catch (IOException e) { + throw new DeserializationException("Failed to read the provided byte channel", e); + } + } + @Override public boolean isOpen() { return channel.isOpen(); @@ -70,6 +164,6 @@ public void close() throws IOException { @Override public MemoryBuffer getBuffer() { - return buffer; + return memoryBuffer; } } diff --git a/java/fury-core/src/main/java/org/apache/fury/io/FuryStreamReader.java b/java/fury-core/src/main/java/org/apache/fury/io/FuryStreamReader.java index a40a4b807f..e02d4efbca 100644 --- a/java/fury-core/src/main/java/org/apache/fury/io/FuryStreamReader.java +++ b/java/fury-core/src/main/java/org/apache/fury/io/FuryStreamReader.java @@ -21,10 +21,13 @@ import java.io.InputStream; import java.nio.ByteBuffer; +import java.nio.channels.SeekableByteChannel; import org.apache.fury.memory.MemoryBuffer; /** A streaming reader to make {@link MemoryBuffer} to support streaming reading. */ public interface FuryStreamReader { + int BUFFER_GROW_STEP_THRESHOLD = 100 * 1024 * 1024; + /** * Read stream and fill the data to underlying {@link MemoryBuffer}, which is also the buffer * returned by {@link #getBuffer}. @@ -63,4 +66,14 @@ public interface FuryStreamReader { static FuryInputStream of(InputStream stream) { return new FuryInputStream(stream); } + + /** + * Create a {@link FuryReadableChannel} from the provided {@link SeekableByteChannel}. Note that + * the provided channel will be owned by the returned {@link FuryReadableChannel}, do + * not read the provided {@link SeekableByteChannel} anymore, read the returned stream + * instead. + */ + static FuryReadableChannel of(SeekableByteChannel channel) { + return new FuryReadableChannel(channel); + } } diff --git a/java/fury-core/src/main/java/org/apache/fury/memory/MemoryBuffer.java b/java/fury-core/src/main/java/org/apache/fury/memory/MemoryBuffer.java index cd4b10dc8b..0bf0882bcc 100644 --- a/java/fury-core/src/main/java/org/apache/fury/memory/MemoryBuffer.java +++ b/java/fury-core/src/main/java/org/apache/fury/memory/MemoryBuffer.java @@ -158,6 +158,15 @@ private MemoryBuffer(long offHeapAddress, int size, ByteBuffer offHeapBuffer) { */ private MemoryBuffer( long offHeapAddress, int size, ByteBuffer offHeapBuffer, FuryStreamReader streamReader) { + initDirectBuffer(offHeapAddress, size, offHeapBuffer); + if (streamReader != null) { + this.streamReader = streamReader; + } else { + this.streamReader = new BoundChecker(); + } + } + + public void initDirectBuffer(long offHeapAddress, int size, ByteBuffer offHeapBuffer) { this.offHeapBuffer = offHeapBuffer; if (offHeapAddress <= 0) { throw new IllegalArgumentException("negative pointer or size"); @@ -175,11 +184,6 @@ private MemoryBuffer( this.address = offHeapAddress; this.addressLimit = this.address + size; this.size = size; - if (streamReader != null) { - this.streamReader = streamReader; - } else { - this.streamReader = new BoundChecker(); - } } private class BoundChecker extends AbstractStreamReader { @@ -2915,6 +2919,12 @@ public static MemoryBuffer fromByteBuffer(ByteBuffer buffer) { } } + public static MemoryBuffer fromDirectByteBuffer( + ByteBuffer buffer, int size, FuryStreamReader streamReader) { + long offHeapAddress = Platform.getAddress(buffer) + buffer.position(); + return new MemoryBuffer(offHeapAddress, size, buffer, streamReader); + } + /** * Creates a new memory buffer that represents the provided native memory. The buffer will change * into a heap buffer automatically if not enough. diff --git a/java/fury-core/src/main/java/org/apache/fury/pool/ThreadPoolFury.java b/java/fury-core/src/main/java/org/apache/fury/pool/ThreadPoolFury.java index 3c207a6aaa..e9367dcfc6 100644 --- a/java/fury-core/src/main/java/org/apache/fury/pool/ThreadPoolFury.java +++ b/java/fury-core/src/main/java/org/apache/fury/pool/ThreadPoolFury.java @@ -28,6 +28,7 @@ import org.apache.fury.AbstractThreadSafeFury; import org.apache.fury.Fury; import org.apache.fury.io.FuryInputStream; +import org.apache.fury.io.FuryReadableChannel; import org.apache.fury.memory.MemoryBuffer; import org.apache.fury.memory.MemoryUtils; import org.apache.fury.serializer.BufferCallback; @@ -153,6 +154,16 @@ public Object deserialize(FuryInputStream inputStream, Iterable ou return execute(fury -> fury.deserialize(inputStream, outOfBandBuffers)); } + @Override + public Object deserialize(FuryReadableChannel channel) { + return execute(fury -> fury.deserialize(channel)); + } + + @Override + public Object deserialize(FuryReadableChannel channel, Iterable outOfBandBuffers) { + return execute(fury -> fury.deserialize(channel, outOfBandBuffers)); + } + @Override public byte[] serializeJavaObject(Object obj) { return execute(fury -> fury.serializeJavaObject(obj)); @@ -191,6 +202,11 @@ public T deserializeJavaObject(FuryInputStream inputStream, Class cls) { return execute(fury -> fury.deserializeJavaObject(inputStream, cls)); } + @Override + public T deserializeJavaObject(FuryReadableChannel channel, Class cls) { + return execute(fury -> fury.deserializeJavaObject(channel, cls)); + } + @Override public byte[] serializeJavaObjectAndClass(Object obj) { return execute(fury -> fury.serializeJavaObjectAndClass(obj)); @@ -229,6 +245,11 @@ public Object deserializeJavaObjectAndClass(FuryInputStream inputStream) { return execute(fury -> fury.deserializeJavaObjectAndClass(inputStream)); } + @Override + public Object deserializeJavaObjectAndClass(FuryReadableChannel channel) { + return execute(fury -> fury.deserializeJavaObjectAndClass(channel)); + } + @Override public void setClassLoader(ClassLoader classLoader) { setClassLoader(classLoader, LoaderBinding.StagingType.SOFT_STAGING); diff --git a/java/fury-core/src/test/java/org/apache/fury/StreamTest.java b/java/fury-core/src/test/java/org/apache/fury/StreamTest.java index 9d69bdc6bd..400e83a1e7 100644 --- a/java/fury-core/src/test/java/org/apache/fury/StreamTest.java +++ b/java/fury-core/src/test/java/org/apache/fury/StreamTest.java @@ -28,7 +28,10 @@ import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.InputStream; +import java.nio.file.Files; +import java.nio.file.Path; import org.apache.fury.io.FuryInputStream; +import org.apache.fury.io.FuryReadableChannel; import org.apache.fury.memory.MemoryBuffer; import org.apache.fury.test.bean.BeanA; import org.apache.fury.util.ReflectionUtils; @@ -204,4 +207,52 @@ public void testJavaOutputStream() throws IOException { assertEquals(newObj, beanA); } } + + @Test + public void testReadableChannel() throws IOException { + Fury fury = Fury.builder().requireClassRegistration(false).build(); + BeanA beanA = BeanA.createBeanA(2); + { + ByteArrayOutputStream bas = new ByteArrayOutputStream(); + fury.serialize(bas, beanA); + + Path tempFile = Files.createTempFile("readable_channel_test", "data_1"); + Files.write(tempFile, bas.toByteArray()); + + try (FuryReadableChannel channel = of(Files.newByteChannel(tempFile))) { + Object newObj = fury.deserialize(channel); + assertEquals(newObj, beanA); + } finally { + Files.delete(tempFile); + } + } + { + ByteArrayOutputStream bas = new ByteArrayOutputStream(); + fury.serializeJavaObject(bas, beanA); + + Path tempFile = Files.createTempFile("readable_channel_test", "data_2"); + Files.write(tempFile, bas.toByteArray()); + + try (FuryReadableChannel channel = of(Files.newByteChannel(tempFile))) { + Object newObj = fury.deserializeJavaObject(channel, BeanA.class); + assertEquals(newObj, beanA); + } finally { + Files.delete(tempFile); + } + } + { + ByteArrayOutputStream bas = new ByteArrayOutputStream(); + fury.serializeJavaObjectAndClass(bas, beanA); + + Path tempFile = Files.createTempFile("readable_channel_test", "data_3"); + Files.write(tempFile, bas.toByteArray()); + + try (FuryReadableChannel channel = of(Files.newByteChannel(tempFile))) { + Object newObj = fury.deserializeJavaObjectAndClass(channel); + assertEquals(newObj, beanA); + } finally { + Files.delete(tempFile); + } + } + } }