Skip to content

Commit

Permalink
Add replacements column support for Java replaceNulls (#7750)
Browse files Browse the repository at this point in the history
Adds Java bindings for `cudf::replace_nulls` with a columnar replacement parameter.

Authors:
  - Jason Lowe (@jlowe)

Approvers:
  - Robert (Bobby) Evans (@revans2)

URL: #7750
  • Loading branch information
jlowe authored Mar 29, 2021
1 parent 54dfaaa commit cddafd9
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 10 deletions.
18 changes: 16 additions & 2 deletions java/src/main/java/ai/rapids/cudf/ColumnView.java
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,19 @@ public final ColumnVector findAndReplaceAll(ColumnView oldValues, ColumnView new
* @return - ColumnVector with nulls replaced by scalar
*/
public final ColumnVector replaceNulls(Scalar scalar) {
return new ColumnVector(replaceNulls(getNativeView(), scalar.getScalarHandle()));
return new ColumnVector(replaceNullsScalar(getNativeView(), scalar.getScalarHandle()));
}

/**
* Returns a ColumnVector with any null values replaced with the corresponding row in the
* specified replacement column.
* This column and the replacement column must have the same type and number of rows.
*
* @param replacements column of replacement values
* @return column with nulls replaced by corresponding row of replacements column
*/
public final ColumnVector replaceNulls(ColumnView replacements) {
return new ColumnVector(replaceNullsColumn(getNativeView(), replacements.getNativeView()));
}

/**
Expand Down Expand Up @@ -2840,7 +2852,9 @@ private static native long rollingWindow(

private static native long charLengths(long viewHandle) throws CudfException;

private static native long replaceNulls(long viewHandle, long scalarHandle) throws CudfException;
private static native long replaceNullsScalar(long viewHandle, long scalarHandle) throws CudfException;

private static native long replaceNullsColumn(long viewHandle, long replaceViewHandle) throws CudfException;

private static native long ifElseVV(long predVec, long trueVec, long falseVec) throws CudfException;

Expand Down
20 changes: 18 additions & 2 deletions java/src/main/native/src/ColumnViewJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,9 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_lowerStrings(JNIEnv *env,
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_replaceNulls(JNIEnv *env, jclass,
jlong j_col, jlong j_scalar) {
JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_replaceNullsScalar(JNIEnv *env, jclass,
jlong j_col,
jlong j_scalar) {
JNI_NULL_CHECK(env, j_col, "column is null", 0);
JNI_NULL_CHECK(env, j_scalar, "scalar is null", 0);
try {
Expand All @@ -135,6 +136,21 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_replaceNulls(JNIEnv *env,
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_replaceNullsColumn(JNIEnv *env, jclass,
jlong j_col,
jlong j_replace_col) {
JNI_NULL_CHECK(env, j_col, "column is null", 0);
JNI_NULL_CHECK(env, j_replace_col, "replacement column is null", 0);
try {
cudf::jni::auto_set_device(env);
auto col = reinterpret_cast<cudf::column_view *>(j_col);
auto replacements = reinterpret_cast<cudf::column_view *>(j_replace_col);
std::unique_ptr<cudf::column> result = cudf::replace_nulls(*col, *replacements);
return reinterpret_cast<jlong>(result.release());
}
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_ifElseVV(JNIEnv *env, jclass,
jlong j_pred_vec,
jlong j_true_vec,
Expand Down
50 changes: 44 additions & 6 deletions java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -1368,7 +1368,7 @@ void testFromScalarNullByte() {
}

@Test
void testReplaceEmptyColumn() {
void testReplaceNullsScalarEmptyColumn() {
try (ColumnVector input = ColumnVector.fromBoxedBooleans();
ColumnVector expected = ColumnVector.fromBoxedBooleans();
Scalar s = Scalar.fromBool(false);
Expand All @@ -1378,7 +1378,7 @@ void testReplaceEmptyColumn() {
}

@Test
void testReplaceNullBoolsWithAllNulls() {
void testReplaceNullsScalarBoolsWithAllNulls() {
try (ColumnVector input = ColumnVector.fromBoxedBooleans(null, null, null, null);
ColumnVector expected = ColumnVector.fromBoxedBooleans(false, false, false, false);
Scalar s = Scalar.fromBool(false);
Expand All @@ -1388,7 +1388,7 @@ void testReplaceNullBoolsWithAllNulls() {
}

@Test
void testReplaceSomeNullBools() {
void testReplaceNullsScalarSomeNullBools() {
try (ColumnVector input = ColumnVector.fromBoxedBooleans(false, null, null, false);
ColumnVector expected = ColumnVector.fromBoxedBooleans(false, true, true, false);
Scalar s = Scalar.fromBool(true);
Expand All @@ -1398,7 +1398,7 @@ void testReplaceSomeNullBools() {
}

@Test
void testReplaceNullIntegersWithAllNulls() {
void testReplaceNullsScalarIntegersWithAllNulls() {
try (ColumnVector input = ColumnVector.fromBoxedInts(null, null, null, null);
ColumnVector expected = ColumnVector.fromBoxedInts(0, 0, 0, 0);
Scalar s = Scalar.fromInt(0);
Expand All @@ -1408,7 +1408,7 @@ void testReplaceNullIntegersWithAllNulls() {
}

@Test
void testReplaceSomeNullIntegers() {
void testReplaceNullsScalarSomeNullIntegers() {
try (ColumnVector input = ColumnVector.fromBoxedInts(1, 2, null, 4, null);
ColumnVector expected = ColumnVector.fromBoxedInts(1, 2, 999, 4, 999);
Scalar s = Scalar.fromInt(999);
Expand All @@ -1418,7 +1418,7 @@ void testReplaceSomeNullIntegers() {
}

@Test
void testReplaceNullsFailsOnTypeMismatch() {
void testReplaceNullsScalarFailsOnTypeMismatch() {
try (ColumnVector input = ColumnVector.fromBoxedInts(1, 2, null, 4, null);
Scalar s = Scalar.fromBool(true)) {
assertThrows(CudfException.class, () -> input.replaceNulls(s).close());
Expand All @@ -1434,6 +1434,44 @@ void testReplaceNullsWithNullScalar() {
}
}

@Test
void testReplaceNullsColumnEmptyColumn() {
try (ColumnVector input = ColumnVector.fromBoxedBooleans();
ColumnVector r = ColumnVector.fromBoxedBooleans();
ColumnVector expected = ColumnVector.fromBoxedBooleans();
ColumnVector result = input.replaceNulls(r)) {
assertColumnsAreEqual(expected, result);
}
}

@Test
void testReplaceNullsColumnBools() {
try (ColumnVector input = ColumnVector.fromBoxedBooleans(null, true, null, false);
ColumnVector r = ColumnVector.fromBoxedBooleans(false, null, true, true);
ColumnVector expected = ColumnVector.fromBoxedBooleans(false, true, true, false);
ColumnVector result = input.replaceNulls(r)) {
assertColumnsAreEqual(expected, result);
}
}

@Test
void testReplaceNullsColumnIntegers() {
try (ColumnVector input = ColumnVector.fromBoxedInts(1, 2, null, 4, null);
ColumnVector r = ColumnVector.fromBoxedInts(996, 997, 998, 909, null);
ColumnVector expected = ColumnVector.fromBoxedInts(1, 2, 998, 4, null);
ColumnVector result = input.replaceNulls(r)) {
assertColumnsAreEqual(expected, result);
}
}

@Test
void testReplaceNullsColumnFailsOnTypeMismatch() {
try (ColumnVector input = ColumnVector.fromBoxedInts(1, 2, null, 4, null);
ColumnVector r = ColumnVector.fromBoxedBooleans(true)) {
assertThrows(CudfException.class, () -> input.replaceNulls(r).close());
}
}

static QuantileMethod[] methods = {LINEAR, LOWER, HIGHER, MIDPOINT, NEAREST};
static double[] quantiles = {0.0, 0.25, 0.33, 0.5, 1.0};

Expand Down

0 comments on commit cddafd9

Please sign in to comment.