Skip to content

Commit

Permalink
feat(C++): String detection is performed using SIMD techniques (#1720)
Browse files Browse the repository at this point in the history
## What does this PR do?
ref: https://arxiv.org/pdf/1902.08318.pdf
ref: https://github.com/simdutf/simdutf
I learned about the related simd technology, as well as this paper and
project implementation.
Using SIMD technique for string detection.
First, I need to implement the logic and complete the latin character
detection
``` c++
// Baseline implementation
bool isLatin_Baseline(const std::string& str) {
    for (char c : str) {
        if (static_cast<unsigned char>(c) >= 128) {
            return false;
        }
    }
    return true;
}
```
<img width="393" alt="image"
src="https://raw.githubusercontent.com/pandalee99/image_store/master/hexo/simd_base_line_test1.png">
Then, I tried to use SSE2 to speed it up, which is obviously a little
bit faster, the logic is to read multiple characters at once and then do
the bit arithmetic
Obviously, there was a speed boost, but I didn't think it was enough, so
I tried it again with AVX2
<img width="493" alt="image" 

src="https://raw.githubusercontent.com/pandalee99/image_store/master/hexo/simd_test_all_1.png">
I think in terms of efficiency, it's already much faster than before. 
But how do you prove that it's also logically true?
I added test samples to verify

``` C++
TEST(StringUtilTest, TestIsLatinLogic)
```

Finally, I ran the test
<img width="493" alt="image" 

src="https://raw.githubusercontent.com/pandalee99/image_store/master/hexo/simd_ubantu_test_1.png">
done.


<!-- Describe the purpose of this PR. -->


## Related issues
Closes #313 

<!--
Is there any related issue? Please attach here.

- #xxxx0
- #xxxx1
- #xxxx2
-->


## Does this PR introduce any user-facing change?

<!--
If any user-facing interface changes, please [open an
issue](https://github.com/apache/fury/issues/new/choose) describing the
need to do so and update the document if necessary.
-->

- [x] Does this PR introduce any public API change?
- [ ] Does this PR introduce any binary protocol compatibility change?


## Benchmark

<!--
When the PR has an impact on performance (if you don't know whether the
PR will have an impact on performance, you can submit the PR first, and
if it will have impact on performance, the code reviewer will explain
it), be sure to attach a benchmark data here.
-->
  • Loading branch information
pandalee99 authored Jul 14, 2024
1 parent 46d48c3 commit b32f3f9
Show file tree
Hide file tree
Showing 4 changed files with 266 additions and 0 deletions.
11 changes: 11 additions & 0 deletions cpp/fury/util/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ cc_library(
name = "fury_util",
srcs = glob(["*.cc"], exclude=["*test.cc"]),
hdrs = glob(["*.h"]),
copts = ["-mavx2"], # Enable AVX2 support
linkopts = ["-mavx2"], # Ensure linker also knows about AVX2
strip_include_prefix = "/cpp",
alwayslink=True,
linkstatic=True,
Expand Down Expand Up @@ -52,3 +54,12 @@ cc_test(
"@com_google_googletest//:gtest",
],
)

cc_test(
name = "string_util_test",
srcs = ["string_util_test.cc"],
deps = [
":fury_util",
"@com_google_googletest//:gtest",
],
)
121 changes: 121 additions & 0 deletions cpp/fury/util/string_util.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#include "string_util.h"

#if defined(__x86_64__) || defined(_M_X64)
#include <immintrin.h>
#elif defined(__ARM_NEON) || defined(__ARM_NEON__)
#include <arm_neon.h>
#elif defined(__riscv) && __riscv_vector
#include <riscv_vector.h>
#endif

namespace fury {

#if defined(__x86_64__) || defined(_M_X64)

bool isLatin(const std::string &str) {
const char *data = str.data();
size_t len = str.size();

size_t i = 0;
__m256i latin_mask = _mm256_set1_epi8(0x80);
for (; i + 32 <= len; i += 32) {
__m256i chars =
_mm256_loadu_si256(reinterpret_cast<const __m256i *>(data + i));
__m256i result = _mm256_and_si256(chars, latin_mask);
if (!_mm256_testz_si256(result, result)) {
return false;
}
}

for (; i < len; ++i) {
if (static_cast<unsigned char>(data[i]) >= 128) {
return false;
}
}

return true;
}

#elif defined(__ARM_NEON) || defined(__ARM_NEON__)

bool isLatin(const std::string &str) {
const char *data = str.data();
size_t len = str.size();

size_t i = 0;
uint8x16_t latin_mask = vdupq_n_u8(0x80);
for (; i + 16 <= len; i += 16) {
uint8x16_t chars = vld1q_u8(reinterpret_cast<const uint8_t *>(data + i));
uint8x16_t result = vandq_u8(chars, latin_mask);
if (vmaxvq_u8(result) != 0) {
return false;
}
}

for (; i < len; ++i) {
if (static_cast<unsigned char>(data[i]) >= 128) {
return false;
}
}

return true;
}

#elif defined(__riscv) && __riscv_vector

bool isLatin(const std::string &str) {
const char *data = str.data();
size_t len = str.size();

size_t i = 0;
for (; i + 16 <= len; i += 16) {
auto chars = vle8_v_u8m1(reinterpret_cast<const uint8_t *>(data + i), 16);
auto mask = vmv_v_x_u8m1(0x80, 16);
auto result = vand_vv_u8m1(chars, mask, 16);
if (vmax_v_u8m1(result, 16) != 0) {
return false;
}
}

for (; i < len; ++i) {
if (static_cast<unsigned char>(data[i]) >= 128) {
return false;
}
}

return true;
}

#else

bool isLatin(const std::string &str) {
for (char c : str) {
if (static_cast<unsigned char>(c) >= 128) {
return false;
}
}
return true;
}

#endif

} // namespace fury
28 changes: 28 additions & 0 deletions cpp/fury/util/string_util.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#pragma once

#include <string>

namespace fury {

bool isLatin(const std::string &str);

} // namespace fury
106 changes: 106 additions & 0 deletions cpp/fury/util/string_util_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#include <chrono>
#include <iostream>
#include <random>

#include "fury/util/logging.h"
#include "string_util.h"
#include "gtest/gtest.h"

namespace fury {

// Function to generate a random string
std::string generateRandomString(size_t length) {
const char charset[] =
"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789";
std::default_random_engine rng(std::random_device{}());
std::uniform_int_distribution<> dist(0, sizeof(charset) - 2);

std::string result;
result.reserve(length);
for (size_t i = 0; i < length; ++i) {
result += charset[dist(rng)];
}

return result;
}

bool isLatin_BaseLine(const std::string &str) {
for (char c : str) {
if (static_cast<unsigned char>(c) >= 128) {
return false;
}
}
return true;
}

TEST(StringUtilTest, TestIsLatinFunctions) {
std::string testStr = generateRandomString(100000);
auto start_time = std::chrono::high_resolution_clock::now();
bool result = isLatin_BaseLine(testStr);
auto end_time = std::chrono::high_resolution_clock::now();
auto duration = std::chrono::duration_cast<std::chrono::nanoseconds>(
end_time - start_time)
.count();
FURY_LOG(INFO) << "BaseLine Running Time: " << duration << " ns.";

start_time = std::chrono::high_resolution_clock::now();
result = isLatin(testStr);
end_time = std::chrono::high_resolution_clock::now();
duration = std::chrono::duration_cast<std::chrono::nanoseconds>(end_time -
start_time)
.count();
FURY_LOG(INFO) << "Optimized Running Time: " << duration << " ns.";

EXPECT_TRUE(result);
}

TEST(StringUtilTest, TestIsLatinLogic) {
// Test strings with only Latin characters
EXPECT_TRUE(isLatin("Fury"));
EXPECT_TRUE(isLatin(generateRandomString(80)));

// Test unaligned strings with only Latin characters
EXPECT_TRUE(isLatin(generateRandomString(80) + "1"));
EXPECT_TRUE(isLatin(generateRandomString(80) + "12"));
EXPECT_TRUE(isLatin(generateRandomString(80) + "123"));

// Test strings with non-Latin characters
EXPECT_FALSE(isLatin("你好, Fury"));
EXPECT_FALSE(isLatin(generateRandomString(80) + "你好"));
EXPECT_FALSE(isLatin(generateRandomString(80) + "1你好"));
EXPECT_FALSE(isLatin(generateRandomString(11) + ""));
EXPECT_FALSE(isLatin(generateRandomString(10) + "你好"));
EXPECT_FALSE(isLatin(generateRandomString(9) + "性能好"));
EXPECT_FALSE(isLatin("\u1234"));
EXPECT_FALSE(isLatin("a\u1234"));
EXPECT_FALSE(isLatin("ab\u1234"));
EXPECT_FALSE(isLatin("abc\u1234"));
EXPECT_FALSE(isLatin("abcd\u1234"));
EXPECT_FALSE(isLatin("Javaone Keynote\u1234"));
}

} // namespace fury

int main(int argc, char **argv) {
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

0 comments on commit b32f3f9

Please sign in to comment.