Skip to content

Commit

Permalink
Increase the precision of decimals during compute
Browse files Browse the repository at this point in the history
This is an initial pass whereby a scalar aggregate of a Decimal type
increases its precision to the maximum. That is, a sum of an
array of decimal128(3, 2)'s becomes a decimal128(38, 2).

Previously, the exact decimal type was preserved (e.g., a sum of
decimal128(3, 2)'s was a decimal128(3, 2)) *regardless* of whether
that was enough precision to capture the full decimal value.
  • Loading branch information
khwilson committed Sep 17, 2024
1 parent a5d40d0 commit 944d15d
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 12 deletions.
38 changes: 30 additions & 8 deletions cpp/src/arrow/compute/kernels/aggregate_basic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -365,11 +365,20 @@ struct ProductImpl : public ScalarAggregator {
}

Status Finalize(KernelContext*, Datum* out) override {
std::shared_ptr<DataType> out_type_;
if (auto decimal128_type = std::dynamic_pointer_cast<Decimal128Type>(this->out_type)) {
ARROW_ASSIGN_OR_RAISE(out_type_, Decimal128Type::Make(Decimal128Type::kMaxPrecision, decimal128_type->scale()));
} else if (auto decimal256_type = std::dynamic_pointer_cast<Decimal256Type>(this->out_type)) {
ARROW_ASSIGN_OR_RAISE(out_type_, Decimal256Type::Make(Decimal256Type::kMaxPrecision, decimal256_type->scale()));
} else {
out_type_ = out_type;
}

if ((!options.skip_nulls && this->nulls_observed) ||
(this->count < options.min_count)) {
out->value = std::make_shared<OutputType>(out_type);
out->value = std::make_shared<OutputType>(out_type_);
} else {
out->value = std::make_shared<OutputType>(this->product, out_type);
out->value = std::make_shared<OutputType>(this->product, out_type_);
}
return Status::OK();
}
Expand Down Expand Up @@ -1020,6 +1029,19 @@ const FunctionDoc index_doc{"Find the index of the first occurrence of a given v

} // namespace


Result<TypeHolder> MaxPrecisionDecimalType(KernelContext*, const std::vector<TypeHolder>& types) {
std::shared_ptr<DataType> out_type_;
if (auto decimal128_type = std::dynamic_pointer_cast<Decimal128Type>(types.front().GetSharedPtr())) {
ARROW_ASSIGN_OR_RAISE(out_type_, Decimal128Type::Make(Decimal128Type::kMaxPrecision, decimal128_type->scale()));
} else if (auto decimal256_type = std::dynamic_pointer_cast<Decimal256Type>(types.front().GetSharedPtr())) {
ARROW_ASSIGN_OR_RAISE(out_type_, Decimal256Type::Make(Decimal256Type::kMaxPrecision, decimal256_type->scale()));
} else {
return Status::TypeError("Bad call");
}
return TypeHolder(out_type_);
}

void RegisterScalarAggregateBasic(FunctionRegistry* registry) {
static auto default_scalar_aggregate_options = ScalarAggregateOptions::Defaults();
static auto default_count_options = CountOptions::Defaults();
Expand Down Expand Up @@ -1048,9 +1070,9 @@ void RegisterScalarAggregateBasic(FunctionRegistry* registry) {
func = std::make_shared<ScalarAggregateFunction>("sum", Arity::Unary(), sum_doc,
&default_scalar_aggregate_options);
AddArrayScalarAggKernels(SumInit, {boolean()}, uint64(), func.get());
AddAggKernel(KernelSignature::Make({Type::DECIMAL128}, FirstType), SumInit, func.get(),
AddAggKernel(KernelSignature::Make({Type::DECIMAL128}, MaxPrecisionDecimalType), SumInit, func.get(),
SimdLevel::NONE);
AddAggKernel(KernelSignature::Make({Type::DECIMAL256}, FirstType), SumInit, func.get(),
AddAggKernel(KernelSignature::Make({Type::DECIMAL256}, MaxPrecisionDecimalType), SumInit, func.get(),
SimdLevel::NONE);
AddArrayScalarAggKernels(SumInit, SignedIntTypes(), int64(), func.get());
AddArrayScalarAggKernels(SumInit, UnsignedIntTypes(), uint64(), func.get());
Expand All @@ -1076,9 +1098,9 @@ void RegisterScalarAggregateBasic(FunctionRegistry* registry) {
&default_scalar_aggregate_options);
AddArrayScalarAggKernels(MeanInit, {boolean()}, float64(), func.get());
AddArrayScalarAggKernels(MeanInit, NumericTypes(), float64(), func.get());
AddAggKernel(KernelSignature::Make({Type::DECIMAL128}, FirstType), MeanInit, func.get(),
AddAggKernel(KernelSignature::Make({Type::DECIMAL128}, MaxPrecisionDecimalType), MeanInit, func.get(),
SimdLevel::NONE);
AddAggKernel(KernelSignature::Make({Type::DECIMAL256}, FirstType), MeanInit, func.get(),
AddAggKernel(KernelSignature::Make({Type::DECIMAL256}, MaxPrecisionDecimalType), MeanInit, func.get(),
SimdLevel::NONE);
AddArrayScalarAggKernels(MeanInit, {null()}, float64(), func.get());
// Add the SIMD variants for mean
Expand Down Expand Up @@ -1160,9 +1182,9 @@ void RegisterScalarAggregateBasic(FunctionRegistry* registry) {
AddArrayScalarAggKernels(ProductInit::Init, UnsignedIntTypes(), uint64(), func.get());
AddArrayScalarAggKernels(ProductInit::Init, FloatingPointTypes(), float64(),
func.get());
AddAggKernel(KernelSignature::Make({Type::DECIMAL128}, FirstType), ProductInit::Init,
AddAggKernel(KernelSignature::Make({Type::DECIMAL128}, MaxPrecisionDecimalType), ProductInit::Init,
func.get(), SimdLevel::NONE);
AddAggKernel(KernelSignature::Make({Type::DECIMAL256}, FirstType), ProductInit::Init,
AddAggKernel(KernelSignature::Make({Type::DECIMAL256}, MaxPrecisionDecimalType), ProductInit::Init,
func.get(), SimdLevel::NONE);
AddArrayScalarAggKernels(ProductInit::Init, {null()}, int64(), func.get());
DCHECK_OK(registry->AddFunction(std::move(func)));
Expand Down
30 changes: 26 additions & 4 deletions cpp/src/arrow/compute/kernels/aggregate_basic.inc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,20 @@ struct SumImpl : public ScalarAggregator {
}

Status Finalize(KernelContext*, Datum* out) override {
std::shared_ptr<DataType> out_type_;
if (auto decimal128_type = std::dynamic_pointer_cast<Decimal128Type>(out_type)) {
ARROW_ASSIGN_OR_RAISE(out_type_, Decimal128Type::Make(Decimal128Type::kMaxPrecision, decimal128_type->scale()));
} else if (auto decimal256_type = std::dynamic_pointer_cast<Decimal256Type>(out_type)) {
ARROW_ASSIGN_OR_RAISE(out_type_, Decimal256Type::Make(Decimal256Type::kMaxPrecision, decimal256_type->scale()));
} else {
out_type_ = out_type;
}

if ((!options.skip_nulls && this->nulls_observed) ||
(this->count < options.min_count)) {
out->value = std::make_shared<OutputType>(out_type);
out->value = std::make_shared<OutputType>(out_type_);
} else {
out->value = std::make_shared<OutputType>(this->sum, out_type);
out->value = std::make_shared<OutputType>(this->sum, out_type_);
}
return Status::OK();
}
Expand Down Expand Up @@ -219,9 +228,22 @@ struct MeanImpl<ArrowType, SimdLevel, enable_if_decimal<ArrowType>>

template <typename T = ArrowType>
Status FinalizeImpl(Datum* out) {
std::shared_ptr<DataType> out_type_;
if (auto decimal128_type = std::dynamic_pointer_cast<Decimal128Type>(this->out_type)) {
ARROW_ASSIGN_OR_RAISE(out_type_, Decimal128Type::Make(Decimal128Type::kMaxPrecision, decimal128_type->scale()));
} else if (auto decimal256_type = std::dynamic_pointer_cast<Decimal256Type>(this->out_type)) {
ARROW_ASSIGN_OR_RAISE(out_type_, Decimal256Type::Make(Decimal256Type::kMaxPrecision, decimal256_type->scale()));
} else {
return Status::TypeError(
"The decimal specialization of MeanImpl was passed a type ",
this->out_type->ToString(),
" and not a decimal type"
);
}

if ((!options.skip_nulls && this->nulls_observed) ||
(this->count < options.min_count) || (this->count == 0)) {
out->value = std::make_shared<OutputType>(this->out_type);
out->value = std::make_shared<OutputType>(out_type_);
} else {
SumCType quotient, remainder;
ARROW_ASSIGN_OR_RAISE(std::tie(quotient, remainder), this->sum.Divide(this->count));
Expand All @@ -234,7 +256,7 @@ struct MeanImpl<ArrowType, SimdLevel, enable_if_decimal<ArrowType>>
quotient -= 1;
}
}
out->value = std::make_shared<OutputType>(quotient, this->out_type);
out->value = std::make_shared<OutputType>(quotient, out_type_);
}
return Status::OK();
}
Expand Down
74 changes: 74 additions & 0 deletions python/pyarrow/tests/test_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,10 +349,51 @@ def test_sum_array(arrow_type):

arr = pa.array([], type=arrow_type)
assert arr.sum().as_py() is None # noqa: E711
assert pc.sum(arr).as_py() is None # noqa: E711
assert arr.sum(min_count=0).as_py() == 0
assert pc.sum(arr, min_count=0).as_py() == 0


@pytest.mark.parametrize("arrow_type", [pa.decimal128(3, 2), pa.decimal256(3, 2)])
def test_sum_decimal_array(arrow_type):
from decimal import Decimal
max_precision_type = pa.decimal128(38, arrow_type.scale) if pa.types.is_decimal128(arrow_type) else pa.decimal256(76, arrow_type.scale)
expected_sum = Decimal("5.79")
zero = Decimal("0.00")

arr = pa.array([Decimal("1.23"), Decimal("4.56")], type=arrow_type)
assert arr.sum().as_py() == expected_sum
assert arr.sum().type == max_precision_type
assert pc.sum(arr).as_py() == expected_sum
assert pc.sum(arr).type == max_precision_type

arr = pa.array([Decimal("1.23"), Decimal("4.56"), None], type=arrow_type)
assert arr.sum().as_py() == expected_sum
assert arr.sum().type == max_precision_type
assert pc.sum(arr).as_py() == expected_sum
assert pc.sum(arr).type == max_precision_type

arr = pa.array([None], type=arrow_type)
assert arr.sum().as_py() is None # noqa: E711
assert arr.sum().type == max_precision_type # noqa: E711
assert pc.sum(arr).as_py() is None # noqa: E711
assert pc.sum(arr).type == max_precision_type # noqa: E711
assert arr.sum(min_count=0).as_py() == zero
assert arr.sum(min_count=0).type == max_precision_type
assert pc.sum(arr, min_count=0).as_py() == zero
assert pc.sum(arr, min_count=0).type == max_precision_type

arr = pa.array([], type=arrow_type)
assert arr.sum().as_py() is None # noqa: E711
assert arr.sum().type == max_precision_type # noqa: E711
assert pc.sum(arr).as_py() is None # noqa: E711
assert pc.sum(arr).type == max_precision_type # noqa: E711
assert arr.sum(min_count=0).as_py() == zero
assert arr.sum(min_count=0).type == max_precision_type
assert pc.sum(arr, min_count=0).as_py() == zero
assert pc.sum(arr, min_count=0).type == max_precision_type


@pytest.mark.parametrize('arrow_type', numerical_arrow_types)
def test_sum_chunked_array(arrow_type):
arr = pa.chunked_array([pa.array([1, 2, 3, 4], type=arrow_type)])
Expand All @@ -376,6 +417,39 @@ def test_sum_chunked_array(arrow_type):
assert pc.sum(arr, min_count=0).as_py() == 0


@pytest.mark.parametrize('arrow_type', [pa.decimal128(3, 2), pa.decimal256(3, 2)])
def test_sum_chunked_array_decimal_type(arrow_type):
from decimal import Decimal
max_precision_type = pa.decimal128(38, arrow_type.scale) if pa.types.is_decimal128(arrow_type) else pa.decimal256(76, arrow_type.scale)
expected_sum = Decimal("5.79")
zero = Decimal("0.00")

arr = pa.chunked_array([pa.array([Decimal("1.23"), Decimal("4.56")], type=arrow_type)])
assert pc.sum(arr).as_py() == expected_sum
assert pc.sum(arr).type == max_precision_type

arr = pa.chunked_array([
pa.array([Decimal("1.23")], type=arrow_type), pa.array([Decimal("4.56")], type=arrow_type)
])
assert pc.sum(arr).as_py() == expected_sum
assert pc.sum(arr).type == max_precision_type

arr = pa.chunked_array([
pa.array([Decimal("1.23")], type=arrow_type),
pa.array([], type=arrow_type),
pa.array([Decimal("4.56")], type=arrow_type)
])
assert pc.sum(arr).as_py() == expected_sum
assert pc.sum(arr).type == max_precision_type

arr = pa.chunked_array((), type=arrow_type)
assert arr.num_chunks == 0
assert pc.sum(arr).as_py() is None # noqa: E711
assert pc.sum(arr).type == max_precision_type
assert pc.sum(arr, min_count=0).as_py() == zero
assert pc.sum(arr, min_count=0).type == max_precision_type


def test_mode_array():
# ARROW-9917
arr = pa.array([1, 1, 3, 4, 3, 5], type='int64')
Expand Down

0 comments on commit 944d15d

Please sign in to comment.