Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add JNI support for converting Arrow buffers to CUDF ColumnVectors [skip ci] #7222

Merged
merged 17 commits into from
Jan 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions java/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,12 @@
<version>2.25.0</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.apache.arrow</groupId>
<artifactId>arrow-vector</artifactId>
<version>${arrow.version}</version>
<scope>test</scope>
</dependency>
</dependencies>

<properties>
Expand All @@ -151,6 +157,7 @@
<GPU_ARCHS>ALL</GPU_ARCHS>
<native.build.path>${project.build.directory}/cmake-build</native.build.path>
<slf4j.version>1.7.30</slf4j.version>
<arrow.version>0.15.1</arrow.version>
</properties>

<profiles>
Expand Down
113 changes: 113 additions & 0 deletions java/src/main/java/ai/rapids/cudf/ArrowColumnBuilder.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
/*
*
* Copyright (c) 2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

package ai.rapids.cudf;

import java.nio.ByteBuffer;
import java.util.ArrayList;

/**
* Column builder from Arrow data. This builder takes in byte buffers referencing
* Arrow data and allows efficient building of CUDF ColumnVectors from that Arrow data.
* The caller can add multiple batches where each batch corresponds to Arrow data
* and those batches get concatenated together after being converted to CUDF
* ColumnVectors.
* This currently only supports primitive types and Strings, Decimals and nested types
* such as list and struct are not supported.
*/
public final class ArrowColumnBuilder implements AutoCloseable {
private DType type;
jlowe marked this conversation as resolved.
Show resolved Hide resolved
private final ArrayList<ByteBuffer> data = new ArrayList<>();
private final ArrayList<ByteBuffer> validity = new ArrayList<>();
private final ArrayList<ByteBuffer> offsets = new ArrayList<>();
private final ArrayList<Long> nullCount = new ArrayList<>();
private final ArrayList<Long> rows = new ArrayList<>();

public ArrowColumnBuilder(HostColumnVector.DataType type) {
this.type = type.getType();
revans2 marked this conversation as resolved.
Show resolved Hide resolved
}

/**
* Add an Arrow buffer. This API allows you to add multiple if you want them
* combined into a single ColumnVector.
* Note, this takes all data, validity, and offsets buffers, but they may not all
* be needed based on the data type. The buffer should be null if its not used
* for that type.
* This API only supports primitive types and Strings, Decimals and nested types
* such as list and struct are not supported.
* @param rows - number of rows in this Arrow buffer
* @param nullCount - number of null values in this Arrow buffer
* @param data - ByteBuffer of the Arrow data buffer
* @param validity - ByteBuffer of the Arrow validity buffer
* @param offsets - ByteBuffer of the Arrow offsets buffer
*/
public void addBatch(long rows, long nullCount, ByteBuffer data, ByteBuffer validity,
ByteBuffer offsets) {
this.rows.add(rows);
this.nullCount.add(nullCount);
revans2 marked this conversation as resolved.
Show resolved Hide resolved
this.data.add(data);
this.validity.add(validity);
this.offsets.add(offsets);
}

/**
* Create the immutable ColumnVector, copied to the device based on the Arrow data.
* @return - new ColumnVector
*/
public final ColumnVector buildAndPutOnDevice() {
int numBatches = rows.size();
ArrayList<ColumnVector> allVecs = new ArrayList<>(numBatches);
ColumnVector vecRet;
try {
for (int i = 0; i < numBatches; i++) {
allVecs.add(ColumnVector.fromArrow(type, rows.get(i), nullCount.get(i),
data.get(i), validity.get(i), offsets.get(i)));
}
if (numBatches == 1) {
vecRet = allVecs.get(0);
} else if (numBatches > 1) {
vecRet = ColumnVector.concatenate(allVecs.toArray(new ColumnVector[0]));
} else {
throw new IllegalStateException("Can't build a ColumnVector when no Arrow batches specified");
}
} finally {
// close the vectors that were concatenated
if (numBatches > 1) {
allVecs.forEach(cv -> cv.close());
}
}
return vecRet;
}

@Override
public void close() {
// memory buffers owned outside of this
}

@Override
public String toString() {
return "ArrowColumnBuilder{" +
"type=" + type +
", data=" + data +
", validity=" + validity +
", offsets=" + offsets +
", nullCount=" + nullCount +
", rows=" + rows +
'}';
}
}
49 changes: 49 additions & 0 deletions java/src/main/java/ai/rapids/cudf/ColumnVector.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import java.math.BigDecimal;
import java.math.RoundingMode;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
Expand Down Expand Up @@ -310,6 +311,50 @@ public BaseDeviceMemoryBuffer getDeviceBufferFor(BufferType type) {
return srcBuffer;
}

