-
Notifications
You must be signed in to change notification settings - Fork 902
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support collect_set on rolling window #7881
Support collect_set on rolling window #7881
Conversation
Hey, there, @sperlingxx. This PR should be against Also, pardon me for pointing this out: The reason we postponed making this change was that we are piling on if-else checks in the SFINAE code. @nvdbaranec is in the middle of sorting out that refactor. |
Codecov Report
@@ Coverage Diff @@
## branch-21.06 #7881 +/- ##
===============================================
Coverage ? 82.89%
===============================================
Files ? 105
Lines ? 17875
Branches ? 0
===============================================
Hits ? 14817
Misses ? 3058
Partials ? 0 Continue to review full report at Codecov.
|
Signed-off-by: sperlingxx <[email protected]>
Hi @ttnghia @mythrocks, I apologize for triggering re-requesting review by mistake. I didn't realize the refactor involving multiple PRs until I found #8158. |
No worries, sir. We will prioritize this as soon as #8158 is in. |
Signed-off-by: sperlingxx <[email protected]>
Signed-off-by: sperlingxx <[email protected]>
Signed-off-by: sperlingxx <[email protected]>
Signed-off-by: sperlingxx <[email protected]>
Rerun tests. |
cpp/src/rolling/rolling_detail.cuh
Outdated
min_periods, | ||
agg._null_handling, | ||
stream, | ||
mr); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
collect_list
will not be returned, thus we will not use mr
. Instead, please use rmm::mr::get_current_device_resource()
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, my bad! I always forget that. Shall we keep the stream
, or replacing it with rmm::cuda_stream_default
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made this change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All func calls should use the same stream so we will not change that.
@@ -581,7 +581,7 @@ class collect_list_aggregation final : public rolling_aggregation { | |||
/** | |||
* @brief Derived aggregation class for specifying COLLECT_SET aggregation | |||
*/ | |||
class collect_set_aggregation final : public aggregation { | |||
class collect_set_aggregation final : public rolling_aggregation { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we inherit collect set from rolling_aggregation
instead of aggregation
? We need it for both rolling window and groupby, don't we?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because rolling_aggregation
is virtually inherited from aggregation
. I just followed corresponding codes for collect_list.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe both are wrong, as collect_set_aggregation
and collect_set_aggregation
are used not only in rolling window but in groupby. Can you change to:
class collect_list_aggregation final : public aggregation
...
class collect_set_aggregation final : public aggregation
and test if they can compile and unit tests all pass please?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @ttnghia, I tried on replacing rolling_aggregation
with aggregation
. And I got compiling error on src/aggregation/aggregation.cpp:454
:
/home/alfredxu/workspace/codes/cudf/cpp/src/aggregation/aggregation.cpp:454:74: error: could not convert ‘std::make_unique(_Args&& ...) [with _Tp = cudf::detail::collect_list_aggregation; _Args = {cudf::null_policy&}; typename std::_MakeUniq<_Tp>::__single_object = std::unique_ptr<cudf::detail::collect_list_aggregation, std::default_delete<cudf::detail::collect_list_aggregation> >]()’ from ‘unique_ptr<cudf::detail::collect_list_aggregation,default_delete<cudf::detail::collect_list_aggregation>>’ to ‘unique_ptr<cudf::rolling_aggregation,default_delete<cudf::rolling_aggregation>>’
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
/// Factory to create a COLLECT_LIST aggregation
template <typename Base = aggregation>
std::unique_ptr<Base> make_collect_list_aggregation(null_policy null_handling)
{
return std::make_unique<detail::collect_list_aggregation>(null_handling);
}
template std::unique_ptr<aggregation> make_collect_list_aggregation<aggregation>(
null_policy null_handling);
template std::unique_ptr<rolling_aggregation> make_collect_list_aggregation<rolling_aggregation>(
null_policy null_handling);
I think it is because we can not return std::unique_ptr<rolling_aggregation>
after we made the replacement.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. Thanks 😄
Signed-off-by: sperlingxx <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Now we have other blocker PRs (#8342 and rapidsai/dask-cuda#623). Wait for them merged first before we can rerun tests.
rerun tests |
Rerun tests. |
@gpucibot merge |
This pull request is to support collect_set on rolling window, which is required in #7809.