From 3a5f28620e155308e4909fc8ec06fb8af3e19374 Mon Sep 17 00:00:00 2001 From: benibus Date: Sat, 8 Jul 2023 23:58:38 -0400 Subject: [PATCH 1/5] Avoid overflows in `SliceCodunitsTransform` --- .../compute/kernels/scalar_string_test.cc | 42 +++++++++++++++++++ .../compute/kernels/scalar_string_utf8.cc | 6 ++- 2 files changed, 46 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_string_test.cc b/cpp/src/arrow/compute/kernels/scalar_string_test.cc index 4581e6377a7fc..46fa06ec5afd0 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string_test.cc @@ -2140,6 +2140,48 @@ TYPED_TEST(TestStringKernels, SliceCodeunitsNegPos) { this->type(), R"(["", "", "ö", "õ", "ḍö", "šõ"])", &options_step_neg); } +// Tests where `start` is positive and `stop` is the max (positive) value +TYPED_TEST(TestStringKernels, SliceCodeunitsPosMax) { + // Test cases used here: https://github.com/apache/arrow/issues/36311 + SliceOptions options{/*start=*/0}; + options.step = 1; + this->CheckUnary("utf8_slice_codeunits", R"(["AB🎭C🎭ㇱD"])", this->type(), + R"(["AB🎭C🎭ㇱD"])", &options); + options.start = 2; + options.step = 4; + this->CheckUnary("utf8_slice_codeunits", R"(["AB🎭C🎭ㇱD"])", this->type(), R"(["🎭D"])", + &options); + + options.start = 2; + options.step = 1; + this->CheckUnary("utf8_slice_codeunits", R"(["", "𝑓", "𝑓ö", "𝑓öõ", "𝑓öõḍ", "𝑓öõḍš"])", + this->type(), R"(["", "", "", "õ", "õḍ", "õḍš"])", &options); + options.start = 1; + options.step = 2; + this->CheckUnary("utf8_slice_codeunits", R"(["", "𝑓", "𝑓ö", "𝑓öõ", "𝑓öõḍ", "𝑓öõḍš"])", + this->type(), R"(["", "", "ö", "ö", "öḍ", "öḍ"])", &options); + options.start = 3; + options.step = -2; + this->CheckUnary("utf8_slice_codeunits", R"(["", "𝑓", "𝑓ö", "𝑓öõ", "𝑓öõḍ", "𝑓öõḍš"])", + this->type(), R"(["", "", "", "", "", ""])", &options); +} + +// Tests where `start` is negative and `stop` is the max (positive) value +TYPED_TEST(TestStringKernels, SliceCodeunitsNegMax) { + SliceOptions options{/*start=*/-2}; + options.step = 1; + this->CheckUnary("utf8_slice_codeunits", R"(["", "𝑓", "𝑓ö", "𝑓öõ", "𝑓öõḍ", "𝑓öõḍš"])", + this->type(), R"(["", "𝑓", "𝑓ö", "öõ", "õḍ", "ḍš"])", &options); + options.start = -3; + options.step = 2; + this->CheckUnary("utf8_slice_codeunits", R"(["", "𝑓", "𝑓ö", "𝑓öõ", "𝑓öõḍ", "𝑓öõḍš"])", + this->type(), R"(["", "𝑓", "𝑓", "𝑓õ", "öḍ", "õš"])", &options); + options.start = -3; + options.step = -1; + this->CheckUnary("utf8_slice_codeunits", R"(["", "𝑓", "𝑓ö", "𝑓öõ", "𝑓öõḍ", "𝑓öõḍš"])", + this->type(), R"(["", "", "", "", "", ""])", &options); +} + #endif // ARROW_WITH_UTF8PROC TYPED_TEST(TestBinaryKernels, SliceBytesBasic) { diff --git a/cpp/src/arrow/compute/kernels/scalar_string_utf8.cc b/cpp/src/arrow/compute/kernels/scalar_string_utf8.cc index fb197e13a688b..e46c433541aad 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string_utf8.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string_utf8.cc @@ -1090,7 +1090,8 @@ struct SliceCodeunitsTransform : StringSliceTransformBase { // on the resulting slice lengths, so return a worst case estimate. return input_ncodeunits; } - int64_t max_slice_codepoints = (opt.stop - opt.start + opt.step - 1) / opt.step; + int64_t stop = std::min(opt.stop, input_ncodeunits); + int64_t max_slice_codepoints = (stop - opt.start + opt.step - 1) / opt.step; // The maximum UTF8 byte size of a codepoint is 4 return std::min(input_ncodeunits, 4 * ninputs * std::max(0, max_slice_codepoints)); @@ -1214,8 +1215,9 @@ struct SliceCodeunitsTransform : StringSliceTransformBase { // similar to opt.start if (opt.stop >= 0) { + int64_t length = std::min(opt.stop, std::numeric_limits::max() - 1) + 1; RETURN_IF_UTF8_ERROR( - arrow::util::UTF8AdvanceCodepoints(begin, end, &end_sliced, opt.stop + 1)); + arrow::util::UTF8AdvanceCodepoints(begin, end, &end_sliced, length)); } else { RETURN_IF_UTF8_ERROR(arrow::util::UTF8AdvanceCodepointsReverse( begin, end, &end_sliced, -opt.stop - 1)); From 40a7afc8612c434a80cf388ca22f08e1f06f1545 Mon Sep 17 00:00:00 2001 From: benibus Date: Mon, 10 Jul 2023 15:45:21 -0400 Subject: [PATCH 2/5] Handle `INT64_MIN` in `SliceCodeunitsTransform` Also reorganizes C++ tests --- .../compute/kernels/scalar_string_test.cc | 71 +++++++++---------- .../compute/kernels/scalar_string_utf8.cc | 2 +- 2 files changed, 33 insertions(+), 40 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_string_test.cc b/cpp/src/arrow/compute/kernels/scalar_string_test.cc index 46fa06ec5afd0..ff14f5e7a5c5d 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string_test.cc @@ -2091,6 +2091,14 @@ TYPED_TEST(TestStringKernels, SliceCodeunitsPosPos) { options_step_neg.stop = 0; this->CheckUnary("utf8_slice_codeunits", R"(["", "𝑓", "𝑓ö", "𝑓öõ", "𝑓öõḍ","𝑓öõḍš"])", this->type(), R"(["", "", "ö", "õ", "ḍö", "šõ"])", &options_step_neg); + + constexpr auto max = std::numeric_limits::max(); + SliceOptions options_max_step{1, max, 2}; + this->CheckUnary("utf8_slice_codeunits", R"(["", "𝑓", "𝑓ö", "𝑓öõ", "𝑓öõḍ", "𝑓öõḍš"])", + this->type(), R"(["", "", "ö", "ö", "öḍ", "öḍ"])", &options_max_step); + SliceOptions options_max_step_neg{1, max, -2}; + this->CheckUnary("utf8_slice_codeunits", R"(["", "𝑓", "𝑓ö", "𝑓öõ", "𝑓öõḍ", "𝑓öõḍš"])", + this->type(), R"(["", "", "", "", "", ""])", &options_max_step_neg); } TYPED_TEST(TestStringKernels, SliceCodeunitsPosNeg) { @@ -2107,6 +2115,15 @@ TYPED_TEST(TestStringKernels, SliceCodeunitsPosNeg) { this->CheckUnary("utf8_slice_codeunits", R"(["", "𝑓", "𝑓ö", "𝑓öõ", "𝑓öõḍ","𝑓öõḍš"])", this->type(), R"(["", "𝑓", "ö", "õ𝑓", "ḍö", "ḍö"])", &options_step_neg); + + constexpr auto min = std::numeric_limits::min(); + SliceOptions options_min_step{2, min, 2}; + this->CheckUnary("utf8_slice_codeunits", R"(["", "𝑓", "𝑓ö", "𝑓öõ", "𝑓öõḍ", "𝑓öõḍš"])", + this->type(), R"(["", "", "", "", "", ""])", &options_min_step); + SliceOptions options_min_step_neg{2, min, -2}; + this->CheckUnary("utf8_slice_codeunits", R"(["", "𝑓", "𝑓ö", "𝑓öõ", "𝑓öõḍ", "𝑓öõḍš"])", + this->type(), R"(["", "𝑓", "ö", "õ𝑓", "õ𝑓", "õ𝑓"])", + &options_min_step_neg); } TYPED_TEST(TestStringKernels, SliceCodeunitsNegNeg) { @@ -2123,6 +2140,15 @@ TYPED_TEST(TestStringKernels, SliceCodeunitsNegNeg) { this->CheckUnary("utf8_slice_codeunits", R"(["", "𝑓", "𝑓ö", "𝑓öõ", "𝑓öõḍ", "𝑓öõḍš"])", this->type(), R"(["", "𝑓", "ö", "õ𝑓", "ḍö", "šõ"])", &options_step_neg); + + constexpr auto min = std::numeric_limits::min(); + SliceOptions options_min_step{-2, min, 2}; + this->CheckUnary("utf8_slice_codeunits", R"(["", "𝑓", "𝑓ö", "𝑓öõ", "𝑓öõḍ", "𝑓öõḍš"])", + this->type(), R"(["", "", "", "", "", ""])", &options_min_step); + SliceOptions options_min_step_neg{-2, min, -2}; + this->CheckUnary("utf8_slice_codeunits", R"(["", "𝑓", "𝑓ö", "𝑓öõ", "𝑓öõḍ", "𝑓öõḍš"])", + this->type(), R"(["", "", "𝑓", "ö", "õ𝑓", "ḍö"])", + &options_min_step_neg); } TYPED_TEST(TestStringKernels, SliceCodeunitsNegPos) { @@ -2138,48 +2164,15 @@ TYPED_TEST(TestStringKernels, SliceCodeunitsNegPos) { options_step_neg.stop = 0; this->CheckUnary("utf8_slice_codeunits", R"(["", "𝑓", "𝑓ö", "𝑓öõ", "𝑓öõḍ", "𝑓öõḍš"])", this->type(), R"(["", "", "ö", "õ", "ḍö", "šõ"])", &options_step_neg); -} - -// Tests where `start` is positive and `stop` is the max (positive) value -TYPED_TEST(TestStringKernels, SliceCodeunitsPosMax) { - // Test cases used here: https://github.com/apache/arrow/issues/36311 - SliceOptions options{/*start=*/0}; - options.step = 1; - this->CheckUnary("utf8_slice_codeunits", R"(["AB🎭C🎭ㇱD"])", this->type(), - R"(["AB🎭C🎭ㇱD"])", &options); - options.start = 2; - options.step = 4; - this->CheckUnary("utf8_slice_codeunits", R"(["AB🎭C🎭ㇱD"])", this->type(), R"(["🎭D"])", - &options); - options.start = 2; - options.step = 1; - this->CheckUnary("utf8_slice_codeunits", R"(["", "𝑓", "𝑓ö", "𝑓öõ", "𝑓öõḍ", "𝑓öõḍš"])", - this->type(), R"(["", "", "", "õ", "õḍ", "õḍš"])", &options); - options.start = 1; - options.step = 2; - this->CheckUnary("utf8_slice_codeunits", R"(["", "𝑓", "𝑓ö", "𝑓öõ", "𝑓öõḍ", "𝑓öõḍš"])", - this->type(), R"(["", "", "ö", "ö", "öḍ", "öḍ"])", &options); - options.start = 3; - options.step = -2; - this->CheckUnary("utf8_slice_codeunits", R"(["", "𝑓", "𝑓ö", "𝑓öõ", "𝑓öõḍ", "𝑓öõḍš"])", - this->type(), R"(["", "", "", "", "", ""])", &options); -} - -// Tests where `start` is negative and `stop` is the max (positive) value -TYPED_TEST(TestStringKernels, SliceCodeunitsNegMax) { - SliceOptions options{/*start=*/-2}; - options.step = 1; - this->CheckUnary("utf8_slice_codeunits", R"(["", "𝑓", "𝑓ö", "𝑓öõ", "𝑓öõḍ", "𝑓öõḍš"])", - this->type(), R"(["", "𝑓", "𝑓ö", "öõ", "õḍ", "ḍš"])", &options); - options.start = -3; - options.step = 2; + constexpr auto max = std::numeric_limits::max(); + SliceOptions options_max_step{-3, max, 2}; this->CheckUnary("utf8_slice_codeunits", R"(["", "𝑓", "𝑓ö", "𝑓öõ", "𝑓öõḍ", "𝑓öõḍš"])", - this->type(), R"(["", "𝑓", "𝑓", "𝑓õ", "öḍ", "õš"])", &options); - options.start = -3; - options.step = -1; + this->type(), R"(["", "𝑓", "𝑓", "𝑓õ", "öḍ", "õš"])", + &options_max_step); + SliceOptions options_max_step_neg{-3, max, -2}; this->CheckUnary("utf8_slice_codeunits", R"(["", "𝑓", "𝑓ö", "𝑓öõ", "𝑓öõḍ", "𝑓öõḍš"])", - this->type(), R"(["", "", "", "", "", ""])", &options); + this->type(), R"(["", "", "", "", "", ""])", &options_max_step_neg); } #endif // ARROW_WITH_UTF8PROC diff --git a/cpp/src/arrow/compute/kernels/scalar_string_utf8.cc b/cpp/src/arrow/compute/kernels/scalar_string_utf8.cc index e46c433541aad..d6b984c27a481 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string_utf8.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string_utf8.cc @@ -1090,7 +1090,7 @@ struct SliceCodeunitsTransform : StringSliceTransformBase { // on the resulting slice lengths, so return a worst case estimate. return input_ncodeunits; } - int64_t stop = std::min(opt.stop, input_ncodeunits); + int64_t stop = std::clamp(opt.stop, -input_ncodeunits, input_ncodeunits); int64_t max_slice_codepoints = (stop - opt.start + opt.step - 1) / opt.step; // The maximum UTF8 byte size of a codepoint is 4 return std::min(input_ncodeunits, From f6e013e50f547ac87b95d0c14264a9ccafbb6ddf Mon Sep 17 00:00:00 2001 From: benibus Date: Mon, 10 Jul 2023 15:47:34 -0400 Subject: [PATCH 3/5] Update Python bindings for `SliceOptions` When `stop=None`, negates the default value of `sys.maxsize` if `step < 0` to mimic the behavior of Python arrays --- python/pyarrow/_compute.pyx | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index ab63a7a19f7f6..ac7efeff41aba 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -1201,6 +1201,8 @@ class SliceOptions(_SliceOptions): def __init__(self, start, stop=None, step=1): if stop is None: stop = sys.maxsize + if step < 0: + stop = -stop self._set_options(start, stop, step) From 03fa729d317545c9a41668cb34fc4aa6b83a75e2 Mon Sep 17 00:00:00 2001 From: benibus Date: Mon, 10 Jul 2023 15:52:29 -0400 Subject: [PATCH 4/5] Test `stop=None` for Python --- python/pyarrow/tests/test_compute.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py index 865fecc7b2291..e47e5d3f3eb3b 100644 --- a/python/pyarrow/tests/test_compute.py +++ b/python/pyarrow/tests/test_compute.py @@ -537,7 +537,7 @@ def test_trim(): def test_slice_compatibility(): arr = pa.array(["", "𝑓", "𝑓ö", "𝑓öõ", "𝑓öõḍ", "𝑓öõḍš"]) for start in range(-6, 6): - for stop in range(-6, 6): + for stop in itertools.chain(range(-6, 6), [None]): for step in [-3, -2, -1, 1, 2, 3]: expected = pa.array([k.as_py()[start:stop:step] for k in arr]) From 58eb8737405fe177c17ed8167b94acd9dad2cb3b Mon Sep 17 00:00:00 2001 From: benibus Date: Mon, 10 Jul 2023 17:07:12 -0400 Subject: [PATCH 5/5] Fix negation UB --- cpp/src/arrow/compute/kernels/scalar_string_utf8.cc | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_string_utf8.cc b/cpp/src/arrow/compute/kernels/scalar_string_utf8.cc index d6b984c27a481..cf8a697fea411 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string_utf8.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string_utf8.cc @@ -1134,7 +1134,7 @@ struct SliceCodeunitsTransform : StringSliceTransformBase { } else if (opt.stop < 0) { // or from the end (but we will never need to < begin_sliced) RETURN_IF_UTF8_ERROR(arrow::util::UTF8AdvanceCodepointsReverse( - begin_sliced, end, &end_sliced, -opt.stop)); + begin_sliced, end, &end_sliced, Negate(opt.stop))); } else { // zero length slice return 0; @@ -1159,7 +1159,7 @@ struct SliceCodeunitsTransform : StringSliceTransformBase { // or begin_sliced), but begin_sliced and opt.start can be 'out of sync', // for instance when start=-100, when the string length is only 10. RETURN_IF_UTF8_ERROR(arrow::util::UTF8AdvanceCodepointsReverse( - begin_sliced, end, &end_sliced, -opt.stop)); + begin_sliced, end, &end_sliced, Negate(opt.stop))); } else { // zero length slice return 0; @@ -1220,7 +1220,7 @@ struct SliceCodeunitsTransform : StringSliceTransformBase { arrow::util::UTF8AdvanceCodepoints(begin, end, &end_sliced, length)); } else { RETURN_IF_UTF8_ERROR(arrow::util::UTF8AdvanceCodepointsReverse( - begin, end, &end_sliced, -opt.stop - 1)); + begin, end, &end_sliced, Negate(opt.stop) - 1)); } end_sliced--; @@ -1242,6 +1242,12 @@ struct SliceCodeunitsTransform : StringSliceTransformBase { } #undef RETURN_IF_UTF8_ERROR + + private: + static int64_t Negate(int64_t v) { + constexpr auto max = std::numeric_limits::max(); + return -max > v ? max : -v; + } }; template