Skip to content

Commit

Permalink
Merge branch 'ARROW-11950-Compute-Add-unary-negative-kernel'
Browse files Browse the repository at this point in the history
  • Loading branch information
edponce committed Apr 13, 2021
2 parents 00a4436 + 26b33ec commit b4806a0
Show file tree
Hide file tree
Showing 4 changed files with 163 additions and 6 deletions.
4 changes: 4 additions & 0 deletions cpp/src/arrow/compute/kernels/codegen_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -739,6 +739,10 @@ struct ScalarUnaryNotNull {
}
};

// A kernel exec generator for unary kernels
template <typename OutType, typename ArgType, typename Op>
using ScalarUnaryType = ScalarUnary<OutType, ArgType, Op>;

// A kernel exec generator for binary functions that addresses both array and
// scalar inputs and dispatches input iteration and output writing to other
// templates
Expand Down
78 changes: 76 additions & 2 deletions cpp/src/arrow/compute/kernels/scalar_arithmetic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ namespace internal {

using applicator::ScalarBinaryEqualTypes;
using applicator::ScalarBinaryNotNullEqualTypes;
// using applicator::ScalarUnary;
using applicator::ScalarUnaryType;

namespace {

Expand Down Expand Up @@ -233,6 +235,43 @@ struct DivideChecked {
}
};

struct Negate {
template <typename T, typename Arg0>
// NOTE [EPM]: Discuss on 0 vs. -0.
static constexpr enable_if_floating_point<T> Call(KernelContext*, Arg0 arg) {
return -arg;
}

// NOTE [EPM]: How to handle unsigned integers?
// * Promote to signed?
// * State that unsigned numbers are not supported (i.e., undefined behavior)?
// * Use C++ integral conversions (e.g., Negate(-128) = -128)?
// * https://timsong-cpp.github.io/cppwp/n4659/conv.integral
template <typename T, typename Arg0>
static constexpr enable_if_integer<T> Call(KernelContext*, Arg0 arg) {
return -arg;
}
};

struct NegateChecked {
template <typename T, typename Arg0>
static enable_if_integer<T> Call(KernelContext* ctx, Arg0 arg) {
static_assert(std::is_same<T, Arg0>::value, "");
T result = 0;
// NOTE [EPM]: Check this edge case of overflow. What are we trying to check here?
if (ARROW_PREDICT_FALSE(SubtractWithOverflow(0, arg, &result))) {
ctx->SetStatus(Status::Invalid("overflow"));
}
return result;
}

template <typename T, typename Arg0>
static enable_if_floating_point<T> Call(KernelContext*, Arg0 arg) {
static_assert(std::is_same<T, Arg0>::value, "");
return -arg;
}
};

// Generate a kernel given an arithmetic functor
template <template <typename... Args> class KernelGenerator, typename Op>
ArrayKernelExec NumericEqualTypesBinary(detail::GetTypeId get_id) {
Expand Down Expand Up @@ -309,6 +348,21 @@ std::shared_ptr<ScalarFunction> MakeArithmeticFunctionNotNull(std::string name,
return func;
}

template <typename Op>
std::shared_ptr<ScalarFunction> MakeScalarArithmeticFunction(std::string name,
const FunctionDoc* doc) {
auto func = std::make_shared<ScalarFunction>(name, Arity::Unary(), doc);
// 8-bit signed integer
// ArrayKernelExec exec = ScalarUnary<Int8Type, Int8Type, Op>::Exec;
// DCHECK_OK(func->AddKernel({int8()}, int8(), exec));

for (const auto& ty : NumericTypes()) {
auto exec = NumericEqualTypesBinary<ScalarUnaryType, Op>(ty);
DCHECK_OK(func->AddKernel({ty}, ty, exec));
}
return func;
}

const FunctionDoc add_doc{"Add the arguments element-wise",
("Results will wrap around on integer overflow.\n"
"Use function \"add_checked\" if you want overflow\n"
Expand All @@ -321,14 +375,14 @@ const FunctionDoc add_checked_doc{
"doesn't fail on overflow, use function \"add\"."),
{"x", "y"}};

const FunctionDoc sub_doc{"Substract the arguments element-wise",
const FunctionDoc sub_doc{"Subtract the arguments element-wise",
("Results will wrap around on integer overflow.\n"
"Use function \"subtract_checked\" if you want overflow\n"
"to return an error."),
{"x", "y"}};

const FunctionDoc sub_checked_doc{
"Substract the arguments element-wise",
"Subtract the arguments element-wise",
("This function returns an error on overflow. For a variant that\n"
"doesn't fail on overflow, use function \"subtract\"."),
{"x", "y"}};
Expand Down Expand Up @@ -359,6 +413,17 @@ const FunctionDoc div_checked_doc{
"integer overflow is encountered."),
{"dividend", "divisor"}};

const FunctionDoc negate_doc{"Negate the argument element-wise",
("Results will wrap around on integer overflow.\n"
"Use function \"negate_checked\" if you want overflow\n"
"to return an error."),
{"x"}};

const FunctionDoc negate_checked_doc{
"Negate the arguments element-wise",
("This function returns an error on overflow. For a variant that\n"
"doesn't fail on overflow, use function \"negate\"."),
{"x"}};
} // namespace

void RegisterScalarArithmetic(FunctionRegistry* registry) {
Expand Down Expand Up @@ -407,6 +472,15 @@ void RegisterScalarArithmetic(FunctionRegistry* registry) {
auto divide_checked =
MakeArithmeticFunctionNotNull<DivideChecked>("divide_checked", &div_checked_doc);
DCHECK_OK(registry->AddFunction(std::move(divide_checked)));

// ----------------------------------------------------------------------
auto negate = MakeScalarArithmeticFunction<Negate>("negate", &negate_doc);
DCHECK_OK(registry->AddFunction(std::move(negate)));

// ----------------------------------------------------------------------
auto negate_checked = MakeScalarArithmeticFunction<NegateChecked>(
"negate_checked", &negate_checked_doc);
DCHECK_OK(registry->AddFunction(std::move(negate_checked)));
}

} // namespace internal
Expand Down
71 changes: 71 additions & 0 deletions cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -709,5 +709,76 @@ TEST(TestBinaryArithmetic, AddWithImplicitCastsUint64EdgeCase) {
ArrayFromJSON(uint64(), "[18446744073709551615]")}));
}

