Skip to content

Commit

Permalink
ARROW-13072: [C++] Add bit-wise arithmetic kernels
Browse files Browse the repository at this point in the history
Closes #10530 from lidavidm/arrow-13072

Lead-authored-by: David Li <[email protected]>
Co-authored-by: Antoine Pitrou <[email protected]>
Signed-off-by: Antoine Pitrou <[email protected]>
  • Loading branch information
lidavidm and pitrou committed Jun 30, 2021
1 parent 58b3109 commit e9fa304
Show file tree
Hide file tree
Showing 6 changed files with 551 additions and 18 deletions.
2 changes: 2 additions & 0 deletions cpp/src/arrow/compute/api_scalar.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ SCALAR_ARITHMETIC_BINARY(Subtract, "subtract", "subtract_checked")
SCALAR_ARITHMETIC_BINARY(Multiply, "multiply", "multiply_checked")
SCALAR_ARITHMETIC_BINARY(Divide, "divide", "divide_checked")
SCALAR_ARITHMETIC_BINARY(Power, "power", "power_checked")
SCALAR_ARITHMETIC_BINARY(ShiftLeft, "shift_left", "shift_left_checked")
SCALAR_ARITHMETIC_BINARY(ShiftRight, "shift_right", "shift_right_checked")

Result<Datum> MaxElementWise(const std::vector<Datum>& args,
ElementWiseAggregateOptions options, ExecContext* ctx) {
Expand Down
27 changes: 27 additions & 0 deletions cpp/src/arrow/compute/api_scalar.h
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,33 @@ Result<Datum> Power(const Datum& left, const Datum& right,
ArithmeticOptions options = ArithmeticOptions(),
ExecContext* ctx = NULLPTR);

/// \brief Left shift the left array by the right array. Array values must be the
/// same length. If either operand is null, the result will be null.
///
/// \param[in] left the value to shift
/// \param[in] right the value to shift by
/// \param[in] options arithmetic options (enable/disable overflow checking), optional
/// \param[in] ctx the function execution context, optional
/// \return the elementwise left value shifted left by the right value
ARROW_EXPORT
Result<Datum> ShiftLeft(const Datum& left, const Datum& right,
ArithmeticOptions options = ArithmeticOptions(),
ExecContext* ctx = NULLPTR);

/// \brief Right shift the left array by the right array. Array values must be the
/// same length. If either operand is null, the result will be null. Performs a
/// logical shift for unsigned values, and an arithmetic shift for signed values.
///
/// \param[in] left the value to shift
/// \param[in] right the value to shift by
/// \param[in] options arithmetic options (enable/disable overflow checking), optional
/// \param[in] ctx the function execution context, optional
/// \return the elementwise left value shifted right by the right value
ARROW_EXPORT
Result<Datum> ShiftRight(const Datum& left, const Datum& right,
ArithmeticOptions options = ArithmeticOptions(),
ExecContext* ctx = NULLPTR);

/// \brief Find the element-wise maximum of any number of arrays or scalars.
/// Array values must be the same length.
///
Expand Down
259 changes: 259 additions & 0 deletions cpp/src/arrow/compute/kernels/scalar_arithmetic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,106 @@ struct PowerChecked {
}
};

// Bitwise operations

struct BitWiseNot {
template <typename T, typename Arg>
static T Call(KernelContext*, Arg arg, Status*) {
return ~arg;
}
};

struct BitWiseAnd {
template <typename T, typename Arg0, typename Arg1>
static T Call(KernelContext*, Arg0 lhs, Arg1 rhs, Status*) {
return lhs & rhs;
}
};

struct BitWiseOr {
template <typename T, typename Arg0, typename Arg1>
static T Call(KernelContext*, Arg0 lhs, Arg1 rhs, Status*) {
return lhs | rhs;
}
};

struct BitWiseXor {
template <typename T, typename Arg0, typename Arg1>
static T Call(KernelContext*, Arg0 lhs, Arg1 rhs, Status*) {
return lhs ^ rhs;
}
};

