Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Java support for casting of nested child columns [skip ci] #7417

Merged
merged 19 commits into from
Mar 8, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions java/src/main/java/ai/rapids/cudf/ColumnVector.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import java.math.RoundingMode;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.function.Consumer;
Expand Down Expand Up @@ -394,6 +395,24 @@ public static ColumnVector makeStruct(long rows, ColumnView... columns) {
}
}

/**
* This is a very specialized method that has only one job. It takes in a list and returns a
revans2 marked this conversation as resolved.
Show resolved Hide resolved
* new list if the Leaf node is a 64-bit Decimal, converting the leaf to a 32-bit Decimal.
* Note: this is not a true cast as it assumes that the 64-bit Decimal column is a 32-bit Decimal
* column that happens to be stored as a 64-bit Decimal.
*
* Ex 1:
* replace(col( type: List<List<Struct<int, List, String>>>)) => throws an assert error
*
* Ex 2:
* replace(col(type: List<List<D64>>) => col(type: List<List<D32>>) no rounding is done
*
*/
public ColumnVector castLeafD64ToD32() {
razajafri marked this conversation as resolved.
Show resolved Hide resolved
assert(type == DType.LIST);
return new ColumnVector(castLeafD64ToD32(offHeap.columnHandle));
}

/**
* Create a LIST column from the given columns. Each list in the returned column will have the
* same number of entries in it as columns passed into this method. Be careful about the
Expand Down Expand Up @@ -725,6 +744,7 @@ static void closeBuffers(AutoCloseable buffer) {

private static native void setNativeNullCountColumn(long cudfColumnHandle, int nullCount) throws CudfException;

private static native long castLeafD64ToD32(long cudfColumnHandle) throws CudfException;
/**
* Create a cudf::column_view from a cudf::column.
* @param cudfColumnHandle the pointer to the cudf::column
Expand Down
13 changes: 13 additions & 0 deletions java/src/main/java/ai/rapids/cudf/ColumnView.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package ai.rapids.cudf;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;

Expand Down Expand Up @@ -1296,6 +1297,16 @@ public ColumnVector castTo(DType type) {
return new ColumnVector(castTo(getNativeView(), type.typeId.getNativeId(), type.getScale()));
}

/**
* Replace columns in the struct with the given columns
*/
public ColumnVector replaceColumnsInStruct(int[] indices,
ColumnView[] views) {
assert(type == DType.STRUCT);
return new ColumnVector(replaceColumnsInStruct(getNativeView(), indices,
Arrays.stream(views).mapToLong(v -> v.getNativeView()).toArray()));
}

/**
* Zero-copy cast between types with the same underlying representation.
*
Expand Down Expand Up @@ -2437,6 +2448,8 @@ static DeviceMemoryBufferView getOffsetsBuffer(long viewHandle) {
*/
private static native long timestampToStringTimestamp(long viewHandle, String format);

private static native long replaceColumnsInStruct(long cudfColumnHandle,
int[] indices, long[] viewHandles) throws CudfException;
jlowe marked this conversation as resolved.
Show resolved Hide resolved
/**
* Native method for locating the starting index of the first instance of a given substring
* in each string in the column. 0 indexing, returns -1 if the substring is not found. Can be
Expand Down
136 changes: 46 additions & 90 deletions java/src/main/native/src/ColumnVectorJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
* limitations under the License.
*/

#include <cstdint>
#include <memory>
#include <arrow/api.h>
#include <cudf/column/column_factories.hpp>
#include <cudf/concatenate.hpp>
Expand All @@ -27,13 +29,57 @@
#include <cudf/lists/lists_column_view.hpp>
#include <cudf/scalar/scalar_factories.hpp>
#include <cudf/structs/structs_column_view.hpp>
#include "cudf/null_mask.hpp"
#include "cudf/types.hpp"
#include "cudf/utilities/traits.hpp"
#include "cudf/unary.hpp"
#include "rmm/device_buffer.hpp"

