Skip to content

Commit

Permalink
initial mask kernel impl
Browse files Browse the repository at this point in the history
  • Loading branch information
bkietz committed Jun 19, 2019
1 parent 2785a73 commit 13a1969
Show file tree
Hide file tree
Showing 6 changed files with 460 additions and 1 deletion.
1 change: 1 addition & 0 deletions cpp/src/arrow/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ if(ARROW_COMPUTE)
compute/kernels/count.cc
compute/kernels/filter.cc
compute/kernels/hash.cc
compute/kernels/mask.cc
compute/kernels/mean.cc
compute/kernels/sum.cc
compute/kernels/take.cc
Expand Down
1 change: 1 addition & 0 deletions cpp/src/arrow/compute/kernels/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ arrow_install_all_headers("arrow/compute/kernels")
add_arrow_test(boolean-test PREFIX "arrow-compute")
add_arrow_test(cast-test PREFIX "arrow-compute")
add_arrow_test(hash-test PREFIX "arrow-compute")
add_arrow_test(mask-test PREFIX "arrow-compute")
add_arrow_test(take-test PREFIX "arrow-compute")
add_arrow_test(util-internal-test PREFIX "arrow-compute")

Expand Down
145 changes: 145 additions & 0 deletions cpp/src/arrow/compute/kernels/mask-test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// returnGegarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#include <memory>
#include <vector>

#include "arrow/compute/context.h"
#include "arrow/compute/kernels/mask.h"
#include "arrow/compute/test-util.h"
#include "arrow/testing/gtest_common.h"
#include "arrow/testing/gtest_util.h"
#include "arrow/testing/random.h"
#include "arrow/testing/util.h"

namespace arrow {
namespace compute {

using util::string_view;

template <typename ArrowType>
class TestMaskKernel : public ComputeFixture, public TestBase {
protected:
void AssertMaskArrays(const std::shared_ptr<Array>& values,
const std::shared_ptr<Array>& mask, MaskOptions options,
const std::shared_ptr<Array>& expected) {
std::shared_ptr<Array> actual;
ASSERT_OK(arrow::compute::Mask(&this->ctx_, *values, *mask, options, &actual));
AssertArraysEqual(*expected, *actual);
}
void AssertMask(const std::shared_ptr<DataType>& type, const std::string& values,
const std::string& mask, MaskOptions options,
const std::string& expected) {
std::shared_ptr<Array> actual;
ASSERT_OK(this->Mask(type, values, mask, options, &actual));
AssertArraysEqual(*ArrayFromJSON(type, expected), *actual);
}
Status Mask(const std::shared_ptr<DataType>& type, const std::string& values,
const std::string& mask, MaskOptions options, std::shared_ptr<Array>* out) {
return arrow::compute::Mask(&this->ctx_, *ArrayFromJSON(type, values),
*ArrayFromJSON(boolean(), mask), options, out);
}
};

class TestMaskKernelWithNull : public TestMaskKernel<NullType> {
protected:
void AssertMask(const std::string& values, const std::string& mask, MaskOptions options,
const std::string& expected) {
TestMaskKernel<NullType>::AssertMask(utf8(), values, mask, options, expected);
}
};

TEST_F(TestMaskKernelWithNull, MaskNull) {
MaskOptions options;
this->AssertMask("[null, null, null]", "[0, 1, 0]", options, "[null]");
this->AssertMask("[null, null, null]", "[1, 1, 0]", options, "[null, null]");
}

class TestMaskKernelWithBoolean : public TestMaskKernel<BooleanType> {
protected:
void AssertMask(const std::string& values, const std::string& mask, MaskOptions options,
const std::string& expected) {
TestMaskKernel<BooleanType>::AssertMask(boolean(), values, mask, options, expected);
}
};

TEST_F(TestMaskKernelWithBoolean, MaskBoolean) {
MaskOptions options;
this->AssertMask("[true, false, true]", "[0, 1, 0]", options, "[false]");
this->AssertMask("[null, false, true]", "[0, 1, 0]", options, "[false]");
this->AssertMask("[true, false, true]", "[null, 1, 0]", options, "[null, false]");
}

template <typename ArrowType>
class TestMaskKernelWithNumeric : public TestMaskKernel<ArrowType> {
protected:
void AssertMask(const std::string& values, const std::string& mask, MaskOptions options,
const std::string& expected) {
TestMaskKernel<ArrowType>::AssertMask(type_singleton(), values, mask, options,
expected);
}
std::shared_ptr<DataType> type_singleton() {
return TypeTraits<ArrowType>::type_singleton();
}
};

TYPED_TEST_CASE(TestMaskKernelWithNumeric, NumericArrowTypes);
TYPED_TEST(TestMaskKernelWithNumeric, MaskNumeric) {
MaskOptions options;
this->AssertMask("[7, 8, 9]", "[0, 1, 0]", options, "[8]");
this->AssertMask("[null, 8, 9]", "[0, 1, 0]", options, "[8]");
this->AssertMask("[7, 8, 9]", "[null, 1, 0]", options, "[null, 8]");
}

class TestMaskKernelWithString : public TestMaskKernel<StringType> {
protected:
void AssertMask(const std::string& values, const std::string& mask, MaskOptions options,
const std::string& expected) {
TestMaskKernel<StringType>::AssertMask(utf8(), values, mask, options, expected);
}
void AssertMaskDictionary(const std::string& dictionary_values,
const std::string& dictionary_mask, const std::string& mask,
MaskOptions options, const std::string& expected_mask) {
auto dict = ArrayFromJSON(utf8(), dictionary_values);
auto type = dictionary(int8(), utf8());
std::shared_ptr<Array> values, actual, expected;
ASSERT_OK(DictionaryArray::FromArrays(type, ArrayFromJSON(int8(), dictionary_mask),
dict, &values));
ASSERT_OK(DictionaryArray::FromArrays(type, ArrayFromJSON(int8(), expected_mask),
dict, &expected));
auto take_mask = ArrayFromJSON(boolean(), mask);
this->AssertMaskArrays(values, take_mask, options, expected);
}
};

TEST_F(TestMaskKernelWithString, MaskString) {
MaskOptions options;
this->AssertMask(R"(["a", "b", "c"])", "[0, 1, 0]", options, R"(["b"])");
this->AssertMask(R"([null, "b", "c"])", "[0, 1, 0]", options, R"(["b"])");
this->AssertMask(R"(["a", "b", "c"])", "[null, 1, 0]", options, R"([null, "b"])");
}

TEST_F(TestMaskKernelWithString, MaskDictionary) {
MaskOptions options;
auto dict = R"(["a", "b", "c", "d", "e"])";
this->AssertMaskDictionary(dict, "[3, 4, 2]", "[0, 1, 0]", options, "[4]");
this->AssertMaskDictionary(dict, "[null, 4, 2]", "[0, 1, 0]", options, "[4]");
this->AssertMaskDictionary(dict, "[3, 4, 2]", "[null, 1, 0]", options, "[null, 4]");
}

} // namespace compute
} // namespace arrow
230 changes: 230 additions & 0 deletions cpp/src/arrow/compute/kernels/mask.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// returnGegarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#include <algorithm>
#include <memory>
#include <utility>

