-
Notifications
You must be signed in to change notification settings - Fork 908
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pure-python masked UDFs #9174
Pure-python masked UDFs #9174
Changes from 38 commits
8d98b73
45c47ee
c2e0218
69f92cf
b636af9
bdf5823
470d25e
e271ce4
0a971f0
2f7e6f8
13a94cb
b2a68e6
5cb75e7
49d9978
2ba8bd2
11b2fd1
7379fe1
775dd57
04c38e6
7a01bdb
627d197
d3e2e0b
394fad3
306f5e1
05adec7
edbae6c
e224bee
a446b75
b54e11e
16406ff
ba2d898
30d6013
a369641
e51d780
3c0c76f
6deb96a
51b4fc9
4249334
9f3c60e
71c71b8
b0580e9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,28 +1,42 @@ | ||
import cachetools | ||
import numpy as np | ||
from numba import cuda | ||
from numba.np import numpy_support | ||
from numba.types import Tuple, boolean, int64, void | ||
from nvtx import annotate | ||
|
||
from cudf.core.udf.api import Masked, pack_return | ||
from cudf.core.udf.typing import MaskedType | ||
from cudf.utils import cudautils | ||
|
||
libcudf_bitmask_type = numpy_support.from_dtype(np.dtype("int32")) | ||
MASK_BITSIZE = np.dtype("int32").itemsize * 8 | ||
precompiled: cachetools.LRUCache = cachetools.LRUCache(maxsize=32) | ||
|
||
cuda.jit(device=True)(pack_return) | ||
gmarkall marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
@annotate("NUMBA JIT", color="green", domain="cudf_python") | ||
def compile_masked_udf(func, dtypes): | ||
def get_udf_return_type(func, dtypes): | ||
""" | ||
Generate an inlineable PTX function that will be injected into | ||
a variadic kernel inside libcudf | ||
|
||
assume all input types are `MaskedType(input_col.dtype)` and then | ||
compile the requestied PTX function as a function over those types | ||
Get the return type of a masked UDF for a given set of argument dtypes. It | ||
is assumed that a `MaskedType(dtype)` is passed to the function for each | ||
input dtype. | ||
""" | ||
to_compiler_sig = tuple( | ||
MaskedType(arg) | ||
for arg in (numpy_support.from_dtype(np_type) for np_type in dtypes) | ||
) | ||
# Get the inlineable PTX function | ||
ptx, numba_output_type = cudautils.compile_udf(func, to_compiler_sig) | ||
numpy_output_type = numpy_support.as_dtype(numba_output_type.value_type) | ||
# Get the return type. The PTX is also returned by compile_udf, but is not | ||
# needed here. | ||
ptx, output_type = cudautils.compile_udf(func, to_compiler_sig) | ||
|
||
return numpy_output_type, ptx | ||
if not isinstance(output_type, MaskedType): | ||
numba_output_type = numpy_support.from_dtype(np.dtype(output_type)) | ||
else: | ||
numba_output_type = output_type | ||
|
||
return numba_output_type | ||
|
||
|
||
def nulludf(func): | ||
|
@@ -50,3 +64,159 @@ def wrapper(*args): | |
return to_udf_table._apply(func) | ||
|
||
return wrapper | ||
|
||
|
||
def masked_array_type_from_col(col): | ||
""" | ||
Return a type representing a tuple of arrays, | ||
the first element an array of the numba type | ||
corresponding to `dtype`, and the second an | ||
array of bools representing a mask. | ||
""" | ||
nb_scalar_ty = numpy_support.from_dtype(col.dtype) | ||
if col.mask is None: | ||
return nb_scalar_ty[::1] | ||
else: | ||
return Tuple((nb_scalar_ty[::1], libcudf_bitmask_type[::1])) | ||
|
||
|
||
def construct_signature(df, return_type): | ||
""" | ||
Build the signature of numba types that will be used to | ||
actually JIT the kernel itself later, accounting for types | ||
and offsets | ||
""" | ||
|
||
# Tuple of arrays, first the output data array, then the mask | ||
return_type = Tuple((return_type[::1], boolean[::1])) | ||
offsets = [] | ||
sig = [return_type] | ||
for col in df._data.values(): | ||
sig.append(masked_array_type_from_col(col)) | ||
offsets.append(int64) | ||
|
||
# return_type + data,masks + offsets + size | ||
sig = void(*(sig + offsets + [int64])) | ||
|
||
return sig | ||
|
||
|
||
@cuda.jit(device=True) | ||
def mask_get(mask, pos): | ||
return (mask[pos // MASK_BITSIZE] >> (pos % MASK_BITSIZE)) & 1 | ||
|
||
|
||
kernel_template = """\ | ||
def _kernel(retval, {input_columns}, {input_offsets}, size): | ||
i = cuda.grid(1) | ||
ret_data_arr, ret_mask_arr = retval | ||
if i < size: | ||
{masked_input_initializers} | ||
ret = {user_udf_call} | ||
ret_masked = pack_return(ret) | ||
ret_data_arr[i] = ret_masked.value | ||
ret_mask_arr[i] = ret_masked.valid | ||
""" | ||
|
||
unmasked_input_initializer_template = """\ | ||
d_{idx} = input_col_{idx} | ||
masked_{idx} = Masked(d_{idx}[i], True) | ||
""" | ||
|
||
masked_input_initializer_template = """\ | ||
d_{idx}, m_{idx} = input_col_{idx} | ||
masked_{idx} = Masked(d_{idx}[i], mask_get(m_{idx}, i + offset_{idx})) | ||
""" | ||
|
||
|
||
def _define_function(df, scalar_return=False): | ||
# Create argument list for kernel | ||
input_columns = ", ".join( | ||
[f"input_col_{i}" for i in range(len(df.columns))] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed and migrated to frame There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does that mean that this code is now dead? Or is an update here needed too? |
||
) | ||
|
||
input_offsets = ", ".join([f"offset_{i}" for i in range(len(df._data))]) | ||
|
||
# Create argument list to pass to device function | ||
args = ", ".join([f"masked_{i}" for i in range(len(df._data))]) | ||
user_udf_call = f"f_({args})" | ||
|
||
# Generate the initializers for each device function argument | ||
initializers = [] | ||
for i, col in enumerate(df._data.values()): | ||
idx = str(i) | ||
if col.mask is not None: | ||
template = masked_input_initializer_template | ||
else: | ||
template = unmasked_input_initializer_template | ||
|
||
initializer = template.format(idx=idx) | ||
|
||
initializers.append(initializer) | ||
|
||
masked_input_initializers = "\n".join(initializers) | ||
|
||
# Incorporate all of the above into the kernel code template | ||
d = { | ||
"input_columns": input_columns, | ||
"input_offsets": input_offsets, | ||
"masked_input_initializers": masked_input_initializers, | ||
"user_udf_call": user_udf_call, | ||
} | ||
|
||
return kernel_template.format(**d) | ||
|
||
|
||
@annotate("UDF COMPILATION", color="darkgreen", domain="cudf_python") | ||
def compile_or_get(df, f): | ||
""" | ||
Return a compiled kernel in terms of MaskedTypes that launches a | ||
kernel equivalent of `f` for the dtypes of `df`. The kernel uses | ||
a thread for each row and calls `f` using that rows data / mask | ||
to produce an output value and output valdity for each row. | ||
|
||
If the UDF has already been compiled for this requested dtypes, | ||
a cached version will be returned instead of running compilation. | ||
|
||
""" | ||
|
||
# check to see if we already compiled this function | ||
cache_key = ( | ||
*cudautils.make_cache_key(f, tuple(df.dtypes)), | ||
*(col.mask is None for col in df._data.values()), | ||
) | ||
if precompiled.get(cache_key) is not None: | ||
kernel, scalar_return_type = precompiled[cache_key] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you just return |
||
else: | ||
|
||
numba_return_type = get_udf_return_type(f, df.dtypes) | ||
_is_scalar_return = not isinstance(numba_return_type, MaskedType) | ||
scalar_return_type = ( | ||
numba_return_type | ||
if _is_scalar_return | ||
else numba_return_type.value_type | ||
) | ||
|
||
sig = construct_signature(df, scalar_return_type) | ||
f_ = cuda.jit(device=True)(f) | ||
|
||
# Dict of 'local' variables into which `_kernel` is defined | ||
local_exec_context = {} | ||
global_exec_context = { | ||
"f_": f_, | ||
"cuda": cuda, | ||
"Masked": Masked, | ||
"mask_get": mask_get, | ||
"pack_return": pack_return, | ||
} | ||
exec( | ||
_define_function(df, scalar_return=_is_scalar_return), | ||
global_exec_context, | ||
local_exec_context, | ||
) | ||
# The python function definition representing the kernel | ||
_kernel = local_exec_context["_kernel"] | ||
kernel = cuda.jit(sig)(_kernel) | ||
precompiled[cache_key] = (kernel, scalar_return_type) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you cache |
||
|
||
return kernel, numpy_support.as_dtype(scalar_return_type) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we make these imports absolute to be consistent with the rest of the code-base?