#include "cudf_jni_apis.hpp"
#include "dtype_utils.hpp"
#include "jni.h"
#include "jni_utils.hpp"


extern "C" {

cudf::column* replace_column(cudf::column list_column) {
cudf::lists_column_view lcv(list_column);

std::unique_ptr<cudf::column> new_child;

if (lcv.child().type().id() != cudf::type_id::LIST) {
assert(lcv.child().type() == cudf::type_id::DECIMAL64);
cudf::data_type to_type = cudf::data_type(cudf::type_id::DECIMAL32, lcv.child().type().scale());
auto u_d32_ptr = cudf::cast(lcv.child(), to_type);
new_child.reset(u_d32_ptr.release());
} else {
new_child.reset(replace_column(list_column.child(cudf::lists_column_view::child_column_index)));
}

assert(new_child->size() == contents.children[lists_column_view::child_column_index].size());
int32_t size = list_column.size();
int32_t null_count = list_column.null_count();
auto contents = list_column.release();

auto col = cudf::make_lists_column(size, std::move(contents.children[cudf::lists_column_view::offsets_column_index]),
std::move(new_child), null_count, std::move(*contents.null_mask.release()));
return col.release();
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_castLeafD64ToD32( JNIEnv *env, jobject j_object, jlong j_handle) {

JNI_NULL_CHECK(env, j_handle, "native handle is null", 0);

try {
cudf::column *n_list_col = reinterpret_cast<cudf::column *>(j_handle);
JNI_ARG_CHECK(env, n_list_col->type().id() == cudf::type_id::LIST, "Only list types are allowed", 0);

return reinterpret_cast<jlong>(replace_column(*n_list_col));
}
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_sequence(JNIEnv *env, jclass,
jlong j_initial_val, jlong j_step,
jint row_count) {
Expand Down Expand Up @@ -317,96 +363,6 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_makeEmptyCudfColumn(JNI
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_makeNumericCudfColumn(
revans2 marked this conversation as resolved.
Show resolved Hide resolved
JNIEnv *env, jobject j_object, jint j_type, jint j_size, jint j_mask_state) {

JNI_ARG_CHECK(env, (j_size != 0), "size is 0", 0);

try {
cudf::jni::auto_set_device(env);
cudf::type_id n_type = static_cast<cudf::type_id>(j_type);
cudf::data_type n_data_type(n_type);
cudf::size_type n_size = static_cast<cudf::size_type>(j_size);
cudf::mask_state n_mask_state = static_cast<cudf::mask_state>(j_mask_state);
std::unique_ptr<cudf::column> column(
cudf::make_numeric_column(n_data_type, n_size, n_mask_state));
return reinterpret_cast<jlong>(column.release());
}
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_makeTimestampCudfColumn(
JNIEnv *env, jobject j_object, jint j_type, jint j_size, jint j_mask_state) {

JNI_NULL_CHECK(env, j_type, "type id is null", 0);
JNI_NULL_CHECK(env, j_size, "size is null", 0);

try {
cudf::jni::auto_set_device(env);
cudf::type_id n_type = static_cast<cudf::type_id>(j_type);
std::unique_ptr<cudf::data_type> n_data_type(new cudf::data_type(n_type));
cudf::size_type n_size = static_cast<cudf::size_type>(j_size);
cudf::mask_state n_mask_state = static_cast<cudf::mask_state>(j_mask_state);
std::unique_ptr<cudf::column> column(
cudf::make_timestamp_column(*n_data_type.get(), n_size, n_mask_state));
return reinterpret_cast<jlong>(column.release());
}
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_makeStringCudfColumnHostSide(
JNIEnv *env, jobject j_object, jlong j_char_data, jlong j_offset_data, jlong j_valid_data,
jint j_null_count, jint size) {

JNI_ARG_CHECK(env, (size != 0), "size is 0", 0);
JNI_NULL_CHECK(env, j_char_data, "char data is null", 0);
JNI_NULL_CHECK(env, j_offset_data, "offset is null", 0);

try {
cudf::jni::auto_set_device(env);
cudf::size_type *host_offsets = reinterpret_cast<cudf::size_type *>(j_offset_data);
char *n_char_data = reinterpret_cast<char *>(j_char_data);
cudf::size_type n_data_size = host_offsets[size];
cudf::bitmask_type *n_validity = reinterpret_cast<cudf::bitmask_type *>(j_valid_data);

if (n_validity == nullptr) {
j_null_count = 0;
}

std::unique_ptr<cudf::column> offsets = cudf::make_numeric_column(
cudf::data_type{cudf::type_id::INT32}, size + 1, cudf::mask_state::UNALLOCATED);
auto offsets_view = offsets->mutable_view();
JNI_CUDA_TRY(env, 0,
cudaMemcpyAsync(offsets_view.data<int32_t>(), host_offsets,
(size + 1) * sizeof(int32_t), cudaMemcpyHostToDevice));

std::unique_ptr<cudf::column> data = cudf::make_numeric_column(
cudf::data_type{cudf::type_id::INT8}, n_data_size, cudf::mask_state::UNALLOCATED);
auto data_view = data->mutable_view();
JNI_CUDA_TRY(env, 0,
cudaMemcpyAsync(data_view.data<int8_t>(), n_char_data, n_data_size,
cudaMemcpyHostToDevice));

std::unique_ptr<cudf::column> column;
if (j_null_count == 0) {
column =
cudf::make_strings_column(size, std::move(offsets), std::move(data), j_null_count, {});
} else {
cudf::size_type bytes = (cudf::word_index(size) + 1) * sizeof(cudf::bitmask_type);
rmm::device_buffer dev_validity(bytes);
JNI_CUDA_TRY(env, 0,
cudaMemcpyAsync(dev_validity.data(), n_validity, bytes, cudaMemcpyHostToDevice));

column = cudf::make_strings_column(size, std::move(offsets), std::move(data), j_null_count,
std::move(dev_validity));
}

JNI_CUDA_TRY(env, 0, cudaStreamSynchronize(0));
return reinterpret_cast<jlong>(column.release());
}
CATCH_STD(env, 0);
}

JNIEXPORT jint JNICALL Java_ai_rapids_cudf_ColumnVector_getNativeNullCountColumn(JNIEnv *env,
jobject j_object,
jlong handle) {
Expand Down
40 changes: 40 additions & 0 deletions java/src/main/native/src/ColumnViewJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
#include <cudf/lists/lists_column_view.hpp>
#include <cudf/structs/structs_column_view.hpp>
#include <map_lookup.hpp>
#include "cudf/types.hpp"
razajafri marked this conversation as resolved.
Show resolved Hide resolved

#include "cudf_jni_apis.hpp"
#include "dtype_utils.hpp"
Expand Down Expand Up @@ -1760,4 +1761,43 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_copyColumnViewToCV(JNIEnv
}
CATCH_STD(env, 0)
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_replaceColumnsInStruct(
JNIEnv *env, jobject j_object, jlong j_handle, jintArray j_indices, jlongArray j_children) {

JNI_NULL_CHECK(env, j_handle, "native handle is null", 0);
JNI_NULL_CHECK(env, j_indices, "child indices to replace can't be null", 0);
JNI_NULL_CHECK(env, j_children, "children to replace can't be null", 0);

try {
cudf::jni::native_jpointerArray<cudf::column_view> children_to_replace(env, j_children);
cudf::jni::native_jintArray indices(env, j_indices);
JNI_ARG_CHECK(env, indices.size() == children_to_replace.size(), "The indices size and children size should match", 0);

cudf::column_view *n_struct_col_view = reinterpret_cast<cudf::column_view *>(j_handle);
JNI_ARG_CHECK(env, n_struct_col_view->type().id() == cudf::type_id::STRUCT, "Only struct types are allowed", 0);

std::map<int32_t, cudf::column_view*> m;
for (int i = 0 ; i < indices.size() ; i++) {
m[indices[i]] = children_to_replace[i];
revans2 marked this conversation as resolved.
Show resolved Hide resolved
}

std::vector<std::unique_ptr<cudf::column>> children;
children.reserve(n_struct_col_view->num_children());
int j = 0;
for (int i = 0 ; i < n_struct_col_view->num_children() ; i++) {
auto it = m.find(i);
if (it != m.end()) {
children.emplace_back(std::make_unique<cudf::column>(*it->second));
} else {
children.emplace_back(std::make_unique<cudf::column>(n_struct_col_view->child(i)));
}
}

auto col = cudf::make_structs_column(n_struct_col_view->size(), std::move(children),
n_struct_col_view->null_count(), cudf::copy_bitmask(*n_struct_col_view));
return reinterpret_cast<jlong>(col.release());
}
CATCH_STD(env, 0);
}
} // extern "C"
37 changes: 37 additions & 0 deletions java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -3936,4 +3936,41 @@ void testMakeList() {
assertColumnsAreEqual(expected, created);
}
}

@Test
void testCastLeafNodeInList() {
try (
ColumnVector c1 = ColumnVector.decimalFromDoubles(DType.create(DType.DTypeEnum.DECIMAL32, 3), RoundingMode.HALF_UP, 770.892, 961.110);
ColumnVector c2 = ColumnVector.decimalFromDoubles(DType.create(DType.DTypeEnum.DECIMAL32, 3), RoundingMode.HALF_UP, 524.982, 479.946);
ColumnVector c3 = ColumnVector.decimalFromDoubles(DType.create(DType.DTypeEnum.DECIMAL32, 3), RoundingMode.HALF_UP, 346.997, 479.946);
ColumnVector c4 = ColumnVector.decimalFromDoubles(DType.create(DType.DTypeEnum.DECIMAL32, 3), RoundingMode.HALF_UP, 87.764, 414.239);
ColumnVector expected = ColumnVector.makeList(c1, c2, c3, c4);
ColumnVector child1 = ColumnVector.decimalFromDoubles(DType.create(DType.DTypeEnum.DECIMAL64, 3), RoundingMode.HALF_UP, 770.892, 961.110);
ColumnVector child2 = ColumnVector.decimalFromDoubles(DType.create(DType.DTypeEnum.DECIMAL64, 3), RoundingMode.HALF_UP, 524.982, 479.946);
ColumnVector child3 = ColumnVector.decimalFromDoubles(DType.create(DType.DTypeEnum.DECIMAL64, 3), RoundingMode.HALF_UP, 346.997, 479.946);
ColumnVector child4 = ColumnVector.decimalFromDoubles(DType.create(DType.DTypeEnum.DECIMAL64, 3), RoundingMode.HALF_UP, 87.764, 414.239);
ColumnVector created = ColumnVector.makeList(child1, child2, child3, child4);
ColumnVector replaced = created.castLeafD64ToD32()) {
assertColumnsAreEqual(expected, replaced);
}
}

@Test
void testReplaceColumnInStruct() {
try (ColumnVector expected = ColumnVector.fromStructs(new StructType(false,
Arrays.asList(
new BasicType(false, DType.INT32),
new BasicType(false, DType.INT32),
new BasicType(false, DType.INT32))),
new HostColumnVector.StructData(1, 5, 3),
new HostColumnVector.StructData(4, 9, 6));
ColumnVector child1 = ColumnVector.fromInts(1, 4);
ColumnVector child2 = ColumnVector.fromInts(2, 5);
ColumnVector child3 = ColumnVector.fromInts(3, 6);
ColumnVector created = ColumnVector.makeStruct(child1, child2, child3);
ColumnVector replaceWith = ColumnVector.fromInts(5, 9);
ColumnVector replaced = created.replaceColumnsInStruct(new int[]{1}, new ColumnVector[]{replaceWith})) {
assertColumnsAreEqual(expected, replaced);
}
}
}