diff --git a/cpp/src/arrow/array/array_nested.cc b/cpp/src/arrow/array/array_nested.cc index df60074c78470..d8308c824953a 100644 --- a/cpp/src/arrow/array/array_nested.cc +++ b/cpp/src/arrow/array/array_nested.cc @@ -627,6 +627,22 @@ std::shared_ptr StructArray::GetFieldByName(const std::string& name) cons return i == -1 ? nullptr : field(i); } +Status StructArray::CanReferenceFieldByName(const std::string& name) const { + if (GetFieldByName(name) == nullptr) { + return Status::Invalid("Field named '", name, + "' not found or not unique in the struct."); + } + return Status::OK(); +} + +Status StructArray::CanReferenceFieldsByNames( + const std::vector& names) const { + for (const auto& name : names) { + ARROW_RETURN_NOT_OK(CanReferenceFieldByName(name)); + } + return Status::OK(); +} + Result StructArray::Flatten(MemoryPool* pool) const { ArrayVector flattened; flattened.resize(data_->child_data.size()); diff --git a/cpp/src/arrow/array/array_nested.h b/cpp/src/arrow/array/array_nested.h index 47c1db039ccc9..8d5cc95fec00d 100644 --- a/cpp/src/arrow/array/array_nested.h +++ b/cpp/src/arrow/array/array_nested.h @@ -404,6 +404,12 @@ class ARROW_EXPORT StructArray : public Array { /// Returns null if name not found std::shared_ptr GetFieldByName(const std::string& name) const; + /// Indicate if field named `name` can be found unambiguously in the struct. + Status CanReferenceFieldByName(const std::string& name) const; + + /// Indicate if fields named `names` can be found unambiguously in the struct. + Status CanReferenceFieldsByNames(const std::vector& names) const; + /// \brief Flatten this array as a vector of arrays, one for each field /// /// \param[in] pool The pool to allocate null bitmaps from, if necessary diff --git a/cpp/src/arrow/array/array_struct_test.cc b/cpp/src/arrow/array/array_struct_test.cc index 318c83860e009..73d53a7efa59b 100644 --- a/cpp/src/arrow/array/array_struct_test.cc +++ b/cpp/src/arrow/array/array_struct_test.cc @@ -303,6 +303,58 @@ TEST(StructArray, FlattenOfSlice) { ASSERT_OK(arr->ValidateFull()); } +TEST(StructArray, CanReferenceFieldByName) { + auto a = ArrayFromJSON(int8(), "[4, 5]"); + auto b = ArrayFromJSON(int16(), "[6, 7]"); + auto c = ArrayFromJSON(int32(), "[8, 9]"); + auto d = ArrayFromJSON(int64(), "[10, 11]"); + auto children = std::vector>{a, b, c, d}; + + auto f0 = field("f0", int8()); + auto f1 = field("f1", int16()); + auto f2 = field("f2", int32()); + auto f3 = field("f1", int64()); + auto type = struct_({f0, f1, f2, f3}); + + auto arr = std::make_shared(type, 2, children); + + ASSERT_OK(arr->CanReferenceFieldByName("f0")); + ASSERT_OK(arr->CanReferenceFieldByName("f2")); + // Not found + ASSERT_RAISES(Invalid, arr->CanReferenceFieldByName("nope")); + + // Duplicates + ASSERT_RAISES(Invalid, arr->CanReferenceFieldByName("f1")); +} + +TEST(StructArray, CanReferenceFieldsByNames) { + auto a = ArrayFromJSON(int8(), "[4, 5]"); + auto b = ArrayFromJSON(int16(), "[6, 7]"); + auto c = ArrayFromJSON(int32(), "[8, 9]"); + auto d = ArrayFromJSON(int64(), "[10, 11]"); + auto children = std::vector>{a, b, c, d}; + + auto f0 = field("f0", int8()); + auto f1 = field("f1", int16()); + auto f2 = field("f2", int32()); + auto f3 = field("f1", int64()); + auto type = struct_({f0, f1, f2, f3}); + + auto arr = std::make_shared(type, 2, children); + + ASSERT_OK(arr->CanReferenceFieldsByNames({"f0", "f2"})); + ASSERT_OK(arr->CanReferenceFieldsByNames({"f2", "f0"})); + + // Not found + ASSERT_RAISES(Invalid, arr->CanReferenceFieldsByNames({"nope"})); + ASSERT_RAISES(Invalid, arr->CanReferenceFieldsByNames({"f0", "nope"})); + // Duplicates + ASSERT_RAISES(Invalid, arr->CanReferenceFieldsByNames({"f1"})); + ASSERT_RAISES(Invalid, arr->CanReferenceFieldsByNames({"f0", "f1"})); + // Both + ASSERT_RAISES(Invalid, arr->CanReferenceFieldsByNames({"f0", "f1", "nope"})); +} + // ---------------------------------------------------------------------------------- // Struct test class TestStructBuilder : public ::testing::Test { diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc index 3d294a3fa8642..47bf52660ffe9 100644 --- a/cpp/src/arrow/type.cc +++ b/cpp/src/arrow/type.cc @@ -1847,14 +1847,18 @@ std::vector Schema::GetAllFieldIndices(const std::string& name) const { return result; } +Status Schema::CanReferenceFieldByName(const std::string& name) const { + if (GetFieldByName(name) == nullptr) { + return Status::Invalid("Field named '", name, + "' not found or not unique in the schema."); + } + return Status::OK(); +} + Status Schema::CanReferenceFieldsByNames(const std::vector& names) const { for (const auto& name : names) { - if (GetFieldByName(name) == nullptr) { - return Status::Invalid("Field named '", name, - "' not found or not unique in the schema."); - } + ARROW_RETURN_NOT_OK(CanReferenceFieldByName(name)); } - return Status::OK(); } diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h index 718540d449226..19910979287cc 100644 --- a/cpp/src/arrow/type.h +++ b/cpp/src/arrow/type.h @@ -2048,6 +2048,9 @@ class ARROW_EXPORT Schema : public detail::Fingerprintable, /// Return the indices of all fields having this name std::vector GetAllFieldIndices(const std::string& name) const; + /// Indicate if field named `name` can be found unambiguously in the schema. + Status CanReferenceFieldByName(const std::string& name) const; + /// Indicate if fields named `names` can be found unambiguously in the schema. Status CanReferenceFieldsByNames(const std::vector& names) const; diff --git a/cpp/src/arrow/type_test.cc b/cpp/src/arrow/type_test.cc index c55b33b4151e4..3dbefdcf0c564 100644 --- a/cpp/src/arrow/type_test.cc +++ b/cpp/src/arrow/type_test.cc @@ -548,6 +548,24 @@ TEST_F(TestSchema, GetFieldDuplicates) { ASSERT_EQ(results.size(), 0); } +TEST_F(TestSchema, CanReferenceFieldByName) { + auto f0 = field("f0", int32()); + auto f1 = field("f1", uint8(), false); + auto f2 = field("f2", utf8()); + auto f3 = field("f1", list(int16())); + + auto schema = ::arrow::schema({f0, f1, f2, f3}); + + ASSERT_OK(schema->CanReferenceFieldByName("f0")); + ASSERT_OK(schema->CanReferenceFieldByName("f2")); + + // Not found + ASSERT_RAISES(Invalid, schema->CanReferenceFieldByName("nope")); + + // Duplicates + ASSERT_RAISES(Invalid, schema->CanReferenceFieldByName("f1")); +} + TEST_F(TestSchema, CanReferenceFieldsByNames) { auto f0 = field("f0", int32()); auto f1 = field("f1", uint8(), false);