Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalized null support in user defined functions #8213

Merged
merged 148 commits into from
Jul 16, 2021
Merged
Show file tree
Hide file tree
Changes from 118 commits
Commits
Show all changes
148 commits
Select commit Hold shift + click to select a range
a7bbfb4
just debugging info
brandon-b-miller Feb 24, 2021
f44c196
Merge branch 'branch-0.19' into fea-udf-nulls
brandon-b-miller Mar 3, 2021
193f8e0
initial python MaskedType
brandon-b-miller Mar 7, 2021
91ae6a3
a little cleanup
brandon-b-miller Mar 7, 2021
a855a6f
basic bindings, header, placeholder c++ code
brandon-b-miller Mar 7, 2021
1b2c00c
missed one cython file - bindings work and run
brandon-b-miller Mar 7, 2021
7584ad3
fix bug
brandon-b-miller Mar 7, 2021
4988b14
little more progress
brandon-b-miller Mar 8, 2021
7a6427c
an attempt at NA plumbing
brandon-b-miller Mar 12, 2021
5e6eb06
a little more plubming and prototyping
brandon-b-miller Mar 12, 2021
ea15da6
lots of progress
brandon-b-miller Mar 12, 2021
d1119b2
Merge branch 'branch-0.19' into fea-udf-nulls
brandon-b-miller Mar 15, 2021
961a9dd
trying to plumb to jitify launcher
brandon-b-miller Mar 16, 2021
664bf79
Merge branch 'branch-0.19' into fea-udf-nulls
brandon-b-miller Mar 16, 2021
5e93094
progress on jitify template/launch
brandon-b-miller Mar 16, 2021
03edceb
null kernel launches with all arguments
brandon-b-miller Mar 16, 2021
2b4c36f
bit_is_set works
brandon-b-miller Mar 16, 2021
3f76df5
successfully passing struct through the ptx function
brandon-b-miller Mar 17, 2021
db88f9e
pipeline fully runs
brandon-b-miller Mar 17, 2021
9a67670
it lives
brandon-b-miller Mar 17, 2021
b9da4bf
cleanup and add notebook
brandon-b-miller Apr 12, 2021
ec302f3
take the plunge and merge 0.20
brandon-b-miller Apr 12, 2021
cb85d88
integrate jitify2
brandon-b-miller Apr 12, 2021
ad067eb
minor cleanup
brandon-b-miller Apr 12, 2021
237af25
pushing forward with ND transform
brandon-b-miller Apr 22, 2021
8e11c7e
variadic kernel up and running
brandon-b-miller Apr 27, 2021
591627c
big plays
brandon-b-miller Apr 28, 2021
c07e187
general logic for building template instantiation arguments
brandon-b-miller Apr 28, 2021
d21b858
cleanup
brandon-b-miller Apr 29, 2021
6806968
attempting to use vector overload in jitify
brandon-b-miller Apr 29, 2021
cef8b71
loop kernel runs finally
brandon-b-miller May 3, 2021
19b88c5
full pipeline works for a dynamic number of columns
brandon-b-miller May 3, 2021
4845f27
code cleanup
brandon-b-miller May 3, 2021
c796dc4
more code cleanup/renaming
brandon-b-miller May 3, 2021
4f0ab9b
even more renaming and cleanup
brandon-b-miller May 3, 2021
3389198
remove old code
brandon-b-miller May 4, 2021
9425d4b
more cleanup
brandon-b-miller May 4, 2021
f7845e5
add a decorator to mimic the pandas api better
brandon-b-miller May 4, 2021
9e89ebd
starting to write tests
brandon-b-miller May 4, 2021
e19c8ba
add tests for constants
brandon-b-miller May 5, 2021
9880081
add failing tests for literal return
brandon-b-miller May 5, 2021
3e6a280
add NA and add constant tests
brandon-b-miller May 7, 2021
3028dba
MaskedType is parameterized
brandon-b-miller May 7, 2021
ecd8527
forward progress on mixed typing
brandon-b-miller May 7, 2021
5791413
generalize MaskedScalarAddConstant
brandon-b-miller May 7, 2021
77c8ee4
write a signature for any incoming types
brandon-b-miller May 7, 2021
85f1fba
reformat code
brandon-b-miller May 7, 2021
22c220c
need a separate __hash__ for different MaskedType
brandon-b-miller May 8, 2021
1ba3338
first sign that mixed typing works end to end
brandon-b-miller May 10, 2021
be06228
add tests for columns of mixed data type
brandon-b-miller May 10, 2021
029203b
incorporate grahams custom unification of extensiontypes with literals
brandon-b-miller May 10, 2021
f024bf7
unify MaskedType and NAType and add a huge comment
brandon-b-miller May 10, 2021
2ef4520
Questionable unification of Masked with Literal
brandon-b-miller May 10, 2021
90d5127
removed unused code
brandon-b-miller May 11, 2021
0953bd1
move alot of code around and refactor, add comments
brandon-b-miller May 11, 2021
96024c4
Merge branch 'branch-0.20' into fea-udf-nulls
brandon-b-miller May 12, 2021
6287404
remove erroneous header
brandon-b-miller May 12, 2021
aa38be2
typing and lowering for Masked is NA, currently not working
brandon-b-miller May 13, 2021
b5dcd13
remove erroneous logic
brandon-b-miller May 13, 2021
4c29c23
fix lowering for Masked is NA
brandon-b-miller May 13, 2021
837f2ef
rougly fix test_apply_NA_conditional, which was passing by coincidence
brandon-b-miller May 13, 2021
899c4dc
Merge branch 'branch-0.20' into fea-udf-nulls
brandon-b-miller May 14, 2021
2d7104d
support and test all arithmetic operators
brandon-b-miller May 14, 2021
b63b435
typing, lowering, tests for masked+constant
brandon-b-miller May 16, 2021
a3e1444
try and type mixed return value, and fail to do so
brandon-b-miller May 16, 2021
ca79d72
continued adding/refactoring of tests
brandon-b-miller May 16, 2021
3be8b16
merge latest
brandon-b-miller May 17, 2021
e1defcb
fix ops between masked and const of different dtype
brandon-b-miller May 18, 2021
33d3dcb
update tests
brandon-b-miller May 18, 2021
5769ded
add test for returning NA
brandon-b-miller May 18, 2021
1579b39
add masked v masked comparison ops
brandon-b-miller May 18, 2021
df28144
add tests for comparing masked to constant
brandon-b-miller May 18, 2021
8bab890
NA <-> Unmasked unification
brandon-b-miller May 19, 2021
6e7ac8d
partially address reviews
brandon-b-miller May 19, 2021
d58234e
just use args
brandon-b-miller May 19, 2021
739f6fc
add reflected ops vs NA
brandon-b-miller May 19, 2021
14e3ab8
add tests for reflected masked/na ops
brandon-b-miller May 19, 2021
417c130
typing for const + masked, lowering can wait for now
brandon-b-miller May 19, 2021
195e9b8
add grahams fix for Masked + const of a different type
brandon-b-miller May 19, 2021
6350909
Merge branch 'branch-21.06' into fea-udf-nulls
brandon-b-miller May 28, 2021
32f54d4
refactor a little c++
brandon-b-miller May 28, 2021
37a9257
minor docstring updates
brandon-b-miller May 28, 2021
22d610f
Add compilation tests for masked extensions + fix
gmarkall May 24, 2021
8a1b053
Fix flake8 in masked ops code
gmarkall May 24, 2021
58dab99
Add tests of comparisons, start testing unary ops
gmarkall May 24, 2021
4f06497
Don't test as-yet unimplemented not
gmarkall May 24, 2021
a59b240
Add execution test for masked ops
gmarkall May 24, 2021
e440770
Begin adding tests for operator.is_ with NA
gmarkall May 25, 2021
a6f67fa
Some tidy-ups in typing
gmarkall May 25, 2021
d9e8fdb
Fix test and implementation of `is NA`
gmarkall May 25, 2021
1d6755a
Add tests of comparison with NA behaviour
gmarkall May 26, 2021
671792c
test reflected const/masked ops - separate lowering to account for no…
brandon-b-miller May 28, 2021
c3007de
unify masked with masked
brandon-b-miller May 28, 2021
100ac44
allocate and build the final column in libcudf rather than cython
brandon-b-miller May 28, 2021
91c91eb
refactor c++ a bit
brandon-b-miller Jun 1, 2021
1fa3cab
use offset_type rather than hardcoding int64_t incorrectly
brandon-b-miller Jun 1, 2021
6125dc0
a little bit more refactoring
brandon-b-miller Jun 1, 2021
59e1209
remove debugging code
brandon-b-miller Jun 1, 2021
c1324b8
move repeated imports to their own function
brandon-b-miller Jun 1, 2021
ed79368
remove old ipython notebook
brandon-b-miller Jun 1, 2021
62ddca7
cleanup
brandon-b-miller Jun 1, 2021
821d11d
cpp style fixes
brandon-b-miller Jun 1, 2021
fb8f1cf
cache ptx
brandon-b-miller Jun 7, 2021
5d77b2b
partially address reviews
brandon-b-miller Jun 7, 2021
c3254f9
Merge branch 'branch-21.08' into fea-udf-nulls
brandon-b-miller Jun 8, 2021
bbefb7d
Merge branch 'branch-21.08' into fea-udf-nulls
brandon-b-miller Jun 10, 2021
f863ba1
switch to push_back
brandon-b-miller Jun 10, 2021
92cd6eb
more pushing back
brandon-b-miller Jun 10, 2021
4b08c51
xfail pow tests due to issue cudf/8470
brandon-b-miller Jun 10, 2021
48733b2
style fixes
brandon-b-miller Jun 15, 2021
16018f6
more style fixes
brandon-b-miller Jun 15, 2021
c91737e
update tests and _ops
brandon-b-miller Jun 15, 2021
d426589
merge 21.08
brandon-b-miller Jun 17, 2021
0da7fc7
address reviewsA
brandon-b-miller Jun 17, 2021
9048879
fix typing for NA
brandon-b-miller Jun 17, 2021
9fa05a3
minor name change
brandon-b-miller Jun 17, 2021
b807534
Update cpp/src/transform/jit/masked_udf_kernel.cu
brandon-b-miller Jun 23, 2021
f00cefc
Merge branch 'branch-21.08' into fea-udf-nulls
brandon-b-miller Jun 23, 2021
fe5eb30
Merge branch 'branch-21.08' into fea-udf-nulls
brandon-b-miller Jun 28, 2021
f56ffbb
add back missing header
brandon-b-miller Jun 29, 2021
7f07452
revise headers again
brandon-b-miller Jun 29, 2021
968e91b
update docstring with examples
brandon-b-miller Jul 1, 2021
699239d
add error checking
brandon-b-miller Jul 1, 2021
2d07152
Address reviews
brandon-b-miller Jul 1, 2021
95098e6
Apply suggestions from code review
brandon-b-miller Jul 1, 2021
b724410
address more revies
brandon-b-miller Jul 1, 2021
593cbd2
simplify masked/unmasked typing logic
brandon-b-miller Jul 1, 2021
a31c15a
style fixes
brandon-b-miller Jul 2, 2021
448e4ea
refactor lowering for reflected const ops
brandon-b-miller Jul 2, 2021
6780814
cleanup
brandon-b-miller Jul 2, 2021
e622d30
Merge branch 'branch-21.08' into fea-udf-nulls
brandon-b-miller Jul 6, 2021
6bf3cf5
fix import and address reviews
brandon-b-miller Jul 6, 2021
6ed7a49
capture libcudacxx version for debugging
brandon-b-miller Jul 7, 2021
4ab7bd8
error for cuda<11.1
brandon-b-miller Jul 12, 2021
1ffce5b
remove CI debugging
brandon-b-miller Jul 12, 2021
169bcf2
skip testing cuda 11.0
brandon-b-miller Jul 12, 2021
aec243d
fix pytest
brandon-b-miller Jul 13, 2021
51ce28f
Merge branch 'branch-21.08' into fea-udf-nulls
brandon-b-miller Jul 13, 2021
993d841
Apply suggestions from code review
brandon-b-miller Jul 13, 2021
512555b
partially address reviews
brandon-b-miller Jul 13, 2021
8f1add4
Apply suggestions from code review
brandon-b-miller Jul 13, 2021
d683db9
Merge branch 'fea-udf-nulls' of github.com:brandon-b-miller/cudf into…
brandon-b-miller Jul 13, 2021
b061710
updates
brandon-b-miller Jul 14, 2021
7c722dd
style
brandon-b-miller Jul 14, 2021
7a7ee83
use table_view const&
brandon-b-miller Jul 15, 2021
a20d630
switch to a lambda
brandon-b-miller Jul 15, 2021
a13e935
Update cpp/src/transform/transform.cpp
brandon-b-miller Jul 16, 2021
9acc7a9
updates
brandon-b-miller Jul 16, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cpp/cmake/Modules/JitifyPreprocessKernels.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ endfunction()