TEST(TestUnaryArithmeticSigned, Negate) {
for (const auto& ty : internal::SignedIntTypes()) {
// No input
CheckScalarUnary("negate", ArrayFromJSON(ty, "[]"), ArrayFromJSON(ty, "[]"));
// Null input
CheckScalarUnary("negate", ArrayFromJSON(ty, "[null]"), ArrayFromJSON(ty, "[null]"));
// Zeros as inputs
CheckScalarUnary("negate", ArrayFromJSON(ty, "[0, 0, -0]"), ArrayFromJSON(ty, "[0, -0, 0]"));
// Positive inputs
CheckScalarUnary("negate", ArrayFromJSON(ty, "[1, 10, 100]"), ArrayFromJSON(ty, "[-1, -10, -100]"));
// Negative inputs
CheckScalarUnary("negate", ArrayFromJSON(ty, "[-1, -10, -100]"), ArrayFromJSON(ty, "[1, 10, 100]"));
}
}

TEST(TestUnaryArithmeticSignedMinMax, Negate) {
// NOTE [EPM]: Can these tests be done by iterating types?

// Min input
// Out-of-bounds after operation (C++ 2's complement wrap around, architecture dependent)
auto int8_min = std::numeric_limits<int8_t>::min();
CheckScalarUnary("negate", MakeScalar(int8_min), MakeScalar(int8_min));
auto int16_min = std::numeric_limits<int16_t>::min();
CheckScalarUnary("negate", MakeScalar(int16_min), MakeScalar(int16_min));
auto int32_min = std::numeric_limits<int32_t>::min();
CheckScalarUnary("negate", MakeScalar(int32_min), MakeScalar(int32_min));
auto int64_min = std::numeric_limits<int64_t>::min();
CheckScalarUnary("negate", MakeScalar(int64_min), MakeScalar(int64_min));

// Max input
// NOTE [EPM]: Why do these fail? The expected result is promoted to int32.
// auto int8_max = std::numeric_limits<int8_t>::max();
// CheckScalarUnary("negate", MakeScalar(int8_max), MakeScalar(-int8_max));
// auto int16_max = std::numeric_limits<int16_t>::max();
// CheckScalarUnary("negate", MakeScalar(int16_max), MakeScalar(-int16_max));
auto int32_max = std::numeric_limits<int32_t>::max();
CheckScalarUnary("negate", MakeScalar(int32_max), MakeScalar(-int32_max));
auto int64_max = std::numeric_limits<int64_t>::max();
CheckScalarUnary("negate", MakeScalar(int64_max), MakeScalar(-int64_max));
}

