Skip to content

Commit

Permalink
Fix rolling window result types for empty inputs.
Browse files Browse the repository at this point in the history
  • Loading branch information
mythrocks committed May 26, 2021
1 parent cbbcba7 commit 07ffa1a
Show file tree
Hide file tree
Showing 6 changed files with 440 additions and 15 deletions.
11 changes: 6 additions & 5 deletions cpp/include/cudf/detail/aggregation/aggregation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -784,9 +784,10 @@ struct target_type_impl<Source, aggregation::ALL> {
// Except for chrono types where result is chrono. (Use FloorDiv)
// TODO: MEAN should be only be enabled for duration types - not for timestamps
template <typename Source, aggregation::Kind k>
struct target_type_impl<Source,
k,
std::enable_if_t<!is_chrono<Source>() && (k == aggregation::MEAN)>> {
struct target_type_impl<
Source,
k,
std::enable_if_t<is_fixed_width<Source>() && !is_chrono<Source>() && (k == aggregation::MEAN)>> {
using type = double;
};

Expand Down Expand Up @@ -1032,7 +1033,7 @@ template <typename Element>
struct dispatch_aggregation {
#pragma nv_exec_check_disable
template <aggregation::Kind k, typename F, typename... Ts>
CUDA_HOST_DEVICE_CALLABLE decltype(auto) operator()(F&& f, Ts&&... args) const noexcept
CUDA_HOST_DEVICE_CALLABLE decltype(auto) operator()(F&& f, Ts&&... args) const
{
return f.template operator()<Element, k>(std::forward<Ts>(args)...);
}
Expand All @@ -1043,7 +1044,7 @@ struct dispatch_source {
template <typename Element, typename F, typename... Ts>
CUDA_HOST_DEVICE_CALLABLE decltype(auto) operator()(aggregation::Kind k,
F&& f,
Ts&&... args) const noexcept
Ts&&... args) const
{
return aggregation_dispatcher(
k, dispatch_aggregation<Element>{}, std::forward<F>(f), std::forward<Ts>(args)...);
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/rolling/grouped_rolling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ std::unique_ptr<column> grouped_rolling_window(table_view const& group_keys,
{
CUDF_FUNC_RANGE();

if (input.is_empty()) return empty_like(input);
if (input.is_empty()) { return cudf::detail::empty_output(input, aggr); }

CUDF_EXPECTS((group_keys.num_columns() == 0 || group_keys.num_rows() == input.size()),
"Size mismatch between group_keys and input vector.");
Expand Down Expand Up @@ -949,7 +949,7 @@ std::unique_ptr<column> grouped_range_rolling_window(table_view const& group_key
{
CUDF_FUNC_RANGE();

if (input.is_empty()) return empty_like(input);
if (input.is_empty()) { return cudf::detail::empty_output(input, aggr); }

CUDF_EXPECTS((group_keys.num_columns() == 0 || group_keys.num_rows() == input.size()),
"Size mismatch between group_keys and input vector.");
Expand Down
9 changes: 6 additions & 3 deletions cpp/src/rolling/rolling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* limitations under the License.
*/

#include <cudf/detail/aggregation/aggregation.hpp>
#include "rolling_detail.cuh"

namespace cudf {
Expand Down Expand Up @@ -46,7 +47,8 @@ std::unique_ptr<column> rolling_window(column_view const& input,
{
CUDF_FUNC_RANGE();

if (input.is_empty()) return empty_like(input);
if (input.is_empty()) { return cudf::detail::empty_output(input, agg); }

CUDF_EXPECTS((min_periods >= 0), "min_periods must be non-negative");

CUDF_EXPECTS((default_outputs.is_empty() || default_outputs.size() == input.size()),
Expand Down Expand Up @@ -88,8 +90,9 @@ std::unique_ptr<column> rolling_window(column_view const& input,
{
CUDF_FUNC_RANGE();

if (preceding_window.is_empty() || following_window.is_empty() || input.is_empty())
return empty_like(input);
if (preceding_window.is_empty() || following_window.is_empty() || input.is_empty()) {
return cudf::detail::empty_output(input, agg);
}

CUDF_EXPECTS(preceding_window.type().id() == type_id::INT32 &&
following_window.type().id() == type_id::INT32,
Expand Down
31 changes: 30 additions & 1 deletion cpp/src/rolling/rolling_detail.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,35 @@ struct DeviceRollingRowNumber {
}
};

struct agg_specific_empty_output {
template <typename InputType, aggregation::Kind op>
std::unique_ptr<column> operator()(column_view const& input, rolling_aggregation const& agg) const
{
using target_type = cudf::detail::target_type_t<InputType, op>;

if constexpr (std::is_same_v<cudf::detail::target_type_t<InputType, op>, void>) {
CUDF_FAIL("Unsupported combination of column-type and aggregation.");
}

if constexpr (cudf::is_fixed_width<target_type>()) {
return cudf::make_empty_column(data_type{type_to_id<target_type>()});
}

if constexpr (op == aggregation::COLLECT_LIST) {
return cudf::make_lists_column(
0, make_empty_column(data_type{type_to_id<offset_type>()}), empty_like(input), 0, {});
}

return empty_like(input);
}
};

std::unique_ptr<column> empty_output(column_view const& input, rolling_aggregation const& agg)
{
return cudf::detail::dispatch_type_and_aggregation(
input.type(), agg.kind, agg_specific_empty_output{}, input, agg);
}

/**
* @brief Operator for applying a LEAD rolling aggregation on a single window.
*/
Expand Down Expand Up @@ -1061,7 +1090,7 @@ std::unique_ptr<column> rolling_window(column_view const& input,
static_assert(warp_size == cudf::detail::size_in_bits<cudf::bitmask_type>(),
"bitmask_type size does not match CUDA warp size");

if (input.is_empty()) { return empty_like(input); }
if (input.is_empty()) { return cudf::detail::empty_output(input, agg); }

if (cudf::is_dictionary(input.type())) {
CUDF_EXPECTS(agg.kind == aggregation::COUNT_ALL || agg.kind == aggregation::COUNT_VALID ||
Expand Down
9 changes: 5 additions & 4 deletions cpp/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -280,12 +280,13 @@ ConfigureTest(STREAM_COMPACTION_TEST
###################################################################################################
# - rolling tests ---------------------------------------------------------------------------------
ConfigureTest(ROLLING_TEST
rolling/rolling_test.cpp
rolling/grouped_rolling_test.cpp
rolling/collect_list_test.cpp
rolling/empty_input_test.cpp
rolling/lead_lag_test.cpp
rolling/range_window_bounds_test.cpp
rolling/grouped_rolling_test.cpp
rolling/range_rolling_window_test.cpp
rolling/collect_list_test.cpp)
rolling/range_window_bounds_test.cpp
rolling/rolling_test.cpp)

###################################################################################################
# - filling test ----------------------------------------------------------------------------------
Expand Down
Loading

0 comments on commit 07ffa1a

Please sign in to comment.