From 01f3338f7cf8dccf38602842a9ade3b0f840cc10 Mon Sep 17 00:00:00 2001
From: David Li
Date: Wed, 30 Jun 2021 18:30:37 +0200
Subject: [PATCH] ARROW-13095: [C++] Implement trig compute functions
Adds sin/cos/tan and their inverses. Checked variants check for what would be domain errors (this does not apply to atan/atan2).
Closes #10544 from lidavidm/arrow-13095
Authored-by: David Li
Signed-off-by: Antoine Pitrou
---
cpp/src/arrow/compute/api_scalar.cc | 7 +
cpp/src/arrow/compute/api_scalar.h | 61 +++
.../arrow/compute/kernels/codegen_internal.h | 14 +-
.../compute/kernels/scalar_arithmetic.cc | 360 ++++++++++++++++--
.../compute/kernels/scalar_arithmetic_test.cc | 106 ++++++
.../arrow/compute/kernels/scalar_compare.cc | 45 ++-
cpp/src/arrow/compute/kernels/util_internal.h | 12 +
docs/source/cpp/compute.rst | 34 ++
docs/source/python/api/compute.rst | 22 ++
9 files changed, 607 insertions(+), 54 deletions(-)
diff --git a/cpp/src/arrow/compute/api_scalar.cc b/cpp/src/arrow/compute/api_scalar.cc
index f005e70e3480c..20bba982a74fb 100644
--- a/cpp/src/arrow/compute/api_scalar.cc
+++ b/cpp/src/arrow/compute/api_scalar.cc
@@ -49,6 +49,12 @@ namespace compute {
SCALAR_ARITHMETIC_UNARY(AbsoluteValue, "abs", "abs_checked")
SCALAR_ARITHMETIC_UNARY(Negate, "negate", "negate_checked")
+SCALAR_ARITHMETIC_UNARY(Sin, "sin", "sin_checked")
+SCALAR_ARITHMETIC_UNARY(Cos, "cos", "cos_checked")
+SCALAR_ARITHMETIC_UNARY(Asin, "asin", "asin_checked")
+SCALAR_ARITHMETIC_UNARY(Acos, "acos", "acos_checked")
+SCALAR_ARITHMETIC_UNARY(Tan, "tan", "tan_checked")
+SCALAR_EAGER_UNARY(Atan, "atan")
#define SCALAR_ARITHMETIC_BINARY(NAME, REGISTRY_NAME, REGISTRY_CHECKED_NAME) \
Result NAME(const Datum& left, const Datum& right, ArithmeticOptions options, \
@@ -64,6 +70,7 @@ SCALAR_ARITHMETIC_BINARY(Divide, "divide", "divide_checked")
SCALAR_ARITHMETIC_BINARY(Power, "power", "power_checked")
SCALAR_ARITHMETIC_BINARY(ShiftLeft, "shift_left", "shift_left_checked")
SCALAR_ARITHMETIC_BINARY(ShiftRight, "shift_right", "shift_right_checked")
+SCALAR_EAGER_BINARY(Atan2, "atan2")
Result MaxElementWise(const std::vector& args,
ElementWiseAggregateOptions options, ExecContext* ctx) {
diff --git a/cpp/src/arrow/compute/api_scalar.h b/cpp/src/arrow/compute/api_scalar.h
index b101325740194..2ec9c1d765307 100644
--- a/cpp/src/arrow/compute/api_scalar.h
+++ b/cpp/src/arrow/compute/api_scalar.h
@@ -335,6 +335,67 @@ Result ShiftRight(const Datum& left, const Datum& right,
ArithmeticOptions options = ArithmeticOptions(),
ExecContext* ctx = NULLPTR);
+/// \brief Compute the sine of the array values.
+/// \param[in] arg The values to compute the sine for.
+/// \param[in] options arithmetic options (enable/disable overflow checking), optional
+/// \param[in] ctx the function execution context, optional
+/// \return the elementwise sine of the values
+ARROW_EXPORT
+Result Sin(const Datum& arg, ArithmeticOptions options = ArithmeticOptions(),
+ ExecContext* ctx = NULLPTR);
+
+/// \brief Compute the cosine of the array values.
+/// \param[in] arg The values to compute the cosine for.
+/// \param[in] options arithmetic options (enable/disable overflow checking), optional
+/// \param[in] ctx the function execution context, optional
+/// \return the elementwise cosine of the values
+ARROW_EXPORT
+Result Cos(const Datum& arg, ArithmeticOptions options = ArithmeticOptions(),
+ ExecContext* ctx = NULLPTR);
+
+/// \brief Compute the inverse sine (arcsine) of the array values.
+/// \param[in] arg The values to compute the inverse sine for.
+/// \param[in] options arithmetic options (enable/disable overflow checking), optional
+/// \param[in] ctx the function execution context, optional
+/// \return the elementwise inverse sine of the values
+ARROW_EXPORT
+Result Asin(const Datum& arg, ArithmeticOptions options = ArithmeticOptions(),
+ ExecContext* ctx = NULLPTR);
+
+/// \brief Compute the inverse cosine (arccosine) of the array values.
+/// \param[in] arg The values to compute the inverse cosine for.
+/// \param[in] options arithmetic options (enable/disable overflow checking), optional
+/// \param[in] ctx the function execution context, optional
+/// \return the elementwise inverse cosine of the values
+ARROW_EXPORT
+Result Acos(const Datum& arg, ArithmeticOptions options = ArithmeticOptions(),
+ ExecContext* ctx = NULLPTR);
+
+/// \brief Compute the tangent of the array values.
+/// \param[in] arg The values to compute the tangent for.
+/// \param[in] options arithmetic options (enable/disable overflow checking), optional
+/// \param[in] ctx the function execution context, optional
+/// \return the elementwise tangent of the values
+ARROW_EXPORT
+Result Tan(const Datum& arg, ArithmeticOptions options = ArithmeticOptions(),
+ ExecContext* ctx = NULLPTR);
+
+/// \brief Compute the inverse tangent (arctangent) of the array values.
+/// \param[in] arg The values to compute the inverse tangent for.
+/// \param[in] ctx the function execution context, optional
+/// \return the elementwise inverse tangent of the values
+ARROW_EXPORT
+Result Atan(const Datum& arg, ExecContext* ctx = NULLPTR);
+
+/// \brief Compute the inverse tangent (arctangent) of y/x, using the
+/// argument signs to determine the correct quadrant.
+/// \param[in] y The y-values to compute the inverse tangent for.
+/// \param[in] x The x-values to compute the inverse tangent for.
+/// \param[in] ctx the function execution context, optional
+/// \return the elementwise inverse tangent of the values
+ARROW_EXPORT
+Result Atan2(const Datum& y, const Datum& x, ExecContext* ctx = NULLPTR);
+
/// \brief Find the element-wise maximum of any number of arrays or scalars.
/// Array values must be the same length.
///
diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.h b/cpp/src/arrow/compute/kernels/codegen_internal.h
index 140f9fdc6695e..a68bb970b4a5f 100644
--- a/cpp/src/arrow/compute/kernels/codegen_internal.h
+++ b/cpp/src/arrow/compute/kernels/codegen_internal.h
@@ -826,7 +826,8 @@ struct ScalarBinary {
ArrayIterator arg0_it(arg0);
ArrayIterator arg1_it(arg1);
RETURN_NOT_OK(OutputAdapter::Write(ctx, out, [&]() -> OutValue {
- return Op::template Call(ctx, arg0_it(), arg1_it(), &st);
+ return Op::template Call(ctx, arg0_it(), arg1_it(),
+ &st);
}));
return st;
}
@@ -837,7 +838,8 @@ struct ScalarBinary {
ArrayIterator arg0_it(arg0);
auto arg1_val = UnboxScalar::Unbox(arg1);
RETURN_NOT_OK(OutputAdapter::Write(ctx, out, [&]() -> OutValue {
- return Op::template Call(ctx, arg0_it(), arg1_val, &st);
+ return Op::template Call(ctx, arg0_it(), arg1_val,
+ &st);
}));
return st;
}
@@ -848,7 +850,8 @@ struct ScalarBinary {
auto arg0_val = UnboxScalar::Unbox(arg0);
ArrayIterator arg1_it(arg1);
RETURN_NOT_OK(OutputAdapter::Write(ctx, out, [&]() -> OutValue {
- return Op::template Call(ctx, arg0_val, arg1_it(), &st);
+ return Op::template Call(ctx, arg0_val, arg1_it(),
+ &st);
}));
return st;
}
@@ -859,8 +862,9 @@ struct ScalarBinary {
if (out->scalar()->is_valid) {
auto arg0_val = UnboxScalar::Unbox(arg0);
auto arg1_val = UnboxScalar::Unbox(arg1);
- BoxScalar::Box(Op::template Call(ctx, arg0_val, arg1_val, &st),
- out->scalar().get());
+ BoxScalar::Box(
+ Op::template Call(ctx, arg0_val, arg1_val, &st),
+ out->scalar().get());
}
return st;
}
diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc
index ef9ef78054a61..da3a3095041b0 100644
--- a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc
@@ -21,6 +21,7 @@
#include
#include "arrow/compute/kernels/common.h"
+#include "arrow/compute/kernels/util_internal.h"
#include "arrow/type_traits.h"
#include "arrow/util/decimal.h"
#include "arrow/util/int_util_internal.h"
@@ -58,12 +59,12 @@ using enable_if_signed_integer = enable_if_t::value, T>;
template
using enable_if_unsigned_integer = enable_if_t::value, T>;
-template
+template
using enable_if_integer =
- enable_if_t::value || is_unsigned_integer::value, T>;
+ enable_if_t::value || is_unsigned_integer::value, R>;
-template
-using enable_if_floating_point = enable_if_t::value, T>;
+template
+using enable_if_floating_point = enable_if_t::value, R>;
template
using enable_if_decimal =
@@ -117,20 +118,20 @@ struct AbsoluteValueChecked {
};
struct Add {
- template
- static constexpr enable_if_floating_point Call(KernelContext*, T left, T right,
+ template
+ static constexpr enable_if_floating_point Call(KernelContext*, Arg0 left, Arg1 right,
Status*) {
return left + right;
}
- template
- static constexpr enable_if_unsigned_integer Call(KernelContext*, T left, T right,
- Status*) {
+ template
+ static constexpr enable_if_unsigned_integer Call(KernelContext*, Arg0 left,
+ Arg1 right, Status*) {
return left + right;
}
- template
- static constexpr enable_if_signed_integer Call(KernelContext*, T left, T right,
+ template
+ static constexpr enable_if_signed_integer Call(KernelContext*, Arg0 left, Arg1 right,
Status*) {
return arrow::internal::SafeSignedAdd(left, right);
}
@@ -166,21 +167,24 @@ struct AddChecked {
};
struct Subtract {
- template
- static constexpr enable_if_floating_point Call(KernelContext*, T left, T right,
+ template
+ static constexpr enable_if_floating_point Call(KernelContext*, Arg0 left, Arg1 right,
Status*) {
+ static_assert(std::is_same::value && std::is_same::value, "");
return left - right;
}
- template
- static constexpr enable_if_unsigned_integer Call(KernelContext*, T left, T right,
- Status*) {
+ template
+ static constexpr enable_if_unsigned_integer Call(KernelContext*, Arg0 left,
+ Arg1 right, Status*) {
+ static_assert(std::is_same::value && std::is_same::value, "");
return left - right;
}
- template
- static constexpr enable_if_signed_integer Call(KernelContext*, T left, T right,
+ template
+ static constexpr enable_if_signed_integer Call(KernelContext*, Arg0 left, Arg1 right,
Status*) {
+ static_assert(std::is_same::value && std::is_same::value, "");
return arrow::internal::SafeSignedSubtract(left, right);
}
@@ -224,21 +228,23 @@ struct Multiply {
static_assert(std::is_same::value, "");
static_assert(std::is_same::value, "");
- template
+ template
static constexpr enable_if_floating_point Call(KernelContext*, T left, T right,
Status*) {
return left * right;
}
- template
- static constexpr enable_if_unsigned_integer Call(KernelContext*, T left, T right,
- Status*) {
+ template
+ static constexpr enable_if_t<
+ is_unsigned_integer::value && !std::is_same::value, T>
+ Call(KernelContext*, T left, T right, Status*) {
return left * right;
}
- template
- static constexpr enable_if_signed_integer Call(KernelContext*, T left, T right,
- Status*) {
+ template
+ static constexpr enable_if_t<
+ is_signed_integer::value && !std::is_same::value, T>
+ Call(KernelContext*, T left, T right, Status*) {
return to_unsigned(left) * to_unsigned(right);
}
@@ -246,12 +252,14 @@ struct Multiply {
// integer. However, some inputs may nevertheless overflow (which triggers undefined
// behaviour). Therefore we first cast to 32 bit unsigned integers where overflow is
// well defined.
- template
- static constexpr int16_t Call(KernelContext*, int16_t left, int16_t right, Status*) {
+ template
+ static constexpr enable_if_same Call(KernelContext*, int16_t left,
+ int16_t right, Status*) {
return static_cast(left) * static_cast(right);
}
- template
- static constexpr uint16_t Call(KernelContext*, uint16_t left, uint16_t right, Status*) {
+ template
+ static constexpr enable_if_same Call(KernelContext*, uint16_t left,
+ uint16_t right, Status*) {
return static_cast(left) * static_cast(right);
}
@@ -405,7 +413,7 @@ struct Power {
return pow;
}
- template
+ template
static enable_if_integer Call(KernelContext*, T base, T exp, Status* st) {
if (exp < 0) {
*st = Status::Invalid("integers to negative integer powers are not allowed");
@@ -414,7 +422,7 @@ struct Power {
return static_cast(IntegerPower(base, exp));
}
- template
+ template
static enable_if_floating_point Call(KernelContext*, T base, T exp, Status*) {
return std::pow(base, exp);
}
@@ -554,6 +562,130 @@ struct ShiftRightChecked {
}
};
+struct Sin {
+ template
+ static enable_if_floating_point Call(KernelContext*, Arg0 val, Status*) {
+ static_assert(std::is_same::value, "");
+ return std::sin(val);
+ }
+};
+
+struct SinChecked {
+ template
+ static enable_if_floating_point Call(KernelContext*, Arg0 val, Status* st) {
+ static_assert(std::is_same::value, "");
+ if (ARROW_PREDICT_FALSE(std::isinf(val))) {
+ *st = Status::Invalid("domain error");
+ return val;
+ }
+ return std::sin(val);
+ }
+};
+
+struct Cos {
+ template
+ static enable_if_floating_point Call(KernelContext*, Arg0 val, Status*) {
+ static_assert(std::is_same::value, "");
+ return std::cos(val);
+ }
+};
+
+struct CosChecked {
+ template
+ static enable_if_floating_point Call(KernelContext*, Arg0 val, Status* st) {
+ static_assert(std::is_same::value, "");
+ if (ARROW_PREDICT_FALSE(std::isinf(val))) {
+ *st = Status::Invalid("domain error");
+ return val;
+ }
+ return std::cos(val);
+ }
+};
+
+struct Tan {
+ template
+ static enable_if_floating_point Call(KernelContext*, Arg0 val, Status*) {
+ static_assert(std::is_same::value, "");
+ return std::tan(val);
+ }
+};
+
+struct TanChecked {
+ template
+ static enable_if_floating_point Call(KernelContext*, Arg0 val, Status* st) {
+ static_assert(std::is_same::value, "");
+ if (ARROW_PREDICT_FALSE(std::isinf(val))) {
+ *st = Status::Invalid("domain error");
+ return val;
+ }
+ // Cannot raise range errors (overflow) since PI/2 is not exactly representable
+ return std::tan(val);
+ }
+};
+
+struct Asin {
+ template
+ static enable_if_floating_point Call(KernelContext*, Arg0 val, Status*) {
+ static_assert(std::is_same::value, "");
+ if (ARROW_PREDICT_FALSE(val < -1.0 || val > 1.0)) {
+ return std::numeric_limits::quiet_NaN();
+ }
+ return std::asin(val);
+ }
+};
+
+struct AsinChecked {
+ template
+ static enable_if_floating_point Call(KernelContext*, Arg0 val, Status* st) {
+ static_assert(std::is_same::value, "");
+ if (ARROW_PREDICT_FALSE(val < -1.0 || val > 1.0)) {
+ *st = Status::Invalid("domain error");
+ return val;
+ }
+ return std::asin(val);
+ }
+};
+
+struct Acos {
+ template
+ static enable_if_floating_point Call(KernelContext*, Arg0 val, Status*) {
+ static_assert(std::is_same::value, "");
+ if (ARROW_PREDICT_FALSE((val < -1.0 || val > 1.0))) {
+ return std::numeric_limits::quiet_NaN();
+ }
+ return std::acos(val);
+ }
+};
+
+struct AcosChecked {
+ template
+ static enable_if_floating_point Call(KernelContext*, Arg0 val, Status* st) {
+ static_assert(std::is_same::value, "");
+ if (ARROW_PREDICT_FALSE((val < -1.0 || val > 1.0))) {
+ *st = Status::Invalid("domain error");
+ return val;
+ }
+ return std::acos(val);
+ }
+};
+
+struct Atan {
+ template
+ static enable_if_floating_point Call(KernelContext*, Arg0 val, Status*) {
+ static_assert(std::is_same::value, "");
+ return std::atan(val);
+ }
+};
+
+struct Atan2 {
+ template
+ static enable_if_floating_point Call(KernelContext*, Arg0 y, Arg1 x, Status*) {
+ static_assert(std::is_same::value, "");
+ static_assert(std::is_same::value, "");
+ return std::atan2(y, x);
+ }
+};
+
// Generate a kernel given an arithmetic functor
template class KernelGenerator, typename Op>
ArrayKernelExec ArithmeticExecFromOp(detail::GetTypeId get_id) {
@@ -633,6 +765,19 @@ ArrayKernelExec ShiftExecFromOp(detail::GetTypeId get_id) {
}
}
+template class KernelGenerator, typename Op>
+ArrayKernelExec GenerateArithmeticFloatingPoint(detail::GetTypeId get_id) {
+ switch (get_id.id) {
+ case Type::FLOAT:
+ return KernelGenerator::Exec;
+ case Type::DOUBLE:
+ return KernelGenerator::Exec;
+ default:
+ DCHECK(false);
+ return ExecFail;
+ }
+}
+
Status CastBinaryDecimalArgs(const std::string& func_name,
std::vector* values) {
auto& left_type = (*values)[0].type;
@@ -904,6 +1049,42 @@ std::shared_ptr MakeShiftFunctionNotNull(std::string name,
return func;
}
+template
+std::shared_ptr MakeUnaryArithmeticFunctionFloatingPoint(
+ std::string name, const FunctionDoc* doc) {
+ auto func = std::make_shared(name, Arity::Unary(), doc);
+ for (const auto& ty : FloatingPointTypes()) {
+ auto output = is_integer(ty->id()) ? float64() : ty;
+ auto exec = GenerateArithmeticFloatingPoint(ty);
+ DCHECK_OK(func->AddKernel({ty}, output, exec));
+ }
+ return func;
+}
+
+template
+std::shared_ptr MakeUnaryArithmeticFunctionFloatingPointNotNull(
+ std::string name, const FunctionDoc* doc) {
+ auto func = std::make_shared(name, Arity::Unary(), doc);
+ for (const auto& ty : FloatingPointTypes()) {
+ auto output = is_integer(ty->id()) ? float64() : ty;
+ auto exec = GenerateArithmeticFloatingPoint(ty);
+ DCHECK_OK(func->AddKernel({ty}, output, exec));
+ }
+ return func;
+}
+
+template
+std::shared_ptr MakeArithmeticFunctionFloatingPoint(
+ std::string name, const FunctionDoc* doc) {
+ auto func = std::make_shared(name, Arity::Binary(), doc);
+ for (const auto& ty : FloatingPointTypes()) {
+ auto output = is_integer(ty->id()) ? float64() : ty;
+ auto exec = GenerateArithmeticFloatingPoint(ty);
+ DCHECK_OK(func->AddKernel({ty, ty}, output, exec));
+ }
+ return func;
+}
+
const FunctionDoc absolute_value_doc{
"Calculate the absolute value of the argument element-wise",
("Results will wrap around on integer overflow.\n"
@@ -1041,6 +1222,79 @@ const FunctionDoc shift_right_checked_doc{
"See \"shift_right\" for a variant that doesn't fail for an invalid shift amount"),
{"x", "y"}};
+const FunctionDoc sin_doc{"Compute the sine of the elements argument-wise",
+ ("Integer arguments return double values. "
+ "This function returns NaN on values outside its domain. "
+ "To raise an error instead, see \"sin_checked\"."),
+ {"x"}};
+
+const FunctionDoc sin_checked_doc{
+ "Compute the sine of the elements argument-wise",
+ ("Integer arguments return double values. "
+ "This function raises an error on values outside its domain. "
+ "To return NaN instead, see \"sin\"."),
+ {"x"}};
+
+const FunctionDoc cos_doc{"Compute the cosine of the elements argument-wise",
+ ("Integer arguments return double values. "
+ "This function returns NaN on values outside its domain. "
+ "To raise an error instead, see \"cos_checked\"."),
+ {"x"}};
+
+const FunctionDoc cos_checked_doc{
+ "Compute the cosine of the elements argument-wise",
+ ("Integer arguments return double values. "
+ "This function raises an error on values outside its domain. "
+ "To return NaN instead, see \"cos\"."),
+ {"x"}};
+
+const FunctionDoc tan_doc{"Compute the tangent of the elements argument-wise",
+ ("Integer arguments return double values. "
+ "This function returns NaN on values outside its domain. "
+ "To raise an error instead, see \"tan_checked\"."),
+ {"x"}};
+
+const FunctionDoc tan_checked_doc{
+ "Compute the tangent of the elements argument-wise",
+ ("Integer arguments return double values. "
+ "This function raises an error on values outside its domain. "
+ "To return NaN instead, see \"tan\"."),
+ {"x"}};
+
+const FunctionDoc asin_doc{"Compute the inverse sine of the elements argument-wise",
+ ("Integer arguments return double values. "
+ "This function returns NaN on values outside its domain. "
+ "To raise an error instead, see \"asin_checked\"."),
+ {"x"}};
+
+const FunctionDoc asin_checked_doc{
+ "Compute the inverse sine of the elements argument-wise",
+ ("Integer arguments return double values. "
+ "This function raises an error on values outside its domain. "
+ "To return NaN instead, see \"asin\"."),
+ {"x"}};
+
+const FunctionDoc acos_doc{"Compute the inverse cosine of the elements argument-wise",
+ ("Integer arguments return double values. "
+ "This function returns NaN on values outside its domain. "
+ "To raise an error instead, see \"acos_checked\"."),
+ {"x"}};
+
+const FunctionDoc acos_checked_doc{
+ "Compute the inverse cosine of the elements argument-wise",
+ ("Integer arguments return double values. "
+ "This function raises an error on values outside its domain. "
+ "To return NaN instead, see \"acos\"."),
+ {"x"}};
+
+const FunctionDoc atan_doc{"Compute the principal value of the inverse tangent",
+ ("Integer arguments return double values."),
+ {"x"}};
+
+const FunctionDoc atan2_doc{
+ "Compute the inverse tangent using argument signs to determine the quadrant",
+ ("Integer arguments return double values."),
+ {"y", "x"}};
} // namespace
void RegisterScalarArithmetic(FunctionRegistry* registry) {
@@ -1126,6 +1380,7 @@ void RegisterScalarArithmetic(FunctionRegistry* registry) {
DCHECK_OK(registry->AddFunction(std::move(power_checked)));
// ----------------------------------------------------------------------
+ // Bitwise functions
{
auto bit_wise_not = std::make_shared(
"bit_wise_not", Arity::Unary(), &bit_wise_not_doc);
@@ -1162,6 +1417,49 @@ void RegisterScalarArithmetic(FunctionRegistry* registry) {
auto shift_right_checked = MakeShiftFunctionNotNull(
"shift_right_checked", &shift_right_checked_doc);
DCHECK_OK(registry->AddFunction(std::move(shift_right_checked)));
+
+ // ----------------------------------------------------------------------
+ // Trig functions
+ auto sin = MakeUnaryArithmeticFunctionFloatingPoint("sin", &sin_doc);
+ DCHECK_OK(registry->AddFunction(std::move(sin)));
+
+ auto sin_checked = MakeUnaryArithmeticFunctionFloatingPointNotNull(
+ "sin_checked", &sin_checked_doc);
+ DCHECK_OK(registry->AddFunction(std::move(sin_checked)));
+
+ auto cos = MakeUnaryArithmeticFunctionFloatingPoint("cos", &cos_doc);
+ DCHECK_OK(registry->AddFunction(std::move(cos)));
+
+ auto cos_checked = MakeUnaryArithmeticFunctionFloatingPointNotNull(
+ "cos_checked", &cos_checked_doc);
+ DCHECK_OK(registry->AddFunction(std::move(cos_checked)));
+
+ auto tan = MakeUnaryArithmeticFunctionFloatingPoint("tan", &tan_doc);
+ DCHECK_OK(registry->AddFunction(std::move(tan)));
+
+ auto tan_checked = MakeUnaryArithmeticFunctionFloatingPointNotNull(
+ "tan_checked", &tan_checked_doc);
+ DCHECK_OK(registry->AddFunction(std::move(tan_checked)));
+
+ auto asin = MakeUnaryArithmeticFunctionFloatingPoint("asin", &asin_doc);
+ DCHECK_OK(registry->AddFunction(std::move(asin)));
+
+ auto asin_checked = MakeUnaryArithmeticFunctionFloatingPointNotNull(
+ "asin_checked", &asin_checked_doc);
+ DCHECK_OK(registry->AddFunction(std::move(asin_checked)));
+
+ auto acos = MakeUnaryArithmeticFunctionFloatingPoint("acos", &acos_doc);
+ DCHECK_OK(registry->AddFunction(std::move(acos)));
+
+ auto acos_checked = MakeUnaryArithmeticFunctionFloatingPointNotNull(
+ "acos_checked", &acos_checked_doc);
+ DCHECK_OK(registry->AddFunction(std::move(acos_checked)));
+
+ auto atan = MakeUnaryArithmeticFunctionFloatingPoint("atan", &atan_doc);
+ DCHECK_OK(registry->AddFunction(std::move(atan)));
+
+ auto atan2 = MakeArithmeticFunctionFloatingPoint("atan2", &atan2_doc);
+ DCHECK_OK(registry->AddFunction(std::move(atan2)));
}
} // namespace internal
diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc
index a94eabb1be0bb..ed24a44484fac 100644
--- a/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc
@@ -16,6 +16,8 @@
// under the License.
#include
+#define _USE_MATH_DEFINES
+#include
#include
#include
#include
@@ -90,6 +92,12 @@ class TestUnaryArithmetic : public TestBase {
void AssertUnaryOp(UnaryFunction func, const std::shared_ptr& arg,
const std::string& expected_json) {
const auto expected = ArrayFromJSON(type_singleton(), expected_json);
+ return AssertUnaryOp(func, arg, expected);
+ }
+
+ // (Array)
+ void AssertUnaryOp(UnaryFunction func, const std::shared_ptr& arg,
+ const std::shared_ptr& expected) {
ASSERT_OK_AND_ASSIGN(Datum actual, func(arg, options_, nullptr));
ValidateAndAssertApproxEqual(actual.make_array(), expected);
@@ -108,6 +116,11 @@ class TestUnaryArithmetic : public TestBase {
auto arg = ArrayFromJSON(type_singleton(), argument);
EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, ::testing::HasSubstr(expected_msg),
func(arg, options_, nullptr));
+ for (int64_t i = 0; i < arg->length(); i++) {
+ ASSERT_OK_AND_ASSIGN(auto scalar, arg->GetScalar(i));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, ::testing::HasSubstr(expected_msg),
+ func(scalar, options_, nullptr));
+ }
}
void AssertUnaryOpNotImplemented(UnaryFunction func, const std::string& argument) {
@@ -232,6 +245,12 @@ class TestBinaryArithmetic : public TestBase {
const std::shared_ptr& right,
const std::string& expected_json) {
const auto expected = ArrayFromJSON(type_singleton(), expected_json);
+ AssertBinop(func, left, right, expected);
+ }
+
+ void AssertBinop(BinaryFunction func, const std::shared_ptr& left,
+ const std::shared_ptr& right,
+ const std::shared_ptr& expected) {
ASSERT_OK_AND_ASSIGN(Datum actual, func(left, right, options_, nullptr));
ValidateAndAssertApproxEqual(actual.make_array(), expected);
@@ -1715,5 +1734,92 @@ TYPED_TEST(TestBinaryArithmeticUnsigned, ShiftRightOverflowRaises) {
"shift amount must be >= 0 and less than precision of type");
}
+TYPED_TEST(TestUnaryArithmeticFloating, TrigSin) {
+ this->SetNansEqual(true);
+ this->AssertUnaryOp(Sin, "[Inf, -Inf]", "[NaN, NaN]");
+ for (auto check_overflow : {false, true}) {
+ this->SetOverflowCheck(check_overflow);
+ this->AssertUnaryOp(Sin, "[]", "[]");
+ this->AssertUnaryOp(Sin, "[null, NaN]", "[null, NaN]");
+ this->AssertUnaryOp(Sin, MakeArray(0, M_PI_2, M_PI), "[0, 1, 0]");
+ }
+ this->AssertUnaryOpRaises(Sin, "[Inf, -Inf]", "domain error");
+}
+
+TYPED_TEST(TestUnaryArithmeticFloating, TrigCos) {
+ this->SetNansEqual(true);
+ this->AssertUnaryOp(Cos, "[Inf, -Inf]", "[NaN, NaN]");
+ for (auto check_overflow : {false, true}) {
+ this->SetOverflowCheck(check_overflow);
+ this->AssertUnaryOp(Cos, "[]", "[]");
+ this->AssertUnaryOp(Cos, "[null, NaN]", "[null, NaN]");
+ this->AssertUnaryOp(Cos, MakeArray(0, M_PI_2, M_PI), "[1, 0, -1]");
+ }
+ this->AssertUnaryOpRaises(Cos, "[Inf, -Inf]", "domain error");
+}
+
+TYPED_TEST(TestUnaryArithmeticFloating, TrigTan) {
+ this->SetNansEqual(true);
+ this->AssertUnaryOp(Tan, "[Inf, -Inf]", "[NaN, NaN]");
+ for (auto check_overflow : {false, true}) {
+ this->SetOverflowCheck(check_overflow);
+ this->AssertUnaryOp(Tan, "[]", "[]");
+ this->AssertUnaryOp(Tan, "[null, NaN]", "[null, NaN]");
+ // N.B. pi/2 isn't representable exactly -> there are no poles
+ // (i.e. tan(pi/2) is merely a large value and not +Inf)
+ this->AssertUnaryOp(Tan, MakeArray(0, M_PI), "[0, 0]");
+ }
+ this->AssertUnaryOpRaises(Tan, "[Inf, -Inf]", "domain error");
+}
+
+TYPED_TEST(TestUnaryArithmeticFloating, TrigAsin) {
+ this->SetNansEqual(true);
+ this->AssertUnaryOp(Asin, "[Inf, -Inf, -2, 2]", "[NaN, NaN, NaN, NaN]");
+ for (auto check_overflow : {false, true}) {
+ this->SetOverflowCheck(check_overflow);
+ this->AssertUnaryOp(Asin, "[]", "[]");
+ this->AssertUnaryOp(Asin, "[null, NaN]", "[null, NaN]");
+ this->AssertUnaryOp(Asin, "[0, 1, -1]", MakeArray(0, M_PI_2, -M_PI_2));
+ }
+ this->AssertUnaryOpRaises(Asin, "[Inf, -Inf, -2, 2]", "domain error");
+}
+
+TYPED_TEST(TestUnaryArithmeticFloating, TrigAcos) {
+ this->SetNansEqual(true);
+ this->AssertUnaryOp(Asin, "[Inf, -Inf, -2, 2]", "[NaN, NaN, NaN, NaN]");
+ for (auto check_overflow : {false, true}) {
+ this->SetOverflowCheck(check_overflow);
+ this->AssertUnaryOp(Acos, "[]", "[]");
+ this->AssertUnaryOp(Acos, "[null, NaN]", "[null, NaN]");
+ this->AssertUnaryOp(Acos, "[0, 1, -1]", MakeArray(M_PI_2, 0, M_PI));
+ }
+ this->AssertUnaryOpRaises(Acos, "[Inf, -Inf, -2, 2]", "domain error");
+}
+
+TYPED_TEST(TestUnaryArithmeticFloating, TrigAtan) {
+ this->SetNansEqual(true);
+ auto atan = [](const Datum& arg, ArithmeticOptions, ExecContext* ctx) {
+ return Atan(arg, ctx);
+ };
+ this->AssertUnaryOp(atan, "[]", "[]");
+ this->AssertUnaryOp(atan, "[null, NaN]", "[null, NaN]");
+ this->AssertUnaryOp(atan, "[0, 1, -1, Inf, -Inf]",
+ MakeArray(0, M_PI_4, -M_PI_4, M_PI_2, -M_PI_2));
+}
+
+TYPED_TEST(TestBinaryArithmeticFloating, TrigAtan2) {
+ this->SetNansEqual(true);
+ auto atan2 = [](const Datum& y, const Datum& x, ArithmeticOptions, ExecContext* ctx) {
+ return Atan2(y, x, ctx);
+ };
+ this->AssertBinop(atan2, "[]", "[]", "[]");
+ this->AssertBinop(atan2, "[0, 0, null, NaN]", "[null, NaN, 0, 0]",
+ "[null, NaN, null, NaN]");
+ this->AssertBinop(atan2, "[0, 0, -0.0, 0, -0.0, 0, 1, 0, -1, Inf, -Inf, 0, 0]",
+ "[0, 0, 0, -0.0, -0.0, 1, 0, -1, 0, 0, 0, Inf, -Inf]",
+ MakeArray(0, 0, -0.0, M_PI, -M_PI, 0, M_PI_2, M_PI, -M_PI_2, M_PI_2,
+ -M_PI_2, 0, M_PI));
+}
+
} // namespace compute
} // namespace arrow
diff --git a/cpp/src/arrow/compute/kernels/scalar_compare.cc b/cpp/src/arrow/compute/kernels/scalar_compare.cc
index 041c6a282f9eb..4342d776c3866 100644
--- a/cpp/src/arrow/compute/kernels/scalar_compare.cc
+++ b/cpp/src/arrow/compute/kernels/scalar_compare.cc
@@ -34,29 +34,33 @@ namespace internal {
namespace {
struct Equal {
- template
- static constexpr bool Call(KernelContext*, const T& left, const T& right, Status*) {
+ template
+ static constexpr T Call(KernelContext*, const Arg0& left, const Arg1& right, Status*) {
+ static_assert(std::is_same::value && std::is_same::value, "");
return left == right;
}
};
struct NotEqual {
- template
- static constexpr bool Call(KernelContext*, const T& left, const T& right, Status*) {
+ template
+ static constexpr T Call(KernelContext*, const Arg0& left, const Arg1& right, Status*) {
+ static_assert(std::is_same::value && std::is_same::value, "");
return left != right;
}
};
struct Greater {
- template
- static constexpr bool Call(KernelContext*, const T& left, const T& right, Status*) {
+ template
+ static constexpr T Call(KernelContext*, const Arg0& left, const Arg1& right, Status*) {
+ static_assert(std::is_same::value && std::is_same::value, "");
return left > right;
}
};
struct GreaterEqual {
- template
- static constexpr bool Call(KernelContext*, const T& left, const T& right, Status*) {
+ template
+ static constexpr T Call(KernelContext*, const Arg0& left, const Arg1& right, Status*) {
+ static_assert(std::is_same::value && std::is_same::value, "");
return left >= right;
}
};
@@ -77,13 +81,15 @@ template
using enable_if_floating_point = enable_if_t::value, T>;
struct Minimum {
- template
- static enable_if_floating_point Call(T left, T right) {
+ template
+ static enable_if_floating_point Call(Arg0 left, Arg1 right) {
+ static_assert(std::is_same::value && std::is_same::value, "");
return std::fmin(left, right);
}
- template
- static enable_if_integer Call(T left, T right) {
+ template
+ static enable_if_integer Call(Arg0 left, Arg1 right) {
+ static_assert(std::is_same::value && std::is_same::value, "");
return std::min(left, right);
}
@@ -104,13 +110,15 @@ struct Minimum {
};
struct Maximum {
- template
- static enable_if_floating_point Call(T left, T right) {
+ template
+ static enable_if_floating_point Call(Arg0 left, Arg1 right) {
+ static_assert(std::is_same::value && std::is_same::value, "");
return std::fmax(left, right);
}
- template
- static enable_if_integer Call(T left, T right) {
+ template
+ static enable_if_integer Call(Arg0 left, Arg1 right) {
+ static_assert(std::is_same::value && std::is_same::value, "");
return std::max(left, right);
}
@@ -291,7 +299,8 @@ struct ScalarMinMax {
value = UnboxScalar::Unbox(scalar);
valid = true;
} else {
- value = Op::Call(value, UnboxScalar::Unbox(scalar));
+ value = Op::template Call(
+ value, UnboxScalar::Unbox(scalar));
}
}
out->is_valid = valid;
@@ -396,7 +405,7 @@ struct ScalarMinMax {
auto u = out_it();
if (!output->buffers[0] ||
BitUtil::GetBit(output->buffers[0]->data(), index)) {
- writer.Write(Op::Call(u, value));
+ writer.Write(Op::template Call(u, value));
} else {
writer.Write(value);
}
diff --git a/cpp/src/arrow/compute/kernels/util_internal.h b/cpp/src/arrow/compute/kernels/util_internal.h
index f230bfbbd6db1..394e08da5813e 100644
--- a/cpp/src/arrow/compute/kernels/util_internal.h
+++ b/cpp/src/arrow/compute/kernels/util_internal.h
@@ -30,6 +30,18 @@ namespace arrow {
namespace compute {
namespace internal {
+// Used in some kernels and testing - not provided by default in MSVC
+// and _USE_MATH_DEFINES is not reliable with unity builds
+#ifndef M_PI
+#define M_PI 3.14159265358979323846
+#endif
+#ifndef M_PI_2
+#define M_PI_2 1.57079632679489661923
+#endif
+#ifndef M_PI_4
+#define M_PI_4 0.785398163397448309616
+#endif
+
// An internal data structure for unpacking a primitive argument to pass to a
// kernel implementation
struct PrimitiveArg {
diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst
index c4ca4d3416cb6..33c1b47445266 100644
--- a/docs/source/cpp/compute.rst
+++ b/docs/source/cpp/compute.rst
@@ -339,6 +339,40 @@ Bit-wise functions
out of bounds for the data type. However, an overflow when shifting the
first input is not error (truncated bits are silently discarded).
+Trigonometric functions
+~~~~~~~~~~~~~~~~~~~~~~~
+
+Trigonometric functions are also supported, and also offer ``_checked``
+variants that check for domain errors if needed.
+
++--------------------------+------------+--------------------+---------------------+
+| Function name | Arity | Input types | Output type |
++==========================+============+====================+=====================+
+| acos | Unary | Float32/Float64 | Float32/Float64 |
++--------------------------+------------+--------------------+---------------------+
+| acos_checked | Unary | Float32/Float64 | Float32/Float64 |
++--------------------------+------------+--------------------+---------------------+
+| asin | Unary | Float32/Float64 | Float32/Float64 |
++--------------------------+------------+--------------------+---------------------+
+| asin_checked | Unary | Float32/Float64 | Float32/Float64 |
++--------------------------+------------+--------------------+---------------------+
+| atan | Unary | Float32/Float64 | Float32/Float64 |
++--------------------------+------------+--------------------+---------------------+
+| atan2 | Binary | Float32/Float64 | Float32/Float64 |
++--------------------------+------------+--------------------+---------------------+
+| cos | Unary | Float32/Float64 | Float32/Float64 |
++--------------------------+------------+--------------------+---------------------+
+| cos_checked | Unary | Float32/Float64 | Float32/Float64 |
++--------------------------+------------+--------------------+---------------------+
+| sin | Unary | Float32/Float64 | Float32/Float64 |
++--------------------------+------------+--------------------+---------------------+
+| sin_checked | Unary | Float32/Float64 | Float32/Float64 |
++--------------------------+------------+--------------------+---------------------+
+| tan | Unary | Float32/Float64 | Float32/Float64 |
++--------------------------+------------+--------------------+---------------------+
+| tan_checked | Unary | Float32/Float64 | Float32/Float64 |
++--------------------------+------------+--------------------+---------------------+
+
Comparisons
~~~~~~~~~~~
diff --git a/docs/source/python/api/compute.rst b/docs/source/python/api/compute.rst
index 461803dc773c8..334a76e75d267 100644
--- a/docs/source/python/api/compute.rst
+++ b/docs/source/python/api/compute.rst
@@ -73,6 +73,28 @@ Bit-wise operations do not offer (or need) a checked variant.
bit_wise_or
bit_wise_xor
+Trigonometric Functions
+-----------------------
+
+Trigonometric functions are also supported, and also offer ``_checked``
+variants which detect domain errors where appropriate.
+
+.. autosummary::
+ :toctree: ../generated/
+
+ acos
+ acos_checked
+ asin
+ asin_checked
+ atan
+ atan2
+ cos
+ cos_checked
+ sin
+ sin_checked
+ tan
+ tan_checked
+
Comparisons
-----------