Skip to content

Commit

Permalink
Pure-python masked UDFs (#9174)
Browse files Browse the repository at this point in the history
Replaces C++ implementation of masked UDF pipeline with a pure python version which compiles and launches the entire kernel using numba. This solves a bunch of problems:

- CUDA 11.0 support is now available since the impl no longer needs `cuda::std::tuple` to work with NVRTC 11.0. 
- Support for special functions which compile to multiple function definitions, such as `pow`, `sin`, and `cos` is now provided since all the PTX is compiled and linked inside numba (Fixes #8470)
- Allows us to support this corner case which would require a separate c++ kernel in previous implementation
```python
def f(x):
    return 42
```

- Makes developing/adding features to the impl much easier

Authors:
  - https://github.com/brandon-b-miller

Approvers:
  - Robert Maynard (https://github.com/robertmaynard)
  - GALI PREM SAGAR (https://github.com/galipremsagar)
  - Graham Markall (https://github.com/gmarkall)
  - Ashwin Srinath (https://github.com/shwina)

URL: #9174
  • Loading branch information
brandon-b-miller authored Sep 29, 2021
1 parent fdb9e3b commit f9ce870
Show file tree
Hide file tree
Showing 9 changed files with 320 additions and 71 deletions.
6 changes: 0 additions & 6 deletions python/cudf/cudf/core/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4866,12 +4866,6 @@ def apply(
runtime compilation features
"""

# libcudacxx tuples are not compatible with nvrtc 11.0
runtime = cuda.cudadrv.runtime.Runtime()
mjr, mnr = runtime.get_version()
if mjr < 11 or (mjr == 11 and mnr < 1):
raise RuntimeError("DataFrame.apply requires CUDA 11.1+")

for dtype in self.dtypes:
if (
isinstance(dtype, cudf.core.dtypes._BaseDtype)
Expand Down
26 changes: 23 additions & 3 deletions python/cudf/cudf/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
)
from cudf.core.column_accessor import ColumnAccessor
from cudf.core.join import merge
from cudf.core.udf.pipeline import compile_or_get
from cudf.core.window import Rolling
from cudf.utils import ioutils
from cudf.utils.docutils import copy_docstring
Expand Down Expand Up @@ -1455,10 +1456,29 @@ def _apply(self, func):
"""
Apply `func` across the rows of the frame.
"""
output_dtype, ptx = cudf.core.udf.pipeline.compile_masked_udf(
func, self.dtypes
kernel, retty = compile_or_get(self, func)

# Mask and data column preallocated
ans_col = cupy.empty(len(self), dtype=retty)
ans_mask = cudf.core.column.column_empty(len(self), dtype="bool")
launch_args = [(ans_col, ans_mask)]
offsets = []
for col in self._data.values():
data = col.data
mask = col.mask
if mask is None:
launch_args.append(data)
else:
launch_args.append((data, mask))
offsets.append(col.offset)
launch_args += offsets
launch_args.append(len(self)) # size
kernel.forall(len(self))(*launch_args)

result = cudf.Series(ans_col).set_mask(
libcudf.transform.bools_to_mask(ans_mask)
)
result = cudf._lib.transform.masked_udf(self, ptx, output_dtype)

return result

def rank(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,11 @@ class Masked:
def __init__(self, value, valid):
self.value = value
self.valid = valid


def pack_return(masked_or_scalar):
# Blank function to give us something for the typing and
# lowering to grab onto. Just a dummy function for us to
# call within kernels that will get replaced later by the
# lowered implementation
pass
30 changes: 23 additions & 7 deletions python/cudf/cudf/core/udf/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,10 @@
)
from numba.extending import lower_builtin, types

from cudf.core.udf import api
from cudf.core.udf._ops import arith_ops, comparison_ops
from cudf.core.udf.typing import MaskedType, NAType

from . import classes
from ._ops import arith_ops, comparison_ops


@cuda_lowering_registry.lower_constant(NAType)
def constant_na(context, builder, ty, pyval):
Expand Down Expand Up @@ -154,9 +153,8 @@ def register_const_op(op):
to_lower_op = make_const_op(op)
cuda_lower(op, MaskedType, types.Number)(to_lower_op)
cuda_lower(op, types.Number, MaskedType)(to_lower_op)

# to_lower_op_reflected = make_reflected_const_op(op)
# cuda_lower(op, types.Number, MaskedType)(to_lower_op_reflected)
cuda_lower(op, MaskedType, types.Boolean)(to_lower_op)
cuda_lower(op, types.Boolean, MaskedType)(to_lower_op)


# register all lowering at init
Expand Down Expand Up @@ -194,6 +192,24 @@ def masked_scalar_is_null_impl(context, builder, sig, args):
return builder.load(result)


# Main kernel always calls `pack_return` on whatever the user defined
# function returned. This returns the same data if its already a `Masked`
# else packs it up into a new one that is valid from the get go
@cuda_lower(api.pack_return, MaskedType)
def pack_return_masked_impl(context, builder, sig, args):
return args[0]


@cuda_lower(api.pack_return, types.Boolean)
@cuda_lower(api.pack_return, types.Number)
def pack_return_scalar_impl(context, builder, sig, args):
outdata = cgutils.create_struct_proxy(sig.return_type)(context, builder)
outdata.value = args[0]
outdata.valid = context.get_constant(types.boolean, 1)

return outdata._getvalue()


@cuda_lower(operator.truth, MaskedType)
def masked_scalar_truth_impl(context, builder, sig, args):
indata = cgutils.create_struct_proxy(MaskedType(types.boolean))(
Expand Down Expand Up @@ -253,7 +269,7 @@ def cast_masked_to_masked(context, builder, fromty, toty, val):


# Masked constructor for use in a kernel for testing
@lower_builtin(classes.Masked, types.Number, types.boolean)
@lower_builtin(api.Masked, types.Number, types.boolean)
def masked_constructor(context, builder, sig, args):
ty = sig.return_type
value, valid = args
Expand Down
187 changes: 177 additions & 10 deletions python/cudf/cudf/core/udf/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,40 @@
import cachetools
import numpy as np
from numba import cuda
from numba.np import numpy_support
from numba.types import Tuple, boolean, int64, void
from nvtx import annotate

from cudf.core.udf.api import Masked, pack_return
from cudf.core.udf.typing import MaskedType
from cudf.utils import cudautils

libcudf_bitmask_type = numpy_support.from_dtype(np.dtype("int32"))
MASK_BITSIZE = np.dtype("int32").itemsize * 8
precompiled: cachetools.LRUCache = cachetools.LRUCache(maxsize=32)


@annotate("NUMBA JIT", color="green", domain="cudf_python")
def compile_masked_udf(func, dtypes):
def get_udf_return_type(func, dtypes):
"""
Generate an inlineable PTX function that will be injected into
a variadic kernel inside libcudf
assume all input types are `MaskedType(input_col.dtype)` and then
compile the requestied PTX function as a function over those types
Get the return type of a masked UDF for a given set of argument dtypes. It
is assumed that a `MaskedType(dtype)` is passed to the function for each
input dtype.
"""
to_compiler_sig = tuple(
MaskedType(arg)
for arg in (numpy_support.from_dtype(np_type) for np_type in dtypes)
)
# Get the inlineable PTX function
ptx, numba_output_type = cudautils.compile_udf(func, to_compiler_sig)
numpy_output_type = numpy_support.as_dtype(numba_output_type.value_type)
# Get the return type. The PTX is also returned by compile_udf, but is not
# needed here.
ptx, output_type = cudautils.compile_udf(func, to_compiler_sig)

if not isinstance(output_type, MaskedType):
numba_output_type = numpy_support.from_dtype(np.dtype(output_type))
else:
numba_output_type = output_type

return numpy_output_type, ptx
return numba_output_type


def nulludf(func):
Expand Down Expand Up @@ -50,3 +62,158 @@ def wrapper(*args):
return to_udf_table._apply(func)

return wrapper


def masked_array_type_from_col(col):
"""
Return a type representing a tuple of arrays,
the first element an array of the numba type
corresponding to `dtype`, and the second an
array of bools representing a mask.
"""
nb_scalar_ty = numpy_support.from_dtype(col.dtype)
if col.mask is None:
return nb_scalar_ty[::1]
else:
return Tuple((nb_scalar_ty[::1], libcudf_bitmask_type[::1]))


def construct_signature(df, return_type):
"""
Build the signature of numba types that will be used to
actually JIT the kernel itself later, accounting for types
and offsets
"""

# Tuple of arrays, first the output data array, then the mask
return_type = Tuple((return_type[::1], boolean[::1]))
offsets = []
sig = [return_type]
for col in df._data.values():
sig.append(masked_array_type_from_col(col))
offsets.append(int64)

# return_type + data,masks + offsets + size
sig = void(*(sig + offsets + [int64]))

return sig


@cuda.jit(device=True)
def mask_get(mask, pos):
return (mask[pos // MASK_BITSIZE] >> (pos % MASK_BITSIZE)) & 1


kernel_template = """\
def _kernel(retval, {input_columns}, {input_offsets}, size):
i = cuda.grid(1)
ret_data_arr, ret_mask_arr = retval
if i < size:
{masked_input_initializers}
ret = {user_udf_call}
ret_masked = pack_return(ret)
ret_data_arr[i] = ret_masked.value
ret_mask_arr[i] = ret_masked.valid
"""

unmasked_input_initializer_template = """\
d_{idx} = input_col_{idx}
masked_{idx} = Masked(d_{idx}[i], True)
"""

masked_input_initializer_template = """\
d_{idx}, m_{idx} = input_col_{idx}
masked_{idx} = Masked(d_{idx}[i], mask_get(m_{idx}, i + offset_{idx}))
"""


def _define_function(df, scalar_return=False):
# Create argument list for kernel
input_columns = ", ".join([f"input_col_{i}" for i in range(len(df._data))])

input_offsets = ", ".join([f"offset_{i}" for i in range(len(df._data))])

# Create argument list to pass to device function
args = ", ".join([f"masked_{i}" for i in range(len(df._data))])
user_udf_call = f"f_({args})"

# Generate the initializers for each device function argument
initializers = []
for i, col in enumerate(df._data.values()):
idx = str(i)
if col.mask is not None:
template = masked_input_initializer_template
else:
template = unmasked_input_initializer_template

initializer = template.format(idx=idx)

initializers.append(initializer)

masked_input_initializers = "\n".join(initializers)

# Incorporate all of the above into the kernel code template
d = {
"input_columns": input_columns,
"input_offsets": input_offsets,
"masked_input_initializers": masked_input_initializers,
"user_udf_call": user_udf_call,
}

return kernel_template.format(**d)


@annotate("UDF COMPILATION", color="darkgreen", domain="cudf_python")
def compile_or_get(df, f):
"""
Return a compiled kernel in terms of MaskedTypes that launches a
kernel equivalent of `f` for the dtypes of `df`. The kernel uses
a thread for each row and calls `f` using that rows data / mask
to produce an output value and output valdity for each row.
If the UDF has already been compiled for this requested dtypes,
a cached version will be returned instead of running compilation.
"""

# check to see if we already compiled this function
cache_key = (
*cudautils.make_cache_key(f, tuple(df.dtypes)),
*(col.mask is None for col in df._data.values()),
)
if precompiled.get(cache_key) is not None:
kernel, scalar_return_type = precompiled[cache_key]
return kernel, scalar_return_type

numba_return_type = get_udf_return_type(f, df.dtypes)
_is_scalar_return = not isinstance(numba_return_type, MaskedType)
scalar_return_type = (
numba_return_type
if _is_scalar_return
else numba_return_type.value_type
)

sig = construct_signature(df, scalar_return_type)
f_ = cuda.jit(device=True)(f)

# Dict of 'local' variables into which `_kernel` is defined
local_exec_context = {}
global_exec_context = {
"f_": f_,
"cuda": cuda,
"Masked": Masked,
"mask_get": mask_get,
"pack_return": pack_return,
}
exec(
_define_function(df, scalar_return=_is_scalar_return),
global_exec_context,
local_exec_context,
)
# The python function definition representing the kernel
_kernel = local_exec_context["_kernel"]
kernel = cuda.jit(sig)(_kernel)
scalar_return_type = numpy_support.as_dtype(scalar_return_type)
precompiled[cache_key] = (kernel, scalar_return_type)

return kernel, scalar_return_type
Loading

0 comments on commit f9ce870

Please sign in to comment.