Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check if a map contains a specific key [skip ci] #8209

Merged
merged 13 commits into from
May 14, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions java/src/main/java/ai/rapids/cudf/ColumnView.java
Original file line number Diff line number Diff line change
Expand Up @@ -2526,6 +2526,18 @@ public final ColumnVector getMapValue(Scalar key) {
return new ColumnVector(mapLookup(getNativeView(), key.getScalarHandle()));
}

/** For a column of type List<Struct<String, String>> and a passed in String key, return a boolean
* column for all keys in the structs, false if the key doesn't exist.
* @param key the String scalar to lookup in the column
* @return a boolean column based on the lookup result
*/
public final ColumnVector getMapKeyExistence(Scalar key) {
firestarman marked this conversation as resolved.
Show resolved Hide resolved
assert type.equals(DType.LIST) : "column type must be a LIST";
assert key != null : "target string may not be null";
assert key.getType().equals(DType.STRING) : "target must be a string scalar";

return new ColumnVector(mapContains(getNativeView(), key.getScalarHandle()));
}

/**
* Create a new struct column view of existing column views. Note that this will NOT copy
Expand Down Expand Up @@ -2844,6 +2856,15 @@ private static native long stringReplaceWithBackrefs(long columnView, String pat
* @throws CudfException
*/
private static native long mapLookup(long columnView, long key) throws CudfException;

