Skip to content

Commit

Permalink
apacheGH-32190: [C++] Implement cumulative product, max, and min comp…
Browse files Browse the repository at this point in the history
…ute functions
  • Loading branch information
js8544 committed Jun 7, 2023
1 parent 5a55fb4 commit e2e8e77
Show file tree
Hide file tree
Showing 14 changed files with 1,085 additions and 185 deletions.
39 changes: 29 additions & 10 deletions cpp/src/arrow/compute/api_vector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "arrow/array/array_nested.h"
#include "arrow/array/builder_primitive.h"
#include "arrow/compute/exec.h"
#include "arrow/compute/function.h"
#include "arrow/compute/function_internal.h"
#include "arrow/compute/kernels/vector_sort_internal.h"
#include "arrow/compute/registry.h"
Expand Down Expand Up @@ -142,9 +143,9 @@ static auto kPartitionNthOptionsType = GetFunctionOptionsType<PartitionNthOption
static auto kSelectKOptionsType = GetFunctionOptionsType<SelectKOptions>(
DataMember("k", &SelectKOptions::k),
DataMember("sort_keys", &SelectKOptions::sort_keys));
static auto kCumulativeSumOptionsType = GetFunctionOptionsType<CumulativeSumOptions>(
DataMember("start", &CumulativeSumOptions::start),
DataMember("skip_nulls", &CumulativeSumOptions::skip_nulls));
static auto kCumulativeOptionsType = GetFunctionOptionsType<CumulativeOptions>(
DataMember("start", &CumulativeOptions::start),
DataMember("skip_nulls", &CumulativeOptions::skip_nulls));
static auto kRankOptionsType = GetFunctionOptionsType<RankOptions>(
DataMember("sort_keys", &RankOptions::sort_keys),
DataMember("null_placement", &RankOptions::null_placement),
Expand Down Expand Up @@ -198,13 +199,15 @@ SelectKOptions::SelectKOptions(int64_t k, std::vector<SortKey> sort_keys)
sort_keys(std::move(sort_keys)) {}
constexpr char SelectKOptions::kTypeName[];

CumulativeSumOptions::CumulativeSumOptions(double start, bool skip_nulls)
: CumulativeSumOptions(std::make_shared<DoubleScalar>(start), skip_nulls) {}
CumulativeSumOptions::CumulativeSumOptions(std::shared_ptr<Scalar> start, bool skip_nulls)
: FunctionOptions(internal::kCumulativeSumOptionsType),
CumulativeOptions::CumulativeOptions(bool skip_nulls)
: FunctionOptions(internal::kCumulativeOptionsType), skip_nulls(skip_nulls) {}
CumulativeOptions::CumulativeOptions(double start, bool skip_nulls)
: CumulativeOptions(std::make_shared<DoubleScalar>(start), skip_nulls) {}
CumulativeOptions::CumulativeOptions(std::shared_ptr<Scalar> start, bool skip_nulls)
: FunctionOptions(internal::kCumulativeOptionsType),
start(std::move(start)),
skip_nulls(skip_nulls) {}
constexpr char CumulativeSumOptions::kTypeName[];
constexpr char CumulativeOptions::kTypeName[];

RankOptions::RankOptions(std::vector<SortKey> sort_keys, NullPlacement null_placement,
RankOptions::Tiebreaker tiebreaker)
Expand All @@ -224,7 +227,7 @@ void RegisterVectorOptions(FunctionRegistry* registry) {
DCHECK_OK(registry->AddFunctionOptionsType(kSortOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kPartitionNthOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kSelectKOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kCumulativeSumOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kCumulativeOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kRankOptionsType));
}
} // namespace internal
Expand Down Expand Up @@ -375,12 +378,28 @@ Result<std::shared_ptr<Array>> DropNull(const Array& values, ExecContext* ctx) {
// ----------------------------------------------------------------------
// Cumulative functions

Result<Datum> CumulativeSum(const Datum& values, const CumulativeSumOptions& options,
Result<Datum> CumulativeSum(const Datum& values, const CumulativeOptions& options,
bool check_overflow, ExecContext* ctx) {
auto func_name = check_overflow ? "cumulative_sum_checked" : "cumulative_sum";
return CallFunction(func_name, {Datum(values)}, &options, ctx);
}

Result<Datum> CumulativeProd(const Datum& values, const CumulativeOptions& options,
bool check_overflow, ExecContext* ctx) {
auto func_name = check_overflow ? "cumulative_prod_checked" : "cumulative_prod";
return CallFunction(func_name, {Datum(values)}, &options, ctx);
}

Result<Datum> CumulativeMax(const Datum& values, const CumulativeOptions& options,
ExecContext* ctx) {
return CallFunction("cumulative_max", {Datum(values)}, &options, ctx);
}

Result<Datum> CumulativeMin(const Datum& values, const CumulativeOptions& options,
ExecContext* ctx) {
return CallFunction("cumulative_min", {Datum(values)}, &options, ctx);
}

// ----------------------------------------------------------------------
// Deprecated functions

Expand Down
61 changes: 50 additions & 11 deletions cpp/src/arrow/compute/api_vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,21 +210,29 @@ class ARROW_EXPORT PartitionNthOptions : public FunctionOptions {
NullPlacement null_placement;
};

/// \brief Options for cumulative sum function
class ARROW_EXPORT CumulativeSumOptions : public FunctionOptions {
/// \brief Options for cumulative functions
/// \note Also aliased as CumulativeSumOptions for backward compatibility
class ARROW_EXPORT CumulativeOptions : public FunctionOptions {
public:
explicit CumulativeSumOptions(double start = 0, bool skip_nulls = false);
explicit CumulativeSumOptions(std::shared_ptr<Scalar> start, bool skip_nulls = false);
static constexpr char const kTypeName[] = "CumulativeSumOptions";
static CumulativeSumOptions Defaults() { return CumulativeSumOptions(); }

/// Optional starting value for cumulative operation computation
std::shared_ptr<Scalar> start;
explicit CumulativeOptions(bool skip_nulls = false);
explicit CumulativeOptions(double start, bool skip_nulls = false);
explicit CumulativeOptions(std::shared_ptr<Scalar> start, bool skip_nulls = false);
static constexpr char const kTypeName[] = "CumulativeOptions";
static CumulativeOptions Defaults() { return CumulativeOptions(); }

/// Optional starting value for cumulative operation computation, default depends on the
/// operation and input type.
/// - sum: 0
/// - prod: 1
/// - min: maximum of the input type
/// - max: minimum of the input type
std::optional<std::shared_ptr<Scalar>> start;

/// If true, nulls in the input are ignored and produce a corresponding null output.
/// When false, the first null encountered is propagated through the remaining output.
bool skip_nulls = false;
};
using CumulativeSumOptions = CumulativeOptions; // For backward compatibility

/// @}

Expand Down Expand Up @@ -601,10 +609,41 @@ Result<Datum> RunEndDecode(const Datum& value, ExecContext* ctx = NULLPTR);
/// \param[in] ctx the function execution context, optional
ARROW_EXPORT
Result<Datum> CumulativeSum(
const Datum& values,
const CumulativeSumOptions& options = CumulativeSumOptions::Defaults(),
const Datum& values, const CumulativeOptions& options = CumulativeOptions::Defaults(),
bool check_overflow = false, ExecContext* ctx = NULLPTR);

/// \brief Compute the cumulative product of an array-like object
///
/// \param[in] values array-like input
/// \param[in] options configures cumulative prod behavior
/// \param[in] check_overflow whether to check for overflow, if true, return Invalid
/// status on overflow, otherwise wrap around on overflow
/// \param[in] ctx the function execution context, optional
ARROW_EXPORT
Result<Datum> CumulativeProd(
const Datum& values, const CumulativeOptions& options = CumulativeOptions::Defaults(),
bool check_overflow = false, ExecContext* ctx = NULLPTR);

/// \brief Compute the cumulative max of an array-like object
///
/// \param[in] values array-like input
/// \param[in] options configures cumulative max behavior
/// \param[in] ctx the function execution context, optional
ARROW_EXPORT
Result<Datum> CumulativeMax(
const Datum& values, const CumulativeOptions& options = CumulativeOptions::Defaults(),
ExecContext* ctx = NULLPTR);

/// \brief Compute the cumulative min of an array-like object
///
/// \param[in] values array-like input
/// \param[in] options configures cumulative min behavior
/// \param[in] ctx the function execution context, optional
ARROW_EXPORT
Result<Datum> CumulativeMin(
const Datum& values, const CumulativeOptions& options = CumulativeOptions::Defaults(),
ExecContext* ctx = NULLPTR);

// ----------------------------------------------------------------------
// Deprecated functions

Expand Down
56 changes: 30 additions & 26 deletions cpp/src/arrow/compute/function_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "arrow/compute/function.h"
#include "arrow/compute/type_fwd.h"
#include "arrow/result.h"
#include "arrow/scalar.h"
#include "arrow/status.h"
#include "arrow/util/checked_cast.h"
#include "arrow/util/key_value_metadata.h"
Expand Down Expand Up @@ -283,12 +284,6 @@ static inline Result<decltype(MakeScalar(std::declval<T>()))> GenericToScalar(
return MakeScalar(value);
}

template <typename T>
static inline Result<decltype(MakeScalar(std::declval<T>()))> GenericToScalar(
const std::optional<T>& value) {
return value.has_value() ? MakeScalar(value.value()) : MakeScalar("");
}

// For Clang/libc++: when iterating through vector<bool>, we can't
// pass it by reference so the overload above doesn't apply
static inline Result<std::shared_ptr<Scalar>> GenericToScalar(bool value) {
Expand Down Expand Up @@ -382,6 +377,12 @@ static inline Result<std::shared_ptr<Scalar>> GenericToScalar(const Datum& value
}
}

template <typename T>
static inline Result<decltype(MakeScalar(std::declval<T>()))> GenericToScalar(
const std::optional<T>& value) {
return value.has_value() ? MakeScalar(value.value()) : std::make_shared<NullScalar>();
}

template <typename T>
static inline enable_if_primitive_ctype<typename CTypeTraits<T>::ArrowType, Result<T>>
GenericFromScalar(const std::shared_ptr<Scalar>& value) {
Expand All @@ -404,26 +405,6 @@ GenericFromScalar(const std::shared_ptr<Scalar>& value) {
return ValidateEnumValue<T>(raw_val);
}

template <typename>
constexpr bool is_optional_impl = false;
template <typename T>
constexpr bool is_optional_impl<std::optional<T>> = true;

template <typename T>
using is_optional =
std::integral_constant<bool, is_optional_impl<std::decay_t<T>> ||
std::is_same<T, std::nullopt_t>::value>;

template <typename T, typename R = void>
using enable_if_optional = enable_if_t<is_optional<T>::value, Result<T>>;

template <typename T>
static inline enable_if_optional<T> GenericFromScalar(
const std::shared_ptr<Scalar>& value) {
using value_type = typename T::value_type;
return GenericFromScalar<value_type>(value);
}

template <typename T, typename U>
using enable_if_same_result = enable_if_same<T, U, Result<T>>;

Expand Down Expand Up @@ -510,6 +491,29 @@ static inline enable_if_same_result<T, Datum> GenericFromScalar(
return Status::Invalid("Cannot deserialize Datum from ", value->ToString());
}

template <typename>
constexpr bool is_optional_impl = false;
template <typename T>
constexpr bool is_optional_impl<std::optional<T>> = true;

template <typename T>
using is_optional =
std::integral_constant<bool, is_optional_impl<std::decay_t<T>> ||
std::is_same<T, std::nullopt_t>::value>;

template <typename T, typename R = void>
using enable_if_optional = enable_if_t<is_optional<T>::value, Result<T>>;

template <typename T>
static inline enable_if_optional<T> GenericFromScalar(
const std::shared_ptr<Scalar>& value) {
using value_type = typename T::value_type;
if (value->type->id() == Type::NA) {
return std::nullopt;
}
return GenericFromScalar<value_type>(value);
}

template <typename T>
static enable_if_same<typename CTypeTraits<T>::ArrowType, ListType, Result<T>>
GenericFromScalar(const std::shared_ptr<Scalar>& value) {
Expand Down
84 changes: 84 additions & 0 deletions cpp/src/arrow/compute/kernels/base_arithmetic_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ struct Add {
static enable_if_decimal_value<T> Call(KernelContext*, Arg0 left, Arg1 right, Status*) {
return left + right;
}

template <typename T>
static constexpr T Identity() {
return static_cast<T>(0);
}
};

struct AddChecked {
Expand All @@ -85,6 +90,11 @@ struct AddChecked {
static enable_if_decimal_value<T> Call(KernelContext*, Arg0 left, Arg1 right, Status*) {
return left + right;
}

template <typename T>
static constexpr T Identity() {
return static_cast<T>(0);
}
};

template <int64_t multiple>
Expand Down Expand Up @@ -331,6 +341,11 @@ struct Multiply {
static enable_if_decimal_value<T> Call(KernelContext*, Arg0 left, Arg1 right, Status*) {
return left * right;
}

template <typename T>
static constexpr T Identity() {
return static_cast<T>(1);
}
};

struct MultiplyChecked {
Expand All @@ -356,6 +371,11 @@ struct MultiplyChecked {
static enable_if_decimal_value<T> Call(KernelContext*, Arg0 left, Arg1 right, Status*) {
return left * right;
}

template <typename T>
static constexpr T Identity() {
return static_cast<T>(1);
}
};

struct Divide {
Expand Down Expand Up @@ -605,6 +625,70 @@ struct Sign {
}
};

struct Max {
template <typename T, typename Arg0, typename Arg1>
static constexpr enable_if_not_floating_value<T> Call(KernelContext*, Arg0 arg0,
Arg1 arg1, Status*) {
static_assert(std::is_same<T, Arg0>::value && std::is_same<Arg0, Arg1>::value);
return std::max(arg0, arg1);
}

template <typename T, typename Arg0, typename Arg1>
static constexpr enable_if_floating_value<T> Call(KernelContext*, Arg0 left, Arg1 right,
Status*) {
static_assert(std::is_same<T, Arg0>::value && std::is_same<Arg0, Arg1>::value);
if (std::isnan(left)) {
return right;
} else if (std::isnan(right)) {
return left;
} else {
return std::max(left, right);
}
}

template <typename T>
static constexpr enable_if_decimal_value<T, T> Identity() {
return T::GetMinSentinel();
}

template <typename T>
static constexpr T Identity() {
return std::numeric_limits<T>::min();
}
};

struct Min {
template <typename T, typename Arg0, typename Arg1>
static constexpr enable_if_not_floating_value<T> Call(KernelContext*, Arg0 arg0,
Arg1 arg1, Status*) {
static_assert(std::is_same<T, Arg0>::value && std::is_same<Arg0, Arg1>::value);
return std::min(arg0, arg1);
}

template <typename T, typename Arg0, typename Arg1>
static constexpr enable_if_floating_value<T> Call(KernelContext*, Arg0 left, Arg1 right,
Status*) {
static_assert(std::is_same<T, Arg0>::value && std::is_same<Arg0, Arg1>::value);
if (std::isnan(left)) {
return right;
} else if (std::isnan(right)) {
return left;
} else {
return std::min(left, right);
}
}

template <typename T>
static constexpr enable_if_decimal_value<T, T> Identity() {
return T::GetMaxSentinel();
}

template <typename T>
static constexpr T Identity() {
return std::numeric_limits<T>::max();
}
};

} // namespace internal
} // namespace compute
} // namespace arrow
3 changes: 3 additions & 0 deletions cpp/src/arrow/compute/kernels/codegen_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,9 @@ using enable_if_integer_value =
template <typename T, typename R = T>
using enable_if_floating_value = enable_if_t<std::is_floating_point<T>::value, R>;

template <typename T, typename R = T>
using enable_if_not_floating_value = enable_if_t<!std::is_floating_point<T>::value, R>;

template <typename T, typename R = T>
using enable_if_decimal_value =
enable_if_t<std::is_same<Decimal128, T>::value || std::is_same<Decimal256, T>::value,
Expand Down
Loading

0 comments on commit e2e8e77

Please sign in to comment.