-
Notifications
You must be signed in to change notification settings - Fork 906
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generalized null support in user defined functions #8213
Changes from 118 commits
a7bbfb4
f44c196
193f8e0
91ae6a3
a855a6f
1b2c00c
7584ad3
4988b14
7a6427c
5e6eb06
ea15da6
d1119b2
961a9dd
664bf79
5e93094
03edceb
2b4c36f
3f76df5
db88f9e
9a67670
b9da4bf
ec302f3
cb85d88
ad067eb
237af25
8e11c7e
591627c
c07e187
d21b858
6806968
cef8b71
19b88c5
4845f27
c796dc4
4f0ab9b
3389198
9425d4b
f7845e5
9e89ebd
e19c8ba
9880081
3e6a280
3028dba
ecd8527
5791413
77c8ee4
85f1fba
22c220c
1ba3338
be06228
029203b
f024bf7
2ef4520
90d5127
0953bd1
96024c4
6287404
aa38be2
b5dcd13
4c29c23
837f2ef
899c4dc
2d7104d
b63b435
a3e1444
ca79d72
3be8b16
e1defcb
33d3dcb
5769ded
1579b39
df28144
8bab890
6e7ac8d
d58234e
739f6fc
14e3ab8
417c130
195e9b8
6350909
32f54d4
37a9257
22d610f
8a1b053
58dab99
4f06497
a59b240
e440770
a6f67fa
d9e8fdb
1d6755a
671792c
c3007de
100ac44
91c91eb
1fa3cab
6125dc0
59e1209
c1324b8
ed79368
62ddca7
821d11d
fb8f1cf
5d77b2b
c3254f9
bbefb7d
f863ba1
92cd6eb
4b08c51
48733b2
16018f6
c91737e
d426589
0da7fc7
9048879
9fa05a3
b807534
f00cefc
fe5eb30
f56ffbb
7f07452
968e91b
699239d
2d07152
95098e6
b724410
593cbd2
a31c15a
448e4ea
6780814
e622d30
6bf3cf5
6ed7a49
4ab7bd8
1ffce5b
169bcf2
aec243d
51ce28f
993d841
512555b
8f1add4
d683db9
b061710
7c722dd
7a7ee83
a20d630
a13e935
9acc7a9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
/* | ||
* Copyright (c) 2021, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
// Include Jitify's cstddef header first | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why? The convention in cudf is to include from "near" to "far". So, you include There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the problem here is that technically when this file is runtime compilated later, |
||
#include <cstddef> | ||
|
||
#include <cuda/std/climits> | ||
#include <cuda/std/cstddef> | ||
#include <cuda/std/limits> | ||
#include <cuda/std/type_traits> | ||
|
||
#include <transform/jit/operation-udf.hpp> | ||
|
||
#include <cudf/types.hpp> | ||
#include <cudf/utilities/bit.hpp> | ||
#include <cudf/wrappers/timestamps.hpp> | ||
|
||
#include <cuda/std/tuple> | ||
#include <tuple> | ||
|
||
namespace cudf { | ||
namespace transformation { | ||
namespace jit { | ||
|
||
template <typename T> | ||
struct Masked { | ||
T value; | ||
bool valid; | ||
}; | ||
|
||
template <typename TypeIn, typename MaskType, typename OffsetType> | ||
__device__ auto make_args(cudf::size_type id, TypeIn in_ptr, MaskType in_mask, OffsetType in_offset) | ||
{ | ||
bool valid = in_mask ? cudf::bit_is_set(in_mask, in_offset + id) : true; | ||
return cuda::std::make_tuple(in_ptr[id], valid); | ||
} | ||
|
||
template <typename InType, typename MaskType, typename OffsetType, typename... Arguments> | ||
__device__ auto make_args(cudf::size_type id, | ||
InType in_ptr, | ||
MaskType in_mask, // in practice, always cudf::bitmask_type const* | ||
OffsetType in_offset, // in practice, always cudf::size_type | ||
Arguments... args) | ||
{ | ||
bool valid = in_mask ? cudf::bit_is_set(in_mask, in_offset + id) : true; | ||
return cuda::std::tuple_cat(cuda::std::make_tuple(in_ptr[id], valid), make_args(id, args...)); | ||
} | ||
|
||
template <typename TypeOut, typename... Arguments> | ||
__global__ void generic_udf_kernel(cudf::size_type size, | ||
TypeOut* out_data, | ||
bool* out_mask, | ||
Arguments... args) | ||
{ | ||
int const tid = threadIdx.x; | ||
int const blkid = blockIdx.x; | ||
int const blksz = blockDim.x; | ||
int const gridsz = gridDim.x; | ||
int const start = tid + blkid * blksz; | ||
int const step = blksz * gridsz; | ||
|
||
Masked<TypeOut> output; | ||
for (cudf::size_type i = start; i < size; i += step) { | ||
auto func_args = cuda::std::tuple_cat( | ||
cuda::std::make_tuple(&output.value), | ||
make_args(i, args...) // passed int64*, bool*, int64, int64*, bool*, int64 | ||
); | ||
cuda::std::apply(GENERIC_OP, func_args); | ||
isVoid marked this conversation as resolved.
Show resolved
Hide resolved
|
||
out_data[i] = output.value; | ||
out_mask[i] = output.valid; | ||
isVoid marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
} | ||
|
||
} // namespace jit | ||
} // namespace transformation | ||
} // namespace cudf |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,6 +15,7 @@ | |
*/ | ||
|
||
#include <jit_preprocessed_files/transform/jit/kernel.cu.jit.hpp> | ||
#include <jit_preprocessed_files/transform/jit/masked_udf_kernel.cu.jit.hpp> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe that jit headers should be included after cudf headers. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 |
||
|
||
#include <jit/cache.hpp> | ||
#include <jit/parser.hpp> | ||
|
@@ -25,6 +26,7 @@ | |
#include <cudf/detail/nvtx/ranges.hpp> | ||
#include <cudf/detail/transform.hpp> | ||
#include <cudf/null_mask.hpp> | ||
#include <cudf/table/table_view.hpp> | ||
#include <cudf/utilities/traits.hpp> | ||
#include <cudf/utilities/type_dispatcher.hpp> | ||
|
||
|
@@ -63,6 +65,81 @@ void unary_operation(mutable_column_view output, | |
cudf::jit::get_data_ptr(input)); | ||
} | ||
|
||
std::vector<std::string> make_template_types(column_view outcol_view, table_view data_view) | ||
{ | ||
std::string mskptr_type = | ||
cudf::jit::get_type_name(cudf::data_type(cudf::type_to_id<cudf::bitmask_type>())) + "*"; | ||
std::string offset_type = | ||
cudf::jit::get_type_name(cudf::data_type(cudf::type_to_id<cudf::offset_type>())); | ||
|
||
std::vector<std::string> template_types; | ||
template_types.reserve(data_view.num_columns() + 1); | ||
|
||
template_types.push_back(cudf::jit::get_type_name(outcol_view.type())); | ||
for (auto const& col : data_view) { | ||
template_types.push_back(cudf::jit::get_type_name(col.type()) + "*"); | ||
template_types.push_back(mskptr_type); | ||
template_types.push_back(offset_type); | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wait, I see that you call There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice catch - this was unsafe. Fixed |
||
return template_types; | ||
} | ||
|
||
void generalized_operation(table_view data_view, | ||
std::string const& udf, | ||
data_type output_type, | ||
mutable_column_view outcol_view, | ||
mutable_column_view outmsk_view, | ||
rmm::mr::device_memory_resource* mr) | ||
{ | ||
std::vector<std::string> template_types = make_template_types(outcol_view, data_view); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One more thing I want to note is that, you can use
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed. |
||
|
||
std::string generic_kernel_name = | ||
jitify2::reflection::Template("cudf::transformation::jit::generic_udf_kernel") | ||
.instantiate(template_types); | ||
|
||
std::string generic_cuda_source = cudf::jit::parse_single_function_ptx( | ||
udf, "GENERIC_OP", cudf::jit::get_type_name(output_type), {0}); | ||
|
||
// {size, out_ptr, out_mask_ptr, col0_ptr, col0_mask_ptr, col0_offset, col1_ptr...} | ||
std::vector<void*> kernel_args; | ||
kernel_args.reserve((data_view.num_columns() * 3) + 3); | ||
|
||
cudf::size_type size = outcol_view.size(); | ||
const void* outcol_ptr = cudf::jit::get_data_ptr(outcol_view); | ||
const void* outmsk_ptr = cudf::jit::get_data_ptr(outmsk_view); | ||
kernel_args.insert(kernel_args.begin(), {&size, &outcol_ptr, &outmsk_ptr}); | ||
|
||
std::vector<const void*> data_ptrs; | ||
std::vector<cudf::bitmask_type const*> mask_ptrs; | ||
std::vector<cudf::offset_type> offsets; | ||
|
||
data_ptrs.reserve(data_view.num_columns()); | ||
mask_ptrs.reserve(data_view.num_columns()); | ||
offsets.reserve(data_view.num_columns()); | ||
|
||
column_view col; | ||
for (int col_idx = 0; col_idx < data_view.num_columns(); col_idx++) { | ||
col = data_view.column(col_idx); | ||
brandon-b-miller marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
data_ptrs.push_back(cudf::jit::get_data_ptr(col)); | ||
mask_ptrs.push_back(col.null_mask()); | ||
offsets.push_back(col.offset()); | ||
|
||
kernel_args.push_back(&data_ptrs[col_idx]); | ||
kernel_args.push_back(&mask_ptrs[col_idx]); | ||
kernel_args.push_back(&offsets[col_idx]); | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we use some type of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is difficult due to the 1->3 transform going on here. I kept trying to do the same, but couldn't get anything that was cleaner. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I managed to use |
||
|
||
rmm::cuda_stream_view generic_stream; | ||
cudf::jit::get_program_cache(*transform_jit_masked_udf_kernel_cu_jit) | ||
.get_kernel(generic_kernel_name, | ||
{}, | ||
{{"transform/jit/operation-udf.hpp", generic_cuda_source}}, | ||
{"-arch=sm_."}) // | ||
->configure_1d_max_occupancy(0, 0, 0, generic_stream.value()) // | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be fixed. |
||
->launch(kernel_args.data()); | ||
} | ||
|
||
} // namespace jit | ||
} // namespace transformation | ||
|
||
|
@@ -89,6 +166,24 @@ std::unique_ptr<column> transform(column_view const& input, | |
return output; | ||
} | ||
|
||
std::unique_ptr<column> generalized_masked_op(table_view data_view, | ||
std::string const& udf, | ||
data_type output_type, | ||
rmm::mr::device_memory_resource* mr) | ||
{ | ||
rmm::cuda_stream_view stream = rmm::cuda_stream_default; | ||
std::unique_ptr<column> output = make_fixed_width_column(output_type, data_view.num_rows()); | ||
std::unique_ptr<column> output_mask = | ||
make_fixed_width_column(cudf::data_type{cudf::type_id::BOOL8}, data_view.num_rows()); | ||
isVoid marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
transformation::jit::generalized_operation( | ||
data_view, udf, output_type, *output, *output_mask, mr); | ||
|
||
auto final_output_mask = cudf::bools_to_mask(*output_mask); | ||
output.get()->set_null_mask(std::move(*(final_output_mask.first))); | ||
brandon-b-miller marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return output; | ||
} | ||
|
||
} // namespace detail | ||
|
||
std::unique_ptr<column> transform(column_view const& input, | ||
|
@@ -101,4 +196,12 @@ std::unique_ptr<column> transform(column_view const& input, | |
return detail::transform(input, unary_udf, output_type, is_ptx, rmm::cuda_stream_default, mr); | ||
} | ||
|
||
std::unique_ptr<column> generalized_masked_op(table_view data_view, | ||
std::string const& udf, | ||
data_type output_type, | ||
rmm::mr::device_memory_resource* mr) | ||
{ | ||
return detail::generalized_masked_op(data_view, udf, output_type, mr); | ||
} | ||
|
||
} // namespace cudf |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from . import typing, lowering |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import operator | ||
|
||
arith_ops = [ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think there are more ops that can be added here, but my expectation / understanding is that this list is fine for this PR, and that we might add more in future PRs (along with tests for them) - does that match your thoughts / plans? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should add everything we can provided that it will just work if we add it to this list as well as the tests, which ones am I missing? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. From looking at Numba's implementation of the Bitwise operations (for integer types): bit_ops = (
operator.lshift,
operator.rshift,
operator.and_,
operator.or_,
operator.xor,
) Inplace operations:
Most of the unary operations:
Maybe some instances of
Some of the unary ops:
I think it's fairly straightforward to add everything, but needs some thought / care to make sure that we only provide typing for things that make sense - e.g. not providing bitwise operations on floats. I do still think it would be better to get this over the line then add more in another PR than to add to the workload of this PR (and I'm happy to help with the addition of more when I get a chance). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great summary. I'll make sure to file some issues for these once this merges. When we don't support something, should we just not provide typing for those cases, and let the typing machinery raise the standard numba error? Or would it be better to provide typing that detects the scenario and then errors with something more specific (I assume this would bubble up through numba's normal traceback system anyways) - just thinking forward for disabling cases like |
||
operator.add, | ||
operator.sub, | ||
operator.mul, | ||
operator.truediv, | ||
operator.floordiv, | ||
operator.mod, | ||
operator.pow, | ||
] | ||
|
||
comparison_ops = [ | ||
operator.eq, | ||
operator.ne, | ||
operator.lt, | ||
operator.le, | ||
operator.gt, | ||
operator.ge, | ||
] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
class Masked: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This isn't generally used in Python code outside of testing the implementation, but exists to give Numba something concrete to refer to for type inference, and to illustrate to the reader what a masked type looks like - is it worth adding a docstring explaining this here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added a docstring here |
||
def __init__(self, value, valid): | ||
self.value = value | ||
self.valid = valid |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typically we pass in
table_view const&
as copying it may involve recursively copying its children column_view which is more expensive.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You may need to be modified to use
table_view const&
(not just this, but in other places too).