jit_preprocess_files(SOURCE_DIRECTORY ${CUDF_SOURCE_DIR}/src
FILES binaryop/jit/kernel.cu
transform/jit/masked_udf_kernel.cu
transform/jit/kernel.cu
rolling/jit/kernel.cu
)
Expand Down
6 changes: 6 additions & 0 deletions cpp/include/cudf/transform.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ std::unique_ptr<column> transform(
bool is_ptx,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

std::unique_ptr<column> generalized_masked_op(
table_view data_view,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typically we pass in table_view const& as copying it may involve recursively copying its children column_view which is more expensive.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You may need to be modified to use table_view const& (not just this, but in other places too).

std::string const& binary_udf,
data_type output_type,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

/**
* @brief Creates a null_mask from `input` by converting `NaN` to null and
* preserving existing null values and also returns new null_count.
Expand Down
89 changes: 89 additions & 0 deletions cpp/src/transform/jit/masked_udf_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

// Include Jitify's cstddef header first
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why? The convention in cudf is to include from "near" to "far". So, you include <transform/...> first, then <cudf/...>, then <cuda/...>, then std headers finally.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the problem here is that technically when this file is runtime compilated later, transform/jit/operation-udf.hpp gets string replaced by by an actual function definition that might contain the types in the std headers. So I think at least the order of those two headers is critical.

#include <cstddef>

#include <cuda/std/climits>
#include <cuda/std/cstddef>
#include <cuda/std/limits>
#include <cuda/std/type_traits>

#include <transform/jit/operation-udf.hpp>

#include <cudf/types.hpp>
#include <cudf/utilities/bit.hpp>
#include <cudf/wrappers/timestamps.hpp>

#include <cuda/std/tuple>
#include <tuple>

namespace cudf {
namespace transformation {
namespace jit {

template <typename T>
struct Masked {
T value;
bool valid;
};

template <typename TypeIn, typename MaskType, typename OffsetType>
__device__ auto make_args(cudf::size_type id, TypeIn in_ptr, MaskType in_mask, OffsetType in_offset)
{
bool valid = in_mask ? cudf::bit_is_set(in_mask, in_offset + id) : true;
return cuda::std::make_tuple(in_ptr[id], valid);
}

template <typename InType, typename MaskType, typename OffsetType, typename... Arguments>
__device__ auto make_args(cudf::size_type id,
InType in_ptr,
MaskType in_mask, // in practice, always cudf::bitmask_type const*
OffsetType in_offset, // in practice, always cudf::size_type
Arguments... args)
{
bool valid = in_mask ? cudf::bit_is_set(in_mask, in_offset + id) : true;
return cuda::std::tuple_cat(cuda::std::make_tuple(in_ptr[id], valid), make_args(id, args...));
}

template <typename TypeOut, typename... Arguments>
__global__ void generic_udf_kernel(cudf::size_type size,
TypeOut* out_data,
bool* out_mask,
Arguments... args)
{
int const tid = threadIdx.x;
int const blkid = blockIdx.x;
int const blksz = blockDim.x;
int const gridsz = gridDim.x;
int const start = tid + blkid * blksz;
int const step = blksz * gridsz;

Masked<TypeOut> output;
for (cudf::size_type i = start; i < size; i += step) {
auto func_args = cuda::std::tuple_cat(
cuda::std::make_tuple(&output.value),
make_args(i, args...) // passed int64*, bool*, int64, int64*, bool*, int64
);
cuda::std::apply(GENERIC_OP, func_args);
isVoid marked this conversation as resolved.
Show resolved Hide resolved
out_data[i] = output.value;
out_mask[i] = output.valid;
isVoid marked this conversation as resolved.
Show resolved Hide resolved
}
}

} // namespace jit
} // namespace transformation
} // namespace cudf
103 changes: 103 additions & 0 deletions cpp/src/transform/transform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

#include <jit_preprocessed_files/transform/jit/kernel.cu.jit.hpp>
#include <jit_preprocessed_files/transform/jit/masked_udf_kernel.cu.jit.hpp>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe that jit headers should be included after cudf headers.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍


#include <jit/cache.hpp>
#include <jit/parser.hpp>
Expand All @@ -25,6 +26,7 @@
#include <cudf/detail/nvtx/ranges.hpp>
#include <cudf/detail/transform.hpp>
#include <cudf/null_mask.hpp>
#include <cudf/table/table_view.hpp>
#include <cudf/utilities/traits.hpp>
#include <cudf/utilities/type_dispatcher.hpp>

Expand Down Expand Up @@ -63,6 +65,81 @@ void unary_operation(mutable_column_view output,
cudf::jit::get_data_ptr(input));
}

std::vector<std::string> make_template_types(column_view outcol_view, table_view data_view)
{
std::string mskptr_type =
cudf::jit::get_type_name(cudf::data_type(cudf::type_to_id<cudf::bitmask_type>())) + "*";
std::string offset_type =
cudf::jit::get_type_name(cudf::data_type(cudf::type_to_id<cudf::offset_type>()));

std::vector<std::string> template_types;
template_types.reserve(data_view.num_columns() + 1);

template_types.push_back(cudf::jit::get_type_name(outcol_view.type()));
for (auto const& col : data_view) {
template_types.push_back(cudf::jit::get_type_name(col.type()) + "*");
template_types.push_back(mskptr_type);
template_types.push_back(offset_type);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait, I see that you call push_back by 3*num_cols() + 1 times instead of num_cols() + 1.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice catch - this was unsafe. Fixed

return template_types;
}

void generalized_operation(table_view data_view,
std::string const& udf,
data_type output_type,
mutable_column_view outcol_view,
mutable_column_view outmsk_view,
rmm::mr::device_memory_resource* mr)
{
std::vector<std::string> template_types = make_template_types(outcol_view, data_view);
Copy link
Contributor

@ttnghia ttnghia Jul 13, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One more thing I want to note is that, you can use auto const for declaring almost everything, instead of writing lengthy types like this. I.e.,

auto const template_types =...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.


std::string generic_kernel_name =
jitify2::reflection::Template("cudf::transformation::jit::generic_udf_kernel")
.instantiate(template_types);

std::string generic_cuda_source = cudf::jit::parse_single_function_ptx(
udf, "GENERIC_OP", cudf::jit::get_type_name(output_type), {0});

// {size, out_ptr, out_mask_ptr, col0_ptr, col0_mask_ptr, col0_offset, col1_ptr...}
std::vector<void*> kernel_args;
kernel_args.reserve((data_view.num_columns() * 3) + 3);

cudf::size_type size = outcol_view.size();
const void* outcol_ptr = cudf::jit::get_data_ptr(outcol_view);
const void* outmsk_ptr = cudf::jit::get_data_ptr(outmsk_view);
kernel_args.insert(kernel_args.begin(), {&size, &outcol_ptr, &outmsk_ptr});

std::vector<const void*> data_ptrs;
std::vector<cudf::bitmask_type const*> mask_ptrs;
std::vector<cudf::offset_type> offsets;

data_ptrs.reserve(data_view.num_columns());
mask_ptrs.reserve(data_view.num_columns());
offsets.reserve(data_view.num_columns());

column_view col;
for (int col_idx = 0; col_idx < data_view.num_columns(); col_idx++) {
col = data_view.column(col_idx);
brandon-b-miller marked this conversation as resolved.
Show resolved Hide resolved

data_ptrs.push_back(cudf::jit::get_data_ptr(col));
mask_ptrs.push_back(col.null_mask());
offsets.push_back(col.offset());

kernel_args.push_back(&data_ptrs[col_idx]);
kernel_args.push_back(&mask_ptrs[col_idx]);
kernel_args.push_back(&offsets[col_idx]);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use some type of std::transform instead? Using raw loop is discouraged.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is difficult due to the 1->3 transform going on here. I kept trying to do the same, but couldn't get anything that was cleaner.

Copy link
Contributor

@ttnghia ttnghia Jul 14, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about using thrust::zip_iterator (host callable)? You can output to 3 values at the same time.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I managed to use zip_iterator to replace about half the logic here. One loop though I did not see how to simplify, open to suggestions here.


rmm::cuda_stream_view generic_stream;
cudf::jit::get_program_cache(*transform_jit_masked_udf_kernel_cu_jit)
.get_kernel(generic_kernel_name,
{},
{{"transform/jit/operation-udf.hpp", generic_cuda_source}},
{"-arch=sm_."}) //
->configure_1d_max_occupancy(0, 0, 0, generic_stream.value()) //
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why generic_stream is used without initialization? Are you using the default stream? If so, call default stream directly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be fixed.

->launch(kernel_args.data());
}

} // namespace jit
} // namespace transformation

Expand All @@ -89,6 +166,24 @@ std::unique_ptr<column> transform(column_view const& input,
return output;
}

std::unique_ptr<column> generalized_masked_op(table_view data_view,
std::string const& udf,
data_type output_type,
rmm::mr::device_memory_resource* mr)
{
rmm::cuda_stream_view stream = rmm::cuda_stream_default;
std::unique_ptr<column> output = make_fixed_width_column(output_type, data_view.num_rows());
std::unique_ptr<column> output_mask =
make_fixed_width_column(cudf::data_type{cudf::type_id::BOOL8}, data_view.num_rows());
isVoid marked this conversation as resolved.
Show resolved Hide resolved

transformation::jit::generalized_operation(
data_view, udf, output_type, *output, *output_mask, mr);

auto final_output_mask = cudf::bools_to_mask(*output_mask);
output.get()->set_null_mask(std::move(*(final_output_mask.first)));
brandon-b-miller marked this conversation as resolved.
Show resolved Hide resolved
return output;
}

} // namespace detail

std::unique_ptr<column> transform(column_view const& input,
Expand All @@ -101,4 +196,12 @@ std::unique_ptr<column> transform(column_view const& input,
return detail::transform(input, unary_udf, output_type, is_ptx, rmm::cuda_stream_default, mr);
}

std::unique_ptr<column> generalized_masked_op(table_view data_view,
std::string const& udf,
data_type output_type,
rmm::mr::device_memory_resource* mr)
{
return detail::generalized_masked_op(data_view, udf, output_type, mr);
}

} // namespace cudf
6 changes: 6 additions & 0 deletions python/cudf/cudf/_lib/cpp/transform.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ cdef extern from "cudf/transform.hpp" namespace "cudf" nogil:
bool is_ptx
) except +

cdef unique_ptr[column] generalized_masked_op(
table_view data_view,
string udf,
data_type output_type,
) except +

cdef pair[unique_ptr[table], unique_ptr[column]] encode(
table_view input
) except +
21 changes: 21 additions & 0 deletions python/cudf/cudf/_lib/transform.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,27 @@ def transform(Column input, op):
return Column.from_unique_ptr(move(c_output))


def masked_udf(Table incols, op, output_type):
cdef table_view data_view = incols.data_view()
cdef string c_str = op.encode("UTF-8")
cdef type_id c_tid
cdef data_type c_dtype

c_tid = <type_id> (
<underlying_type_t_type_id> np_to_cudf_types[output_type]
)
c_dtype = data_type(c_tid)

with nogil:
c_output = move(libcudf_transform.generalized_masked_op(
data_view,
c_str,
c_dtype,
))

return Column.from_unique_ptr(move(c_output))


def table_encode(Table input):
cdef table_view c_input = input.data_view()
cdef pair[unique_ptr[table], unique_ptr[column]] c_result
Expand Down
1 change: 1 addition & 0 deletions python/cudf/cudf/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,5 @@
from cudf.core.multiindex import MultiIndex
from cudf.core.scalar import NA, Scalar
from cudf.core.series import Series
import cudf.core.udf
from cudf.core.cut import cut
29 changes: 29 additions & 0 deletions python/cudf/cudf/core/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4697,6 +4697,35 @@ def query(self, expr, local_dict=None):
boolmask = queryutils.query_execute(self, expr, callenv)
return self._apply_boolean_mask(boolmask)

def apply(self, func, axis=1):
galipremsagar marked this conversation as resolved.
Show resolved Hide resolved
"""
Apply a function along an axis of the DataFrame.

Designed to mimic `pandas.DataFrame.apply`. Applies a user
defined function row wise over a dataframe, with true null
handling. Works with UDFs using `core.udf.pipeline.nulludf`
and returns a single series. Uses numba to jit compile the
function to PTX via LLVM.

Parameters
----------
func : function
Function to apply to each row.

axis : {0 or 'index', 1 or 'columns'}, default 0
Axis along which the function is applied:
* 0 or 'index': apply function to each column.
Note: axis=0 is not yet supported.
* 1 or 'columns': apply function to each row.

"""
if axis != 1:
raise ValueError(
"DataFrame.apply currently only supports row wise ops"
)

return func(self)

@applyutils.doc_apply()
def apply_rows(
self,
Expand Down
11 changes: 11 additions & 0 deletions python/cudf/cudf/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -1456,6 +1456,17 @@ def _quantiles(
result._copy_type_metadata(self)
return result

@annotate("APPLY", color="purple", domain="cudf_python")
def _apply(self, func):
"""
Apply `func` across the rows of the frame.
"""
output_dtype, ptx = cudf.core.udf.pipeline.compile_masked_udf(
func, self.dtypes
)
result = cudf._lib.transform.masked_udf(self, ptx, output_dtype)
return result

def rank(
self,
axis=0,
Expand Down
1 change: 1 addition & 0 deletions python/cudf/cudf/core/udf/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from . import typing, lowering
20 changes: 20 additions & 0 deletions python/cudf/cudf/core/udf/_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import operator

arith_ops = [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there are more ops that can be added here, but my expectation / understanding is that this list is fine for this PR, and that we might add more in future PRs (along with tests for them) - does that match your thoughts / plans?

Copy link
Contributor Author

@brandon-b-miller brandon-b-miller Jul 1, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should add everything we can provided that it will just work if we add it to this list as well as the tests, which ones am I missing?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From looking at Numba's implementation of the operator module, I see the following that would need to be tested / checked:

Bitwise operations (for integer types):

bit_ops = ( 
    operator.lshift,
    operator.rshift,
    operator.and_,
    operator.or_,
    operator.xor,
)

Inplace operations:

inplace_ops = ( 
    operator.iadd,
    operator.isub,
    operator.imul,
    operator.imod,
    operator.itruediv,
    operator.ifloordiv,
    operator.ipow,
    operator.ilshift,
    operator.irshift,
    operator.iand,
)

Most of the unary operations:

unary_ops = ( 
    operator.truth,
    operator.invert,
    operator.neg,
    operator.pos,
    operator.not_
)

Maybe some instances of is and is not from the comparison operations:

comparison_ops = (
    operator.lt,
    operator.le,
    operator.eq,
    operator.ne,
    operator.ge,
    operator.gt,
    operator.is_,
    operator.is_not,
)

Some of the unary ops:

unary_ops = (
    operator.truth,
    operator.invert,
    operator.neg,
    operator.pos,
    operator.not_
)

I think it's fairly straightforward to add everything, but needs some thought / care to make sure that we only provide typing for things that make sense - e.g. not providing bitwise operations on floats. I do still think it would be better to get this over the line then add more in another PR than to add to the workload of this PR (and I'm happy to help with the addition of more when I get a chance).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great summary. I'll make sure to file some issues for these once this merges.

When we don't support something, should we just not provide typing for those cases, and let the typing machinery raise the standard numba error? Or would it be better to provide typing that detects the scenario and then errors with something more specific (I assume this would bubble up through numba's normal traceback system anyways) - just thinking forward for disabling cases like __pow__ and some of the other functions we have problems with currently

operator.add,
operator.sub,
operator.mul,
operator.truediv,
operator.floordiv,
operator.mod,
operator.pow,
]

comparison_ops = [
operator.eq,
operator.ne,
operator.lt,
operator.le,
operator.gt,
operator.ge,
]
4 changes: 4 additions & 0 deletions python/cudf/cudf/core/udf/classes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
class Masked:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't generally used in Python code outside of testing the implementation, but exists to give Numba something concrete to refer to for type inference, and to illustrate to the reader what a masked type looks like - is it worth adding a docstring explaining this here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a docstring here

def __init__(self, value, valid):
self.value = value
self.valid = valid
Loading