Skip to content

Commit

Permalink
ARROW-14795: [C++] Fix issue on replace with mask for null values
Browse files Browse the repository at this point in the history
  • Loading branch information
alvinj15 authored and alvinj15 committed Nov 23, 2021
1 parent b305edb commit 5dd5dc9
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 9 deletions.
56 changes: 47 additions & 9 deletions cpp/src/arrow/compute/kernels/vector_replace.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ struct CopyArrayBitmap {
const uint8_t* in_bitmap;
int64_t in_offset;

const uint8_t* GetInBitmap() const { return in_bitmap; }

void CopyBitmap(uint8_t* out_bitmap, int64_t out_offset, int64_t offset,
int64_t length) const {
arrow::internal::CopyBitmap(in_bitmap, in_offset + offset, length, out_bitmap,
Expand All @@ -83,10 +85,17 @@ struct CopyArrayBitmap {
BitUtil::SetBitTo(out_bitmap, out_offset,
BitUtil::GetBit(in_bitmap, in_offset + offset));
}

void SetBitToTrue(uint8_t* out_bitmap, int64_t out_offset, int64_t offset) const {
BitUtil::SetBitTo(out_bitmap, out_offset, true);
}
};

struct CopyScalarBitmap {
const bool is_valid;
const uint8_t* in_bitmap;

const uint8_t* GetInBitmap() const { return in_bitmap; }

void CopyBitmap(uint8_t* out_bitmap, int64_t out_offset, int64_t offset,
int64_t length) const {
Expand All @@ -96,6 +105,10 @@ struct CopyScalarBitmap {
void SetBit(uint8_t* out_bitmap, int64_t out_offset, int64_t offset) const {
BitUtil::SetBitTo(out_bitmap, out_offset, is_valid);
}

void SetBitToTrue(uint8_t* out_bitmap, int64_t out_offset, int64_t offset) const {
BitUtil::SetBitTo(out_bitmap, out_offset, true);
}
};

// Helper to implement replace_with kernel with array mask for fixed-width types,
Expand All @@ -119,10 +132,10 @@ void ReplaceWithArrayMaskImpl(const ArrayData& array, const ArrayData& mask,
// Copy from replacement array
Functor::CopyData(*array.type, out_values, out_offset + write_offset, replacements,
replacements_offset, block.length);
if (replacements_bitmap) {
if (replacements_bitmap && Functor::HasBitmap(replacements)) {
copy_bitmap.CopyBitmap(out_bitmap, out_offset + write_offset, replacements_offset,
block.length);
} else if (!replacements_bitmap && out_bitmap) {
} else if (out_bitmap) {
BitUtil::SetBitsTo(out_bitmap, out_offset + write_offset, block.length, true);
}
replacements_offset += block.length;
Expand All @@ -134,8 +147,13 @@ void ReplaceWithArrayMaskImpl(const ArrayData& array, const ArrayData& mask,
Functor::CopyData(*array.type, out_values, out_offset + write_offset + i,
replacements, replacements_offset, /*length=*/1);
if (replacements_bitmap) {
copy_bitmap.SetBit(out_bitmap, out_offset + write_offset + i,
replacements_offset);
if (copy_bitmap.GetInBitmap()) {
copy_bitmap.SetBit(out_bitmap, out_offset + write_offset + i,
replacements_offset);
} else {
copy_bitmap.SetBitToTrue(out_bitmap, out_offset + write_offset + i,
replacements_offset);
}
}
replacements_offset++;
}
Expand All @@ -154,9 +172,10 @@ Status ReplaceWithArrayMask(KernelContext* ctx, const ArrayData& array,
uint8_t* out_values = output->buffers[1]->mutable_data();
const uint8_t* mask_bitmap = mask.MayHaveNulls() ? mask.buffers[0]->data() : nullptr;
const uint8_t* mask_values = mask.buffers[1]->data();
const bool replacements_bitmap = replacements.is_array()
? replacements.array()->MayHaveNulls()
: !replacements.scalar()->is_valid;
const bool replacements_bitmap =
replacements.is_array()
? replacements.array()->MayHaveNulls() || array.MayHaveNulls()
: !replacements.scalar()->is_valid;
if (replacements.is_array()) {
// Check that we have enough replacement values
const int64_t replacements_length = replacements.array()->length;
Expand Down Expand Up @@ -189,8 +208,9 @@ Status ReplaceWithArrayMask(KernelContext* ctx, const ArrayData& array,
const ArrayData& array_repl = *replacements.array();
ReplaceWithArrayMaskImpl<Functor>(
array, mask, array_repl, replacements_bitmap,
CopyArrayBitmap{replacements_bitmap ? array_repl.buffers[0]->data() : nullptr,
array_repl.offset},
CopyArrayBitmap{
(!!array_repl.buffers[0]) ? array_repl.buffers[0]->data() : nullptr,
array_repl.offset},
mask_bitmap, mask_values, out_bitmap, out_values, out_offset);
} else {
const Scalar& scalar_repl = *replacements.scalar();
Expand Down Expand Up @@ -229,6 +249,12 @@ struct ReplaceWithMask<Type,
std::fill(begin, end, UnboxScalar<Type>::Unbox(in));
}

static bool HasBitmap(const ArrayData& replacements) {
return replacements.MayHaveNulls();
}

static bool HasBitmap(const Scalar& replacements) { return true; }

static Status ExecScalarMask(KernelContext* ctx, const ArrayData& array,
const BooleanScalar& mask, const Datum& replacements,
ArrayData* output) {
Expand Down Expand Up @@ -257,6 +283,12 @@ struct ReplaceWithMask<Type, enable_if_boolean<Type>> {
BitUtil::SetBitsTo(out, out_offset, length, in.is_valid);
}

static bool HasBitmap(const ArrayData& replacements) {
return replacements.MayHaveNulls();
}

static bool HasBitmap(const Scalar& replacements) { return true; }

static Status ExecScalarMask(KernelContext* ctx, const ArrayData& array,
const BooleanScalar& mask, const Datum& replacements,
ArrayData* output) {
Expand Down Expand Up @@ -296,6 +328,12 @@ struct ReplaceWithMask<Type, enable_if_fixed_size_binary<Type>> {
}
}

static bool HasBitmap(const ArrayData& replacements) {
return replacements.MayHaveNulls();
}

static bool HasBitmap(const Scalar& replacements) { return true; }

static Status ExecScalarMask(KernelContext* ctx, const ArrayData& array,
const BooleanScalar& mask, const Datum& replacements,
ArrayData* output) {
Expand Down
21 changes: 21 additions & 0 deletions cpp/src/arrow/compute/kernels/vector_replace_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,21 @@ TYPED_TEST(TestReplaceNumeric, ReplaceWithMask) {
this->array("[0, null, 10]"));
}

TYPED_TEST(TestReplaceNumeric, ReplaceWithMaskForNullValuesAndMaskEnabled) {
this->Assert(ReplaceWithMask, this->array("[1, null, 1]"),
this->mask("[false, true, false]"), this->array("[7]"),
this->array("[1, 7, 1]"));
this->Assert(ReplaceWithMask, this->array("[1, null, 1, 7]"),
this->mask("[false, true, false, true]"), this->array("[7, 20]"),
this->array("[1, 7, 1, 20]"));
this->Assert(ReplaceWithMask, this->array("[1, 2, 3, 4]"),
this->mask("[false, true, false, true]"), this->array("[null, null]"),
this->array("[1, null, 3, null]"));
this->Assert(ReplaceWithMask, this->array("[null, 2, 3, 4]"),
this->mask("[true, true, false, true]"), this->array("[1, null, null]"),
this->array("[1, null, 3, null]"));
}

TYPED_TEST(TestReplaceNumeric, ReplaceWithMaskRandom) {
using ArrayType = typename TypeTraits<TypeParam>::ArrayType;
using CType = typename TypeTraits<TypeParam>::CType;
Expand Down Expand Up @@ -340,6 +355,9 @@ TEST_F(TestReplaceBoolean, ReplaceWithMask) {
this->mask("[false, false, null, null, true, true]"),
this->array("[false, null]"),
this->array("[null, null, null, null, false, null]"));
this->Assert(ReplaceWithMask, this->array("[true, null, true]"),
this->mask("[false, true, false]"), this->array("[true]"),
this->array("[true, true, true]"));

this->Assert(ReplaceWithMask, this->array("[]"), this->mask("[]"), this->scalar("true"),
this->array("[]"));
Expand Down Expand Up @@ -427,6 +445,9 @@ TEST_F(TestReplaceFixedSizeBinary, ReplaceWithMask) {
this->mask("[false, false, null, null, true, true]"),
this->array(R"(["aaa", null])"),
this->array(R"([null, null, null, null, "aaa", null])"));
this->Assert(ReplaceWithMask, this->array(R"(["aaa", null, "bbb"])"),
this->mask("[false, true, false]"), this->array(R"(["aba"])"),
this->array(R"(["aaa", "aba", "bbb"])"));

this->Assert(ReplaceWithMask, this->array("[]"), this->mask("[]"),
this->scalar(R"("zzz")"), this->array("[]"));
Expand Down

0 comments on commit 5dd5dc9

Please sign in to comment.