Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pure-python masked UDFs #9174

Merged
Merged
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
8d98b73
impl
brandon-b-miller Aug 30, 2021
45c47ee
purge c++ code
brandon-b-miller Aug 30, 2021
c2e0218
enable cuda 11.0
brandon-b-miller Aug 30, 2021
69f92cf
enable tests for __pow__
brandon-b-miller Aug 30, 2021
b636af9
solve multiple problems
brandon-b-miller Aug 31, 2021
bdf5823
masks are required for entry
brandon-b-miller Sep 2, 2021
470d25e
support returning a single number
brandon-b-miller Sep 3, 2021
e271ce4
formatting
brandon-b-miller Sep 3, 2021
0a971f0
bugfix
brandon-b-miller Sep 3, 2021
2f7e6f8
remove header
brandon-b-miller Sep 3, 2021
13a94cb
fix bool typing
brandon-b-miller Sep 3, 2021
b2a68e6
template kernels
brandon-b-miller Sep 3, 2021
5cb75e7
switch back to forall
brandon-b-miller Sep 3, 2021
49d9978
implement construct_signature
brandon-b-miller Sep 3, 2021
2ba8bd2
support offsets
brandon-b-miller Sep 3, 2021
11b2fd1
cache kernels
brandon-b-miller Sep 3, 2021
7379fe1
merge latest
brandon-b-miller Sep 7, 2021
775dd57
style
brandon-b-miller Sep 7, 2021
04c38e6
skip cases where pandas null logic differs
brandon-b-miller Sep 8, 2021
7a01bdb
style
brandon-b-miller Sep 8, 2021
627d197
update tests slightly
brandon-b-miller Sep 8, 2021
d3e2e0b
updates to pipeline.py
brandon-b-miller Sep 8, 2021
394fad3
Merge branch 'branch-21.10' into fea-masked-udf-pure-python
brandon-b-miller Sep 10, 2021
306f5e1
address many reviews
brandon-b-miller Sep 13, 2021
05adec7
cleanup
brandon-b-miller Sep 13, 2021
edbae6c
minor updtes
brandon-b-miller Sep 13, 2021
e224bee
Apply suggestions from code review
brandon-b-miller Sep 14, 2021
a446b75
address reviews
brandon-b-miller Sep 14, 2021
b54e11e
remove creating buffers if the column has no mask
brandon-b-miller Sep 14, 2021
16406ff
put buffer back in for blank mask for now
brandon-b-miller Sep 14, 2021
ba2d898
merge latest and resolve conflicts
brandon-b-miller Sep 16, 2021
30d6013
fix import bug
brandon-b-miller Sep 17, 2021
a369641
clarify exec context
brandon-b-miller Sep 17, 2021
e51d780
Merge branch 'branch-21.10' into fea-masked-udf-pure-python
brandon-b-miller Sep 17, 2021
3c0c76f
rework unmasked kernels slightly
brandon-b-miller Sep 17, 2021
6deb96a
un purge c++
brandon-b-miller Sep 21, 2021
51b4fc9
cpp cleanup
brandon-b-miller Sep 21, 2021
4249334
Merge branch 'branch-21.10' into fea-masked-udf-pure-python
brandon-b-miller Sep 21, 2021
9f3c60e
Merge branch 'branch-21.10' into fea-masked-udf-pure-python
brandon-b-miller Sep 23, 2021
71c71b8
address reviews
brandon-b-miller Sep 28, 2021
b0580e9
Merge branch 'branch-21.12' into fea-masked-udf-pure-python
brandon-b-miller Sep 29, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion cpp/cmake/Modules/JitifyPreprocessKernels.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ endfunction()

