Skip to content

Commit

Permalink
ARROW-12712: [C++] String repeat kernel
Browse files Browse the repository at this point in the history
This PR adds the string repeat compute function named "string_repeat". String repeat is a binary function that accepts Binary/StringType(s) and the repetition value(s). Repetition values can be:
* a single value applied to all strings
* an array of values where each repeat count corresponds to the string in the same position

To support inputs of different shapes for this kernel, kernel exec generators and base classes for binary string transforms are also included.

Closes #11023 from edponce/ARROW-12712-String-repeat-kernel

Authored-by: Eduardo Ponce <[email protected]>
Signed-off-by: David Li <[email protected]>
  • Loading branch information
edponce authored and lidavidm committed Nov 4, 2021
1 parent 5897217 commit 0ead7c9
Show file tree
Hide file tree
Showing 9 changed files with 750 additions and 93 deletions.
567 changes: 531 additions & 36 deletions cpp/src/arrow/compute/kernels/scalar_string.cc

Large diffs are not rendered by default.

25 changes: 25 additions & 0 deletions cpp/src/arrow/compute/kernels/scalar_string_benchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,29 @@ static void BinaryJoinElementWiseArrayArray(benchmark::State& state) {
});
}

static void BinaryRepeat(benchmark::State& state) {
const int64_t array_length = 1 << 20;
const int64_t value_min_size = 0;
const int64_t value_max_size = 32;
const double null_probability = 0.01;
const int64_t repeat_min_size = 0;
const int64_t repeat_max_size = 8;
random::RandomArrayGenerator rng(kSeed);

// NOTE: this produces only-Ascii data
auto values =
rng.String(array_length, value_min_size, value_max_size, null_probability);
auto num_repeats = rng.Int64(array_length, repeat_min_size, repeat_max_size, 0);
// Make sure lookup tables are initialized before measuring
ABORT_NOT_OK(CallFunction("binary_repeat", {values, num_repeats}));

for (auto _ : state) {
ABORT_NOT_OK(CallFunction("binary_repeat", {values, num_repeats}));
}
state.SetItemsProcessed(state.iterations() * array_length);
state.SetBytesProcessed(state.iterations() * values->data()->buffers[2]->size());
}

BENCHMARK(AsciiLower);
BENCHMARK(AsciiUpper);
BENCHMARK(IsAlphaNumericAscii);
Expand All @@ -236,5 +259,7 @@ BENCHMARK(BinaryJoinArrayArray);
BENCHMARK(BinaryJoinElementWiseArrayScalar)->RangeMultiplier(8)->Range(2, 128);
BENCHMARK(BinaryJoinElementWiseArrayArray)->RangeMultiplier(8)->Range(2, 128);

BENCHMARK(BinaryRepeat);

} // namespace compute
} // namespace arrow
79 changes: 71 additions & 8 deletions cpp/src/arrow/compute/kernels/scalar_string_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

#include <memory>
#include <string>
#include <utility>
#include <vector>

#include <gmock/gmock.h>
#include <gtest/gtest.h>
Expand All @@ -26,8 +28,10 @@
#endif

#include "arrow/compute/api_scalar.h"
#include "arrow/compute/kernels/codegen_internal.h"
#include "arrow/compute/kernels/test_util.h"
#include "arrow/testing/gtest_util.h"
#include "arrow/type.h"

namespace arrow {
namespace compute {
Expand Down Expand Up @@ -64,14 +68,6 @@ class BaseTestStringKernels : public ::testing::Test {
CheckScalar(func_name, {Datum(input)}, Datum(expected), options);
}

void CheckBinaryScalar(std::string func_name, std::string json_left_input,
std::string json_right_scalar, std::shared_ptr<DataType> out_ty,
std::string json_expected,
const FunctionOptions* options = nullptr) {
CheckScalarBinaryScalar(func_name, type(), json_left_input, json_right_scalar, out_ty,
json_expected, options);
}

void CheckVarArgsScalar(std::string func_name, std::string json_input,
std::shared_ptr<DataType> out_ty, std::string json_expected,
const FunctionOptions* options = nullptr) {
Expand Down Expand Up @@ -1041,6 +1037,73 @@ TYPED_TEST(TestStringKernels, Utf8Title) {
R"([null, "", "B", "Aaaz;Zææ&", "Ɑɽɽow", "Ii", "Ⱥ.Ⱥ.Ⱥ..Ⱥ", "Hello, World!", "Foo Bar;Héhé0Zop", "!%$^.,;"])");
}

TYPED_TEST(TestStringKernels, BinaryRepeatWithScalarRepeat) {
auto values = ArrayFromJSON(this->type(),
R"(["aAazZæÆ&", null, "", "b", "ɑɽⱤoW", "ıI",
"ⱥⱥⱥȺ", "hEllO, WoRld!", "$. A3", "!ɑⱤⱤow"])");
std::vector<std::pair<int, std::string>> nrepeats_and_expected{{
{0, R"(["", null, "", "", "", "", "", "", "", ""])"},
{1, R"(["aAazZæÆ&", null, "", "b", "ɑɽⱤoW", "ıI", "ⱥⱥⱥȺ", "hEllO, WoRld!",
"$. A3", "!ɑⱤⱤow"])"},
{4, R"(["aAazZæÆ&aAazZæÆ&aAazZæÆ&aAazZæÆ&", null, "", "bbbb",
"ɑɽⱤoWɑɽⱤoWɑɽⱤoWɑɽⱤoW", "ıIıIıIıI", "ⱥⱥⱥȺⱥⱥⱥȺⱥⱥⱥȺⱥⱥⱥȺ",
"hEllO, WoRld!hEllO, WoRld!hEllO, WoRld!hEllO, WoRld!",
"$. A3$. A3$. A3$. A3", "!ɑⱤⱤow!ɑⱤⱤow!ɑⱤⱤow!ɑⱤⱤow"])"},
}};

