Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vectorize basic_string::find (the string needle overload) #5048

Merged
merged 11 commits into from
Oct 30, 2024
19 changes: 19 additions & 0 deletions benchmarks/src/search.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,22 @@ void search_default_searcher(benchmark::State& state) {
}
}

template <class T>
void member_find(benchmark::State& state) {
const auto& src_haystack = patterns[static_cast<size_t>(state.range())].data;
const auto& src_needle = patterns[static_cast<size_t>(state.range())].pattern;

const T haystack(src_haystack.begin(), src_haystack.end());
const T needle(src_needle.begin(), src_needle.end());

for (auto _ : state) {
benchmark::DoNotOptimize(haystack);
benchmark::DoNotOptimize(needle);
auto res = haystack.find(needle);
benchmark::DoNotOptimize(res);
}
}

template <class T>
void classic_find_end(benchmark::State& state) {
const auto& src_haystack = patterns[static_cast<size_t>(state.range())].data;
Expand Down Expand Up @@ -158,6 +174,9 @@ BENCHMARK(ranges_search<std::uint16_t>)->Apply(common_args);
BENCHMARK(search_default_searcher<std::uint8_t>)->Apply(common_args);
BENCHMARK(search_default_searcher<std::uint16_t>)->Apply(common_args);

BENCHMARK(member_find<std::string>)->Apply(common_args);
BENCHMARK(member_find<std::wstring>)->Apply(common_args);

BENCHMARK(classic_find_end<std::uint8_t>)->Apply(common_args);
BENCHMARK(classic_find_end<std::uint16_t>)->Apply(common_args);

Expand Down
15 changes: 15 additions & 0 deletions stl/inc/__msvc_string_view.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -628,6 +628,21 @@ constexpr size_t _Traits_find(_In_reads_(_Hay_size) const _Traits_ptr_t<_Traits>
return _Start_at;
}

#if _USE_STD_VECTOR_ALGORITHMS
if constexpr (_Is_implementation_handled_char_traits<_Traits> && sizeof(typename _Traits::char_type) <= 2) {
if (!_STD _Is_constant_evaluated()) {
const auto _End = _Haystack + _Hay_size;
const auto _Ptr = _STD _Search_vectorized(_Haystack + _Start_at, _End, _Needle, _Needle_size);

if (_Ptr != _End) {
return static_cast<size_t>(_Ptr - _Haystack);
} else {
return static_cast<size_t>(-1);
}
}
}
#endif // _USE_STD_VECTOR_ALGORITHMS

const auto _Possible_matches_end = _Haystack + (_Hay_size - _Needle_size) + 1;
for (auto _Match_try = _Haystack + _Start_at;; ++_Match_try) {
_Match_try = _Traits::find(_Match_try, static_cast<size_t>(_Possible_matches_end - _Match_try), *_Needle);
Expand Down
35 changes: 35 additions & 0 deletions tests/std/tests/VSO_0000000_vector_algorithms/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1247,22 +1247,57 @@ void test_case_string_find_last_of(const basic_string<T>& input_haystack, const
assert(expected == actual);
}

template <class T>
void test_case_string_find_str(const basic_string<T>& input_haystack, const basic_string<T>& input_needle) {
ptrdiff_t expected;
if (input_needle.empty()) {
expected = 0;
AlexGuteniev marked this conversation as resolved.
Show resolved Hide resolved
} else {
const auto expected_iter = last_known_good_search(
input_haystack.begin(), input_haystack.end(), input_needle.begin(), input_needle.end());

if (expected_iter != input_haystack.end()) {
expected = expected_iter - input_haystack.begin();
} else {
expected = -1;
}
}
const auto actual = static_cast<ptrdiff_t>(input_haystack.find(input_needle));
assert(expected == actual);
}

template <class T, class D>
void test_basic_string_dis(mt19937_64& gen, D& dis) {
basic_string<T> input_haystack;
basic_string<T> input_needle;
basic_string<T> temp;
input_haystack.reserve(haystackDataCount);
input_needle.reserve(needleDataCount);
temp.reserve(needleDataCount);

for (;;) {
input_needle.clear();

test_case_string_find_first_of(input_haystack, input_needle);
test_case_string_find_last_of(input_haystack, input_needle);
test_case_string_find_str(input_haystack, input_needle);

for (size_t attempts = 0; attempts < needleDataCount; ++attempts) {
input_needle.push_back(static_cast<T>(dis(gen)));
test_case_string_find_first_of(input_haystack, input_needle);
test_case_string_find_last_of(input_haystack, input_needle);
test_case_string_find_str(input_haystack, input_needle);

// For large needles the chance of a match is low, so test a guaranteed match
if (input_haystack.size() > input_needle.size() * 2) {
uniform_int_distribution<size_t> pos_dis(0, input_haystack.size() - input_needle.size());
const size_t pos = pos_dis(gen);
const auto overwritten_first = input_haystack.begin() + static_cast<ptrdiff_t>(pos);
temp.assign(overwritten_first, overwritten_first + static_cast<ptrdiff_t>(input_needle.size()));
copy(input_needle.begin(), input_needle.end(), overwritten_first);
test_case_string_find_str(input_haystack, input_needle);
copy(temp.begin(), temp.end(), overwritten_first);
}
}

if (input_haystack.size() == haystackDataCount) {
Expand Down