jit_preprocess_files(SOURCE_DIRECTORY ${CUDF_SOURCE_DIR}/src
FILES binaryop/jit/kernel.cu
transform/jit/masked_udf_kernel.cu
transform/jit/kernel.cu
rolling/jit/kernel.cu
)
Expand Down
6 changes: 0 additions & 6 deletions cpp/include/cudf/transform.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,6 @@ std::unique_ptr<column> transform(
bool is_ptx,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

std::unique_ptr<column> generalized_masked_op(
table_view const& data_view,
std::string const& binary_udf,
data_type output_type,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

/**
* @brief Creates a null_mask from `input` by converting `NaN` to null and
* preserving existing null values and also returns new null_count.
Expand Down
85 changes: 0 additions & 85 deletions cpp/src/transform/jit/masked_udf_kernel.cu

This file was deleted.

101 changes: 0 additions & 101 deletions cpp/src/transform/transform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
#include <cudf/utilities/type_dispatcher.hpp>

#include <jit_preprocessed_files/transform/jit/kernel.cu.jit.hpp>
#include <jit_preprocessed_files/transform/jit/masked_udf_kernel.cu.jit.hpp>

#include <jit/cache.hpp>
#include <jit/parser.hpp>
Expand Down Expand Up @@ -65,80 +64,6 @@ void unary_operation(mutable_column_view output,
cudf::jit::get_data_ptr(input));
}

std::vector<std::string> make_template_types(column_view outcol_view, table_view const& data_view)
{
std::string mskptr_type =
cudf::jit::get_type_name(cudf::data_type(cudf::type_to_id<cudf::bitmask_type>())) + "*";
std::string offset_type =
cudf::jit::get_type_name(cudf::data_type(cudf::type_to_id<cudf::offset_type>()));

std::vector<std::string> template_types;
template_types.reserve((3 * data_view.num_columns()) + 1);

template_types.push_back(cudf::jit::get_type_name(outcol_view.type()));
for (auto const& col : data_view) {
template_types.push_back(cudf::jit::get_type_name(col.type()) + "*");
template_types.push_back(mskptr_type);
template_types.push_back(offset_type);
}
return template_types;
}

void generalized_operation(table_view const& data_view,
std::string const& udf,
data_type output_type,
mutable_column_view outcol_view,
mutable_column_view outmsk_view,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
auto const template_types = make_template_types(outcol_view, data_view);

std::string generic_kernel_name =
jitify2::reflection::Template("cudf::transformation::jit::generic_udf_kernel")
.instantiate(template_types);

std::string generic_cuda_source = cudf::jit::parse_single_function_ptx(
udf, "GENERIC_OP", cudf::jit::get_type_name(output_type), {0});

std::vector<void*> kernel_args;
kernel_args.reserve((data_view.num_columns() * 3) + 3);

cudf::size_type size = outcol_view.size();
const void* outcol_ptr = cudf::jit::get_data_ptr(outcol_view);
const void* outmsk_ptr = cudf::jit::get_data_ptr(outmsk_view);
kernel_args.insert(kernel_args.begin(), {&size, &outcol_ptr, &outmsk_ptr});

std::vector<const void*> data_ptrs;
std::vector<cudf::bitmask_type const*> mask_ptrs;
std::vector<cudf::offset_type> offsets;

data_ptrs.reserve(data_view.num_columns());
mask_ptrs.reserve(data_view.num_columns());
offsets.reserve(data_view.num_columns());

auto const iters = thrust::make_zip_iterator(
thrust::make_tuple(data_ptrs.begin(), mask_ptrs.begin(), offsets.begin()));

std::for_each(iters, iters + data_view.num_columns(), [&](auto const& tuple_vals) {
kernel_args.push_back(&thrust::get<0>(tuple_vals));
kernel_args.push_back(&thrust::get<1>(tuple_vals));
kernel_args.push_back(&thrust::get<2>(tuple_vals));
});

std::transform(data_view.begin(), data_view.end(), iters, [&](column_view const& col) {
return thrust::make_tuple(cudf::jit::get_data_ptr(col), col.null_mask(), col.offset());
});

cudf::jit::get_program_cache(*transform_jit_masked_udf_kernel_cu_jit)
.get_kernel(generic_kernel_name,
{},
{{"transform/jit/operation-udf.hpp", generic_cuda_source}},
{"-arch=sm_."})
->configure_1d_max_occupancy(0, 0, 0, stream.value())
->launch(kernel_args.data());
}

} // namespace jit
} // namespace transformation

Expand All @@ -165,24 +90,6 @@ std::unique_ptr<column> transform(column_view const& input,
return output;
}

std::unique_ptr<column> generalized_masked_op(table_view const& data_view,
std::string const& udf,
data_type output_type,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
std::unique_ptr<column> output = make_fixed_width_column(output_type, data_view.num_rows());
std::unique_ptr<column> output_mask =
make_fixed_width_column(cudf::data_type{cudf::type_id::BOOL8}, data_view.num_rows());

transformation::jit::generalized_operation(
data_view, udf, output_type, *output, *output_mask, stream, mr);

auto final_output_mask = cudf::bools_to_mask(*output_mask);
output.get()->set_null_mask(std::move(*(final_output_mask.first)));
return output;
}

} // namespace detail

std::unique_ptr<column> transform(column_view const& input,
Expand All @@ -195,12 +102,4 @@ std::unique_ptr<column> transform(column_view const& input,
return detail::transform(input, unary_udf, output_type, is_ptx, rmm::cuda_stream_default, mr);
}

std::unique_ptr<column> generalized_masked_op(table_view const& data_view,
std::string const& udf,
data_type output_type,
rmm::mr::device_memory_resource* mr)
{
return detail::generalized_masked_op(data_view, udf, output_type, rmm::cuda_stream_default, mr);
}

} // namespace cudf
6 changes: 0 additions & 6 deletions python/cudf/cudf/_lib/cpp/transform.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,6 @@ cdef extern from "cudf/transform.hpp" namespace "cudf" nogil:
bool is_ptx
) except +

cdef unique_ptr[column] generalized_masked_op(
const table_view& data_view,
string udf,
data_type output_type,
) except +

cdef pair[unique_ptr[table], unique_ptr[column]] encode(
table_view input
) except +
24 changes: 0 additions & 24 deletions python/cudf/cudf/_lib/transform.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -126,30 +126,6 @@ def transform(Column input, op):
return Column.from_unique_ptr(move(c_output))


def masked_udf(Table incols, op, output_type):
cdef table_view data_view = table_view_from_table(
incols, ignore_index=True)
cdef string c_str = op.encode("UTF-8")
cdef type_id c_tid
cdef data_type c_dtype

c_tid = <type_id> (
<underlying_type_t_type_id> SUPPORTED_NUMPY_TO_LIBCUDF_TYPES[
output_type
]
)
c_dtype = data_type(c_tid)

with nogil:
c_output = move(libcudf_transform.generalized_masked_op(
data_view,
c_str,
c_dtype,
))

return Column.from_unique_ptr(move(c_output))


def table_encode(Table input):
cdef table_view c_input = table_view_from_table(
input, ignore_index=True)
Expand Down
6 changes: 0 additions & 6 deletions python/cudf/cudf/core/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4918,12 +4918,6 @@ def apply(self, func, axis=1):
runtime compilation features
"""

# libcudacxx tuples are not compatible with nvrtc 11.0
runtime = cuda.cudadrv.runtime.Runtime()
mjr, mnr = runtime.get_version()
if mjr < 11 or (mjr == 11 and mnr < 1):
raise RuntimeError("DataFrame.apply requires CUDA 11.1+")

for dtype in self.dtypes:
if (
isinstance(dtype, cudf.core.dtypes._BaseDtype)
Expand Down
25 changes: 22 additions & 3 deletions python/cudf/cudf/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
)
from cudf.core.column_accessor import ColumnAccessor
from cudf.core.join import merge
from cudf.core.udf.pipeline import compile_or_get
from cudf.core.window import Rolling
from cudf.utils import ioutils
from cudf.utils.docutils import copy_docstring
Expand Down Expand Up @@ -1342,10 +1343,28 @@ def _apply(self, func):
"""
Apply `func` across the rows of the frame.
"""
output_dtype, ptx = cudf.core.udf.pipeline.compile_masked_udf(
func, self.dtypes
kernel, retty = compile_or_get(self, func)

# Mask and data column preallocated
ans_col = cupy.empty(len(self), dtype=retty)
ans_mask = cudf.core.column.column_empty(len(self), dtype="bool")
launch_args = [(ans_col, ans_mask)]
offsets = []
for col in self._data.values():
data = col.data
mask = col.mask
if mask is None:
mask = cudf.core.buffer.Buffer()
launch_args.append((data, mask))
offsets.append(col.offset)
launch_args += offsets
launch_args.append(len(self)) # size
kernel.forall(len(self))(*launch_args)

result = cudf.Series(ans_col).set_mask(
libcudf.transform.bools_to_mask(ans_mask)
)
result = cudf._lib.transform.masked_udf(self, ptx, output_dtype)

return result

def rank(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,11 @@ class Masked:
def __init__(self, value, valid):
self.value = value
self.valid = valid


def pack_return(masked_or_scalar):
# Blank function to give us something for the typing and
# lowering to grab onto. Just a dummy function for us to
# call within kernels that will get replaced later by the
# lowered implementation
pass
27 changes: 22 additions & 5 deletions python/cudf/cudf/core/udf/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from cudf.core.udf.typing import MaskedType, NAType

from . import classes
from . import api
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make these imports absolute to be consistent with the rest of the code-base?

from ._ops import arith_ops, comparison_ops


Expand Down Expand Up @@ -154,9 +154,8 @@ def register_const_op(op):
to_lower_op = make_const_op(op)
cuda_lower(op, MaskedType, types.Number)(to_lower_op)
cuda_lower(op, types.Number, MaskedType)(to_lower_op)

# to_lower_op_reflected = make_reflected_const_op(op)
# cuda_lower(op, types.Number, MaskedType)(to_lower_op_reflected)
cuda_lower(op, MaskedType, types.Boolean)(to_lower_op)
cuda_lower(op, types.Boolean, MaskedType)(to_lower_op)


# register all lowering at init
Expand Down Expand Up @@ -194,6 +193,24 @@ def masked_scalar_is_null_impl(context, builder, sig, args):
return builder.load(result)


# Main kernel always calls `pack_return` on whatever the user defined
# function returned. This returns the same data if its already a `Masked`
# else packs it up into a new one that is valid from the get go
@cuda_lower(api.pack_return, MaskedType)
def pack_return_masked_impl(context, builder, sig, args):
return args[0]


@cuda_lower(api.pack_return, types.Boolean)
@cuda_lower(api.pack_return, types.Number)
def pack_return_scalar_impl(context, builder, sig, args):
outdata = cgutils.create_struct_proxy(sig.return_type)(context, builder)
outdata.value = args[0]
outdata.valid = context.get_constant(types.boolean, 1)

return outdata._getvalue()


@cuda_lower(operator.truth, MaskedType)
def masked_scalar_truth_impl(context, builder, sig, args):
indata = cgutils.create_struct_proxy(MaskedType(types.boolean))(
Expand Down Expand Up @@ -253,7 +270,7 @@ def cast_masked_to_masked(context, builder, fromty, toty, val):


# Masked constructor for use in a kernel for testing
@lower_builtin(classes.Masked, types.Number, types.boolean)
@lower_builtin(api.Masked, types.Number, types.boolean)
def masked_constructor(context, builder, sig, args):
ty = sig.return_type
value, valid = args
Expand Down
Loading