Skip to content

Commit

Permalink
GH-32190: [C++][Compute] Implement cumulative prod, max and min funct…
Browse files Browse the repository at this point in the history
…ions (#36020)

### Rationale for this change

Implement cumulative prod, max and min compute functions
### What changes are included in this PR?

1. Add implementations, docs and tests for the three functions.
2. Refactor `CumulativeSumOptions` to `CumulativeOptions` for reusability.
3. Fix a bug where `GenericFromScalar(GenericToScalar(std::nullopt))  != std::nullopt`.
4. Remove an unnecessary Cast with the default start value.
5. Add tests to check behavior with `NaN`.

I'll explain some of the changes in comments.

### Are these changes tested?

Yes, in vector_accumulative_ops_test.cc and test_compute.py

### Are there any user-facing changes?

No. The data members of `CumulativeSumOptions` are changed, but the member functions behave as before. And std::optional<T> also can be constructed directly from T. So users should not feel any difference.
* Closes: #32190

Lead-authored-by: Jin Shang <[email protected]>
Co-authored-by: Benjamin Kietzman <[email protected]>
Signed-off-by: Benjamin Kietzman <[email protected]>
  • Loading branch information
js8544 and bkietz authored Jun 22, 2023
1 parent 990c297 commit e3eb589
Show file tree
Hide file tree
Showing 14 changed files with 1,096 additions and 185 deletions.
39 changes: 29 additions & 10 deletions cpp/src/arrow/compute/api_vector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "arrow/array/array_nested.h"
#include "arrow/array/builder_primitive.h"
#include "arrow/compute/exec.h"
#include "arrow/compute/function.h"
#include "arrow/compute/function_internal.h"
#include "arrow/compute/kernels/vector_sort_internal.h"
#include "arrow/compute/registry.h"
Expand Down Expand Up @@ -142,9 +143,9 @@ static auto kPartitionNthOptionsType = GetFunctionOptionsType<PartitionNthOption
static auto kSelectKOptionsType = GetFunctionOptionsType<SelectKOptions>(
DataMember("k", &SelectKOptions::k),
DataMember("sort_keys", &SelectKOptions::sort_keys));
static auto kCumulativeSumOptionsType = GetFunctionOptionsType<CumulativeSumOptions>(
DataMember("start", &CumulativeSumOptions::start),
DataMember("skip_nulls", &CumulativeSumOptions::skip_nulls));
static auto kCumulativeOptionsType = GetFunctionOptionsType<CumulativeOptions>(
DataMember("start", &CumulativeOptions::start),
DataMember("skip_nulls", &CumulativeOptions::skip_nulls));
static auto kRankOptionsType = GetFunctionOptionsType<RankOptions>(
DataMember("sort_keys", &RankOptions::sort_keys),
DataMember("null_placement", &RankOptions::null_placement),
Expand Down Expand Up @@ -198,13 +199,15 @@ SelectKOptions::SelectKOptions(int64_t k, std::vector<SortKey> sort_keys)
sort_keys(std::move(sort_keys)) {}
constexpr char SelectKOptions::kTypeName[];

CumulativeSumOptions::CumulativeSumOptions(double start, bool skip_nulls)
: CumulativeSumOptions(std::make_shared<DoubleScalar>(start), skip_nulls) {}
CumulativeSumOptions::CumulativeSumOptions(std::shared_ptr<Scalar> start, bool skip_nulls)
: FunctionOptions(internal::kCumulativeSumOptionsType),
CumulativeOptions::CumulativeOptions(bool skip_nulls)
: FunctionOptions(internal::kCumulativeOptionsType), skip_nulls(skip_nulls) {}
CumulativeOptions::CumulativeOptions(double start, bool skip_nulls)
: CumulativeOptions(std::make_shared<DoubleScalar>(start), skip_nulls) {}
CumulativeOptions::CumulativeOptions(std::shared_ptr<Scalar> start, bool skip_nulls)
: FunctionOptions(internal::kCumulativeOptionsType),
start(std::move(start)),
skip_nulls(skip_nulls) {}
constexpr char CumulativeSumOptions::kTypeName[];
constexpr char CumulativeOptions::kTypeName[];

RankOptions::RankOptions(std::vector<SortKey> sort_keys, NullPlacement null_placement,
RankOptions::Tiebreaker tiebreaker)
Expand All @@ -224,7 +227,7 @@ void RegisterVectorOptions(FunctionRegistry* registry) {
DCHECK_OK(registry->AddFunctionOptionsType(kSortOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kPartitionNthOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kSelectKOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kCumulativeSumOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kCumulativeOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kRankOptionsType));
}
} // namespace internal
Expand Down Expand Up @@ -375,12 +378,28 @@ Result<std::shared_ptr<Array>> DropNull(const Array& values, ExecContext* ctx) {
// ----------------------------------------------------------------------
// Cumulative functions

Result<Datum> CumulativeSum(const Datum& values, const CumulativeSumOptions& options,
Result<Datum> CumulativeSum(const Datum& values, const CumulativeOptions& options,
bool check_overflow, ExecContext* ctx) {
auto func_name = check_overflow ? "cumulative_sum_checked" : "cumulative_sum";
return CallFunction(func_name, {Datum(values)}, &options, ctx);
}

Result<Datum> CumulativeProd(const Datum& values, const CumulativeOptions& options,
bool check_overflow, ExecContext* ctx) {
auto func_name = check_overflow ? "cumulative_prod_checked" : "cumulative_prod";
return CallFunction(func_name, {Datum(values)}, &options, ctx);
}

Result<Datum> CumulativeMax(const Datum& values, const CumulativeOptions& options,
ExecContext* ctx) {
return CallFunction("cumulative_max", {Datum(values)}, &options, ctx);
}

Result<Datum> CumulativeMin(const Datum& values, const CumulativeOptions& options,
ExecContext* ctx) {
return CallFunction("cumulative_min", {Datum(values)}, &options, ctx);
}

// ----------------------------------------------------------------------
// Deprecated functions

Expand Down
61 changes: 50 additions & 11 deletions cpp/src/arrow/compute/api_vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,21 +210,29 @@ class ARROW_EXPORT PartitionNthOptions : public FunctionOptions {
NullPlacement null_placement;
};

/// \brief Options for cumulative sum function
class ARROW_EXPORT CumulativeSumOptions : public FunctionOptions {
/// \brief Options for cumulative functions
/// \note Also aliased as CumulativeSumOptions for backward compatibility
class ARROW_EXPORT CumulativeOptions : public FunctionOptions {
public:
explicit CumulativeSumOptions(double start = 0, bool skip_nulls = false);
explicit CumulativeSumOptions(std::shared_ptr<Scalar> start, bool skip_nulls = false);
static constexpr char const kTypeName[] = "CumulativeSumOptions";
static CumulativeSumOptions Defaults() { return CumulativeSumOptions(); }

/// Optional starting value for cumulative operation computation
std::shared_ptr<Scalar> start;
explicit CumulativeOptions(bool skip_nulls = false);
explicit CumulativeOptions(double start, bool skip_nulls = false);
explicit CumulativeOptions(std::shared_ptr<Scalar> start, bool skip_nulls = false);
static constexpr char const kTypeName[] = "CumulativeOptions";
static CumulativeOptions Defaults() { return CumulativeOptions(); }

/// Optional starting value for cumulative operation computation, default depends on the
/// operation and input type.
/// - sum: 0
/// - prod: 1
/// - min: maximum of the input type
/// - max: minimum of the input type
std::optional<std::shared_ptr<Scalar>> start;

/// If true, nulls in the input are ignored and produce a corresponding null output.
/// When false, the first null encountered is propagated through the remaining output.
bool skip_nulls = false;
};
using CumulativeSumOptions = CumulativeOptions; // For backward compatibility

/// @}

Expand Down Expand Up @@ -607,10 +615,41 @@ Result<Datum> RunEndDecode(const Datum& value, ExecContext* ctx = NULLPTR);
/// \param[in] ctx the function execution context, optional
ARROW_EXPORT
Result<Datum> CumulativeSum(
const Datum& values,
const CumulativeSumOptions& options = CumulativeSumOptions::Defaults(),
const Datum& values, const CumulativeOptions& options = CumulativeOptions::Defaults(),
bool check_overflow = false, ExecContext* ctx = NULLPTR);

/// \brief Compute the cumulative product of an array-like object
///
/// \param[in] values array-like input
/// \param[in] options configures cumulative prod behavior
/// \param[in] check_overflow whether to check for overflow, if true, return Invalid
/// status on overflow, otherwise wrap around on overflow
/// \param[in] ctx the function execution context, optional
ARROW_EXPORT
Result<Datum> CumulativeProd(
const Datum& values, const CumulativeOptions& options = CumulativeOptions::Defaults(),
bool check_overflow = false, ExecContext* ctx = NULLPTR);

/// \brief Compute the cumulative max of an array-like object
///
/// \param[in] values array-like input
/// \param[in] options configures cumulative max behavior
/// \param[in] ctx the function execution context, optional
ARROW_EXPORT
Result<Datum> CumulativeMax(
const Datum& values, const CumulativeOptions& options = CumulativeOptions::Defaults(),
ExecContext* ctx = NULLPTR);

/// \brief Compute the cumulative min of an array-like object
///
/// \param[in] values array-like input
/// \param[in] options configures cumulative min behavior
/// \param[in] ctx the function execution context, optional
ARROW_EXPORT
Result<Datum> CumulativeMin(
const Datum& values, const CumulativeOptions& options = CumulativeOptions::Defaults(),
ExecContext* ctx = NULLPTR);

// ----------------------------------------------------------------------
// Deprecated functions

Expand Down
54 changes: 28 additions & 26 deletions cpp/src/arrow/compute/function_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "arrow/compute/function.h"
#include "arrow/compute/type_fwd.h"
#include "arrow/result.h"
#include "arrow/scalar.h"
#include "arrow/status.h"
#include "arrow/util/checked_cast.h"
#include "arrow/util/key_value_metadata.h"
Expand Down Expand Up @@ -283,12 +284,6 @@ static inline Result<decltype(MakeScalar(std::declval<T>()))> GenericToScalar(
return MakeScalar(value);
}

template <typename T>
static inline Result<decltype(MakeScalar(std::declval<T>()))> GenericToScalar(
const std::optional<T>& value) {
return value.has_value() ? MakeScalar(value.value()) : MakeScalar("");
}

// For Clang/libc++: when iterating through vector<bool>, we can't
// pass it by reference so the overload above doesn't apply
static inline Result<std::shared_ptr<Scalar>> GenericToScalar(bool value) {
Expand Down Expand Up @@ -382,6 +377,16 @@ static inline Result<std::shared_ptr<Scalar>> GenericToScalar(const Datum& value
}
}

static inline Result<std::shared_ptr<Scalar>> GenericToScalar(std::nullopt_t) {
return std::make_shared<NullScalar>();
}

template <typename T>
static inline auto GenericToScalar(const std::optional<T>& value)
-> Result<decltype(MakeScalar(value.value()))> {
return value.has_value() ? MakeScalar(value.value()) : std::make_shared<NullScalar>();
}

template <typename T>
static inline enable_if_primitive_ctype<typename CTypeTraits<T>::ArrowType, Result<T>>
GenericFromScalar(const std::shared_ptr<Scalar>& value) {
Expand All @@ -404,26 +409,6 @@ GenericFromScalar(const std::shared_ptr<Scalar>& value) {
return ValidateEnumValue<T>(raw_val);
}

template <typename>
constexpr bool is_optional_impl = false;
template <typename T>
constexpr bool is_optional_impl<std::optional<T>> = true;

template <typename T>
using is_optional =
std::integral_constant<bool, is_optional_impl<std::decay_t<T>> ||
std::is_same<T, std::nullopt_t>::value>;

template <typename T, typename R = void>
using enable_if_optional = enable_if_t<is_optional<T>::value, Result<T>>;

template <typename T>
static inline enable_if_optional<T> GenericFromScalar(
const std::shared_ptr<Scalar>& value) {
using value_type = typename T::value_type;
return GenericFromScalar<value_type>(value);
}

template <typename T, typename U>
using enable_if_same_result = enable_if_same<T, U, Result<T>>;

Expand Down Expand Up @@ -510,6 +495,23 @@ static inline enable_if_same_result<T, Datum> GenericFromScalar(
return Status::Invalid("Cannot deserialize Datum from ", value->ToString());
}

template <typename>
constexpr inline bool is_optional_v = false;
template <typename T>
constexpr inline bool is_optional_v<std::optional<T>> = true;
template <>
constexpr inline bool is_optional_v<std::nullopt_t> = true;

template <typename T>
static inline std::enable_if_t<is_optional_v<T>, Result<T>> GenericFromScalar(
const std::shared_ptr<Scalar>& value) {
using value_type = typename T::value_type;
if (value->type->id() == Type::NA) {
return std::nullopt;
}
return GenericFromScalar<value_type>(value);
}

template <typename T>
static enable_if_same<typename CTypeTraits<T>::ArrowType, ListType, Result<T>>
GenericFromScalar(const std::shared_ptr<Scalar>& value) {
Expand Down
81 changes: 81 additions & 0 deletions cpp/src/arrow/compute/kernels/base_arithmetic_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#pragma once

#include <limits>
#include "arrow/compute/api_scalar.h"
#include "arrow/compute/kernels/common_internal.h"
#include "arrow/compute/kernels/util_internal.h"
Expand Down Expand Up @@ -605,6 +606,86 @@ struct Sign {
}
};

struct Max {
template <typename T, typename Arg0, typename Arg1>
static constexpr enable_if_not_floating_value<T> Call(KernelContext*, Arg0 arg0,
Arg1 arg1, Status*) {
static_assert(std::is_same<T, Arg0>::value && std::is_same<Arg0, Arg1>::value);
return std::max(arg0, arg1);
}

template <typename T, typename Arg0, typename Arg1>
static constexpr enable_if_floating_value<T> Call(KernelContext*, Arg0 left, Arg1 right,
Status*) {
static_assert(std::is_same<T, Arg0>::value && std::is_same<Arg0, Arg1>::value);
if (std::isnan(left)) {
return right;
} else if (std::isnan(right)) {
return left;
} else {
return std::max(left, right);
}
}
};

struct Min {
template <typename T, typename Arg0, typename Arg1>
static constexpr enable_if_not_floating_value<T> Call(KernelContext*, Arg0 arg0,
Arg1 arg1, Status*) {
static_assert(std::is_same<T, Arg0>::value && std::is_same<Arg0, Arg1>::value);
return std::min(arg0, arg1);
}

template <typename T, typename Arg0, typename Arg1>
static constexpr enable_if_floating_value<T> Call(KernelContext*, Arg0 left, Arg1 right,
Status*) {
static_assert(std::is_same<T, Arg0>::value && std::is_same<Arg0, Arg1>::value);
if (std::isnan(left)) {
return right;
} else if (std::isnan(right)) {
return left;
} else {
return std::min(left, right);
}
}
};

/// The term identity is from the mathematical notation monoid.
/// For any associative binary operation, identity is defined as:
/// Op(identity, x) = x for all x.
template <typename Op>
struct Identity;

template <>
struct Identity<Add> {
template <typename Value>
static constexpr Value value{0};
};

template <>
struct Identity<AddChecked> : Identity<Add> {};

template <>
struct Identity<Multiply> {
template <typename Value>
static constexpr Value value{1};
};

template <>
struct Identity<MultiplyChecked> : Identity<Multiply> {};

template <>
struct Identity<Max> {
template <typename Value>
static constexpr Value value{std::numeric_limits<Value>::min()};
};

template <>
struct Identity<Min> {
template <typename Value>
static constexpr Value value{std::numeric_limits<Value>::max()};
};

} // namespace internal
} // namespace compute
} // namespace arrow
3 changes: 3 additions & 0 deletions cpp/src/arrow/compute/kernels/codegen_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,9 @@ using enable_if_integer_value =
template <typename T, typename R = T>
using enable_if_floating_value = enable_if_t<std::is_floating_point<T>::value, R>;

template <typename T, typename R = T>
using enable_if_not_floating_value = enable_if_t<!std::is_floating_point<T>::value, R>;

template <typename T, typename R = T>
using enable_if_decimal_value =
enable_if_t<std::is_same<Decimal128, T>::value || std::is_same<Decimal256, T>::value,
Expand Down
Loading

0 comments on commit e3eb589

Please sign in to comment.