for (const auto& pair : nrepeats_and_expected) {
auto num_repeat = pair.first;
auto expected = pair.second;
for (const auto& ty : IntTypes()) {
this->CheckVarArgs("binary_repeat",
{values, Datum(*arrow::MakeScalar(ty, num_repeat))},
this->type(), expected);
}
}

// Negative repeat count
for (auto num_repeat_ : {-1, -2, -5}) {
auto num_repeat = *arrow::MakeScalar(int64(), num_repeat_);
EXPECT_RAISES_WITH_MESSAGE_THAT(
Invalid, ::testing::HasSubstr("Repeat count must be a non-negative integer"),
CallFunction("binary_repeat", {values, num_repeat}));
}

// Floating-point repeat count
for (auto num_repeat_ : {0.0, 1.2, -1.3}) {
auto num_repeat = *arrow::MakeScalar(float64(), num_repeat_);
EXPECT_RAISES_WITH_MESSAGE_THAT(
NotImplemented, ::testing::HasSubstr("has no kernel matching input types"),
CallFunction("binary_repeat", {values, num_repeat}));
}
}

TYPED_TEST(TestStringKernels, BinaryRepeatWithArrayRepeat) {
auto values = ArrayFromJSON(this->type(),
R"([null, "aAazZæÆ&", "", "b", "ɑɽⱤoW", "ıI",
"ⱥⱥⱥȺ", "hEllO, WoRld!", "$. A3", "!ɑⱤⱤow"])");
for (const auto& ty : IntTypes()) {
auto num_repeats = ArrayFromJSON(ty, R"([100, 1, 2, 5, 2, 0, 1, 3, null, 3])");
std::string expected =
R"([null, "aAazZæÆ&", "", "bbbbb", "ɑɽⱤoWɑɽⱤoW", "", "ⱥⱥⱥȺ",
"hEllO, WoRld!hEllO, WoRld!hEllO, WoRld!", null,
"!ɑⱤⱤow!ɑⱤⱤow!ɑⱤⱤow"])";
this->CheckVarArgs("binary_repeat", {values, num_repeats}, this->type(), expected);
}

// Negative repeat count
auto num_repeats = ArrayFromJSON(int64(), R"([100, -1, 2, -5, 2, -1, 3, -2, 3, -100])");
EXPECT_RAISES_WITH_MESSAGE_THAT(
Invalid, ::testing::HasSubstr("Repeat count must be a non-negative integer"),
CallFunction("binary_repeat", {values, num_repeats}));

// Floating-point repeat count
num_repeats = ArrayFromJSON(float64(), R"([0.0, 1.2, -1.3])");
EXPECT_RAISES_WITH_MESSAGE_THAT(
NotImplemented, ::testing::HasSubstr("has no kernel matching input types"),
CallFunction("binary_repeat", {values, num_repeats}));
}

TYPED_TEST(TestStringKernels, IsAlphaNumericUnicode) {
// U+08BE (utf8: \xE0\xA2\xBE) is undefined, but utf8proc things it is
// UTF8PROC_CATEGORY_LO
Expand Down
48 changes: 48 additions & 0 deletions cpp/src/arrow/util/bit_block_counter.h
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,54 @@ static void VisitBitBlocksVoid(const std::shared_ptr<Buffer>& bitmap_buf, int64_
}
}

template <typename VisitNotNull, typename VisitNull>
static Status VisitTwoBitBlocks(const std::shared_ptr<Buffer>& left_bitmap_buf,
int64_t left_offset,
const std::shared_ptr<Buffer>& right_bitmap_buf,
int64_t right_offset, int64_t length,
VisitNotNull&& visit_not_null, VisitNull&& visit_null) {
if (left_bitmap_buf == NULLPTR || right_bitmap_buf == NULLPTR) {
// At most one bitmap is present
if (left_bitmap_buf == NULLPTR) {
return VisitBitBlocks(right_bitmap_buf, right_offset, length,
std::forward<VisitNotNull>(visit_not_null),
std::forward<VisitNull>(visit_null));
} else {
return VisitBitBlocks(left_bitmap_buf, left_offset, length,
std::forward<VisitNotNull>(visit_not_null),
std::forward<VisitNull>(visit_null));
}
}
// Both bitmaps are present
const uint8_t* left_bitmap = left_bitmap_buf->data();
const uint8_t* right_bitmap = right_bitmap_buf->data();
BinaryBitBlockCounter bit_counter(left_bitmap, left_offset, right_bitmap, right_offset,
length);
int64_t position = 0;
while (position < length) {
BitBlockCount block = bit_counter.NextAndWord();
if (block.AllSet()) {
for (int64_t i = 0; i < block.length; ++i, ++position) {
ARROW_RETURN_NOT_OK(visit_not_null(position));
}
} else if (block.NoneSet()) {
for (int64_t i = 0; i < block.length; ++i, ++position) {
ARROW_RETURN_NOT_OK(visit_null());
}
} else {
for (int64_t i = 0; i < block.length; ++i, ++position) {
if (BitUtil::GetBit(left_bitmap, left_offset + position) &&
BitUtil::GetBit(right_bitmap, right_offset + position)) {
ARROW_RETURN_NOT_OK(visit_not_null(position));
} else {
ARROW_RETURN_NOT_OK(visit_null());
}
}
}
}
return Status::OK();
}

template <typename VisitNotNull, typename VisitNull>
static void VisitTwoBitBlocksVoid(const std::shared_ptr<Buffer>& left_bitmap_buf,
int64_t left_offset,
Expand Down
Loading

0 comments on commit 0ead7c9

Please sign in to comment.