diff --git a/cpp/include/cudf/strings/detail/utilities.cuh b/cpp/include/cudf/strings/detail/utilities.cuh index 64f5d3f0450..5c719cd25d2 100644 --- a/cpp/include/cudf/strings/detail/utilities.cuh +++ b/cpp/include/cudf/strings/detail/utilities.cuh @@ -18,6 +18,9 @@ #include #include +#include +#include + #include #include @@ -29,14 +32,15 @@ namespace detail { * @brief Copies input string data into a buffer and increments the pointer by the number of bytes * copied. * - * @param buffer Device buffer to copy to. - * @param input Data to copy from. - * @param bytes Number of bytes to copy. - * @return Pointer to the end of the output buffer after the copy. + * @param buffer Device buffer to copy to + * @param input Data to copy from + * @param bytes Number of bytes to copy + * @return Pointer to the end of the output buffer after the copy */ __device__ inline char* copy_and_increment(char* buffer, char const* input, size_type bytes) { - memcpy(buffer, input, bytes); + // this can be slightly faster than memcpy + thrust::copy_n(thrust::seq, input, bytes, buffer); return buffer + bytes; } diff --git a/cpp/src/strings/contains.cu b/cpp/src/strings/contains.cu index 44b3faeb38a..22534870409 100644 --- a/cpp/src/strings/contains.cu +++ b/cpp/src/strings/contains.cu @@ -50,10 +50,9 @@ struct contains_fn { if (d_strings.is_null(idx)) return false; auto const d_str = d_strings.element(idx); - size_type begin = 0; - size_type end = beginning_only ? 1 // match only the beginning of the string; - : -1; // match anywhere in the string - return static_cast(prog.find(thread_idx, d_str, begin, end)); + size_type end = beginning_only ? 1 // match only the beginning of the string; + : -1; // match anywhere in the string + return prog.find(thread_idx, d_str, d_str.begin(), end).has_value(); } }; diff --git a/cpp/src/strings/count_matches.cu b/cpp/src/strings/count_matches.cu index 1fde3a54089..6de5d43dc94 100644 --- a/cpp/src/strings/count_matches.cu +++ b/cpp/src/strings/count_matches.cu @@ -41,12 +41,14 @@ struct count_fn { auto const nchars = d_str.length(); int32_t count = 0; - size_type begin = 0; - size_type end = -1; - while ((begin <= nchars) && (prog.find(thread_idx, d_str, begin, end) > 0)) { + auto itr = d_str.begin(); + while (itr.position() <= nchars) { + auto result = prog.find(thread_idx, d_str, itr); + if (!result) { break; } ++count; - begin = end + (begin == end); - end = -1; + // increment the iterator is faster than creating a new one + // +1 if the match was on a virtual position (e.g. word boundary) + itr += (result->second - itr.position()) + (result->first == result->second); } return count; } diff --git a/cpp/src/strings/extract/extract.cu b/cpp/src/strings/extract/extract.cu index ccfc007e7ed..532053e750e 100644 --- a/cpp/src/strings/extract/extract.cu +++ b/cpp/src/strings/extract/extract.cu @@ -61,18 +61,19 @@ struct extract_fn { if (d_strings.is_valid(idx)) { auto const d_str = d_strings.element(idx); - - size_type begin = 0; - size_type end = -1; // handles empty strings automatically - if (d_prog.find(prog_idx, d_str, begin, end) > 0) { + auto const match = d_prog.find(prog_idx, d_str, d_str.begin()); + if (match) { + auto const itr = d_str.begin() + match->first; + auto last_pos = itr; for (auto col_idx = 0; col_idx < groups; ++col_idx) { - auto const extracted = d_prog.extract(prog_idx, d_str, begin, end, col_idx); - d_output[col_idx] = [&] { - if (!extracted) return string_index_pair{nullptr, 0}; - auto const offset = d_str.byte_offset((*extracted).first); - return string_index_pair{d_str.data() + offset, - d_str.byte_offset((*extracted).second) - offset}; - }(); + auto const extracted = d_prog.extract(prog_idx, d_str, itr, match->second, col_idx); + if (extracted) { + auto const d_extracted = string_from_match(*extracted, d_str, last_pos); + d_output[col_idx] = string_index_pair{d_extracted.data(), d_extracted.size_bytes()}; + last_pos += (extracted->second - last_pos.position()); + } else { + d_output[col_idx] = string_index_pair{nullptr, 0}; + } } return; } diff --git a/cpp/src/strings/extract/extract_all.cu b/cpp/src/strings/extract/extract_all.cu index 1252e79be90..fcd05ee9dc6 100644 --- a/cpp/src/strings/extract/extract_all.cu +++ b/cpp/src/strings/extract/extract_all.cu @@ -59,32 +59,36 @@ struct extract_fn { { if (d_strings.is_null(idx)) { return; } + auto const d_str = d_strings.element(idx); + auto const nchars = d_str.length(); + auto const groups = d_prog.group_counts(); auto d_output = d_indices + d_offsets[idx]; size_type output_idx = 0; - auto const d_str = d_strings.element(idx); - auto const nchars = d_str.length(); + auto itr = d_str.begin(); - size_type begin = 0; - size_type end = nchars; - // match the regex - while ((begin < end) && d_prog.find(prog_idx, d_str, begin, end) > 0) { + while (itr.position() < nchars) { + // first, match the regex + auto const match = d_prog.find(prog_idx, d_str, itr); + if (!match) { break; } + itr += (match->first - itr.position()); // position to beginning of the match + auto last_pos = itr; // extract each group into the output for (auto group_idx = 0; group_idx < groups; ++group_idx) { // result is an optional containing the bounds of the extracted string at group_idx - auto const extracted = d_prog.extract(prog_idx, d_str, begin, end, group_idx); - - d_output[group_idx + output_idx] = [&] { - if (!extracted) { return string_index_pair{nullptr, 0}; } - auto const start_offset = d_str.byte_offset(extracted->first); - auto const end_offset = d_str.byte_offset(extracted->second); - return string_index_pair{d_str.data() + start_offset, end_offset - start_offset}; - }(); + auto const extracted = d_prog.extract(prog_idx, d_str, itr, match->second, group_idx); + if (extracted) { + auto const d_result = string_from_match(*extracted, d_str, last_pos); + d_output[group_idx + output_idx] = + string_index_pair{d_result.data(), d_result.size_bytes()}; + } else { + d_output[group_idx + output_idx] = string_index_pair{nullptr, 0}; + } + last_pos += (extracted->second - last_pos.position()); } - // continue to next match - begin = end; - end = nchars; + // point to the end of this match to start the next match + itr += (match->second - itr.position()); output_idx += groups; } } diff --git a/cpp/src/strings/regex/regex.cuh b/cpp/src/strings/regex/regex.cuh index 4d18af69b9c..19d82380350 100644 --- a/cpp/src/strings/regex/regex.cuh +++ b/cpp/src/strings/regex/regex.cuh @@ -18,6 +18,7 @@ #include #include +#include #include #include @@ -30,9 +31,6 @@ #include namespace cudf { - -class string_view; - namespace strings { namespace detail { @@ -184,36 +182,33 @@ class reprog_device { * * @param thread_idx The index used for mapping the state memory for this string in global memory. * @param d_str The string to search. - * @param[in,out] begin Position index to begin the search. If found, returns the position found - * in the string. - * @param[in,out] end Position index to end the search. If found, returns the last position - * matching in the string. - * @return Returns 0 if no match is found. + * @param begin Position to begin the search within `d_str`. + * @param end Character position index to end the search within `d_str`. + * Specify -1 to match any virtual positions past the end of the string. + * @return If match found, returns character positions of the matches. */ - __device__ inline int32_t find(int32_t const thread_idx, - string_view const d_str, - cudf::size_type& begin, - cudf::size_type& end) const; + __device__ inline match_result find(int32_t const thread_idx, + string_view const d_str, + string_view::const_iterator begin, + cudf::size_type end = -1) const; /** * @brief Does an extract evaluation using the compiled expression on the given string. * - * This will find a specific match within the string when more than match occurs. + * This will find a specific capture group within the string. * The find() function should be called first to locate the begin/end bounds of the * the matched section. * * @param thread_idx The index used for mapping the state memory for this string in global memory. * @param d_str The string to search. - * @param begin Position index to begin the search. If found, returns the position found - * in the string. - * @param end Position index to end the search. If found, returns the last position - * matching in the string. + * @param begin Position to begin the search within `d_str`. + * @param end Character position index to end the search within `d_str`. * @param group_id The specific group to return its matching position values. * @return If valid, returns the character position of the matched group in the given string, */ __device__ inline match_result extract(int32_t const thread_idx, string_view const d_str, - cudf::size_type begin, + string_view::const_iterator begin, cudf::size_type end, cudf::size_type const group_id) const; @@ -241,20 +236,20 @@ class reprog_device { /** * @brief Executes the regex pattern on the given string. */ - __device__ inline int32_t regexec(string_view const d_str, - reljunk jnk, - cudf::size_type& begin, - cudf::size_type& end, - cudf::size_type const group_id = 0) const; + __device__ inline match_result regexec(string_view const d_str, + reljunk jnk, + string_view::const_iterator begin, + cudf::size_type end, + cudf::size_type const group_id = 0) const; /** * @brief Utility wrapper to setup state memory structures for calling regexec */ - __device__ inline int32_t call_regexec(int32_t const thread_idx, - string_view const d_str, - cudf::size_type& begin, - cudf::size_type& end, - cudf::size_type const group_id = 0) const; + __device__ inline match_result call_regexec(int32_t const thread_idx, + string_view const d_str, + string_view::const_iterator begin, + cudf::size_type end, + cudf::size_type const group_id = 0) const; reprog_device(reprog const&); @@ -285,6 +280,30 @@ class reprog_device { */ std::size_t compute_working_memory_size(int32_t num_threads, int32_t insts_count); +/** + * @brief Converts a match_pair from character positions to byte positions + */ +__device__ __forceinline__ match_pair match_positions_to_bytes(match_pair const result, + string_view d_str, + string_view::const_iterator last) +{ + if (d_str.length() == d_str.size_bytes()) { return result; } + auto const begin = (last + (result.first - last.position())).byte_offset(); + auto const end = (last + (result.second - last.position())).byte_offset(); + return {begin, end}; +} + +/** + * @brief Creates a string_view from a match result + */ +__device__ __forceinline__ string_view string_from_match(match_pair const result, + string_view d_str, + string_view::const_iterator last) +{ + auto const [begin, end] = match_positions_to_bytes(result, d_str, last); + return string_view(d_str.data() + begin, end - begin); +} + } // namespace detail } // namespace strings } // namespace cudf diff --git a/cpp/src/strings/regex/regex.inl b/cpp/src/strings/regex/regex.inl index d25a0888f32..c5205ae7789 100644 --- a/cpp/src/strings/regex/regex.inl +++ b/cpp/src/strings/regex/regex.inl @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,10 +16,6 @@ #include #include -#include -#include - -#include namespace cudf { namespace strings { @@ -235,21 +231,19 @@ __device__ __forceinline__ reprog_device reprog_device::load(reprog_device const * @param group_id Index of the group to match in a multi-group regex pattern. * @return >0 if match found */ -__device__ __forceinline__ int32_t reprog_device::regexec(string_view const dstr, - reljunk jnk, - cudf::size_type& begin, - cudf::size_type& end, - cudf::size_type const group_id) const +__device__ __forceinline__ match_result reprog_device::regexec(string_view const dstr, + reljunk jnk, + string_view::const_iterator itr, + cudf::size_type end, + cudf::size_type const group_id) const { int32_t match = 0; + auto begin = itr.position(); auto pos = begin; auto eos = end; - char_utf8 c = 0; auto checkstart = jnk.starttype != 0; auto last_character = false; - string_view::const_iterator itr = string_view::const_iterator(dstr, pos); - jnk.list1->reset(); do { // fast check for first CHAR or BOL @@ -258,12 +252,12 @@ __device__ __forceinline__ int32_t reprog_device::regexec(string_view const dstr switch (jnk.starttype) { case BOL: if (pos == 0) break; - if (jnk.startchar != '^') { return match; } + if (jnk.startchar != '^') { return thrust::nullopt; } --pos; startchar = static_cast('\n'); case CHAR: { auto const fidx = dstr.find(startchar, pos); - if (fidx == string_view::npos) { return match; } + if (fidx == string_view::npos) { return thrust::nullopt; } pos = fidx + (jnk.starttype == BOL); break; } @@ -279,7 +273,7 @@ __device__ __forceinline__ int32_t reprog_device::regexec(string_view const dstr last_character = itr.byte_offset() >= dstr.size_bytes(); - c = last_character ? 0 : *itr; + char_utf8 const c = last_character ? 0 : *itr; // expand the non-character types like: LBRA, RBRA, BOL, EOL, BOW, NBOW, and OR bool expanded = false; @@ -394,35 +388,33 @@ __device__ __forceinline__ int32_t reprog_device::regexec(string_view const dstr checkstart = jnk.list1->get_size() == 0; } while (!last_character && (!checkstart || !match)); - return match; + return match ? match_result({begin, end}) : thrust::nullopt; } -__device__ __forceinline__ int32_t reprog_device::find(int32_t const thread_idx, - string_view const dstr, - cudf::size_type& begin, - cudf::size_type& end) const +__device__ __forceinline__ match_result reprog_device::find(int32_t const thread_idx, + string_view const dstr, + string_view::const_iterator begin, + cudf::size_type end) const { - auto const rtn = call_regexec(thread_idx, dstr, begin, end); - if (rtn <= 0) begin = end = -1; - return rtn; + return call_regexec(thread_idx, dstr, begin, end); } __device__ __forceinline__ match_result reprog_device::extract(int32_t const thread_idx, string_view const dstr, - cudf::size_type begin, + string_view::const_iterator begin, cudf::size_type end, cudf::size_type const group_id) const { - end = begin + 1; - return call_regexec(thread_idx, dstr, begin, end, group_id + 1) > 0 ? match_result({begin, end}) - : thrust::nullopt; + end = begin.position() + 1; + return call_regexec(thread_idx, dstr, begin, end, group_id + 1); } -__device__ __forceinline__ int32_t reprog_device::call_regexec(int32_t const thread_idx, - string_view const dstr, - cudf::size_type& begin, - cudf::size_type& end, - cudf::size_type const group_id) const +__device__ __forceinline__ match_result +reprog_device::call_regexec(int32_t const thread_idx, + string_view const dstr, + string_view::const_iterator begin, + cudf::size_type end, + cudf::size_type const group_id) const { auto gp_ptr = reinterpret_cast(_buffer); relist list1(static_cast(_max_insts), _thread_count, gp_ptr, thread_idx); diff --git a/cpp/src/strings/regex/regex_program_impl.h b/cpp/src/strings/regex/regex_program_impl.h index eede2225bce..74cc1902739 100644 --- a/cpp/src/strings/regex/regex_program_impl.h +++ b/cpp/src/strings/regex/regex_program_impl.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#pragma once #include "regcomp.h" #include "regex.cuh" diff --git a/cpp/src/strings/replace/backref_re.cuh b/cpp/src/strings/replace/backref_re.cuh index a5f3ace2141..aeaea40358f 100644 --- a/cpp/src/strings/replace/backref_re.cuh +++ b/cpp/src/strings/replace/backref_re.cuh @@ -45,7 +45,7 @@ struct backrefs_fn { string_view const d_repl; // string replacement template Iterator backrefs_begin; Iterator backrefs_end; - int32_t* d_offsets{}; + size_type* d_offsets{}; char* d_chars{}; __device__ void operator()(size_type const idx, reprog_device const prog, int32_t const prog_idx) @@ -59,23 +59,27 @@ struct backrefs_fn { auto const nchars = d_str.length(); // number of characters in input string auto nbytes = d_str.size_bytes(); // number of bytes for the output string auto out_ptr = d_chars ? (d_chars + d_offsets[idx]) : nullptr; - size_type lpos = 0; // last byte position processed in d_str - size_type begin = 0; // first character position matching regex - size_type end = -1; // match through the end of the string + auto itr = d_str.begin(); + auto last_pos = itr; // copy input to output replacing strings as we go - while ((begin <= nchars) && - (prog.find(prog_idx, d_str, begin, end) > 0)) // inits the begin/end vars + while (itr.position() <= nchars) // inits the begin/end vars { - auto spos = d_str.byte_offset(begin); // get offset for the - auto epos = d_str.byte_offset(end); // character position values; - nbytes += d_repl.size_bytes() - (epos - spos); // compute the output size + auto const match = prog.find(prog_idx, d_str, itr); + if (!match) { break; } + + auto const [start_pos, end_pos] = match_positions_to_bytes(*match, d_str, itr); + nbytes += d_repl.size_bytes() - (end_pos - start_pos); // compute the output size // copy the string data before the matched section - if (out_ptr) { out_ptr = copy_and_increment(out_ptr, in_ptr + lpos, spos - lpos); } + if (out_ptr) { + out_ptr = copy_and_increment( + out_ptr, in_ptr + last_pos.byte_offset(), start_pos - last_pos.byte_offset()); + } size_type lpos_template = 0; // last end pos of replace template auto const repl_ptr = d_repl.data(); // replace template pattern + itr += (match->first - itr.position()); thrust::for_each( thrust::seq, backrefs_begin, backrefs_end, [&] __device__(backref_type backref) { if (out_ptr) { @@ -84,17 +88,13 @@ struct backrefs_fn { lpos_template += copy_length; } // extract the specific group's string for this backref's index - auto extracted = prog.extract(prog_idx, d_str, begin, end, backref.first - 1); - if (!extracted || (extracted.value().second < extracted.value().first)) { + auto extracted = prog.extract(prog_idx, d_str, itr, match->second, backref.first - 1); + if (!extracted || (extracted->second < extracted->first)) { return; // no value for this backref number; that is ok } - auto spos_extract = d_str.byte_offset(extracted.value().first); // convert - auto epos_extract = d_str.byte_offset(extracted.value().second); // to bytes - nbytes += epos_extract - spos_extract; - if (out_ptr) { - out_ptr = - copy_and_increment(out_ptr, in_ptr + spos_extract, (epos_extract - spos_extract)); - } + auto const d_str_ex = string_from_match(*extracted, d_str, itr); + nbytes += d_str_ex.size_bytes(); + if (out_ptr) { out_ptr = copy_string(out_ptr, d_str_ex); } }); // copy remainder of template @@ -104,16 +104,16 @@ struct backrefs_fn { } // setup to match the next section - lpos = epos; - begin = end + (begin == end); - end = -1; + last_pos += (match->second - last_pos.position()); + itr = last_pos + (match->first == match->second); } // finally, copy remainder of input string - if (out_ptr && (lpos < d_str.size_bytes())) { - memcpy(out_ptr, in_ptr + lpos, d_str.size_bytes() - lpos); - } else if (!out_ptr) { - d_offsets[idx] = static_cast(nbytes); + if (out_ptr) { + thrust::copy_n( + thrust::seq, in_ptr + itr.byte_offset(), d_str.size_bytes() - itr.byte_offset(), out_ptr); + } else { + d_offsets[idx] = nbytes; } } }; diff --git a/cpp/src/strings/replace/multi_re.cu b/cpp/src/strings/replace/multi_re.cu index b554d0a815c..867b443c036 100644 --- a/cpp/src/strings/replace/multi_re.cu +++ b/cpp/src/strings/replace/multi_re.cu @@ -55,7 +55,7 @@ struct replace_multi_regex_fn { device_span progs; // array of regex progs found_range* d_found_ranges; // working array matched (begin,end) values column_device_view const d_repls; // replacement strings - int32_t* d_offsets{}; + size_type* d_offsets{}; char* d_chars{}; __device__ void operator()(size_type idx) @@ -67,61 +67,69 @@ struct replace_multi_regex_fn { auto const number_of_patterns = static_cast(progs.size()); - auto const d_str = d_strings.element(idx); - auto const nchars = d_str.length(); // number of characters in input string - auto nbytes = d_str.size_bytes(); // number of bytes in input string - auto in_ptr = d_str.data(); // input pointer - auto out_ptr = d_chars ? d_chars + d_offsets[idx] : nullptr; + auto const d_str = d_strings.element(idx); + auto const nchars = d_str.length(); // number of characters in input string + auto nbytes = d_str.size_bytes(); // number of bytes in input string + auto in_ptr = d_str.data(); // input pointer + auto out_ptr = d_chars ? d_chars + d_offsets[idx] : nullptr; + auto itr = d_str.begin(); + auto last_pos = itr; + found_range* d_ranges = d_found_ranges + (idx * number_of_patterns); - size_type lpos = 0; - size_type ch_pos = 0; + // initialize the working ranges memory to -1's thrust::fill(thrust::seq, d_ranges, d_ranges + number_of_patterns, found_range{-1, 1}); + // process string one character at a time - while (ch_pos < nchars) { + while (itr.position() < nchars) { // this minimizes the regex-find calls by only calling it for stale patterns // -- those that have not previously matched up to this point (ch_pos) for (size_type ptn_idx = 0; ptn_idx < number_of_patterns; ++ptn_idx) { - if (d_ranges[ptn_idx].first >= ch_pos) // previously matched here - continue; // or later in the string + if (d_ranges[ptn_idx].first >= itr.position()) { // previously matched here + continue; // or later in the string + } reprog_device prog = progs[ptn_idx]; - auto begin = ch_pos; - auto end = nchars; - if (!prog.is_empty() && prog.find(idx, d_str, begin, end) > 0) - d_ranges[ptn_idx] = found_range{begin, end}; // found a match - else - d_ranges[ptn_idx] = found_range{nchars, nchars}; // this pattern is done + auto const result = !prog.is_empty() ? prog.find(idx, d_str, itr) : thrust::nullopt; + d_ranges[ptn_idx] = + result ? found_range{result->first, result->second} : found_range{nchars, nchars}; } // all the ranges have been updated from each regex match; // look for any that match at this character position (ch_pos) - auto itr = - thrust::find_if(thrust::seq, d_ranges, d_ranges + number_of_patterns, [ch_pos](auto range) { - return range.first == ch_pos; - }); - if (itr != d_ranges + number_of_patterns) { + auto const ptn_itr = + thrust::find_if(thrust::seq, + d_ranges, + d_ranges + number_of_patterns, + [ch_pos = itr.position()](auto range) { return range.first == ch_pos; }); + if (ptn_itr != d_ranges + number_of_patterns) { // match found, compute and replace the string in the output - size_type ptn_idx = static_cast(itr - d_ranges); - size_type begin = d_ranges[ptn_idx].first; - size_type end = d_ranges[ptn_idx].second; - string_view d_repl = d_repls.size() > 1 ? d_repls.element(ptn_idx) - : d_repls.element(0); - auto spos = d_str.byte_offset(begin); - auto epos = d_str.byte_offset(end); - nbytes += d_repl.size_bytes() - (epos - spos); + auto const ptn_idx = static_cast(thrust::distance(d_ranges, ptn_itr)); + + auto d_repl = d_repls.size() > 1 ? d_repls.element(ptn_idx) + : d_repls.element(0); + + auto const d_range = d_ranges[ptn_idx]; + auto const [start_pos, end_pos] = + match_positions_to_bytes({d_range.first, d_range.second}, d_str, last_pos); + nbytes += d_repl.size_bytes() - (end_pos - start_pos); if (out_ptr) { // copy unmodified content plus new replacement string - out_ptr = copy_and_increment(out_ptr, in_ptr + lpos, spos - lpos); + out_ptr = copy_and_increment( + out_ptr, in_ptr + last_pos.byte_offset(), start_pos - last_pos.byte_offset()); out_ptr = copy_string(out_ptr, d_repl); - lpos = epos; } - ch_pos = end - 1; + last_pos += (d_range.second - last_pos.position()); + itr = last_pos - 1; } - ++ch_pos; + ++itr; + } + if (out_ptr) { // copy the remainder + thrust::copy_n(thrust::seq, + in_ptr + last_pos.byte_offset(), + d_str.size_bytes() - last_pos.byte_offset(), + out_ptr); + } else { + d_offsets[idx] = nbytes; } - if (out_ptr) // copy the remainder - memcpy(out_ptr, in_ptr + lpos, d_str.size_bytes() - lpos); - else - d_offsets[idx] = static_cast(nbytes); } }; diff --git a/cpp/src/strings/replace/replace_re.cu b/cpp/src/strings/replace/replace_re.cu index c334d2b2013..460074a5296 100644 --- a/cpp/src/strings/replace/replace_re.cu +++ b/cpp/src/strings/replace/replace_re.cu @@ -42,7 +42,7 @@ struct replace_regex_fn { column_device_view const d_strings; string_view const d_repl; size_type const maxrepl; - int32_t* d_offsets{}; + size_type* d_offsets{}; char* d_chars{}; __device__ void operator()(size_type const idx, reprog_device const prog, int32_t const prog_idx) @@ -54,46 +54,42 @@ struct replace_regex_fn { auto const d_str = d_strings.element(idx); auto const nchars = d_str.length(); - auto nbytes = d_str.size_bytes(); // number of bytes in input string - auto mxn = maxrepl < 0 ? nchars + 1 : maxrepl; // max possible replaces for this string - auto in_ptr = d_str.data(); // input pointer (i) - auto out_ptr = d_chars ? d_chars + d_offsets[idx] // output pointer (o) - : nullptr; - size_type last_pos = 0; - size_type begin = 0; // these are for calling prog.find - size_type end = -1; // matches final word-boundary if at the end of the string + auto nbytes = d_str.size_bytes(); // number of bytes in input string + auto mxn = maxrepl < 0 ? nchars + 1 : maxrepl; // max possible replaces for this string + auto in_ptr = d_str.data(); // input pointer (i) + auto out_ptr = d_chars ? d_chars + d_offsets[idx] // output pointer (o) + : nullptr; + auto itr = d_str.begin(); + auto last_pos = itr; // copy input to output replacing strings as we go - while (mxn-- > 0 && begin <= nchars) { // maximum number of replaces - - if (prog.is_empty() || prog.find(prog_idx, d_str, begin, end) <= 0) { - break; // no more matches - } - - auto const start_pos = d_str.byte_offset(begin); // get offset for these - auto const end_pos = d_str.byte_offset(end); // character position values - nbytes += d_repl.size_bytes() - (end_pos - start_pos); // and compute new size - - if (out_ptr) { // replace: - // i:bbbbsssseeee - out_ptr = copy_and_increment(out_ptr, // ^ - in_ptr + last_pos, // o:bbbb - start_pos - last_pos); // ^ - out_ptr = copy_string(out_ptr, d_repl); // o:bbbbrrrrrr - // out_ptr ---^ - last_pos = end_pos; // i:bbbbsssseeee - } // in_ptr --^ - - begin = end + (begin == end); - end = -1; + while (mxn-- > 0 && itr.position() <= nchars && !prog.is_empty()) { + auto const match = prog.find(prog_idx, d_str, itr); + if (!match) { break; } // no more matches + + auto const [start_pos, end_pos] = match_positions_to_bytes(*match, d_str, last_pos); + nbytes += d_repl.size_bytes() - (end_pos - start_pos); // add new size + + if (out_ptr) { // replace: + // i:bbbbsssseeee + out_ptr = copy_and_increment(out_ptr, // ^ + in_ptr + last_pos.byte_offset(), // o:bbbb + start_pos - last_pos.byte_offset()); // ^ + out_ptr = copy_string(out_ptr, d_repl); // o:bbbbrrrrrr + } // out_ptr ---^ + last_pos += (match->second - last_pos.position()); // i:bbbbsssseeee + // in_ptr --^ + + itr = last_pos + (match->first == match->second); } if (out_ptr) { - memcpy(out_ptr, // copy the remainder - in_ptr + last_pos, // o:bbbbrrrrrreeee - d_str.size_bytes() - last_pos); // ^ ^ + thrust::copy_n(thrust::seq, // copy the remainder + in_ptr + last_pos.byte_offset(), // o:bbbbrrrrrreeee + d_str.size_bytes() - last_pos.byte_offset(), // ^ ^ + out_ptr); } else { - d_offsets[idx] = static_cast(nbytes); + d_offsets[idx] = nbytes; } } }; diff --git a/cpp/src/strings/search/findall.cu b/cpp/src/strings/search/findall.cu index 0c8359928a5..596fbb39d15 100644 --- a/cpp/src/strings/search/findall.cu +++ b/cpp/src/strings/search/findall.cu @@ -62,16 +62,15 @@ struct findall_fn { auto d_output = d_indices + d_offsets[idx]; size_type output_idx = 0; - size_type begin = 0; - size_type end = nchars; - while ((begin < end) && (prog.find(prog_idx, d_str, begin, end) > 0)) { - auto const spos = d_str.byte_offset(begin); // convert - auto const epos = d_str.byte_offset(end); // to bytes + auto itr = d_str.begin(); + while (itr.position() < nchars) { + auto const match = prog.find(prog_idx, d_str, itr); + if (!match) { break; } - d_output[output_idx++] = string_index_pair{d_str.data() + spos, (epos - spos)}; + auto const d_result = string_from_match(*match, d_str, itr); + d_output[output_idx++] = string_index_pair{d_result.data(), d_result.size_bytes()}; - begin = end + (begin == end); - end = nchars; + itr += (match->second - itr.position()); } } }; diff --git a/cpp/src/strings/split/split_re.cu b/cpp/src/strings/split/split_re.cu index 25fe4d00336..f0829eb08ba 100644 --- a/cpp/src/strings/split/split_re.cu +++ b/cpp/src/strings/split/split_re.cu @@ -66,20 +66,25 @@ struct token_reader_fn { __device__ void operator()(size_type const idx, reprog_device const prog, int32_t const prog_idx) { if (d_strings.is_null(idx)) { return; } - auto const d_str = d_strings.element(idx); + auto const d_str = d_strings.element(idx); + auto const nchars = d_str.length(); auto const token_offset = d_token_offsets[idx]; auto const token_count = d_token_offsets[idx + 1] - token_offset; auto const d_result = d_tokens + token_offset; // store tokens here size_type token_idx = 0; - size_type begin = 0; // characters - size_type end = -1; - size_type last_pos = 0; // bytes - while (prog.find(prog_idx, d_str, begin, end) > 0) { + auto itr = d_str.begin(); + auto last_pos = itr; + while (itr.position() <= nchars) { + auto const match = prog.find(prog_idx, d_str, itr); + if (!match) { break; } + + auto const [start_pos, end_pos] = match_positions_to_bytes(*match, d_str, last_pos); + // get the token (characters just before this match) - auto const token = - string_index_pair{d_str.data() + last_pos, d_str.byte_offset(begin) - last_pos}; + auto const token = string_index_pair{d_str.data() + last_pos.byte_offset(), + start_pos - last_pos.byte_offset()}; // store it if we have space if (token_idx < token_count - 1) { d_result[token_idx++] = token; @@ -91,13 +96,13 @@ struct token_reader_fn { d_result[token_idx - 1] = token; } // setup for next match - last_pos = d_str.byte_offset(end); - begin = end + (begin == end); - end = -1; + last_pos += (match->second - last_pos.position()); + itr = last_pos + (match->first == match->second); } // set the last token to the remainder of the string - d_result[token_idx] = string_index_pair{d_str.data() + last_pos, d_str.size_bytes() - last_pos}; + d_result[token_idx] = string_index_pair{d_str.data() + last_pos.byte_offset(), + d_str.size_bytes() - last_pos.byte_offset()}; if (direction == split_direction::BACKWARD) { // update first entry -- this happens when max_tokens is hit before the end of the string diff --git a/cpp/tests/strings/extract_tests.cpp b/cpp/tests/strings/extract_tests.cpp index 312341d6559..70112f7ca75 100644 --- a/cpp/tests/strings/extract_tests.cpp +++ b/cpp/tests/strings/extract_tests.cpp @@ -226,7 +226,7 @@ TEST_F(StringsExtractTests, EmptyExtractTest) TEST_F(StringsExtractTests, ExtractAllTest) { std::vector h_input( - {"123 banana 7 eleven", "41 apple", "6 pear 0 pair", nullptr, "", "bees", "4 pare"}); + {"123 banana 7 eleven", "41 apple", "6 péar 0 pair", nullptr, "", "bees", "4 paré"}); auto validity = thrust::make_transform_iterator(h_input.begin(), [](auto str) { return str != nullptr; }); cudf::test::strings_column_wrapper input(h_input.begin(), h_input.end(), validity); @@ -238,11 +238,11 @@ TEST_F(StringsExtractTests, ExtractAllTest) using LCW = cudf::test::lists_column_wrapper; LCW expected({LCW{"123", "banana", "7", "eleven"}, LCW{"41", "apple"}, - LCW{"6", "pear", "0", "pair"}, + LCW{"6", "péar", "0", "pair"}, LCW{}, LCW{}, LCW{}, - LCW{"4", "pare"}}, + LCW{"4", "paré"}}, valids); auto prog = cudf::strings::regex_program::create(pattern); auto results = cudf::strings::extract_all_record(sv, *prog); diff --git a/cpp/tests/strings/findall_tests.cpp b/cpp/tests/strings/findall_tests.cpp index c7eddb69ee7..fe27beed197 100644 --- a/cpp/tests/strings/findall_tests.cpp +++ b/cpp/tests/strings/findall_tests.cpp @@ -69,12 +69,12 @@ TEST_F(StringsFindallTests, Multiline) TEST_F(StringsFindallTests, DotAll) { - cudf::test::strings_column_wrapper input({"abc\nfa\nef", "fff\nabbc\nfff", "abcdef", ""}); + cudf::test::strings_column_wrapper input({"abc\nfa\nef", "fff\nabbc\nfff", "abcdéf", ""}); auto view = cudf::strings_column_view(input); auto pattern = std::string("(b.*f)"); using LCW = cudf::test::lists_column_wrapper; - LCW expected({LCW{"bc\nfa\nef"}, LCW{"bbc\nfff"}, LCW{"bcdef"}, LCW{}}); + LCW expected({LCW{"bc\nfa\nef"}, LCW{"bbc\nfff"}, LCW{"bcdéf"}, LCW{}}); auto prog = cudf::strings::regex_program::create(pattern, cudf::strings::regex_flags::DOTALL); auto results = cudf::strings::findall(view, *prog); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(results->view(), expected);