/**
* Ensures the ByteBuffer passed in is a direct byte buffer.
* If it is not then it creates one and copies the data in
* the byte buffer passed in to the direct byte buffer
* it created and returns it.
*/
private static ByteBuffer bufferAsDirect(ByteBuffer buf) {
ByteBuffer bufferOut = buf;
if (bufferOut != null && !bufferOut.isDirect()) {
bufferOut = ByteBuffer.allocateDirect(buf.remaining());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is ByteOrder an issue? if buf is in non-native byte order?

bufferOut.put(buf);
bufferOut.flip();
}
return bufferOut;
}

/**
* Create a ColumnVector from the Apache Arrow byte buffers passed in.
* Any of the buffers not used for that datatype should be set to null.
* The buffers are expected to be off heap buffers, but if they are not,
* it will handle copying them to direct byte buffers.
* This only supports primitive types. Strings, Decimals and nested types
* such as list and struct are not supported.
* @param type - type of the column
* @param numRows - Number of rows in the arrow column
* @param nullCount - Null count
* @param data - ByteBuffer of the Arrow data buffer
* @param validity - ByteBuffer of the Arrow validity buffer
* @param offsets - ByteBuffer of the Arrow offsets buffer
* @return - new ColumnVector
*/
public static ColumnVector fromArrow(
DType type,
long numRows,
long nullCount,
ByteBuffer data,
ByteBuffer validity,
ByteBuffer offsets) {
long columnHandle = fromArrow(type.typeId.getNativeId(), numRows, nullCount,
bufferAsDirect(data), bufferAsDirect(validity), bufferAsDirect(offsets));
ColumnVector vec = new ColumnVector(columnHandle);
return vec;
}

/**
* Create a new vector of length rows, where each row is filled with the Scalar's
* value
Expand Down Expand Up @@ -615,6 +660,10 @@ public ColumnVector castTo(DType type) {

private static native long sequence(long initialValue, long step, int rows);

private static native long fromArrow(int type, long col_length,
long null_count, ByteBuffer data, ByteBuffer validity,
ByteBuffer offsets) throws CudfException;

private static native long fromScalar(long scalarHandle, int rowCount) throws CudfException;

private static native long makeList(long[] handles, long typeHandle, int scale, long rows)
Expand Down
75 changes: 75 additions & 0 deletions java/src/main/native/src/ColumnVectorJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,15 @@
* limitations under the License.
*/

