Skip to content

Commit

Permalink
ARROW-14795: [C++] Fix issue on replace with mask for null values
Browse files Browse the repository at this point in the history
  • Loading branch information
alvinj15 authored and alvinj15 committed Nov 26, 2021
1 parent b305edb commit 15d4b11
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 12 deletions.
43 changes: 31 additions & 12 deletions cpp/src/arrow/compute/kernels/vector_replace.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ struct CopyArrayBitmap {

void SetBit(uint8_t* out_bitmap, int64_t out_offset, int64_t offset) const {
BitUtil::SetBitTo(out_bitmap, out_offset,
BitUtil::GetBit(in_bitmap, in_offset + offset));
in_bitmap ? BitUtil::GetBit(in_bitmap, in_offset + offset) : true);
}
};

Expand Down Expand Up @@ -122,7 +122,7 @@ void ReplaceWithArrayMaskImpl(const ArrayData& array, const ArrayData& mask,
if (replacements_bitmap) {
copy_bitmap.CopyBitmap(out_bitmap, out_offset + write_offset, replacements_offset,
block.length);
} else if (!replacements_bitmap && out_bitmap) {
} else if (out_bitmap) {
BitUtil::SetBitsTo(out_bitmap, out_offset + write_offset, block.length, true);
}
replacements_offset += block.length;
Expand All @@ -133,10 +133,9 @@ void ReplaceWithArrayMaskImpl(const ArrayData& array, const ArrayData& mask,
BitUtil::GetBit(mask_bitmap, write_offset + mask.offset + i))) {
Functor::CopyData(*array.type, out_values, out_offset + write_offset + i,
replacements, replacements_offset, /*length=*/1);
if (replacements_bitmap) {
copy_bitmap.SetBit(out_bitmap, out_offset + write_offset + i,
replacements_offset);
}
copy_bitmap.SetBit(out_bitmap, out_offset + write_offset + i,

replacements_offset);
replacements_offset++;
}
}
Expand All @@ -154,9 +153,8 @@ Status ReplaceWithArrayMask(KernelContext* ctx, const ArrayData& array,
uint8_t* out_values = output->buffers[1]->mutable_data();
const uint8_t* mask_bitmap = mask.MayHaveNulls() ? mask.buffers[0]->data() : nullptr;
const uint8_t* mask_values = mask.buffers[1]->data();
const bool replacements_bitmap = replacements.is_array()
? replacements.array()->MayHaveNulls()
: !replacements.scalar()->is_valid;
const bool replacements_bitmap =
replacements.is_array() ? replacements.array()->MayHaveNulls() : true;
if (replacements.is_array()) {
// Check that we have enough replacement values
const int64_t replacements_length = replacements.array()->length;
Expand Down Expand Up @@ -189,8 +187,9 @@ Status ReplaceWithArrayMask(KernelContext* ctx, const ArrayData& array,
const ArrayData& array_repl = *replacements.array();
ReplaceWithArrayMaskImpl<Functor>(
array, mask, array_repl, replacements_bitmap,
CopyArrayBitmap{replacements_bitmap ? array_repl.buffers[0]->data() : nullptr,
array_repl.offset},
CopyArrayBitmap{
(!!array_repl.buffers[0]) ? array_repl.buffers[0]->data() : nullptr,
array_repl.offset},
mask_bitmap, mask_values, out_bitmap, out_values, out_offset);
} else {
const Scalar& scalar_repl = *replacements.scalar();
Expand Down Expand Up @@ -229,6 +228,12 @@ struct ReplaceWithMask<Type,
std::fill(begin, end, UnboxScalar<Type>::Unbox(in));
}

static bool HasBitmap(const ArrayData& replacements) {
return replacements.MayHaveNulls();
}

static bool HasBitmap(const Scalar& replacements) { return true; }

static Status ExecScalarMask(KernelContext* ctx, const ArrayData& array,
const BooleanScalar& mask, const Datum& replacements,
ArrayData* output) {
Expand All @@ -254,9 +259,17 @@ struct ReplaceWithMask<Type, enable_if_boolean<Type>> {
}
static void CopyData(const DataType&, uint8_t* out, const int64_t out_offset,
const Scalar& in, const int64_t in_offset, const int64_t length) {
BitUtil::SetBitsTo(out, out_offset, length, in.is_valid);
BitUtil::SetBitsTo(
out, out_offset, length,
in.is_valid ? checked_cast<const BooleanScalar&>(in).value : false);
}

static bool HasBitmap(const ArrayData& replacements) {
return replacements.MayHaveNulls();
}

static bool HasBitmap(const Scalar& replacements) { return true; }

static Status ExecScalarMask(KernelContext* ctx, const ArrayData& array,
const BooleanScalar& mask, const Datum& replacements,
ArrayData* output) {
Expand Down Expand Up @@ -296,6 +309,12 @@ struct ReplaceWithMask<Type, enable_if_fixed_size_binary<Type>> {
}
}

static bool HasBitmap(const ArrayData& replacements) {
return replacements.MayHaveNulls();
}

static bool HasBitmap(const Scalar& replacements) { return true; }

static Status ExecScalarMask(KernelContext* ctx, const ArrayData& array,
const BooleanScalar& mask, const Datum& replacements,
ArrayData* output) {
Expand Down
43 changes: 43 additions & 0 deletions cpp/src/arrow/compute/kernels/vector_replace_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,38 @@ TYPED_TEST(TestReplaceNumeric, ReplaceWithMask) {
this->array("[0, null, 10]"));
}

TYPED_TEST(TestReplaceNumeric, ReplaceWithMaskForNullValuesAndMaskEnabled) {
this->Assert(ReplaceWithMask, this->array("[1, null, 1]"),
this->mask("[false, true, false]"), this->array("[7]"),
this->array("[1, 7, 1]"));
this->Assert(ReplaceWithMask, this->array("[1, null, 1, 7]"),
this->mask("[false, true, false, true]"), this->array("[7, 20]"),
this->array("[1, 7, 1, 20]"));
this->Assert(ReplaceWithMask, this->array("[1, 2, 3, 4]"),
this->mask("[false, true, false, true]"), this->array("[null, null]"),
this->array("[1, null, 3, null]"));
this->Assert(ReplaceWithMask, this->array("[null, 2, 3, 4]"),
this->mask("[true, true, false, true]"), this->array("[1, null, null]"),
this->array("[1, null, 3, null]"));
this->Assert(ReplaceWithMask, this->array("[1, null, 1]"),
this->mask("[false, true, false]"), this->scalar("null"),
this->array("[1, null, 1]"));
this->Assert(ReplaceWithMask, this->array("[1, null, 1]"),
this->mask("[true, true, true]"), this->array("[7, 7, 7]"),
this->array("[7, 7, 7]"));
this->Assert(ReplaceWithMask, this->array("[1, null, 1]"),
this->mask("[true, true, true]"), this->array("[null, null, null]"),
this->array("[null, null, null]"));
this->Assert(ReplaceWithMask, this->array("[1, null, 1]"),
this->mask("[false, true, false]"), this->scalar("null"),
this->array("[1, null, 1]"));
this->Assert(ReplaceWithMask, this->array("[1, null, 1]"),
this->mask("[true, true, true]"), this->scalar("null"),
this->array("[null, null, null]"));
this->Assert(ReplaceWithMask, this->array("[null, null]"), this->mask("[true, true]"),
this->array("[1, 1]"), this->array("[1, 1]"));
}

TYPED_TEST(TestReplaceNumeric, ReplaceWithMaskRandom) {
using ArrayType = typename TypeTraits<TypeParam>::ArrayType;
using CType = typename TypeTraits<TypeParam>::CType;
Expand Down Expand Up @@ -340,16 +372,24 @@ TEST_F(TestReplaceBoolean, ReplaceWithMask) {
this->mask("[false, false, null, null, true, true]"),
this->array("[false, null]"),
this->array("[null, null, null, null, false, null]"));
this->Assert(ReplaceWithMask, this->array("[true, null, true]"),
this->mask("[false, true, false]"), this->array("[true]"),
this->array("[true, true, true]"));

this->Assert(ReplaceWithMask, this->array("[]"), this->mask("[]"), this->scalar("true"),
this->array("[]"));
this->Assert(ReplaceWithMask, this->array("[null, false, true]"),
this->mask("[true, false, false]"), this->scalar("false"),
this->array("[false, false, true]"));
this->Assert(ReplaceWithMask, this->array("[false, false]"), this->mask("[true, true]"),
this->scalar("true"), this->array("[true, true]"));
this->Assert(ReplaceWithMask, this->array("[false, false]"), this->mask("[true, true]"),
this->scalar("null"), this->array("[null, null]"));
this->Assert(ReplaceWithMask, this->array("[false, false, false]"),
this->mask("[false, null, true]"), this->scalar("true"),
this->array("[false, null, true]"));
this->Assert(ReplaceWithMask, this->array("[null, null]"), this->mask("[true, true]"),
this->array("[true, true]"), this->array("[true, true]"));
}

TEST_F(TestReplaceBoolean, ReplaceWithMaskErrors) {
Expand Down Expand Up @@ -427,6 +467,9 @@ TEST_F(TestReplaceFixedSizeBinary, ReplaceWithMask) {
this->mask("[false, false, null, null, true, true]"),
this->array(R"(["aaa", null])"),
this->array(R"([null, null, null, null, "aaa", null])"));
this->Assert(ReplaceWithMask, this->array(R"(["aaa", null, "bbb"])"),
this->mask("[false, true, false]"), this->array(R"(["aba"])"),
this->array(R"(["aaa", "aba", "bbb"])"));

this->Assert(ReplaceWithMask, this->array("[]"), this->mask("[]"),
this->scalar(R"("zzz")"), this->array("[]"));
Expand Down

0 comments on commit 15d4b11

Please sign in to comment.