diff --git a/cpp/src/arrow/compute/kernels/scalar_string.cc b/cpp/src/arrow/compute/kernels/scalar_string.cc index 3fce9dd8e4aac..09e189bc562cb 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string.cc @@ -31,11 +31,11 @@ #include "arrow/array/builder_binary.h" #include "arrow/array/builder_nested.h" #include "arrow/buffer_builder.h" - #include "arrow/builder.h" #include "arrow/compute/api_scalar.h" #include "arrow/compute/kernels/common.h" #include "arrow/util/checked_cast.h" +#include "arrow/util/string.h" #include "arrow/util/utf8.h" #include "arrow/util/value_parsing.h" #include "arrow/visitor_inline.h" @@ -330,15 +330,25 @@ struct StringTransformBase { return input_ncodeunits; } - virtual Status InvalidStatus() { + virtual Status InvalidInputSequence() { return Status::Invalid("Invalid UTF8 sequence in input"); } - - // Derived classes should also define this method: - // int64_t Transform(const uint8_t* input, int64_t input_string_ncodeunits, - // uint8_t* output); }; +/// Kernel exec generator for unary string transforms. Types of template +/// parameter StringTransform need to define a transform method with the +/// following signature: +/// +/// int64_t Transform(const uint8_t* input, int64_t input_string_ncodeunits, +/// uint8_t* output); +/// +/// where +/// * `input` is the input sequence (binary or string) +/// * `input_string_ncodeunits` is the length of input sequence in codeunits +/// * `output` is the output sequence (binary or string) +/// +/// and returns the number of codeunits of the `output` sequence or a negative +/// value if an invalid input sequence is detected. template struct StringTransformExecBase { using offset_type = typename Type::offset_type; @@ -356,27 +366,21 @@ struct StringTransformExecBase { static Status ExecArray(KernelContext* ctx, StringTransform* transform, const std::shared_ptr& data, Datum* out) { ArrayType input(data); - ArrayData* output = out->mutable_array(); - const int64_t input_ncodeunits = input.total_values_length(); const int64_t input_nstrings = input.length(); - - const int64_t output_ncodeunits_max = + const int64_t max_output_ncodeunits = transform->MaxCodeunits(input_nstrings, input_ncodeunits); - if (output_ncodeunits_max > std::numeric_limits::max()) { - return Status::CapacityError( - "Result might not fit in a 32bit utf8 array, convert to large_utf8"); - } + RETURN_NOT_OK(CheckOutputCapacity(max_output_ncodeunits)); - ARROW_ASSIGN_OR_RAISE(auto values_buffer, ctx->Allocate(output_ncodeunits_max)); + ArrayData* output = out->mutable_array(); + ARROW_ASSIGN_OR_RAISE(auto values_buffer, ctx->Allocate(max_output_ncodeunits)); output->buffers[2] = values_buffer; // String offsets are preallocated offset_type* output_string_offsets = output->GetMutableValues(1); uint8_t* output_str = output->buffers[2]->mutable_data(); offset_type output_ncodeunits = 0; - - output_string_offsets[0] = 0; + output_string_offsets[0] = output_ncodeunits; for (int64_t i = 0; i < input_nstrings; i++) { if (!input.IsNull(i)) { offset_type input_string_ncodeunits; @@ -384,15 +388,15 @@ struct StringTransformExecBase { auto encoded_nbytes = static_cast(transform->Transform( input_string, input_string_ncodeunits, output_str + output_ncodeunits)); if (encoded_nbytes < 0) { - return transform->InvalidStatus(); + return transform->InvalidInputSequence(); } output_ncodeunits += encoded_nbytes; } output_string_offsets[i + 1] = output_ncodeunits; } - DCHECK_LE(output_ncodeunits, output_ncodeunits_max); + DCHECK_LE(output_ncodeunits, max_output_ncodeunits); - // Trim the codepoint buffer, since we allocated too much + // Trim the codepoint buffer, since we may have allocated too much return values_buffer->Resize(output_ncodeunits, /*shrink_to_fit=*/true); } @@ -402,25 +406,30 @@ struct StringTransformExecBase { if (!input.is_valid) { return Status::OK(); } - auto* result = checked_cast(out->scalar().get()); - result->is_valid = true; const int64_t data_nbytes = static_cast(input.value->size()); + const int64_t max_output_ncodeunits = transform->MaxCodeunits(1, data_nbytes); + RETURN_NOT_OK(CheckOutputCapacity(max_output_ncodeunits)); - const int64_t output_ncodeunits_max = transform->MaxCodeunits(1, data_nbytes); - if (output_ncodeunits_max > std::numeric_limits::max()) { - return Status::CapacityError( - "Result might not fit in a 32bit utf8 array, convert to large_utf8"); - } - ARROW_ASSIGN_OR_RAISE(auto value_buffer, ctx->Allocate(output_ncodeunits_max)); + ARROW_ASSIGN_OR_RAISE(auto value_buffer, ctx->Allocate(max_output_ncodeunits)); + auto* result = checked_cast(out->scalar().get()); + result->is_valid = true; result->value = value_buffer; auto encoded_nbytes = static_cast(transform->Transform( input.value->data(), data_nbytes, value_buffer->mutable_data())); if (encoded_nbytes < 0) { - return transform->InvalidStatus(); + return transform->InvalidInputSequence(); } - DCHECK_LE(encoded_nbytes, output_ncodeunits_max); + DCHECK_LE(encoded_nbytes, max_output_ncodeunits); return value_buffer->Resize(encoded_nbytes, /*shrink_to_fit=*/true); } + + static Status CheckOutputCapacity(int64_t ncodeunits) { + if (ncodeunits > std::numeric_limits::max()) { + return Status::CapacityError( + "Result might not fit in a 32bit utf8 array, convert to large_utf8"); + } + return Status::OK(); + } }; template @@ -478,7 +487,7 @@ struct FixedSizeBinaryTransformExecBase { auto encoded_nbytes = static_cast( transform->Transform(input_string, input_width, output_str)); if (encoded_nbytes != output_width) { - return transform->InvalidStatus(); + return transform->InvalidInputSequence(); } } else { std::memset(output_str, 0x00, output_width); @@ -505,7 +514,7 @@ struct FixedSizeBinaryTransformExecBase { auto encoded_nbytes = static_cast(transform->Transform( input.value->data(), data_nbytes, value_buffer->mutable_data())); if (encoded_nbytes != out_width) { - return transform->InvalidStatus(); + return transform->InvalidInputSequence(); } result->is_valid = true; @@ -537,6 +546,362 @@ struct FixedSizeBinaryTransformExecWithState } }; +template +struct StringBinaryTransformBase { + using ViewType2 = typename GetViewType::T; + using ArrayType1 = typename TypeTraits::ArrayType; + using ArrayType2 = typename TypeTraits::ArrayType; + + virtual ~StringBinaryTransformBase() = default; + + virtual Status PreExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + return Status::OK(); + } + + virtual Status InvalidInputSequence() { + return Status::Invalid("Invalid UTF8 sequence in input"); + } + + // Return the maximum total size of the output in codeunits (i.e. bytes) + // given input characteristics for different input shapes. + // The Status parameter should only be set if an error needs to be signaled. + + // Scalar-Scalar + virtual Result MaxCodeunits(const int64_t input1_ncodeunits, const ViewType2) { + return input1_ncodeunits; + } + + // Scalar-Array + virtual Result MaxCodeunits(const int64_t input1_ncodeunits, + const ArrayType2&) { + return input1_ncodeunits; + } + + // Array-Scalar + virtual Result MaxCodeunits(const ArrayType1& input1, const ViewType2) { + return input1.total_values_length(); + } + + // Array-Array + virtual Result MaxCodeunits(const ArrayType1& input1, const ArrayType2&) { + return input1.total_values_length(); + } + + // Not all combinations of input shapes are meaningful to string binary + // transforms, so these flags serve as control toggles for enabling/disabling + // the corresponding ones. These flags should be set in the PreExec() method. + // + // This is an example of a StringTransform that disables support for arguments + // with mixed Scalar/Array shapes. + // + // template + // struct MyStringTransform : public StringBinaryTransformBase { + // Status PreExec(KernelContext* ctx, const ExecBatch& batch, Datum* out) override { + // enable_scalar_array_ = false; + // enable_array_scalar_ = false; + // return StringBinaryTransformBase::PreExec(ctx, batch, out); + // } + // ... + // }; + bool enable_scalar_scalar_ = true; + bool enable_scalar_array_ = true; + bool enable_array_scalar_ = true; + bool enable_array_array_ = true; +}; + +/// Kernel exec generator for binary (two parameters) string transforms. +/// The first parameter is expected to always be a Binary/StringType while the +/// second parameter is generic. Types of template parameter StringTransform +/// need to define a transform method with the following signature: +/// +/// Result Transform( +/// const uint8_t* input, const int64_t input_string_ncodeunits, +/// const ViewType2 value2, uint8_t* output); +/// +/// where +/// * `input` - input sequence (binary or string) +/// * `input_string_ncodeunits` - length of input sequence in codeunits +/// * `value2` - second argument to the string transform +/// * `output` - output sequence (binary or string) +/// * `st` - Status code, only set if transform needs to signal an error +/// +/// and returns the number of codeunits of the `output` sequence or a negative +/// value if an invalid input sequence is detected. +template +struct StringBinaryTransformExecBase { + using offset_type = typename Type1::offset_type; + using ViewType2 = typename GetViewType::T; + using ArrayType1 = typename TypeTraits::ArrayType; + using ArrayType2 = typename TypeTraits::ArrayType; + + static Status Execute(KernelContext* ctx, StringTransform* transform, + const ExecBatch& batch, Datum* out) { + if (batch[0].is_scalar()) { + if (batch[1].is_scalar()) { + if (transform->enable_scalar_scalar_) { + return ExecScalarScalar(ctx, transform, batch[0].scalar(), batch[1].scalar(), + out); + } + } else if (batch[1].is_array()) { + if (transform->enable_scalar_array_) { + return ExecScalarArray(ctx, transform, batch[0].scalar(), batch[1].array(), + out); + } + } + } else if (batch[0].is_array()) { + if (batch[1].is_scalar()) { + if (transform->enable_array_scalar_) { + return ExecArrayScalar(ctx, transform, batch[0].array(), batch[1].scalar(), + out); + } + } else if (batch[1].is_array()) { + if (transform->enable_array_array_) { + return ExecArrayArray(ctx, transform, batch[0].array(), batch[1].array(), out); + } + } + } + + if (!(transform->enable_scalar_scalar_ && transform->enable_scalar_array_ && + transform->enable_array_scalar_ && transform->enable_array_array_)) { + return Status::Invalid( + "Binary string transform has no combination of operand kinds enabled."); + } + + return Status::TypeError("Invalid combination of operands (", batch[0].ToString(), + ", ", batch[1].ToString(), ") for binary string transform."); + } + + static Status ExecScalarScalar(KernelContext* ctx, StringTransform* transform, + const std::shared_ptr& scalar1, + const std::shared_ptr& scalar2, Datum* out) { + if (!scalar1->is_valid || !scalar2->is_valid) { + return Status::OK(); + } + const auto& binary_scalar1 = checked_cast(*scalar1); + const auto input_string = binary_scalar1.value->data(); + const auto input_ncodeunits = binary_scalar1.value->size(); + const auto value2 = UnboxScalar::Unbox(*scalar2); + + // Calculate max number of output codeunits + ARROW_ASSIGN_OR_RAISE(const auto max_output_ncodeunits, + transform->MaxCodeunits(input_ncodeunits, value2)); + RETURN_NOT_OK(CheckOutputCapacity(max_output_ncodeunits)); + + // Allocate output string + const auto output = checked_cast(out->scalar().get()); + output->is_valid = true; + ARROW_ASSIGN_OR_RAISE(auto value_buffer, ctx->Allocate(max_output_ncodeunits)); + output->value = value_buffer; + auto output_string = output->value->mutable_data(); + + // Apply transform + ARROW_ASSIGN_OR_RAISE( + auto encoded_nbytes_, + transform->Transform(input_string, input_ncodeunits, value2, output_string)); + auto encoded_nbytes = static_cast(encoded_nbytes_); + if (encoded_nbytes < 0) { + return transform->InvalidInputSequence(); + } + DCHECK_LE(encoded_nbytes, max_output_ncodeunits); + + // Trim the codepoint buffer, since we may have allocated too much + return value_buffer->Resize(encoded_nbytes, /*shrink_to_fit=*/true); + } + + static Status ExecArrayScalar(KernelContext* ctx, StringTransform* transform, + const std::shared_ptr& data1, + const std::shared_ptr& scalar2, Datum* out) { + if (!scalar2->is_valid) { + return Status::OK(); + } + const ArrayType1 array1(data1); + const auto value2 = UnboxScalar::Unbox(*scalar2); + + // Calculate max number of output codeunits + ARROW_ASSIGN_OR_RAISE(const auto max_output_ncodeunits, + transform->MaxCodeunits(array1, value2)); + RETURN_NOT_OK(CheckOutputCapacity(max_output_ncodeunits)); + + // Allocate output strings + const auto output = out->mutable_array(); + ARROW_ASSIGN_OR_RAISE(auto values_buffer, ctx->Allocate(max_output_ncodeunits)); + output->buffers[2] = values_buffer; + const auto output_string = output->buffers[2]->mutable_data(); + + // String offsets are preallocated + auto output_offsets = output->GetMutableValues(1); + output_offsets[0] = 0; + offset_type output_ncodeunits = 0; + + // Apply transform + RETURN_NOT_OK(VisitArrayDataInline( + *data1, + [&](util::string_view input_string_view) { + auto input_ncodeunits = static_cast(input_string_view.length()); + auto input_string = reinterpret_cast(input_string_view.data()); + ARROW_ASSIGN_OR_RAISE( + auto encoded_nbytes_, + transform->Transform(input_string, input_ncodeunits, value2, + output_string + output_ncodeunits)); + auto encoded_nbytes = static_cast(encoded_nbytes_); + if (encoded_nbytes < 0) { + return transform->InvalidInputSequence(); + } + output_ncodeunits += encoded_nbytes; + *(++output_offsets) = output_ncodeunits; + return Status::OK(); + }, + [&]() { + *(++output_offsets) = output_ncodeunits; + return Status::OK(); + })); + DCHECK_LE(output_ncodeunits, max_output_ncodeunits); + + // Trim the codepoint buffer, since we may have allocated too much + return values_buffer->Resize(output_ncodeunits, /*shrink_to_fit=*/true); + } + + static Status ExecScalarArray(KernelContext* ctx, StringTransform* transform, + const std::shared_ptr& scalar1, + const std::shared_ptr& data2, Datum* out) { + if (!scalar1->is_valid) { + return Status::OK(); + } + const auto& binary_scalar1 = checked_cast(*scalar1); + const auto input_string = binary_scalar1.value->data(); + const auto input_ncodeunits = binary_scalar1.value->size(); + const ArrayType2 array2(data2); + + // Calculate max number of output codeunits + ARROW_ASSIGN_OR_RAISE(const auto max_output_ncodeunits, + transform->MaxCodeunits(input_ncodeunits, array2)); + RETURN_NOT_OK(CheckOutputCapacity(max_output_ncodeunits)); + + // Allocate output strings + const auto output = out->mutable_array(); + ARROW_ASSIGN_OR_RAISE(auto values_buffer, ctx->Allocate(max_output_ncodeunits)); + output->buffers[2] = values_buffer; + const auto output_string = output->buffers[2]->mutable_data(); + + // String offsets are preallocated + auto output_offsets = output->GetMutableValues(1); + output_offsets[0] = 0; + offset_type output_ncodeunits = 0; + + // Apply transform + RETURN_NOT_OK(arrow::internal::VisitBitBlocks( + data2->buffers[0], data2->offset, data2->length, + [&](int64_t i) { + auto value2 = array2.GetView(i); + ARROW_ASSIGN_OR_RAISE( + auto encoded_nbytes_, + transform->Transform(input_string, input_ncodeunits, value2, + output_string + output_ncodeunits)); + auto encoded_nbytes = static_cast(encoded_nbytes_); + if (encoded_nbytes < 0) { + return transform->InvalidInputSequence(); + } + output_ncodeunits += encoded_nbytes; + *(++output_offsets) = output_ncodeunits; + return Status::OK(); + }, + [&]() { + *(++output_offsets) = output_ncodeunits; + return Status::OK(); + })); + DCHECK_LE(output_ncodeunits, max_output_ncodeunits); + + // Trim the codepoint buffer, since we may have allocated too much + return values_buffer->Resize(output_ncodeunits, /*shrink_to_fit=*/true); + } + + static Status ExecArrayArray(KernelContext* ctx, StringTransform* transform, + const std::shared_ptr& data1, + const std::shared_ptr& data2, Datum* out) { + const ArrayType1 array1(data1); + const ArrayType2 array2(data2); + + // Calculate max number of output codeunits + ARROW_ASSIGN_OR_RAISE(const auto max_output_ncodeunits, + transform->MaxCodeunits(array1, array2)); + RETURN_NOT_OK(CheckOutputCapacity(max_output_ncodeunits)); + + // Allocate output strings + const auto output = out->mutable_array(); + ARROW_ASSIGN_OR_RAISE(auto values_buffer, ctx->Allocate(max_output_ncodeunits)); + output->buffers[2] = values_buffer; + const auto output_string = output->buffers[2]->mutable_data(); + + // String offsets are preallocated + auto output_offsets = output->GetMutableValues(1); + output_offsets[0] = 0; + offset_type output_ncodeunits = 0; + + // Apply transform + RETURN_NOT_OK(arrow::internal::VisitTwoBitBlocks( + data1->buffers[0], data1->offset, data2->buffers[0], data2->offset, data1->length, + [&](int64_t i) { + auto input_string_view = array1.GetView(i); + auto input_ncodeunits = static_cast(input_string_view.length()); + auto input_string = reinterpret_cast(input_string_view.data()); + auto value2 = array2.GetView(i); + ARROW_ASSIGN_OR_RAISE( + auto encoded_nbytes_, + transform->Transform(input_string, input_ncodeunits, value2, + output_string + output_ncodeunits)); + auto encoded_nbytes = static_cast(encoded_nbytes_); + if (encoded_nbytes < 0) { + return transform->InvalidInputSequence(); + } + output_ncodeunits += encoded_nbytes; + *(++output_offsets) = output_ncodeunits; + return Status::OK(); + }, + [&]() { + *(++output_offsets) = output_ncodeunits; + return Status::OK(); + })); + DCHECK_LE(output_ncodeunits, max_output_ncodeunits); + + // Trim the codepoint buffer, since we may have allocated too much + return values_buffer->Resize(output_ncodeunits, /*shrink_to_fit=*/true); + } + + static Status CheckOutputCapacity(int64_t ncodeunits) { + if (ncodeunits > std::numeric_limits::max()) { + return Status::CapacityError( + "Result might not fit in requested binary/string array. " + "If possible, convert to a large binary/string."); + } + return Status::OK(); + } +}; + +template +struct StringBinaryTransformExec + : public StringBinaryTransformExecBase { + using StringBinaryTransformExecBase::Execute; + + static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + StringTransform transform; + RETURN_NOT_OK(transform.PreExec(ctx, batch, out)); + return Execute(ctx, &transform, batch, out); + } +}; + +template +struct StringBinaryTransformExecWithState + : public StringBinaryTransformExecBase { + using State = typename StringTransform::State; + using StringBinaryTransformExecBase::Execute; + + static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + StringTransform transform(State::Get(ctx)); + RETURN_NOT_OK(transform.PreExec(ctx, batch, out)); + return Execute(ctx, &transform, batch, out); + } +}; + #ifdef ARROW_WITH_UTF8PROC struct FunctionalCaseMappingTransform : public StringTransformBase { @@ -552,7 +917,7 @@ struct FunctionalCaseMappingTransform : public StringTransformBase { // in bytes is actually only at max 3/2 (as covered by the unittest). // Note that rounding down the 3/2 is ok, since only codepoints encoded by // two code units (even) can grow to 3 code units. - return static_cast(input_ncodeunits) * 3 / 2; + return input_ncodeunits * 3 / 2; } }; @@ -686,7 +1051,7 @@ struct AsciiReverseTransform : public StringTransformBase { return utf8_char_found ? kTransformError : input_string_ncodeunits; } - Status InvalidStatus() override { + Status InvalidInputSequence() override { return Status::Invalid("Non-ASCII sequence in input"); } }; @@ -2513,6 +2878,137 @@ void AddSplit(FunctionRegistry* registry) { #endif } +/// An ScalarFunction that promotes integer arguments to Int64. +struct ScalarCTypeToInt64Function : public ScalarFunction { + using ScalarFunction::ScalarFunction; + + Result DispatchBest(std::vector* values) const override { + RETURN_NOT_OK(CheckArity(*values)); + + using arrow::compute::detail::DispatchExactImpl; + if (auto kernel = DispatchExactImpl(this, *values)) return kernel; + + EnsureDictionaryDecoded(values); + + for (auto& descr : *values) { + if (is_integer(descr.type->id())) { + descr.type = int64(); + } + } + + if (auto kernel = DispatchExactImpl(this, *values)) return kernel; + return arrow::compute::detail::NoMatchingKernel(this, *values); + } +}; + +template +struct BinaryRepeatTransform : public StringBinaryTransformBase { + using ArrayType1 = typename TypeTraits::ArrayType; + using ArrayType2 = typename TypeTraits::ArrayType; + + Result MaxCodeunits(const int64_t input1_ncodeunits, + const int64_t num_repeats) override { + ARROW_RETURN_NOT_OK(ValidateRepeatCount(num_repeats)); + return input1_ncodeunits * num_repeats; + } + + Result MaxCodeunits(const int64_t input1_ncodeunits, + const ArrayType2& input2) override { + int64_t total_num_repeats = 0; + for (int64_t i = 0; i < input2.length(); ++i) { + auto num_repeats = input2.GetView(i); + ARROW_RETURN_NOT_OK(ValidateRepeatCount(num_repeats)); + total_num_repeats += num_repeats; + } + return input1_ncodeunits * total_num_repeats; + } + + Result MaxCodeunits(const ArrayType1& input1, + const int64_t num_repeats) override { + ARROW_RETURN_NOT_OK(ValidateRepeatCount(num_repeats)); + return input1.total_values_length() * num_repeats; + } + + Result MaxCodeunits(const ArrayType1& input1, + const ArrayType2& input2) override { + int64_t total_codeunits = 0; + for (int64_t i = 0; i < input2.length(); ++i) { + auto num_repeats = input2.GetView(i); + ARROW_RETURN_NOT_OK(ValidateRepeatCount(num_repeats)); + total_codeunits += input1.GetView(i).length() * num_repeats; + } + return total_codeunits; + } + + static Result TransformSimpleLoop(const uint8_t* input, + const int64_t input_string_ncodeunits, + const int64_t num_repeats, uint8_t* output) { + uint8_t* output_start = output; + for (int64_t i = 0; i < num_repeats; ++i) { + std::memcpy(output, input, input_string_ncodeunits); + output += input_string_ncodeunits; + } + return output - output_start; + } + + static Result TransformDoublingString(const uint8_t* input, + const int64_t input_string_ncodeunits, + const int64_t num_repeats, + uint8_t* output) { + uint8_t* output_start = output; + // Repeated doubling of string + // NB: This implementation expects `num_repeats > 0`. + std::memcpy(output, input, input_string_ncodeunits); + output += input_string_ncodeunits; + int64_t irep = 1; + for (int64_t ilen = input_string_ncodeunits; irep <= (num_repeats / 2); + irep *= 2, ilen *= 2) { + std::memcpy(output, output_start, ilen); + output += ilen; + } + + // Epilogue remainder + int64_t rem = (num_repeats - irep) * input_string_ncodeunits; + std::memcpy(output, output_start, rem); + output += rem; + return output - output_start; + } + + static Result Transform(const uint8_t* input, + const int64_t input_string_ncodeunits, + const int64_t num_repeats, uint8_t* output) { + auto transform = (num_repeats < 4) ? TransformSimpleLoop : TransformDoublingString; + return transform(input, input_string_ncodeunits, num_repeats, output); + } + + static Status ValidateRepeatCount(const int64_t num_repeats) { + if (num_repeats < 0) { + return Status::Invalid("Repeat count must be a non-negative integer"); + } + return Status::OK(); + } +}; + +template +using BinaryRepeat = + StringBinaryTransformExec>; + +const FunctionDoc binary_repeat_doc( + "Repeat a binary string", + ("For each binary string in `strings`, return a replicated version."), + {"strings", "num_repeats"}); + +void AddBinaryRepeat(FunctionRegistry* registry) { + auto func = std::make_shared( + "binary_repeat", Arity::Binary(), &binary_repeat_doc); + for (const auto& ty : BaseBinaryTypes()) { + auto exec = GenerateVarBinaryToVarBinary(ty); + ScalarKernel kernel{{ty, int64()}, ty, exec}; + DCHECK_OK(func->AddKernel(std::move(kernel))); + } + DCHECK_OK(registry->AddFunction(std::move(func))); +} + // ---------------------------------------------------------------------- // Replace substring (plain, regex) @@ -4430,7 +4926,6 @@ const FunctionDoc utf8_reverse_doc( "clusters. Hence, it will not correctly reverse grapheme clusters\n" "composed of multiple codepoints."), {"strings"}); - } // namespace void RegisterScalarStringAscii(FunctionRegistry* registry) { @@ -4454,7 +4949,6 @@ void RegisterScalarStringAscii(FunctionRegistry* registry) { &ascii_rtrim_whitespace_doc); MakeUnaryStringBatchKernel("ascii_reverse", registry, &ascii_reverse_doc); MakeUnaryStringBatchKernel("utf8_reverse", registry, &utf8_reverse_doc); - MakeUnaryStringBatchKernelWithState("ascii_center", registry, &ascii_center_doc); MakeUnaryStringBatchKernelWithState("ascii_lpad", registry, &ascii_lpad_doc); @@ -4534,6 +5028,7 @@ void RegisterScalarStringAscii(FunctionRegistry* registry) { AddSplit(registry); AddStrptime(registry); AddBinaryJoin(registry); + AddBinaryRepeat(registry); } } // namespace internal diff --git a/cpp/src/arrow/compute/kernels/scalar_string_benchmark.cc b/cpp/src/arrow/compute/kernels/scalar_string_benchmark.cc index ddc3a56f00fdc..0977ea7806cb4 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string_benchmark.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string_benchmark.cc @@ -210,6 +210,29 @@ static void BinaryJoinElementWiseArrayArray(benchmark::State& state) { }); } +static void BinaryRepeat(benchmark::State& state) { + const int64_t array_length = 1 << 20; + const int64_t value_min_size = 0; + const int64_t value_max_size = 32; + const double null_probability = 0.01; + const int64_t repeat_min_size = 0; + const int64_t repeat_max_size = 8; + random::RandomArrayGenerator rng(kSeed); + + // NOTE: this produces only-Ascii data + auto values = + rng.String(array_length, value_min_size, value_max_size, null_probability); + auto num_repeats = rng.Int64(array_length, repeat_min_size, repeat_max_size, 0); + // Make sure lookup tables are initialized before measuring + ABORT_NOT_OK(CallFunction("binary_repeat", {values, num_repeats})); + + for (auto _ : state) { + ABORT_NOT_OK(CallFunction("binary_repeat", {values, num_repeats})); + } + state.SetItemsProcessed(state.iterations() * array_length); + state.SetBytesProcessed(state.iterations() * values->data()->buffers[2]->size()); +} + BENCHMARK(AsciiLower); BENCHMARK(AsciiUpper); BENCHMARK(IsAlphaNumericAscii); @@ -236,5 +259,7 @@ BENCHMARK(BinaryJoinArrayArray); BENCHMARK(BinaryJoinElementWiseArrayScalar)->RangeMultiplier(8)->Range(2, 128); BENCHMARK(BinaryJoinElementWiseArrayArray)->RangeMultiplier(8)->Range(2, 128); +BENCHMARK(BinaryRepeat); + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/scalar_string_test.cc b/cpp/src/arrow/compute/kernels/scalar_string_test.cc index be22ef4a7c1b3..4551e8c61e58f 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string_test.cc @@ -17,6 +17,8 @@ #include #include +#include +#include #include #include @@ -26,8 +28,10 @@ #endif #include "arrow/compute/api_scalar.h" +#include "arrow/compute/kernels/codegen_internal.h" #include "arrow/compute/kernels/test_util.h" #include "arrow/testing/gtest_util.h" +#include "arrow/type.h" namespace arrow { namespace compute { @@ -64,14 +68,6 @@ class BaseTestStringKernels : public ::testing::Test { CheckScalar(func_name, {Datum(input)}, Datum(expected), options); } - void CheckBinaryScalar(std::string func_name, std::string json_left_input, - std::string json_right_scalar, std::shared_ptr out_ty, - std::string json_expected, - const FunctionOptions* options = nullptr) { - CheckScalarBinaryScalar(func_name, type(), json_left_input, json_right_scalar, out_ty, - json_expected, options); - } - void CheckVarArgsScalar(std::string func_name, std::string json_input, std::shared_ptr out_ty, std::string json_expected, const FunctionOptions* options = nullptr) { @@ -1041,6 +1037,73 @@ TYPED_TEST(TestStringKernels, Utf8Title) { R"([null, "", "B", "Aaaz;Zææ&", "Ɑɽɽow", "Ii", "Ⱥ.Ⱥ.Ⱥ..Ⱥ", "Hello, World!", "Foo Bar;Héhé0Zop", "!%$^.,;"])"); } +TYPED_TEST(TestStringKernels, BinaryRepeatWithScalarRepeat) { + auto values = ArrayFromJSON(this->type(), + R"(["aAazZæÆ&", null, "", "b", "ɑɽⱤoW", "ıI", + "ⱥⱥⱥȺ", "hEllO, WoRld!", "$. A3", "!ɑⱤⱤow"])"); + std::vector> nrepeats_and_expected{{ + {0, R"(["", null, "", "", "", "", "", "", "", ""])"}, + {1, R"(["aAazZæÆ&", null, "", "b", "ɑɽⱤoW", "ıI", "ⱥⱥⱥȺ", "hEllO, WoRld!", + "$. A3", "!ɑⱤⱤow"])"}, + {4, R"(["aAazZæÆ&aAazZæÆ&aAazZæÆ&aAazZæÆ&", null, "", "bbbb", + "ɑɽⱤoWɑɽⱤoWɑɽⱤoWɑɽⱤoW", "ıIıIıIıI", "ⱥⱥⱥȺⱥⱥⱥȺⱥⱥⱥȺⱥⱥⱥȺ", + "hEllO, WoRld!hEllO, WoRld!hEllO, WoRld!hEllO, WoRld!", + "$. A3$. A3$. A3$. A3", "!ɑⱤⱤow!ɑⱤⱤow!ɑⱤⱤow!ɑⱤⱤow"])"}, + }}; + + for (const auto& pair : nrepeats_and_expected) { + auto num_repeat = pair.first; + auto expected = pair.second; + for (const auto& ty : IntTypes()) { + this->CheckVarArgs("binary_repeat", + {values, Datum(*arrow::MakeScalar(ty, num_repeat))}, + this->type(), expected); + } + } + + // Negative repeat count + for (auto num_repeat_ : {-1, -2, -5}) { + auto num_repeat = *arrow::MakeScalar(int64(), num_repeat_); + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, ::testing::HasSubstr("Repeat count must be a non-negative integer"), + CallFunction("binary_repeat", {values, num_repeat})); + } + + // Floating-point repeat count + for (auto num_repeat_ : {0.0, 1.2, -1.3}) { + auto num_repeat = *arrow::MakeScalar(float64(), num_repeat_); + EXPECT_RAISES_WITH_MESSAGE_THAT( + NotImplemented, ::testing::HasSubstr("has no kernel matching input types"), + CallFunction("binary_repeat", {values, num_repeat})); + } +} + +TYPED_TEST(TestStringKernels, BinaryRepeatWithArrayRepeat) { + auto values = ArrayFromJSON(this->type(), + R"([null, "aAazZæÆ&", "", "b", "ɑɽⱤoW", "ıI", + "ⱥⱥⱥȺ", "hEllO, WoRld!", "$. A3", "!ɑⱤⱤow"])"); + for (const auto& ty : IntTypes()) { + auto num_repeats = ArrayFromJSON(ty, R"([100, 1, 2, 5, 2, 0, 1, 3, null, 3])"); + std::string expected = + R"([null, "aAazZæÆ&", "", "bbbbb", "ɑɽⱤoWɑɽⱤoW", "", "ⱥⱥⱥȺ", + "hEllO, WoRld!hEllO, WoRld!hEllO, WoRld!", null, + "!ɑⱤⱤow!ɑⱤⱤow!ɑⱤⱤow"])"; + this->CheckVarArgs("binary_repeat", {values, num_repeats}, this->type(), expected); + } + + // Negative repeat count + auto num_repeats = ArrayFromJSON(int64(), R"([100, -1, 2, -5, 2, -1, 3, -2, 3, -100])"); + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, ::testing::HasSubstr("Repeat count must be a non-negative integer"), + CallFunction("binary_repeat", {values, num_repeats})); + + // Floating-point repeat count + num_repeats = ArrayFromJSON(float64(), R"([0.0, 1.2, -1.3])"); + EXPECT_RAISES_WITH_MESSAGE_THAT( + NotImplemented, ::testing::HasSubstr("has no kernel matching input types"), + CallFunction("binary_repeat", {values, num_repeats})); +} + TYPED_TEST(TestStringKernels, IsAlphaNumericUnicode) { // U+08BE (utf8: \xE0\xA2\xBE) is undefined, but utf8proc things it is // UTF8PROC_CATEGORY_LO diff --git a/cpp/src/arrow/util/bit_block_counter.h b/cpp/src/arrow/util/bit_block_counter.h index 460799036050c..9d3a75a5a3174 100644 --- a/cpp/src/arrow/util/bit_block_counter.h +++ b/cpp/src/arrow/util/bit_block_counter.h @@ -491,6 +491,54 @@ static void VisitBitBlocksVoid(const std::shared_ptr& bitmap_buf, int64_ } } +template +static Status VisitTwoBitBlocks(const std::shared_ptr& left_bitmap_buf, + int64_t left_offset, + const std::shared_ptr& right_bitmap_buf, + int64_t right_offset, int64_t length, + VisitNotNull&& visit_not_null, VisitNull&& visit_null) { + if (left_bitmap_buf == NULLPTR || right_bitmap_buf == NULLPTR) { + // At most one bitmap is present + if (left_bitmap_buf == NULLPTR) { + return VisitBitBlocks(right_bitmap_buf, right_offset, length, + std::forward(visit_not_null), + std::forward(visit_null)); + } else { + return VisitBitBlocks(left_bitmap_buf, left_offset, length, + std::forward(visit_not_null), + std::forward(visit_null)); + } + } + // Both bitmaps are present + const uint8_t* left_bitmap = left_bitmap_buf->data(); + const uint8_t* right_bitmap = right_bitmap_buf->data(); + BinaryBitBlockCounter bit_counter(left_bitmap, left_offset, right_bitmap, right_offset, + length); + int64_t position = 0; + while (position < length) { + BitBlockCount block = bit_counter.NextAndWord(); + if (block.AllSet()) { + for (int64_t i = 0; i < block.length; ++i, ++position) { + ARROW_RETURN_NOT_OK(visit_not_null(position)); + } + } else if (block.NoneSet()) { + for (int64_t i = 0; i < block.length; ++i, ++position) { + ARROW_RETURN_NOT_OK(visit_null()); + } + } else { + for (int64_t i = 0; i < block.length; ++i, ++position) { + if (BitUtil::GetBit(left_bitmap, left_offset + position) && + BitUtil::GetBit(right_bitmap, right_offset + position)) { + ARROW_RETURN_NOT_OK(visit_not_null(position)); + } else { + ARROW_RETURN_NOT_OK(visit_null()); + } + } + } + } + return Status::OK(); +} + template static void VisitTwoBitBlocksVoid(const std::shared_ptr& left_bitmap_buf, int64_t left_offset, diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst index 0a87752e92d4b..34b1f3448da7c 100644 --- a/docs/source/cpp/compute.rst +++ b/docs/source/cpp/compute.rst @@ -812,45 +812,47 @@ The third set of functions examines string elements on a byte-per-byte basis: String transforms ~~~~~~~~~~~~~~~~~ -+-------------------------+-------+------------------------+------------------------+-----------------------------------+-------+ -| Function name | Arity | Input types | Output type | Options class | Notes | -+=========================+=======+========================+========================+===================================+=======+ -| ascii_capitalize | Unary | String-like | String-like | | \(1) | -+-------------------------+-------+------------------------+------------------------+-----------------------------------+-------+ -| ascii_lower | Unary | String-like | String-like | | \(1) | -+-------------------------+-------+------------------------+------------------------+-----------------------------------+-------+ -| ascii_reverse | Unary | String-like | String-like | | \(2) | -+-------------------------+-------+------------------------+------------------------+-----------------------------------+-------+ -| ascii_swapcase | Unary | String-like | String-like | | \(1) | -+-------------------------+-------+------------------------+------------------------+-----------------------------------+-------+ -| ascii_title | Unary | String-like | String-like | | \(1) | -+-------------------------+-------+------------------------+------------------------+-----------------------------------+-------+ -| ascii_upper | Unary | String-like | String-like | | \(1) | -+-------------------------+-------+------------------------+------------------------+-----------------------------------+-------+ -| binary_length | Unary | Binary- or String-like | Int32 or Int64 | | \(3) | -+-------------------------+-------+------------------------+------------------------+-----------------------------------+-------+ -| binary_replace_slice | Unary | Binary- or String-like | Binary- or String-like | :struct:`ReplaceSliceOptions` | \(4) | -+-------------------------+-------+------------------------+------------------------+-----------------------------------+-------+ -| replace_substring | Unary | Binary- or String-like | Binary- or String-like | :struct:`ReplaceSubstringOptions` | \(5) | -+-------------------------+-------+------------------------+------------------------+-----------------------------------+-------+ -| replace_substring_regex | Unary | Binary- or String-like | Binary- or String-like | :struct:`ReplaceSubstringOptions` | \(6) | -+-------------------------+-------+------------------------+------------------------+-----------------------------------+-------+ -| utf8_capitalize | Unary | String-like | String-like | | \(8) | -+-------------------------+-------+------------------------+------------------------+-----------------------------------+-------+ -| utf8_length | Unary | String-like | Int32 or Int64 | | \(7) | -+-------------------------+-------+------------------------+------------------------+-----------------------------------+-------+ -| utf8_lower | Unary | String-like | String-like | | \(8) | -+-------------------------+-------+------------------------+------------------------+-----------------------------------+-------+ -| utf8_replace_slice | Unary | String-like | String-like | :struct:`ReplaceSliceOptions` | \(4) | -+-------------------------+-------+------------------------+------------------------+-----------------------------------+-------+ -| utf8_reverse | Unary | String-like | String-like | | \(9) | -+-------------------------+-------+------------------------+------------------------+-----------------------------------+-------+ -| utf8_swapcase | Unary | String-like | String-like | | \(8) | -+-------------------------+-------+------------------------+------------------------+-----------------------------------+-------+ -| utf8_title | Unary | String-like | String-like | | \(8) | -+-------------------------+-------+------------------------+------------------------+-----------------------------------+-------+ -| utf8_upper | Unary | String-like | String-like | | \(8) | -+-------------------------+-------+------------------------+------------------------+-----------------------------------+-------+ ++-------------------------+--------+-----------------------------------------+------------------------+-----------------------------------+-------+ +| Function name | Arity | Input types | Output type | Options class | Notes | ++=========================+========+=========================================+========================+===================================+=======+ +| ascii_capitalize | Unary | String-like | String-like | | \(1) | ++-------------------------+--------+-----------------------------------------+------------------------+-----------------------------------+-------+ +| ascii_lower | Unary | String-like | String-like | | \(1) | ++-------------------------+--------+-----------------------------------------+------------------------+-----------------------------------+-------+ +| ascii_reverse | Unary | String-like | String-like | | \(2) | ++-------------------------+--------+-----------------------------------------+------------------------+-----------------------------------+-------+ +| ascii_swapcase | Unary | String-like | String-like | | \(1) | ++-------------------------+--------+-----------------------------------------+------------------------+-----------------------------------+-------+ +| ascii_title | Unary | String-like | String-like | | \(1) | ++-------------------------+--------+-----------------------------------------+------------------------+-----------------------------------+-------+ +| ascii_upper | Unary | String-like | String-like | | \(1) | ++-------------------------+--------+-----------------------------------------+------------------------+-----------------------------------+-------+ +| binary_length | Unary | Binary- or String-like | Int32 or Int64 | | \(3) | ++-------------------------+--------+-----------------------------------------+------------------------+-----------------------------------+-------+ +| binary_repeat | Binary | Binary/String (Arg 0); Integral (Arg 1) | Binary- or String-like | | \(4) | ++-------------------------+--------+-----------------------------------------+------------------------+-----------------------------------+-------+ +| binary_replace_slice | Unary | String-like | Binary- or String-like | :struct:`ReplaceSliceOptions` | \(5) | ++-------------------------+--------+-----------------------------------------+------------------------+-----------------------------------+-------+ +| replace_substring | Unary | String-like | String-like | :struct:`ReplaceSubstringOptions` | \(6) | ++-------------------------+--------+-----------------------------------------+------------------------+-----------------------------------+-------+ +| replace_substring_regex | Unary | String-like | String-like | :struct:`ReplaceSubstringOptions` | \(7) | ++-------------------------+--------+-----------------------------------------+------------------------+-----------------------------------+-------+ +| utf8_capitalize | Unary | String-like | String-like | | \(8) | ++-------------------------+--------+-----------------------------------------+------------------------+-----------------------------------+-------+ +| utf8_length | Unary | String-like | Int32 or Int64 | | \(9) | ++-------------------------+--------+-----------------------------------------+------------------------+-----------------------------------+-------+ +| utf8_lower | Unary | String-like | String-like | | \(8) | ++-------------------------+--------+-----------------------------------------+------------------------+-----------------------------------+-------+ +| utf8_replace_slice | Unary | String-like | String-like | :struct:`ReplaceSliceOptions` | \(6) | ++-------------------------+--------+-----------------------------------------+------------------------+-----------------------------------+-------+ +| utf8_reverse | Unary | String-like | String-like | | \(10) | ++-------------------------+--------+-----------------------------------------+------------------------+-----------------------------------+-------+ +| utf8_swapcase | Unary | String-like | String-like | | \(8) | ++-------------------------+--------+-----------------------------------------+------------------------+-----------------------------------+-------+ +| utf8_title | Unary | String-like | String-like | | \(8) | ++-------------------------+--------+-----------------------------------------+------------------------+-----------------------------------+-------+ +| utf8_upper | Unary | String-like | String-like | | \(8) | ++-------------------------+--------+-----------------------------------------+------------------------+-----------------------------------+-------+ * \(1) Each ASCII character in the input is converted to lowercase or uppercase. Non-ASCII characters are left untouched. @@ -861,31 +863,33 @@ String transforms * \(3) Output is the physical length in bytes of each input element. Output type is Int32 for Binary/String, Int64 for LargeBinary/LargeString. -* \(4) Replace the slice of the substring from :member:`ReplaceSliceOptions::start` +* \(4) Repeat the input binary string a given number of times. + +* \(5) Replace the slice of the substring from :member:`ReplaceSliceOptions::start` (inclusive) to :member:`ReplaceSliceOptions::stop` (exclusive) by :member:`ReplaceSubstringOptions::replacement`. The binary kernel measures the slice in bytes, while the UTF8 kernel measures the slice in codeunits. -* \(5) Replace non-overlapping substrings that match to +* \(6) Replace non-overlapping substrings that match to :member:`ReplaceSubstringOptions::pattern` by :member:`ReplaceSubstringOptions::replacement`. If :member:`ReplaceSubstringOptions::max_replacements` != -1, it determines the maximum number of replacements made, counting from the left. -* \(6) Replace non-overlapping substrings that match to the regular expression +* \(7) Replace non-overlapping substrings that match to the regular expression :member:`ReplaceSubstringOptions::pattern` by :member:`ReplaceSubstringOptions::replacement`, using the Google RE2 library. If :member:`ReplaceSubstringOptions::max_replacements` != -1, it determines the maximum number of replacements made, counting from the left. Note that if the pattern contains groups, backreferencing can be used. -* \(7) Output is the number of characters (not bytes) of each input element. - Output type is Int32 for String, Int64 for LargeString. - * \(8) Each UTF8-encoded character in the input is converted to lowercase or uppercase. -* \(9) Each UTF8-encoded code unit is written in reverse order to the output. +* \(9) Output is the number of characters (not bytes) of each input element. + Output type is Int32 for String, Int64 for LargeString. + +* \(10) Each UTF8-encoded code unit is written in reverse order to the output. If the input is not valid UTF8, then the output is undefined (but the size of output buffers will be preserved). diff --git a/docs/source/python/api/compute.rst b/docs/source/python/api/compute.rst index 521182f8a41f5..225d853718fe1 100644 --- a/docs/source/python/api/compute.rst +++ b/docs/source/python/api/compute.rst @@ -267,6 +267,7 @@ String Transforms ascii_title ascii_upper binary_length + binary_repeat binary_replace_slice replace_substring replace_substring_regex diff --git a/r/R/expression.R b/r/R/expression.R index b1b6635f53812..9c2554f9e05c7 100644 --- a/r/R/expression.R +++ b/r/R/expression.R @@ -100,7 +100,9 @@ # use `%/%` above. "%%" = "divide_checked", "^" = "power_checked", - "%in%" = "is_in_meta_binary" + "%in%" = "is_in_meta_binary", + "strrep" = "binary_repeat", + "str_dup" = "binary_repeat" ) .array_function_map <- c(.unary_function_map, .binary_function_map) diff --git a/r/R/type.R b/r/R/type.R index 4ef7cefb56e2d..afa9a094af15f 100644 --- a/r/R/type.R +++ b/r/R/type.R @@ -481,12 +481,12 @@ canonical_type_str <- function(type_str) { } # vctrs support ----------------------------------------------------------- -str_dup <- function(x, times) { +duplicate_string <- function(x, times) { paste0(rep(x, times = times), collapse = "") } indent <- function(x, n) { - pad <- str_dup(" ", n) + pad <- duplicate_string(" ", n) sapply(x, gsub, pattern = "(\n+)", replacement = paste0("\\1", pad)) } diff --git a/r/tests/testthat/test-dplyr-funcs-string.R b/r/tests/testthat/test-dplyr-funcs-string.R index 05cf319978829..f0965926f291c 100644 --- a/r/tests/testthat/test-dplyr-funcs-string.R +++ b/r/tests/testthat/test-dplyr-funcs-string.R @@ -467,6 +467,25 @@ test_that("strsplit and str_split", { ) }) +test_that("strrep and str_dup", { + df <- tibble(x = c("foo1", " \tB a R\n", "!apACHe aRroW!")) + for (times in 0:8) { + compare_dplyr_binding( + .input %>% + mutate(x = strrep(x, times)) %>% + collect(), + df + ) + + compare_dplyr_binding( + .input %>% + mutate(x = str_dup(x, times)) %>% + collect(), + df + ) + } +}) + test_that("str_to_lower, str_to_upper, and str_to_title", { df <- tibble(x = c("foo1", " \tB a R\n", "!apACHe aRroW!")) compare_dplyr_binding(