struct ShiftLeft {
template <typename T, typename Arg0, typename Arg1>
static T Call(KernelContext*, Arg0 lhs, Arg1 rhs, Status*) {
using Unsigned = typename std::make_unsigned<Arg0>::type;
static_assert(std::is_same<T, Arg0>::value, "");
if (ARROW_PREDICT_FALSE(rhs < 0 || rhs >= std::numeric_limits<Arg0>::digits)) {
return lhs;
}
return static_cast<T>(static_cast<Unsigned>(lhs) << static_cast<Unsigned>(rhs));
}
};

// See SEI CERT C Coding Standard rule INT34-C
struct ShiftLeftChecked {
template <typename T, typename Arg0, typename Arg1>
static enable_if_unsigned_integer<T> Call(KernelContext*, Arg0 lhs, Arg1 rhs,
Status* st) {
static_assert(std::is_same<T, Arg0>::value, "");
if (ARROW_PREDICT_FALSE(rhs < 0 || rhs >= std::numeric_limits<Arg0>::digits)) {
*st = Status::Invalid("shift amount must be >= 0 and less than precision of type");
return lhs;
}
return lhs << rhs;
}

template <typename T, typename Arg0, typename Arg1>
static enable_if_signed_integer<T> Call(KernelContext*, Arg0 lhs, Arg1 rhs,
Status* st) {
using Unsigned = typename std::make_unsigned<Arg0>::type;
static_assert(std::is_same<T, Arg0>::value, "");
if (ARROW_PREDICT_FALSE(rhs < 0 || rhs >= std::numeric_limits<Arg0>::digits)) {
*st = Status::Invalid("shift amount must be >= 0 and less than precision of type");
return lhs;
}
// In C/C++ left shift of a negative number is undefined (C++11 standard 5.8.2)
// Mimic Java/etc. and treat left shift as based on two's complement representation
// Assumes two's complement machine
return static_cast<T>(static_cast<Unsigned>(lhs) << static_cast<Unsigned>(rhs));
}
};

struct ShiftRight {
template <typename T, typename Arg0, typename Arg1>
static T Call(KernelContext*, Arg0 lhs, Arg1 rhs, Status*) {
static_assert(std::is_same<T, Arg0>::value, "");
// Logical right shift when Arg0 is unsigned
// Arithmetic otherwise (this is implementation-defined but GCC and MSVC document this
// as arithmetic right shift)
// https://gcc.gnu.org/onlinedocs/gcc/Integers-implementation.html#Integers-implementation
// https://docs.microsoft.com/en-us/cpp/cpp/left-shift-and-right-shift-operators-input-and-output?view=msvc-160
// Clang doesn't document their behavior.
if (ARROW_PREDICT_FALSE(rhs < 0 || rhs >= std::numeric_limits<Arg0>::digits)) {
return lhs;
}
return lhs >> rhs;
}
};

struct ShiftRightChecked {
template <typename T, typename Arg0, typename Arg1>
static T Call(KernelContext*, Arg0 lhs, Arg1 rhs, Status* st) {
static_assert(std::is_same<T, Arg0>::value, "");
if (ARROW_PREDICT_FALSE(rhs < 0 || rhs >= std::numeric_limits<Arg0>::digits)) {
*st = Status::Invalid("shift amount must be >= 0 and less than precision of type");
return lhs;
}
return lhs >> rhs;
}
};

// Generate a kernel given an arithmetic functor
template <template <typename... Args> class KernelGenerator, typename Op>
ArrayKernelExec ArithmeticExecFromOp(detail::GetTypeId get_id) {
Expand Down Expand Up @@ -485,6 +585,54 @@ ArrayKernelExec ArithmeticExecFromOp(detail::GetTypeId get_id) {
}
}

// Generate a kernel given a bitwise arithmetic functor. Assumes the
// functor treats all integer types of equal width identically
template <template <typename... Args> class KernelGenerator, typename Op>
ArrayKernelExec TypeAgnosticBitWiseExecFromOp(detail::GetTypeId get_id) {
switch (get_id.id) {
case Type::INT8:
case Type::UINT8:
return KernelGenerator<UInt8Type, UInt8Type, Op>::Exec;
case Type::INT16:
case Type::UINT16:
return KernelGenerator<UInt16Type, UInt16Type, Op>::Exec;
case Type::INT32:
case Type::UINT32:
return KernelGenerator<UInt32Type, UInt32Type, Op>::Exec;
case Type::INT64:
case Type::UINT64:
return KernelGenerator<UInt64Type, UInt64Type, Op>::Exec;
default:
DCHECK(false);
return ExecFail;
}
}

