Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-37782: [C++] Add CanReferenceFieldsByNames method to arrow::StructArray #37823

Merged
merged 6 commits into from
Sep 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions cpp/src/arrow/array/array_nested.cc
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,22 @@ std::shared_ptr<Array> StructArray::GetFieldByName(const std::string& name) cons
return i == -1 ? nullptr : field(i);
}

Status StructArray::CanReferenceFieldByName(const std::string& name) const {
if (GetFieldByName(name) == nullptr) {
return Status::Invalid("Field named '", name,
"' not found or not unique in the struct.");
}
return Status::OK();
}

Status StructArray::CanReferenceFieldsByNames(
const std::vector<std::string>& names) const {
for (const auto& name : names) {
ARROW_RETURN_NOT_OK(CanReferenceFieldByName(name));
}
return Status::OK();
}

Result<ArrayVector> StructArray::Flatten(MemoryPool* pool) const {
ArrayVector flattened;
flattened.resize(data_->child_data.size());
Expand Down
6 changes: 6 additions & 0 deletions cpp/src/arrow/array/array_nested.h
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,12 @@ class ARROW_EXPORT StructArray : public Array {
/// Returns null if name not found
std::shared_ptr<Array> GetFieldByName(const std::string& name) const;

/// Indicate if field named `name` can be found unambiguously in the struct.
Status CanReferenceFieldByName(const std::string& name) const;

/// Indicate if fields named `names` can be found unambiguously in the struct.
Status CanReferenceFieldsByNames(const std::vector<std::string>& names) const;

/// \brief Flatten this array as a vector of arrays, one for each field
///
/// \param[in] pool The pool to allocate null bitmaps from, if necessary
Expand Down
52 changes: 52 additions & 0 deletions cpp/src/arrow/array/array_struct_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,58 @@ TEST(StructArray, FlattenOfSlice) {
ASSERT_OK(arr->ValidateFull());
}

TEST(StructArray, CanReferenceFieldByName) {
auto a = ArrayFromJSON(int8(), "[4, 5]");
auto b = ArrayFromJSON(int16(), "[6, 7]");
auto c = ArrayFromJSON(int32(), "[8, 9]");
auto d = ArrayFromJSON(int64(), "[10, 11]");
auto children = std::vector<std::shared_ptr<Array>>{a, b, c, d};

auto f0 = field("f0", int8());
auto f1 = field("f1", int16());
auto f2 = field("f2", int32());
auto f3 = field("f1", int64());
auto type = struct_({f0, f1, f2, f3});

auto arr = std::make_shared<StructArray>(type, 2, children);

ASSERT_OK(arr->CanReferenceFieldByName("f0"));
ASSERT_OK(arr->CanReferenceFieldByName("f2"));
// Not found
ASSERT_RAISES(Invalid, arr->CanReferenceFieldByName("nope"));

// Duplicates
ASSERT_RAISES(Invalid, arr->CanReferenceFieldByName("f1"));
}

TEST(StructArray, CanReferenceFieldsByNames) {
auto a = ArrayFromJSON(int8(), "[4, 5]");
auto b = ArrayFromJSON(int16(), "[6, 7]");
auto c = ArrayFromJSON(int32(), "[8, 9]");
auto d = ArrayFromJSON(int64(), "[10, 11]");
auto children = std::vector<std::shared_ptr<Array>>{a, b, c, d};

auto f0 = field("f0", int8());
auto f1 = field("f1", int16());
auto f2 = field("f2", int32());
auto f3 = field("f1", int64());
auto type = struct_({f0, f1, f2, f3});

auto arr = std::make_shared<StructArray>(type, 2, children);

ASSERT_OK(arr->CanReferenceFieldsByNames({"f0", "f2"}));
ASSERT_OK(arr->CanReferenceFieldsByNames({"f2", "f0"}));

// Not found
ASSERT_RAISES(Invalid, arr->CanReferenceFieldsByNames({"nope"}));
ASSERT_RAISES(Invalid, arr->CanReferenceFieldsByNames({"f0", "nope"}));
// Duplicates
ASSERT_RAISES(Invalid, arr->CanReferenceFieldsByNames({"f1"}));
ASSERT_RAISES(Invalid, arr->CanReferenceFieldsByNames({"f0", "f1"}));
// Both
ASSERT_RAISES(Invalid, arr->CanReferenceFieldsByNames({"f0", "f1", "nope"}));
}

// ----------------------------------------------------------------------------------
// Struct test
class TestStructBuilder : public ::testing::Test {
Expand Down
14 changes: 9 additions & 5 deletions cpp/src/arrow/type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1847,14 +1847,18 @@ std::vector<int> Schema::GetAllFieldIndices(const std::string& name) const {
return result;
}

Status Schema::CanReferenceFieldByName(const std::string& name) const {
if (GetFieldByName(name) == nullptr) {
return Status::Invalid("Field named '", name,
"' not found or not unique in the schema.");
}
return Status::OK();
}

Status Schema::CanReferenceFieldsByNames(const std::vector<std::string>& names) const {
for (const auto& name : names) {
if (GetFieldByName(name) == nullptr) {
return Status::Invalid("Field named '", name,
"' not found or not unique in the schema.");
}
ARROW_RETURN_NOT_OK(CanReferenceFieldByName(name));
}

return Status::OK();
}

Expand Down
3 changes: 3 additions & 0 deletions cpp/src/arrow/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -2048,6 +2048,9 @@ class ARROW_EXPORT Schema : public detail::Fingerprintable,
/// Return the indices of all fields having this name
std::vector<int> GetAllFieldIndices(const std::string& name) const;

/// Indicate if field named `name` can be found unambiguously in the schema.
Status CanReferenceFieldByName(const std::string& name) const;

/// Indicate if fields named `names` can be found unambiguously in the schema.
Status CanReferenceFieldsByNames(const std::vector<std::string>& names) const;

Expand Down
18 changes: 18 additions & 0 deletions cpp/src/arrow/type_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,24 @@ TEST_F(TestSchema, GetFieldDuplicates) {
ASSERT_EQ(results.size(), 0);
}

TEST_F(TestSchema, CanReferenceFieldByName) {
auto f0 = field("f0", int32());
auto f1 = field("f1", uint8(), false);
auto f2 = field("f2", utf8());
auto f3 = field("f1", list(int16()));

auto schema = ::arrow::schema({f0, f1, f2, f3});

ASSERT_OK(schema->CanReferenceFieldByName("f0"));
ASSERT_OK(schema->CanReferenceFieldByName("f2"));

// Not found
ASSERT_RAISES(Invalid, schema->CanReferenceFieldByName("nope"));

// Duplicates
ASSERT_RAISES(Invalid, schema->CanReferenceFieldByName("f1"));
}

TEST_F(TestSchema, CanReferenceFieldsByNames) {
auto f0 = field("f0", int32());
auto f1 = field("f1", uint8(), false);
Expand Down
Loading