diff --git a/cpp/src/arrow/compute/kernels/vector_replace.cc b/cpp/src/arrow/compute/kernels/vector_replace.cc index 7f204b529ebfa..009809941186b 100644 --- a/cpp/src/arrow/compute/kernels/vector_replace.cc +++ b/cpp/src/arrow/compute/kernels/vector_replace.cc @@ -81,7 +81,7 @@ struct CopyArrayBitmap { void SetBit(uint8_t* out_bitmap, int64_t out_offset, int64_t offset) const { BitUtil::SetBitTo(out_bitmap, out_offset, - BitUtil::GetBit(in_bitmap, in_offset + offset)); + in_bitmap ? BitUtil::GetBit(in_bitmap, in_offset + offset) : true); } }; @@ -122,7 +122,7 @@ void ReplaceWithArrayMaskImpl(const ArrayData& array, const ArrayData& mask, if (replacements_bitmap) { copy_bitmap.CopyBitmap(out_bitmap, out_offset + write_offset, replacements_offset, block.length); - } else if (!replacements_bitmap && out_bitmap) { + } else if (out_bitmap) { BitUtil::SetBitsTo(out_bitmap, out_offset + write_offset, block.length, true); } replacements_offset += block.length; @@ -133,10 +133,9 @@ void ReplaceWithArrayMaskImpl(const ArrayData& array, const ArrayData& mask, BitUtil::GetBit(mask_bitmap, write_offset + mask.offset + i))) { Functor::CopyData(*array.type, out_values, out_offset + write_offset + i, replacements, replacements_offset, /*length=*/1); - if (replacements_bitmap) { - copy_bitmap.SetBit(out_bitmap, out_offset + write_offset + i, - replacements_offset); - } + copy_bitmap.SetBit(out_bitmap, out_offset + write_offset + i, + + replacements_offset); replacements_offset++; } } @@ -154,9 +153,8 @@ Status ReplaceWithArrayMask(KernelContext* ctx, const ArrayData& array, uint8_t* out_values = output->buffers[1]->mutable_data(); const uint8_t* mask_bitmap = mask.MayHaveNulls() ? mask.buffers[0]->data() : nullptr; const uint8_t* mask_values = mask.buffers[1]->data(); - const bool replacements_bitmap = replacements.is_array() - ? replacements.array()->MayHaveNulls() - : !replacements.scalar()->is_valid; + const bool replacements_bitmap = + replacements.is_array() ? replacements.array()->MayHaveNulls() : true; if (replacements.is_array()) { // Check that we have enough replacement values const int64_t replacements_length = replacements.array()->length; @@ -189,8 +187,9 @@ Status ReplaceWithArrayMask(KernelContext* ctx, const ArrayData& array, const ArrayData& array_repl = *replacements.array(); ReplaceWithArrayMaskImpl( array, mask, array_repl, replacements_bitmap, - CopyArrayBitmap{replacements_bitmap ? array_repl.buffers[0]->data() : nullptr, - array_repl.offset}, + CopyArrayBitmap{ + (!!array_repl.buffers[0]) ? array_repl.buffers[0]->data() : nullptr, + array_repl.offset}, mask_bitmap, mask_values, out_bitmap, out_values, out_offset); } else { const Scalar& scalar_repl = *replacements.scalar(); @@ -229,6 +228,12 @@ struct ReplaceWithMask::Unbox(in)); } + static bool HasBitmap(const ArrayData& replacements) { + return replacements.MayHaveNulls(); + } + + static bool HasBitmap(const Scalar& replacements) { return true; } + static Status ExecScalarMask(KernelContext* ctx, const ArrayData& array, const BooleanScalar& mask, const Datum& replacements, ArrayData* output) { @@ -254,9 +259,17 @@ struct ReplaceWithMask> { } static void CopyData(const DataType&, uint8_t* out, const int64_t out_offset, const Scalar& in, const int64_t in_offset, const int64_t length) { - BitUtil::SetBitsTo(out, out_offset, length, in.is_valid); + BitUtil::SetBitsTo( + out, out_offset, length, + in.is_valid ? checked_cast(in).value : false); } + static bool HasBitmap(const ArrayData& replacements) { + return replacements.MayHaveNulls(); + } + + static bool HasBitmap(const Scalar& replacements) { return true; } + static Status ExecScalarMask(KernelContext* ctx, const ArrayData& array, const BooleanScalar& mask, const Datum& replacements, ArrayData* output) { @@ -296,6 +309,12 @@ struct ReplaceWithMask> { } } + static bool HasBitmap(const ArrayData& replacements) { + return replacements.MayHaveNulls(); + } + + static bool HasBitmap(const Scalar& replacements) { return true; } + static Status ExecScalarMask(KernelContext* ctx, const ArrayData& array, const BooleanScalar& mask, const Datum& replacements, ArrayData* output) { diff --git a/cpp/src/arrow/compute/kernels/vector_replace_test.cc b/cpp/src/arrow/compute/kernels/vector_replace_test.cc index e12a42e5254e4..96afae6b12a7b 100644 --- a/cpp/src/arrow/compute/kernels/vector_replace_test.cc +++ b/cpp/src/arrow/compute/kernels/vector_replace_test.cc @@ -235,6 +235,38 @@ TYPED_TEST(TestReplaceNumeric, ReplaceWithMask) { this->array("[0, null, 10]")); } +TYPED_TEST(TestReplaceNumeric, ReplaceWithMaskForNullValuesAndMaskEnabled) { + this->Assert(ReplaceWithMask, this->array("[1, null, 1]"), + this->mask("[false, true, false]"), this->array("[7]"), + this->array("[1, 7, 1]")); + this->Assert(ReplaceWithMask, this->array("[1, null, 1, 7]"), + this->mask("[false, true, false, true]"), this->array("[7, 20]"), + this->array("[1, 7, 1, 20]")); + this->Assert(ReplaceWithMask, this->array("[1, 2, 3, 4]"), + this->mask("[false, true, false, true]"), this->array("[null, null]"), + this->array("[1, null, 3, null]")); + this->Assert(ReplaceWithMask, this->array("[null, 2, 3, 4]"), + this->mask("[true, true, false, true]"), this->array("[1, null, null]"), + this->array("[1, null, 3, null]")); + this->Assert(ReplaceWithMask, this->array("[1, null, 1]"), + this->mask("[false, true, false]"), this->scalar("null"), + this->array("[1, null, 1]")); + this->Assert(ReplaceWithMask, this->array("[1, null, 1]"), + this->mask("[true, true, true]"), this->array("[7, 7, 7]"), + this->array("[7, 7, 7]")); + this->Assert(ReplaceWithMask, this->array("[1, null, 1]"), + this->mask("[true, true, true]"), this->array("[null, null, null]"), + this->array("[null, null, null]")); + this->Assert(ReplaceWithMask, this->array("[1, null, 1]"), + this->mask("[false, true, false]"), this->scalar("null"), + this->array("[1, null, 1]")); + this->Assert(ReplaceWithMask, this->array("[1, null, 1]"), + this->mask("[true, true, true]"), this->scalar("null"), + this->array("[null, null, null]")); + this->Assert(ReplaceWithMask, this->array("[null, null]"), this->mask("[true, true]"), + this->array("[1, 1]"), this->array("[1, 1]")); +} + TYPED_TEST(TestReplaceNumeric, ReplaceWithMaskRandom) { using ArrayType = typename TypeTraits::ArrayType; using CType = typename TypeTraits::CType; @@ -340,9 +372,15 @@ TEST_F(TestReplaceBoolean, ReplaceWithMask) { this->mask("[false, false, null, null, true, true]"), this->array("[false, null]"), this->array("[null, null, null, null, false, null]")); + this->Assert(ReplaceWithMask, this->array("[true, null, true]"), + this->mask("[false, true, false]"), this->array("[true]"), + this->array("[true, true, true]")); this->Assert(ReplaceWithMask, this->array("[]"), this->mask("[]"), this->scalar("true"), this->array("[]")); + this->Assert(ReplaceWithMask, this->array("[null, false, true]"), + this->mask("[true, false, false]"), this->scalar("false"), + this->array("[false, false, true]")); this->Assert(ReplaceWithMask, this->array("[false, false]"), this->mask("[true, true]"), this->scalar("true"), this->array("[true, true]")); this->Assert(ReplaceWithMask, this->array("[false, false]"), this->mask("[true, true]"), @@ -350,6 +388,8 @@ TEST_F(TestReplaceBoolean, ReplaceWithMask) { this->Assert(ReplaceWithMask, this->array("[false, false, false]"), this->mask("[false, null, true]"), this->scalar("true"), this->array("[false, null, true]")); + this->Assert(ReplaceWithMask, this->array("[null, null]"), this->mask("[true, true]"), + this->array("[true, true]"), this->array("[true, true]")); } TEST_F(TestReplaceBoolean, ReplaceWithMaskErrors) { @@ -427,6 +467,9 @@ TEST_F(TestReplaceFixedSizeBinary, ReplaceWithMask) { this->mask("[false, false, null, null, true, true]"), this->array(R"(["aaa", null])"), this->array(R"([null, null, null, null, "aaa", null])")); + this->Assert(ReplaceWithMask, this->array(R"(["aaa", null, "bbb"])"), + this->mask("[false, true, false]"), this->array(R"(["aba"])"), + this->array(R"(["aaa", "aba", "bbb"])")); this->Assert(ReplaceWithMask, this->array("[]"), this->mask("[]"), this->scalar(R"("zzz")"), this->array("[]"));