TEST(TestUnaryArithmeticUnsigned, Negate) {
for (const auto& ty : internal::UnsignedIntTypes()) {
// No input
CheckScalarUnary("negate", ArrayFromJSON(ty, "[]"), ArrayFromJSON(ty, "[]"));
// Null input
CheckScalarUnary("negate", ArrayFromJSON(ty, "[null]"), ArrayFromJSON(ty, "[null]"));
// Zeros as inputs
CheckScalarUnary("negate", ArrayFromJSON(ty, "[0]"), ArrayFromJSON(ty, "[0]"));
}
// Positive inputs
// CheckScalarUnary("negate", ArrayFromJSON(ty, "[1, 10, 100]"), ArrayFromJSON(ty, "[-1, -10, -100]"));
// Negative inputs
// CheckScalarUnary("negate", ArrayFromJSON(ty, "[-1, -10, -100]"), ArrayFromJSON(ty, "[1, 10, 100]"));
}

TEST(TestUnaryArithmeticFloating, Negate) {
for (const auto& ty : internal::FloatingPointTypes()) {
// No input
CheckScalarUnary("negate", ArrayFromJSON(ty, "[]"), ArrayFromJSON(ty, "[]"));
// Null input
CheckScalarUnary("negate", ArrayFromJSON(ty, "[null]"), ArrayFromJSON(ty, "[null]"));
// Zeros as inputs
CheckScalarUnary("negate", ArrayFromJSON(ty, "[0.0, 0.0, -0.0]"), ArrayFromJSON(ty, "[0.0, -0.0, 0.0]"));
// Positive inputs
CheckScalarUnary("negate", ArrayFromJSON(ty, "[1.3, 10.80, 12748.001]"), ArrayFromJSON(ty, "[-1.3, -10.80, -12748.001]"));
// Negative inputs
CheckScalarUnary("negate", ArrayFromJSON(ty, "[-1.3, -10.80, -12748.001]"), ArrayFromJSON(ty, "[1.3, 10.80, 12748.001]"));
}
}

} // namespace compute
} // namespace arrow
16 changes: 12 additions & 4 deletions docs/source/cpp/compute.rst
Original file line number Diff line number Diff line change
Expand Up @@ -250,17 +250,21 @@ Binary functions have the following semantics (which is sometimes called
Arithmetic functions
~~~~~~~~~~~~~~~~~~~~

These functions expect two inputs of numeric type and apply a given binary
operation to each pair of elements gathered from the inputs. If any of the
input elements in a pair is null, the corresponding output element is null.
Inputs will be cast to the :ref:`common numeric type <common-numeric-type>`
These functions expect inputs of numeric type and apply a given arithmetic
operation to each element(s) gathered from the input(s). If any of the
input element(s) is null, the corresponding output element is null.
Input(s) will be cast to the :ref:`common numeric type <common-numeric-type>`
(and dictionary decoded, if applicable) before the operation is applied.

The default variant of these functions does not detect overflow (the result
then typically wraps around). Each function is also available in an
overflow-checking variant, suffixed ``_checked``, which returns
an ``Invalid`` :class:`Status` when overflow is detected.

For a unary operation, should unsigned integer types be promoted as if in a
binary operation with ``int8``? This would at least ensure narrowest possible
signed integer ouptut.

+--------------------------+------------+--------------------+---------------------+
| Function name | Arity | Input types | Output type |
+==========================+============+====================+=====================+
Expand All @@ -280,6 +284,10 @@ an ``Invalid`` :class:`Status` when overflow is detected.
+--------------------------+------------+--------------------+---------------------+
| subtract_checked | Binary | Numeric | Numeric |
+--------------------------+------------+--------------------+---------------------+
| negate | Unary | Numeric | Numeric |
+--------------------------+------------+--------------------+---------------------+
| negate_checked | Unary | Numeric | Numeric |
+--------------------------+------------+--------------------+---------------------+

Comparisons
~~~~~~~~~~~
Expand Down

0 comments on commit b4806a0

Please sign in to comment.