Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pure-python masked UDFs #9174

Merged
Merged
Show file tree
Hide file tree
Changes from 38 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
8d98b73
impl
brandon-b-miller Aug 30, 2021
45c47ee
purge c++ code
brandon-b-miller Aug 30, 2021
c2e0218
enable cuda 11.0
brandon-b-miller Aug 30, 2021
69f92cf
enable tests for __pow__
brandon-b-miller Aug 30, 2021
b636af9
solve multiple problems
brandon-b-miller Aug 31, 2021
bdf5823
masks are required for entry
brandon-b-miller Sep 2, 2021
470d25e
support returning a single number
brandon-b-miller Sep 3, 2021
e271ce4
formatting
brandon-b-miller Sep 3, 2021
0a971f0
bugfix
brandon-b-miller Sep 3, 2021
2f7e6f8
remove header
brandon-b-miller Sep 3, 2021
13a94cb
fix bool typing
brandon-b-miller Sep 3, 2021
b2a68e6
template kernels
brandon-b-miller Sep 3, 2021
5cb75e7
switch back to forall
brandon-b-miller Sep 3, 2021
49d9978
implement construct_signature
brandon-b-miller Sep 3, 2021
2ba8bd2
support offsets
brandon-b-miller Sep 3, 2021
11b2fd1
cache kernels
brandon-b-miller Sep 3, 2021
7379fe1
merge latest
brandon-b-miller Sep 7, 2021
775dd57
style
brandon-b-miller Sep 7, 2021
04c38e6
skip cases where pandas null logic differs
brandon-b-miller Sep 8, 2021
7a01bdb
style
brandon-b-miller Sep 8, 2021
627d197
update tests slightly
brandon-b-miller Sep 8, 2021
d3e2e0b
updates to pipeline.py
brandon-b-miller Sep 8, 2021
394fad3
Merge branch 'branch-21.10' into fea-masked-udf-pure-python
brandon-b-miller Sep 10, 2021
306f5e1
address many reviews
brandon-b-miller Sep 13, 2021
05adec7
cleanup
brandon-b-miller Sep 13, 2021
edbae6c
minor updtes
brandon-b-miller Sep 13, 2021
e224bee
Apply suggestions from code review
brandon-b-miller Sep 14, 2021
a446b75
address reviews
brandon-b-miller Sep 14, 2021
b54e11e
remove creating buffers if the column has no mask
brandon-b-miller Sep 14, 2021
16406ff
put buffer back in for blank mask for now
brandon-b-miller Sep 14, 2021
ba2d898
merge latest and resolve conflicts
brandon-b-miller Sep 16, 2021
30d6013
fix import bug
brandon-b-miller Sep 17, 2021
a369641
clarify exec context
brandon-b-miller Sep 17, 2021
e51d780
Merge branch 'branch-21.10' into fea-masked-udf-pure-python
brandon-b-miller Sep 17, 2021
3c0c76f
rework unmasked kernels slightly
brandon-b-miller Sep 17, 2021
6deb96a
un purge c++
brandon-b-miller Sep 21, 2021
51b4fc9
cpp cleanup
brandon-b-miller Sep 21, 2021
4249334
Merge branch 'branch-21.10' into fea-masked-udf-pure-python
brandon-b-miller Sep 21, 2021
9f3c60e
Merge branch 'branch-21.10' into fea-masked-udf-pure-python
brandon-b-miller Sep 23, 2021
71c71b8
address reviews
brandon-b-miller Sep 28, 2021
b0580e9
Merge branch 'branch-21.12' into fea-masked-udf-pure-python
brandon-b-miller Sep 29, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions python/cudf/cudf/core/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4890,12 +4890,6 @@ def apply(self, func, axis=1):
runtime compilation features
"""

# libcudacxx tuples are not compatible with nvrtc 11.0
runtime = cuda.cudadrv.runtime.Runtime()
mjr, mnr = runtime.get_version()
if mjr < 11 or (mjr == 11 and mnr < 1):
raise RuntimeError("DataFrame.apply requires CUDA 11.1+")

for dtype in self.dtypes:
if (
isinstance(dtype, cudf.core.dtypes._BaseDtype)
Expand Down
26 changes: 23 additions & 3 deletions python/cudf/cudf/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
)
from cudf.core.column_accessor import ColumnAccessor
from cudf.core.join import merge
from cudf.core.udf.pipeline import compile_or_get
from cudf.core.window import Rolling
from cudf.utils import ioutils
from cudf.utils.docutils import copy_docstring
Expand Down Expand Up @@ -1431,10 +1432,29 @@ def _apply(self, func):
"""
Apply `func` across the rows of the frame.
"""
output_dtype, ptx = cudf.core.udf.pipeline.compile_masked_udf(
func, self.dtypes
kernel, retty = compile_or_get(self, func)

# Mask and data column preallocated
ans_col = cupy.empty(len(self), dtype=retty)
ans_mask = cudf.core.column.column_empty(len(self), dtype="bool")
launch_args = [(ans_col, ans_mask)]
offsets = []
for col in self._data.values():
data = col.data
mask = col.mask
if mask is None:
launch_args.append(data)
else:
launch_args.append((data, mask))
offsets.append(col.offset)
launch_args += offsets
launch_args.append(len(self)) # size
kernel.forall(len(self))(*launch_args)

result = cudf.Series(ans_col).set_mask(
libcudf.transform.bools_to_mask(ans_mask)
)
result = cudf._lib.transform.masked_udf(self, ptx, output_dtype)

return result

def rank(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,11 @@ class Masked:
def __init__(self, value, valid):
self.value = value
self.valid = valid


def pack_return(masked_or_scalar):
# Blank function to give us something for the typing and
# lowering to grab onto. Just a dummy function for us to
# call within kernels that will get replaced later by the
# lowered implementation
pass
27 changes: 22 additions & 5 deletions python/cudf/cudf/core/udf/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from cudf.core.udf.typing import MaskedType, NAType

from . import classes
from . import api
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make these imports absolute to be consistent with the rest of the code-base?

from ._ops import arith_ops, comparison_ops


Expand Down Expand Up @@ -154,9 +154,8 @@ def register_const_op(op):
to_lower_op = make_const_op(op)
cuda_lower(op, MaskedType, types.Number)(to_lower_op)
cuda_lower(op, types.Number, MaskedType)(to_lower_op)

# to_lower_op_reflected = make_reflected_const_op(op)
# cuda_lower(op, types.Number, MaskedType)(to_lower_op_reflected)
cuda_lower(op, MaskedType, types.Boolean)(to_lower_op)
cuda_lower(op, types.Boolean, MaskedType)(to_lower_op)


# register all lowering at init
Expand Down Expand Up @@ -194,6 +193,24 @@ def masked_scalar_is_null_impl(context, builder, sig, args):
return builder.load(result)


# Main kernel always calls `pack_return` on whatever the user defined
# function returned. This returns the same data if its already a `Masked`
# else packs it up into a new one that is valid from the get go
@cuda_lower(api.pack_return, MaskedType)
def pack_return_masked_impl(context, builder, sig, args):
return args[0]


@cuda_lower(api.pack_return, types.Boolean)
@cuda_lower(api.pack_return, types.Number)
def pack_return_scalar_impl(context, builder, sig, args):
outdata = cgutils.create_struct_proxy(sig.return_type)(context, builder)
outdata.value = args[0]
outdata.valid = context.get_constant(types.boolean, 1)

return outdata._getvalue()


@cuda_lower(operator.truth, MaskedType)
def masked_scalar_truth_impl(context, builder, sig, args):
indata = cgutils.create_struct_proxy(MaskedType(types.boolean))(
Expand Down Expand Up @@ -253,7 +270,7 @@ def cast_masked_to_masked(context, builder, fromty, toty, val):


# Masked constructor for use in a kernel for testing
@lower_builtin(classes.Masked, types.Number, types.boolean)
@lower_builtin(api.Masked, types.Number, types.boolean)
def masked_constructor(context, builder, sig, args):
ty = sig.return_type
value, valid = args
Expand Down
190 changes: 180 additions & 10 deletions python/cudf/cudf/core/udf/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,42 @@
import cachetools
import numpy as np
from numba import cuda
from numba.np import numpy_support
from numba.types import Tuple, boolean, int64, void
from nvtx import annotate

from cudf.core.udf.api import Masked, pack_return
from cudf.core.udf.typing import MaskedType
from cudf.utils import cudautils

libcudf_bitmask_type = numpy_support.from_dtype(np.dtype("int32"))
MASK_BITSIZE = np.dtype("int32").itemsize * 8
precompiled: cachetools.LRUCache = cachetools.LRUCache(maxsize=32)

cuda.jit(device=True)(pack_return)
gmarkall marked this conversation as resolved.
Show resolved Hide resolved


@annotate("NUMBA JIT", color="green", domain="cudf_python")
def compile_masked_udf(func, dtypes):
def get_udf_return_type(func, dtypes):
"""
Generate an inlineable PTX function that will be injected into
a variadic kernel inside libcudf

