Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix return 0 while trying to divide or mod by 0 with decimal #197

Open
wants to merge 1 commit into
base: arrow-4.0.0-oap
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions cpp/src/gandiva/function_registry_arithmetic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,14 @@ std::vector<NativeFunction> GetArithmeticFunctionRegistry() {
// add/sub/multiply/divide/mod
BINARY_SYMMETRIC_FN(add, {}), BINARY_SYMMETRIC_FN(subtract, {}),
BINARY_SYMMETRIC_FN(multiply, {}),
NUMERIC_TYPES_WITHOUT_DECIMAL(BINARY_SYMMETRIC_SAFE_INTERNAL_NULL, divide, {}),
NUMERIC_TYPES(BINARY_SYMMETRIC_SAFE_INTERNAL_NULL, divide, {}),
BINARY_SYMMETRIC_SAFE_INTERNAL_NULL(mod, {"modulo"}, int8),
BINARY_SYMMETRIC_SAFE_INTERNAL_NULL(mod, {"modulo"}, int16),
BINARY_SYMMETRIC_SAFE_INTERNAL_NULL(mod, {"modulo"}, int32),
BINARY_SYMMETRIC_SAFE_INTERNAL_NULL(mod, {"modulo"}, int64),
BINARY_SYMMETRIC_SAFE_INTERNAL_NULL(mod, {"modulo"}, float32),
BINARY_SYMMETRIC_SAFE_INTERNAL_NULL(mod, {"modulo"}, float64),
BINARY_SYMMETRIC_UNSAFE_NULL_IF_NULL(mod, {"modulo"}, decimal128),
BINARY_SYMMETRIC_UNSAFE_NULL_IF_NULL(divide, {}, decimal128),
BINARY_SYMMETRIC_SAFE_INTERNAL_NULL(mod, {"modulo"}, decimal128),
BINARY_SYMMETRIC_UNSAFE_NULL_IF_NULL(div, {}, int32),
BINARY_SYMMETRIC_UNSAFE_NULL_IF_NULL(div, {}, int64),
BINARY_SYMMETRIC_UNSAFE_NULL_IF_NULL(div, {}, float32),
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/gandiva/precompiled/decimal_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@ BasicDecimal128 Multiply(const BasicDecimalScalar128& x, const BasicDecimalScala
return result;
}

BasicDecimal128 Divide(int64_t context, const BasicDecimalScalar128& x,
BasicDecimal128 Divide(const BasicDecimalScalar128& x,
const BasicDecimalScalar128& y, int32_t out_precision,
int32_t out_scale, bool* overflow) {
if (y.value() == 0) {
Expand Down Expand Up @@ -392,7 +392,7 @@ BasicDecimal128 Divide(int64_t context, const BasicDecimalScalar128& x,
return result;
}

BasicDecimal128 Mod(int64_t context, const BasicDecimalScalar128& x,
BasicDecimal128 Mod(const BasicDecimalScalar128& x,
const BasicDecimalScalar128& y, int32_t out_precision,
int32_t out_scale, bool* overflow) {
if (y.value() == 0) {
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/gandiva/precompiled/decimal_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@ arrow::BasicDecimal128 Multiply(const BasicDecimalScalar128& x,
int32_t out_scale, bool* overflow);

/// Divide 'x' by 'y', and return the result.
arrow::BasicDecimal128 Divide(int64_t context, const BasicDecimalScalar128& x,
arrow::BasicDecimal128 Divide(const BasicDecimalScalar128& x,
const BasicDecimalScalar128& y, int32_t out_precision,
int32_t out_scale, bool* overflow);

/// Divide 'x' by 'y', and return the remainder.
arrow::BasicDecimal128 Mod(int64_t context, const BasicDecimalScalar128& x,
arrow::BasicDecimal128 Mod(const BasicDecimalScalar128& x,
const BasicDecimalScalar128& y, int32_t out_precision,
int32_t out_scale, bool* overflow);

Expand Down
14 changes: 6 additions & 8 deletions cpp/src/gandiva/precompiled/decimal_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,13 +144,13 @@ void TestDecimalSql::Verify(DecimalTypeUtil::Op op, const DecimalScalar128& x,

case DecimalTypeUtil::kOpDivide:
op_name = "divide";
out_value = decimalops::Divide(context, x, y, out_type->precision(),
out_value = decimalops::Divide(x, y, out_type->precision(),
out_type->scale(), &overflow);
break;

case DecimalTypeUtil::kOpMod:
op_name = "mod";
out_value = decimalops::Mod(context, x, y, out_type->precision(), out_type->scale(),
out_value = decimalops::Mod(x, y, out_type->precision(), out_type->scale(),
&overflow);
break;

Expand Down Expand Up @@ -451,32 +451,30 @@ TEST_F(TestDecimalSql, DivideByZero) {
context.Reset();
result_precision = 38;
result_scale = 19;
decimalops::Divide(reinterpret_cast<gdv_int64>(&context),
DecimalScalar128{"201", 20, 3}, DecimalScalar128{"0", 20, 2},
decimalops::Divide(DecimalScalar128{"201", 20, 3}, DecimalScalar128{"0", 20, 2},
result_precision, result_scale, &overflow);
// EXPECT_TRUE(context.has_error());
// EXPECT_EQ(context.get_error(), "divide by zero error");

// divide-by-nonzero should not cause an error.
context.Reset();
decimalops::Divide(reinterpret_cast<gdv_int64>(&context),
DecimalScalar128{"201", 20, 3}, DecimalScalar128{"1", 20, 2},
decimalops::Divide(DecimalScalar128{"201", 20, 3}, DecimalScalar128{"1", 20, 2},
result_precision, result_scale, &overflow);
EXPECT_FALSE(context.has_error());

// mod-by-zero should cause an error.
context.Reset();
result_precision = 20;
result_scale = 3;
decimalops::Mod(reinterpret_cast<gdv_int64>(&context), DecimalScalar128{"201", 20, 3},
decimalops::Mod(DecimalScalar128{"201", 20, 3},
DecimalScalar128{"0", 20, 2}, result_precision, result_scale,
&overflow);
// EXPECT_TRUE(context.has_error());
// EXPECT_EQ(context.get_error(), "divide by zero error");

// mod-by-nonzero should not cause an error.
context.Reset();
decimalops::Mod(reinterpret_cast<gdv_int64>(&context), DecimalScalar128{"201", 20, 3},
decimalops::Mod(DecimalScalar128{"201", 20, 3},
DecimalScalar128{"1", 20, 2}, result_precision, result_scale,
&overflow);
EXPECT_FALSE(context.has_error());
Expand Down
48 changes: 40 additions & 8 deletions cpp/src/gandiva/precompiled/decimal_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,35 +52,67 @@ void multiply_decimal128_decimal128(int64_t x_high, uint64_t x_low, int32_t x_pr
}

FORCE_INLINE
void divide_decimal128_decimal128(int64_t context, int64_t x_high, uint64_t x_low,
int32_t x_precision, int32_t x_scale, int64_t y_high,
uint64_t y_low, int32_t y_precision, int32_t y_scale,
void divide_decimal128_decimal128(int64_t x_high, uint64_t x_low,
int32_t x_precision, int32_t x_scale, bool x_valid,
int64_t y_high, uint64_t y_low, int32_t y_precision,
int32_t y_scale, bool y_valid, bool* out_valid,
int32_t out_precision, int32_t out_scale,
int64_t* out_high, uint64_t* out_low) {
if (!x_valid || !y_valid) {
*out_valid = false;
arrow::BasicDecimal128 out = 0;
*out_high = out.high_bits();
*out_low = out.low_bits();
return;
}
gandiva::BasicDecimalScalar128 x(x_high, x_low, x_precision, x_scale);
gandiva::BasicDecimalScalar128 y(y_high, y_low, y_precision, y_scale);
bool overflow;

if (y.value() == 0) {
*out_valid = false;
arrow::BasicDecimal128 out = 0;
*out_high = out.high_bits();
*out_low = out.low_bits();
return;
}
// TODO ravindra: generate error on overflows (ARROW-4570).
arrow::BasicDecimal128 out =
gandiva::decimalops::Divide(context, x, y, out_precision, out_scale, &overflow);
gandiva::decimalops::Divide(x, y, out_precision, out_scale, &overflow);
*out_valid = true;
*out_high = out.high_bits();
*out_low = out.low_bits();
}

FORCE_INLINE
void mod_decimal128_decimal128(int64_t context, int64_t x_high, uint64_t x_low,
int32_t x_precision, int32_t x_scale, int64_t y_high,
uint64_t y_low, int32_t y_precision, int32_t y_scale,
void mod_decimal128_decimal128(int64_t x_high, uint64_t x_low,
int32_t x_precision, int32_t x_scale, bool x_valid,
int64_t y_high, uint64_t y_low, int32_t y_precision,
int32_t y_scale, bool y_valid, bool* out_valid,
int32_t out_precision, int32_t out_scale,
int64_t* out_high, uint64_t* out_low) {
if (!x_valid || !y_valid) {
*out_valid = false;
arrow::BasicDecimal128 out = 0;
*out_high = out.high_bits();
*out_low = out.low_bits();
return;
}
gandiva::BasicDecimalScalar128 x(x_high, x_low, x_precision, x_scale);
gandiva::BasicDecimalScalar128 y(y_high, y_low, y_precision, y_scale);
bool overflow;

if (y.value() == 0) {
*out_valid = false;
arrow::BasicDecimal128 out = 0;
*out_high = out.high_bits();
*out_low = out.low_bits();
return;
}
// TODO ravindra: generate error on overflows (ARROW-4570).
arrow::BasicDecimal128 out =
gandiva::decimalops::Mod(context, x, y, out_precision, out_scale, &overflow);
gandiva::decimalops::Mod(x, y, out_precision, out_scale, &overflow);
*out_valid = true;
*out_high = out.high_bits();
*out_low = out.low_bits();
}
Expand Down