Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add replacements column support for Java replaceNulls [skip ci] #7750

Merged
merged 2 commits into from
Mar 29, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions java/src/main/java/ai/rapids/cudf/ColumnView.java
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,19 @@ public final ColumnVector findAndReplaceAll(ColumnView oldValues, ColumnView new
* @return - ColumnVector with nulls replaced by scalar
*/
public final ColumnVector replaceNulls(Scalar scalar) {
return new ColumnVector(replaceNulls(getNativeView(), scalar.getScalarHandle()));
return new ColumnVector(replaceNullsScalar(getNativeView(), scalar.getScalarHandle()));
}

/**
* Returns a ColumnVector with any null values replaced with the corresponding row in the
* specified replacement column.
* This column and the replacement column must have the same type and number of rows.
*
* @param replacements column of replacement values
* @return column with nulls replaced by corresponding row of replacements column
*/
public final ColumnVector replaceNulls(ColumnView replacements) {
return new ColumnVector(replaceNullsColumn(getNativeView(), replacements.getNativeView()));
}

/**
Expand Down Expand Up @@ -2840,7 +2852,9 @@ private static native long rollingWindow(

private static native long charLengths(long viewHandle) throws CudfException;

private static native long replaceNulls(long viewHandle, long scalarHandle) throws CudfException;
private static native long replaceNullsScalar(long viewHandle, long scalarHandle) throws CudfException;

private static native long replaceNullsColumn(long viewHandle, long replaceViewHandle) throws CudfException;

private static native long ifElseVV(long predVec, long trueVec, long falseVec) throws CudfException;

Expand Down
20 changes: 18 additions & 2 deletions java/src/main/native/src/ColumnViewJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,9 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_lowerStrings(JNIEnv *env,
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_replaceNulls(JNIEnv *env, jclass,
jlong j_col, jlong j_scalar) {
JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_replaceNullsScalar(JNIEnv *env, jclass,
jlong j_col,
jlong j_scalar) {
JNI_NULL_CHECK(env, j_col, "column is null", 0);
JNI_NULL_CHECK(env, j_scalar, "scalar is null", 0);
try {
Expand All @@ -135,6 +136,21 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_replaceNulls(JNIEnv *env,
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_replaceNullsColumn(JNIEnv *env, jclass,
jlong j_col,
jlong j_replace_col) {
JNI_NULL_CHECK(env, j_col, "column is null", 0);
JNI_NULL_CHECK(env, j_replace_col, "replacement column is null", 0);
try {
cudf::jni::auto_set_device(env);
auto col = reinterpret_cast<cudf::column_view *>(j_col);
auto replacements = reinterpret_cast<cudf::column_view *>(j_replace_col);
std::unique_ptr<cudf::column> result = cudf::replace_nulls(*col, *replacements);
return reinterpret_cast<jlong>(result.release());
}
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_ifElseVV(JNIEnv *env, jclass,
jlong j_pred_vec,
jlong j_true_vec,
Expand Down
50 changes: 44 additions & 6 deletions java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -1368,7 +1368,7 @@ void testFromScalarNullByte() {
}

@Test
void testReplaceEmptyColumn() {
void testReplaceNullsScalarEmptyColumn() {
try (ColumnVector input = ColumnVector.fromBoxedBooleans();
ColumnVector expected = ColumnVector.fromBoxedBooleans();
Scalar s = Scalar.fromBool(false);
Expand All @@ -1378,7 +1378,7 @@ void testReplaceEmptyColumn() {
}

@Test
void testReplaceNullBoolsWithAllNulls() {
void testReplaceNullsScalarBoolsWithAllNulls() {
try (ColumnVector input = ColumnVector.fromBoxedBooleans(null, null, null, null);
ColumnVector expected = ColumnVector.fromBoxedBooleans(false, false, false, false);
Scalar s = Scalar.fromBool(false);
Expand All @@ -1388,7 +1388,7 @@ void testReplaceNullBoolsWithAllNulls() {
}

@Test
void testReplaceSomeNullBools() {
void testReplaceNullsScalarSomeNullBools() {
try (ColumnVector input = ColumnVector.fromBoxedBooleans(false, null, null, false);
ColumnVector expected = ColumnVector.fromBoxedBooleans(false, true, true, false);
Scalar s = Scalar.fromBool(true);
Expand All @@ -1398,7 +1398,7 @@ void testReplaceSomeNullBools() {
}

@Test
void testReplaceNullIntegersWithAllNulls() {
void testReplaceNullsScalarIntegersWithAllNulls() {
try (ColumnVector input = ColumnVector.fromBoxedInts(null, null, null, null);
ColumnVector expected = ColumnVector.fromBoxedInts(0, 0, 0, 0);
Scalar s = Scalar.fromInt(0);
Expand All @@ -1408,7 +1408,7 @@ void testReplaceNullIntegersWithAllNulls() {
}

@Test
void testReplaceSomeNullIntegers() {
void testReplaceNullsScalarSomeNullIntegers() {
try (ColumnVector input = ColumnVector.fromBoxedInts(1, 2, null, 4, null);
ColumnVector expected = ColumnVector.fromBoxedInts(1, 2, 999, 4, 999);
Scalar s = Scalar.fromInt(999);
Expand All @@ -1418,7 +1418,7 @@ void testReplaceSomeNullIntegers() {
}

@Test
void testReplaceNullsFailsOnTypeMismatch() {
void testReplaceNullsScalarFailsOnTypeMismatch() {
try (ColumnVector input = ColumnVector.fromBoxedInts(1, 2, null, 4, null);
Scalar s = Scalar.fromBool(true)) {
assertThrows(CudfException.class, () -> input.replaceNulls(s).close());
Expand All @@ -1434,6 +1434,44 @@ void testReplaceNullsWithNullScalar() {
}
}

@Test
void testReplaceNullsColumnEmptyColumn() {
try (ColumnVector input = ColumnVector.fromBoxedBooleans();
ColumnVector r = ColumnVector.fromBoxedBooleans();
ColumnVector expected = ColumnVector.fromBoxedBooleans();
ColumnVector result = input.replaceNulls(r)) {
assertColumnsAreEqual(expected, result);
}
}

@Test
void testReplaceNullsColumnBools() {
try (ColumnVector input = ColumnVector.fromBoxedBooleans(null, true, null, false);
ColumnVector r = ColumnVector.fromBoxedBooleans(false, null, true, true);
ColumnVector expected = ColumnVector.fromBoxedBooleans(false, true, true, false);
ColumnVector result = input.replaceNulls(r)) {
assertColumnsAreEqual(expected, result);
}
}

@Test
void testReplaceNullsColumnIntegers() {
try (ColumnVector input = ColumnVector.fromBoxedInts(1, 2, null, 4, null);
ColumnVector r = ColumnVector.fromBoxedInts(996, 997, 998, 909, null);
ColumnVector expected = ColumnVector.fromBoxedInts(1, 2, 998, 4, null);
ColumnVector result = input.replaceNulls(r)) {
assertColumnsAreEqual(expected, result);
}
}

@Test
void testReplaceNullsColumnFailsOnTypeMismatch() {
try (ColumnVector input = ColumnVector.fromBoxedInts(1, 2, null, 4, null);
ColumnVector r = ColumnVector.fromBoxedBooleans(true)) {
assertThrows(CudfException.class, () -> input.replaceNulls(r).close());
}
}

static QuantileMethod[] methods = {LINEAR, LOWER, HIGHER, MIDPOINT, NEAREST};
static double[] quantiles = {0.0, 0.25, 0.33, 0.5, 1.0};

Expand Down