Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check if a map contains a specific key [skip ci] #8209

Merged
merged 13 commits into from
May 14, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions java/src/main/java/ai/rapids/cudf/ColumnView.java
Original file line number Diff line number Diff line change
Expand Up @@ -2526,6 +2526,13 @@ public final ColumnVector getMapValue(Scalar key) {
return new ColumnVector(mapLookup(getNativeView(), key.getScalarHandle()));
}

public final ColumnVector getMapKeyExistence(Scalar key) {
firestarman marked this conversation as resolved.
Show resolved Hide resolved
assert type.equals(DType.LIST) : "column type must be a LIST";
assert key != null : "target string may not be null";
assert key.getType().equals(DType.STRING) : "target string must be a string scalar";
firestarman marked this conversation as resolved.
Show resolved Hide resolved

return new ColumnVector(mapContains(getNativeView(), key.getScalarHandle()));
}

/**
* Create a new struct column view of existing column views. Note that this will NOT copy
Expand Down Expand Up @@ -2844,6 +2851,8 @@ private static native long stringReplaceWithBackrefs(long columnView, String pat
* @throws CudfException
*/
private static native long mapLookup(long columnView, long key) throws CudfException;

private static native long mapContains(long columnView, long key) throws CudfException;
/**
* Native method to add zeros as padding to the left of each string.
*/
Expand Down
16 changes: 16 additions & 0 deletions java/src/main/native/src/ColumnViewJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1169,6 +1169,22 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_mapLookup(JNIEnv *env, jc
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_mapContains(JNIEnv *env, jclass,
jlong map_column_view,
jlong lookup_key) {
jlowe marked this conversation as resolved.
Show resolved Hide resolved
JNI_NULL_CHECK(env, map_column_view, "column is null", 0);
JNI_NULL_CHECK(env, lookup_key, "target string scalar is null", 0);
try {
cudf::jni::auto_set_device(env);
cudf::column_view *cv = reinterpret_cast<cudf::column_view *>(map_column_view);
cudf::string_scalar *ss_key = reinterpret_cast<cudf::string_scalar *>(lookup_key);

std::unique_ptr<cudf::column> result = cudf::jni::map_contains(*cv, *ss_key);
firestarman marked this conversation as resolved.
Show resolved Hide resolved
return reinterpret_cast<jlong>(result.release());
}
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_stringReplaceWithBackrefs(JNIEnv *env,
jclass,
jlong column_view,
Expand Down
29 changes: 29 additions & 0 deletions java/src/main/native/src/map_lookup.cu
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,35 @@ get_gather_map_for_map_values(column_view const &input, string_scalar &lookup_ke
} // namespace

namespace jni {

std::unique_ptr<column> map_contains(column_view const &map_column, string_scalar lookup_key,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's better to reuse this function in map_lookup instead of duplicating codes.

bool has_nulls, rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource *mr) {
// Defensive checks.
CUDF_EXPECTS(map_column.type().id() == type_id::LIST, "Expected LIST<STRUCT<key,value>>.");

lists_column_view lcv{map_column};
auto structs_column = lcv.get_sliced_child(stream);

CUDF_EXPECTS(structs_column.type().id() == type_id::STRUCT, "Expected LIST<STRUCT<key,value>>.");

structs_column_view scv{structs_column};
CUDF_EXPECTS(structs_column.num_children() == 2, "Expected LIST<STRUCT<key,value>>.");
CUDF_EXPECTS(structs_column.child(0).type().id() == type_id::STRING,
"Expected LIST<STRUCT<key,value>>.");
CUDF_EXPECTS(structs_column.child(1).type().id() == type_id::STRING,
"Expected LIST<STRUCT<key,value>>.");

// Two-pass plan: construct gather map, and then gather() on structs_column.child(1). Plan A.
jlowe marked this conversation as resolved.
Show resolved Hide resolved
// (Can do in one pass perhaps, but that's Plan B.)

auto gather_map = has_nulls ?
Copy link
Contributor

@nvdbaranec nvdbaranec May 11, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is a better alternative here. We have a list<struct<key, value>> column and we want to find whether or not each row contains the incoming key. That to me sounds like cudf::contains() if we were passing a list(string) column where the strings were just the keys. It should be possible to construct a fake column_view here that gives us this structure. Roughly:

lists_column_view lcv(map_column);
structs_column_view scv(lcv.child());
  
std::vector<column_view> children;
children.push_back(lcv.offsets());  // offsets
children.push_back(scv.child(0));   // keys (a string column)

column_view list_of_keys(map_column.type(), map_column.size(),
  nullptr, map_column.null_mask(), map_column.null_count(), 0, children);

@mythrocks is my thinking here about contains() sound?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is a better alternative here. We have a list<struct<key, value>> column and we want to find whether or not each row contains the incoming key. That to me sounds like cudf::contains() if we were passing a list(string) column where the strings were just the keys. It should be possible to construct a fake column_view here that gives us this structure. Roughly:

lists_column_view lcv(map_column);
structs_column_view scv(lcv.child());
  
std::vector<column_view> children;
children.push_back(lcv.offsets());  // offsets
children.push_back(scv.child(0));   // keys (a string column)

column_view list_of_keys(map_column.type(), map_column.size(),
  nullptr, map_column.null_mask(), map_column.null_count(), 0, children);

@mythrocks is my thinking here about contains() sound?

Thank you Dave, I've updated the code according to cudf::contains().
The only change here is I manully added a null_mask(0) for the contains_column. To make nulls count in all_aggregation.

Please help review, thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nvdbaranec, that makes perfect sense. Thank you for the suggestion.

I'll need to examine the null mask issue more closely before I can comment.

get_gather_map_for_map_values<true>(map_column, lookup_key, stream, mr) :
get_gather_map_for_map_values<false>(map_column, lookup_key, stream, mr);
return gather_map;
}


std::unique_ptr<column> map_lookup(column_view const &map_column, string_scalar lookup_key,
bool has_nulls, rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource *mr) {
Expand Down
29 changes: 29 additions & 0 deletions java/src/main/native/src/map_lookup.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,35 @@ map_lookup(column_view const &map_column, string_scalar lookup_key, bool has_nul
rmm::cuda_stream_view stream = rmm::cuda_stream_default,
rmm::mr::device_memory_resource *mr = rmm::mr::get_current_device_resource());


/**
* @brief Looks up a "map" column by specified key to see if the key exists or not,
* and returns a column of int values.
jlowe marked this conversation as resolved.
Show resolved Hide resolved
*
* The map-column is represented as follows:
*
* list_view<struct_view< string_view, string_view > >.
* <---KEY---> <--VALUE-->
*
* The string_view struct members are the key and value, respectively.
* For each row in the input list column. If the key is not found, -1 is returned.
jlowe marked this conversation as resolved.
Show resolved Hide resolved
*
* @param map_column The input "map" column to be searched. Must be of
* type list_view<struct_view<string_view, string_view>>.
* @param lookup_key The search key, whose value is to be returned for each list row
* @param has_nulls Whether the input column might contain null list-rows, or null keys.
* @param stream The CUDA stream
* @param mr The device memory resource to be used for allocations
* @return A string_view column with the value from the first match in each list.
firestarman marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment for return is not correct.

* A null row is returned for any row where the lookup_key is not found.
* @throw cudf::logic_error If the input column is not of type
* list_view<struct_view<string_view, string_view>>
*/
std::unique_ptr<column>
map_contains(column_view const &map_column, string_scalar lookup_key, bool has_nulls = true,
mythrocks marked this conversation as resolved.
Show resolved Hide resolved
rmm::cuda_stream_view stream = rmm::cuda_stream_default,
rmm::mr::device_memory_resource *mr = rmm::mr::get_current_device_resource());

} // namespace jni

} // namespace cudf
17 changes: 17 additions & 0 deletions java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -4410,6 +4410,23 @@ void testGetMapValue() {
assertColumnsAreEqual(expected, res);
}
}
@Test
void testGetMapKeyExistence() {
jlowe marked this conversation as resolved.
Show resolved Hide resolved
List<HostColumnVector.StructData> list1 = Arrays.asList(new HostColumnVector.StructData(Arrays.asList("a", "b")));
firestarman marked this conversation as resolved.
Show resolved Hide resolved
List<HostColumnVector.StructData> list2 = Arrays.asList(new HostColumnVector.StructData(Arrays.asList("a", "c")));
List<HostColumnVector.StructData> list3 = Arrays.asList(new HostColumnVector.StructData(Arrays.asList("e", "d")));
List<HostColumnVector.StructData> list4 = Arrays.asList(new HostColumnVector.StructData(Arrays.asList("a", "g")));
List<HostColumnVector.StructData> list5 = Arrays.asList(new HostColumnVector.StructData(Arrays.asList("f", "h")));
List<HostColumnVector.StructData> list6 = Arrays.asList(new HostColumnVector.StructData(Arrays.asList("a", null)));
List<HostColumnVector.StructData> list7 = Arrays.asList(new HostColumnVector.StructData(Arrays.asList(null, null)));
firestarman marked this conversation as resolved.
Show resolved Hide resolved
HostColumnVector.StructType structType = new HostColumnVector.StructType(true, Arrays.asList(new HostColumnVector.BasicType(true, DType.STRING),
new HostColumnVector.BasicType(true, DType.STRING)));
try (ColumnVector cv = ColumnVector.fromLists(new HostColumnVector.ListType(true, structType), list1, list2, list3, list4, list5, list6, list7);
ColumnVector res = cv.getMapKeyExistence(Scalar.fromString("a"));
ColumnVector expected = ColumnVector.fromInts(0, 1, -1, 3, -1, 5, -1)) {
assertColumnsAreEqual(expected, res);
}
}

@Test
void testListOfStructsOfStructs() {
Expand Down