#include <arrow/api.h>
#include <cudf/column/column_factories.hpp>
#include <cudf/concatenate.hpp>
#include <cudf/filling.hpp>
#include <cudf/interop.hpp>
#include <cudf/hashing.hpp>
#include <cudf/reshape.hpp>
#include <cudf/utilities/bit.hpp>
#include <cudf/detail/interop.hpp>
#include <cudf/lists/detail/concatenate.hpp>
#include <cudf/lists/lists_column_view.hpp>
#include <cudf/scalar/scalar_factories.hpp>
Expand Down Expand Up @@ -50,6 +53,78 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_sequence(JNIEnv *env, j
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_fromArrow(JNIEnv *env, jclass,
jint j_type,
jlong j_col_length,
jlong j_null_count,
jobject j_data_obj,
jobject j_validity_obj,
jobject j_offsets_obj) {
try {
cudf::jni::auto_set_device(env);
cudf::type_id n_type = static_cast<cudf::type_id>(j_type);
// not all the buffers are used for all types
void const *data_address = 0;
int data_length = 0;
if (j_data_obj != 0) {
data_address = env->GetDirectBufferAddress(j_data_obj);
data_length = env->GetDirectBufferCapacity(j_data_obj);
}
void const *validity_address = 0;
int validity_length = 0;
if (j_validity_obj != 0) {
validity_address = env->GetDirectBufferAddress(j_validity_obj);
validity_length = env->GetDirectBufferCapacity(j_validity_obj);
}
void const *offsets_address = 0;
int offsets_length = 0;
if (j_offsets_obj != 0) {
offsets_address = env->GetDirectBufferAddress(j_offsets_obj);
offsets_length = env->GetDirectBufferCapacity(j_offsets_obj);
}
auto data_buffer = arrow::Buffer::Wrap(static_cast<const char *>(data_address), static_cast<int>(data_length));
auto null_buffer = arrow::Buffer::Wrap(static_cast<const char *>(validity_address), static_cast<int>(validity_length));
auto offsets_buffer = arrow::Buffer::Wrap(static_cast<const char *>(offsets_address), static_cast<int>(offsets_length));

cudf::jni::native_jlongArray outcol_handles(env, 1);
std::shared_ptr<arrow::Array> arrow_array;
switch (n_type) {
case cudf::type_id::DECIMAL32:
JNI_THROW_NEW(env, cudf::jni::ILLEGAL_ARG_CLASS, "Don't support converting DECIMAL32 yet", 0);
break;
case cudf::type_id::DECIMAL64:
JNI_THROW_NEW(env, cudf::jni::ILLEGAL_ARG_CLASS, "Don't support converting DECIMAL64 yet", 0);
break;
case cudf::type_id::STRUCT:
JNI_THROW_NEW(env, cudf::jni::ILLEGAL_ARG_CLASS, "Don't support converting STRUCT yet", 0);
break;
case cudf::type_id::LIST:
JNI_THROW_NEW(env, cudf::jni::ILLEGAL_ARG_CLASS, "Don't support converting LIST yet", 0);
break;
case cudf::type_id::DICTIONARY32:
JNI_THROW_NEW(env, cudf::jni::ILLEGAL_ARG_CLASS, "Don't support converting DICTIONARY32 yet", 0);
break;
case cudf::type_id::STRING:
arrow_array = std::make_shared<arrow::StringArray>(j_col_length, offsets_buffer, data_buffer, null_buffer, j_null_count);
break;
default:
// this handles the primitive types
arrow_array = cudf::detail::to_arrow_array(n_type, j_col_length, data_buffer, null_buffer, j_null_count);
}
auto name_and_type = arrow::field("col", arrow_array->type());
std::vector<std::shared_ptr<arrow::Field>> fields = {name_and_type};
std::shared_ptr<arrow::Schema> schema = std::make_shared<arrow::Schema>(fields);
auto arrow_table = arrow::Table::Make(schema, std::vector<std::shared_ptr<arrow::Array>>{arrow_array});
std::unique_ptr<cudf::table> table_result = cudf::from_arrow(*(arrow_table));
std::vector<std::unique_ptr<cudf::column>> retCols = table_result->release();
if (retCols.size() != 1) {
JNI_THROW_NEW(env, "java/lang/IllegalArgumentException", "Must result in one column", 0);
}
return reinterpret_cast<jlong>(retCols[0].release());
}
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_makeList(JNIEnv *env, jobject j_object,
jlongArray handles,
jlong j_type,
Expand Down
Loading