Skip to content

Commit

Permalink
fix list_concat_rows
Browse files Browse the repository at this point in the history
Signed-off-by: sperlingxx <[email protected]>
  • Loading branch information
sperlingxx committed May 21, 2021
1 parent 9a85b3b commit fb22fc7
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 31 deletions.
40 changes: 34 additions & 6 deletions cpp/src/lists/combine/concatenate_list_elements.cu
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ namespace {
* concatenation.
*/
std::unique_ptr<column> concatenate_lists_ignore_null(column_view const& input,
bool build_null_mask,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
Expand All @@ -50,9 +51,13 @@ std::unique_ptr<column> concatenate_lists_ignore_null(column_view const& input,
auto out_offsets = make_numeric_column(
data_type{type_id::INT32}, num_rows + 1, mask_state::UNALLOCATED, stream, mr);

// The array of int8_t stores validities for the output list elements.
auto validities = rmm::device_uvector<int8_t>(build_null_mask ? num_rows : 0, stream);

auto const d_out_offsets = out_offsets->mutable_view().template begin<offset_type>();
auto const d_row_offsets = lists_column_view(input).offsets_begin();
auto const d_list_offsets = lists_column_view(lists_column_view(input).child()).offsets_begin();
auto const lists_dv_ptr = column_device_view::create(lists_column_view(input).child());

// Concatenating the lists at the same row by converting the entry offsets from the child column
// into row offsets of the root column. Those entry offsets are subtracted by the first entry
Expand All @@ -62,7 +67,22 @@ std::unique_ptr<column> concatenate_lists_ignore_null(column_view const& input,
iter,
iter + num_rows + 1,
d_out_offsets,
[d_row_offsets, d_list_offsets] __device__(auto const idx) {
[d_row_offsets,
d_list_offsets,
lists_dv = *lists_dv_ptr,
d_validities = validities.begin(),
build_null_mask,
iter] __device__(auto const idx) {
if (build_null_mask) {
// The output row will be null only if all lists on the input row are null.
auto const is_valid = thrust::any_of(thrust::seq,
iter + d_row_offsets[idx],
iter + d_row_offsets[idx + 1],
[&] __device__(auto const list_idx) {
return lists_dv.is_valid(list_idx);
});
d_validities[idx] = static_cast<int8_t>(is_valid);
}
auto const start_offset = d_list_offsets[d_row_offsets[0]];
return d_list_offsets[d_row_offsets[idx]] - start_offset;
});
Expand All @@ -71,11 +91,18 @@ std::unique_ptr<column> concatenate_lists_ignore_null(column_view const& input,
auto out_entries = std::make_unique<column>(
lists_column_view(lists_column_view(input).get_sliced_child(stream)).get_sliced_child(stream));

auto [null_mask, null_count] = [&] {
return build_null_mask
? cudf::detail::valid_if(
validities.begin(), validities.end(), thrust::identity<int8_t>{}, stream, mr)
: std::make_pair(cudf::detail::copy_bitmask(input, stream, mr), input.null_count());
}();

return make_lists_column(num_rows,
std::move(out_offsets),
std::move(out_entries),
input.null_count(),
cudf::detail::copy_bitmask(input, stream, mr),
null_count,
null_count > 0 ? std::move(null_mask) : rmm::device_buffer{},
stream,
mr);
}
Expand Down Expand Up @@ -241,9 +268,10 @@ std::unique_ptr<column> concatenate_list_elements(column_view const& input,

if (input.size() == 0) { return cudf::empty_like(input); }

return (null_policy == concatenate_null_policy::IGNORE ||
!lists_column_view(input).child().has_nulls())
? concatenate_lists_ignore_null(input, stream, mr)
bool has_null_list = lists_column_view(input).child().has_nulls();

return (null_policy == concatenate_null_policy::IGNORE || !has_null_list)
? concatenate_lists_ignore_null(input, has_null_list, stream, mr)
: concatenate_lists_nullifying_rows(input, stream, mr);
}

Expand Down
25 changes: 15 additions & 10 deletions cpp/tests/lists/combine/concatenate_list_elements_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,19 +147,23 @@ TYPED_TEST(ConcatenateListElementsTypedTest, SimpleInputWithNulls)
auto row5 = ListsCol{ListsCol{{1, 2, 3, null}, null_at(3)},
ListsCol{{null}, null_at(0)},
ListsCol{{null, null, null, null, null}, all_nulls()}};
auto const col = build_lists_col(row0, row1, row2, row3, row4, row5);
auto row6 =
ListsCol{{ListsCol{} /*NULL*/, ListsCol{} /*NULL*/, ListsCol{} /*NULL*/}, all_nulls()};
auto const col = build_lists_col(row0, row1, row2, row3, row4, row5, row6);

// Ignore null list elements.
{
auto const results = cudf::lists::concatenate_list_elements(col);
auto const expected =
ListsCol{ListsCol{{1, null, 3, 4, 10, 11, 12, null}, null_at({1, 7})},
ListsCol{{null, 2, 3, 4, 13, 14, 15, 16, 17, null, 20, null}, null_at({0, 9, 11})},
ListsCol{{null, 2, 3, 4, null, 21, null, null}, null_at({0, 4, 6, 7})},
ListsCol{{null, 18}, null_at(0)},
ListsCol{{1, 2, null, 4, 19, 20, null, 22, 23, 24, 25}, null_at({2, 6})},
ListsCol{{1, 2, 3, null, null, null, null, null, null, null},
null_at({3, 4, 5, 6, 7, 8, 9})}};
ListsCol{{ListsCol{{1, null, 3, 4, 10, 11, 12, null}, null_at({1, 7})},
ListsCol{{null, 2, 3, 4, 13, 14, 15, 16, 17, null, 20, null}, null_at({0, 9, 11})},
ListsCol{{null, 2, 3, 4, null, 21, null, null}, null_at({0, 4, 6, 7})},
ListsCol{{null, 18}, null_at(0)},
ListsCol{{1, 2, null, 4, 19, 20, null, 22, 23, 24, 25}, null_at({2, 6})},
ListsCol{{1, 2, 3, null, null, null, null, null, null, null},
null_at({3, 4, 5, 6, 7, 8, 9})},
ListsCol{} /*NULL*/},
null_at(6)};
CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, *results, print_all);
}

Expand All @@ -174,8 +178,9 @@ TYPED_TEST(ConcatenateListElementsTypedTest, SimpleInputWithNulls)
ListsCol{} /*NULL*/,
ListsCol{{1, 2, null, 4, 19, 20, null, 22, 23, 24, 25}, null_at({2, 6})},
ListsCol{{1, 2, 3, null, null, null, null, null, null, null},
null_at({3, 4, 5, 6, 7, 8, 9})}},
null_at({0, 2, 3})};
null_at({3, 4, 5, 6, 7, 8, 9})},
ListsCol{} /*NULL*/},
null_at({0, 2, 3, 6})};
CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected, *results, print_all);
}
}
Expand Down
36 changes: 21 additions & 15 deletions cpp/tests/lists/combine/concatenate_rows_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,38 +184,43 @@ TYPED_TEST(ListConcatenateRowsTypedTest, SimpleInputWithNulls)
ListsCol{{null, 2, 3, 4}, null_at(0)},
ListsCol{} /*NULL*/,
ListsCol{{1, 2, null, 4}, null_at(2)},
ListsCol{{1, 2, 3, null}, null_at(3)}},
null_at(3)}
ListsCol{{1, 2, 3, null}, null_at(3)},
ListsCol{} /*NULL*/},
null_at({3, 6})}
.release();
auto const col2 = ListsCol{{ListsCol{{10, 11, 12, null}, null_at(3)},
ListsCol{{13, 14, 15, 16, 17, null}, null_at(5)},
ListsCol{} /*NULL*/,
ListsCol{{null, 18}, null_at(0)},
ListsCol{{19, 20, null}, null_at(2)},
ListsCol{{null}, null_at(0)}},
null_at(2)}
ListsCol{{null}, null_at(0)},
ListsCol{} /*NULL*/},
null_at({2, 6})}
.release();
auto const col3 = ListsCol{{ListsCol{} /*NULL*/,
ListsCol{{20, null}, null_at(1)},
ListsCol{{null, 21, null, null}, null_at({0, 2, 3})},
ListsCol{},
ListsCol{22, 23, 24, 25},
ListsCol{{null, null, null, null, null}, all_nulls()}},
null_at(0)}
ListsCol{{null, null, null, null, null}, all_nulls()},
ListsCol{} /*NULL*/},
null_at({0, 6})}
.release();