assume all input types are `MaskedType(input_col.dtype)` and then
compile the requestied PTX function as a function over those types
Get the return type of a masked UDF for a given set of argument dtypes. It
is assumed that a `MaskedType(dtype)` is passed to the function for each
input dtype.
"""
to_compiler_sig = tuple(
MaskedType(arg)
for arg in (numpy_support.from_dtype(np_type) for np_type in dtypes)
)
# Get the inlineable PTX function
ptx, numba_output_type = cudautils.compile_udf(func, to_compiler_sig)
numpy_output_type = numpy_support.as_dtype(numba_output_type.value_type)
# Get the return type. The PTX is also returned by compile_udf, but is not
# needed here.
ptx, output_type = cudautils.compile_udf(func, to_compiler_sig)

return numpy_output_type, ptx
if not isinstance(output_type, MaskedType):
numba_output_type = numpy_support.from_dtype(np.dtype(output_type))
else:
numba_output_type = output_type

return numba_output_type


def nulludf(func):
Expand Down Expand Up @@ -50,3 +64,159 @@ def wrapper(*args):
return to_udf_table._apply(func)

return wrapper


def masked_array_type_from_col(col):
"""
Return a type representing a tuple of arrays,
the first element an array of the numba type
corresponding to `dtype`, and the second an
array of bools representing a mask.
"""
nb_scalar_ty = numpy_support.from_dtype(col.dtype)
if col.mask is None:
return nb_scalar_ty[::1]
else:
return Tuple((nb_scalar_ty[::1], libcudf_bitmask_type[::1]))


def construct_signature(df, return_type):
"""
Build the signature of numba types that will be used to
actually JIT the kernel itself later, accounting for types
and offsets
"""

# Tuple of arrays, first the output data array, then the mask
return_type = Tuple((return_type[::1], boolean[::1]))
offsets = []
sig = [return_type]
for col in df._data.values():
sig.append(masked_array_type_from_col(col))
offsets.append(int64)

# return_type + data,masks + offsets + size
sig = void(*(sig + offsets + [int64]))

return sig


@cuda.jit(device=True)
def mask_get(mask, pos):
return (mask[pos // MASK_BITSIZE] >> (pos % MASK_BITSIZE)) & 1


kernel_template = """\
def _kernel(retval, {input_columns}, {input_offsets}, size):
i = cuda.grid(1)
ret_data_arr, ret_mask_arr = retval
if i < size:
{masked_input_initializers}
ret = {user_udf_call}
ret_masked = pack_return(ret)
ret_data_arr[i] = ret_masked.value
ret_mask_arr[i] = ret_masked.valid
"""

unmasked_input_initializer_template = """\
d_{idx} = input_col_{idx}
masked_{idx} = Masked(d_{idx}[i], True)
"""

masked_input_initializer_template = """\
d_{idx}, m_{idx} = input_col_{idx}
masked_{idx} = Masked(d_{idx}[i], mask_get(m_{idx}, i + offset_{idx}))
"""


def _define_function(df, scalar_return=False):
# Create argument list for kernel
input_columns = ", ".join(
[f"input_col_{i}" for i in range(len(df.columns))]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.columns is expensive -- use ._data instead (bonus: works for Series as well)

Copy link
Contributor Author

@brandon-b-miller brandon-b-miller Sep 13, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed and migrated to frame

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does that mean that this code is now dead? Or is an update here needed too?

)

input_offsets = ", ".join([f"offset_{i}" for i in range(len(df._data))])

# Create argument list to pass to device function
args = ", ".join([f"masked_{i}" for i in range(len(df._data))])
user_udf_call = f"f_({args})"

# Generate the initializers for each device function argument
initializers = []
for i, col in enumerate(df._data.values()):
idx = str(i)
if col.mask is not None:
template = masked_input_initializer_template
else:
template = unmasked_input_initializer_template

initializer = template.format(idx=idx)

initializers.append(initializer)

masked_input_initializers = "\n".join(initializers)

# Incorporate all of the above into the kernel code template
d = {
"input_columns": input_columns,
"input_offsets": input_offsets,
"masked_input_initializers": masked_input_initializers,
"user_udf_call": user_udf_call,
}

return kernel_template.format(**d)


@annotate("UDF COMPILATION", color="darkgreen", domain="cudf_python")
def compile_or_get(df, f):
"""
Return a compiled kernel in terms of MaskedTypes that launches a
kernel equivalent of `f` for the dtypes of `df`. The kernel uses
a thread for each row and calls `f` using that rows data / mask
to produce an output value and output valdity for each row.

