Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check if a map contains a specific key [skip ci] #8209

Merged
merged 13 commits into from
May 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions java/src/main/java/ai/rapids/cudf/ColumnView.java
Original file line number Diff line number Diff line change
Expand Up @@ -2526,6 +2526,19 @@ public final ColumnVector getMapValue(Scalar key) {
return new ColumnVector(mapLookup(getNativeView(), key.getScalarHandle()));
}

/** For a column of type List<Struct<String, String>> and a passed in String key, return a boolean
* column for all keys in the structs, It is true if the key exists in the corresponding map for
* that row, false otherwise. It will never return null for a row.
* @param key the String scalar to lookup in the column
* @return a boolean column based on the lookup result
*/
public final ColumnVector getMapKeyExistence(Scalar key) {
firestarman marked this conversation as resolved.
Show resolved Hide resolved
assert type.equals(DType.LIST) : "column type must be a LIST";
assert key != null : "target string may not be null";
assert key.getType().equals(DType.STRING) : "target must be a string scalar";

return new ColumnVector(mapContains(getNativeView(), key.getScalarHandle()));
}

/**
* Create a new struct column view of existing column views. Note that this will NOT copy
Expand Down Expand Up @@ -2844,6 +2857,15 @@ private static native long stringReplaceWithBackrefs(long columnView, String pat
* @throws CudfException
*/
private static native long mapLookup(long columnView, long key) throws CudfException;

