From 35241632812f000baffb79a7a1c2457a78f4c652 Mon Sep 17 00:00:00 2001 From: David Wendt Date: Wed, 20 Oct 2021 07:22:24 -0400 Subject: [PATCH 01/18] Add support for single-line regex anchors ^/$ in contains_re --- conda/recipes/libcudf/meta.yaml | 3 +- cpp/include/cudf/strings/contains.hpp | 7 +++ cpp/include/cudf/strings/regex/flags.hpp | 44 +++++++++++++++++ cpp/src/strings/contains.cu | 23 ++++++--- cpp/src/strings/regex/regcomp.cpp | 56 +++++++++++++++++----- cpp/src/strings/regex/regcomp.h | 5 +- cpp/src/strings/regex/regex.cuh | 22 +++++++++ cpp/src/strings/regex/regexec.cu | 13 ++++- cpp/tests/strings/contains_tests.cpp | 60 ++++++++++++++++++++++++ 9 files changed, 210 insertions(+), 23 deletions(-) create mode 100644 cpp/include/cudf/strings/regex/flags.hpp diff --git a/conda/recipes/libcudf/meta.yaml b/conda/recipes/libcudf/meta.yaml index 8ccfdaa4aed..bec603e3cdd 100644 --- a/conda/recipes/libcudf/meta.yaml +++ b/conda/recipes/libcudf/meta.yaml @@ -203,8 +203,9 @@ test: - test -f $PREFIX/include/cudf/strings/find_multiple.hpp - test -f $PREFIX/include/cudf/strings/json.hpp - test -f $PREFIX/include/cudf/strings/padding.hpp + - test -f $PREFIX/include/cudf/strings/regex/flags.hpp - test -f $PREFIX/include/cudf/strings/repeat_strings.hpp - - test -f $PREFIX/include/cudf/strings/replace.hpp + - test -f $PREFIX/include/cudf/strings/replace.hpp - test -f $PREFIX/include/cudf/strings/replace_re.hpp - test -f $PREFIX/include/cudf/strings/split/partition.hpp - test -f $PREFIX/include/cudf/strings/split/split.hpp diff --git a/cpp/include/cudf/strings/contains.hpp b/cpp/include/cudf/strings/contains.hpp index a650fdc239a..a9de166a240 100644 --- a/cpp/include/cudf/strings/contains.hpp +++ b/cpp/include/cudf/strings/contains.hpp @@ -16,6 +16,7 @@ #pragma once #include +#include #include namespace cudf { @@ -44,12 +45,14 @@ namespace strings { * * @param strings Strings instance for this operation. * @param pattern Regex pattern to match to each string. + * @param flags Regex flags for processing the pattern * @param mr Device memory resource used to allocate the returned column's device memory. * @return New column of boolean results for each string. */ std::unique_ptr contains_re( strings_column_view const& strings, std::string const& pattern, + regex_flags const flags = regex_flags::DEFAULT, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); /** @@ -69,12 +72,14 @@ std::unique_ptr contains_re( * * @param strings Strings instance for this operation. * @param pattern Regex pattern to match to each string. + * @param flags Regex flags for processing the pattern * @param mr Device memory resource used to allocate the returned column's device memory. * @return New column of boolean results for each string. */ std::unique_ptr matches_re( strings_column_view const& strings, std::string const& pattern, + regex_flags const flags = regex_flags::DEFAULT, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); /** @@ -94,12 +99,14 @@ std::unique_ptr matches_re( * * @param strings Strings instance for this operation. * @param pattern Regex pattern to match within each string. + * @param flags Regex flags for processing the pattern * @param mr Device memory resource used to allocate the returned column's device memory. * @return New INT32 column with counts for each string. */ std::unique_ptr count_re( strings_column_view const& strings, std::string const& pattern, + regex_flags const flags = regex_flags::DEFAULT, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); /** @} */ // end of doxygen group diff --git a/cpp/include/cudf/strings/regex/flags.hpp b/cpp/include/cudf/strings/regex/flags.hpp new file mode 100644 index 00000000000..608f66d8ce2 --- /dev/null +++ b/cpp/include/cudf/strings/regex/flags.hpp @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include + +namespace cudf { +namespace strings { + +/** + * @addtogroup strings_contains + * @{ + */ + +/** + * @brief Regex flags. + * + * These types can be or'd to combine them. + */ +enum regex_flags : uint32_t { + DEFAULT = 0, /// default + SINGLE_LINE = 1, /// the '^' and '$' ignore new-line characters + DOT_ALL = 2 /// the '.' matching includes new-line characters +}; + +#define IS_SINGLE_LINE(f) ((f & regex_flags::SINGLE_LINE) == regex_flags::SINGLE_LINE) +#define IS_DOT_ALL(f) ((f & regex_flags::DOT_ALL) == regex_flags::DOT_ALL) + +/** @} */ // end of doxygen group +} // namespace strings +} // namespace cudf diff --git a/cpp/src/strings/contains.cu b/cpp/src/strings/contains.cu index 628dbcb8755..9376a0082a8 100644 --- a/cpp/src/strings/contains.cu +++ b/cpp/src/strings/contains.cu @@ -66,6 +66,7 @@ struct contains_fn { std::unique_ptr contains_util( strings_column_view const& strings, std::string const& pattern, + regex_flags const flags, bool beginning_only = false, rmm::cuda_stream_view stream = rmm::cuda_stream_default, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()) @@ -75,7 +76,8 @@ std::unique_ptr contains_util( auto d_column = *strings_column; // compile regex into device object - auto prog = reprog_device::create(pattern, get_character_flags_table(), strings_count, stream); + auto prog = + reprog_device::create(pattern, flags, get_character_flags_table(), strings_count, stream); auto d_prog = *prog; // create the output column @@ -123,19 +125,21 @@ std::unique_ptr contains_util( std::unique_ptr contains_re( strings_column_view const& strings, std::string const& pattern, + regex_flags const flags, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()) { - return contains_util(strings, pattern, false, stream, mr); + return contains_util(strings, pattern, flags, false, stream, mr); } std::unique_ptr matches_re( strings_column_view const& strings, std::string const& pattern, + regex_flags const flags, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()) { - return contains_util(strings, pattern, true, stream, mr); + return contains_util(strings, pattern, flags, true, stream, mr); } } // namespace detail @@ -144,18 +148,20 @@ std::unique_ptr matches_re( std::unique_ptr contains_re(strings_column_view const& strings, std::string const& pattern, + regex_flags const flags, rmm::mr::device_memory_resource* mr) { CUDF_FUNC_RANGE(); - return detail::contains_re(strings, pattern, rmm::cuda_stream_default, mr); + return detail::contains_re(strings, pattern, flags, rmm::cuda_stream_default, mr); } std::unique_ptr matches_re(strings_column_view const& strings, std::string const& pattern, + regex_flags const flags, rmm::mr::device_memory_resource* mr) { CUDF_FUNC_RANGE(); - return detail::matches_re(strings, pattern, rmm::cuda_stream_default, mr); + return detail::matches_re(strings, pattern, flags, rmm::cuda_stream_default, mr); } namespace detail { @@ -190,6 +196,7 @@ struct count_fn { std::unique_ptr count_re( strings_column_view const& strings, std::string const& pattern, + regex_flags const flags, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()) { @@ -198,7 +205,8 @@ std::unique_ptr count_re( auto d_column = *strings_column; // compile regex into device object - auto prog = reprog_device::create(pattern, get_character_flags_table(), strings_count, stream); + auto prog = + reprog_device::create(pattern, flags, get_character_flags_table(), strings_count, stream); auto d_prog = *prog; // create the output column @@ -247,10 +255,11 @@ std::unique_ptr count_re( std::unique_ptr count_re(strings_column_view const& strings, std::string const& pattern, + regex_flags const flags, rmm::mr::device_memory_resource* mr) { CUDF_FUNC_RANGE(); - return detail::count_re(strings, pattern, rmm::cuda_stream_default, mr); + return detail::count_re(strings, pattern, flags, rmm::cuda_stream_default, mr); } } // namespace strings diff --git a/cpp/src/strings/regex/regcomp.cpp b/cpp/src/strings/regex/regcomp.cpp index 0e00221dabf..9d51bc5325c 100644 --- a/cpp/src/strings/regex/regcomp.cpp +++ b/cpp/src/strings/regex/regcomp.cpp @@ -15,6 +15,7 @@ */ #include + #include #include @@ -523,6 +524,8 @@ class regex_compiler { bool lastwasand; int nbra; + regex_flags flags; + inline void pushand(int f, int l) { andstack.push_back({f, l}); } inline Node popand(int op) @@ -567,11 +570,11 @@ class regex_compiler { case LBRA: /* must have been RBRA */ op1 = popand('('); id_inst2 = m_prog.add_inst(RBRA); - m_prog.inst_at(id_inst2).u1.subid = ator.subid; // subidstack[subidstack.size()-1]; + m_prog.inst_at(id_inst2).u1.subid = ator.subid; m_prog.inst_at(op1.id_last).u2.next_id = id_inst2; id_inst1 = m_prog.add_inst(LBRA); - m_prog.inst_at(id_inst1).u1.subid = ator.subid; // subidstack[subidstack.size() - 1]; - m_prog.inst_at(id_inst1).u2.next_id = op1.id_first; + m_prog.inst_at(id_inst1).u1.subid = ator.subid; + m_prog.inst_at(id_inst1).u2.next_id = op1.id_first; pushand(id_inst1, id_inst2); return; case OR: @@ -664,10 +667,13 @@ class regex_compiler { { if (lastwasand) Operator(CAT); /* catenate is implicit */ int inst_id = m_prog.add_inst(t); - if (t == CCLASS || t == NCCLASS) + if (t == CCLASS || t == NCCLASS) { m_prog.inst_at(inst_id).u1.cls_id = yyclass_id; - else if (t == CHAR || t == BOL || t == EOL) + } else if (t == CHAR) { m_prog.inst_at(inst_id).u1.c = yy; + } else if (t == BOL || t == EOL) { + m_prog.inst_at(inst_id).u1.c = IS_SINGLE_LINE(flags) ? '\n' : yy; + } pushand(inst_id, inst_id); lastwasand = true; } @@ -766,13 +772,20 @@ class regex_compiler { } public: - regex_compiler(const char32_t* pattern, int dot_type, reprog& prog) - : m_prog(prog), cursubid(0), pushsubid(0), lastwasand(false), nbra(0), yy(0), yyclass_id(0) + regex_compiler(const char32_t* pattern, regex_flags const flags, reprog& prog) + : m_prog(prog), + cursubid(0), + pushsubid(0), + lastwasand(false), + nbra(0), + flags(flags), + yy(0), + yyclass_id(0) { // Parse std::vector items; { - regex_parser parser(pattern, dot_type, m_prog); + regex_parser parser(pattern, IS_DOT_ALL(flags) ? ANYNL : ANY, m_prog); // Expand counted repetitions if (parser.m_has_counted) @@ -822,11 +835,12 @@ class regex_compiler { }; // Convert pattern into program -reprog reprog::create_from(const char32_t* pattern) +reprog reprog::create_from(const char32_t* pattern, regex_flags const flags) { reprog rtn; - regex_compiler compiler(pattern, ANY, rtn); // future feature: ANYNL - // rtn->print(); + // regex_compiler compiler(pattern, ANY, rtn); // future feature: ANYNL + regex_compiler compiler(pattern, flags, rtn); + if (std::getenv("CUDF_REGEX_DEBUG")) rtn.print(); return rtn; } @@ -941,8 +955,24 @@ void reprog::print() case ANY: printf("ANY, nextid= %d", inst.u2.next_id); break; case ANYNL: printf("ANYNL, nextid= %d", inst.u2.next_id); break; case NOP: printf("NOP, nextid= %d", inst.u2.next_id); break; - case BOL: printf("BOL, c = '%c', nextid= %d", inst.u1.c, inst.u2.next_id); break; - case EOL: printf("EOL, c = '%c', nextid= %d", inst.u1.c, inst.u2.next_id); break; + case BOL: { + printf("BOL, c = "); + if (inst.u1.c == '\n') + printf("'\\n'"); + else + printf("'%c'", inst.u1.c); + printf(", nextid= %d", inst.u2.next_id); + break; + } + case EOL: { + printf("EOL, c = "); + if (inst.u1.c == '\n') + printf("'\\n'"); + else + printf("'%c'", inst.u1.c); + printf(", nextid= %d", inst.u2.next_id); + break; + } case CCLASS: printf("CCLASS, cls_id=%d , nextid= %d", inst.u1.cls_id, inst.u2.next_id); break; case NCCLASS: printf("NCCLASS, cls_id=%d , nextid= %d", inst.u1.cls_id, inst.u2.next_id); diff --git a/cpp/src/strings/regex/regcomp.h b/cpp/src/strings/regex/regcomp.h index 90bbc90f622..ea44987bcfb 100644 --- a/cpp/src/strings/regex/regcomp.h +++ b/cpp/src/strings/regex/regcomp.h @@ -14,6 +14,9 @@ * limitations under the License. */ #pragma once + +#include + #include #include @@ -89,7 +92,7 @@ class reprog { * @brief Parses the given regex pattern and compiles * into a list of chained instructions. */ - static reprog create_from(const char32_t* pattern); + static reprog create_from(const char32_t* pattern, regex_flags const flags); int32_t add_inst(int32_t type); int32_t add_inst(reinst inst); diff --git a/cpp/src/strings/regex/regex.cuh b/cpp/src/strings/regex/regex.cuh index 564f742b2cd..20be1829c2e 100644 --- a/cpp/src/strings/regex/regex.cuh +++ b/cpp/src/strings/regex/regex.cuh @@ -17,6 +17,7 @@ #include +#include #include #include @@ -108,6 +109,27 @@ class reprog_device { int32_t strings_count, rmm::cuda_stream_view stream); + /** + * @brief Create device program instance from a regex pattern. + * + * The number of strings is needed to compute the state data size required when evaluating the + * regex. + * + * @param pattern The regex pattern to compile. + * @param cp_flags The code-point lookup table for character types. + * @param strings_count Number of strings that will be evaluated. + * @param stream CUDA stream for asynchronous memory allocations. To ensure correct + * synchronization on destruction, the same stream should be used for all operations with the + * created objects. + * @return The program device object. + */ + static std::unique_ptr> create( + std::string const& pattern, + regex_flags const re_flags, + const uint8_t* cp_flags, + int32_t strings_count, + rmm::cuda_stream_view stream); + /** * @brief Called automatically by the unique_ptr returned from create(). */ diff --git a/cpp/src/strings/regex/regexec.cu b/cpp/src/strings/regex/regexec.cu index bd040eecaa6..8664a9fb5b5 100644 --- a/cpp/src/strings/regex/regexec.cu +++ b/cpp/src/strings/regex/regexec.cu @@ -72,16 +72,27 @@ reprog_device::reprog_device(reprog& prog) { } +std::unique_ptr> reprog_device::create( + std::string const& pattern, + uint8_t const* codepoint_flags, + int32_t strings_count, + rmm::cuda_stream_view stream) +{ + return reprog_device::create( + pattern, regex_flags::DEFAULT, codepoint_flags, strings_count, stream); +} + // Create instance of the reprog that can be passed into a device kernel std::unique_ptr> reprog_device::create( std::string const& pattern, + regex_flags const flags, uint8_t const* codepoint_flags, int32_t strings_count, rmm::cuda_stream_view stream) { std::vector pattern32 = string_to_char32_vector(pattern); // compile pattern into host object - reprog h_prog = reprog::create_from(pattern32.data()); + reprog h_prog = reprog::create_from(pattern32.data(), flags); // compute size to hold all the member data auto insts_count = h_prog.insts_count(); auto classes_count = h_prog.classes_count(); diff --git a/cpp/tests/strings/contains_tests.cpp b/cpp/tests/strings/contains_tests.cpp index ddd6fc9e1dc..48b509ad10a 100644 --- a/cpp/tests/strings/contains_tests.cpp +++ b/cpp/tests/strings/contains_tests.cpp @@ -275,6 +275,66 @@ TEST_F(StringsContainsTests, CountTest) } } +TEST_F(StringsContainsTests, SingleLine) +{ + auto input = cudf::test::strings_column_wrapper({"abc\nfff\nabc", "fff\nabc\nlll", "abc", ""}); + auto view = cudf::strings_column_view(input); + + auto results = cudf::strings::contains_re(view, "^abc$", cudf::strings::regex_flags::SINGLE_LINE); + auto expected_contains = cudf::test::fixed_width_column_wrapper({0, 0, 1, 0}); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*results, expected_contains); + results = cudf::strings::contains_re(view, "^abc$"); + expected_contains = cudf::test::fixed_width_column_wrapper({1, 1, 1, 0}); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*results, expected_contains); + + results = cudf::strings::matches_re(view, "^abc$", cudf::strings::regex_flags::SINGLE_LINE); + auto expected_matches = cudf::test::fixed_width_column_wrapper({0, 0, 1, 0}); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*results, expected_matches); + results = cudf::strings::matches_re(view, "^abc$"); + expected_matches = cudf::test::fixed_width_column_wrapper({1, 0, 1, 0}); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*results, expected_matches); + + results = cudf::strings::count_re(view, "^abc$", cudf::strings::regex_flags::SINGLE_LINE); + auto expected_count = cudf::test::fixed_width_column_wrapper({0, 0, 1, 0}); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*results, expected_count); + results = cudf::strings::count_re(view, "^abc$"); + expected_count = cudf::test::fixed_width_column_wrapper({2, 1, 1, 0}); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*results, expected_count); +} + +TEST_F(StringsContainsTests, DotAll) +{ + auto input = cudf::test::strings_column_wrapper({"abc\nfa\nef", "fff\nabbc\nfff", "abcdef", ""}); + auto view = cudf::strings_column_view(input); + + auto results = cudf::strings::contains_re(view, "a.*f", cudf::strings::regex_flags::DOT_ALL); + auto expected_contains = cudf::test::fixed_width_column_wrapper({1, 1, 1, 0}); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*results, expected_contains); + results = cudf::strings::contains_re(view, "a.*f"); + expected_contains = cudf::test::fixed_width_column_wrapper({0, 0, 1, 0}); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*results, expected_contains); + + results = cudf::strings::matches_re(view, "a.*f", cudf::strings::regex_flags::DOT_ALL); + auto expected_matches = cudf::test::fixed_width_column_wrapper({1, 0, 1, 0}); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*results, expected_matches); + results = cudf::strings::matches_re(view, "a.*f"); + expected_matches = cudf::test::fixed_width_column_wrapper({0, 0, 1, 0}); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*results, expected_matches); + + results = cudf::strings::count_re(view, "a.*?f", cudf::strings::regex_flags::DOT_ALL); + auto expected_count = cudf::test::fixed_width_column_wrapper({2, 1, 1, 0}); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*results, expected_count); + results = cudf::strings::count_re(view, "a.*?f"); + expected_count = cudf::test::fixed_width_column_wrapper({0, 0, 1, 0}); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*results, expected_count); + + auto both_flags = cudf::strings::regex_flags::DOT_ALL | cudf::strings::regex_flags::SINGLE_LINE; + results = + cudf::strings::count_re(view, "a.*?f", static_cast(both_flags)); + expected_count = cudf::test::fixed_width_column_wrapper({2, 1, 1, 0}); + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*results, expected_count); +} + TEST_F(StringsContainsTests, MediumRegex) { // This results in 95 regex instructions and falls in the 'medium' range. From bfb7b6c577157debca500db51c44e53d088259ff Mon Sep 17 00:00:00 2001 From: David Wendt Date: Wed, 20 Oct 2021 08:19:22 -0400 Subject: [PATCH 02/18] fix meta.yaml --- conda/recipes/libcudf/meta.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/conda/recipes/libcudf/meta.yaml b/conda/recipes/libcudf/meta.yaml index bec603e3cdd..b10a2a2ca54 100644 --- a/conda/recipes/libcudf/meta.yaml +++ b/conda/recipes/libcudf/meta.yaml @@ -205,7 +205,7 @@ test: - test -f $PREFIX/include/cudf/strings/padding.hpp - test -f $PREFIX/include/cudf/strings/regex/flags.hpp - test -f $PREFIX/include/cudf/strings/repeat_strings.hpp - - test -f $PREFIX/include/cudf/strings/replace.hpp + - test -f $PREFIX/include/cudf/strings/replace.hpp - test -f $PREFIX/include/cudf/strings/replace_re.hpp - test -f $PREFIX/include/cudf/strings/split/partition.hpp - test -f $PREFIX/include/cudf/strings/split/split.hpp From 8690abea84dab6a598c98346238f280b0523c66a Mon Sep 17 00:00:00 2001 From: David Wendt Date: Wed, 20 Oct 2021 09:54:12 -0400 Subject: [PATCH 03/18] change default to !multiline --- cpp/include/cudf/strings/regex/flags.hpp | 10 ++++----- cpp/src/strings/regex/regcomp.cpp | 4 ++-- cpp/src/strings/regex/regexec.cu | 2 +- cpp/tests/strings/contains_tests.cpp | 28 ++++++++++++------------ 4 files changed, 22 insertions(+), 22 deletions(-) diff --git a/cpp/include/cudf/strings/regex/flags.hpp b/cpp/include/cudf/strings/regex/flags.hpp index 608f66d8ce2..65b16b24b41 100644 --- a/cpp/include/cudf/strings/regex/flags.hpp +++ b/cpp/include/cudf/strings/regex/flags.hpp @@ -31,13 +31,13 @@ namespace strings { * These types can be or'd to combine them. */ enum regex_flags : uint32_t { - DEFAULT = 0, /// default - SINGLE_LINE = 1, /// the '^' and '$' ignore new-line characters - DOT_ALL = 2 /// the '.' matching includes new-line characters + DEFAULT = 0, /// default + MULTILINE = 1, /// the '^' and '$' honor new-line characters + DOTALL = 2 /// the '.' matching includes new-line characters }; -#define IS_SINGLE_LINE(f) ((f & regex_flags::SINGLE_LINE) == regex_flags::SINGLE_LINE) -#define IS_DOT_ALL(f) ((f & regex_flags::DOT_ALL) == regex_flags::DOT_ALL) +#define IS_MULTILINE(f) ((f & regex_flags::MULTILINE) == regex_flags::MULTILINE) +#define IS_DOTALL(f) ((f & regex_flags::DOTALL) == regex_flags::DOTALL) /** @} */ // end of doxygen group } // namespace strings diff --git a/cpp/src/strings/regex/regcomp.cpp b/cpp/src/strings/regex/regcomp.cpp index 9d51bc5325c..a2f259cacdb 100644 --- a/cpp/src/strings/regex/regcomp.cpp +++ b/cpp/src/strings/regex/regcomp.cpp @@ -672,7 +672,7 @@ class regex_compiler { } else if (t == CHAR) { m_prog.inst_at(inst_id).u1.c = yy; } else if (t == BOL || t == EOL) { - m_prog.inst_at(inst_id).u1.c = IS_SINGLE_LINE(flags) ? '\n' : yy; + m_prog.inst_at(inst_id).u1.c = IS_MULTILINE(flags) ? yy : '\n'; } pushand(inst_id, inst_id); lastwasand = true; @@ -785,7 +785,7 @@ class regex_compiler { // Parse std::vector items; { - regex_parser parser(pattern, IS_DOT_ALL(flags) ? ANYNL : ANY, m_prog); + regex_parser parser(pattern, IS_DOTALL(flags) ? ANYNL : ANY, m_prog); // Expand counted repetitions if (parser.m_has_counted) diff --git a/cpp/src/strings/regex/regexec.cu b/cpp/src/strings/regex/regexec.cu index 8664a9fb5b5..604cc5b5264 100644 --- a/cpp/src/strings/regex/regexec.cu +++ b/cpp/src/strings/regex/regexec.cu @@ -79,7 +79,7 @@ std::unique_ptr> reprog_devic rmm::cuda_stream_view stream) { return reprog_device::create( - pattern, regex_flags::DEFAULT, codepoint_flags, strings_count, stream); + pattern, regex_flags::MULTILINE, codepoint_flags, strings_count, stream); } // Create instance of the reprog that can be passed into a device kernel diff --git a/cpp/tests/strings/contains_tests.cpp b/cpp/tests/strings/contains_tests.cpp index 48b509ad10a..0eabd0d144e 100644 --- a/cpp/tests/strings/contains_tests.cpp +++ b/cpp/tests/strings/contains_tests.cpp @@ -275,30 +275,30 @@ TEST_F(StringsContainsTests, CountTest) } } -TEST_F(StringsContainsTests, SingleLine) +TEST_F(StringsContainsTests, MultiLine) { auto input = cudf::test::strings_column_wrapper({"abc\nfff\nabc", "fff\nabc\nlll", "abc", ""}); auto view = cudf::strings_column_view(input); - auto results = cudf::strings::contains_re(view, "^abc$", cudf::strings::regex_flags::SINGLE_LINE); - auto expected_contains = cudf::test::fixed_width_column_wrapper({0, 0, 1, 0}); + auto results = cudf::strings::contains_re(view, "^abc$", cudf::strings::regex_flags::MULTILINE); + auto expected_contains = cudf::test::fixed_width_column_wrapper({1, 1, 1, 0}); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*results, expected_contains); results = cudf::strings::contains_re(view, "^abc$"); - expected_contains = cudf::test::fixed_width_column_wrapper({1, 1, 1, 0}); + expected_contains = cudf::test::fixed_width_column_wrapper({0, 0, 1, 0}); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*results, expected_contains); - results = cudf::strings::matches_re(view, "^abc$", cudf::strings::regex_flags::SINGLE_LINE); - auto expected_matches = cudf::test::fixed_width_column_wrapper({0, 0, 1, 0}); + results = cudf::strings::matches_re(view, "^abc$", cudf::strings::regex_flags::MULTILINE); + auto expected_matches = cudf::test::fixed_width_column_wrapper({1, 0, 1, 0}); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*results, expected_matches); results = cudf::strings::matches_re(view, "^abc$"); - expected_matches = cudf::test::fixed_width_column_wrapper({1, 0, 1, 0}); + expected_matches = cudf::test::fixed_width_column_wrapper({0, 0, 1, 0}); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*results, expected_matches); - results = cudf::strings::count_re(view, "^abc$", cudf::strings::regex_flags::SINGLE_LINE); - auto expected_count = cudf::test::fixed_width_column_wrapper({0, 0, 1, 0}); + results = cudf::strings::count_re(view, "^abc$", cudf::strings::regex_flags::MULTILINE); + auto expected_count = cudf::test::fixed_width_column_wrapper({2, 1, 1, 0}); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*results, expected_count); results = cudf::strings::count_re(view, "^abc$"); - expected_count = cudf::test::fixed_width_column_wrapper({2, 1, 1, 0}); + expected_count = cudf::test::fixed_width_column_wrapper({0, 0, 1, 0}); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*results, expected_count); } @@ -307,28 +307,28 @@ TEST_F(StringsContainsTests, DotAll) auto input = cudf::test::strings_column_wrapper({"abc\nfa\nef", "fff\nabbc\nfff", "abcdef", ""}); auto view = cudf::strings_column_view(input); - auto results = cudf::strings::contains_re(view, "a.*f", cudf::strings::regex_flags::DOT_ALL); + auto results = cudf::strings::contains_re(view, "a.*f", cudf::strings::regex_flags::DOTALL); auto expected_contains = cudf::test::fixed_width_column_wrapper({1, 1, 1, 0}); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*results, expected_contains); results = cudf::strings::contains_re(view, "a.*f"); expected_contains = cudf::test::fixed_width_column_wrapper({0, 0, 1, 0}); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*results, expected_contains); - results = cudf::strings::matches_re(view, "a.*f", cudf::strings::regex_flags::DOT_ALL); + results = cudf::strings::matches_re(view, "a.*f", cudf::strings::regex_flags::DOTALL); auto expected_matches = cudf::test::fixed_width_column_wrapper({1, 0, 1, 0}); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*results, expected_matches); results = cudf::strings::matches_re(view, "a.*f"); expected_matches = cudf::test::fixed_width_column_wrapper({0, 0, 1, 0}); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*results, expected_matches); - results = cudf::strings::count_re(view, "a.*?f", cudf::strings::regex_flags::DOT_ALL); + results = cudf::strings::count_re(view, "a.*?f", cudf::strings::regex_flags::DOTALL); auto expected_count = cudf::test::fixed_width_column_wrapper({2, 1, 1, 0}); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*results, expected_count); results = cudf::strings::count_re(view, "a.*?f"); expected_count = cudf::test::fixed_width_column_wrapper({0, 0, 1, 0}); CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(*results, expected_count); - auto both_flags = cudf::strings::regex_flags::DOT_ALL | cudf::strings::regex_flags::SINGLE_LINE; + auto both_flags = cudf::strings::regex_flags::DOTALL | cudf::strings::regex_flags::MULTILINE; results = cudf::strings::count_re(view, "a.*?f", static_cast(both_flags)); expected_count = cudf::test::fixed_width_column_wrapper({2, 1, 1, 0}); From ac43c6a9caffc06705ab8adc915222cbb2b8a708 Mon Sep 17 00:00:00 2001 From: David Wendt Date: Wed, 20 Oct 2021 10:10:11 -0400 Subject: [PATCH 04/18] add regex_flags to contains.pxd --- python/cudf/cudf/_lib/cpp/strings/contains.pxd | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/python/cudf/cudf/_lib/cpp/strings/contains.pxd b/python/cudf/cudf/_lib/cpp/strings/contains.pxd index bde0b4fdfb7..0bf7d15a705 100644 --- a/python/cudf/cudf/_lib/cpp/strings/contains.pxd +++ b/python/cudf/cudf/_lib/cpp/strings/contains.pxd @@ -6,6 +6,13 @@ from libcpp.string cimport string from cudf._lib.cpp.column.column cimport column from cudf._lib.cpp.column.column_view cimport column_view +cdef extern from "cudf/strings/regex/flags.hpp" \ + namespace "cudf::strings" nogil: + + ctypedef enum regex_flags: + DEFAULT 'cudf::strings::regex_flags::DEFAULT' + MULTILINE 'cudf::strings::regex_flags::MULTILINE' + DOTALL 'cudf::strings::regex_flags::DOTALL' cdef extern from "cudf/strings/contains.hpp" namespace "cudf::strings" nogil: From d1e28e67a88dc053c2603d569215eaa52c5d457d Mon Sep 17 00:00:00 2001 From: David Wendt Date: Wed, 20 Oct 2021 10:56:08 -0400 Subject: [PATCH 05/18] fix isort style check failure in .pxd --- python/cudf/cudf/_lib/cpp/strings/contains.pxd | 1 + 1 file changed, 1 insertion(+) diff --git a/python/cudf/cudf/_lib/cpp/strings/contains.pxd b/python/cudf/cudf/_lib/cpp/strings/contains.pxd index 0bf7d15a705..84afb876222 100644 --- a/python/cudf/cudf/_lib/cpp/strings/contains.pxd +++ b/python/cudf/cudf/_lib/cpp/strings/contains.pxd @@ -6,6 +6,7 @@ from libcpp.string cimport string from cudf._lib.cpp.column.column cimport column from cudf._lib.cpp.column.column_view cimport column_view + cdef extern from "cudf/strings/regex/flags.hpp" \ namespace "cudf::strings" nogil: From d9447efe11ad2011ff740eebfd9903adb93155bd Mon Sep 17 00:00:00 2001 From: David Wendt Date: Wed, 20 Oct 2021 14:55:23 -0400 Subject: [PATCH 06/18] pass regex flags through from python/cython --- cpp/include/cudf/strings/regex/flags.hpp | 6 ++++-- cpp/src/strings/regex/regcomp.cpp | 5 +++-- cpp/src/strings/regex/regcomp.h | 2 +- .../cudf/cudf/_lib/cpp/strings/contains.pxd | 9 ++++++--- python/cudf/cudf/_lib/strings/contains.pyx | 20 +++++++++++++------ python/cudf/cudf/core/column/string.py | 16 +++++++-------- python/cudf/cudf/tests/test_string.py | 2 +- 7 files changed, 36 insertions(+), 24 deletions(-) diff --git a/cpp/include/cudf/strings/regex/flags.hpp b/cpp/include/cudf/strings/regex/flags.hpp index 65b16b24b41..f459c5bdad0 100644 --- a/cpp/include/cudf/strings/regex/flags.hpp +++ b/cpp/include/cudf/strings/regex/flags.hpp @@ -29,11 +29,13 @@ namespace strings { * @brief Regex flags. * * These types can be or'd to combine them. + * The values are chosen to leave room for future flags + * and to match the Python flag values. */ enum regex_flags : uint32_t { DEFAULT = 0, /// default - MULTILINE = 1, /// the '^' and '$' honor new-line characters - DOTALL = 2 /// the '.' matching includes new-line characters + MULTILINE = 8, /// the '^' and '$' honor new-line characters + DOTALL = 16 /// the '.' matching includes new-line characters }; #define IS_MULTILINE(f) ((f & regex_flags::MULTILINE) == regex_flags::MULTILINE) diff --git a/cpp/src/strings/regex/regcomp.cpp b/cpp/src/strings/regex/regcomp.cpp index a2f259cacdb..ab8b01323eb 100644 --- a/cpp/src/strings/regex/regcomp.cpp +++ b/cpp/src/strings/regex/regcomp.cpp @@ -840,7 +840,7 @@ reprog reprog::create_from(const char32_t* pattern, regex_flags const flags) reprog rtn; // regex_compiler compiler(pattern, ANY, rtn); // future feature: ANYNL regex_compiler compiler(pattern, flags, rtn); - if (std::getenv("CUDF_REGEX_DEBUG")) rtn.print(); + if (std::getenv("CUDF_REGEX_DEBUG")) rtn.print(flags); return rtn; } @@ -926,8 +926,9 @@ void reprog::optimize2() _startinst_ids.push_back(-1); // terminator mark } -void reprog::print() +void reprog::print(regex_flags const flags) { + printf("Flags = 0x%02x\n", static_cast(flags)); printf("Instructions:\n"); for (std::size_t i = 0; i < _insts.size(); i++) { const reinst& inst = _insts[i]; diff --git a/cpp/src/strings/regex/regcomp.h b/cpp/src/strings/regex/regcomp.h index ea44987bcfb..63d7933eebe 100644 --- a/cpp/src/strings/regex/regcomp.h +++ b/cpp/src/strings/regex/regcomp.h @@ -116,7 +116,7 @@ class reprog { void optimize1(); void optimize2(); - void print(); // for debugging + void print(regex_flags const flags); private: std::vector _insts; diff --git a/python/cudf/cudf/_lib/cpp/strings/contains.pxd b/python/cudf/cudf/_lib/cpp/strings/contains.pxd index 84afb876222..b48d2f58334 100644 --- a/python/cudf/cudf/_lib/cpp/strings/contains.pxd +++ b/python/cudf/cudf/_lib/cpp/strings/contains.pxd @@ -19,12 +19,15 @@ cdef extern from "cudf/strings/contains.hpp" namespace "cudf::strings" nogil: cdef unique_ptr[column] contains_re( column_view source_strings, - string pattern) except + + string pattern, + regex_flags) except + cdef unique_ptr[column] count_re( column_view source_strings, - string pattern) except + + string pattern, + regex_flags) except + cdef unique_ptr[column] matches_re( column_view source_strings, - string pattern) except + + string pattern, + regex_flags) except + diff --git a/python/cudf/cudf/_lib/strings/contains.pyx b/python/cudf/cudf/_lib/strings/contains.pyx index 1f622378280..f18d0eb7f36 100644 --- a/python/cudf/cudf/_lib/strings/contains.pyx +++ b/python/cudf/cudf/_lib/strings/contains.pyx @@ -1,5 +1,6 @@ # Copyright (c) 2020, NVIDIA CORPORATION. +from libc.stdint cimport uint32_t from libcpp.memory cimport unique_ptr from libcpp.string cimport string from libcpp.utility cimport move @@ -11,11 +12,12 @@ from cudf._lib.cpp.strings.contains cimport ( contains_re as cpp_contains_re, count_re as cpp_count_re, matches_re as cpp_matches_re, + regex_flags as regex_flags, ) from cudf._lib.scalar cimport DeviceScalar -def contains_re(Column source_strings, object reg_ex): +def contains_re(Column source_strings, object reg_ex, uint32_t flags): """ Returns a Column of boolean values with True for `source_strings` that contain regular expression `reg_ex`. @@ -24,17 +26,19 @@ def contains_re(Column source_strings, object reg_ex): cdef column_view source_view = source_strings.view() cdef string reg_ex_string = str(reg_ex).encode() + cdef regex_flags c_flags = flags with nogil: c_result = move(cpp_contains_re( source_view, - reg_ex_string + reg_ex_string, + c_flags )) return Column.from_unique_ptr(move(c_result)) -def count_re(Column source_strings, object reg_ex): +def count_re(Column source_strings, object reg_ex, uint32_t flags): """ Returns a Column with count of occurrences of `reg_ex` in each string of `source_strings` @@ -43,17 +47,19 @@ def count_re(Column source_strings, object reg_ex): cdef column_view source_view = source_strings.view() cdef string reg_ex_string = str(reg_ex).encode() + cdef regex_flags c_flags = flags with nogil: c_result = move(cpp_count_re( source_view, - reg_ex_string + reg_ex_string, + c_flags )) return Column.from_unique_ptr(move(c_result)) -def match_re(Column source_strings, object reg_ex): +def match_re(Column source_strings, object reg_ex, uint32_t flags): """ Returns a Column with each value True if the string matches `reg_ex` regular expression with each record of `source_strings` @@ -62,11 +68,13 @@ def match_re(Column source_strings, object reg_ex): cdef column_view source_view = source_strings.view() cdef string reg_ex_string = str(reg_ex).encode() + cdef regex_flags c_flags = flags with nogil: c_result = move(cpp_matches_re( source_view, - reg_ex_string + reg_ex_string, + c_flags )) return Column.from_unique_ptr(move(c_result)) diff --git a/python/cudf/cudf/core/column/string.py b/python/cudf/cudf/core/column/string.py index 81d4c9adfa1..cedef6f6f7b 100644 --- a/python/cudf/cudf/core/column/string.py +++ b/python/cudf/cudf/core/column/string.py @@ -731,8 +731,6 @@ def contains( """ # noqa W605 if case is not True: raise NotImplementedError("`case` parameter is not yet supported") - elif flags != 0: - raise NotImplementedError("`flags` parameter is not yet supported") elif na is not np.nan: raise NotImplementedError("`na` parameter is not yet supported") @@ -742,7 +740,7 @@ def contains( ) elif is_scalar(pat): if regex is True: - result_col = libstrings.contains_re(self._column, pat) + result_col = libstrings.contains_re(self._column, pat, flags) else: result_col = libstrings.contains( self._column, cudf.Scalar(pat, "str") @@ -3325,10 +3323,10 @@ def count(self, pat: str, flags: int = 0) -> SeriesOrIndex: >>> index.str.count('a') Int64Index([0, 0, 2, 1], dtype='int64') """ # noqa W605 - if flags != 0: - raise NotImplementedError("`flags` parameter is not yet supported") - return self._return_or_inplace(libstrings.count_re(self._column, pat)) + return self._return_or_inplace( + libstrings.count_re(self._column, pat, flags) + ) def findall( self, pat: str, flags: int = 0, expand: bool = True @@ -3879,10 +3877,10 @@ def match( """ if case is not True: raise NotImplementedError("`case` parameter is not yet supported") - if flags != 0: - raise NotImplementedError("`flags` parameter is not yet supported") - return self._return_or_inplace(libstrings.match_re(self._column, pat)) + return self._return_or_inplace( + libstrings.match_re(self._column, pat, flags) + ) def url_decode(self) -> SeriesOrIndex: """ diff --git a/python/cudf/cudf/tests/test_string.py b/python/cudf/cudf/tests/test_string.py index dad0e7581d7..546707f5667 100644 --- a/python/cudf/cudf/tests/test_string.py +++ b/python/cudf/cudf/tests/test_string.py @@ -834,7 +834,7 @@ def test_string_extract(ps_gs, pat, expand, flags, flags_raise): ("FGHI", False), ], ) -@pytest.mark.parametrize("flags,flags_raise", [(0, 0), (1, 1)]) +@pytest.mark.parametrize("flags,flags_raise", [(0, 0), (8, 0)]) @pytest.mark.parametrize("na,na_raise", [(np.nan, 0), (None, 1), ("", 1)]) def test_string_contains(ps_gs, pat, regex, flags, flags_raise, na, na_raise): ps, gs = ps_gs From 96766d17b2d392815fec560b2bb44c38c757b24e Mon Sep 17 00:00:00 2001 From: David Wendt Date: Thu, 21 Oct 2021 11:53:11 -0400 Subject: [PATCH 07/18] update python and pytests for flags checking --- python/cudf/cudf/core/column/string.py | 23 ++++++++++++++++++----- python/cudf/cudf/tests/test_string.py | 18 +++++++++++++----- 2 files changed, 31 insertions(+), 10 deletions(-) diff --git a/python/cudf/cudf/core/column/string.py b/python/cudf/cudf/core/column/string.py index cedef6f6f7b..f272ce63568 100644 --- a/python/cudf/cudf/core/column/string.py +++ b/python/cudf/cudf/core/column/string.py @@ -21,6 +21,7 @@ import numpy as np import pandas as pd import pyarrow as pa +import regex as rex from numba import cuda import cudf @@ -651,9 +652,11 @@ def contains( Notes ----- - The parameters `case`, `flags`, and `na` are not yet supported and - will raise a NotImplementedError if anything other than the default + The parameters `case` and `na` are not yet supported and will + raise a NotImplementedError if anything other than the default value is set. + The `flags` parameter currently only supports re.DOTALL and + re.MULTILINE. Examples -------- @@ -731,8 +734,10 @@ def contains( """ # noqa W605 if case is not True: raise NotImplementedError("`case` parameter is not yet supported") - elif na is not np.nan: + if na is not np.nan: raise NotImplementedError("`na` parameter is not yet supported") + if flags != 0 and (flags & (rex.MULTILINE | rex.DOTALL) == 0): + raise NotImplementedError("invalid `flags` parameter value") if pat is None: result_col = column.column_empty( @@ -3286,7 +3291,8 @@ def count(self, pat: str, flags: int = 0) -> SeriesOrIndex: Notes ----- - - `flags` parameter is currently not supported. + - `flags` parameter currently only supports re.DOTALL + and re.MULTILINE. - Some characters need to be escaped when passing in pat. eg. ``'$'`` has a special meaning in regex and must be escaped when finding this literal character. @@ -3323,6 +3329,8 @@ def count(self, pat: str, flags: int = 0) -> SeriesOrIndex: >>> index.str.count('a') Int64Index([0, 0, 2, 1], dtype='int64') """ # noqa W605 + if flags != 0 and (flags & (rex.MULTILINE | rex.DOTALL) == 0): + raise NotImplementedError("invalid `flags` parameter value") return self._return_or_inplace( libstrings.count_re(self._column, pat, flags) @@ -3852,7 +3860,10 @@ def match( Notes ----- - Parameters currently not supported are: `case`, `flags` and `na`. + Parameters `case` and `na` are currently not supported. + The `flags` parameter currently only supports re.DOTALL and + re.MULTILINE. + Examples -------- @@ -3877,6 +3888,8 @@ def match( """ if case is not True: raise NotImplementedError("`case` parameter is not yet supported") + if flags != 0 and (flags & (rex.MULTILINE | rex.DOTALL) == 0): + raise NotImplementedError("invalid `flags` parameter value") return self._return_or_inplace( libstrings.match_re(self._column, pat, flags) diff --git a/python/cudf/cudf/tests/test_string.py b/python/cudf/cudf/tests/test_string.py index 546707f5667..c8305ce1114 100644 --- a/python/cudf/cudf/tests/test_string.py +++ b/python/cudf/cudf/tests/test_string.py @@ -12,6 +12,7 @@ import pandas as pd import pyarrow as pa import pytest +import regex as rex import cudf from cudf import concat @@ -834,7 +835,7 @@ def test_string_extract(ps_gs, pat, expand, flags, flags_raise): ("FGHI", False), ], ) -@pytest.mark.parametrize("flags,flags_raise", [(0, 0), (8, 0)]) +@pytest.mark.parametrize("flags,flags_raise", [(0, 0), (24, 0), (1, 1)]) @pytest.mark.parametrize("na,na_raise", [(np.nan, 0), (None, 1), ("", 1)]) def test_string_contains(ps_gs, pat, regex, flags, flags_raise, na, na_raise): ps, gs = ps_gs @@ -1728,15 +1729,22 @@ def test_string_wrap(data, width): ["23", "³", "⅕", ""], [" ", "\t\r\n ", ""], ["$", "B", "Aab$", "$$ca", "C$B$", "cat"], - ["line to be wrapped", "another line to be wrapped"], + ["line\nto be wrapped", "another\nline\nto be wrapped"], ], ) -@pytest.mark.parametrize("pat", ["a", " ", "\t", "another", "0", r"\$"]) -def test_string_count(data, pat): +@pytest.mark.parametrize( + "pat", ["a", " ", "\t", "another", "0", r"\$", "^line$", "line.*be"] +) +@pytest.mark.parametrize("flags", [0, rex.MULTILINE, rex.DOTALL]) +def test_string_count(data, pat, flags): gs = cudf.Series(data) ps = pd.Series(data) - assert_eq(gs.str.count(pat=pat), ps.str.count(pat=pat), check_dtype=False) + assert_eq( + gs.str.count(pat=pat, flags=flags), + ps.str.count(pat=pat, flags=flags), + check_dtype=False, + ) assert_eq(as_index(gs).str.count(pat=pat), pd.Index(ps).str.count(pat=pat)) From 5a8a2f65dd027946b4c10427e2cfb152ea17929c Mon Sep 17 00:00:00 2001 From: David Wendt Date: Thu, 21 Oct 2021 14:27:16 -0400 Subject: [PATCH 08/18] add flags doc to string.py --- python/cudf/cudf/core/column/string.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/python/cudf/cudf/core/column/string.py b/python/cudf/cudf/core/column/string.py index f272ce63568..4c9fb2a2515 100644 --- a/python/cudf/cudf/core/column/string.py +++ b/python/cudf/cudf/core/column/string.py @@ -639,6 +639,8 @@ def contains( Character sequence or regular expression. If ``pat`` is list-like then regular expressions are not accepted. + flags : int, default 0 (no flags) + Flags to pass through to the regex engine (e.g. re.MULTILINE) regex : bool, default True If True, assumes the pattern is a regular expression. If False, treats the pattern as a literal string. @@ -3284,6 +3286,8 @@ def count(self, pat: str, flags: int = 0) -> SeriesOrIndex: ---------- pat : str Valid regular expression. + flags : int, default 0 (no flags) + Flags to pass through to the regex engine (e.g. re.MULTILINE) Returns ------- @@ -3853,6 +3857,8 @@ def match( ---------- pat : str Character sequence or regular expression. + flags : int, default 0 (no flags) + Flags to pass through to the regex engine (e.g. re.MULTILINE) Returns ------- From 452c1ae157d5a5c7c6d7b8a652570bc31a74c0b9 Mon Sep 17 00:00:00 2001 From: David Wendt Date: Thu, 21 Oct 2021 15:51:22 -0400 Subject: [PATCH 09/18] use rex values instead of hardcoded int --- python/cudf/cudf/tests/test_string.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/cudf/cudf/tests/test_string.py b/python/cudf/cudf/tests/test_string.py index c8305ce1114..3148ec6bac9 100644 --- a/python/cudf/cudf/tests/test_string.py +++ b/python/cudf/cudf/tests/test_string.py @@ -835,7 +835,9 @@ def test_string_extract(ps_gs, pat, expand, flags, flags_raise): ("FGHI", False), ], ) -@pytest.mark.parametrize("flags,flags_raise", [(0, 0), (24, 0), (1, 1)]) +@pytest.mark.parametrize( + "flags,flags_raise", [(0, 0), (rex.MULTILINE | rex.DOTALL, 0), (1, 1)] +) @pytest.mark.parametrize("na,na_raise", [(np.nan, 0), (None, 1), ("", 1)]) def test_string_contains(ps_gs, pat, regex, flags, flags_raise, na, na_raise): ps, gs = ps_gs From 96d6e01c258111b5a25995875c7a49a1959842a0 Mon Sep 17 00:00:00 2001 From: David Wendt Date: Fri, 22 Oct 2021 07:11:06 -0400 Subject: [PATCH 10/18] import re (builtin) instead of regex --- python/cudf/cudf/core/column/string.py | 2 +- python/cudf/cudf/tests/test_string.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/cudf/cudf/core/column/string.py b/python/cudf/cudf/core/column/string.py index f212687317e..2dd916d6f3c 100644 --- a/python/cudf/cudf/core/column/string.py +++ b/python/cudf/cudf/core/column/string.py @@ -4,6 +4,7 @@ import builtins import pickle +import re as rex import warnings from typing import ( TYPE_CHECKING, @@ -21,7 +22,6 @@ import numpy as np import pandas as pd import pyarrow as pa -import regex as rex from numba import cuda import cudf diff --git a/python/cudf/cudf/tests/test_string.py b/python/cudf/cudf/tests/test_string.py index 3148ec6bac9..2f8ea14327f 100644 --- a/python/cudf/cudf/tests/test_string.py +++ b/python/cudf/cudf/tests/test_string.py @@ -2,6 +2,7 @@ import json import re +import re as rex import urllib.parse from contextlib import ExitStack as does_not_raise from decimal import Decimal @@ -12,7 +13,6 @@ import pandas as pd import pyarrow as pa import pytest -import regex as rex import cudf from cudf import concat From d3191f439507e1c1c07b094c3ae7cba20b4987ac Mon Sep 17 00:00:00 2001 From: David Wendt Date: Fri, 22 Oct 2021 07:48:02 -0400 Subject: [PATCH 11/18] remove duplicate import re --- python/cudf/cudf/core/column/string.py | 8 ++++---- python/cudf/cudf/tests/test_string.py | 5 ++--- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/python/cudf/cudf/core/column/string.py b/python/cudf/cudf/core/column/string.py index 2dd916d6f3c..5f3e828a603 100644 --- a/python/cudf/cudf/core/column/string.py +++ b/python/cudf/cudf/core/column/string.py @@ -4,7 +4,7 @@ import builtins import pickle -import re as rex +import re import warnings from typing import ( TYPE_CHECKING, @@ -737,7 +737,7 @@ def contains( raise NotImplementedError("`case` parameter is not yet supported") if na is not np.nan: raise NotImplementedError("`na` parameter is not yet supported") - if flags != 0 and (flags & (rex.MULTILINE | rex.DOTALL) == 0): + if flags != 0 and (flags & (re.MULTILINE | re.DOTALL) == 0): raise NotImplementedError("invalid `flags` parameter value") if pat is None: @@ -3332,7 +3332,7 @@ def count(self, pat: str, flags: int = 0) -> SeriesOrIndex: >>> index.str.count('a') Int64Index([0, 0, 2, 1], dtype='int64') """ # noqa W605 - if flags != 0 and (flags & (rex.MULTILINE | rex.DOTALL) == 0): + if flags != 0 and (flags & (re.MULTILINE | re.DOTALL) == 0): raise NotImplementedError("invalid `flags` parameter value") return self._return_or_inplace( @@ -3893,7 +3893,7 @@ def match( """ if case is not True: raise NotImplementedError("`case` parameter is not yet supported") - if flags != 0 and (flags & (rex.MULTILINE | rex.DOTALL) == 0): + if flags != 0 and (flags & (re.MULTILINE | re.DOTALL) == 0): raise NotImplementedError("invalid `flags` parameter value") return self._return_or_inplace( diff --git a/python/cudf/cudf/tests/test_string.py b/python/cudf/cudf/tests/test_string.py index 2f8ea14327f..4c340b85d13 100644 --- a/python/cudf/cudf/tests/test_string.py +++ b/python/cudf/cudf/tests/test_string.py @@ -2,7 +2,6 @@ import json import re -import re as rex import urllib.parse from contextlib import ExitStack as does_not_raise from decimal import Decimal @@ -836,7 +835,7 @@ def test_string_extract(ps_gs, pat, expand, flags, flags_raise): ], ) @pytest.mark.parametrize( - "flags,flags_raise", [(0, 0), (rex.MULTILINE | rex.DOTALL, 0), (1, 1)] + "flags,flags_raise", [(0, 0), (re.MULTILINE | re.DOTALL, 0), (1, 1)] ) @pytest.mark.parametrize("na,na_raise", [(np.nan, 0), (None, 1), ("", 1)]) def test_string_contains(ps_gs, pat, regex, flags, flags_raise, na, na_raise): @@ -1737,7 +1736,7 @@ def test_string_wrap(data, width): @pytest.mark.parametrize( "pat", ["a", " ", "\t", "another", "0", r"\$", "^line$", "line.*be"] ) -@pytest.mark.parametrize("flags", [0, rex.MULTILINE, rex.DOTALL]) +@pytest.mark.parametrize("flags", [0, re.MULTILINE, re.DOTALL]) def test_string_count(data, pat, flags): gs = cudf.Series(data) ps = pd.Series(data) From a2634f3e261b9a30c2e7588b2cc9bd2a3ca3140c Mon Sep 17 00:00:00 2001 From: David Wendt Date: Wed, 27 Oct 2021 11:45:44 -0400 Subject: [PATCH 12/18] add missing doxygen param --- cpp/include/cudf/strings/contains.hpp | 6 +++--- cpp/src/strings/regex/regex.cuh | 10 +++------- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/cpp/include/cudf/strings/contains.hpp b/cpp/include/cudf/strings/contains.hpp index a9de166a240..2d8c4419986 100644 --- a/cpp/include/cudf/strings/contains.hpp +++ b/cpp/include/cudf/strings/contains.hpp @@ -45,7 +45,7 @@ namespace strings { * * @param strings Strings instance for this operation. * @param pattern Regex pattern to match to each string. - * @param flags Regex flags for processing the pattern + * @param flags Regex flags for interpretting special characters in the pattern. * @param mr Device memory resource used to allocate the returned column's device memory. * @return New column of boolean results for each string. */ @@ -72,7 +72,7 @@ std::unique_ptr contains_re( * * @param strings Strings instance for this operation. * @param pattern Regex pattern to match to each string. - * @param flags Regex flags for processing the pattern + * @param flags Regex flags for interpretting special characters in the pattern. * @param mr Device memory resource used to allocate the returned column's device memory. * @return New column of boolean results for each string. */ @@ -99,7 +99,7 @@ std::unique_ptr matches_re( * * @param strings Strings instance for this operation. * @param pattern Regex pattern to match within each string. - * @param flags Regex flags for processing the pattern + * @param flags Regex flags for interpretting special characters in the pattern. * @param mr Device memory resource used to allocate the returned column's device memory. * @return New INT32 column with counts for each string. */ diff --git a/cpp/src/strings/regex/regex.cuh b/cpp/src/strings/regex/regex.cuh index 20be1829c2e..1c42deb296d 100644 --- a/cpp/src/strings/regex/regex.cuh +++ b/cpp/src/strings/regex/regex.cuh @@ -110,17 +110,13 @@ class reprog_device { rmm::cuda_stream_view stream); /** - * @brief Create device program instance from a regex pattern. - * - * The number of strings is needed to compute the state data size required when evaluating the - * regex. + * @brief Create the device program instance from a regex pattern. * * @param pattern The regex pattern to compile. + * @param re_flags Regex flags for interpretting special characters in the pattern. * @param cp_flags The code-point lookup table for character types. * @param strings_count Number of strings that will be evaluated. - * @param stream CUDA stream for asynchronous memory allocations. To ensure correct - * synchronization on destruction, the same stream should be used for all operations with the - * created objects. + * @param stream CUDA stream used for device memory operations and kernel launches * @return The program device object. */ static std::unique_ptr> create( From 8106b41413b492b4f4f7f22ae4c5979ef11dc03b Mon Sep 17 00:00:00 2001 From: David Wendt Date: Wed, 27 Oct 2021 13:15:15 -0400 Subject: [PATCH 13/18] change define to constexpr --- cpp/include/cudf/strings/regex/flags.hpp | 23 +++++++++++++++++++++-- cpp/src/strings/regex/regcomp.cpp | 4 ++-- 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/cpp/include/cudf/strings/regex/flags.hpp b/cpp/include/cudf/strings/regex/flags.hpp index f459c5bdad0..f6aee6d22cc 100644 --- a/cpp/include/cudf/strings/regex/flags.hpp +++ b/cpp/include/cudf/strings/regex/flags.hpp @@ -38,8 +38,27 @@ enum regex_flags : uint32_t { DOTALL = 16 /// the '.' matching includes new-line characters }; -#define IS_MULTILINE(f) ((f & regex_flags::MULTILINE) == regex_flags::MULTILINE) -#define IS_DOTALL(f) ((f & regex_flags::DOTALL) == regex_flags::DOTALL) +/** + * @brief Returns true if the given flags contain MULTILINE. + * + * @param f Regex flags to check + * @return true if `f` includes MULTILINE + */ +constexpr bool is_multiline(regex_flags const f) +{ + return (f & regex_flags::MULTILINE) == regex_flags::MULTILINE; +} + +/** + * @brief Returns true if the given flags contain DOTALL. + * + * @param f Regex flags to check + * @return true if `f` includes DOTALL + */ +constexpr bool is_dotall(regex_flags const f) +{ + return (f & regex_flags::DOTALL) == regex_flags::DOTALL; +} /** @} */ // end of doxygen group } // namespace strings diff --git a/cpp/src/strings/regex/regcomp.cpp b/cpp/src/strings/regex/regcomp.cpp index 5d8f3641633..46646949d70 100644 --- a/cpp/src/strings/regex/regcomp.cpp +++ b/cpp/src/strings/regex/regcomp.cpp @@ -672,7 +672,7 @@ class regex_compiler { } else if (t == CHAR) { m_prog.inst_at(inst_id).u1.c = yy; } else if (t == BOL || t == EOL) { - m_prog.inst_at(inst_id).u1.c = IS_MULTILINE(flags) ? yy : '\n'; + m_prog.inst_at(inst_id).u1.c = is_multiline(flags) ? yy : '\n'; } pushand(inst_id, inst_id); lastwasand = true; @@ -785,7 +785,7 @@ class regex_compiler { // Parse std::vector items; { - regex_parser parser(pattern, IS_DOTALL(flags) ? ANYNL : ANY, m_prog); + regex_parser parser(pattern, is_dotall(flags) ? ANYNL : ANY, m_prog); // Expand counted repetitions if (parser.m_has_counted) From 88cca5f784a4f8da4669e84a646eb7380025e979 Mon Sep 17 00:00:00 2001 From: David Wendt Date: Thu, 28 Oct 2021 16:46:56 -0400 Subject: [PATCH 14/18] check for unsupported flags --- cpp/src/strings/regex/regcomp.cpp | 10 ++++++---- python/cudf/cudf/core/column/string.py | 18 +++++++++++++++--- python/cudf/cudf/tests/test_string.py | 3 ++- 3 files changed, 23 insertions(+), 8 deletions(-) diff --git a/cpp/src/strings/regex/regcomp.cpp b/cpp/src/strings/regex/regcomp.cpp index 46646949d70..0dd7a0620bd 100644 --- a/cpp/src/strings/regex/regcomp.cpp +++ b/cpp/src/strings/regex/regcomp.cpp @@ -959,19 +959,21 @@ void reprog::print(regex_flags const flags) case NOP: printf("NOP, nextid= %d", inst.u2.next_id); break; case BOL: { printf("BOL, c = "); - if (inst.u1.c == '\n') + if (inst.u1.c == '\n') { printf("'\\n'"); - else + } else { printf("'%c'", inst.u1.c); + } printf(", nextid= %d", inst.u2.next_id); break; } case EOL: { printf("EOL, c = "); - if (inst.u1.c == '\n') + if (inst.u1.c == '\n') { printf("'\\n'"); - else + } else { printf("'%c'", inst.u1.c); + } printf(", nextid= %d", inst.u2.next_id); break; } diff --git a/python/cudf/cudf/core/column/string.py b/python/cudf/cudf/core/column/string.py index 90fc84fa2bb..37c278307f4 100644 --- a/python/cudf/cudf/core/column/string.py +++ b/python/cudf/cudf/core/column/string.py @@ -737,7 +737,11 @@ def contains( raise NotImplementedError("`case` parameter is not yet supported") if na is not np.nan: raise NotImplementedError("`na` parameter is not yet supported") - if flags != 0 and (flags & (re.MULTILINE | re.DOTALL) == 0): + if ( + flags != 0 + and (flags & (re.MULTILINE | re.DOTALL) == 0) + or (flags & ~(re.MULTILINE | re.DOTALL) != 0) + ): raise NotImplementedError("invalid `flags` parameter value") if pat is None: @@ -3332,7 +3336,11 @@ def count(self, pat: str, flags: int = 0) -> SeriesOrIndex: >>> index.str.count('a') Int64Index([0, 0, 2, 1], dtype='int64') """ # noqa W605 - if flags != 0 and (flags & (re.MULTILINE | re.DOTALL) == 0): + if ( + flags != 0 + and (flags & (re.MULTILINE | re.DOTALL) == 0) + or (flags & ~(re.MULTILINE | re.DOTALL) != 0) + ): raise NotImplementedError("invalid `flags` parameter value") return self._return_or_inplace( @@ -3893,7 +3901,11 @@ def match( """ if case is not True: raise NotImplementedError("`case` parameter is not yet supported") - if flags != 0 and (flags & (re.MULTILINE | re.DOTALL) == 0): + if ( + flags != 0 + and (flags & (re.MULTILINE | re.DOTALL) == 0) + or (flags & ~(re.MULTILINE | re.DOTALL) != 0) + ): raise NotImplementedError("invalid `flags` parameter value") return self._return_or_inplace( diff --git a/python/cudf/cudf/tests/test_string.py b/python/cudf/cudf/tests/test_string.py index 4c340b85d13..f96a1b12d2d 100644 --- a/python/cudf/cudf/tests/test_string.py +++ b/python/cudf/cudf/tests/test_string.py @@ -835,7 +835,8 @@ def test_string_extract(ps_gs, pat, expand, flags, flags_raise): ], ) @pytest.mark.parametrize( - "flags,flags_raise", [(0, 0), (re.MULTILINE | re.DOTALL, 0), (1, 1)] + "flags,flags_raise", + [(0, 0), (re.MULTILINE | re.DOTALL, 0), (re.I, 1), (re.I | re.DOTALL, 1)], ) @pytest.mark.parametrize("na,na_raise", [(np.nan, 0), (None, 1), ("", 1)]) def test_string_contains(ps_gs, pat, regex, flags, flags_raise, na, na_raise): From dbf6f156c786ee11ae3456ea95ab46ac0d092cc8 Mon Sep 17 00:00:00 2001 From: David Wendt Date: Fri, 29 Oct 2021 11:42:22 -0400 Subject: [PATCH 15/18] fix misspelling of interpreting --- cpp/include/cudf/strings/contains.hpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cpp/include/cudf/strings/contains.hpp b/cpp/include/cudf/strings/contains.hpp index 2d8c4419986..9f408a40314 100644 --- a/cpp/include/cudf/strings/contains.hpp +++ b/cpp/include/cudf/strings/contains.hpp @@ -45,7 +45,7 @@ namespace strings { * * @param strings Strings instance for this operation. * @param pattern Regex pattern to match to each string. - * @param flags Regex flags for interpretting special characters in the pattern. + * @param flags Regex flags for interpreting special characters in the pattern. * @param mr Device memory resource used to allocate the returned column's device memory. * @return New column of boolean results for each string. */ @@ -72,7 +72,7 @@ std::unique_ptr contains_re( * * @param strings Strings instance for this operation. * @param pattern Regex pattern to match to each string. - * @param flags Regex flags for interpretting special characters in the pattern. + * @param flags Regex flags for interpreting special characters in the pattern. * @param mr Device memory resource used to allocate the returned column's device memory. * @return New column of boolean results for each string. */ @@ -99,7 +99,7 @@ std::unique_ptr matches_re( * * @param strings Strings instance for this operation. * @param pattern Regex pattern to match within each string. - * @param flags Regex flags for interpretting special characters in the pattern. + * @param flags Regex flags for interpreting special characters in the pattern. * @param mr Device memory resource used to allocate the returned column's device memory. * @return New INT32 column with counts for each string. */ From 8ba74e8a018317dbd234b2e4c259ee0e72558d5c Mon Sep 17 00:00:00 2001 From: David Wendt Date: Fri, 29 Oct 2021 11:43:16 -0400 Subject: [PATCH 16/18] change int32 to size-type --- cpp/src/strings/regex/regcomp.cpp | 5 +++-- cpp/src/strings/regex/regex.cuh | 18 ++++++++---------- cpp/src/strings/regex/regexec.cu | 4 ++-- 3 files changed, 13 insertions(+), 14 deletions(-) diff --git a/cpp/src/strings/regex/regcomp.cpp b/cpp/src/strings/regex/regcomp.cpp index 0dd7a0620bd..0e3dcb93826 100644 --- a/cpp/src/strings/regex/regcomp.cpp +++ b/cpp/src/strings/regex/regcomp.cpp @@ -841,7 +841,6 @@ reprog reprog::create_from(const char32_t* pattern, regex_flags const flags) regex_compiler compiler(pattern, flags, rtn); // for debugging, it can be helpful to call rtn.print(flags) here to dump // out the instructions that have been created from the given pattern - if (std::getenv("LIBCUDF_REGEX_DEBUG")) rtn.print(flags); return rtn; } @@ -927,9 +926,10 @@ void reprog::optimize2() _startinst_ids.push_back(-1); // terminator mark } +#ifndef NDBUG void reprog::print(regex_flags const flags) { - printf("Flags = 0x%02x\n", static_cast(flags)); + printf("Flags = 0x%08x\n", static_cast(flags)); printf("Instructions:\n"); for (std::size_t i = 0; i < _insts.size(); i++) { const reinst& inst = _insts[i]; @@ -1026,6 +1026,7 @@ void reprog::print(regex_flags const flags) } if (_num_capturing_groups) printf("Number of capturing groups: %d\n", _num_capturing_groups); } +#endif } // namespace detail } // namespace strings diff --git a/cpp/src/strings/regex/regex.cuh b/cpp/src/strings/regex/regex.cuh index 1c42deb296d..27556d90b1b 100644 --- a/cpp/src/strings/regex/regex.cuh +++ b/cpp/src/strings/regex/regex.cuh @@ -96,25 +96,23 @@ class reprog_device { * regex. * * @param pattern The regex pattern to compile. - * @param cp_flags The code-point lookup table for character types. + * @param codepoint_flags The code point lookup table for character types. * @param strings_count Number of strings that will be evaluated. - * @param stream CUDA stream for asynchronous memory allocations. To ensure correct - * synchronization on destruction, the same stream should be used for all operations with the - * created objects. + * @param stream CUDA stream used for device memory operations and kernel launches. * @return The program device object. */ static std::unique_ptr> create( std::string const& pattern, - const uint8_t* cp_flags, - int32_t strings_count, + uint8_t const* codepoint_flags, + size_type strings_count, rmm::cuda_stream_view stream); /** * @brief Create the device program instance from a regex pattern. * * @param pattern The regex pattern to compile. - * @param re_flags Regex flags for interpretting special characters in the pattern. - * @param cp_flags The code-point lookup table for character types. + * @param re_flags Regex flags for interpreting special characters in the pattern. + * @param codepoint_flags The code point lookup table for character types. * @param strings_count Number of strings that will be evaluated. * @param stream CUDA stream used for device memory operations and kernel launches * @return The program device object. @@ -122,8 +120,8 @@ class reprog_device { static std::unique_ptr> create( std::string const& pattern, regex_flags const re_flags, - const uint8_t* cp_flags, - int32_t strings_count, + uint8_t const* codepoint_flags, + size_type strings_count, rmm::cuda_stream_view stream); /** diff --git a/cpp/src/strings/regex/regexec.cu b/cpp/src/strings/regex/regexec.cu index 604cc5b5264..4f93bbd6e7b 100644 --- a/cpp/src/strings/regex/regexec.cu +++ b/cpp/src/strings/regex/regexec.cu @@ -75,7 +75,7 @@ reprog_device::reprog_device(reprog& prog) std::unique_ptr> reprog_device::create( std::string const& pattern, uint8_t const* codepoint_flags, - int32_t strings_count, + size_type strings_count, rmm::cuda_stream_view stream) { return reprog_device::create( @@ -87,7 +87,7 @@ std::unique_ptr> reprog_devic std::string const& pattern, regex_flags const flags, uint8_t const* codepoint_flags, - int32_t strings_count, + size_type strings_count, rmm::cuda_stream_view stream) { std::vector pattern32 = string_to_char32_vector(pattern); From 571afa884e9c85511fe0b415b55c9f5d3b8cae5f Mon Sep 17 00:00:00 2001 From: David Wendt Date: Fri, 29 Oct 2021 11:43:41 -0400 Subject: [PATCH 17/18] add flags parameter name --- python/cudf/cudf/_lib/cpp/strings/contains.pxd | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/cudf/cudf/_lib/cpp/strings/contains.pxd b/python/cudf/cudf/_lib/cpp/strings/contains.pxd index b48d2f58334..8014a60617d 100644 --- a/python/cudf/cudf/_lib/cpp/strings/contains.pxd +++ b/python/cudf/cudf/_lib/cpp/strings/contains.pxd @@ -20,14 +20,14 @@ cdef extern from "cudf/strings/contains.hpp" namespace "cudf::strings" nogil: cdef unique_ptr[column] contains_re( column_view source_strings, string pattern, - regex_flags) except + + regex_flags flags) except + cdef unique_ptr[column] count_re( column_view source_strings, string pattern, - regex_flags) except + + regex_flags flags) except + cdef unique_ptr[column] matches_re( column_view source_strings, string pattern, - regex_flags) except + + regex_flags flags) except + From 3e9c3900b6ea24e323ea089da3438a6e8d940a0f Mon Sep 17 00:00:00 2001 From: David Wendt Date: Fri, 29 Oct 2021 11:44:13 -0400 Subject: [PATCH 18/18] add _is_supported_regex_flags --- python/cudf/cudf/core/column/string.py | 35 +++++++++++--------------- python/cudf/cudf/tests/test_string.py | 6 ++++- 2 files changed, 20 insertions(+), 21 deletions(-) diff --git a/python/cudf/cudf/core/column/string.py b/python/cudf/cudf/core/column/string.py index 37c278307f4..558e70c68c2 100644 --- a/python/cudf/cudf/core/column/string.py +++ b/python/cudf/cudf/core/column/string.py @@ -98,6 +98,13 @@ def str_to_boolean(column: StringColumn): } +def _is_supported_regex_flags(flags): + return flags == 0 or ( + (flags & (re.MULTILINE | re.DOTALL) != 0) + and (flags & ~(re.MULTILINE | re.DOTALL) == 0) + ) + + class StringMethods(ColumnMethods): """ Vectorized string functions for Series and Index. @@ -737,19 +744,15 @@ def contains( raise NotImplementedError("`case` parameter is not yet supported") if na is not np.nan: raise NotImplementedError("`na` parameter is not yet supported") - if ( - flags != 0 - and (flags & (re.MULTILINE | re.DOTALL) == 0) - or (flags & ~(re.MULTILINE | re.DOTALL) != 0) - ): - raise NotImplementedError("invalid `flags` parameter value") + if not _is_supported_regex_flags(flags): + raise ValueError("invalid `flags` parameter value") if pat is None: result_col = column.column_empty( len(self._column), dtype="bool", masked=True ) elif is_scalar(pat): - if regex is True: + if regex: result_col = libstrings.contains_re(self._column, pat, flags) else: result_col = libstrings.contains( @@ -3301,7 +3304,7 @@ def count(self, pat: str, flags: int = 0) -> SeriesOrIndex: - `flags` parameter currently only supports re.DOTALL and re.MULTILINE. - Some characters need to be escaped when passing - in pat. eg. ``'$'`` has a special meaning in regex + in pat. e.g. ``'$'`` has a special meaning in regex and must be escaped when finding this literal character. Examples @@ -3336,12 +3339,8 @@ def count(self, pat: str, flags: int = 0) -> SeriesOrIndex: >>> index.str.count('a') Int64Index([0, 0, 2, 1], dtype='int64') """ # noqa W605 - if ( - flags != 0 - and (flags & (re.MULTILINE | re.DOTALL) == 0) - or (flags & ~(re.MULTILINE | re.DOTALL) != 0) - ): - raise NotImplementedError("invalid `flags` parameter value") + if not _is_supported_regex_flags(flags): + raise ValueError("invalid `flags` parameter value") return self._return_or_inplace( libstrings.count_re(self._column, pat, flags) @@ -3901,12 +3900,8 @@ def match( """ if case is not True: raise NotImplementedError("`case` parameter is not yet supported") - if ( - flags != 0 - and (flags & (re.MULTILINE | re.DOTALL) == 0) - or (flags & ~(re.MULTILINE | re.DOTALL) != 0) - ): - raise NotImplementedError("invalid `flags` parameter value") + if not _is_supported_regex_flags(flags): + raise ValueError("invalid `flags` parameter value") return self._return_or_inplace( libstrings.match_re(self._column, pat, flags) diff --git a/python/cudf/cudf/tests/test_string.py b/python/cudf/cudf/tests/test_string.py index f96a1b12d2d..848d59158ea 100644 --- a/python/cudf/cudf/tests/test_string.py +++ b/python/cudf/cudf/tests/test_string.py @@ -842,7 +842,11 @@ def test_string_extract(ps_gs, pat, expand, flags, flags_raise): def test_string_contains(ps_gs, pat, regex, flags, flags_raise, na, na_raise): ps, gs = ps_gs - expectation = raise_builder([flags_raise, na_raise], NotImplementedError) + expectation = does_not_raise() + if flags_raise: + expectation = pytest.raises(ValueError) + if na_raise: + expectation = pytest.raises(NotImplementedError) with expectation: expect = ps.str.contains(pat, flags=flags, na=na, regex=regex)