template <template <typename... Args> class KernelGenerator, typename Op>
ArrayKernelExec ShiftExecFromOp(detail::GetTypeId get_id) {
switch (get_id.id) {
case Type::INT8:
return KernelGenerator<Int8Type, Int8Type, Op>::Exec;
case Type::UINT8:
return KernelGenerator<UInt8Type, UInt8Type, Op>::Exec;
case Type::INT16:
return KernelGenerator<Int16Type, Int16Type, Op>::Exec;
case Type::UINT16:
return KernelGenerator<UInt16Type, UInt16Type, Op>::Exec;
case Type::INT32:
return KernelGenerator<Int32Type, Int32Type, Op>::Exec;
case Type::UINT32:
return KernelGenerator<UInt32Type, UInt32Type, Op>::Exec;
case Type::INT64:
return KernelGenerator<Int64Type, Int64Type, Op>::Exec;
case Type::UINT64:
return KernelGenerator<UInt64Type, UInt64Type, Op>::Exec;
default:
DCHECK(false);
return ExecFail;
}
}

Status CastBinaryDecimalArgs(const std::string& func_name,
std::vector<ValueDescr>* values) {
auto& left_type = (*values)[0].type;
Expand Down Expand Up @@ -734,6 +882,28 @@ std::shared_ptr<ScalarFunction> MakeUnarySignedArithmeticFunctionNotNull(
return func;
}

template <typename Op>
std::shared_ptr<ScalarFunction> MakeBitWiseFunctionNotNull(std::string name,
const FunctionDoc* doc) {
auto func = std::make_shared<ArithmeticFunction>(name, Arity::Binary(), doc);
for (const auto& ty : IntTypes()) {
auto exec = TypeAgnosticBitWiseExecFromOp<ScalarBinaryNotNullEqualTypes, Op>(ty);
DCHECK_OK(func->AddKernel({ty, ty}, ty, exec));
}
return func;
}

template <typename Op>
std::shared_ptr<ScalarFunction> MakeShiftFunctionNotNull(std::string name,
const FunctionDoc* doc) {
auto func = std::make_shared<ArithmeticFunction>(name, Arity::Binary(), doc);
for (const auto& ty : IntTypes()) {
auto exec = ShiftExecFromOp<ScalarBinaryNotNullEqualTypes, Op>(ty);
DCHECK_OK(func->AddKernel({ty, ty}, ty, exec));
}
return func;
}

const FunctionDoc absolute_value_doc{
"Calculate the absolute value of the argument element-wise",
("Results will wrap around on integer overflow.\n"
Expand Down Expand Up @@ -820,6 +990,57 @@ const FunctionDoc pow_checked_doc{
("An error is returned when integer to negative integer power is encountered,\n"
"or integer overflow is encountered."),
{"base", "exponent"}};

const FunctionDoc bit_wise_not_doc{
"Bit-wise negate the arguments element-wise", ("Null values return null."), {"x"}};

const FunctionDoc bit_wise_and_doc{
"Bit-wise AND the arguments element-wise", ("Null values return null."), {"x", "y"}};

const FunctionDoc bit_wise_or_doc{
"Bit-wise OR the arguments element-wise", ("Null values return null."), {"x", "y"}};

const FunctionDoc bit_wise_xor_doc{
"Bit-wise XOR the arguments element-wise", ("Null values return null."), {"x", "y"}};

const FunctionDoc shift_left_doc{
"Left shift `x` by `y`",
("This function will return `x` if `y` (the amount to shift by) is: "
"(1) negative or (2) greater than or equal to the precision of `x`.\n"
"The shift operates as if on the two's complement representation of the number. "
"In other words, this is equivalent to multiplying `x` by 2 to the power `y`, "
"even if overflow occurs.\n"
"Use function \"shift_left_checked\" if you want an invalid shift amount to "
"return an error."),
{"x", "y"}};

const FunctionDoc shift_left_checked_doc{
"Left shift `x` by `y` with invalid shift check",
("This function will raise an error if `y` (the amount to shift by) is: "
"(1) negative or (2) greater than or equal to the precision of `x`. "
"The shift operates as if on the two's complement representation of the number. "
"In other words, this is equivalent to multiplying `x` by 2 to the power `y`, "
"even if overflow occurs.\n"
"See \"shift_left\" for a variant that doesn't fail for an invalid shift amount."),
{"x", "y"}};

const FunctionDoc shift_right_doc{
"Right shift `x` by `y`",
("Perform a logical shift for unsigned `x` and an arithmetic shift for signed `x`.\n"
"This function will return `x` if `y` (the amount to shift by) is: "
"(1) negative or (2) greater than or equal to the precision of `x`.\n"
"Use function \"shift_right_checked\" if you want an invalid shift amount to return "
"an error."),
{"x", "y"}};

const FunctionDoc shift_right_checked_doc{
"Right shift `x` by `y` with invalid shift check",
("Perform a logical shift for unsigned `x` and an arithmetic shift for signed `x`.\n"
"This function will raise an error if `y` (the amount to shift by) is: "
"(1) negative or (2) greater than or equal to the precision of `x`.\n"
"See \"shift_right\" for a variant that doesn't fail for an invalid shift amount"),
{"x", "y"}};

} // namespace

void RegisterScalarArithmetic(FunctionRegistry* registry) {
Expand Down Expand Up @@ -903,6 +1124,44 @@ void RegisterScalarArithmetic(FunctionRegistry* registry) {
auto power_checked =
MakeArithmeticFunctionNotNull<PowerChecked>("power_checked", &pow_checked_doc);
DCHECK_OK(registry->AddFunction(std::move(power_checked)));

// ----------------------------------------------------------------------
{
auto bit_wise_not = std::make_shared<ArithmeticFunction>(
"bit_wise_not", Arity::Unary(), &bit_wise_not_doc);
for (const auto& ty : IntTypes()) {
auto exec = TypeAgnosticBitWiseExecFromOp<ScalarUnaryNotNull, BitWiseNot>(ty);
DCHECK_OK(bit_wise_not->AddKernel({ty}, ty, exec));
}
DCHECK_OK(registry->AddFunction(std::move(bit_wise_not)));
}

auto bit_wise_and =
MakeBitWiseFunctionNotNull<BitWiseAnd>("bit_wise_and", &bit_wise_and_doc);
DCHECK_OK(registry->AddFunction(std::move(bit_wise_and)));

auto bit_wise_or =
MakeBitWiseFunctionNotNull<BitWiseOr>("bit_wise_or", &bit_wise_or_doc);
DCHECK_OK(registry->AddFunction(std::move(bit_wise_or)));

auto bit_wise_xor =
MakeBitWiseFunctionNotNull<BitWiseXor>("bit_wise_xor", &bit_wise_xor_doc);
DCHECK_OK(registry->AddFunction(std::move(bit_wise_xor)));

auto shift_left = MakeShiftFunctionNotNull<ShiftLeft>("shift_left", &shift_left_doc);
DCHECK_OK(registry->AddFunction(std::move(shift_left)));

auto shift_left_checked = MakeShiftFunctionNotNull<ShiftLeftChecked>(
"shift_left_checked", &shift_left_checked_doc);
DCHECK_OK(registry->AddFunction(std::move(shift_left_checked)));

auto shift_right =
MakeShiftFunctionNotNull<ShiftRight>("shift_right", &shift_right_doc);
DCHECK_OK(registry->AddFunction(std::move(shift_right)));

auto shift_right_checked = MakeShiftFunctionNotNull<ShiftRightChecked>(
"shift_right_checked", &shift_right_checked_doc);
DCHECK_OK(registry->AddFunction(std::move(shift_right_checked)));
}

} // namespace internal
Expand Down
Loading

0 comments on commit e9fa304

Please sign in to comment.