Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix result column types for empty inputs to rolling window #8274

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions cpp/include/cudf/detail/aggregation/aggregation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -784,9 +784,10 @@ struct target_type_impl<Source, aggregation::ALL> {
// Except for chrono types where result is chrono. (Use FloorDiv)
// TODO: MEAN should be only be enabled for duration types - not for timestamps
template <typename Source, aggregation::Kind k>
struct target_type_impl<Source,
k,
std::enable_if_t<!is_chrono<Source>() && (k == aggregation::MEAN)>> {
struct target_type_impl<
Source,
k,
std::enable_if_t<is_fixed_width<Source>() && !is_chrono<Source>() && (k == aggregation::MEAN)>> {
using type = double;
};

Expand Down Expand Up @@ -1032,7 +1033,7 @@ template <typename Element>
struct dispatch_aggregation {
#pragma nv_exec_check_disable
template <aggregation::Kind k, typename F, typename... Ts>
CUDA_HOST_DEVICE_CALLABLE decltype(auto) operator()(F&& f, Ts&&... args) const noexcept
CUDA_HOST_DEVICE_CALLABLE decltype(auto) operator()(F&& f, Ts&&... args) const
{
return f.template operator()<Element, k>(std::forward<Ts>(args)...);
}
Expand All @@ -1043,7 +1044,7 @@ struct dispatch_source {
template <typename Element, typename F, typename... Ts>
CUDA_HOST_DEVICE_CALLABLE decltype(auto) operator()(aggregation::Kind k,
F&& f,
Ts&&... args) const noexcept
Ts&&... args) const
{
return aggregation_dispatcher(
k, dispatch_aggregation<Element>{}, std::forward<F>(f), std::forward<Ts>(args)...);
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/rolling/grouped_rolling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ std::unique_ptr<column> grouped_rolling_window(table_view const& group_keys,
{
CUDF_FUNC_RANGE();

if (input.is_empty()) return empty_like(input);
if (input.is_empty()) { return cudf::detail::empty_output(input, aggr); }

CUDF_EXPECTS((group_keys.num_columns() == 0 || group_keys.num_rows() == input.size()),
"Size mismatch between group_keys and input vector.");
Expand Down Expand Up @@ -949,7 +949,7 @@ std::unique_ptr<column> grouped_range_rolling_window(table_view const& group_key
{
CUDF_FUNC_RANGE();

if (input.is_empty()) return empty_like(input);
if (input.is_empty()) { return cudf::detail::empty_output(input, aggr); }

CUDF_EXPECTS((group_keys.num_columns() == 0 || group_keys.num_rows() == input.size()),
"Size mismatch between group_keys and input vector.");
Expand Down
9 changes: 6 additions & 3 deletions cpp/src/rolling/rolling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* limitations under the License.
*/

#include <cudf/detail/aggregation/aggregation.hpp>
#include "rolling_detail.cuh"

namespace cudf {
Expand Down Expand Up @@ -46,7 +47,8 @@ std::unique_ptr<column> rolling_window(column_view const& input,
{
CUDF_FUNC_RANGE();

if (input.is_empty()) return empty_like(input);
if (input.is_empty()) { return cudf::detail::empty_output(input, agg); }

CUDF_EXPECTS((min_periods >= 0), "min_periods must be non-negative");

CUDF_EXPECTS((default_outputs.is_empty() || default_outputs.size() == input.size()),
Expand Down Expand Up @@ -88,8 +90,9 @@ std::unique_ptr<column> rolling_window(column_view const& input,
{
CUDF_FUNC_RANGE();

if (preceding_window.is_empty() || following_window.is_empty() || input.is_empty())
return empty_like(input);
if (preceding_window.is_empty() || following_window.is_empty() || input.is_empty()) {
return cudf::detail::empty_output(input, agg);
}

CUDF_EXPECTS(preceding_window.type().id() == type_id::INT32 &&
following_window.type().id() == type_id::INT32,
Expand Down
31 changes: 30 additions & 1 deletion cpp/src/rolling/rolling_detail.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,35 @@ struct DeviceRollingRowNumber {
}
};

struct agg_specific_empty_output {
template <typename InputType, aggregation::Kind op>
std::unique_ptr<column> operator()(column_view const& input, rolling_aggregation const& agg) const
{
using target_type = cudf::detail::target_type_t<InputType, op>;

if constexpr (std::is_same_v<cudf::detail::target_type_t<InputType, op>, void>) {
CUDF_FAIL("Unsupported combination of column-type and aggregation.");
}

if constexpr (cudf::is_fixed_width<target_type>()) {
return cudf::make_empty_column(data_type{type_to_id<target_type>()});
}

if constexpr (op == aggregation::COLLECT_LIST) {
return cudf::make_lists_column(
0, make_empty_column(data_type{type_to_id<offset_type>()}), empty_like(input), 0, {});
}

return empty_like(input);
}
};

std::unique_ptr<column> empty_output(column_view const& input, rolling_aggregation const& agg)
{
return cudf::detail::dispatch_type_and_aggregation(
input.type(), agg.kind, agg_specific_empty_output{}, input, agg);
}

/**
* @brief Operator for applying a LEAD rolling aggregation on a single window.
*/
Expand Down Expand Up @@ -1061,7 +1090,7 @@ std::unique_ptr<column> rolling_window(column_view const& input,
static_assert(warp_size == cudf::detail::size_in_bits<cudf::bitmask_type>(),
"bitmask_type size does not match CUDA warp size");

if (input.is_empty()) { return empty_like(input); }
if (input.is_empty()) { return cudf::detail::empty_output(input, agg); }

if (cudf::is_dictionary(input.type())) {
CUDF_EXPECTS(agg.kind == aggregation::COUNT_ALL || agg.kind == aggregation::COUNT_VALID ||
Expand Down
9 changes: 5 additions & 4 deletions cpp/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -280,12 +280,13 @@ ConfigureTest(STREAM_COMPACTION_TEST
###################################################################################################
# - rolling tests ---------------------------------------------------------------------------------
ConfigureTest(ROLLING_TEST
rolling/rolling_test.cpp
rolling/grouped_rolling_test.cpp
rolling/collect_list_test.cpp
rolling/empty_input_test.cpp
rolling/lead_lag_test.cpp
rolling/range_window_bounds_test.cpp
rolling/grouped_rolling_test.cpp
rolling/range_rolling_window_test.cpp
rolling/collect_list_test.cpp)
rolling/range_window_bounds_test.cpp
rolling/rolling_test.cpp)

###################################################################################################
# - filling test ----------------------------------------------------------------------------------
Expand Down
Loading