Skip to content

Commit

Permalink
Add cudf::strings::extract_all API (#9909)
Browse files Browse the repository at this point in the history
Closes #9856 

Adds a new `cudf::strings::extract_all` API that returns a LIST column of extracted strings given a regex pattern.

This is similar to nvstrings version of `extract` called `extract_record` but returns groups from all matches in each string instead of just the first match. Here is pseudo code of it's behavior on various strings input:
```
s = [ "ABC-200 DEF-400", "GHI-60", "JK-800", "900", NULL ]
r =  extract_all( s, "'(\w+)-(\d+)" )
r is a LIST column of strings that looks like this:

[ [ "ABC", "200", "DEF", "400" ], // 2 matches
  [ "GHI", "60" ], // 1 match
  [ "JK", "800" ], // 1 match
  NULL,            // no match
  NULL
]
```
Each match results in two groups as specified in the regex pattern.

Also reorganized the extract source code into `src/strings/extract` directory.
The match-counting has been factored out into new `count_matches.cuh` since it will become common code used with `findall_record` in a follow on PR.

Authors:
  - David Wendt (https://github.com/davidwendt)

Approvers:
  - Robert Maynard (https://github.com/robertmaynard)
  - Bradley Dice (https://github.com/bdice)
  - Mike Wilson (https://github.com/hyperbolic2346)

URL: #9909
  • Loading branch information
davidwendt authored Jan 5, 2022
1 parent 2112757 commit eba4f03
Show file tree
Hide file tree
Showing 7 changed files with 386 additions and 16 deletions.
3 changes: 2 additions & 1 deletion cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,8 @@ add_library(
src/strings/copying/concatenate.cu
src/strings/copying/copying.cu
src/strings/copying/shift.cu
src/strings/extract.cu
src/strings/extract/extract.cu
src/strings/extract/extract_all.cu
src/strings/filling/fill.cu
src/strings/filter_chars.cu
src/strings/findall.cu
Expand Down
8 changes: 6 additions & 2 deletions cpp/include/cudf/strings/detail/strings_column_factories.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ namespace cudf {
namespace strings {
namespace detail {

/**
* @brief Basic type expected for iterators passed to `make_strings_column` that represent string
* data in device memory.
*/
using string_index_pair = thrust::pair<const char*, size_type>;

/**
* @brief Average string byte-length threshold for deciding character-level
* vs. row-level parallel algorithm.
Expand Down Expand Up @@ -64,8 +70,6 @@ std::unique_ptr<column> make_strings_column(IndexPairIterator begin,
size_type strings_count = thrust::distance(begin, end);
if (strings_count == 0) return make_empty_column(type_id::STRING);

using string_index_pair = thrust::pair<const char*, size_type>;

// check total size is not too large for cudf column
auto size_checker = [] __device__(string_index_pair const& item) {
return (item.first != nullptr) ? item.second : 0;
Expand Down
50 changes: 42 additions & 8 deletions cpp/include/cudf/strings/extract.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019, NVIDIA CORPORATION.
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -27,20 +27,21 @@ namespace strings {
*/

/**
* @brief Returns a vector of strings columns for each matching group specified in the given regular
* expression pattern.
* @brief Returns a table of strings columns where each column corresponds to the matching
* group specified in the given regular expression pattern.
*
* All the strings for the first group will go in the first output column; the second group
* go in the second column and so on. Null entries are added if the string does match.
* go in the second column and so on. Null entries are added to the columns in row `i` if
* the string at row `i` does not match.
*
* Any null string entries return corresponding null output column entries.
*
* @code{.pseudo}
* Example:
* s = ["a1","b2","c3"]
* r = extract(s,"([ab])(\\d)")
* r is now [["a","b",null],
* ["1","2",null]]
* s = ["a1", "b2", "c3"]
* r = extract(s, "([ab])(\\d)")
* r is now [ ["a", "b", null],
* ["1", "2", null] ]
* @endcode
*
* See the @ref md_regex "Regex Features" page for details on patterns supported by this API.
Expand All @@ -55,6 +56,39 @@ std::unique_ptr<table> extract(
std::string const& pattern,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

/**
* @brief Returns a lists column of strings where each string column row corresponds to the
* matching group specified in the given regular expression pattern.
*
* All the matching groups for the first row will go in the first row output column; the second
* row results will go into the second row output column and so on.
*
* A null output row will result if the corresponding input string row does not match or
* that input row is null.
*
* @code{.pseudo}
* Example:
* s = ["a1 b4", "b2", "c3 a5", "b", null]
* r = extract_all(s,"([ab])(\\d)")
* r is now [ ["a", "1", "b", "4"],
* ["b", "2"],
* ["a", "5"],
* null,
* null ]
* @endcode
*
* See the @ref md_regex "Regex Features" page for details on patterns supported by this API.
*
* @param strings Strings instance for this operation.
* @param pattern The regular expression pattern with group indicators.
* @param mr Device memory resource used to allocate any returned device memory.
* @return Lists column containing strings extracted from the input column.
*/
std::unique_ptr<column> extract_all(
strings_column_view const& strings,
std::string const& pattern,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

/** @} */ // end of doxygen group
} // namespace strings
} // namespace cudf
105 changes: 105 additions & 0 deletions cpp/src/strings/count_matches.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <strings/regex/regex.cuh>

#include <cudf/column/column.hpp>
#include <cudf/column/column_device_view.cuh>
#include <cudf/column/column_factories.hpp>
#include <cudf/strings/string_view.cuh>

#include <rmm/cuda_stream_view.hpp>
#include <rmm/exec_policy.hpp>

#include <thrust/transform.h>

namespace cudf {
namespace strings {
namespace detail {

/**
* @brief Functor counts the total matches to the given regex in each string.
*/
template <int stack_size>
struct count_matches_fn {
column_device_view const d_strings;
reprog_device prog;

__device__ size_type operator()(size_type idx)
{
if (d_strings.is_null(idx)) { return 0; }
size_type count = 0;
auto const d_str = d_strings.element<string_view>(idx);

int32_t begin = 0;
int32_t end = d_str.length();
while ((begin < end) && (prog.find<stack_size>(idx, d_str, begin, end) > 0)) {
++count;
begin = end;
end = d_str.length();
}
return count;
}
};

/**
* @brief Returns a column of regex match counts for each string in the given column.
*
* A null entry will result in a zero count for that output row.
*
* @param d_strings Device view of the input strings column.
* @param d_prog Regex instance to evaluate on each string.
* @param stream CUDA stream used for device memory operations and kernel launches.
* @param mr Device memory resource used to allocate the returned column's device memory.
*/
std::unique_ptr<column> count_matches(
column_device_view const& d_strings,
reprog_device const& d_prog,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource())
{
// Create output column
auto counts = make_numeric_column(
data_type{type_id::INT32}, d_strings.size() + 1, mask_state::UNALLOCATED, stream, mr);
auto d_counts = counts->mutable_view().data<offset_type>();

auto begin = thrust::make_counting_iterator<size_type>(0);
auto end = thrust::make_counting_iterator<size_type>(d_strings.size());

// Count matches
auto const regex_insts = d_prog.insts_counts();
if (regex_insts <= RX_SMALL_INSTS) {
count_matches_fn<RX_STACK_SMALL> fn{d_strings, d_prog};
thrust::transform(rmm::exec_policy(stream), begin, end, d_counts, fn);
} else if (regex_insts <= RX_MEDIUM_INSTS) {
count_matches_fn<RX_STACK_MEDIUM> fn{d_strings, d_prog};
thrust::transform(rmm::exec_policy(stream), begin, end, d_counts, fn);
} else if (regex_insts <= RX_LARGE_INSTS) {
count_matches_fn<RX_STACK_LARGE> fn{d_strings, d_prog};
thrust::transform(rmm::exec_policy(stream), begin, end, d_counts, fn);
} else {
count_matches_fn<RX_STACK_ANY> fn{d_strings, d_prog};
thrust::transform(rmm::exec_policy(stream), begin, end, d_counts, fn);
}

return counts;
}

} // namespace detail
} // namespace strings
} // namespace cudf
File renamed without changes.
Loading

0 comments on commit eba4f03

Please sign in to comment.