/**
* Native method for check the existence of a key over a column of List<Struct<String,String>>
* @param columnView the column view handle of the map
* @param key the string scalar that is the key for lookup
* @return boolean column handle of the result
* @throws CudfException
*/
private static native long mapContains(long columnView, long key) throws CudfException;
/**
* Native method to add zeros as padding to the left of each string.
*/
Expand Down
16 changes: 16 additions & 0 deletions java/src/main/native/src/ColumnViewJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1169,6 +1169,22 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_mapLookup(JNIEnv *env, jc
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_mapContains(JNIEnv *env, jclass,
jlong map_column_view,
jlong lookup_key) {
JNI_NULL_CHECK(env, map_column_view, "column is null", 0);
JNI_NULL_CHECK(env, lookup_key, "target string scalar is null", 0);
try {
cudf::jni::auto_set_device(env);
cudf::column_view *cv = reinterpret_cast<cudf::column_view *>(map_column_view);
cudf::string_scalar *ss_key = reinterpret_cast<cudf::string_scalar *>(lookup_key);

std::unique_ptr<cudf::column> result = cudf::jni::map_contains(*cv, *ss_key);
firestarman marked this conversation as resolved.
Show resolved Hide resolved
return reinterpret_cast<jlong>(result.release());
}
CATCH_STD(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_stringReplaceWithBackrefs(JNIEnv *env,
jclass,
jlong column_view,
Expand Down
58 changes: 49 additions & 9 deletions java/src/main/native/src/map_lookup.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,18 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <cudf/column/column.hpp>
#include <cudf/column/column_device_view.cuh>
#include <cudf/column/column_factories.hpp>
#include <cudf/detail/gather.hpp>
#include <cudf/detail/utilities/cuda.cuh>
#include <cudf/lists/contains.hpp>
#include <cudf/lists/lists_column_view.hpp>
#include <cudf/replace.hpp>
#include <cudf/scalar/scalar.hpp>
#include <cudf/scalar/scalar_device_view.cuh>
#include <cudf/scalar/scalar_factories.hpp>
#include <cudf/structs/structs_column_view.hpp>
#include <cudf/table/table_view.hpp>
#include <cudf/types.hpp>
Expand Down Expand Up @@ -124,27 +128,63 @@ get_gather_map_for_map_values(column_view const &input, string_scalar &lookup_ke
return gather_map;
}

} // namespace

namespace jni {
std::unique_ptr<column> map_lookup(column_view const &map_column, string_scalar lookup_key,
bool has_nulls, rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource *mr) {
// Defensive checks.
/**
* @brief a defensive check for the map column that is going to be processed
*/
void map_input_check(column_view const &map_column, rmm::cuda_stream_view stream) {
CUDF_EXPECTS(map_column.type().id() == type_id::LIST, "Expected LIST<STRUCT<key,value>>.");

lists_column_view lcv{map_column};
auto structs_column = lcv.get_sliced_child(stream);
column_view structs_column = lcv.get_sliced_child(stream);

CUDF_EXPECTS(structs_column.type().id() == type_id::STRUCT, "Expected LIST<STRUCT<key,value>>.");

structs_column_view scv{structs_column};
CUDF_EXPECTS(structs_column.num_children() == 2, "Expected LIST<STRUCT<key,value>>.");
CUDF_EXPECTS(structs_column.child(0).type().id() == type_id::STRING,
"Expected LIST<STRUCT<key,value>>.");
CUDF_EXPECTS(structs_column.child(1).type().id() == type_id::STRING,
"Expected LIST<STRUCT<key,value>>.");
}

} // namespace

namespace jni {

std::unique_ptr<column> map_contains(column_view const &map_column, string_scalar lookup_key,
bool has_nulls, rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource *mr) {
// Defensive checks.
map_input_check(map_column, stream);

lists_column_view lcv(map_column);
structs_column_view scv(lcv.child());

std::vector<column_view> children;
children.push_back(lcv.offsets());
children.push_back(scv.child(0));

column_view list_of_keys(map_column.type(), map_column.size(),
nullptr, map_column.null_mask(), map_column.null_count(), 0, children);
auto contains_column = lists::contains(list_of_keys, lookup_key);
// null will be skipped in all-aggregation when checking if all rows contain the key,
// so replace all nulls with 0.
std::unique_ptr<cudf::scalar> replacement =
cudf::make_numeric_scalar(cudf::data_type(cudf::type_id::BOOL8));
replacement->set_valid(true);
using ScalarType = cudf::scalar_type_t<int8_t>;
static_cast<ScalarType *>(replacement.get())->set_value(0);
auto result = cudf::replace_nulls(contains_column->view(), *replacement);
return result;
}

std::unique_ptr<column> map_lookup(column_view const &map_column, string_scalar lookup_key,
bool has_nulls, rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource *mr) {
// Defensive checks.
map_input_check(map_column, stream);

lists_column_view lcv{map_column};
column_view structs_column = lcv.get_sliced_child(stream);
// Two-pass plan: construct gather map, and then gather() on structs_column.child(1). Plan A.
// (Can do in one pass perhaps, but that's Plan B.)

Expand Down
32 changes: 32 additions & 0 deletions java/src/main/native/src/map_lookup.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,38 @@ map_lookup(column_view const &map_column, string_scalar lookup_key, bool has_nul
rmm::cuda_stream_view stream = rmm::cuda_stream_default,
rmm::mr::device_memory_resource *mr = rmm::mr::get_current_device_resource());


/**
* @brief Looks up a "map" column by specified key to see if the key exists or not,
* and returns a cudf column of bool value.
*
* The map-column is represented as follows:
*
* list_view<struct_view< string_view, string_view > >.
* <---KEY---> <--VALUE-->
*
* The string_view struct members are the key and value, respectively.
* For each row in the input list column, if the key is not found, false will be returned for that
* row.
* Note: when search for the scalar key of "null", a column full of "false" will be returned because
* map_contains is leveraging cudf::list:contains.
*
* @param map_column The input "map" column to be searched. Must be of
* type list_view<struct_view<string_view, string_view>>.
* @param lookup_key The search key, whose index(offset) is to be returned for each list row
* @param has_nulls Whether the input column might contain null list-rows, or null keys.
* @param stream The CUDA stream
* @param mr The device memory resource to be used for allocations
* @return An boolean column reflecting the existence of the key in each row in the map
* column. True means the lookup_key is found in that row.
* @throw cudf::logic_error If the input column is not of type
* list_view<struct_view<string_view, string_view>>
*/
std::unique_ptr<column>
map_contains(column_view const &map_column, string_scalar lookup_key, bool has_nulls = true,
mythrocks marked this conversation as resolved.
Show resolved Hide resolved
rmm::cuda_stream_view stream = rmm::cuda_stream_default,
rmm::mr::device_memory_resource *mr = rmm::mr::get_current_device_resource());

} // namespace jni

} // namespace cudf
29 changes: 29 additions & 0 deletions java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -4411,6 +4411,35 @@ void testGetMapValue() {
}
}

@Test
void testGetMapKeyExistence() {
List<HostColumnVector.StructData> list1 = Arrays.asList(new HostColumnVector.StructData("a", "b"));
List<HostColumnVector.StructData> list2 = Arrays.asList(new HostColumnVector.StructData("a", "c"));
List<HostColumnVector.StructData> list3 = Arrays.asList(new HostColumnVector.StructData("e", "d"));
List<HostColumnVector.StructData> list4 = Arrays.asList(new HostColumnVector.StructData("a", "g"));
List<HostColumnVector.StructData> list5 = Arrays.asList(new HostColumnVector.StructData("a", null));
List<HostColumnVector.StructData> list6 = Arrays.asList(new HostColumnVector.StructData(null, null));
List<HostColumnVector.StructData> list7 = Arrays.asList(new HostColumnVector.StructData());
HostColumnVector.StructType structType = new HostColumnVector.StructType(true, Arrays.asList(new HostColumnVector.BasicType(true, DType.STRING),
new HostColumnVector.BasicType(true, DType.STRING)));
try (ColumnVector cv = ColumnVector.fromLists(new HostColumnVector.ListType(true, structType), list1, list2, list3, list4, list5, list6, list7);
ColumnVector resValidKey = cv.getMapKeyExistence(Scalar.fromString("a"));
ColumnVector expectedValid = ColumnVector.fromBoxedBooleans(true, true, false, true, true, false, false);
ColumnVector expectedNull = ColumnVector.fromBoxedBooleans(false, false, false, false, false, false, false);
ColumnVector resNullKey = cv.getMapKeyExistence(Scalar.fromNull(DType.STRING))) {
assertColumnsAreEqual(expectedValid, resValidKey);
assertColumnsAreEqual(expectedNull, resNullKey);
}

AssertionError e = assertThrows(AssertionError.class, () -> {
try (ColumnVector cv = ColumnVector.fromLists(new HostColumnVector.ListType(true, structType), list1, list2, list3, list4, list5, list6, list7);
ColumnVector resNullKey = cv.getMapKeyExistence(null)) {
}
});
assertTrue(e.getMessage().contains("target string may not be null"));
}


@Test
void testListOfStructsOfStructs() {
List<HostColumnVector.StructData> list1 = Arrays.asList(
Expand Down