Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support replace in strings_udf #12207

Merged
Merged
36 changes: 36 additions & 0 deletions python/cudf/cudf/core/udf/strings_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
len_impl,
lower_impl,
lstrip_impl,
replace_impl,
rfind_impl,
rstrip_impl,
startswith_impl,
Expand All @@ -50,6 +51,41 @@ def masked_len_impl(context, builder, sig, args):
return ret._getvalue()


@cuda_lower(
"MaskedType.replace",
MaskedType(string_view),
MaskedType(string_view),
MaskedType(string_view),
)
def masked_string_view_replace_impl(context, builder, sig, args):
ret = cgutils.create_struct_proxy(sig.return_type)(context, builder)

src_masked = cgutils.create_struct_proxy(sig.args[0])(
context, builder, value=args[0]
)
to_replace_masked = cgutils.create_struct_proxy(sig.args[0])(
context, builder, value=args[1]
)
replacement_masked = cgutils.create_struct_proxy(sig.args[0])(
context, builder, value=args[2]
)
brandon-b-miller marked this conversation as resolved.
Show resolved Hide resolved

result = replace_impl(
context,
builder,
nb_signature(udf_string, string_view, string_view, string_view),
(src_masked.value, to_replace_masked.value, replacement_masked.value),
)

ret.value = result
ret.valid = builder.and_(
builder.and_(src_masked.valid, to_replace_masked.valid),
replacement_masked.valid,
)

return ret._getvalue()


def create_binary_string_func(op, cuda_func, retty):
"""
Provide a wrapper around numba's low-level extension API which
Expand Down
17 changes: 17 additions & 0 deletions python/cudf/cudf/core/udf/strings_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,26 @@ def generic(self, args, kws):
)


class MaskedStringViewReplace(AbstractTemplate):
key = "MaskedType.replace"

def generic(self, args, kws):
return nb_signature(
MaskedType(udf_string),
MaskedType(string_view),
MaskedType(string_view),
recvr=self.this,
)


class MaskedStringViewAttrs(AttributeTemplate):
key = MaskedType(string_view)

def resolve_replace(self, mod):
return types.BoundFunction(
MaskedStringViewReplace, MaskedType(string_view)
)

def resolve_count(self, mod):
return types.BoundFunction(
MaskedStringViewCount, MaskedType(string_view)
Expand Down
10 changes: 10 additions & 0 deletions python/cudf/cudf/tests/test_udf_masked_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -928,6 +928,16 @@ def func(row):
run_masked_udf_test(func, str_udf_data, check_dtype=False)


@string_udf_test
@pytest.mark.parametrize("to_replace", ["a", "1", "", "@"])
@pytest.mark.parametrize("replacement", ["a", "1", "", "@"])
def test_string_udf_replace(str_udf_data, to_replace, replacement):
def func(row):
return row["str_col"].replace(to_replace, replacement)

run_masked_udf_test(func, str_udf_data, check_dtype=False)


@pytest.mark.parametrize(
"data", [[1.0, 0.0, 1.5], [1, 0, 2], [True, False, True]]
)
Expand Down
18 changes: 18 additions & 0 deletions python/strings_udf/cpp/src/strings/udf/shim.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include <cudf/strings/udf/case.cuh>
#include <cudf/strings/udf/char_types.cuh>
#include <cudf/strings/udf/replace.cuh>
#include <cudf/strings/udf/search.cuh>
#include <cudf/strings/udf/starts_with.cuh>
#include <cudf/strings/udf/strip.cuh>
Expand Down Expand Up @@ -329,3 +330,20 @@ extern "C" __device__ int concat(int* nb_retval, void* udf_str, void* const* lhs
*udf_str_ptr = result;
return 0;
}

extern "C" __device__ int replace(
int* nb_retval, void* udf_str, void* const src, void* const to_replace, void* const replacement)
{
auto src_ptr = reinterpret_cast<cudf::string_view const*>(src);
auto to_replace_ptr = reinterpret_cast<cudf::string_view const*>(to_replace);
auto replacement_ptr = reinterpret_cast<cudf::string_view const*>(replacement);

auto udf_str_ptr = new (udf_str) udf_string;
udf_string result;

result = replace(*src_ptr, *to_replace_ptr, *replacement_ptr);

*udf_str_ptr = result;
brandon-b-miller marked this conversation as resolved.
Show resolved Hide resolved

return 0;
}
39 changes: 25 additions & 14 deletions python/strings_udf/strings_udf/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,20 +153,6 @@ def generic(self, args, kws):
cuda_decl_registry.register_global(op)(StringViewBinaryOp)


register_stringview_binaryop(operator.eq, types.boolean)
register_stringview_binaryop(operator.ne, types.boolean)
register_stringview_binaryop(operator.lt, types.boolean)
register_stringview_binaryop(operator.gt, types.boolean)
register_stringview_binaryop(operator.le, types.boolean)
register_stringview_binaryop(operator.ge, types.boolean)

# st in other
register_stringview_binaryop(operator.contains, types.boolean)

# st + other
register_stringview_binaryop(operator.add, udf_string)


def create_binary_attr(attrname, retty):
"""
Helper function wrapping numba's low level extension API. Provides
Expand Down Expand Up @@ -212,13 +198,25 @@ def generic(self, args, kws):
return nb_signature(size_type, string_view, recvr=self.this)


