From b32f3f9acca14dd474d9bcf59786a8bffd02a0a4 Mon Sep 17 00:00:00 2001 From: PAN <46820719+pandalee99@users.noreply.github.com> Date: Sun, 14 Jul 2024 15:11:29 +0800 Subject: [PATCH] feat(C++): String detection is performed using SIMD techniques (#1720) ## What does this PR do? ref: https://arxiv.org/pdf/1902.08318.pdf ref: https://github.com/simdutf/simdutf I learned about the related simd technology, as well as this paper and project implementation. Using SIMD technique for string detection. First, I need to implement the logic and complete the latin character detection ``` c++ // Baseline implementation bool isLatin_Baseline(const std::string& str) { for (char c : str) { if (static_cast(c) >= 128) { return false; } } return true; } ``` image Then, I tried to use SSE2 to speed it up, which is obviously a little bit faster, the logic is to read multiple characters at once and then do the bit arithmetic Obviously, there was a speed boost, but I didn't think it was enough, so I tried it again with AVX2 image I think in terms of efficiency, it's already much faster than before. But how do you prove that it's also logically true? I added test samples to verify ``` C++ TEST(StringUtilTest, TestIsLatinLogic) ``` Finally, I ran the test image done. ## Related issues Closes #313 ## Does this PR introduce any user-facing change? - [x] Does this PR introduce any public API change? - [ ] Does this PR introduce any binary protocol compatibility change? ## Benchmark --- cpp/fury/util/BUILD | 11 +++ cpp/fury/util/string_util.cc | 121 ++++++++++++++++++++++++++++++ cpp/fury/util/string_util.h | 28 +++++++ cpp/fury/util/string_util_test.cc | 106 ++++++++++++++++++++++++++ 4 files changed, 266 insertions(+) create mode 100644 cpp/fury/util/string_util.cc create mode 100644 cpp/fury/util/string_util.h create mode 100644 cpp/fury/util/string_util_test.cc diff --git a/cpp/fury/util/BUILD b/cpp/fury/util/BUILD index 1aa1b87c80..8f605dc75e 100644 --- a/cpp/fury/util/BUILD +++ b/cpp/fury/util/BUILD @@ -4,6 +4,8 @@ cc_library( name = "fury_util", srcs = glob(["*.cc"], exclude=["*test.cc"]), hdrs = glob(["*.h"]), + copts = ["-mavx2"], # Enable AVX2 support + linkopts = ["-mavx2"], # Ensure linker also knows about AVX2 strip_include_prefix = "/cpp", alwayslink=True, linkstatic=True, @@ -52,3 +54,12 @@ cc_test( "@com_google_googletest//:gtest", ], ) + +cc_test( + name = "string_util_test", + srcs = ["string_util_test.cc"], + deps = [ + ":fury_util", + "@com_google_googletest//:gtest", + ], +) \ No newline at end of file diff --git a/cpp/fury/util/string_util.cc b/cpp/fury/util/string_util.cc new file mode 100644 index 0000000000..1f57b76fdd --- /dev/null +++ b/cpp/fury/util/string_util.cc @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "string_util.h" + +#if defined(__x86_64__) || defined(_M_X64) +#include +#elif defined(__ARM_NEON) || defined(__ARM_NEON__) +#include +#elif defined(__riscv) && __riscv_vector +#include +#endif + +namespace fury { + +#if defined(__x86_64__) || defined(_M_X64) + +bool isLatin(const std::string &str) { + const char *data = str.data(); + size_t len = str.size(); + + size_t i = 0; + __m256i latin_mask = _mm256_set1_epi8(0x80); + for (; i + 32 <= len; i += 32) { + __m256i chars = + _mm256_loadu_si256(reinterpret_cast(data + i)); + __m256i result = _mm256_and_si256(chars, latin_mask); + if (!_mm256_testz_si256(result, result)) { + return false; + } + } + + for (; i < len; ++i) { + if (static_cast(data[i]) >= 128) { + return false; + } + } + + return true; +} + +#elif defined(__ARM_NEON) || defined(__ARM_NEON__) + +bool isLatin(const std::string &str) { + const char *data = str.data(); + size_t len = str.size(); + + size_t i = 0; + uint8x16_t latin_mask = vdupq_n_u8(0x80); + for (; i + 16 <= len; i += 16) { + uint8x16_t chars = vld1q_u8(reinterpret_cast(data + i)); + uint8x16_t result = vandq_u8(chars, latin_mask); + if (vmaxvq_u8(result) != 0) { + return false; + } + } + + for (; i < len; ++i) { + if (static_cast(data[i]) >= 128) { + return false; + } + } + + return true; +} + +#elif defined(__riscv) && __riscv_vector + +bool isLatin(const std::string &str) { + const char *data = str.data(); + size_t len = str.size(); + + size_t i = 0; + for (; i + 16 <= len; i += 16) { + auto chars = vle8_v_u8m1(reinterpret_cast(data + i), 16); + auto mask = vmv_v_x_u8m1(0x80, 16); + auto result = vand_vv_u8m1(chars, mask, 16); + if (vmax_v_u8m1(result, 16) != 0) { + return false; + } + } + + for (; i < len; ++i) { + if (static_cast(data[i]) >= 128) { + return false; + } + } + + return true; +} + +#else + +bool isLatin(const std::string &str) { + for (char c : str) { + if (static_cast(c) >= 128) { + return false; + } + } + return true; +} + +#endif + +} // namespace fury diff --git a/cpp/fury/util/string_util.h b/cpp/fury/util/string_util.h new file mode 100644 index 0000000000..0824d1a246 --- /dev/null +++ b/cpp/fury/util/string_util.h @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#pragma once + +#include + +namespace fury { + +bool isLatin(const std::string &str); + +} // namespace fury diff --git a/cpp/fury/util/string_util_test.cc b/cpp/fury/util/string_util_test.cc new file mode 100644 index 0000000000..045454db96 --- /dev/null +++ b/cpp/fury/util/string_util_test.cc @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include + +#include "fury/util/logging.h" +#include "string_util.h" +#include "gtest/gtest.h" + +namespace fury { + +// Function to generate a random string +std::string generateRandomString(size_t length) { + const char charset[] = + "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"; + std::default_random_engine rng(std::random_device{}()); + std::uniform_int_distribution<> dist(0, sizeof(charset) - 2); + + std::string result; + result.reserve(length); + for (size_t i = 0; i < length; ++i) { + result += charset[dist(rng)]; + } + + return result; +} + +bool isLatin_BaseLine(const std::string &str) { + for (char c : str) { + if (static_cast(c) >= 128) { + return false; + } + } + return true; +} + +TEST(StringUtilTest, TestIsLatinFunctions) { + std::string testStr = generateRandomString(100000); + auto start_time = std::chrono::high_resolution_clock::now(); + bool result = isLatin_BaseLine(testStr); + auto end_time = std::chrono::high_resolution_clock::now(); + auto duration = std::chrono::duration_cast( + end_time - start_time) + .count(); + FURY_LOG(INFO) << "BaseLine Running Time: " << duration << " ns."; + + start_time = std::chrono::high_resolution_clock::now(); + result = isLatin(testStr); + end_time = std::chrono::high_resolution_clock::now(); + duration = std::chrono::duration_cast(end_time - + start_time) + .count(); + FURY_LOG(INFO) << "Optimized Running Time: " << duration << " ns."; + + EXPECT_TRUE(result); +} + +TEST(StringUtilTest, TestIsLatinLogic) { + // Test strings with only Latin characters + EXPECT_TRUE(isLatin("Fury")); + EXPECT_TRUE(isLatin(generateRandomString(80))); + + // Test unaligned strings with only Latin characters + EXPECT_TRUE(isLatin(generateRandomString(80) + "1")); + EXPECT_TRUE(isLatin(generateRandomString(80) + "12")); + EXPECT_TRUE(isLatin(generateRandomString(80) + "123")); + + // Test strings with non-Latin characters + EXPECT_FALSE(isLatin("你好, Fury")); + EXPECT_FALSE(isLatin(generateRandomString(80) + "你好")); + EXPECT_FALSE(isLatin(generateRandomString(80) + "1你好")); + EXPECT_FALSE(isLatin(generateRandomString(11) + "你")); + EXPECT_FALSE(isLatin(generateRandomString(10) + "你好")); + EXPECT_FALSE(isLatin(generateRandomString(9) + "性能好")); + EXPECT_FALSE(isLatin("\u1234")); + EXPECT_FALSE(isLatin("a\u1234")); + EXPECT_FALSE(isLatin("ab\u1234")); + EXPECT_FALSE(isLatin("abc\u1234")); + EXPECT_FALSE(isLatin("abcd\u1234")); + EXPECT_FALSE(isLatin("Javaone Keynote\u1234")); +} + +} // namespace fury + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +}