Skip to content

Commit

Permalink
revert removal of TakeOptions
Browse files Browse the repository at this point in the history
  • Loading branch information
bkietz committed May 22, 2019
1 parent 9606ef0 commit 7f86e37
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 45 deletions.
86 changes: 48 additions & 38 deletions cpp/src/arrow/compute/kernels/take-test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,75 +35,81 @@ template <typename ArrowType>
class TestTakeKernel : public ComputeFixture, public TestBase {
protected:
void AssertTakeArrays(const std::shared_ptr<Array>& values,
const std::shared_ptr<Array>& indices,
const std::shared_ptr<Array>& indices, TakeOptions options,
const std::shared_ptr<Array>& expected) {
std::shared_ptr<Array> actual;
ASSERT_OK(arrow::compute::Take(&this->ctx_, *values, *indices, &actual));
ASSERT_OK(arrow::compute::Take(&this->ctx_, *values, *indices, options, &actual));
AssertArraysEqual(*expected, *actual);
}
void AssertTake(const std::shared_ptr<DataType>& type, const std::string& values,
const std::string& indices, const std::string& expected) {
const std::string& indices, TakeOptions options,
const std::string& expected) {
std::shared_ptr<Array> actual;

for (auto index_type : {int8(), uint32()}) {
ASSERT_OK(this->Take(type, values, index_type, indices, &actual));
ASSERT_OK(this->Take(type, values, index_type, indices, options, &actual));
AssertArraysEqual(*ArrayFromJSON(type, expected), *actual);
}
}
Status Take(const std::shared_ptr<DataType>& type, const std::string& values,
const std::shared_ptr<DataType>& index_type, const std::string& indices,
std::shared_ptr<Array>* out) {
TakeOptions options, std::shared_ptr<Array>* out) {
return arrow::compute::Take(&this->ctx_, *ArrayFromJSON(type, values),
*ArrayFromJSON(index_type, indices), out);
*ArrayFromJSON(index_type, indices), options, out);
}
};

class TestTakeKernelWithNull : public TestTakeKernel<NullType> {
protected:
void AssertTake(const std::string& values, const std::string& indices,
const std::string& expected) {
TestTakeKernel<NullType>::AssertTake(utf8(), values, indices, expected);
TakeOptions options, const std::string& expected) {
TestTakeKernel<NullType>::AssertTake(utf8(), values, indices, options, expected);
}
};

TEST_F(TestTakeKernelWithNull, TakeNull) {
this->AssertTake("[null, null, null]", "[0, 1, 0]", "[null, null, null]");
TakeOptions options;
this->AssertTake("[null, null, null]", "[0, 1, 0]", options, "[null, null, null]");

std::shared_ptr<Array> arr;
ASSERT_RAISES(IndexError,
this->Take(null(), "[null, null, null]", int8(), "[0, 9, 0]", &arr));
ASSERT_RAISES(IndexError, this->Take(null(), "[null, null, null]", int8(), "[0, 9, 0]",
options, &arr));
}

TEST_F(TestTakeKernelWithNull, InvalidIndexType) {
TakeOptions options;
std::shared_ptr<Array> arr;
ASSERT_RAISES(TypeError, this->Take(null(), "[null, null, null]", float32(),
"[0.0, 1.0, 0.1]", &arr));
"[0.0, 1.0, 0.1]", options, &arr));
}

class TestTakeKernelWithBoolean : public TestTakeKernel<BooleanType> {
protected:
void AssertTake(const std::string& values, const std::string& indices,
const std::string& expected) {
TestTakeKernel<BooleanType>::AssertTake(boolean(), values, indices, expected);
TakeOptions options, const std::string& expected) {
TestTakeKernel<BooleanType>::AssertTake(boolean(), values, indices, options,
expected);
}
};

TEST_F(TestTakeKernelWithBoolean, TakeBoolean) {
this->AssertTake("[true, false, true]", "[0, 1, 0]", "[true, false, true]");
this->AssertTake("[null, false, true]", "[0, 1, 0]", "[null, false, null]");
this->AssertTake("[true, false, true]", "[null, 1, 0]", "[null, false, true]");
TakeOptions options;
this->AssertTake("[true, false, true]", "[0, 1, 0]", options, "[true, false, true]");
this->AssertTake("[null, false, true]", "[0, 1, 0]", options, "[null, false, null]");
this->AssertTake("[true, false, true]", "[null, 1, 0]", options, "[null, false, true]");

std::shared_ptr<Array> arr;
ASSERT_RAISES(IndexError,
this->Take(boolean(), "[true, false, true]", int8(), "[0, 9, 0]", &arr));
ASSERT_RAISES(IndexError, this->Take(boolean(), "[true, false, true]", int8(),
"[0, 9, 0]", options, &arr));
}

template <typename ArrowType>
class TestTakeKernelWithNumeric : public TestTakeKernel<ArrowType> {
protected:
void AssertTake(const std::string& values, const std::string& indices,
const std::string& expected) {
TestTakeKernel<ArrowType>::AssertTake(type_singleton(), values, indices, expected);
TakeOptions options, const std::string& expected) {
TestTakeKernel<ArrowType>::AssertTake(type_singleton(), values, indices, options,
expected);
}
std::shared_ptr<DataType> type_singleton() {
return TypeTraits<ArrowType>::type_singleton();
Expand All @@ -112,25 +118,26 @@ class TestTakeKernelWithNumeric : public TestTakeKernel<ArrowType> {

TYPED_TEST_CASE(TestTakeKernelWithNumeric, NumericArrowTypes);
TYPED_TEST(TestTakeKernelWithNumeric, TakeNumeric) {
this->AssertTake("[7, 8, 9]", "[0, 1, 0]", "[7, 8, 7]");
this->AssertTake("[null, 8, 9]", "[0, 1, 0]", "[null, 8, null]");
this->AssertTake("[7, 8, 9]", "[null, 1, 0]", "[null, 8, 7]");
this->AssertTake("[null, 8, 9]", "[]", "[]");
TakeOptions options;
this->AssertTake("[7, 8, 9]", "[0, 1, 0]", options, "[7, 8, 7]");
this->AssertTake("[null, 8, 9]", "[0, 1, 0]", options, "[null, 8, null]");
this->AssertTake("[7, 8, 9]", "[null, 1, 0]", options, "[null, 8, 7]");
this->AssertTake("[null, 8, 9]", "[]", options, "[]");

std::shared_ptr<Array> arr;
ASSERT_RAISES(IndexError, this->Take(this->type_singleton(), "[7, 8, 9]", int8(),
"[0, 9, 0]", &arr));
"[0, 9, 0]", options, &arr));
}

class TestTakeKernelWithString : public TestTakeKernel<StringType> {
protected:
void AssertTake(const std::string& values, const std::string& indices,
const std::string& expected) {
TestTakeKernel<StringType>::AssertTake(utf8(), values, indices, expected);
TakeOptions options, const std::string& expected) {
TestTakeKernel<StringType>::AssertTake(utf8(), values, indices, options, expected);
}
void AssertTakeDictionary(const std::string& dictionary_values,
const std::string& dictionary_indices,
const std::string& indices,
const std::string& indices, TakeOptions options,
const std::string& expected_indices) {
auto dict = ArrayFromJSON(utf8(), dictionary_values);
auto type = dictionary(int8(), utf8());
Expand All @@ -140,25 +147,28 @@ class TestTakeKernelWithString : public TestTakeKernel<StringType> {
ASSERT_OK(DictionaryArray::FromArrays(type, ArrayFromJSON(int8(), expected_indices),
dict, &expected));
auto take_indices = ArrayFromJSON(int8(), indices);
this->AssertTakeArrays(values, take_indices, expected);
this->AssertTakeArrays(values, take_indices, options, expected);
}
};

TEST_F(TestTakeKernelWithString, TakeString) {
this->AssertTake(R"(["a", "b", "c"])", "[0, 1, 0]", R"(["a", "b", "a"])");
this->AssertTake(R"([null, "b", "c"])", "[0, 1, 0]", "[null, \"b\", null]");
this->AssertTake(R"(["a", "b", "c"])", "[null, 1, 0]", R"([null, "b", "a"])");
TakeOptions options;
this->AssertTake(R"(["a", "b", "c"])", "[0, 1, 0]", options, R"(["a", "b", "a"])");
this->AssertTake(R"([null, "b", "c"])", "[0, 1, 0]", options, "[null, \"b\", null]");
this->AssertTake(R"(["a", "b", "c"])", "[null, 1, 0]", options, R"([null, "b", "a"])");

std::shared_ptr<Array> arr;
ASSERT_RAISES(IndexError,
this->Take(utf8(), R"(["a", "b", "c"])", int8(), "[0, 9, 0]", &arr));
ASSERT_RAISES(IndexError, this->Take(utf8(), R"(["a", "b", "c"])", int8(), "[0, 9, 0]",
options, &arr));
}

TEST_F(TestTakeKernelWithString, TakeDictionary) {
TakeOptions options;
auto dict = R"(["a", "b", "c", "d", "e"])";
this->AssertTakeDictionary(dict, "[3, 4, 2]", "[0, 1, 0]", "[3, 4, 3]");
this->AssertTakeDictionary(dict, "[null, 4, 2]", "[0, 1, 0]", "[null, 4, null]");
this->AssertTakeDictionary(dict, "[3, 4, 2]", "[null, 1, 0]", "[null, 4, 3]");
this->AssertTakeDictionary(dict, "[3, 4, 2]", "[0, 1, 0]", options, "[3, 4, 3]");
this->AssertTakeDictionary(dict, "[null, 4, 2]", "[0, 1, 0]", options,
"[null, 4, null]");
this->AssertTakeDictionary(dict, "[3, 4, 2]", "[null, 1, 0]", options, "[null, 4, 3]");
}

} // namespace compute
Expand Down
11 changes: 7 additions & 4 deletions cpp/src/arrow/compute/kernels/take.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,23 +31,25 @@ namespace compute {
using internal::checked_cast;

Status Take(FunctionContext* context, const Array& values, const Array& indices,
std::shared_ptr<Array>* out) {
const TakeOptions& options, std::shared_ptr<Array>* out) {
Datum out_datum;
RETURN_NOT_OK(Take(context, Datum(values.data()), Datum(indices.data()), &out_datum));
RETURN_NOT_OK(
Take(context, Datum(values.data()), Datum(indices.data()), options, &out_datum));
*out = out_datum.make_array();
return Status::OK();
}

Status Take(FunctionContext* context, const Datum& values, const Datum& indices,
Datum* out) {
TakeKernel kernel(values.type());
const TakeOptions& options, Datum* out) {
TakeKernel kernel(values.type(), options);
RETURN_NOT_OK(kernel.Call(context, values, indices, out));
return Status::OK();
}

struct TakeParameters {
FunctionContext* context;
std::shared_ptr<Array> values, indices;
TakeOptions options;
std::shared_ptr<Array>* out;
};

Expand Down Expand Up @@ -211,6 +213,7 @@ Status TakeKernel::Call(FunctionContext* ctx, const Datum& values, const Datum&
params.context = ctx;
params.values = values.make_array();
params.indices = indices.make_array();
params.options = options_;
params.out = &out_array;
UnpackIndices unpack = {params};
RETURN_NOT_OK(VisitTypeInline(*indices.type(), &unpack));
Expand Down
12 changes: 9 additions & 3 deletions cpp/src/arrow/compute/kernels/take.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ namespace compute {

class FunctionContext;

struct ARROW_EXPORT TakeOptions {};

/// \brief Take from an array of values at indices in another array
///
/// The output array will be of the same type as the input values
Expand All @@ -45,25 +47,28 @@ class FunctionContext;
/// \param[in] context the FunctionContext
/// \param[in] values array from which to take
/// \param[in] indices which values to take
/// \param[in] options options
/// \param[out] out resulting array
ARROW_EXPORT
Status Take(FunctionContext* context, const Array& values, const Array& indices,
std::shared_ptr<Array>* out);
const TakeOptions& options, std::shared_ptr<Array>* out);

/// \brief Take from an array of values at indices in another array
///
/// \param[in] context the FunctionContext
/// \param[in] values datum from which to take
/// \param[in] indices which values to take
/// \param[in] options options
/// \param[out] out resulting datum
ARROW_EXPORT
Status Take(FunctionContext* context, const Datum& values, const Datum& indices,
Datum* out);
const TakeOptions& options, Datum* out);

/// \brief BinaryKernel implementing Take operation
class ARROW_EXPORT TakeKernel : public BinaryKernel {
public:
explicit TakeKernel(const std::shared_ptr<DataType>& type) : type_(type) {}
explicit TakeKernel(const std::shared_ptr<DataType>& type, TakeOptions options = {})
: type_(type), options_(options) {}

Status Call(FunctionContext* ctx, const Datum& values, const Datum& indices,
Datum* out) override;
Expand All @@ -72,6 +77,7 @@ class ARROW_EXPORT TakeKernel : public BinaryKernel {

private:
std::shared_ptr<DataType> type_;
TakeOptions options_;
};
} // namespace compute
} // namespace arrow

0 comments on commit 7f86e37

Please sign in to comment.