Skip to content

Commit

Permalink
feat(java): channel stream reader (#1483)
Browse files Browse the repository at this point in the history
Co-authored-by: Shawn Yang <[email protected]>
  • Loading branch information
Munoon and chaokunyang authored Apr 10, 2024
1 parent af97e62 commit 71c5b76
Show file tree
Hide file tree
Showing 9 changed files with 272 additions and 22 deletions.
9 changes: 9 additions & 0 deletions java/fury-core/src/main/java/org/apache/fury/BaseFury.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import java.io.OutputStream;
import org.apache.fury.io.FuryInputStream;
import org.apache.fury.io.FuryReadableChannel;
import org.apache.fury.memory.MemoryBuffer;
import org.apache.fury.serializer.BufferCallback;
import org.apache.fury.serializer.Serializer;
Expand Down Expand Up @@ -114,6 +115,10 @@ public interface BaseFury {

Object deserialize(FuryInputStream inputStream, Iterable<MemoryBuffer> outOfBandBuffers);

Object deserialize(FuryReadableChannel channel);

Object deserialize(FuryReadableChannel channel, Iterable<MemoryBuffer> outOfBandBuffers);

/**
* Serialize java object without class info, deserialization should use {@link
* #deserializeJavaObject}.
Expand Down Expand Up @@ -142,6 +147,8 @@ public interface BaseFury {

<T> T deserializeJavaObject(FuryInputStream inputStream, Class<T> cls);

<T> T deserializeJavaObject(FuryReadableChannel channel, Class<T> cls);

byte[] serializeJavaObjectAndClass(Object obj);

void serializeJavaObjectAndClass(MemoryBuffer buffer, Object obj);
Expand All @@ -153,4 +160,6 @@ public interface BaseFury {
Object deserializeJavaObjectAndClass(MemoryBuffer buffer);

Object deserializeJavaObjectAndClass(FuryInputStream inputStream);

Object deserializeJavaObjectAndClass(FuryReadableChannel channel);
}
32 changes: 32 additions & 0 deletions java/fury-core/src/main/java/org/apache/fury/Fury.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import org.apache.fury.config.LongEncoding;
import org.apache.fury.exception.DeserializationException;
import org.apache.fury.io.FuryInputStream;
import org.apache.fury.io.FuryReadableChannel;
import org.apache.fury.logging.Logger;
import org.apache.fury.logging.LoggerFactory;
import org.apache.fury.memory.MemoryBuffer;
Expand Down Expand Up @@ -771,6 +772,17 @@ public Object deserialize(FuryInputStream inputStream, Iterable<MemoryBuffer> ou
}
}

@Override
public Object deserialize(FuryReadableChannel channel) {
return deserialize(channel, null);
}

@Override
public Object deserialize(FuryReadableChannel channel, Iterable<MemoryBuffer> outOfBandBuffers) {
MemoryBuffer buf = channel.getBuffer();
return deserialize(buf, outOfBandBuffers);
}

private RuntimeException handleReadFailed(Throwable t) {
if (refResolver instanceof MapRefResolver) {
ObjectArray readObjects = ((MapRefResolver) refResolver).getReadObjects();
Expand Down Expand Up @@ -1092,6 +1104,16 @@ public <T> T deserializeJavaObject(FuryInputStream inputStream, Class<T> cls) {
}
}

/**
* Deserialize java object from binary channel by passing class info, serialization should use
* {@link #serializeJavaObject}.
*/
@Override
public <T> T deserializeJavaObject(FuryReadableChannel channel, Class<T> cls) {
MemoryBuffer buf = channel.getBuffer();
return deserializeJavaObject(buf, cls);
}

/**
* Deserialize java object from binary by passing class info, serialization should use {@link
* #deserializeJavaObjectAndClass}.
Expand Down Expand Up @@ -1181,6 +1203,16 @@ public Object deserializeJavaObjectAndClass(FuryInputStream inputStream) {
}
}

/**
* Deserialize class info and java object from binary channel, serialization should use {@link
* #serializeJavaObjectAndClass}.
*/
@Override
public Object deserializeJavaObjectAndClass(FuryReadableChannel channel) {
MemoryBuffer buf = channel.getBuffer();
return deserializeJavaObjectAndClass(buf);
}

private void serializeToStream(OutputStream outputStream, Consumer<MemoryBuffer> function) {
MemoryBuffer buf = getBuffer();
if (outputStream.getClass() == ByteArrayOutputStream.class) {
Expand Down
21 changes: 21 additions & 0 deletions java/fury-core/src/main/java/org/apache/fury/ThreadLocalFury.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import java.util.function.Function;
import javax.annotation.concurrent.ThreadSafe;
import org.apache.fury.io.FuryInputStream;
import org.apache.fury.io.FuryReadableChannel;
import org.apache.fury.memory.MemoryBuffer;
import org.apache.fury.memory.MemoryUtils;
import org.apache.fury.resolver.ClassResolver;
Expand Down Expand Up @@ -164,6 +165,16 @@ public Object deserialize(FuryInputStream inputStream, Iterable<MemoryBuffer> ou
return bindingThreadLocal.get().get().deserialize(inputStream, outOfBandBuffers);
}

@Override
public Object deserialize(FuryReadableChannel channel) {
return bindingThreadLocal.get().get().deserialize(channel);
}

@Override
public Object deserialize(FuryReadableChannel channel, Iterable<MemoryBuffer> outOfBandBuffers) {
return bindingThreadLocal.get().get().deserialize(channel, outOfBandBuffers);
}

@Override
public byte[] serializeJavaObject(Object obj) {
return bindingThreadLocal.get().get().serializeJavaObject(obj);
Expand Down Expand Up @@ -194,6 +205,11 @@ public <T> T deserializeJavaObject(FuryInputStream inputStream, Class<T> cls) {
return bindingThreadLocal.get().get().deserializeJavaObject(inputStream, cls);
}

@Override
public <T> T deserializeJavaObject(FuryReadableChannel channel, Class<T> cls) {
return bindingThreadLocal.get().get().deserializeJavaObject(channel, cls);
}

@Override
public byte[] serializeJavaObjectAndClass(Object obj) {
return bindingThreadLocal.get().get().serializeJavaObjectAndClass(obj);
Expand Down Expand Up @@ -224,6 +240,11 @@ public Object deserializeJavaObjectAndClass(FuryInputStream inputStream) {
return bindingThreadLocal.get().get().deserializeJavaObjectAndClass(inputStream);
}

@Override
public Object deserializeJavaObjectAndClass(FuryReadableChannel channel) {
return bindingThreadLocal.get().get().deserializeJavaObjectAndClass(channel);
}

@Override
public void setClassLoader(ClassLoader classLoader) {
setClassLoader(classLoader, StagingType.SOFT_STAGING);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
*/
@NotThreadSafe
public class FuryInputStream extends InputStream implements FuryStreamReader {
private static final int BUFFER_GROW_STEP_THRESHOLD = 100 * 1024 * 1024;
private final InputStream stream;
private final int bufferSize;
private final MemoryBuffer buffer;
Expand Down
126 changes: 110 additions & 16 deletions java/fury-core/src/main/java/org/apache/fury/io/FuryReadableChannel.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,42 +22,136 @@
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.ReadableByteChannel;
import javax.annotation.concurrent.NotThreadSafe;
import org.apache.fury.exception.DeserializationException;
import org.apache.fury.memory.MemoryBuffer;
import org.apache.fury.util.Platform;
import org.apache.fury.util.Preconditions;

// TODO support zero-copy channel reading.
public class FuryReadableChannel extends AbstractStreamReader implements ReadableByteChannel {
@NotThreadSafe
public class FuryReadableChannel implements FuryStreamReader, ReadableByteChannel {
private final ReadableByteChannel channel;
private final ByteBuffer byteBuffer;
private final MemoryBuffer buffer;
private final MemoryBuffer memoryBuffer;
private ByteBuffer byteBuffer;

public FuryReadableChannel(ReadableByteChannel channel) {
this(channel, ByteBuffer.allocate(4096));
this(channel, ByteBuffer.allocateDirect(4096));
}

private FuryReadableChannel(ReadableByteChannel channel, ByteBuffer directBuffer) {
public FuryReadableChannel(ReadableByteChannel channel, ByteBuffer directBuffer) {
Preconditions.checkArgument(
directBuffer.isDirect(), "FuryReadableChannel support only direct ByteBuffer.");
this.channel = channel;
this.byteBuffer = directBuffer;
this.buffer = MemoryBuffer.fromByteBuffer(directBuffer);
this.memoryBuffer = MemoryBuffer.fromDirectByteBuffer(directBuffer, 0, this);
}

@Override
public int fillBuffer(int minFillSize) {
try {
ByteBuffer byteBuf = byteBuffer;
MemoryBuffer memoryBuf = memoryBuffer;
int position = byteBuf.position();
int newLimit = position + minFillSize;
if (newLimit > byteBuf.capacity()) {
int newSize =
newLimit < BUFFER_GROW_STEP_THRESHOLD ? newLimit << 2 : (int) (newLimit * 1.5);
ByteBuffer newByteBuf = ByteBuffer.allocateDirect(newSize);
byteBuf.position(0);
newByteBuf.put(byteBuf);
byteBuf = byteBuffer = newByteBuf;
memoryBuf.initDirectBuffer(Platform.getAddress(byteBuf), position, byteBuf);
}
byteBuf.limit(newLimit);
int readCount = channel.read(byteBuf);
memoryBuf.increaseSize(readCount);
return readCount;
} catch (IOException e) {
throw new DeserializationException("Failed to read the provided byte channel", e);
}
}

@Override
public int read(ByteBuffer dst) throws IOException {
MemoryBuffer buf = buffer;
int length = dst.remaining();
MemoryBuffer buf = memoryBuffer;
int remaining = buf.remaining();
if (remaining >= length) {
buf.read(dst, length);
return length;
} else {
buf.read(dst, remaining);
return channel.read(dst) + remaining;
}
}

@Override
public void readTo(byte[] dst, int dstIndex, int length) {
MemoryBuffer buf = memoryBuffer;
int remaining = buf.remaining();
int len = dst.remaining();
if (remaining >= len) {
buf.read(dst, len);
return len;
if (remaining >= length) {
buf.readBytes(dst, dstIndex, length);
} else {
buf.readBytes(dst, dstIndex, remaining);
try {
buf.read(dst, remaining);
return channel.read(dst) + remaining;
ByteBuffer buffer = ByteBuffer.wrap(dst, dstIndex + remaining, length - remaining);
channel.read(buffer);
} catch (IOException e) {
throw new RuntimeException(e);
throw new DeserializationException("Failed to read the provided byte channel", e);
}
}
}

@Override
public void readToUnsafe(Object target, long targetPointer, int numBytes) {
MemoryBuffer buf = memoryBuffer;
int remaining = buf.remaining();
if (remaining < numBytes) {
fillBuffer(numBytes - remaining);
}
long address = buf.getUnsafeReaderAddress();
Platform.copyMemory(null, address, target, targetPointer, numBytes);
buf.increaseReaderIndex(numBytes);
}

@Override
public void readToByteBuffer(ByteBuffer dst, int length) {
MemoryBuffer buf = memoryBuffer;
int remaining = buf.remaining();
if (remaining >= length) {
buf.read(dst, length);
} else {
buf.read(dst, remaining);
try {
int dstLimit = dst.limit();
int newLimit = dst.position() + length - remaining;
if (dstLimit > newLimit) {
dst.limit(newLimit);
channel.read(dst);
dst.limit(dstLimit);
} else {
channel.read(dst);
}
} catch (IOException e) {
throw new DeserializationException("Failed to read the provided byte channel", e);
}
}
}

@Override
public int readToByteBuffer(ByteBuffer dst) {
MemoryBuffer buf = memoryBuffer;
int remaining = buf.remaining();
if (remaining > 0) {
buf.read(dst, remaining);
}
try {
return channel.read(dst) + remaining;
} catch (IOException e) {
throw new DeserializationException("Failed to read the provided byte channel", e);
}
}

@Override
public boolean isOpen() {
return channel.isOpen();
Expand All @@ -70,6 +164,6 @@ public void close() throws IOException {

@Override
public MemoryBuffer getBuffer() {
return buffer;
return memoryBuffer;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,13 @@

import java.io.InputStream;
import java.nio.ByteBuffer;
import java.nio.channels.SeekableByteChannel;
import org.apache.fury.memory.MemoryBuffer;

/** A streaming reader to make {@link MemoryBuffer} to support streaming reading. */
public interface FuryStreamReader {
int BUFFER_GROW_STEP_THRESHOLD = 100 * 1024 * 1024;

/**
* Read stream and fill the data to underlying {@link MemoryBuffer}, which is also the buffer
* returned by {@link #getBuffer}.
Expand Down Expand Up @@ -63,4 +66,14 @@ public interface FuryStreamReader {
static FuryInputStream of(InputStream stream) {
return new FuryInputStream(stream);
}

/**
* Create a {@link FuryReadableChannel} from the provided {@link SeekableByteChannel}. Note that
* the provided channel will be owned by the returned {@link FuryReadableChannel}, <bold>do
* not</bold> read the provided {@link SeekableByteChannel} anymore, read the returned stream
* instead.
*/
static FuryReadableChannel of(SeekableByteChannel channel) {
return new FuryReadableChannel(channel);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,15 @@ private MemoryBuffer(long offHeapAddress, int size, ByteBuffer offHeapBuffer) {
*/
private MemoryBuffer(
long offHeapAddress, int size, ByteBuffer offHeapBuffer, FuryStreamReader streamReader) {
initDirectBuffer(offHeapAddress, size, offHeapBuffer);
if (streamReader != null) {
this.streamReader = streamReader;
} else {
this.streamReader = new BoundChecker();
}
}

public void initDirectBuffer(long offHeapAddress, int size, ByteBuffer offHeapBuffer) {
this.offHeapBuffer = offHeapBuffer;
if (offHeapAddress <= 0) {
throw new IllegalArgumentException("negative pointer or size");
Expand All @@ -175,11 +184,6 @@ private MemoryBuffer(
this.address = offHeapAddress;
this.addressLimit = this.address + size;
this.size = size;
if (streamReader != null) {
this.streamReader = streamReader;
} else {
this.streamReader = new BoundChecker();
}
}

private class BoundChecker extends AbstractStreamReader {
Expand Down Expand Up @@ -2915,6 +2919,12 @@ public static MemoryBuffer fromByteBuffer(ByteBuffer buffer) {
}
}

public static MemoryBuffer fromDirectByteBuffer(
ByteBuffer buffer, int size, FuryStreamReader streamReader) {
long offHeapAddress = Platform.getAddress(buffer) + buffer.position();
return new MemoryBuffer(offHeapAddress, size, buffer, streamReader);
}

/**
* Creates a new memory buffer that represents the provided native memory. The buffer will change
* into a heap buffer automatically if not enough.
Expand Down
Loading

0 comments on commit 71c5b76

Please sign in to comment.