class StringViewReplace(AbstractTemplate):
key = "StringView.replace"

def generic(self, args, kws):
return nb_signature(
udf_string, string_view, string_view, recvr=self.this
)


@cuda_decl_registry.register_attr
class StringViewAttrs(AttributeTemplate):
key = string_view

def resolve_count(self, mod):
return types.BoundFunction(StringViewCount, string_view)

def resolve_replace(self, mod):
return types.BoundFunction(StringViewReplace, string_view)


# Build attributes for `MaskedType(string_view)`
bool_binary_funcs = ["startswith", "endswith"]
Expand Down Expand Up @@ -272,3 +270,16 @@ def resolve_count(self, mod):
)

cuda_decl_registry.register_attr(StringViewAttrs)

register_stringview_binaryop(operator.eq, types.boolean)
register_stringview_binaryop(operator.ne, types.boolean)
register_stringview_binaryop(operator.lt, types.boolean)
register_stringview_binaryop(operator.gt, types.boolean)
register_stringview_binaryop(operator.le, types.boolean)
register_stringview_binaryop(operator.ge, types.boolean)

# st in other
register_stringview_binaryop(operator.contains, types.boolean)

# st + other
register_stringview_binaryop(operator.add, udf_string)
36 changes: 36 additions & 0 deletions python/strings_udf/strings_udf/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@
"concat", types.void(_UDF_STRING_PTR, _STR_VIEW_PTR, _STR_VIEW_PTR)
)

_string_view_replace = cuda.declare_device(
"replace",
types.void(_UDF_STRING_PTR, _STR_VIEW_PTR, _STR_VIEW_PTR, _STR_VIEW_PTR),
)


def _declare_binary_func(lhs, rhs, out, name):
# Declare a binary function
Expand Down Expand Up @@ -209,6 +214,37 @@ def concat_impl(context, builder, sig, args):
return result._getvalue()


def call_string_view_replace(result, src, to_replace, replacement):
return _string_view_replace(result, src, to_replace, replacement)


@cuda_lower("StringView.replace", string_view, string_view, string_view)
def replace_impl(context, builder, sig, args):
src_ptr = builder.alloca(args[0].type)
to_replace_ptr = builder.alloca(args[1].type)
replacement_ptr = builder.alloca(args[2].type)

builder.store(args[0], src_ptr)
builder.store(args[1], to_replace_ptr),
builder.store(args[2], replacement_ptr)

udf_str_ptr = builder.alloca(default_manager[udf_string].get_value_type())

_ = context.compile_internal(
builder,
call_string_view_replace,
types.void(
_UDF_STRING_PTR, _STR_VIEW_PTR, _STR_VIEW_PTR, _STR_VIEW_PTR
),
(udf_str_ptr, src_ptr, to_replace_ptr, replacement_ptr),
)

result = cgutils.create_struct_proxy(udf_string)(
context, builder, value=builder.load(udf_str_ptr)
)
return result._getvalue()


def create_binary_string_func(binary_func, retty):
"""
Provide a wrapper around numba's low-level extension API which
Expand Down
9 changes: 9 additions & 0 deletions python/strings_udf/strings_udf/tests/test_string_udfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,3 +332,12 @@ def func(st):
return concat_char + st

run_udf_test(data, func, "str")


@pytest.mark.parametrize("to_replace", ["a", "1", "", "@"])
@pytest.mark.parametrize("replacement", ["a", "1", "", "@"])
def test_string_udf_replace(data, to_replace, replacement):
def func(st):
return st.replace(to_replace, replacement)

run_udf_test(data, func, "str")