If the UDF has already been compiled for this requested dtypes,
a cached version will be returned instead of running compilation.

"""

# check to see if we already compiled this function
cache_key = (
*cudautils.make_cache_key(f, tuple(df.dtypes)),
*(col.mask is None for col in df._data.values()),
)
if precompiled.get(cache_key) is not None:
kernel, scalar_return_type = precompiled[cache_key]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you just return precompiled[cache_key] here to save the rest of the function living in an else block?

else:

numba_return_type = get_udf_return_type(f, df.dtypes)
_is_scalar_return = not isinstance(numba_return_type, MaskedType)
scalar_return_type = (
numba_return_type
if _is_scalar_return
else numba_return_type.value_type
)

sig = construct_signature(df, scalar_return_type)
f_ = cuda.jit(device=True)(f)

# Dict of 'local' variables into which `_kernel` is defined
local_exec_context = {}
global_exec_context = {
"f_": f_,
"cuda": cuda,
"Masked": Masked,
"mask_get": mask_get,
"pack_return": pack_return,
}
exec(
_define_function(df, scalar_return=_is_scalar_return),
global_exec_context,
local_exec_context,
)
# The python function definition representing the kernel
_kernel = local_exec_context["_kernel"]
kernel = cuda.jit(sig)(_kernel)
precompiled[cache_key] = (kernel, scalar_return_type)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you cache (kernel, numpy_support.as_dtype(scalar_return_type) so that you don't need to call as_dtype on the scalar_return_type each time it's returned?


return kernel, numpy_support.as_dtype(scalar_return_type)
Loading