Skip to content

Commit

Permalink
Performance improvement for some libcudf regex functions for long str…
Browse files Browse the repository at this point in the history
…ings (#13322)

Changes the internal regex logic to minimize character counting to help performance with longer strings. The improvement applies mainly to libcudf regex functions that return strings (i.e. extract, replace, split). The changes here also improve the internal device APIs for clarity to improve maintenance. The most significant change makes the position variables input-only and returning an optional pair to indicate a successful match.

There are some more optimizations that are possible here where character positions are passed back and forth that could be replaced with byte positions to further reduce counting. Initial measurements showed this noticeably slowed down small strings so more analysis is required before continuing this optimization. 

Reference: #13480

### More Detail

First, there is a change to some internal regex function signatures. Notable the `reprog_device::find()` and `reprog_device::extract()` member functions declared in `cpp/src/strings/regex/regex.cuh` that are used by all the libcudf regex functions. The in/out parameters are now input-only parameters (pass by value) and the return is an optional pair that includes the match result. Also, the `begin` parameter is now an iterator and the `end` parameter now has a default. This change requires updating all the definitions and uses of the `find` and `extract` member functions.

Using an iterator as the `begin` parameter allows for some optimizations in the calling code to minimize character counting that may be needed for processing multi-byte UTF-8 characters. Rather than using the `cudf::string_view::byte_offset()` member function to convert character positions to byte positions, an iterator can be incremented as we traverse through the string which helps reduce some character counting. So the changes here involve removing some calls to `byte_offset()` and incrementing (really moving) iterators with a pattern like `itr += (new_pos - itr.position());` There is another PR #13428 to make a `move_to` iterator member function.

It is possible to reduce the character counting even more as mentioned above but further optimization requires some deeper analysis.

Authors:
  - David Wendt (https://github.com/davidwendt)

Approvers:
  - Mark Harris (https://github.com/harrism)
  - MithunR (https://github.com/mythrocks)

URL: #13322
  • Loading branch information
davidwendt authored Jun 23, 2023
1 parent 0fc31a7 commit f0c62cb
Show file tree
Hide file tree
Showing 15 changed files with 257 additions and 227 deletions.
14 changes: 9 additions & 5 deletions cpp/include/cudf/strings/detail/utilities.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
#include <cudf/strings/string_view.cuh>
#include <cudf/utilities/error.hpp>

#include <thrust/copy.h>
#include <thrust/execution_policy.h>

#include <mutex>
#include <unordered_map>

Expand All @@ -29,14 +32,15 @@ namespace detail {
* @brief Copies input string data into a buffer and increments the pointer by the number of bytes
* copied.
*
* @param buffer Device buffer to copy to.
* @param input Data to copy from.
* @param bytes Number of bytes to copy.
* @return Pointer to the end of the output buffer after the copy.
* @param buffer Device buffer to copy to
* @param input Data to copy from
* @param bytes Number of bytes to copy
* @return Pointer to the end of the output buffer after the copy
*/
__device__ inline char* copy_and_increment(char* buffer, char const* input, size_type bytes)
{
memcpy(buffer, input, bytes);
// this can be slightly faster than memcpy
thrust::copy_n(thrust::seq, input, bytes, buffer);
return buffer + bytes;
}

Expand Down
7 changes: 3 additions & 4 deletions cpp/src/strings/contains.cu
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,9 @@ struct contains_fn {
if (d_strings.is_null(idx)) return false;
auto const d_str = d_strings.element<string_view>(idx);

size_type begin = 0;
size_type end = beginning_only ? 1 // match only the beginning of the string;
: -1; // match anywhere in the string
return static_cast<bool>(prog.find(thread_idx, d_str, begin, end));
size_type end = beginning_only ? 1 // match only the beginning of the string;
: -1; // match anywhere in the string
return prog.find(thread_idx, d_str, d_str.begin(), end).has_value();
}
};

Expand Down
12 changes: 7 additions & 5 deletions cpp/src/strings/count_matches.cu
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,14 @@ struct count_fn {
auto const nchars = d_str.length();
int32_t count = 0;

size_type begin = 0;
size_type end = -1;
while ((begin <= nchars) && (prog.find(thread_idx, d_str, begin, end) > 0)) {
auto itr = d_str.begin();
while (itr.position() <= nchars) {
auto result = prog.find(thread_idx, d_str, itr);
if (!result) { break; }
++count;
begin = end + (begin == end);
end = -1;
// increment the iterator is faster than creating a new one
// +1 if the match was on a virtual position (e.g. word boundary)
itr += (result->second - itr.position()) + (result->first == result->second);
}
return count;
}
Expand Down
23 changes: 12 additions & 11 deletions cpp/src/strings/extract/extract.cu
Original file line number Diff line number Diff line change
Expand Up @@ -61,18 +61,19 @@ struct extract_fn {

if (d_strings.is_valid(idx)) {
auto const d_str = d_strings.element<string_view>(idx);

size_type begin = 0;
size_type end = -1; // handles empty strings automatically
if (d_prog.find(prog_idx, d_str, begin, end) > 0) {
auto const match = d_prog.find(prog_idx, d_str, d_str.begin());
if (match) {
auto const itr = d_str.begin() + match->first;
auto last_pos = itr;
for (auto col_idx = 0; col_idx < groups; ++col_idx) {
auto const extracted = d_prog.extract(prog_idx, d_str, begin, end, col_idx);
d_output[col_idx] = [&] {
if (!extracted) return string_index_pair{nullptr, 0};
auto const offset = d_str.byte_offset((*extracted).first);
return string_index_pair{d_str.data() + offset,
d_str.byte_offset((*extracted).second) - offset};
}();
auto const extracted = d_prog.extract(prog_idx, d_str, itr, match->second, col_idx);
if (extracted) {
auto const d_extracted = string_from_match(*extracted, d_str, last_pos);
d_output[col_idx] = string_index_pair{d_extracted.data(), d_extracted.size_bytes()};
last_pos += (extracted->second - last_pos.position());
} else {
d_output[col_idx] = string_index_pair{nullptr, 0};
}
}
return;
}
Expand Down
38 changes: 21 additions & 17 deletions cpp/src/strings/extract/extract_all.cu
Original file line number Diff line number Diff line change
Expand Up @@ -59,32 +59,36 @@ struct extract_fn {
{
if (d_strings.is_null(idx)) { return; }

auto const d_str = d_strings.element<string_view>(idx);
auto const nchars = d_str.length();

auto const groups = d_prog.group_counts();
auto d_output = d_indices + d_offsets[idx];
size_type output_idx = 0;

auto const d_str = d_strings.element<string_view>(idx);
auto const nchars = d_str.length();
auto itr = d_str.begin();

size_type begin = 0;
size_type end = nchars;
// match the regex
while ((begin < end) && d_prog.find(prog_idx, d_str, begin, end) > 0) {
while (itr.position() < nchars) {
// first, match the regex
auto const match = d_prog.find(prog_idx, d_str, itr);
if (!match) { break; }
itr += (match->first - itr.position()); // position to beginning of the match
auto last_pos = itr;
// extract each group into the output
for (auto group_idx = 0; group_idx < groups; ++group_idx) {
// result is an optional containing the bounds of the extracted string at group_idx
auto const extracted = d_prog.extract(prog_idx, d_str, begin, end, group_idx);

d_output[group_idx + output_idx] = [&] {
if (!extracted) { return string_index_pair{nullptr, 0}; }
auto const start_offset = d_str.byte_offset(extracted->first);
auto const end_offset = d_str.byte_offset(extracted->second);
return string_index_pair{d_str.data() + start_offset, end_offset - start_offset};
}();
auto const extracted = d_prog.extract(prog_idx, d_str, itr, match->second, group_idx);
if (extracted) {
auto const d_result = string_from_match(*extracted, d_str, last_pos);
d_output[group_idx + output_idx] =
string_index_pair{d_result.data(), d_result.size_bytes()};
} else {
d_output[group_idx + output_idx] = string_index_pair{nullptr, 0};
}
last_pos += (extracted->second - last_pos.position());
}
// continue to next match
begin = end;
end = nchars;
// point to the end of this match to start the next match
itr += (match->second - itr.position());
output_idx += groups;
}
}
Expand Down
75 changes: 47 additions & 28 deletions cpp/src/strings/regex/regex.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <strings/regex/regcomp.h>

#include <cudf/strings/regex/flags.hpp>
#include <cudf/strings/string_view.cuh>
#include <cudf/types.hpp>

#include <rmm/cuda_stream_view.hpp>
Expand All @@ -30,9 +31,6 @@
#include <memory>

namespace cudf {

class string_view;

namespace strings {
namespace detail {

Expand Down Expand Up @@ -184,36 +182,33 @@ class reprog_device {
*
* @param thread_idx The index used for mapping the state memory for this string in global memory.
* @param d_str The string to search.
* @param[in,out] begin Position index to begin the search. If found, returns the position found
* in the string.
* @param[in,out] end Position index to end the search. If found, returns the last position
* matching in the string.
* @return Returns 0 if no match is found.
* @param begin Position to begin the search within `d_str`.
* @param end Character position index to end the search within `d_str`.
* Specify -1 to match any virtual positions past the end of the string.
* @return If match found, returns character positions of the matches.
*/
__device__ inline int32_t find(int32_t const thread_idx,
string_view const d_str,
cudf::size_type& begin,
cudf::size_type& end) const;
__device__ inline match_result find(int32_t const thread_idx,
string_view const d_str,
string_view::const_iterator begin,
cudf::size_type end = -1) const;

/**
* @brief Does an extract evaluation using the compiled expression on the given string.
*
* This will find a specific match within the string when more than match occurs.
* This will find a specific capture group within the string.
* The find() function should be called first to locate the begin/end bounds of the
* the matched section.
*
* @param thread_idx The index used for mapping the state memory for this string in global memory.
* @param d_str The string to search.
* @param begin Position index to begin the search. If found, returns the position found
* in the string.
* @param end Position index to end the search. If found, returns the last position
* matching in the string.
* @param begin Position to begin the search within `d_str`.
* @param end Character position index to end the search within `d_str`.
* @param group_id The specific group to return its matching position values.
* @return If valid, returns the character position of the matched group in the given string,
*/
__device__ inline match_result extract(int32_t const thread_idx,
string_view const d_str,
cudf::size_type begin,
string_view::const_iterator begin,
cudf::size_type end,
cudf::size_type const group_id) const;

Expand Down Expand Up @@ -241,20 +236,20 @@ class reprog_device {
/**
* @brief Executes the regex pattern on the given string.
*/
__device__ inline int32_t regexec(string_view const d_str,
reljunk jnk,
cudf::size_type& begin,
cudf::size_type& end,
cudf::size_type const group_id = 0) const;
__device__ inline match_result regexec(string_view const d_str,
reljunk jnk,
string_view::const_iterator begin,
cudf::size_type end,
cudf::size_type const group_id = 0) const;

/**
* @brief Utility wrapper to setup state memory structures for calling regexec
*/
__device__ inline int32_t call_regexec(int32_t const thread_idx,
string_view const d_str,
cudf::size_type& begin,
cudf::size_type& end,
cudf::size_type const group_id = 0) const;
__device__ inline match_result call_regexec(int32_t const thread_idx,
string_view const d_str,
string_view::const_iterator begin,
cudf::size_type end,
cudf::size_type const group_id = 0) const;

reprog_device(reprog const&);

Expand Down Expand Up @@ -285,6 +280,30 @@ class reprog_device {
*/
std::size_t compute_working_memory_size(int32_t num_threads, int32_t insts_count);

/**
* @brief Converts a match_pair from character positions to byte positions
*/
__device__ __forceinline__ match_pair match_positions_to_bytes(match_pair const result,
string_view d_str,
string_view::const_iterator last)
{
if (d_str.length() == d_str.size_bytes()) { return result; }
auto const begin = (last + (result.first - last.position())).byte_offset();
auto const end = (last + (result.second - last.position())).byte_offset();
return {begin, end};
}

/**
* @brief Creates a string_view from a match result
*/
__device__ __forceinline__ string_view string_from_match(match_pair const result,
string_view d_str,
string_view::const_iterator last)
{
auto const [begin, end] = match_positions_to_bytes(result, d_str, last);
return string_view(d_str.data() + begin, end - begin);
}

} // namespace detail
} // namespace strings
} // namespace cudf
Expand Down
58 changes: 25 additions & 33 deletions cpp/src/strings/regex/regex.inl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,10 +16,6 @@

#include <cudf/detail/utilities/integer_utils.hpp>
#include <cudf/strings/detail/char_tables.hpp>
#include <cudf/strings/detail/utf8.hpp>
#include <cudf/strings/string_view.cuh>

#include <thrust/optional.h>

namespace cudf {
namespace strings {
Expand Down Expand Up @@ -235,21 +231,19 @@ __device__ __forceinline__ reprog_device reprog_device::load(reprog_device const
* @param group_id Index of the group to match in a multi-group regex pattern.
* @return >0 if match found
*/
__device__ __forceinline__ int32_t reprog_device::regexec(string_view const dstr,
reljunk jnk,
cudf::size_type& begin,
cudf::size_type& end,
cudf::size_type const group_id) const
__device__ __forceinline__ match_result reprog_device::regexec(string_view const dstr,
reljunk jnk,
string_view::const_iterator itr,
cudf::size_type end,
cudf::size_type const group_id) const
{
int32_t match = 0;
auto begin = itr.position();
auto pos = begin;
auto eos = end;
char_utf8 c = 0;
auto checkstart = jnk.starttype != 0;
auto last_character = false;

string_view::const_iterator itr = string_view::const_iterator(dstr, pos);

jnk.list1->reset();
do {
// fast check for first CHAR or BOL
Expand All @@ -258,12 +252,12 @@ __device__ __forceinline__ int32_t reprog_device::regexec(string_view const dstr
switch (jnk.starttype) {
case BOL:
if (pos == 0) break;
if (jnk.startchar != '^') { return match; }
if (jnk.startchar != '^') { return thrust::nullopt; }
--pos;
startchar = static_cast<char_utf8>('\n');
case CHAR: {
auto const fidx = dstr.find(startchar, pos);
if (fidx == string_view::npos) { return match; }
if (fidx == string_view::npos) { return thrust::nullopt; }
pos = fidx + (jnk.starttype == BOL);
break;
}
Expand All @@ -279,7 +273,7 @@ __device__ __forceinline__ int32_t reprog_device::regexec(string_view const dstr

last_character = itr.byte_offset() >= dstr.size_bytes();

c = last_character ? 0 : *itr;
char_utf8 const c = last_character ? 0 : *itr;

// expand the non-character types like: LBRA, RBRA, BOL, EOL, BOW, NBOW, and OR
bool expanded = false;
Expand Down Expand Up @@ -394,35 +388,33 @@ __device__ __forceinline__ int32_t reprog_device::regexec(string_view const dstr
checkstart = jnk.list1->get_size() == 0;
} while (!last_character && (!checkstart || !match));

return match;
return match ? match_result({begin, end}) : thrust::nullopt;
}

__device__ __forceinline__ int32_t reprog_device::find(int32_t const thread_idx,
string_view const dstr,
cudf::size_type& begin,
cudf::size_type& end) const
__device__ __forceinline__ match_result reprog_device::find(int32_t const thread_idx,
string_view const dstr,
string_view::const_iterator begin,
cudf::size_type end) const
{
auto const rtn = call_regexec(thread_idx, dstr, begin, end);
if (rtn <= 0) begin = end = -1;
return rtn;
return call_regexec(thread_idx, dstr, begin, end);
}

__device__ __forceinline__ match_result reprog_device::extract(int32_t const thread_idx,
string_view const dstr,
cudf::size_type begin,
string_view::const_iterator begin,
cudf::size_type end,
cudf::size_type const group_id) const
{
end = begin + 1;
return call_regexec(thread_idx, dstr, begin, end, group_id + 1) > 0 ? match_result({begin, end})
: thrust::nullopt;
end = begin.position() + 1;
return call_regexec(thread_idx, dstr, begin, end, group_id + 1);
}

__device__ __forceinline__ int32_t reprog_device::call_regexec(int32_t const thread_idx,
string_view const dstr,
cudf::size_type& begin,
cudf::size_type& end,
cudf::size_type const group_id) const
__device__ __forceinline__ match_result
reprog_device::call_regexec(int32_t const thread_idx,
string_view const dstr,
string_view::const_iterator begin,
cudf::size_type end,
cudf::size_type const group_id) const
{
auto gp_ptr = reinterpret_cast<u_char*>(_buffer);
relist list1(static_cast<int16_t>(_max_insts), _thread_count, gp_ptr, thread_idx);
Expand Down
Loading

0 comments on commit f0c62cb

Please sign in to comment.