From 01f3338f7cf8dccf38602842a9ade3b0f840cc10 Mon Sep 17 00:00:00 2001 From: David Li Date: Wed, 30 Jun 2021 18:30:37 +0200 Subject: [PATCH] ARROW-13095: [C++] Implement trig compute functions Adds sin/cos/tan and their inverses. Checked variants check for what would be domain errors (this does not apply to atan/atan2). Closes #10544 from lidavidm/arrow-13095 Authored-by: David Li Signed-off-by: Antoine Pitrou --- cpp/src/arrow/compute/api_scalar.cc | 7 + cpp/src/arrow/compute/api_scalar.h | 61 +++ .../arrow/compute/kernels/codegen_internal.h | 14 +- .../compute/kernels/scalar_arithmetic.cc | 360 ++++++++++++++++-- .../compute/kernels/scalar_arithmetic_test.cc | 106 ++++++ .../arrow/compute/kernels/scalar_compare.cc | 45 ++- cpp/src/arrow/compute/kernels/util_internal.h | 12 + docs/source/cpp/compute.rst | 34 ++ docs/source/python/api/compute.rst | 22 ++ 9 files changed, 607 insertions(+), 54 deletions(-) diff --git a/cpp/src/arrow/compute/api_scalar.cc b/cpp/src/arrow/compute/api_scalar.cc index f005e70e3480c..20bba982a74fb 100644 --- a/cpp/src/arrow/compute/api_scalar.cc +++ b/cpp/src/arrow/compute/api_scalar.cc @@ -49,6 +49,12 @@ namespace compute { SCALAR_ARITHMETIC_UNARY(AbsoluteValue, "abs", "abs_checked") SCALAR_ARITHMETIC_UNARY(Negate, "negate", "negate_checked") +SCALAR_ARITHMETIC_UNARY(Sin, "sin", "sin_checked") +SCALAR_ARITHMETIC_UNARY(Cos, "cos", "cos_checked") +SCALAR_ARITHMETIC_UNARY(Asin, "asin", "asin_checked") +SCALAR_ARITHMETIC_UNARY(Acos, "acos", "acos_checked") +SCALAR_ARITHMETIC_UNARY(Tan, "tan", "tan_checked") +SCALAR_EAGER_UNARY(Atan, "atan") #define SCALAR_ARITHMETIC_BINARY(NAME, REGISTRY_NAME, REGISTRY_CHECKED_NAME) \ Result NAME(const Datum& left, const Datum& right, ArithmeticOptions options, \ @@ -64,6 +70,7 @@ SCALAR_ARITHMETIC_BINARY(Divide, "divide", "divide_checked") SCALAR_ARITHMETIC_BINARY(Power, "power", "power_checked") SCALAR_ARITHMETIC_BINARY(ShiftLeft, "shift_left", "shift_left_checked") SCALAR_ARITHMETIC_BINARY(ShiftRight, "shift_right", "shift_right_checked") +SCALAR_EAGER_BINARY(Atan2, "atan2") Result MaxElementWise(const std::vector& args, ElementWiseAggregateOptions options, ExecContext* ctx) { diff --git a/cpp/src/arrow/compute/api_scalar.h b/cpp/src/arrow/compute/api_scalar.h index b101325740194..2ec9c1d765307 100644 --- a/cpp/src/arrow/compute/api_scalar.h +++ b/cpp/src/arrow/compute/api_scalar.h @@ -335,6 +335,67 @@ Result ShiftRight(const Datum& left, const Datum& right, ArithmeticOptions options = ArithmeticOptions(), ExecContext* ctx = NULLPTR); +/// \brief Compute the sine of the array values. +/// \param[in] arg The values to compute the sine for. +/// \param[in] options arithmetic options (enable/disable overflow checking), optional +/// \param[in] ctx the function execution context, optional +/// \return the elementwise sine of the values +ARROW_EXPORT +Result Sin(const Datum& arg, ArithmeticOptions options = ArithmeticOptions(), + ExecContext* ctx = NULLPTR); + +/// \brief Compute the cosine of the array values. +/// \param[in] arg The values to compute the cosine for. +/// \param[in] options arithmetic options (enable/disable overflow checking), optional +/// \param[in] ctx the function execution context, optional +/// \return the elementwise cosine of the values +ARROW_EXPORT +Result Cos(const Datum& arg, ArithmeticOptions options = ArithmeticOptions(), + ExecContext* ctx = NULLPTR); + +/// \brief Compute the inverse sine (arcsine) of the array values. +/// \param[in] arg The values to compute the inverse sine for. +/// \param[in] options arithmetic options (enable/disable overflow checking), optional +/// \param[in] ctx the function execution context, optional +/// \return the elementwise inverse sine of the values +ARROW_EXPORT +Result Asin(const Datum& arg, ArithmeticOptions options = ArithmeticOptions(), + ExecContext* ctx = NULLPTR); + +/// \brief Compute the inverse cosine (arccosine) of the array values. +/// \param[in] arg The values to compute the inverse cosine for. +/// \param[in] options arithmetic options (enable/disable overflow checking), optional +/// \param[in] ctx the function execution context, optional +/// \return the elementwise inverse cosine of the values +ARROW_EXPORT +Result Acos(const Datum& arg, ArithmeticOptions options = ArithmeticOptions(), + ExecContext* ctx = NULLPTR); + +/// \brief Compute the tangent of the array values. +/// \param[in] arg The values to compute the tangent for. +/// \param[in] options arithmetic options (enable/disable overflow checking), optional +/// \param[in] ctx the function execution context, optional +/// \return the elementwise tangent of the values +ARROW_EXPORT +Result Tan(const Datum& arg, ArithmeticOptions options = ArithmeticOptions(), + ExecContext* ctx = NULLPTR); + +/// \brief Compute the inverse tangent (arctangent) of the array values. +/// \param[in] arg The values to compute the inverse tangent for. +/// \param[in] ctx the function execution context, optional +/// \return the elementwise inverse tangent of the values +ARROW_EXPORT +Result Atan(const Datum& arg, ExecContext* ctx = NULLPTR); + +/// \brief Compute the inverse tangent (arctangent) of y/x, using the +/// argument signs to determine the correct quadrant. +/// \param[in] y The y-values to compute the inverse tangent for. +/// \param[in] x The x-values to compute the inverse tangent for. +/// \param[in] ctx the function execution context, optional +/// \return the elementwise inverse tangent of the values +ARROW_EXPORT +Result Atan2(const Datum& y, const Datum& x, ExecContext* ctx = NULLPTR); + /// \brief Find the element-wise maximum of any number of arrays or scalars. /// Array values must be the same length. /// diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.h b/cpp/src/arrow/compute/kernels/codegen_internal.h index 140f9fdc6695e..a68bb970b4a5f 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.h +++ b/cpp/src/arrow/compute/kernels/codegen_internal.h @@ -826,7 +826,8 @@ struct ScalarBinary { ArrayIterator arg0_it(arg0); ArrayIterator arg1_it(arg1); RETURN_NOT_OK(OutputAdapter::Write(ctx, out, [&]() -> OutValue { - return Op::template Call(ctx, arg0_it(), arg1_it(), &st); + return Op::template Call(ctx, arg0_it(), arg1_it(), + &st); })); return st; } @@ -837,7 +838,8 @@ struct ScalarBinary { ArrayIterator arg0_it(arg0); auto arg1_val = UnboxScalar::Unbox(arg1); RETURN_NOT_OK(OutputAdapter::Write(ctx, out, [&]() -> OutValue { - return Op::template Call(ctx, arg0_it(), arg1_val, &st); + return Op::template Call(ctx, arg0_it(), arg1_val, + &st); })); return st; } @@ -848,7 +850,8 @@ struct ScalarBinary { auto arg0_val = UnboxScalar::Unbox(arg0); ArrayIterator arg1_it(arg1); RETURN_NOT_OK(OutputAdapter::Write(ctx, out, [&]() -> OutValue { - return Op::template Call(ctx, arg0_val, arg1_it(), &st); + return Op::template Call(ctx, arg0_val, arg1_it(), + &st); })); return st; } @@ -859,8 +862,9 @@ struct ScalarBinary { if (out->scalar()->is_valid) { auto arg0_val = UnboxScalar::Unbox(arg0); auto arg1_val = UnboxScalar::Unbox(arg1); - BoxScalar::Box(Op::template Call(ctx, arg0_val, arg1_val, &st), - out->scalar().get()); + BoxScalar::Box( + Op::template Call(ctx, arg0_val, arg1_val, &st), + out->scalar().get()); } return st; } diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc index ef9ef78054a61..da3a3095041b0 100644 --- a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc +++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc @@ -21,6 +21,7 @@ #include #include "arrow/compute/kernels/common.h" +#include "arrow/compute/kernels/util_internal.h" #include "arrow/type_traits.h" #include "arrow/util/decimal.h" #include "arrow/util/int_util_internal.h" @@ -58,12 +59,12 @@ using enable_if_signed_integer = enable_if_t::value, T>; template using enable_if_unsigned_integer = enable_if_t::value, T>; -template +template using enable_if_integer = - enable_if_t::value || is_unsigned_integer::value, T>; + enable_if_t::value || is_unsigned_integer::value, R>; -template -using enable_if_floating_point = enable_if_t::value, T>; +template +using enable_if_floating_point = enable_if_t::value, R>; template using enable_if_decimal = @@ -117,20 +118,20 @@ struct AbsoluteValueChecked { }; struct Add { - template - static constexpr enable_if_floating_point Call(KernelContext*, T left, T right, + template + static constexpr enable_if_floating_point Call(KernelContext*, Arg0 left, Arg1 right, Status*) { return left + right; } - template - static constexpr enable_if_unsigned_integer Call(KernelContext*, T left, T right, - Status*) { + template + static constexpr enable_if_unsigned_integer Call(KernelContext*, Arg0 left, + Arg1 right, Status*) { return left + right; } - template - static constexpr enable_if_signed_integer Call(KernelContext*, T left, T right, + template + static constexpr enable_if_signed_integer Call(KernelContext*, Arg0 left, Arg1 right, Status*) { return arrow::internal::SafeSignedAdd(left, right); } @@ -166,21 +167,24 @@ struct AddChecked { }; struct Subtract { - template - static constexpr enable_if_floating_point Call(KernelContext*, T left, T right, + template + static constexpr enable_if_floating_point Call(KernelContext*, Arg0 left, Arg1 right, Status*) { + static_assert(std::is_same::value && std::is_same::value, ""); return left - right; } - template - static constexpr enable_if_unsigned_integer Call(KernelContext*, T left, T right, - Status*) { + template + static constexpr enable_if_unsigned_integer Call(KernelContext*, Arg0 left, + Arg1 right, Status*) { + static_assert(std::is_same::value && std::is_same::value, ""); return left - right; } - template - static constexpr enable_if_signed_integer Call(KernelContext*, T left, T right, + template + static constexpr enable_if_signed_integer Call(KernelContext*, Arg0 left, Arg1 right, Status*) { + static_assert(std::is_same::value && std::is_same::value, ""); return arrow::internal::SafeSignedSubtract(left, right); } @@ -224,21 +228,23 @@ struct Multiply { static_assert(std::is_same::value, ""); static_assert(std::is_same::value, ""); - template + template static constexpr enable_if_floating_point Call(KernelContext*, T left, T right, Status*) { return left * right; } - template - static constexpr enable_if_unsigned_integer Call(KernelContext*, T left, T right, - Status*) { + template + static constexpr enable_if_t< + is_unsigned_integer::value && !std::is_same::value, T> + Call(KernelContext*, T left, T right, Status*) { return left * right; } - template - static constexpr enable_if_signed_integer Call(KernelContext*, T left, T right, - Status*) { + template + static constexpr enable_if_t< + is_signed_integer::value && !std::is_same::value, T> + Call(KernelContext*, T left, T right, Status*) { return to_unsigned(left) * to_unsigned(right); } @@ -246,12 +252,14 @@ struct Multiply { // integer. However, some inputs may nevertheless overflow (which triggers undefined // behaviour). Therefore we first cast to 32 bit unsigned integers where overflow is // well defined. - template - static constexpr int16_t Call(KernelContext*, int16_t left, int16_t right, Status*) { + template + static constexpr enable_if_same Call(KernelContext*, int16_t left, + int16_t right, Status*) { return static_cast(left) * static_cast(right); } - template - static constexpr uint16_t Call(KernelContext*, uint16_t left, uint16_t right, Status*) { + template + static constexpr enable_if_same Call(KernelContext*, uint16_t left, + uint16_t right, Status*) { return static_cast(left) * static_cast(right); } @@ -405,7 +413,7 @@ struct Power { return pow; } - template + template static enable_if_integer Call(KernelContext*, T base, T exp, Status* st) { if (exp < 0) { *st = Status::Invalid("integers to negative integer powers are not allowed"); @@ -414,7 +422,7 @@ struct Power { return static_cast(IntegerPower(base, exp)); } - template + template static enable_if_floating_point Call(KernelContext*, T base, T exp, Status*) { return std::pow(base, exp); } @@ -554,6 +562,130 @@ struct ShiftRightChecked { } }; +struct Sin { + template + static enable_if_floating_point Call(KernelContext*, Arg0 val, Status*) { + static_assert(std::is_same::value, ""); + return std::sin(val); + } +}; + +struct SinChecked { + template + static enable_if_floating_point Call(KernelContext*, Arg0 val, Status* st) { + static_assert(std::is_same::value, ""); + if (ARROW_PREDICT_FALSE(std::isinf(val))) { + *st = Status::Invalid("domain error"); + return val; + } + return std::sin(val); + } +}; + +struct Cos { + template + static enable_if_floating_point Call(KernelContext*, Arg0 val, Status*) { + static_assert(std::is_same::value, ""); + return std::cos(val); + } +}; + +struct CosChecked { + template + static enable_if_floating_point Call(KernelContext*, Arg0 val, Status* st) { + static_assert(std::is_same::value, ""); + if (ARROW_PREDICT_FALSE(std::isinf(val))) { + *st = Status::Invalid("domain error"); + return val; + } + return std::cos(val); + } +}; + +struct Tan { + template + static enable_if_floating_point Call(KernelContext*, Arg0 val, Status*) { + static_assert(std::is_same::value, ""); + return std::tan(val); + } +}; + +struct TanChecked { + template + static enable_if_floating_point Call(KernelContext*, Arg0 val, Status* st) { + static_assert(std::is_same::value, ""); + if (ARROW_PREDICT_FALSE(std::isinf(val))) { + *st = Status::Invalid("domain error"); + return val; + } + // Cannot raise range errors (overflow) since PI/2 is not exactly representable + return std::tan(val); + } +}; + +struct Asin { + template + static enable_if_floating_point Call(KernelContext*, Arg0 val, Status*) { + static_assert(std::is_same::value, ""); + if (ARROW_PREDICT_FALSE(val < -1.0 || val > 1.0)) { + return std::numeric_limits::quiet_NaN(); + } + return std::asin(val); + } +}; + +struct AsinChecked { + template + static enable_if_floating_point Call(KernelContext*, Arg0 val, Status* st) { + static_assert(std::is_same::value, ""); + if (ARROW_PREDICT_FALSE(val < -1.0 || val > 1.0)) { + *st = Status::Invalid("domain error"); + return val; + } + return std::asin(val); + } +}; + +struct Acos { + template + static enable_if_floating_point Call(KernelContext*, Arg0 val, Status*) { + static_assert(std::is_same::value, ""); + if (ARROW_PREDICT_FALSE((val < -1.0 || val > 1.0))) { + return std::numeric_limits::quiet_NaN(); + } + return std::acos(val); + } +}; + +struct AcosChecked { + template + static enable_if_floating_point Call(KernelContext*, Arg0 val, Status* st) { + static_assert(std::is_same::value, ""); + if (ARROW_PREDICT_FALSE((val < -1.0 || val > 1.0))) { + *st = Status::Invalid("domain error"); + return val; + } + return std::acos(val); + } +}; + +struct Atan { + template + static enable_if_floating_point Call(KernelContext*, Arg0 val, Status*) { + static_assert(std::is_same::value, ""); + return std::atan(val); + } +}; + +struct Atan2 { + template + static enable_if_floating_point Call(KernelContext*, Arg0 y, Arg1 x, Status*) { + static_assert(std::is_same::value, ""); + static_assert(std::is_same::value, ""); + return std::atan2(y, x); + } +}; + // Generate a kernel given an arithmetic functor template