Skip to content

Commit

Permalink
Support the decompress feature in shuffle component. (apache#3)
Browse files Browse the repository at this point in the history
* ARROW-10880: [Java] Support compressing RecordBatch IPC buffers by LZ4

* ARROW-10880: [Java] Support reading/writing big-endian message size

* ARROW-10880: [Java] Adjust variable names

* ARROW-10880: [Java] Support empty buffers

* ARROW-10880: [Java] Support passing raw data

* ARROW-10880: [Java] Switch to commons-compress library

* bug fix and support the fastpfor codec in the IPC framework

* update the access permission from private to protected

* disable the decompress function when loading the buffer

Co-authored-by: liyafan82 <[email protected]>
  • Loading branch information
2 people authored and zhztheplayer committed May 6, 2021
1 parent 9f8ae57 commit 4b1202d
Show file tree
Hide file tree
Showing 11 changed files with 423 additions and 19 deletions.
2 changes: 2 additions & 0 deletions cpp/src/arrow/ipc/metadata_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -922,6 +922,8 @@ static Status GetBodyCompression(FBB& fbb, const IpcWriteOptions& options,
codec = flatbuf::CompressionType::LZ4_FRAME;
} else if (options.codec->compression_type() == Compression::ZSTD) {
codec = flatbuf::CompressionType::ZSTD;
} else if (options.codec->compression_type() == Compression::FASTPFOR) {
codec = flatbuf::CompressionType::FASTPFOR;
} else {
return Status::Invalid("Unsupported IPC compression codec: ",
options.codec->name());
Expand Down
2 changes: 2 additions & 0 deletions cpp/src/arrow/ipc/reader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,8 @@ Status GetCompression(const flatbuf::RecordBatch* batch, Compression::type* out)
*out = Compression::LZ4_FRAME;
} else if (compression->codec() == flatbuf::CompressionType::ZSTD) {
*out = Compression::ZSTD;
} else if (compression->codec() == flatbuf::CompressionType::FASTPFOR) {
*out = Compression::FASTPFOR;
} else {
return Status::Invalid("Unsupported codec in RecordBatch::compression metadata");
}
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/arrow/util/compression.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ Result<Compression::type> Codec::GetCompressionType(const std::string& name) {
return Compression::ZSTD;
} else if (name == "bz2") {
return Compression::BZ2;
} else if (name == "FASTPFOR") {
} else if (name == "fastpfor") {
return Compression::FASTPFOR;
} else {
return Status::Invalid("Unrecognized compression type: ", name);
Expand Down
9 changes: 6 additions & 3 deletions cpp/src/generated/Message_generated.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,22 +32,25 @@ struct MessageBuilder;
enum class CompressionType : int8_t {
LZ4_FRAME = 0,
ZSTD = 1,
FASTPFOR = 2,
MIN = LZ4_FRAME,
MAX = ZSTD
};

inline const CompressionType (&EnumValuesCompressionType())[2] {
inline const CompressionType (&EnumValuesCompressionType())[3] {
static const CompressionType values[] = {
CompressionType::LZ4_FRAME,
CompressionType::ZSTD
CompressionType::ZSTD,
CompressionType::FASTPFOR
};
return values;
}

inline const char * const *EnumNamesCompressionType() {
static const char * const names[3] = {
static const char * const names[4] = {
"LZ4_FRAME",
"ZSTD",
"FASTPFOR",
nullptr
};
return names;
Expand Down
3 changes: 2 additions & 1 deletion format/Message.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ enum CompressionType:byte {
LZ4_FRAME,

// Zstandard
ZSTD
ZSTD,
FASTPFOR
}

/// Provided for forward compatibility in case we need to support different
Expand Down
5 changes: 5 additions & 0 deletions java/vector/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-compress</artifactId>
<version>1.20</version>
</dependency>
</dependencies>

<pluginRepositories>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,9 @@ public VectorLoader(VectorSchemaRoot root, CompressionCodec.Factory factory) {
public void load(ArrowRecordBatch recordBatch) {
Iterator<ArrowBuf> buffers = recordBatch.getBuffers().iterator();
Iterator<ArrowFieldNode> nodes = recordBatch.getNodes().iterator();
CompressionUtil.CodecType codecType =
CompressionUtil.CodecType.fromCompressionType(recordBatch.getBodyCompression().getCodec());
decompressionNeeded = codecType != CompressionUtil.CodecType.NO_COMPRESSION;
CompressionCodec codec = decompressionNeeded ? factory.createCodec(codecType) : NoCompressionCodec.INSTANCE;

for (FieldVector fieldVector : root.getFieldVectors()) {
loadBuffers(fieldVector, fieldVector.getField(), buffers, nodes, codec);
loadBuffers(fieldVector, fieldVector.getField(), buffers, nodes);
}
root.setRowCount(recordBatch.getLength());
if (nodes.hasNext() || buffers.hasNext()) {
Expand All @@ -94,21 +91,20 @@ protected void loadBuffers(
FieldVector vector,
Field field,
Iterator<ArrowBuf> buffers,
Iterator<ArrowFieldNode> nodes,
CompressionCodec codec) {
Iterator<ArrowFieldNode> nodes) {
checkArgument(nodes.hasNext(), "no more field nodes for for field %s and vector %s", field, vector);
ArrowFieldNode fieldNode = nodes.next();
int bufferLayoutCount = TypeLayout.getTypeBufferCount(field.getType());
List<ArrowBuf> ownBuffers = new ArrayList<>(bufferLayoutCount);
for (int j = 0; j < bufferLayoutCount; j++) {
ArrowBuf nextBuf = buffers.next();
// for vectors without nulls, the buffer is empty, so there is no need to decompress it.
ArrowBuf bufferToAdd = nextBuf.writerIndex() > 0 ? codec.decompress(vector.getAllocator(), nextBuf) : nextBuf;
ownBuffers.add(bufferToAdd);
ownBuffers.add(nextBuf);
if (decompressionNeeded) {
// decompression performed
nextBuf.getReferenceManager().retain();
}
ownBuffers.add(nextBuf);
}
try {
vector.loadFieldBuffers(fieldNode, ownBuffers);
Expand All @@ -130,7 +126,7 @@ protected void loadBuffers(
for (int i = 0; i < childrenFromFields.size(); i++) {
Field child = children.get(i);
FieldVector fieldVector = childrenFromFields.get(i);
loadBuffers(fieldVector, child, buffers, nodes, codec);
loadBuffers(fieldVector, child, buffers, nodes);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,4 +100,33 @@ public static ArrowBuf extractUncompressedBuffer(ArrowBuf inputBuffer) {
return inputBuffer.slice(SIZE_OF_UNCOMPRESSED_LENGTH,
inputBuffer.writerIndex() - SIZE_OF_UNCOMPRESSED_LENGTH);
}

public static CompressionCodec createCodec(byte compressionType) {
switch (compressionType) {
case NoCompressionCodec.COMPRESSION_TYPE:
return NoCompressionCodec.INSTANCE;
case CompressionType.LZ4_FRAME:
return new Lz4CompressionCodec();
default:
throw new IllegalArgumentException("Compression type not supported: " + compressionType);
}
}
/**
* Process compression by compressing the buffer as is.
*/
public static ArrowBuf compressRawBuffer(BufferAllocator allocator, ArrowBuf inputBuffer) {
ArrowBuf compressedBuffer = allocator.buffer(SIZE_OF_UNCOMPRESSED_LENGTH + inputBuffer.writerIndex());
compressedBuffer.setLong(0, NO_COMPRESSION_LENGTH);
compressedBuffer.setBytes(SIZE_OF_UNCOMPRESSED_LENGTH, inputBuffer, 0, inputBuffer.writerIndex());
compressedBuffer.writerIndex(SIZE_OF_UNCOMPRESSED_LENGTH + inputBuffer.writerIndex());
return compressedBuffer;
}

/**
* Process decompression by decompressing the buffer as is.
*/
public static ArrowBuf decompressRawBuffer(ArrowBuf inputBuffer) {
return inputBuffer.slice(SIZE_OF_UNCOMPRESSED_LENGTH,
inputBuffer.writerIndex() - SIZE_OF_UNCOMPRESSED_LENGTH);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.arrow.vector.compression;

import static org.apache.arrow.memory.util.MemoryUtil.LITTLE_ENDIAN;
import static org.apache.arrow.vector.compression.CompressionUtil.NO_COMPRESSION_LENGTH;
import static org.apache.arrow.vector.compression.CompressionUtil.SIZE_OF_UNCOMPRESSED_LENGTH;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;

import org.apache.arrow.flatbuf.CompressionType;
import org.apache.arrow.memory.ArrowBuf;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.util.Preconditions;
import org.apache.commons.compress.compressors.lz4.FramedLZ4CompressorInputStream;
import org.apache.commons.compress.compressors.lz4.FramedLZ4CompressorOutputStream;
import org.apache.commons.compress.utils.IOUtils;

import io.netty.util.internal.PlatformDependent;

/**
* Compression codec for the LZ4 algorithm.
*/
public class Lz4CompressionCodec implements CompressionCodec {

@Override
public ArrowBuf compress(BufferAllocator allocator, ArrowBuf uncompressedBuffer) {
Preconditions.checkArgument(uncompressedBuffer.writerIndex() <= Integer.MAX_VALUE,
"The uncompressed buffer size exceeds the integer limit");

if (uncompressedBuffer.writerIndex() == 0L) {
// shortcut for empty buffer
ArrowBuf compressedBuffer = allocator.buffer(SIZE_OF_UNCOMPRESSED_LENGTH);
compressedBuffer.setLong(0, 0);
compressedBuffer.writerIndex(SIZE_OF_UNCOMPRESSED_LENGTH);
uncompressedBuffer.close();
return compressedBuffer;
}

try {
ArrowBuf compressedBuffer = doCompress(allocator, uncompressedBuffer);
long compressedLength = compressedBuffer.writerIndex() - SIZE_OF_UNCOMPRESSED_LENGTH;
if (compressedLength > uncompressedBuffer.writerIndex()) {
// compressed buffer is larger, send the raw buffer
compressedBuffer.close();
compressedBuffer = CompressionUtil.compressRawBuffer(allocator, uncompressedBuffer);
}

uncompressedBuffer.close();
return compressedBuffer;
} catch (IOException e) {
throw new RuntimeException(e);
}
}

private ArrowBuf doCompress(BufferAllocator allocator, ArrowBuf uncompressedBuffer) throws IOException {
byte[] inBytes = new byte[(int) uncompressedBuffer.writerIndex()];
PlatformDependent.copyMemory(uncompressedBuffer.memoryAddress(), inBytes, 0, uncompressedBuffer.writerIndex());
ByteArrayOutputStream baos = new ByteArrayOutputStream();
try (InputStream in = new ByteArrayInputStream(inBytes);
OutputStream out = new FramedLZ4CompressorOutputStream(baos)) {
IOUtils.copy(in, out);
}

byte[] outBytes = baos.toByteArray();

ArrowBuf compressedBuffer = allocator.buffer(SIZE_OF_UNCOMPRESSED_LENGTH + outBytes.length);

long uncompressedLength = uncompressedBuffer.writerIndex();
if (!LITTLE_ENDIAN) {
uncompressedLength = Long.reverseBytes(uncompressedLength);
}
// first 8 bytes reserved for uncompressed length, to be consistent with the
// C++ implementation.
compressedBuffer.setLong(0, uncompressedLength);

PlatformDependent.copyMemory(
outBytes, 0, compressedBuffer.memoryAddress() + SIZE_OF_UNCOMPRESSED_LENGTH, outBytes.length);
compressedBuffer.writerIndex(SIZE_OF_UNCOMPRESSED_LENGTH + outBytes.length);
return compressedBuffer;
}

@Override
public ArrowBuf decompress(BufferAllocator allocator, ArrowBuf compressedBuffer) {
Preconditions.checkArgument(compressedBuffer.writerIndex() <= Integer.MAX_VALUE,
"The compressed buffer size exceeds the integer limit");

Preconditions.checkArgument(compressedBuffer.writerIndex() >= SIZE_OF_UNCOMPRESSED_LENGTH,
"Not enough data to decompress.");

long decompressedLength = compressedBuffer.getLong(0);
if (!LITTLE_ENDIAN) {
decompressedLength = Long.reverseBytes(decompressedLength);
}

if (decompressedLength == 0L) {
// shortcut for empty buffer
compressedBuffer.close();
return allocator.getEmpty();
}

if (decompressedLength == NO_COMPRESSION_LENGTH) {
// no compression
return CompressionUtil.decompressRawBuffer(compressedBuffer);
}

try {
ArrowBuf decompressedBuffer = doDecompress(allocator, compressedBuffer);
compressedBuffer.close();
return decompressedBuffer;
} catch (IOException e) {
throw new RuntimeException(e);
}
}

private ArrowBuf doDecompress(BufferAllocator allocator, ArrowBuf compressedBuffer) throws IOException {
long decompressedLength = compressedBuffer.getLong(0);
if (!LITTLE_ENDIAN) {
decompressedLength = Long.reverseBytes(decompressedLength);
}

byte[] inBytes = new byte[(int) (compressedBuffer.writerIndex() - SIZE_OF_UNCOMPRESSED_LENGTH)];
PlatformDependent.copyMemory(
compressedBuffer.memoryAddress() + SIZE_OF_UNCOMPRESSED_LENGTH, inBytes, 0, inBytes.length);
ByteArrayOutputStream out = new ByteArrayOutputStream((int) decompressedLength);
try (InputStream in = new FramedLZ4CompressorInputStream(new ByteArrayInputStream(inBytes))) {
IOUtils.copy(in, out);
}

byte[] outBytes = out.toByteArray();
ArrowBuf decompressedBuffer = allocator.buffer(outBytes.length);
PlatformDependent.copyMemory(outBytes, 0, decompressedBuffer.memoryAddress(), outBytes.length);
decompressedBuffer.writerIndex(decompressedLength);
return decompressedBuffer;
}

@Override
public String getCodecName() {
return CompressionType.name(CompressionType.LZ4_FRAME);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@
*/
public class ArrowStreamReader extends ArrowReader {

private MessageChannelReader messageReader;
protected MessageChannelReader messageReader;

private int loadedDictionaryCount;
protected int loadedDictionaryCount;

/**
* Constructs a streaming reader using a MessageChannelReader. Non-blocking.
Expand Down Expand Up @@ -176,7 +176,7 @@ public boolean loadNextBatch() throws IOException {
/**
* When read a record batch, check whether its dictionaries are available.
*/
private void checkDictionaries() throws IOException {
protected void checkDictionaries() throws IOException {
// if all dictionaries are loaded, return.
if (loadedDictionaryCount == dictionaries.size()) {
return;
Expand Down Expand Up @@ -215,7 +215,7 @@ protected Schema readSchema() throws IOException {
}


private ArrowDictionaryBatch readDictionary(MessageResult result) throws IOException {
protected ArrowDictionaryBatch readDictionary(MessageResult result) throws IOException {

ArrowBuf bodyBuffer = result.getBodyBuffer();

Expand Down
Loading

0 comments on commit 4b1202d

Please sign in to comment.