Skip to content

Commit

Permalink
ARROW-14795: [C++] Fix issue on replace with mask for null values
Browse files Browse the repository at this point in the history
ARROW-14795: [C++]  Fix issue on vector replace with mask for null values, which weren't updated on null bitmaps

Closes apache#11759 from AlvinJ15/achunga/14795-fix_vector_replace_with_mask

Authored-by: alvinj15 <Alvin258461@>
Signed-off-by: David Li <[email protected]>
  • Loading branch information
alvinj15 authored and kou committed Dec 1, 2021
1 parent 68cb943 commit ce28368
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 11 deletions.
22 changes: 11 additions & 11 deletions cpp/src/arrow/compute/kernels/vector_replace.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ struct CopyArrayBitmap {

void SetBit(uint8_t* out_bitmap, int64_t out_offset, int64_t offset) const {
BitUtil::SetBitTo(out_bitmap, out_offset,
BitUtil::GetBit(in_bitmap, in_offset + offset));
in_bitmap ? BitUtil::GetBit(in_bitmap, in_offset + offset) : true);
}
};

Expand Down Expand Up @@ -122,7 +122,7 @@ void ReplaceWithArrayMaskImpl(const ArrayData& array, const ArrayData& mask,
if (replacements_bitmap) {
copy_bitmap.CopyBitmap(out_bitmap, out_offset + write_offset, replacements_offset,
block.length);
} else if (!replacements_bitmap && out_bitmap) {
} else if (out_bitmap) {
BitUtil::SetBitsTo(out_bitmap, out_offset + write_offset, block.length, true);
}
replacements_offset += block.length;
Expand All @@ -133,10 +133,9 @@ void ReplaceWithArrayMaskImpl(const ArrayData& array, const ArrayData& mask,
BitUtil::GetBit(mask_bitmap, write_offset + mask.offset + i))) {
Functor::CopyData(*array.type, out_values, out_offset + write_offset + i,
replacements, replacements_offset, /*length=*/1);
if (replacements_bitmap) {
copy_bitmap.SetBit(out_bitmap, out_offset + write_offset + i,
replacements_offset);
}
copy_bitmap.SetBit(out_bitmap, out_offset + write_offset + i,

replacements_offset);
replacements_offset++;
}
}
Expand All @@ -154,9 +153,8 @@ Status ReplaceWithArrayMask(KernelContext* ctx, const ArrayData& array,
uint8_t* out_values = output->buffers[1]->mutable_data();
const uint8_t* mask_bitmap = mask.MayHaveNulls() ? mask.buffers[0]->data() : nullptr;
const uint8_t* mask_values = mask.buffers[1]->data();
const bool replacements_bitmap = replacements.is_array()
? replacements.array()->MayHaveNulls()
: !replacements.scalar()->is_valid;
const bool replacements_bitmap =
replacements.is_array() ? replacements.array()->MayHaveNulls() : true;
if (replacements.is_array()) {
// Check that we have enough replacement values
const int64_t replacements_length = replacements.array()->length;
Expand Down Expand Up @@ -189,7 +187,7 @@ Status ReplaceWithArrayMask(KernelContext* ctx, const ArrayData& array,
const ArrayData& array_repl = *replacements.array();
ReplaceWithArrayMaskImpl<Functor>(
array, mask, array_repl, replacements_bitmap,
CopyArrayBitmap{replacements_bitmap ? array_repl.buffers[0]->data() : nullptr,
CopyArrayBitmap{(replacements_bitmap) ? array_repl.buffers[0]->data() : nullptr,
array_repl.offset},
mask_bitmap, mask_values, out_bitmap, out_values, out_offset);
} else {
Expand Down Expand Up @@ -254,7 +252,9 @@ struct ReplaceWithMask<Type, enable_if_boolean<Type>> {
}
static void CopyData(const DataType&, uint8_t* out, const int64_t out_offset,
const Scalar& in, const int64_t in_offset, const int64_t length) {
BitUtil::SetBitsTo(out, out_offset, length, in.is_valid);
BitUtil::SetBitsTo(
out, out_offset, length,
in.is_valid ? checked_cast<const BooleanScalar&>(in).value : false);
}

static Status ExecScalarMask(KernelContext* ctx, const ArrayData& array,
Expand Down
43 changes: 43 additions & 0 deletions cpp/src/arrow/compute/kernels/vector_replace_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,38 @@ TYPED_TEST(TestReplaceNumeric, ReplaceWithMask) {
this->array("[0, null, 10]"));
}

TYPED_TEST(TestReplaceNumeric, ReplaceWithMaskForNullValuesAndMaskEnabled) {
this->Assert(ReplaceWithMask, this->array("[1, null, 1]"),
this->mask("[false, true, false]"), this->array("[7]"),
this->array("[1, 7, 1]"));
this->Assert(ReplaceWithMask, this->array("[1, null, 1, 7]"),
this->mask("[false, true, false, true]"), this->array("[7, 20]"),
this->array("[1, 7, 1, 20]"));
this->Assert(ReplaceWithMask, this->array("[1, 2, 3, 4]"),
this->mask("[false, true, false, true]"), this->array("[null, null]"),
this->array("[1, null, 3, null]"));
this->Assert(ReplaceWithMask, this->array("[null, 2, 3, 4]"),
this->mask("[true, true, false, true]"), this->array("[1, null, null]"),
this->array("[1, null, 3, null]"));
this->Assert(ReplaceWithMask, this->array("[1, null, 1]"),
this->mask("[false, true, false]"), this->scalar("null"),
this->array("[1, null, 1]"));
this->Assert(ReplaceWithMask, this->array("[1, null, 1]"),
this->mask("[true, true, true]"), this->array("[7, 7, 7]"),
this->array("[7, 7, 7]"));
this->Assert(ReplaceWithMask, this->array("[1, null, 1]"),
this->mask("[true, true, true]"), this->array("[null, null, null]"),
this->array("[null, null, null]"));
this->Assert(ReplaceWithMask, this->array("[1, null, 1]"),
this->mask("[false, true, false]"), this->scalar("null"),
this->array("[1, null, 1]"));
this->Assert(ReplaceWithMask, this->array("[1, null, 1]"),
this->mask("[true, true, true]"), this->scalar("null"),
this->array("[null, null, null]"));
this->Assert(ReplaceWithMask, this->array("[null, null]"), this->mask("[true, true]"),
this->array("[1, 1]"), this->array("[1, 1]"));
}

TYPED_TEST(TestReplaceNumeric, ReplaceWithMaskRandom) {
using ArrayType = typename TypeTraits<TypeParam>::ArrayType;
using CType = typename TypeTraits<TypeParam>::CType;
Expand Down Expand Up @@ -340,16 +372,24 @@ TEST_F(TestReplaceBoolean, ReplaceWithMask) {
this->mask("[false, false, null, null, true, true]"),
this->array("[false, null]"),
this->array("[null, null, null, null, false, null]"));
this->Assert(ReplaceWithMask, this->array("[true, null, true]"),
this->mask("[false, true, false]"), this->array("[true]"),
this->array("[true, true, true]"));

this->Assert(ReplaceWithMask, this->array("[]"), this->mask("[]"), this->scalar("true"),
this->array("[]"));
this->Assert(ReplaceWithMask, this->array("[null, false, true]"),
this->mask("[true, false, false]"), this->scalar("false"),
this->array("[false, false, true]"));
this->Assert(ReplaceWithMask, this->array("[false, false]"), this->mask("[true, true]"),
this->scalar("true"), this->array("[true, true]"));
this->Assert(ReplaceWithMask, this->array("[false, false]"), this->mask("[true, true]"),
this->scalar("null"), this->array("[null, null]"));
this->Assert(ReplaceWithMask, this->array("[false, false, false]"),
this->mask("[false, null, true]"), this->scalar("true"),
this->array("[false, null, true]"));
this->Assert(ReplaceWithMask, this->array("[null, null]"), this->mask("[true, true]"),
this->array("[true, true]"), this->array("[true, true]"));
}

TEST_F(TestReplaceBoolean, ReplaceWithMaskErrors) {
Expand Down Expand Up @@ -427,6 +467,9 @@ TEST_F(TestReplaceFixedSizeBinary, ReplaceWithMask) {
this->mask("[false, false, null, null, true, true]"),
this->array(R"(["aaa", null])"),
this->array(R"([null, null, null, null, "aaa", null])"));
this->Assert(ReplaceWithMask, this->array(R"(["aaa", null, "bbb"])"),
this->mask("[false, true, false]"), this->array(R"(["aba"])"),
this->array(R"(["aaa", "aba", "bbb"])"));

this->Assert(ReplaceWithMask, this->array("[]"), this->mask("[]"),
this->scalar(R"("zzz")"), this->array("[]"));
Expand Down

0 comments on commit ce28368

Please sign in to comment.