/**
* Native method for check the existence of a key over a column of List<Struct<String,String>>
* @param columnView the column view handle of the map
* @param key the string scalar that is the key for lookup
* @return an boolean column handle of the resultant
jlowe marked this conversation as resolved.
Show resolved Hide resolved
* @throws CudfException
*/
private static native long mapContains(long columnView, long key) throws CudfException;
/**
* Native method to add zeros as padding to the left of each string.
*/
Expand Down
16 changes: 16 additions & 0 deletions java/src/main/native/src/ColumnViewJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1169,6 +1169,22 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_mapLookup(JNIEnv *env, jc
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_mapContains(JNIEnv *env, jclass,
jlong map_column_view,
jlong lookup_key) {
jlowe marked this conversation as resolved.
Show resolved Hide resolved
JNI_NULL_CHECK(env, map_column_view, "column is null", 0);
JNI_NULL_CHECK(env, lookup_key, "target string scalar is null", 0);
try {
cudf::jni::auto_set_device(env);
cudf::column_view *cv = reinterpret_cast<cudf::column_view *>(map_column_view);
cudf::string_scalar *ss_key = reinterpret_cast<cudf::string_scalar *>(lookup_key);

std::unique_ptr<cudf::column> result = cudf::jni::map_contains(*cv, *ss_key);
firestarman marked this conversation as resolved.
Show resolved Hide resolved
return reinterpret_cast<jlong>(result.release());
}
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_stringReplaceWithBackrefs(JNIEnv *env,
jclass,
jlong column_view,
Expand Down
44 changes: 38 additions & 6 deletions java/src/main/native/src/map_lookup.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <cudf/utilities/type_dispatcher.hpp>
#include <rmm/cuda_stream_view.hpp>
#include <rmm/mr/device/device_memory_resource.hpp>
#include <rmm/exec_policy.hpp>
jlowe marked this conversation as resolved.
Show resolved Hide resolved

namespace cudf {
namespace {
Expand Down Expand Up @@ -127,24 +128,55 @@ get_gather_map_for_map_values(column_view const &input, string_scalar &lookup_ke
} // namespace

namespace jni {
std::unique_ptr<column> map_lookup(column_view const &map_column, string_scalar lookup_key,
bool has_nulls, rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource *mr) {
// Defensive checks.


void map_check(column_view const &map_column, rmm::cuda_stream_view stream) {
wjxiz1992 marked this conversation as resolved.
Show resolved Hide resolved
CUDF_EXPECTS(map_column.type().id() == type_id::LIST, "Expected LIST<STRUCT<key,value>>.");

lists_column_view lcv{map_column};
auto structs_column = lcv.get_sliced_child(stream);
column_view structs_column = lcv.get_sliced_child(stream);

CUDF_EXPECTS(structs_column.type().id() == type_id::STRUCT, "Expected LIST<STRUCT<key,value>>.");

structs_column_view scv{structs_column};
CUDF_EXPECTS(structs_column.num_children() == 2, "Expected LIST<STRUCT<key,value>>.");
CUDF_EXPECTS(structs_column.child(0).type().id() == type_id::STRING,
"Expected LIST<STRUCT<key,value>>.");
CUDF_EXPECTS(structs_column.child(1).type().id() == type_id::STRING,
"Expected LIST<STRUCT<key,value>>.");
return;
jlowe marked this conversation as resolved.
Show resolved Hide resolved
}

std::unique_ptr<column> map_contains(column_view const &map_column, string_scalar lookup_key,
bool has_nulls, rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource *mr) {
// Defensive checks.
map_check(map_column, stream);

// Two-pass plan: construct gather map, and then gather() on structs_column.child(1). Plan A.
jlowe marked this conversation as resolved.
Show resolved Hide resolved
// (Can do in one pass perhaps, but that's Plan B.)

auto gather_map = has_nulls ?
Copy link
Contributor

@nvdbaranec nvdbaranec May 11, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is a better alternative here. We have a list<struct<key, value>> column and we want to find whether or not each row contains the incoming key. That to me sounds like cudf::contains() if we were passing a list(string) column where the strings were just the keys. It should be possible to construct a fake column_view here that gives us this structure. Roughly:

lists_column_view lcv(map_column);
structs_column_view scv(lcv.child());
  
std::vector<column_view> children;
children.push_back(lcv.offsets());  // offsets
children.push_back(scv.child(0));   // keys (a string column)

column_view list_of_keys(map_column.type(), map_column.size(),
  nullptr, map_column.null_mask(), map_column.null_count(), 0, children);

@mythrocks is my thinking here about contains() sound?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is a better alternative here. We have a list<struct<key, value>> column and we want to find whether or not each row contains the incoming key. That to me sounds like cudf::contains() if we were passing a list(string) column where the strings were just the keys. It should be possible to construct a fake column_view here that gives us this structure. Roughly:

lists_column_view lcv(map_column);
structs_column_view scv(lcv.child());
  
std::vector<column_view> children;
children.push_back(lcv.offsets());  // offsets
children.push_back(scv.child(0));   // keys (a string column)

column_view list_of_keys(map_column.type(), map_column.size(),
  nullptr, map_column.null_mask(), map_column.null_count(), 0, children);

@mythrocks is my thinking here about contains() sound?

Thank you Dave, I've updated the code according to cudf::contains().
The only change here is I manully added a null_mask(0) for the contains_column. To make nulls count in all_aggregation.

Please help review, thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nvdbaranec, that makes perfect sense. Thank you for the suggestion.

I'll need to examine the null mask issue more closely before I can comment.

get_gather_map_for_map_values<true>(map_column, lookup_key, stream, mr) :
get_gather_map_for_map_values<false>(map_column, lookup_key, stream, mr);

auto found = make_numeric_column(data_type{type_id::BOOL8}, gather_map->size(),
mask_state::UNALLOCATED, stream, mr);
thrust::transform(rmm::exec_policy(stream), thrust::make_counting_iterator<size_type>(0),
thrust::make_counting_iterator<size_type>(gather_map->size()),
found->mutable_view().template begin<bool>(),
[d_gather_map = gather_map->view().template begin<size_type>()] __device__(
auto i) { return d_gather_map[i] >= 0; });
return found;
}

std::unique_ptr<column> map_lookup(column_view const &map_column, string_scalar lookup_key,
bool has_nulls, rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource *mr) {
// Defensive checks.
map_check(map_column, stream);

lists_column_view lcv{map_column};
column_view structs_column = lcv.get_sliced_child(stream);
// Two-pass plan: construct gather map, and then gather() on structs_column.child(1). Plan A.
// (Can do in one pass perhaps, but that's Plan B.)

Expand Down
29 changes: 29 additions & 0 deletions java/src/main/native/src/map_lookup.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,35 @@ map_lookup(column_view const &map_column, string_scalar lookup_key, bool has_nul
rmm::cuda_stream_view stream = rmm::cuda_stream_default,
rmm::mr::device_memory_resource *mr = rmm::mr::get_current_device_resource());


/**
* @brief Looks up a "map" column by specified key to see if the key exists or not,
* and returns a column of int values.
jlowe marked this conversation as resolved.
Show resolved Hide resolved
*
* The map-column is represented as follows:
*
* list_view<struct_view< string_view, string_view > >.
* <---KEY---> <--VALUE-->
*
* The string_view struct members are the key and value, respectively.
* For each row in the input list column. If the key is not found, -1 is returned.
jlowe marked this conversation as resolved.
Show resolved Hide resolved
*
* @param map_column The input "map" column to be searched. Must be of
* type list_view<struct_view<string_view, string_view>>.
* @param lookup_key The search key, whose index(offset) is to be returned for each list row
* @param has_nulls Whether the input column might contain null list-rows, or null keys.
* @param stream The CUDA stream
* @param mr The device memory resource to be used for allocations
* @return An boolean_view column reflecting the existence for the key in each list.
* false means the lookup_key is not found.
* @throw cudf::logic_error If the input column is not of type
* list_view<struct_view<string_view, string_view>>
*/
std::unique_ptr<column>
map_contains(column_view const &map_column, string_scalar lookup_key, bool has_nulls = true,
mythrocks marked this conversation as resolved.
Show resolved Hide resolved
rmm::cuda_stream_view stream = rmm::cuda_stream_default,
rmm::mr::device_memory_resource *mr = rmm::mr::get_current_device_resource());

} // namespace jni

} // namespace cudf
17 changes: 17 additions & 0 deletions java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -4410,6 +4410,23 @@ void testGetMapValue() {
assertColumnsAreEqual(expected, res);
}
}
@Test
void testGetMapKeyExistence() {
jlowe marked this conversation as resolved.
Show resolved Hide resolved
List<HostColumnVector.StructData> list1 = Arrays.asList(new HostColumnVector.StructData("a", "b"));
List<HostColumnVector.StructData> list2 = Arrays.asList(new HostColumnVector.StructData("a", "c"));
List<HostColumnVector.StructData> list3 = Arrays.asList(new HostColumnVector.StructData("e", "d"));
List<HostColumnVector.StructData> list4 = Arrays.asList(new HostColumnVector.StructData("a", "g"));
List<HostColumnVector.StructData> list5 = Arrays.asList(new HostColumnVector.StructData("f", "h"));
List<HostColumnVector.StructData> list6 = Arrays.asList(new HostColumnVector.StructData("a", null));
List<HostColumnVector.StructData> list7 = Arrays.asList(new HostColumnVector.StructData(null, null));
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since: https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala#L27
this case seems not useful, but still keep it here since we manually added a null_mask(0) for null values.

HostColumnVector.StructType structType = new HostColumnVector.StructType(true, Arrays.asList(new HostColumnVector.BasicType(true, DType.STRING),
new HostColumnVector.BasicType(true, DType.STRING)));
try (ColumnVector cv = ColumnVector.fromLists(new HostColumnVector.ListType(true, structType), list1, list2, list3, list4, list5, list6, list7);
ColumnVector res = cv.getMapKeyExistence(Scalar.fromString("a"));
ColumnVector expected = ColumnVector.fromBoxedBooleans(true, true, false, true, false, true, false)) {
assertColumnsAreEqual(expected, res);
}
}

@Test
void testListOfStructsOfStructs() {
Expand Down