Skip to content

Commit

Permalink
ARROW-13095: [C++] Implement trig compute functions
Browse files Browse the repository at this point in the history
Adds sin/cos/tan and their inverses. Checked variants check for what would be domain errors (this does not apply to atan/atan2).

Closes #10544 from lidavidm/arrow-13095

Authored-by: David Li <[email protected]>
Signed-off-by: Antoine Pitrou <[email protected]>
  • Loading branch information
lidavidm authored and pitrou committed Jun 30, 2021
1 parent 6bc94da commit 01f3338
Show file tree
Hide file tree
Showing 9 changed files with 607 additions and 54 deletions.
7 changes: 7 additions & 0 deletions cpp/src/arrow/compute/api_scalar.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ namespace compute {

SCALAR_ARITHMETIC_UNARY(AbsoluteValue, "abs", "abs_checked")
SCALAR_ARITHMETIC_UNARY(Negate, "negate", "negate_checked")
SCALAR_ARITHMETIC_UNARY(Sin, "sin", "sin_checked")
SCALAR_ARITHMETIC_UNARY(Cos, "cos", "cos_checked")
SCALAR_ARITHMETIC_UNARY(Asin, "asin", "asin_checked")
SCALAR_ARITHMETIC_UNARY(Acos, "acos", "acos_checked")
SCALAR_ARITHMETIC_UNARY(Tan, "tan", "tan_checked")
SCALAR_EAGER_UNARY(Atan, "atan")

#define SCALAR_ARITHMETIC_BINARY(NAME, REGISTRY_NAME, REGISTRY_CHECKED_NAME) \
Result<Datum> NAME(const Datum& left, const Datum& right, ArithmeticOptions options, \
Expand All @@ -64,6 +70,7 @@ SCALAR_ARITHMETIC_BINARY(Divide, "divide", "divide_checked")
SCALAR_ARITHMETIC_BINARY(Power, "power", "power_checked")
SCALAR_ARITHMETIC_BINARY(ShiftLeft, "shift_left", "shift_left_checked")
SCALAR_ARITHMETIC_BINARY(ShiftRight, "shift_right", "shift_right_checked")
SCALAR_EAGER_BINARY(Atan2, "atan2")

Result<Datum> MaxElementWise(const std::vector<Datum>& args,
ElementWiseAggregateOptions options, ExecContext* ctx) {
Expand Down
61 changes: 61 additions & 0 deletions cpp/src/arrow/compute/api_scalar.h
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,67 @@ Result<Datum> ShiftRight(const Datum& left, const Datum& right,
ArithmeticOptions options = ArithmeticOptions(),
ExecContext* ctx = NULLPTR);

/// \brief Compute the sine of the array values.
/// \param[in] arg The values to compute the sine for.
/// \param[in] options arithmetic options (enable/disable overflow checking), optional
/// \param[in] ctx the function execution context, optional
/// \return the elementwise sine of the values
ARROW_EXPORT
Result<Datum> Sin(const Datum& arg, ArithmeticOptions options = ArithmeticOptions(),
ExecContext* ctx = NULLPTR);

/// \brief Compute the cosine of the array values.
/// \param[in] arg The values to compute the cosine for.
/// \param[in] options arithmetic options (enable/disable overflow checking), optional
/// \param[in] ctx the function execution context, optional
/// \return the elementwise cosine of the values
ARROW_EXPORT
Result<Datum> Cos(const Datum& arg, ArithmeticOptions options = ArithmeticOptions(),
ExecContext* ctx = NULLPTR);

/// \brief Compute the inverse sine (arcsine) of the array values.
/// \param[in] arg The values to compute the inverse sine for.
/// \param[in] options arithmetic options (enable/disable overflow checking), optional
/// \param[in] ctx the function execution context, optional
/// \return the elementwise inverse sine of the values
ARROW_EXPORT
Result<Datum> Asin(const Datum& arg, ArithmeticOptions options = ArithmeticOptions(),
ExecContext* ctx = NULLPTR);

/// \brief Compute the inverse cosine (arccosine) of the array values.
/// \param[in] arg The values to compute the inverse cosine for.
/// \param[in] options arithmetic options (enable/disable overflow checking), optional
/// \param[in] ctx the function execution context, optional
/// \return the elementwise inverse cosine of the values
ARROW_EXPORT
Result<Datum> Acos(const Datum& arg, ArithmeticOptions options = ArithmeticOptions(),
ExecContext* ctx = NULLPTR);

/// \brief Compute the tangent of the array values.
/// \param[in] arg The values to compute the tangent for.
/// \param[in] options arithmetic options (enable/disable overflow checking), optional
/// \param[in] ctx the function execution context, optional
/// \return the elementwise tangent of the values
ARROW_EXPORT
Result<Datum> Tan(const Datum& arg, ArithmeticOptions options = ArithmeticOptions(),
ExecContext* ctx = NULLPTR);

/// \brief Compute the inverse tangent (arctangent) of the array values.
/// \param[in] arg The values to compute the inverse tangent for.
/// \param[in] ctx the function execution context, optional
/// \return the elementwise inverse tangent of the values
ARROW_EXPORT
Result<Datum> Atan(const Datum& arg, ExecContext* ctx = NULLPTR);

/// \brief Compute the inverse tangent (arctangent) of y/x, using the
/// argument signs to determine the correct quadrant.
/// \param[in] y The y-values to compute the inverse tangent for.
/// \param[in] x The x-values to compute the inverse tangent for.
/// \param[in] ctx the function execution context, optional
/// \return the elementwise inverse tangent of the values
ARROW_EXPORT
Result<Datum> Atan2(const Datum& y, const Datum& x, ExecContext* ctx = NULLPTR);

/// \brief Find the element-wise maximum of any number of arrays or scalars.
/// Array values must be the same length.
///
Expand Down
14 changes: 9 additions & 5 deletions cpp/src/arrow/compute/kernels/codegen_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -826,7 +826,8 @@ struct ScalarBinary {
ArrayIterator<Arg0Type> arg0_it(arg0);
ArrayIterator<Arg1Type> arg1_it(arg1);
RETURN_NOT_OK(OutputAdapter<OutType>::Write(ctx, out, [&]() -> OutValue {
return Op::template Call(ctx, arg0_it(), arg1_it(), &st);
return Op::template Call<OutValue, Arg0Value, Arg1Value>(ctx, arg0_it(), arg1_it(),
&st);
}));
return st;
}
Expand All @@ -837,7 +838,8 @@ struct ScalarBinary {
ArrayIterator<Arg0Type> arg0_it(arg0);
auto arg1_val = UnboxScalar<Arg1Type>::Unbox(arg1);
RETURN_NOT_OK(OutputAdapter<OutType>::Write(ctx, out, [&]() -> OutValue {
return Op::template Call(ctx, arg0_it(), arg1_val, &st);
return Op::template Call<OutValue, Arg0Value, Arg1Value>(ctx, arg0_it(), arg1_val,
&st);
}));
return st;
}
Expand All @@ -848,7 +850,8 @@ struct ScalarBinary {
auto arg0_val = UnboxScalar<Arg0Type>::Unbox(arg0);
ArrayIterator<Arg1Type> arg1_it(arg1);
RETURN_NOT_OK(OutputAdapter<OutType>::Write(ctx, out, [&]() -> OutValue {
return Op::template Call(ctx, arg0_val, arg1_it(), &st);
return Op::template Call<OutValue, Arg0Value, Arg1Value>(ctx, arg0_val, arg1_it(),
&st);
}));
return st;
}
Expand All @@ -859,8 +862,9 @@ struct ScalarBinary {
if (out->scalar()->is_valid) {
auto arg0_val = UnboxScalar<Arg0Type>::Unbox(arg0);
auto arg1_val = UnboxScalar<Arg1Type>::Unbox(arg1);
BoxScalar<OutType>::Box(Op::template Call(ctx, arg0_val, arg1_val, &st),
out->scalar().get());
BoxScalar<OutType>::Box(
Op::template Call<OutValue, Arg0Value, Arg1Value>(ctx, arg0_val, arg1_val, &st),
out->scalar().get());
}
return st;
}
Expand Down
Loading

0 comments on commit 01f3338

Please sign in to comment.