Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-35789: [C++] Remove check_overflow from CumulativeSumOptions #35790

Merged
merged 2 commits into from
Jun 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 7 additions & 12 deletions cpp/src/arrow/compute/api_vector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,7 @@ static auto kSelectKOptionsType = GetFunctionOptionsType<SelectKOptions>(
DataMember("sort_keys", &SelectKOptions::sort_keys));
static auto kCumulativeSumOptionsType = GetFunctionOptionsType<CumulativeSumOptions>(
DataMember("start", &CumulativeSumOptions::start),
DataMember("skip_nulls", &CumulativeSumOptions::skip_nulls),
DataMember("check_overflow", &CumulativeSumOptions::check_overflow));
DataMember("skip_nulls", &CumulativeSumOptions::skip_nulls));
static auto kRankOptionsType = GetFunctionOptionsType<RankOptions>(
DataMember("sort_keys", &RankOptions::sort_keys),
DataMember("null_placement", &RankOptions::null_placement),
Expand Down Expand Up @@ -199,16 +198,12 @@ SelectKOptions::SelectKOptions(int64_t k, std::vector<SortKey> sort_keys)
sort_keys(std::move(sort_keys)) {}
constexpr char SelectKOptions::kTypeName[];

CumulativeSumOptions::CumulativeSumOptions(double start, bool skip_nulls,
bool check_overflow)
: CumulativeSumOptions(std::make_shared<DoubleScalar>(start), skip_nulls,
check_overflow) {}
CumulativeSumOptions::CumulativeSumOptions(std::shared_ptr<Scalar> start, bool skip_nulls,
bool check_overflow)
CumulativeSumOptions::CumulativeSumOptions(double start, bool skip_nulls)
: CumulativeSumOptions(std::make_shared<DoubleScalar>(start), skip_nulls) {}
CumulativeSumOptions::CumulativeSumOptions(std::shared_ptr<Scalar> start, bool skip_nulls)
: FunctionOptions(internal::kCumulativeSumOptionsType),
start(std::move(start)),
skip_nulls(skip_nulls),
check_overflow(check_overflow) {}
skip_nulls(skip_nulls) {}
constexpr char CumulativeSumOptions::kTypeName[];

RankOptions::RankOptions(std::vector<SortKey> sort_keys, NullPlacement null_placement,
Expand Down Expand Up @@ -381,8 +376,8 @@ Result<std::shared_ptr<Array>> DropNull(const Array& values, ExecContext* ctx) {
// Cumulative functions

Result<Datum> CumulativeSum(const Datum& values, const CumulativeSumOptions& options,
ExecContext* ctx) {
auto func_name = (options.check_overflow) ? "cumulative_sum_checked" : "cumulative_sum";
bool check_overflow, ExecContext* ctx) {
auto func_name = check_overflow ? "cumulative_sum_checked" : "cumulative_sum";
return CallFunction(func_name, {Datum(values)}, &options, ctx);
}

Expand Down
18 changes: 10 additions & 8 deletions cpp/src/arrow/compute/api_vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -213,10 +213,8 @@ class ARROW_EXPORT PartitionNthOptions : public FunctionOptions {
/// \brief Options for cumulative sum function
class ARROW_EXPORT CumulativeSumOptions : public FunctionOptions {
public:
explicit CumulativeSumOptions(double start = 0, bool skip_nulls = false,
bool check_overflow = false);
explicit CumulativeSumOptions(std::shared_ptr<Scalar> start, bool skip_nulls = false,
bool check_overflow = false);
explicit CumulativeSumOptions(double start = 0, bool skip_nulls = false);
explicit CumulativeSumOptions(std::shared_ptr<Scalar> start, bool skip_nulls = false);
static constexpr char const kTypeName[] = "CumulativeSumOptions";
static CumulativeSumOptions Defaults() { return CumulativeSumOptions(); }

Expand All @@ -226,9 +224,6 @@ class ARROW_EXPORT CumulativeSumOptions : public FunctionOptions {
/// If true, nulls in the input are ignored and produce a corresponding null output.
/// When false, the first null encountered is propagated through the remaining output.
bool skip_nulls = false;

/// When true, returns an Invalid Status when overflow is detected
bool check_overflow = false;
};

/// @}
Expand Down Expand Up @@ -597,11 +592,18 @@ Result<Datum> RunEndEncode(
ARROW_EXPORT
Result<Datum> RunEndDecode(const Datum& value, ExecContext* ctx = NULLPTR);

/// \brief Compute the cumulative sum of an array-like object
///
/// \param[in] values array-like input
/// \param[in] options configures cumulative sum behavior
/// \param[in] check_overflow whether to check for overflow, if true, return Invalid
/// status on overflow, otherwise wrap around on overflow
/// \param[in] ctx the function execution context, optional
ARROW_EXPORT
Result<Datum> CumulativeSum(
const Datum& values,
const CumulativeSumOptions& options = CumulativeSumOptions::Defaults(),
ExecContext* ctx = NULLPTR);
bool check_overflow = false, ExecContext* ctx = NULLPTR);
pitrou marked this conversation as resolved.
Show resolved Hide resolved

// ----------------------------------------------------------------------
// Deprecated functions
Expand Down
12 changes: 12 additions & 0 deletions cpp/src/arrow/compute/kernels/vector_cumulative_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,15 @@

#include "arrow/array.h"
#include "arrow/chunked_array.h"
#include "arrow/compute/api_vector.h"
#include "arrow/testing/gtest_util.h"
#include "arrow/testing/util.h"
#include "arrow/type.h"

#include "arrow/array/builder_primitive.h"
#include "arrow/compute/api.h"
#include "arrow/compute/kernels/test_util.h"
#include "arrow/type_fwd.h"

namespace arrow {
namespace compute {
Expand Down Expand Up @@ -344,5 +346,15 @@ TEST(TestCumulativeSum, HasStartDoSkip) {
}
}

TEST(TestCumulativeSum, ConvenienceFunctionCheckOverflow) {
ASSERT_ARRAYS_EQUAL(*CumulativeSum(ArrayFromJSON(int8(), "[127, 1]"),
CumulativeSumOptions::Defaults(), false)
->make_array(),
*ArrayFromJSON(int8(), "[127, -128]"));

EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, HasSubstr("overflow"),
CumulativeSum(ArrayFromJSON(int8(), "[127, 1]"),
CumulativeSumOptions::Defaults(), true));
}
} // namespace compute
} // namespace arrow