diff --git a/cpp/src/arrow/compute/api_vector.cc b/cpp/src/arrow/compute/api_vector.cc index d3ca2ae15ddda..5044d4f25690a 100644 --- a/cpp/src/arrow/compute/api_vector.cc +++ b/cpp/src/arrow/compute/api_vector.cc @@ -144,8 +144,7 @@ static auto kSelectKOptionsType = GetFunctionOptionsType( DataMember("sort_keys", &SelectKOptions::sort_keys)); static auto kCumulativeSumOptionsType = GetFunctionOptionsType( DataMember("start", &CumulativeSumOptions::start), - DataMember("skip_nulls", &CumulativeSumOptions::skip_nulls), - DataMember("check_overflow", &CumulativeSumOptions::check_overflow)); + DataMember("skip_nulls", &CumulativeSumOptions::skip_nulls)); static auto kRankOptionsType = GetFunctionOptionsType( DataMember("sort_keys", &RankOptions::sort_keys), DataMember("null_placement", &RankOptions::null_placement), @@ -199,16 +198,12 @@ SelectKOptions::SelectKOptions(int64_t k, std::vector sort_keys) sort_keys(std::move(sort_keys)) {} constexpr char SelectKOptions::kTypeName[]; -CumulativeSumOptions::CumulativeSumOptions(double start, bool skip_nulls, - bool check_overflow) - : CumulativeSumOptions(std::make_shared(start), skip_nulls, - check_overflow) {} -CumulativeSumOptions::CumulativeSumOptions(std::shared_ptr start, bool skip_nulls, - bool check_overflow) +CumulativeSumOptions::CumulativeSumOptions(double start, bool skip_nulls) + : CumulativeSumOptions(std::make_shared(start), skip_nulls) {} +CumulativeSumOptions::CumulativeSumOptions(std::shared_ptr start, bool skip_nulls) : FunctionOptions(internal::kCumulativeSumOptionsType), start(std::move(start)), - skip_nulls(skip_nulls), - check_overflow(check_overflow) {} + skip_nulls(skip_nulls) {} constexpr char CumulativeSumOptions::kTypeName[]; RankOptions::RankOptions(std::vector sort_keys, NullPlacement null_placement, @@ -381,8 +376,8 @@ Result> DropNull(const Array& values, ExecContext* ctx) { // Cumulative functions Result CumulativeSum(const Datum& values, const CumulativeSumOptions& options, - ExecContext* ctx) { - auto func_name = (options.check_overflow) ? "cumulative_sum_checked" : "cumulative_sum"; + bool check_overflow, ExecContext* ctx) { + auto func_name = check_overflow ? "cumulative_sum_checked" : "cumulative_sum"; return CallFunction(func_name, {Datum(values)}, &options, ctx); } diff --git a/cpp/src/arrow/compute/api_vector.h b/cpp/src/arrow/compute/api_vector.h index 2ec1cf959cbdc..d02c505f3e59a 100644 --- a/cpp/src/arrow/compute/api_vector.h +++ b/cpp/src/arrow/compute/api_vector.h @@ -213,10 +213,8 @@ class ARROW_EXPORT PartitionNthOptions : public FunctionOptions { /// \brief Options for cumulative sum function class ARROW_EXPORT CumulativeSumOptions : public FunctionOptions { public: - explicit CumulativeSumOptions(double start = 0, bool skip_nulls = false, - bool check_overflow = false); - explicit CumulativeSumOptions(std::shared_ptr start, bool skip_nulls = false, - bool check_overflow = false); + explicit CumulativeSumOptions(double start = 0, bool skip_nulls = false); + explicit CumulativeSumOptions(std::shared_ptr start, bool skip_nulls = false); static constexpr char const kTypeName[] = "CumulativeSumOptions"; static CumulativeSumOptions Defaults() { return CumulativeSumOptions(); } @@ -226,9 +224,6 @@ class ARROW_EXPORT CumulativeSumOptions : public FunctionOptions { /// If true, nulls in the input are ignored and produce a corresponding null output. /// When false, the first null encountered is propagated through the remaining output. bool skip_nulls = false; - - /// When true, returns an Invalid Status when overflow is detected - bool check_overflow = false; }; /// @} @@ -597,11 +592,18 @@ Result RunEndEncode( ARROW_EXPORT Result RunEndDecode(const Datum& value, ExecContext* ctx = NULLPTR); +/// \brief Compute the cumulative sum of an array-like object +/// +/// \param[in] values array-like input +/// \param[in] options configures cumulative sum behavior +/// \param[in] check_overflow whether to check for overflow, if true, return Invalid +/// status on overflow, otherwise wrap around on overflow +/// \param[in] ctx the function execution context, optional ARROW_EXPORT Result CumulativeSum( const Datum& values, const CumulativeSumOptions& options = CumulativeSumOptions::Defaults(), - ExecContext* ctx = NULLPTR); + bool check_overflow = false, ExecContext* ctx = NULLPTR); // ---------------------------------------------------------------------- // Deprecated functions diff --git a/cpp/src/arrow/compute/kernels/vector_cumulative_ops_test.cc b/cpp/src/arrow/compute/kernels/vector_cumulative_ops_test.cc index 9ec287b537d64..3c6bb3c1d10d9 100644 --- a/cpp/src/arrow/compute/kernels/vector_cumulative_ops_test.cc +++ b/cpp/src/arrow/compute/kernels/vector_cumulative_ops_test.cc @@ -23,6 +23,7 @@ #include "arrow/array.h" #include "arrow/chunked_array.h" +#include "arrow/compute/api_vector.h" #include "arrow/testing/gtest_util.h" #include "arrow/testing/util.h" #include "arrow/type.h" @@ -30,6 +31,7 @@ #include "arrow/array/builder_primitive.h" #include "arrow/compute/api.h" #include "arrow/compute/kernels/test_util.h" +#include "arrow/type_fwd.h" namespace arrow { namespace compute { @@ -344,5 +346,15 @@ TEST(TestCumulativeSum, HasStartDoSkip) { } } +TEST(TestCumulativeSum, ConvenienceFunctionCheckOverflow) { + ASSERT_ARRAYS_EQUAL(*CumulativeSum(ArrayFromJSON(int8(), "[127, 1]"), + CumulativeSumOptions::Defaults(), false) + ->make_array(), + *ArrayFromJSON(int8(), "[127, -128]")); + + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, HasSubstr("overflow"), + CumulativeSum(ArrayFromJSON(int8(), "[127, 1]"), + CumulativeSumOptions::Defaults(), true)); +} } // namespace compute } // namespace arrow