From 6fc5dc37479e5fdecde70fc7654b658e1f4bff31 Mon Sep 17 00:00:00 2001 From: Jin Shang Date: Thu, 22 Jun 2023 13:32:45 +0800 Subject: [PATCH 01/12] add integer round kernels --- cpp/src/arrow/compute/kernels/scalar_round.cc | 744 ++++++++++++------ .../kernels/scalar_round_arithmetic_test.cc | 312 +++++++- docs/source/cpp/compute.rst | 28 +- 3 files changed, 811 insertions(+), 273 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_round.cc b/cpp/src/arrow/compute/kernels/scalar_round.cc index fc2cb5b8a6ee1..522aff38e2cab 100644 --- a/cpp/src/arrow/compute/kernels/scalar_round.cc +++ b/cpp/src/arrow/compute/kernels/scalar_round.cc @@ -25,6 +25,7 @@ #include "arrow/compare.h" #include "arrow/compute/api_scalar.h" #include "arrow/compute/cast.h" +#include "arrow/compute/kernel.h" #include "arrow/compute/kernels/base_arithmetic_internal.h" #include "arrow/compute/kernels/common_internal.h" #include "arrow/compute/kernels/util_internal.h" @@ -34,6 +35,7 @@ #include "arrow/util/int_util_overflow.h" #include "arrow/util/macros.h" #include "arrow/visit_scalar_inline.h" +#include "arrow/visit_type_inline.h" namespace arrow { @@ -43,8 +45,7 @@ using internal::MultiplyWithOverflow; using internal::NegateWithOverflow; using internal::SubtractWithOverflow; -namespace compute { -namespace internal { +namespace compute::internal { using applicator::ScalarBinary; using applicator::ScalarBinaryEqualTypes; @@ -56,6 +57,9 @@ using applicator::ScalarUnaryNotNullStateful; namespace { +// ---------------------------------------------------------------------- +// Begin utility structs for round kernels + // Convenience visitor to detect if a numeric Scalar is positive. struct IsPositiveVisitor { bool result = false; @@ -82,9 +86,25 @@ bool IsPositive(const Scalar& scalar) { // N.B. take care not to conflict with type_traits.h as that can cause surprises in a // unity build +// A constexpr helper struct to compute powers of 10 at compile time +// Can use a consteval function once we force C++20 +template +struct Pow10Struct { + private: + static constexpr uint64_t half_pow = Pow10Struct::value; + + public: + static constexpr uint64_t value = half_pow * half_pow * (Exp % 2 ? 10 : 1); +}; + +template <> +struct Pow10Struct<0> { + static constexpr uint64_t value = 1; +}; + struct RoundUtil { // Calculate powers of ten with arbitrary integer exponent - template + template static enable_if_floating_value Pow10(int64_t power) { static constexpr T lut[] = {1e0F, 1e1F, 1e2F, 1e3F, 1e4F, 1e5F, 1e6F, 1e7F, 1e8F, 1e9F, 1e10F, 1e11F, 1e12F, 1e13F, 1e14F, 1e15F}; @@ -96,8 +116,30 @@ struct RoundUtil { } return (power >= 0) ? pow10 : (1 / pow10); } + + // Calculate powers of ten with arbitrary integer exponent + template + static enable_if_integer_value Pow10(int64_t power) { + DCHECK(power >= 0); + + static constexpr uint64_t lut[] = { + Pow10Struct<0>::value, Pow10Struct<1>::value, Pow10Struct<2>::value, + Pow10Struct<3>::value, Pow10Struct<4>::value, Pow10Struct<5>::value, + Pow10Struct<6>::value, Pow10Struct<7>::value, Pow10Struct<8>::value, + Pow10Struct<9>::value, Pow10Struct<10>::value, Pow10Struct<11>::value, + Pow10Struct<12>::value, Pow10Struct<13>::value, Pow10Struct<14>::value, + Pow10Struct<15>::value, Pow10Struct<16>::value, Pow10Struct<17>::value, + Pow10Struct<18>::value, Pow10Struct<19>::value}; + + auto digits10 = std::numeric_limits::digits10; + return lut[std::min(power, static_cast(digits10))]; + } }; +// End utility structs for round kernels +// ---------------------------------------------------------------------- +// Begin round implementations for single scalar + // Specializations of rounding implementations for round kernels template struct RoundImpl; @@ -117,6 +159,21 @@ struct RoundImpl { (*val) -= pow10; } } + + template + static enable_if_integer_value Round(const T val, const T floor, const T multiple, + Status* st) { + if constexpr (is_signed_integer_value::value) { + if (ARROW_PREDICT_FALSE(val < 0 && + std::numeric_limits::min() + multiple > floor)) { + *st = Status::Invalid("Rounding ", val, " down to multiple of ", multiple, + " would overflow"); + return val; + } + return val < 0 ? floor - multiple : floor; + } + return floor; + } }; template @@ -134,6 +191,18 @@ struct RoundImpl { (*val) += pow10; } } + + template + static enable_if_integer_value Round(const T val, const T floor, const T multiple, + Status* st) { + if (ARROW_PREDICT_FALSE(val > 0 && + std::numeric_limits::max() - multiple < floor)) { + *st = Status::Invalid("Rounding ", val, " up to multiple of ", multiple, + " would overflow"); + return val; + } + return val > 0 ? floor + multiple : floor; + } }; template @@ -148,6 +217,12 @@ struct RoundImpl { const T& pow10, const int32_t scale) { (*val) -= remainder; } + + template + static enable_if_integer_value Round(const T val, const T floor, const T pow10, + Status* st) { + return floor; + } }; template @@ -167,6 +242,32 @@ struct RoundImpl { (*val) += pow10; } } + + template + static enable_if_integer_value Round(const T val, const T floor, const T multiple, + Status* st) { + if constexpr (is_signed_integer_value::value) { + if (ARROW_PREDICT_FALSE(val < 0 && + std::numeric_limits::min() + multiple > floor)) { + *st = Status::Invalid("Rounding ", val, " down to multiple of ", multiple, + " would overflow"); + return val; + } + } + + if (ARROW_PREDICT_FALSE(val > 0 && + std::numeric_limits::max() - multiple < floor)) { + *st = Status::Invalid("Rounding ", val, " up to multiple of ", multiple, + " would overflow"); + return val; + } + + if constexpr (is_signed_integer_value::value) { + return val < 0 ? floor - multiple : floor + multiple; + } + + return floor + multiple; + } }; // NOTE: RoundImpl variants for the HALF_* rounding modes are only @@ -185,6 +286,12 @@ struct RoundImpl { const T& pow10, const int32_t scale) { RoundImpl::Round(val, remainder, pow10, scale); } + + template + static constexpr enable_if_integer_value Round(const T val, const T floor, + const T multiple, Status* st) { + return RoundImpl::Round(val, floor, multiple, st); + } }; template @@ -199,6 +306,12 @@ struct RoundImpl { const T& pow10, const int32_t scale) { RoundImpl::Round(val, remainder, pow10, scale); } + + template + static constexpr enable_if_integer_value Round(const T val, const T floor, + const T multiple, Status* st) { + return RoundImpl::Round(val, floor, multiple, st); + } }; template @@ -213,6 +326,12 @@ struct RoundImpl { const T& pow10, const int32_t scale) { RoundImpl::Round(val, remainder, pow10, scale); } + + template + static constexpr enable_if_integer_value Round(const T val, const T floor, + const T multiple, Status* st) { + return RoundImpl::Round(val, floor, multiple, st); + } }; template @@ -224,8 +343,14 @@ struct RoundImpl { template static enable_if_decimal_value Round(T* val, const T& remainder, - const T& pow10, const int32_t scale) { - RoundImpl::Round(val, remainder, pow10, scale); + const T& multiple, const int32_t scale) { + RoundImpl::Round(val, remainder, multiple, scale); + } + + template + static constexpr enable_if_integer_value Round(const T val, const T floor, + const T multiple, Status* st) { + return RoundImpl::Round(val, floor, multiple, st); } }; @@ -245,6 +370,15 @@ struct RoundImpl { } *val = scaled.IncreaseScaleBy(scale); } + + template + static constexpr enable_if_integer_value Round(const T val, const T floor, + const T multiple, Status* st) { + if ((floor / multiple) % 2 == 0) { + return floor; + } + return RoundImpl::Round(val, floor, multiple, st); + } }; template @@ -263,23 +397,37 @@ struct RoundImpl { } *val = scaled.IncreaseScaleBy(scale); } + + template + static constexpr enable_if_integer_value Round(const T val, const T floor, + const T multiple, Status* st) { + if ((floor / multiple) % 2 == 1) { + return floor; + } + return RoundImpl::Round(val, floor, multiple, st); + } }; +// End round implementations for single scalar +// ---------------------------------------------------------------------- +// Begin round options wrappers + // Specializations of kernel state for round kernels -template +// CType is the physical type used to store pow10 +template struct RoundOptionsWrapper; -template <> -struct RoundOptionsWrapper : public OptionsWrapper { +template +struct RoundOptionsWrapper : public OptionsWrapper { using OptionsType = RoundOptions; - double pow10; + CType pow10; explicit RoundOptionsWrapper(OptionsType options) : OptionsWrapper(std::move(options)) { // Only positive exponents for powers of 10 are used because combining // multiply and division operations produced more stable rounding than // using multiply-only. Refer to NumPy's round implementation: // https://github.com/numpy/numpy/blob/7b2f20b406d27364c812f7a81a9c901afbd3600c/numpy/core/src/multiarray/calculation.c#L589 - pow10 = RoundUtil::Pow10(std::abs(options.ndigits)); + pow10 = RoundUtil::Pow10(std::abs(options.ndigits)); } static Result> Init(KernelContext* ctx, @@ -292,8 +440,8 @@ struct RoundOptionsWrapper : public OptionsWrapper { } }; -template <> -struct RoundOptionsWrapper +template +struct RoundOptionsWrapper : public OptionsWrapper { using OptionsType = RoundBinaryOptions; @@ -310,8 +458,8 @@ struct RoundOptionsWrapper } }; -template <> -struct RoundOptionsWrapper +template +struct RoundOptionsWrapper : public OptionsWrapper { using OptionsType = RoundToMultipleOptions; using OptionsWrapper::OptionsWrapper; @@ -333,14 +481,8 @@ struct RoundOptionsWrapper return Status::Invalid("Rounding multiple must be positive"); } - // Ensure the rounding multiple option matches the kernel's output type. - // The output type is not available here so we use the following rule: - // If `multiple` is neither a floating-point nor a decimal type, then - // cast to float64, else cast to the kernel's input type. - std::shared_ptr to_type = - (!is_floating(multiple->type->id()) && !is_decimal(multiple->type->id())) - ? float64() - : args.inputs[0].GetSharedPtr(); + // Ensure the rounding multiple option matches the kernel's input type. + std::shared_ptr to_type = args.inputs[0].GetSharedPtr(); if (!multiple->type->Equals(to_type)) { ARROW_ASSIGN_OR_RAISE( auto casted_multiple, @@ -355,11 +497,231 @@ struct RoundOptionsWrapper } }; +template +struct RoundOptionsTrait; + +template +struct RoundOptionsTrait> { + using CType = double; +}; + +template +struct RoundOptionsTrait> { + using CType = double; +}; + +template +struct RoundOptionsTrait> { + using CType = typename ArrowType::c_type; +}; + +// End round options wrappers +// ---------------------------------------------------------------------- +// Begin round op implementations + +template +struct RoundToMultiple { + using CType = typename TypeTraits::CType; + using State = RoundOptionsWrapper::CType>; + + CType multiple; + + explicit RoundToMultiple(const State& state, const DataType& out_ty) + : multiple(UnboxScalar::Unbox(*state.options.multiple)) { + const auto& options = state.options; + DCHECK(options.multiple); + DCHECK(options.multiple->is_valid); + DCHECK(is_floating(options.multiple->type->id())); + } + + template ::CType> + enable_if_floating_value Call(KernelContext* ctx, CType arg, Status* st) const { + // Do not process Inf or NaN because they will trigger the overflow error at end of + // function. + if (!std::isfinite(arg)) { + return arg; + } + auto round_val = arg / multiple; + auto frac = round_val - std::floor(round_val); + if (frac != T(0)) { + // Use std::round() if in tie-breaking mode and scaled value is not 0.5. + if ((kRoundMode >= RoundMode::HALF_DOWN) && (frac != T(0.5))) { + round_val = std::round(round_val); + } else { + round_val = RoundImpl::Round(round_val); + } + round_val *= multiple; + if (!std::isfinite(round_val)) { + *st = Status::Invalid("overflow occurred during rounding"); + return arg; + } + } else { + // If scaled value is an integer, then no rounding is needed. + round_val = arg; + } + return round_val; + } +}; + +template +struct RoundToMultiple> { + using CType = typename TypeTraits::CType; + using State = RoundOptionsWrapper; + const ArrowType& ty; + CType multiple, half_multiple, neg_half_multiple; + bool has_halfway_point; + + explicit RoundToMultiple(const State& state, const DataType& out_ty) + : ty(checked_cast(out_ty)), + multiple(UnboxScalar::Unbox(*state.options.multiple)), + half_multiple(multiple / 2), + neg_half_multiple(-half_multiple), + has_halfway_point(multiple.low_bits() % 2 == 0) { + const auto& options = state.options; + DCHECK(options.multiple); + DCHECK(options.multiple->is_valid); + DCHECK(options.multiple->type->Equals(out_ty)); + } + + template ::CType> + enable_if_decimal_value Call(KernelContext* ctx, CType arg, Status* st) const { + std::pair pair; + *st = arg.Divide(multiple).Value(&pair); + if (!st->ok()) return arg; + const auto& remainder = pair.second; + if (remainder == 0) return arg; + if (kRoundMode >= RoundMode::HALF_DOWN) { + if (has_halfway_point && + (remainder == half_multiple || remainder == neg_half_multiple)) { + // On the halfway point, use tiebreaker + // Manually implement rounding since we're not actually rounding a + // decimal value, but rather manipulating the multiple + switch (kRoundMode) { + case RoundMode::HALF_DOWN: + if (remainder.Sign() < 0) pair.first -= 1; + break; + case RoundMode::HALF_UP: + if (remainder.Sign() >= 0) pair.first += 1; + break; + case RoundMode::HALF_TOWARDS_ZERO: + // Do nothing + break; + case RoundMode::HALF_TOWARDS_INFINITY: + pair.first += remainder.Sign() >= 0 ? 1 : -1; + break; + case RoundMode::HALF_TO_EVEN: + if (pair.first.low_bits() % 2 != 0) { + pair.first += remainder.Sign() >= 0 ? 1 : -1; + } + break; + case RoundMode::HALF_TO_ODD: + if (pair.first.low_bits() % 2 == 0) { + pair.first += remainder.Sign() >= 0 ? 1 : -1; + } + break; + default: + DCHECK(false); + } + } else if (remainder.Sign() >= 0) { + // Positive, round up/down + if (remainder > half_multiple) { + pair.first += 1; + } + } else { + // Negative, round up/down + if (remainder < neg_half_multiple) { + pair.first -= 1; + } + } + } else { + // Manually implement rounding since we're not actually rounding a + // decimal value, but rather manipulating the multiple + switch (kRoundMode) { + case RoundMode::DOWN: + if (remainder.Sign() < 0) pair.first -= 1; + break; + case RoundMode::UP: + if (remainder.Sign() >= 0) pair.first += 1; + break; + case RoundMode::TOWARDS_ZERO: + // Do nothing + break; + case RoundMode::TOWARDS_INFINITY: + pair.first += remainder.Sign() >= 0 ? 1 : -1; + break; + default: + DCHECK(false); + } + } + CType round_val = pair.first * multiple; + if (!round_val.FitsInPrecision(ty.precision())) { + *st = Status::Invalid("Rounded value ", round_val.ToString(ty.scale()), + " does not fit in precision of ", ty); + return 0; + } + return round_val; + } +}; + +template +struct RoundToMultiple> { + using CType = typename TypeTraits::CType; + using State = RoundOptionsWrapper; + CType multiple; + + explicit RoundToMultiple(const State& state, const DataType& out_ty) + : multiple(UnboxScalar::Unbox(*state.options.multiple)) { + const auto& options = state.options; + DCHECK(options.multiple); + DCHECK(options.multiple->is_valid); + DCHECK(is_integer(options.multiple->type->id())); + } + + explicit RoundToMultiple(const CType multiple, const DataType& out_ty) + : multiple(multiple) {} + + template ::CType> + enable_if_integer_value Call(KernelContext* ctx, CType arg, Status* st) const { + CType floor = arg / multiple * multiple; + CType remainder = arg > floor ? arg - floor : floor - arg; + + if (remainder == 0) { + return arg; + } + + if (kRoundMode >= RoundMode::HALF_DOWN && remainder * 2 != multiple) { + // not half way, round to nearest multiple of multiple like std::round + if (remainder * 2 > multiple) { + if (arg >= 0) { + if (ARROW_PREDICT_FALSE(std::numeric_limits::max() - multiple < floor)) { + *st = Status::Invalid("Rounding ", arg, " up to multiples of ", multiple, + " would overflow"); + return arg; + } + return floor + multiple; + } else { + if (ARROW_PREDICT_FALSE(std::numeric_limits::min() + multiple > floor)) { + *st = Status::Invalid("Rounding ", arg, " down to multiples of ", multiple, + " would overflow"); + return arg; + } + return floor - multiple; + } + } else { + return floor; + } + } else { + return RoundImpl::Round(arg, floor, multiple, st); + } + } +}; + template struct Round { using CType = typename TypeTraits::CType; - using State = RoundOptionsWrapper; - + using State = + RoundOptionsWrapper::CType>; CType pow10; int64_t ndigits; @@ -400,8 +762,7 @@ struct Round { template struct Round> { using CType = typename TypeTraits::CType; - using State = RoundOptionsWrapper; - + using State = RoundOptionsWrapper; const ArrowType& ty; int64_t ndigits; int32_t pow; @@ -470,10 +831,37 @@ struct Round> { } }; +template +struct Round> { + using CType = typename TypeTraits::CType; + using State = RoundOptionsWrapper; + CType pow10; + int64_t ndigits; + const DataType& out_ty; + + explicit Round(const State& state, const DataType& out_ty) + : pow10(static_cast(state.pow10)), + ndigits(state.options.ndigits), + out_ty(out_ty) {} + + template ::CType> + enable_if_integer_value Call(KernelContext* ctx, CType arg, Status* st) const { + // no-op if ndigits is non-negative + if (ndigits >= 0) { + return arg; + } + + // If ndigits is negative, then round to the nearest multiple of 10^ndigits. + RoundToMultiple round_to_multiple(pow10, out_ty); + return round_to_multiple.Call(ctx, arg, st); + } +}; + template struct RoundBinary { using CType = typename TypeTraits::CType; - using State = RoundOptionsWrapper; + using State = RoundOptionsWrapper::CType>; explicit RoundBinary(const State& state, const DataType& out_ty) {} @@ -491,7 +879,7 @@ struct RoundBinary { // multiply and division operations produced more stable rounding than // using multiply-only. Refer to NumPy's round implementation: // https://github.com/numpy/numpy/blob/7b2f20b406d27364c812f7a81a9c901afbd3600c/numpy/core/src/multiarray/calculation.c#L589 - double pow10 = RoundUtil::Pow10(std::abs(arg1)); + double pow10 = RoundUtil::Pow10(std::abs(arg1)); auto round_val = arg1 >= 0 ? (arg0 * pow10) : (arg0 / pow10); auto frac = round_val - std::floor(round_val); @@ -520,8 +908,7 @@ struct RoundBinary { template struct RoundBinary> { using CType = typename TypeTraits::CType; - using State = RoundOptionsWrapper; - + using State = RoundOptionsWrapper; const ArrowType& ty; int32_t pow; // pow10 is "1" for the given decimal scale. Similarly half_pow10 is "0.5". @@ -591,156 +978,27 @@ struct RoundBinary> { } }; -template -Status FixedRoundDecimalExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { - using Op = Round; - return ScalarUnaryNotNullStateful( - Op(kDigits, *out->type())) - .Exec(ctx, batch, out); -} - -template -struct RoundToMultiple { - using CType = typename TypeTraits::CType; - using State = RoundOptionsWrapper; - - CType multiple; - - explicit RoundToMultiple(const State& state, const DataType& out_ty) - : multiple(UnboxScalar::Unbox(*state.options.multiple)) { - const auto& options = state.options; - DCHECK(options.multiple); - DCHECK(options.multiple->is_valid); - DCHECK(is_floating(options.multiple->type->id())); - } - - template ::CType> - enable_if_floating_value Call(KernelContext* ctx, CType arg, Status* st) const { - // Do not process Inf or NaN because they will trigger the overflow error at end of - // function. - if (!std::isfinite(arg)) { - return arg; - } - auto round_val = arg / multiple; - auto frac = round_val - std::floor(round_val); - if (frac != T(0)) { - // Use std::round() if in tie-breaking mode and scaled value is not 0.5. - if ((kRoundMode >= RoundMode::HALF_DOWN) && (frac != T(0.5))) { - round_val = std::round(round_val); - } else { - round_val = RoundImpl::Round(round_val); - } - round_val *= multiple; - if (!std::isfinite(round_val)) { - *st = Status::Invalid("overflow occurred during rounding"); - return arg; - } - } else { - // If scaled value is an integer, then no rounding is needed. - round_val = arg; - } - return round_val; - } -}; - template -struct RoundToMultiple> { +struct RoundBinary> { using CType = typename TypeTraits::CType; - using State = RoundOptionsWrapper; + using State = RoundOptionsWrapper; - const ArrowType& ty; - CType multiple, half_multiple, neg_half_multiple; - bool has_halfway_point; - - explicit RoundToMultiple(const State& state, const DataType& out_ty) - : ty(checked_cast(out_ty)), - multiple(UnboxScalar::Unbox(*state.options.multiple)), - half_multiple(multiple / 2), - neg_half_multiple(-half_multiple), - has_halfway_point(multiple.low_bits() % 2 == 0) { - const auto& options = state.options; - DCHECK(options.multiple); - DCHECK(options.multiple->is_valid); - DCHECK(options.multiple->type->Equals(out_ty)); - } + const DataType& out_ty; + explicit RoundBinary(const State& state, const DataType& out_ty) : out_ty(out_ty) {} - template ::CType> - enable_if_decimal_value Call(KernelContext* ctx, CType arg, Status* st) const { - std::pair pair; - *st = arg.Divide(multiple).Value(&pair); - if (!st->ok()) return arg; - const auto& remainder = pair.second; - if (remainder == 0) return arg; - if (kRoundMode >= RoundMode::HALF_DOWN) { - if (has_halfway_point && - (remainder == half_multiple || remainder == neg_half_multiple)) { - // On the halfway point, use tiebreaker - // Manually implement rounding since we're not actually rounding a - // decimal value, but rather manipulating the multiple - switch (kRoundMode) { - case RoundMode::HALF_DOWN: - if (remainder.Sign() < 0) pair.first -= 1; - break; - case RoundMode::HALF_UP: - if (remainder.Sign() >= 0) pair.first += 1; - break; - case RoundMode::HALF_TOWARDS_ZERO: - // Do nothing - break; - case RoundMode::HALF_TOWARDS_INFINITY: - pair.first += remainder.Sign() >= 0 ? 1 : -1; - break; - case RoundMode::HALF_TO_EVEN: - if (pair.first.low_bits() % 2 != 0) { - pair.first += remainder.Sign() >= 0 ? 1 : -1; - } - break; - case RoundMode::HALF_TO_ODD: - if (pair.first.low_bits() % 2 == 0) { - pair.first += remainder.Sign() >= 0 ? 1 : -1; - } - break; - default: - DCHECK(false); - } - } else if (remainder.Sign() >= 0) { - // Positive, round up/down - if (remainder > half_multiple) { - pair.first += 1; - } - } else { - // Negative, round up/down - if (remainder < neg_half_multiple) { - pair.first -= 1; - } - } - } else { - // Manually implement rounding since we're not actually rounding a - // decimal value, but rather manipulating the multiple - switch (kRoundMode) { - case RoundMode::DOWN: - if (remainder.Sign() < 0) pair.first -= 1; - break; - case RoundMode::UP: - if (remainder.Sign() >= 0) pair.first += 1; - break; - case RoundMode::TOWARDS_ZERO: - // Do nothing - break; - case RoundMode::TOWARDS_INFINITY: - pair.first += remainder.Sign() >= 0 ? 1 : -1; - break; - default: - DCHECK(false); - } - } - CType round_val = pair.first * multiple; - if (!round_val.FitsInPrecision(ty.precision())) { - *st = Status::Invalid("Rounded value ", round_val.ToString(ty.scale()), - " does not fit in precision of ", ty); - return 0; + template ::CType0, + typename CType1 = typename TypeTraits::CType1> + enable_if_integer_value Call(KernelContext* ctx, CType0 arg0, CType1 arg1, + Status* st) const { + // ndigits >= 0 is a no-op + if (arg1 >= 0) { + return arg0; } - return round_val; + + // If ndigits is negative, then round to the nearest multiple of 10^ndigits. + CType pow10 = RoundUtil::Pow10(std::abs(arg1)); + RoundToMultiple round_to_multiple(pow10, out_ty); + return round_to_multiple.Call(ctx, arg0, st); } }; @@ -748,7 +1006,7 @@ struct Floor { template static constexpr enable_if_floating_value Call(KernelContext*, Arg arg, Status*) { - static_assert(std::is_same::value, ""); + static_assert(std::is_same::value); return RoundImpl::Round(arg); } }; @@ -757,7 +1015,7 @@ struct Ceil { template static constexpr enable_if_floating_value Call(KernelContext*, Arg arg, Status*) { - static_assert(std::is_same::value, ""); + static_assert(std::is_same::value); return RoundImpl::Round(arg); } }; @@ -766,11 +1024,15 @@ struct Trunc { template static constexpr enable_if_floating_value Call(KernelContext*, Arg arg, Status*) { - static_assert(std::is_same::value, ""); + static_assert(std::is_same::value); return RoundImpl::Round(arg); } }; +// End round op implementations +// ---------------------------------------------------------------------- +// Begin round functions + struct RoundFunction : ScalarFunction { using ScalarFunction::ScalarFunction; @@ -782,6 +1044,10 @@ struct RoundFunction : ScalarFunction { EnsureDictionaryDecoded(types); + // for binary round functions, the second scalar must be int32 + if (types->size() == 2 && (*types)[1].id() != Type::INT32) { + (*types)[1] = int32(); + } if (auto kernel = DispatchExactImpl(this, *types)) return kernel; return arrow::compute::detail::NoMatchingKernel(this, *types); } @@ -862,6 +1128,10 @@ struct RoundFloatingPointFunction : public RoundFunction { } }; +// End round functions +// ---------------------------------------------------------------------- +// Begin round kernels + #define ROUND_CASE(MODE) \ case RoundMode::MODE: { \ using Op = OpImpl; \ @@ -874,7 +1144,8 @@ template class OpImpl> struct RoundKernel { static Status Exec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { - using State = RoundOptionsWrapper; + using State = + RoundOptionsWrapper::CType>; const auto& state = static_cast(*ctx->state()); switch (state.options.round_mode) { ROUND_CASE(DOWN) @@ -909,7 +1180,8 @@ template class OpImpl> struct RoundBinaryKernel { static Status Exec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { - using State = RoundOptionsWrapper; + using State = + RoundOptionsWrapper::CType>; const auto& state = static_cast(*ctx->state()); switch (state.options.round_mode) { ROUND_BINARY_CASE(DOWN) @@ -931,38 +1203,59 @@ struct RoundBinaryKernel { }; #undef ROUND_BINARY_CASE +template +Status FixedRoundDecimalExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { + using Op = Round; + return ScalarUnaryNotNullStateful( + Op(kDigits, *out->type())) + .Exec(ctx, batch, out); +} + +// End round kernels +// ---------------------------------------------------------------------- +// Begin round kernel generation and function registration + +template < + template class Op, + template typename> + class Kernel, + typename OptionsType> +struct RoundKernelGenerator { + template + Status Visit(const ArrowType& type, ArrayKernelExec* exec, KernelInit* init) { + if constexpr (is_integer_type::value || + (is_floating_type::value && + !is_half_float_type::value) || + is_decimal_type::value) { + *exec = Kernel::Exec; + *init = RoundOptionsWrapper::CType>::Init; + } else { + DCHECK(false); + return Status::NotImplemented("Round does not support ", type.ToString()); + } + return Status::OK(); + } +}; + // For unary rounding functions that control kernel dispatch based on RoundMode, only on // non-null output. template