// Ignore null list elements
{
auto const results =
cudf::lists::concatenate_rows(TView{{col1->view(), col2->view(), col3->view()}});
auto const expected =
ListsCol{ListsCol{{1, null, 3, 4, 10, 11, 12, null}, null_at({1, 7})},
ListsCol{{null, 2, 3, 4, 13, 14, 15, 16, 17, null, 20, null}, null_at({0, 9, 11})},
ListsCol{{null, 2, 3, 4, null, 21, null, null}, null_at({0, 4, 6, 7})},
ListsCol{{null, 18}, null_at(0)},
ListsCol{{1, 2, null, 4, 19, 20, null, 22, 23, 24, 25}, null_at({2, 6})},
ListsCol{{1, 2, 3, null, null, null, null, null, null, null},
null_at({3, 4, 5, 6, 7, 8, 9})}}
ListsCol{{ListsCol{{1, null, 3, 4, 10, 11, 12, null}, null_at({1, 7})},
ListsCol{{null, 2, 3, 4, 13, 14, 15, 16, 17, null, 20, null}, null_at({0, 9, 11})},
ListsCol{{null, 2, 3, 4, null, 21, null, null}, null_at({0, 4, 6, 7})},
ListsCol{{null, 18}, null_at(0)},
ListsCol{{1, 2, null, 4, 19, 20, null, 22, 23, 24, 25}, null_at({2, 6})},
ListsCol{{1, 2, 3, null, null, null, null, null, null, null},
null_at({3, 4, 5, 6, 7, 8, 9})},
ListsCol{} /*NULL*/},
null_at(6)}
.release();
CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*expected, *results, print_all);
}
Expand All @@ -232,8 +237,9 @@ TYPED_TEST(ListConcatenateRowsTypedTest, SimpleInputWithNulls)
ListsCol{} /*NULL*/,
ListsCol{{1, 2, null, 4, 19, 20, null, 22, 23, 24, 25}, null_at({2, 6})},
ListsCol{{1, 2, 3, null, null, null, null, null, null, null},
null_at({3, 4, 5, 6, 7, 8, 9})}},
null_at({0, 2, 3})}
null_at({3, 4, 5, 6, 7, 8, 9})},
ListsCol{} /*NULL*/},
null_at({0, 2, 3, 6})}
.release();
CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*expected, *results, print_all);
}
Expand Down

0 comments on commit fb22fc7

Please sign in to comment.