Skip to content

Commit

Permalink
Generalized null support in user defined functions (#8213)
Browse files Browse the repository at this point in the history
**Draft**

- Adds `DataFrame.apply` similar to Pandas
- Adds support for automatically including the validity of the operand columns in the computation of the result
- Adds support for involving `cudf.NA` in user defined functions explicitly

This PR creates the following API:

```python
@nulludf
def func_gdf(x, y):
    if x is cudf.NA:
        return y
    else:
        return x + y


gdf = cudf.DataFrame({
    'a':[1,None,3, None],
    'b':[4,5,None, None]
})
gdf.apply(lambda row: func_gdf(row['a'], row['b']), axis=1)

# 0       5
# 1       5
# 2    <NA>
# 3    <NA>
# dtype: int64
```

Authors:
  - https://github.com/brandon-b-miller
  - Graham Markall (https://github.com/gmarkall)

Approvers:
  - Robert Maynard (https://github.com/robertmaynard)
  - Mike Wilson (https://github.com/hyperbolic2346)
  - Michael Wang (https://github.com/isVoid)
  - GALI PREM SAGAR (https://github.com/galipremsagar)
  - Nghia Truong (https://github.com/ttnghia)

URL: #8213
  • Loading branch information
brandon-b-miller authored Jul 16, 2021
1 parent 18f7c01 commit 7ff4724
Show file tree
Hide file tree
Showing 18 changed files with 1,672 additions and 8 deletions.
1 change: 1 addition & 0 deletions cpp/cmake/Modules/JitifyPreprocessKernels.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ endfunction()

jit_preprocess_files(SOURCE_DIRECTORY ${CUDF_SOURCE_DIR}/src
FILES binaryop/jit/kernel.cu
transform/jit/masked_udf_kernel.cu
transform/jit/kernel.cu
rolling/jit/kernel.cu
)
Expand Down
6 changes: 6 additions & 0 deletions cpp/include/cudf/transform.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ std::unique_ptr<column> transform(
bool is_ptx,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

std::unique_ptr<column> generalized_masked_op(
table_view const& data_view,
std::string const& binary_udf,
data_type output_type,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

/**
* @brief Creates a null_mask from `input` by converting `NaN` to null and
* preserving existing null values and also returns new null_count.
Expand Down
85 changes: 85 additions & 0 deletions cpp/src/transform/jit/masked_udf_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <cstddef>
#include <cstdint>
#include <transform/jit/operation-udf.hpp>

#include <cudf/types.hpp>
#include <cudf/utilities/bit.hpp>

#include <cuda/std/climits>
#include <cuda/std/cstddef>
#include <cuda/std/limits>
#include <cuda/std/tuple>
#include <cuda/std/type_traits>

namespace cudf {
namespace transformation {
namespace jit {

template <typename T>
struct Masked {
T value;
bool valid;
};

template <typename TypeIn, typename MaskType, typename OffsetType>
__device__ auto make_args(cudf::size_type id, TypeIn in_ptr, MaskType in_mask, OffsetType in_offset)
{
bool valid = in_mask ? cudf::bit_is_set(in_mask, in_offset + id) : true;
return cuda::std::make_tuple(in_ptr[id], valid);
}

template <typename InType, typename MaskType, typename OffsetType, typename... Arguments>
__device__ auto make_args(cudf::size_type id,
InType in_ptr,
MaskType in_mask, // in practice, always cudf::bitmask_type const*
OffsetType in_offset, // in practice, always cudf::size_type
Arguments... args)
{
bool valid = in_mask ? cudf::bit_is_set(in_mask, in_offset + id) : true;
return cuda::std::tuple_cat(cuda::std::make_tuple(in_ptr[id], valid), make_args(id, args...));
}

template <typename TypeOut, typename... Arguments>
__global__ void generic_udf_kernel(cudf::size_type size,
TypeOut* out_data,
bool* out_mask,
Arguments... args)
{
int const tid = threadIdx.x;
int const blkid = blockIdx.x;
int const blksz = blockDim.x;
int const gridsz = gridDim.x;
int const start = tid + blkid * blksz;
int const step = blksz * gridsz;

Masked<TypeOut> output;
for (cudf::size_type i = start; i < size; i += step) {
auto func_args = cuda::std::tuple_cat(
cuda::std::make_tuple(&output.value),
make_args(i, args...) // passed int64*, bool*, int64, int64*, bool*, int64
);
cuda::std::apply(GENERIC_OP, func_args);
out_data[i] = output.value;
out_mask[i] = output.valid;
}
}

} // namespace jit
} // namespace transformation
} // namespace cudf
114 changes: 108 additions & 6 deletions cpp/src/transform/transform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,22 @@
* limitations under the License.
*/

#include <jit_preprocessed_files/transform/jit/kernel.cu.jit.hpp>

#include <jit/cache.hpp>
#include <jit/parser.hpp>
#include <jit/type.hpp>

#include <cudf/column/column.hpp>
#include <cudf/column/column_factories.hpp>
#include <cudf/detail/nvtx/ranges.hpp>
#include <cudf/detail/transform.hpp>
#include <cudf/null_mask.hpp>
#include <cudf/table/table_view.hpp>
#include <cudf/utilities/traits.hpp>
#include <cudf/utilities/type_dispatcher.hpp>

#include <jit_preprocessed_files/transform/jit/kernel.cu.jit.hpp>
#include <jit_preprocessed_files/transform/jit/masked_udf_kernel.cu.jit.hpp>

#include <jit/cache.hpp>
#include <jit/parser.hpp>
#include <jit/type.hpp>

#include <rmm/cuda_stream_view.hpp>

namespace cudf {
Expand Down Expand Up @@ -63,6 +65,80 @@ void unary_operation(mutable_column_view output,
cudf::jit::get_data_ptr(input));
}

std::vector<std::string> make_template_types(column_view outcol_view, table_view const& data_view)
{
std::string mskptr_type =
cudf::jit::get_type_name(cudf::data_type(cudf::type_to_id<cudf::bitmask_type>())) + "*";
std::string offset_type =
cudf::jit::get_type_name(cudf::data_type(cudf::type_to_id<cudf::offset_type>()));

std::vector<std::string> template_types;
template_types.reserve((3 * data_view.num_columns()) + 1);

template_types.push_back(cudf::jit::get_type_name(outcol_view.type()));
for (auto const& col : data_view) {
template_types.push_back(cudf::jit::get_type_name(col.type()) + "*");
template_types.push_back(mskptr_type);
template_types.push_back(offset_type);
}
return template_types;
}

void generalized_operation(table_view const& data_view,
std::string const& udf,
data_type output_type,
mutable_column_view outcol_view,
mutable_column_view outmsk_view,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
auto const template_types = make_template_types(outcol_view, data_view);

std::string generic_kernel_name =
jitify2::reflection::Template("cudf::transformation::jit::generic_udf_kernel")
.instantiate(template_types);

std::string generic_cuda_source = cudf::jit::parse_single_function_ptx(
udf, "GENERIC_OP", cudf::jit::get_type_name(output_type), {0});

std::vector<void*> kernel_args;
kernel_args.reserve((data_view.num_columns() * 3) + 3);

cudf::size_type size = outcol_view.size();
const void* outcol_ptr = cudf::jit::get_data_ptr(outcol_view);
const void* outmsk_ptr = cudf::jit::get_data_ptr(outmsk_view);
kernel_args.insert(kernel_args.begin(), {&size, &outcol_ptr, &outmsk_ptr});

std::vector<const void*> data_ptrs;
std::vector<cudf::bitmask_type const*> mask_ptrs;
std::vector<cudf::offset_type> offsets;

data_ptrs.reserve(data_view.num_columns());
mask_ptrs.reserve(data_view.num_columns());
offsets.reserve(data_view.num_columns());

auto const iters = thrust::make_zip_iterator(
thrust::make_tuple(data_ptrs.begin(), mask_ptrs.begin(), offsets.begin()));

std::for_each(iters, iters + data_view.num_columns(), [&](auto const& tuple_vals) {
kernel_args.push_back(&thrust::get<0>(tuple_vals));
kernel_args.push_back(&thrust::get<1>(tuple_vals));
kernel_args.push_back(&thrust::get<2>(tuple_vals));
});

std::transform(data_view.begin(), data_view.end(), iters, [&](column_view const& col) {
return thrust::make_tuple(cudf::jit::get_data_ptr(col), col.null_mask(), col.offset());
});

cudf::jit::get_program_cache(*transform_jit_masked_udf_kernel_cu_jit)
.get_kernel(generic_kernel_name,
{},
{{"transform/jit/operation-udf.hpp", generic_cuda_source}},
{"-arch=sm_."})
->configure_1d_max_occupancy(0, 0, 0, stream.value())
->launch(kernel_args.data());
}

} // namespace jit
} // namespace transformation

Expand All @@ -89,6 +165,24 @@ std::unique_ptr<column> transform(column_view const& input,
return output;
}

std::unique_ptr<column> generalized_masked_op(table_view const& data_view,
std::string const& udf,
data_type output_type,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
std::unique_ptr<column> output = make_fixed_width_column(output_type, data_view.num_rows());
std::unique_ptr<column> output_mask =
make_fixed_width_column(cudf::data_type{cudf::type_id::BOOL8}, data_view.num_rows());

transformation::jit::generalized_operation(
data_view, udf, output_type, *output, *output_mask, stream, mr);

auto final_output_mask = cudf::bools_to_mask(*output_mask);
output.get()->set_null_mask(std::move(*(final_output_mask.first)));
return output;
}

} // namespace detail

std::unique_ptr<column> transform(column_view const& input,
Expand All @@ -101,4 +195,12 @@ std::unique_ptr<column> transform(column_view const& input,
return detail::transform(input, unary_udf, output_type, is_ptx, rmm::cuda_stream_default, mr);
}

std::unique_ptr<column> generalized_masked_op(table_view const& data_view,
std::string const& udf,
data_type output_type,
rmm::mr::device_memory_resource* mr)
{
return detail::generalized_masked_op(data_view, udf, output_type, rmm::cuda_stream_default, mr);
}

} // namespace cudf
6 changes: 6 additions & 0 deletions python/cudf/cudf/_lib/cpp/transform.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ cdef extern from "cudf/transform.hpp" namespace "cudf" nogil:
bool is_ptx
) except +

cdef unique_ptr[column] generalized_masked_op(
const table_view& data_view,
string udf,
data_type output_type,
) except +

cdef pair[unique_ptr[table], unique_ptr[column]] encode(
table_view input
) except +
21 changes: 21 additions & 0 deletions python/cudf/cudf/_lib/transform.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,27 @@ def transform(Column input, op):
return Column.from_unique_ptr(move(c_output))


def masked_udf(Table incols, op, output_type):
cdef table_view data_view = incols.data_view()
cdef string c_str = op.encode("UTF-8")
cdef type_id c_tid
cdef data_type c_dtype

c_tid = <type_id> (
<underlying_type_t_type_id> np_to_cudf_types[output_type]
)
c_dtype = data_type(c_tid)

with nogil:
c_output = move(libcudf_transform.generalized_masked_op(
data_view,
c_str,
c_dtype,
))

return Column.from_unique_ptr(move(c_output))


def table_encode(Table input):
cdef table_view c_input = input.data_view()
cdef pair[unique_ptr[table], unique_ptr[column]] c_result
Expand Down
1 change: 1 addition & 0 deletions python/cudf/cudf/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,5 @@
from cudf.core.multiindex import MultiIndex
from cudf.core.scalar import NA, Scalar
from cudf.core.series import Series
import cudf.core.udf
from cudf.core.cut import cut
Loading

0 comments on commit 7ff4724

Please sign in to comment.