Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Java support for casting of nested child columns [skip ci] #7417

Merged
merged 19 commits into from
Mar 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions java/src/main/java/ai/rapids/cudf/ColumnVector.java
Original file line number Diff line number Diff line change
Expand Up @@ -167,16 +167,16 @@ private static long getColumnViewFromColumn(long nativePointer) {
}
}


private static long initViewHandle(DType type, int rows, int nc, DeviceMemoryBuffer dataBuffer,
DeviceMemoryBuffer validityBuffer,
DeviceMemoryBuffer offsetBuffer, long[] childHandles) {
static long initViewHandle(DType type, int rows, int nc,
BaseDeviceMemoryBuffer dataBuffer,
BaseDeviceMemoryBuffer validityBuffer,
BaseDeviceMemoryBuffer offsetBuffer, long[] childHandles) {
long cd = dataBuffer == null ? 0 : dataBuffer.address;
long cdSize = dataBuffer == null ? 0 : dataBuffer.length;
long od = offsetBuffer == null ? 0 : offsetBuffer.address;
long vd = validityBuffer == null ? 0 : validityBuffer.address;
return makeCudfColumnView(type.typeId.getNativeId(), type.getScale(), cd, cdSize,
od, vd, nc, rows, childHandles) ;
od, vd, nc, rows, childHandles);
}