#include "arrow/builder.h"
#include "arrow/compute/context.h"
#include "arrow/compute/kernels/mask.h"
#include "arrow/util/bit-util.h"
#include "arrow/util/checked_cast.h"
#include "arrow/util/logging.h"
#include "arrow/visitor_inline.h"

namespace arrow {
namespace compute {

Status Mask(FunctionContext* context, const Array& values, const Array& mask,
const MaskOptions& options, std::shared_ptr<Array>* out) {
Datum out_datum;
RETURN_NOT_OK(
Mask(context, Datum(values.data()), Datum(mask.data()), options, &out_datum));
*out = out_datum.make_array();
return Status::OK();
}

Status Mask(FunctionContext* context, const Datum& values, const Datum& mask,
const MaskOptions& options, Datum* out) {
MaskKernel kernel(values.type(), options);
RETURN_NOT_OK(kernel.Call(context, values, mask, out));
return Status::OK();
}

struct MaskParameters {
FunctionContext* context;
std::shared_ptr<Array> values, mask;
MaskOptions options;
std::shared_ptr<Array>* out;
};

template <typename Builder, typename Scalar>
static Status UnsafeAppend(Builder* builder, Scalar&& value) {
builder->UnsafeAppend(std::forward<Scalar>(value));
return Status::OK();
}

static Status UnsafeAppend(BinaryBuilder* builder, util::string_view value) {
RETURN_NOT_OK(builder->ReserveData(static_cast<int64_t>(value.size())));
builder->UnsafeAppend(value);
return Status::OK();
}

static Status UnsafeAppend(StringBuilder* builder, util::string_view value) {
RETURN_NOT_OK(builder->ReserveData(static_cast<int64_t>(value.size())));
builder->UnsafeAppend(value);
return Status::OK();
}

// TODO(bkietz) this can be optimized
static int64_t OutputSize(const BooleanArray& mask) {
auto offset = mask.offset();
auto length = mask.length();
internal::BitmapReader mask_data(mask.data()->buffers[1]->data(), offset, length);
int64_t size = 0;
for (auto i = offset; i < offset + length; ++i) {
if (mask.IsNull(i) || mask_data.IsSet()) {
++size;
}
mask_data.Next();
}
return size;
}

template <bool AllValuesValid, bool WholeMaskValid, typename ValueArray,
typename OutBuilder>
Status MaskImpl(FunctionContext*, const ValueArray& values, const BooleanArray& mask,
OutBuilder* builder) {
auto offset = mask.offset();
auto length = mask.length();
internal::BitmapReader mask_data(mask.data()->buffers[1]->data(), offset, length);
for (int64_t i = 0; i < mask.length(); mask_data.Next(), ++i) {
if (!WholeMaskValid && mask.IsNull(i)) {
builder->UnsafeAppendNull();
continue;
}
if (mask_data.IsNotSet()) {
continue;
}
if (!AllValuesValid && values.IsNull(i)) {
builder->UnsafeAppendNull();
continue;
}
RETURN_NOT_OK(UnsafeAppend(builder, values.GetView(i)));
}
return Status::OK();
}

template <bool AllValuesValid, typename ValueArray, typename MaskArray,
typename OutBuilder>
Status UnpackMaskNullCount(FunctionContext* context, const ValueArray& values,
const MaskArray& mask, OutBuilder* builder) {
if (mask.null_count() == 0) {
return MaskImpl<AllValuesValid, true>(context, values, mask, builder);
}
return MaskImpl<AllValuesValid, false>(context, values, mask, builder);
}

template <typename ValueArray, typename MaskArray, typename OutBuilder>
Status UnpackValuesNullCount(FunctionContext* context, const ValueArray& values,
const MaskArray& mask, OutBuilder* builder) {
if (values.null_count() == 0) {
return UnpackMaskNullCount<true>(context, values, mask, builder);
}
return UnpackMaskNullCount<false>(context, values, mask, builder);
}

template <typename T>
using ArrayType = typename TypeTraits<T>::ArrayType;

template <typename MaskType>
struct UnpackValues {
template <typename ValueType>
Status Visit(const ValueType&) {
using OutBuilder = typename TypeTraits<ValueType>::BuilderType;
auto&& mask = static_cast<const ArrayType<MaskType>&>(*params_.mask);
auto&& values = static_cast<const ArrayType<ValueType>&>(*params_.values);
std::unique_ptr<ArrayBuilder> builder;
RETURN_NOT_OK(MakeBuilder(params_.context->memory_pool(), values.type(), &builder));
RETURN_NOT_OK(builder->Reserve(OutputSize(mask)));
RETURN_NOT_OK(UnpackValuesNullCount(params_.context, values, mask,
static_cast<OutBuilder*>(builder.get())));
return builder->Finish(params_.out);
}

Status Visit(const NullType& t) {
auto&& mask = static_cast<const ArrayType<MaskType>&>(*params_.mask);
params_.out->reset(new NullArray(OutputSize(mask)));
return Status::OK();
}

Status Visit(const DictionaryType& t) {
std::shared_ptr<Array> masked_indices;
const auto& values = internal::checked_cast<const DictionaryArray&>(*params_.values);
{
// To take from a dictionary, apply the current kernel to the dictionary's
// mask. (Use UnpackValues<MaskType> since MaskType is already unpacked)
MaskParameters params = params_;
params.values = values.indices();
params.out = &masked_indices;
UnpackValues<MaskType> unpack = {params};
RETURN_NOT_OK(VisitTypeInline(*t.index_type(), &unpack));
}
// create output dictionary from taken mask
*params_.out = std::make_shared<DictionaryArray>(values.type(), masked_indices,
values.dictionary());
return Status::OK();
}

Status Visit(const ExtensionType& t) {
// XXX can we just take from its storage?
return Status::NotImplemented("gathering values of type ", t);
}

Status Visit(const UnionType& t) {
return Status::NotImplemented("gathering values of type ", t);
}

Status Visit(const ListType& t) {
return Status::NotImplemented("gathering values of type ", t);
}

Status Visit(const FixedSizeListType& t) {
return Status::NotImplemented("gathering values of type ", t);
}

Status Visit(const StructType& t) {
return Status::NotImplemented("gathering values of type ", t);
}

const MaskParameters& params_;
};

struct UnpackMask {
Status Visit(const BooleanType&) {
UnpackValues<BooleanType> unpack = {params_};
return VisitTypeInline(*params_.values->type(), &unpack);
}

Status Visit(const DataType& other) {
return Status::TypeError("mask type not supported: ", other);
}

const MaskParameters& params_;
};

Status MaskKernel::Call(FunctionContext* ctx, const Datum& values, const Datum& mask,
Datum* out) {
if (!values.is_array() || !mask.is_array()) {
return Status::Invalid("MaskKernel expects array values and mask");
}
std::shared_ptr<Array> out_array;
MaskParameters params;
params.context = ctx;
params.values = values.make_array();
params.mask = mask.make_array();
params.options = options_;
params.out = &out_array;
UnpackMask unpack = {params};
RETURN_NOT_OK(VisitTypeInline(*mask.type(), &unpack));
*out = Datum(out_array);
return Status::OK();
}

} // namespace compute
} // namespace arrow
Loading

0 comments on commit 13a1969

Please sign in to comment.