Skip to content

Commit

Permalink
ARROW-13096: [C++] Implement logarithm compute functions
Browse files Browse the repository at this point in the history
  • Loading branch information
lidavidm committed Jun 21, 2021
1 parent e990d17 commit d67c715
Show file tree
Hide file tree
Showing 6 changed files with 374 additions and 4 deletions.
3 changes: 3 additions & 0 deletions cpp/src/arrow/compute/api_scalar.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ namespace compute {

SCALAR_ARITHMETIC_UNARY(AbsoluteValue, "abs", "abs_checked")
SCALAR_ARITHMETIC_UNARY(Negate, "negate", "negate_checked")
SCALAR_ARITHMETIC_UNARY(Ln, "ln", "ln_checked")
SCALAR_ARITHMETIC_UNARY(Log10, "log10", "log10_checked")
SCALAR_ARITHMETIC_UNARY(Log2, "log2", "log2_checked")

#define SCALAR_ARITHMETIC_BINARY(NAME, REGISTRY_NAME, REGISTRY_CHECKED_NAME) \
Result<Datum> NAME(const Datum& left, const Datum& right, ArithmeticOptions options, \
Expand Down
33 changes: 33 additions & 0 deletions cpp/src/arrow/compute/api_scalar.h
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,39 @@ Result<Datum> Power(const Datum& left, const Datum& right,
ArithmeticOptions options = ArithmeticOptions(),
ExecContext* ctx = NULLPTR);

/// \brief Get the natural log of a value. Array values can be of arbitrary
/// length. If argument is null the result will be null.
///
/// \param[in] arg the value transformed
/// \param[in] options arithmetic options (overflow handling), optional
/// \param[in] ctx the function execution context, optional
/// \return the elementwise natural log
ARROW_EXPORT
Result<Datum> Ln(const Datum& arg, ArithmeticOptions options = ArithmeticOptions(),
ExecContext* ctx = NULLPTR);

/// \brief Get the log base 10 of a value. Array values can be of arbitrary
/// length. If argument is null the result will be null.
///
/// \param[in] arg the value transformed
/// \param[in] options arithmetic options (overflow handling), optional
/// \param[in] ctx the function execution context, optional
/// \return the elementwise log base 10
ARROW_EXPORT
Result<Datum> Log10(const Datum& arg, ArithmeticOptions options = ArithmeticOptions(),
ExecContext* ctx = NULLPTR);

/// \brief Get the log base 2 of a value. Array values can be of arbitrary
/// length. If argument is null the result will be null.
///
/// \param[in] arg the value transformed
/// \param[in] options arithmetic options (overflow handling), optional
/// \param[in] ctx the function execution context, optional
/// \return the elementwise log base 2
ARROW_EXPORT
Result<Datum> Log2(const Datum& arg, ArithmeticOptions options = ArithmeticOptions(),
ExecContext* ctx = NULLPTR);

/// \brief Find the element-wise maximum of any number of arrays or scalars.
/// Array values must be the same length.
///
Expand Down
268 changes: 264 additions & 4 deletions cpp/src/arrow/compute/kernels/scalar_arithmetic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,12 @@ using enable_if_signed_integer = enable_if_t<is_signed_integer<T>::value, T>;
template <typename T>
using enable_if_unsigned_integer = enable_if_t<is_unsigned_integer<T>::value, T>;

template <typename T>
template <typename T, typename R = T>
using enable_if_integer =
enable_if_t<is_signed_integer<T>::value || is_unsigned_integer<T>::value, T>;
enable_if_t<is_signed_integer<T>::value || is_unsigned_integer<T>::value, R>;

template <typename T>
using enable_if_floating_point = enable_if_t<std::is_floating_point<T>::value, T>;
template <typename T, typename R = T>
using enable_if_floating_point = enable_if_t<std::is_floating_point<T>::value, R>;

template <typename T>
using enable_if_decimal =
Expand Down Expand Up @@ -454,6 +454,165 @@ struct PowerChecked {
}
};

struct LogNatural {
template <typename T, typename Arg>
static enable_if_integer<Arg, T> Call(KernelContext*, Arg arg, Status*) {
static_assert(std::is_same<T, double>::value, "");
// Match behavior of IEEE754 log (without raising a floating point exception)
if (arg == 0) {
return -std::numeric_limits<T>::infinity();
} else if (arg < 0) {
return std::numeric_limits<T>::quiet_NaN();
}
return std::log(arg);
}

template <typename T, typename Arg>
static enable_if_floating_point<Arg, T> Call(KernelContext*, Arg arg, Status*) {
static_assert(std::is_same<T, Arg>::value, "");
if (arg == 0) {
return -std::numeric_limits<T>::infinity();
} else if (arg < 0) {
return std::numeric_limits<T>::quiet_NaN();
}
return std::log(arg);
}
};

struct LogNaturalChecked {
template <typename T, typename Arg>
static enable_if_integer<Arg, T> Call(KernelContext*, Arg arg, Status* st) {
static_assert(std::is_same<T, double>::value, "");
if (arg == 0) {
*st = Status::Invalid("divide by zero");
return arg;
} else if (arg < 0) {
*st = Status::Invalid("domain error");
return arg;
}
return std::log(arg);
}

template <typename T, typename Arg>
static enable_if_floating_point<Arg, T> Call(KernelContext*, Arg arg, Status* st) {
static_assert(std::is_same<T, Arg>::value, "");
if (arg == 0) {
*st = Status::Invalid("divide by zero");
return arg;
} else if (arg < 0) {
*st = Status::Invalid("domain error");
return arg;
}
return std::log(arg);
}
};

struct Log10 {
template <typename T, typename Arg>
static enable_if_integer<Arg, T> Call(KernelContext*, Arg arg, Status*) {
static_assert(std::is_same<T, double>::value, "");
// Match behavior of IEEE754 log (without raising a floating point exception)
if (arg == 0) {
return -std::numeric_limits<T>::infinity();
} else if (arg < 0) {
return std::numeric_limits<T>::quiet_NaN();
}
return std::log10(arg);
}

template <typename T, typename Arg>
static enable_if_floating_point<Arg, T> Call(KernelContext*, Arg arg, Status*) {
static_assert(std::is_same<T, Arg>::value, "");
if (arg == 0) {
return -std::numeric_limits<T>::infinity();
} else if (arg < 0) {
return std::numeric_limits<T>::quiet_NaN();
}
return std::log10(arg);
}
};

struct Log10Checked {
template <typename T, typename Arg>
static enable_if_integer<Arg, T> Call(KernelContext*, Arg arg, Status* st) {
static_assert(std::is_same<T, double>::value, "");
if (arg == 0) {
*st = Status::Invalid("divide by zero");
return arg;
} else if (arg < 0) {
*st = Status::Invalid("domain error");
return arg;
}
return std::log10(arg);
}

template <typename T, typename Arg>
static enable_if_floating_point<Arg, T> Call(KernelContext*, Arg arg, Status* st) {
static_assert(std::is_same<T, Arg>::value, "");
if (arg == 0) {
*st = Status::Invalid("divide by zero");
return arg;
} else if (arg < 0) {
*st = Status::Invalid("domain error");
return arg;
}
return std::log10(arg);
}
};

struct Log2 {
template <typename T, typename Arg>
static enable_if_integer<Arg, T> Call(KernelContext*, Arg arg, Status*) {
static_assert(std::is_same<T, double>::value, "");
// Match behavior of IEEE754 log (without raising a floating point exception)
if (arg == 0) {
return -std::numeric_limits<T>::infinity();
} else if (arg < 0) {
return std::numeric_limits<T>::quiet_NaN();
}
return std::log2(arg);
}

template <typename T, typename Arg>
static enable_if_floating_point<Arg, T> Call(KernelContext*, Arg arg, Status*) {
static_assert(std::is_same<T, Arg>::value, "");
if (arg == 0) {
return -std::numeric_limits<T>::infinity();
} else if (arg < 0) {
return std::numeric_limits<T>::quiet_NaN();
}
return std::log2(arg);
}
};

struct Log2Checked {
template <typename T, typename Arg>
static enable_if_integer<Arg, T> Call(KernelContext*, Arg arg, Status* st) {
static_assert(std::is_same<T, double>::value, "");
if (arg == 0) {
*st = Status::Invalid("divide by zero");
return arg;
} else if (arg < 0) {
*st = Status::Invalid("domain error");
return arg;
}
return std::log2(arg);
}

template <typename T, typename Arg>
static enable_if_floating_point<Arg, T> Call(KernelContext*, Arg arg, Status* st) {
static_assert(std::is_same<T, Arg>::value, "");
if (arg == 0) {
*st = Status::Invalid("divide by zero");
return arg;
} else if (arg < 0) {
*st = Status::Invalid("domain error");
return arg;
}
return std::log2(arg);
}
};

// Generate a kernel given an arithmetic functor
template <template <typename... Args> class KernelGenerator, typename Op>
ArrayKernelExec ArithmeticExecFromOp(detail::GetTypeId get_id) {
Expand Down Expand Up @@ -485,6 +644,37 @@ ArrayKernelExec ArithmeticExecFromOp(detail::GetTypeId get_id) {
}
}

// For kernels that always return floating results
template <template <typename... Args> class KernelGenerator, typename Op>
ArrayKernelExec IntToDoubleExecFromOp(detail::GetTypeId get_id) {
switch (get_id.id) {
case Type::INT8:
return KernelGenerator<DoubleType, Int8Type, Op>::Exec;
case Type::UINT8:
return KernelGenerator<DoubleType, UInt8Type, Op>::Exec;
case Type::INT16:
return KernelGenerator<DoubleType, Int16Type, Op>::Exec;
case Type::UINT16:
return KernelGenerator<DoubleType, UInt16Type, Op>::Exec;
case Type::INT32:
return KernelGenerator<DoubleType, Int32Type, Op>::Exec;
case Type::UINT32:
return KernelGenerator<DoubleType, UInt32Type, Op>::Exec;
case Type::INT64:
case Type::TIMESTAMP:
return KernelGenerator<DoubleType, Int64Type, Op>::Exec;
case Type::UINT64:
return KernelGenerator<DoubleType, UInt64Type, Op>::Exec;
case Type::FLOAT:
return KernelGenerator<FloatType, FloatType, Op>::Exec;
case Type::DOUBLE:
return KernelGenerator<DoubleType, DoubleType, Op>::Exec;
default:
DCHECK(false);
return ExecFail;
}
}

Status CastBinaryDecimalArgs(const std::string& func_name,
std::vector<ValueDescr>* values) {
auto& left_type = (*values)[0].type;
Expand Down Expand Up @@ -734,6 +924,19 @@ std::shared_ptr<ScalarFunction> MakeUnarySignedArithmeticFunctionNotNull(
return func;
}

// Integer arguments return double values
template <typename Op>
std::shared_ptr<ScalarFunction> MakeUnaryIntToDoubleNotNull(std::string name,
const FunctionDoc* doc) {
auto func = std::make_shared<ArithmeticFunction>(name, Arity::Unary(), doc);
for (const auto& ty : NumericTypes()) {
auto output = is_integer(ty->id()) ? float64() : ty;
auto exec = IntToDoubleExecFromOp<ScalarUnaryNotNull, Op>(ty);
DCHECK_OK(func->AddKernel({ty}, output, exec));
}
return func;
}

const FunctionDoc absolute_value_doc{
"Calculate the absolute value of the argument element-wise",
("Results will wrap around on integer overflow.\n"
Expand Down Expand Up @@ -820,6 +1023,45 @@ const FunctionDoc pow_checked_doc{
("An error is returned when integer to negative integer power is encountered,\n"
"or integer overflow is encountered."),
{"base", "exponent"}};

const FunctionDoc ln_doc{
"Take natural log of arguments element-wise",
("Non-positive values return -inf or NaN. Null values return null.\n"
"Use function \"ln_checked\" if you want non-positive values to raise an error."),
{"x"}};

const FunctionDoc ln_checked_doc{
"Take natural log of arguments element-wise",
("Non-positive values return -inf or NaN. Null values return null.\n"
"Use function \"ln\" if you want non-positive values to return "
"-inf or NaN."),
{"x"}};

const FunctionDoc log10_doc{
"Take log base 10 of arguments element-wise",
("Non-positive values return -inf or NaN. Null values return null.\n"
"Use function \"log10_checked\" if you want non-positive values to raise an error."),
{"x"}};

const FunctionDoc log10_checked_doc{
"Take log base 10 of arguments element-wise",
("Non-positive values return -inf or NaN. Null values return null.\n"
"Use function \"log10\" if you want non-positive values to return "
"-inf or NaN."),
{"x"}};

const FunctionDoc log2_doc{
"Take log base 2 of arguments element-wise",
("Non-positive values return -inf or NaN. Null values return null.\n"
"Use function \"log2_checked\" if you want non-positive values to raise an error."),
{"x"}};

const FunctionDoc log2_checked_doc{
"Take log base 2 of arguments element-wise",
("Non-positive values return -inf or NaN. Null values return null.\n"
"Use function \"log2\" if you want non-positive values to return "
"-inf or NaN."),
{"x"}};
} // namespace

void RegisterScalarArithmetic(FunctionRegistry* registry) {
Expand Down Expand Up @@ -903,6 +1145,24 @@ void RegisterScalarArithmetic(FunctionRegistry* registry) {
auto power_checked =
MakeArithmeticFunctionNotNull<PowerChecked>("power_checked", &pow_checked_doc);
DCHECK_OK(registry->AddFunction(std::move(power_checked)));

// ----------------------------------------------------------------------
// Logarithms
auto ln = MakeUnaryIntToDoubleNotNull<LogNatural>("ln", &ln_doc);
DCHECK_OK(registry->AddFunction(std::move(ln)));
auto ln_checked =
MakeUnaryIntToDoubleNotNull<LogNaturalChecked>("ln_checked", &ln_checked_doc);
DCHECK_OK(registry->AddFunction(std::move(ln_checked)));
auto log10 = MakeUnaryIntToDoubleNotNull<Log10>("log10", &log10_doc);
DCHECK_OK(registry->AddFunction(std::move(log10)));
auto log10_checked =
MakeUnaryIntToDoubleNotNull<Log10Checked>("log10_checked", &log10_checked_doc);
DCHECK_OK(registry->AddFunction(std::move(log10_checked)));
auto log2 = MakeUnaryIntToDoubleNotNull<Log2>("log2", &log2_doc);
DCHECK_OK(registry->AddFunction(std::move(log2)));
auto log2_checked =
MakeUnaryIntToDoubleNotNull<Log2Checked>("log2_checked", &log2_checked_doc);
DCHECK_OK(registry->AddFunction(std::move(log2_checked)));
}

} // namespace internal
Expand Down
Loading

0 comments on commit d67c715

Please sign in to comment.