static ColumnVector fromViewWithContiguousAllocation(long columnViewAddress, DeviceMemoryBuffer buffer) {
Expand Down
144 changes: 141 additions & 3 deletions java/src/main/java/ai/rapids/cudf/ColumnView.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@

package ai.rapids.cudf;

import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.*;
import java.util.stream.IntStream;

import static ai.rapids.cudf.HostColumnVector.OFFSET_SIZE;

Expand Down Expand Up @@ -49,6 +48,65 @@ protected ColumnView(long address) {
this.nullCount = ColumnView.getNativeNullCount(viewHandle);
}

/**
* Create a new column view based off of data already on the device. Ref count on the buffers
* is not incremented and none of the underlying buffers are owned by this view. The returned
* ColumnView is only valid as long as the underlying buffers remain valid. If the buffers are
* closed before this ColumnView is closed, it will result in undefined behavior.
*
* If ownership is needed, call {@link ColumnView#copyToColumnVector}
*
* @param type the type of the vector
* @param rows the number of rows in this vector.
* @param nullCount the number of nulls in the dataset.
* @param validityBuffer an optional validity buffer. Must be provided if nullCount != 0.
* The ownership doesn't change on this buffer
* @param offsetBuffer a host buffer required for nested types including strings and string
* categories. The ownership doesn't change on this buffer
* @param children an array of ColumnView children
*/
public ColumnView(DType type, long rows, Optional<Long> nullCount,
BaseDeviceMemoryBuffer validityBuffer,
BaseDeviceMemoryBuffer offsetBuffer, ColumnView[] children) {
this(type, (int) rows, nullCount.orElse(UNKNOWN_NULL_COUNT).intValue(),
null, validityBuffer, offsetBuffer, children);
assert(type.isNestedType());
assert (nullCount.isPresent() && nullCount.get() <= Integer.MAX_VALUE)
|| !nullCount.isPresent();
}

/**
* Create a new column view based off of data already on the device. Ref count on the buffers
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar comment on emphasizing lifetime requirements on input parameters relative to this instance.

* is not incremented and none of the underlying buffers are owned by this view. The returned
* ColumnView is only valid as long as the underlying buffers remain valid. If the buffers are
* closed before this ColumnView is closed, it will result in undefined behavior.
*
* If ownership is needed, call {@link ColumnView#copyToColumnVector}
*
* @param type the type of the vector
* @param rows the number of rows in this vector.
* @param nullCount the number of nulls in the dataset.
* @param dataBuffer a host buffer required for nested types including strings and string
* categories. The ownership doesn't change on this buffer
* @param validityBuffer an optional validity buffer. Must be provided if nullCount != 0.
* The ownership doesn't change on this buffer
*/
public ColumnView(DType type, long rows, Optional<Long> nullCount,
BaseDeviceMemoryBuffer dataBuffer,
BaseDeviceMemoryBuffer validityBuffer) {
this(type, (int) rows, nullCount.orElse(UNKNOWN_NULL_COUNT).intValue(),
dataBuffer, validityBuffer, null, null);
assert (!type.isNestedType());
assert (nullCount.isPresent() && nullCount.get() <= Integer.MAX_VALUE)
|| !nullCount.isPresent();
}

private ColumnView(DType type, long rows, int nullCount,
BaseDeviceMemoryBuffer dataBuffer, BaseDeviceMemoryBuffer validityBuffer,
BaseDeviceMemoryBuffer offsetBuffer, ColumnView[] children) {
this(ColumnVector.initViewHandle(type, (int) rows, nullCount, dataBuffer, validityBuffer,
offsetBuffer, Arrays.stream(children).mapToLong(c -> c.getNativeView()).toArray()));
}

/** Creates a ColumnVector from a column view handle
* @return a new ColumnVector
Expand Down Expand Up @@ -1296,6 +1354,86 @@ public ColumnVector castTo(DType type) {
return new ColumnVector(castTo(getNativeView(), type.typeId.getNativeId(), type.getScale()));
}

/**
* This method takes in a nested type and replaces its children with the given views
revans2 marked this conversation as resolved.
Show resolved Hide resolved
* Note: Make sure the numbers of rows in the leaf node are the same as the child replacing it
* otherwise the list can point to elements outside of the column values.
razajafri marked this conversation as resolved.
Show resolved Hide resolved
*
* Note: this method returns a ColumnView that won't live past the ColumnVector that it's
* pointing to.
*
* Ex: List<Int> list = col{{1,3}, {9,3,5}}
*
* validNewChild = col{8, 3, 9, 2, 0}
*
* list.replaceChildrenWithViews(1, validNewChild) => col{{8, 3}, {9, 2, 0}}
*
* invalidNewChild = col{3, 2}
* list.replaceChildrenWithViews(1, invalidNewChild) => col{{3, 2}, {invalid, invalid, invalid}}
*
* invalidNewChild = col{8, 3, 9, 2, 0, 0, 7}
* list.replaceChildrenWithViews(1, invalidNewChild) => col{{8, 3}, {9, 2, 0}} // undefined result
*/
public ColumnView replaceChildrenWithViews(int[] indices,
jlowe marked this conversation as resolved.
Show resolved Hide resolved
ColumnView[] views) {
assert (type.isNestedType());
assert (indices.length == views.length);
if (type == DType.LIST) {
assert (indices.length == 1);
}
if (indices.length != views.length) {
throw new IllegalArgumentException("The indices size and children size should match");
}
Map<Integer, ColumnView> map = new HashMap<>();
IntStream.range(0, indices.length).forEach(index -> {
if (map.containsKey(indices[index])) {
throw new IllegalArgumentException("Duplicate mapping found for replacing child index");
}
map.put(indices[index], views[index]);
});
List<ColumnView> newChildren = new ArrayList<>(getNumChildren());
IntStream.range(0, getNumChildren()).forEach(i -> {
ColumnView view = map.remove(i);
if (view == null) {
newChildren.add(getChildColumnView(i));
} else {
newChildren.add(view);
}
});
if (!map.isEmpty()) {
throw new IllegalArgumentException("One or more invalid child indices passed to be replaced");
}
return new ColumnView(type, getRowCount(), Optional.of(getNullCount()), getValid(),
getOffsets(), newChildren.stream().toArray(n -> new ColumnView[n]));
}

/**
* This method takes in a list and returns a new list with the leaf node replaced with the given
* view. Make sure the numbers of rows in the leaf node are the same as the child replacing it
* otherwise the list can point to elements outside of the column values.
*
* Note: this method returns a ColumnView that won't live past the ColumnVector that it's
* pointing to.
*
* Ex: List<Int> list = col{{1,3}, {9,3,5}}
*
* validNewChild = col{8, 3, 9, 2, 0}
*
* list.replaceChildrenWithViews(1, validNewChild) => col{{8, 3}, {9, 2, 0}}
*
* invalidNewChild = col{3, 2}
* list.replaceChildrenWithViews(1, invalidNewChild) =>
* col{{3, 2}, {invalid, invalid, invalid}} throws an exception
*
* invalidNewChild = col{8, 3, 9, 2, 0, 0, 7}
* list.replaceChildrenWithViews(1, invalidNewChild) =>
* col{{8, 3}, {9, 2, 0}} throws an exception
*/
public ColumnView replaceListChild(ColumnView child) {
jlowe marked this conversation as resolved.
Show resolved Hide resolved
assert(type == DType.LIST);
return replaceChildrenWithViews(new int[]{1}, new ColumnView[]{child});
}

/**
* Zero-copy cast between types with the same underlying representation.
*
Expand Down
91 changes: 0 additions & 91 deletions java/src/main/native/src/ColumnVectorJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
#include "cudf_jni_apis.hpp"
#include "dtype_utils.hpp"


extern "C" {

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_sequence(JNIEnv *env, jclass,
Expand Down Expand Up @@ -315,96 +314,6 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_makeEmptyCudfColumn(JNI
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_makeNumericCudfColumn(
revans2 marked this conversation as resolved.
Show resolved Hide resolved
JNIEnv *env, jobject j_object, jint j_type, jint j_size, jint j_mask_state) {

JNI_ARG_CHECK(env, (j_size != 0), "size is 0", 0);

try {
cudf::jni::auto_set_device(env);
cudf::type_id n_type = static_cast<cudf::type_id>(j_type);
cudf::data_type n_data_type(n_type);
cudf::size_type n_size = static_cast<cudf::size_type>(j_size);
cudf::mask_state n_mask_state = static_cast<cudf::mask_state>(j_mask_state);
std::unique_ptr<cudf::column> column(
cudf::make_numeric_column(n_data_type, n_size, n_mask_state));
return reinterpret_cast<jlong>(column.release());
}
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_makeTimestampCudfColumn(
JNIEnv *env, jobject j_object, jint j_type, jint j_size, jint j_mask_state) {

JNI_NULL_CHECK(env, j_type, "type id is null", 0);
JNI_NULL_CHECK(env, j_size, "size is null", 0);

try {
cudf::jni::auto_set_device(env);
cudf::type_id n_type = static_cast<cudf::type_id>(j_type);
std::unique_ptr<cudf::data_type> n_data_type(new cudf::data_type(n_type));
cudf::size_type n_size = static_cast<cudf::size_type>(j_size);
cudf::mask_state n_mask_state = static_cast<cudf::mask_state>(j_mask_state);
std::unique_ptr<cudf::column> column(
cudf::make_timestamp_column(*n_data_type.get(), n_size, n_mask_state));
return reinterpret_cast<jlong>(column.release());
}
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnVector_makeStringCudfColumnHostSide(
JNIEnv *env, jobject j_object, jlong j_char_data, jlong j_offset_data, jlong j_valid_data,
jint j_null_count, jint size) {

JNI_ARG_CHECK(env, (size != 0), "size is 0", 0);
JNI_NULL_CHECK(env, j_char_data, "char data is null", 0);
JNI_NULL_CHECK(env, j_offset_data, "offset is null", 0);

try {
cudf::jni::auto_set_device(env);
cudf::size_type *host_offsets = reinterpret_cast<cudf::size_type *>(j_offset_data);
char *n_char_data = reinterpret_cast<char *>(j_char_data);
cudf::size_type n_data_size = host_offsets[size];
cudf::bitmask_type *n_validity = reinterpret_cast<cudf::bitmask_type *>(j_valid_data);

if (n_validity == nullptr) {
j_null_count = 0;
}

std::unique_ptr<cudf::column> offsets = cudf::make_numeric_column(
cudf::data_type{cudf::type_id::INT32}, size + 1, cudf::mask_state::UNALLOCATED);
auto offsets_view = offsets->mutable_view();
JNI_CUDA_TRY(env, 0,
cudaMemcpyAsync(offsets_view.data<int32_t>(), host_offsets,
(size + 1) * sizeof(int32_t), cudaMemcpyHostToDevice));

std::unique_ptr<cudf::column> data = cudf::make_numeric_column(
cudf::data_type{cudf::type_id::INT8}, n_data_size, cudf::mask_state::UNALLOCATED);
auto data_view = data->mutable_view();
JNI_CUDA_TRY(env, 0,
cudaMemcpyAsync(data_view.data<int8_t>(), n_char_data, n_data_size,
cudaMemcpyHostToDevice));

std::unique_ptr<cudf::column> column;
if (j_null_count == 0) {
column =
cudf::make_strings_column(size, std::move(offsets), std::move(data), j_null_count, {});
} else {
cudf::size_type bytes = (cudf::word_index(size) + 1) * sizeof(cudf::bitmask_type);
rmm::device_buffer dev_validity(bytes);
JNI_CUDA_TRY(env, 0,
cudaMemcpyAsync(dev_validity.data(), n_validity, bytes, cudaMemcpyHostToDevice));

column = cudf::make_strings_column(size, std::move(offsets), std::move(data), j_null_count,
std::move(dev_validity));
}

JNI_CUDA_TRY(env, 0, cudaStreamSynchronize(0));
return reinterpret_cast<jlong>(column.release());
}
CATCH_STD(env, 0);
}

JNIEXPORT jint JNICALL Java_ai_rapids_cudf_ColumnVector_getNativeNullCountColumn(JNIEnv *env,
jobject j_object,
jlong handle) {
Expand Down
101 changes: 101 additions & 0 deletions java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -3951,4 +3951,105 @@ void testMakeList() {
assertColumnsAreEqual(expected, created);
}
}

@Test
void testReplaceLeafNodeInList() {
try (
ColumnVector c1 = ColumnVector.fromInts(1, 2);
ColumnVector c2 = ColumnVector.fromInts(8, 3);
ColumnVector c3 = ColumnVector.fromInts(9, 8);
ColumnVector c4 = ColumnVector.fromInts(2, 6);
ColumnVector expected = ColumnVector.makeList(c1, c2, c3, c4);
ColumnVector child1 =
ColumnVector.decimalFromDoubles(DType.create(DType.DTypeEnum.DECIMAL64, 3),
RoundingMode.HALF_UP, 770.892, 961.110);
ColumnVector child2 =
ColumnVector.decimalFromDoubles(DType.create(DType.DTypeEnum.DECIMAL64, 3),
RoundingMode.HALF_UP, 524.982, 479.946);
ColumnVector child3 =
ColumnVector.decimalFromDoubles(DType.create(DType.DTypeEnum.DECIMAL64, 3),
RoundingMode.HALF_UP, 346.997, 479.946);
ColumnVector child4 =
ColumnVector.decimalFromDoubles(DType.create(DType.DTypeEnum.DECIMAL64, 3),
RoundingMode.HALF_UP, 87.764, 414.239);
ColumnVector created = ColumnVector.makeList(child1, child2, child3, child4);
ColumnVector newChild = ColumnVector.fromInts(1, 8, 9, 2, 2, 3, 8, 6);
ColumnView replacedView = created.replaceListChild(newChild)) {
try (ColumnVector replaced = replacedView.copyToColumnVector()) {
assertColumnsAreEqual(expected, replaced);
}
}
}

@Test
void testReplaceLeafNodeInListWithIllegal() {
assertThrows(IllegalArgumentException.class, () -> {
try (ColumnVector child1 =
ColumnVector.decimalFromDoubles(DType.create(DType.DTypeEnum.DECIMAL64, 3),
RoundingMode.HALF_UP, 770.892, 961.110);
ColumnVector child2 =
ColumnVector.decimalFromDoubles(DType.create(DType.DTypeEnum.DECIMAL64, 3),
RoundingMode.HALF_UP, 524.982, 479.946);
ColumnVector child3 =
ColumnVector.decimalFromDoubles(DType.create(DType.DTypeEnum.DECIMAL64, 3),
RoundingMode.HALF_UP, 346.997, 479.946);
ColumnVector child4 =
ColumnVector.decimalFromDoubles(DType.create(DType.DTypeEnum.DECIMAL64, 3),
RoundingMode.HALF_UP, 87.764, 414.239);
ColumnVector created = ColumnVector.makeList(child1, child2, child3, child4);
ColumnVector newChild = ColumnVector.fromInts(0, 1, 8, 9, 2, 2, 3, 8, 6);
ColumnView replacedView = created.replaceListChild(newChild)) {
}
});
}

@Test
void testReplaceColumnInStruct() {
try (ColumnVector expected = ColumnVector.fromStructs(new StructType(false,
Arrays.asList(
new BasicType(false, DType.INT32),
new BasicType(false, DType.INT32),
new BasicType(false, DType.INT32))),
new HostColumnVector.StructData(1, 5, 3),
new HostColumnVector.StructData(4, 9, 6));
ColumnVector child1 = ColumnVector.fromInts(1, 4);
ColumnVector child2 = ColumnVector.fromInts(2, 5);
ColumnVector child3 = ColumnVector.fromInts(3, 6);
ColumnVector created = ColumnVector.makeStruct(child1, child2, child3);
ColumnVector replaceWith = ColumnVector.fromInts(5, 9);
ColumnView replacedView = created.replaceChildrenWithViews(new int[]{1},
new ColumnVector[]{replaceWith})) {
try (ColumnVector replaced = replacedView.copyToColumnVector()) {
assertColumnsAreEqual(expected, replaced);
}
}
}

@Test
void testReplaceIllegalIndexColumnInStruct() {
assertThrows(IllegalArgumentException.class, () -> {
try (ColumnVector child1 = ColumnVector.fromInts(1, 4);
ColumnVector child2 = ColumnVector.fromInts(2, 5);
ColumnVector child3 = ColumnVector.fromInts(3, 6);
ColumnVector created = ColumnVector.makeStruct(child1, child2, child3);
ColumnVector replaceWith = ColumnVector.fromInts(5, 9);
ColumnView replacedView = created.replaceChildrenWithViews(new int[]{5},
new ColumnVector[]{replaceWith})) {
}
});
}

@Test
void testReplaceSameIndexColumnInStruct() {
assertThrows(IllegalArgumentException.class, () -> {
try (ColumnVector child1 = ColumnVector.fromInts(1, 4);
ColumnVector child2 = ColumnVector.fromInts(2, 5);
ColumnVector child3 = ColumnVector.fromInts(3, 6);
ColumnVector created = ColumnVector.makeStruct(child1, child2, child3);
ColumnVector replaceWith = ColumnVector.fromInts(5, 9);
ColumnView replacedView = created.replaceChildrenWithViews(new int[]{1, 1},
new ColumnVector[]{replaceWith, replaceWith})) {
}
});
}
}