diff --git a/cpp/src/arrow/chunked_array.cc b/cpp/src/arrow/chunked_array.cc index c5e6d7fa4bdf0..c4a72323d3889 100644 --- a/cpp/src/arrow/chunked_array.cc +++ b/cpp/src/arrow/chunked_array.cc @@ -93,7 +93,8 @@ bool ChunkedArray::Equals(const ChunkedArray& other) const { return false; } // We cannot toggle check_metadata here yet, so we don't check it - if (!type_->Equals(*other.type_, /*check_metadata=*/false)) { + if (!type_->Equals(*other.type_, /*check_metadata=*/false, + /*check_internal_field_names=*/false)) { return false; } @@ -130,7 +131,8 @@ bool ChunkedArray::ApproxEquals(const ChunkedArray& other, return false; } // We cannot toggle check_metadata here yet, so we don't check it - if (!type_->Equals(*other.type_, /*check_metadata=*/false)) { + if (!type_->Equals(*other.type_, /*check_metadata=*/false, + /*check_internal_field_names=*/false)) { return false; } diff --git a/cpp/src/arrow/compare.cc b/cpp/src/arrow/compare.cc index baadd10cca98b..3e4bf2f47044d 100644 --- a/cpp/src/arrow/compare.cc +++ b/cpp/src/arrow/compare.cc @@ -43,6 +43,7 @@ #include "arrow/util/bitmap_ops.h" #include "arrow/util/bitmap_reader.h" #include "arrow/util/checked_cast.h" +#include "arrow/util/key_value_metadata.h" #include "arrow/util/logging.h" #include "arrow/util/macros.h" #include "arrow/util/memory.h" @@ -530,7 +531,8 @@ bool CompareArrayRanges(const ArrayData& left, const ArrayData& right, int64_t right_start_idx, const EqualOptions& options, bool floating_approximate) { if (left.type->id() != right.type->id() || - !TypeEquals(*left.type, *right.type, false /* check_metadata */)) { + !TypeEquals(*left.type, *right.type, /*check_metadata=*/false, + /*check_internal_field_names=*/false)) { return false; } @@ -556,8 +558,20 @@ bool CompareArrayRanges(const ArrayData& left, const ArrayData& right, class TypeEqualsVisitor { public: - explicit TypeEqualsVisitor(const DataType& right, bool check_metadata) - : right_(right), check_metadata_(check_metadata), result_(false) {} + explicit TypeEqualsVisitor(const DataType& right, bool check_metadata, + bool check_internal_field_names) + : right_(right), + check_metadata_(check_metadata), + check_internal_field_names_(check_internal_field_names), + result_(false) {} + + bool MetadataEqual(const Field& left, const Field& right) { + if (left.HasMetadata() && right.HasMetadata()) { + return left.metadata()->Equals(*right.metadata()); + } else { + return !left.HasMetadata() && !right.HasMetadata(); + } + } Status VisitChildren(const DataType& left) { if (left.num_fields() != right_.num_fields()) { @@ -626,8 +640,23 @@ class TypeEqualsVisitor { } template - enable_if_t::value || is_struct_type::value, Status> Visit( - const T& left) { + enable_if_t::value, Status> Visit(const T& left) { + std::shared_ptr left_field = left.field(0); + std::shared_ptr right_field = checked_cast(right_).field(0); + bool equal_names = + !check_internal_field_names_ || (left_field->name() == right_field->name()); + bool equal_metadata = !check_metadata_ || MetadataEqual(*left_field, *right_field); + + result_ = equal_names && equal_metadata && + (left_field->nullable() == right_field->nullable()) && + left_field->type()->Equals(*right_field->type(), check_metadata_, + check_internal_field_names_); + + return Status::OK(); + } + + template + enable_if_t::value, Status> Visit(const T& left) { return VisitChildren(left); } @@ -637,8 +666,23 @@ class TypeEqualsVisitor { result_ = false; return Status::OK(); } - result_ = left.key_type()->Equals(*right.key_type(), check_metadata_) && - left.item_type()->Equals(*right.item_type(), check_metadata_); + if (check_internal_field_names_ && + (left.item_field()->name() != right.item_field()->name() || + left.key_field()->name() != right.key_field()->name() || + left.value_field()->name() != right.value_field()->name())) { + result_ = false; + return Status::OK(); + } + if (check_metadata_ && !(MetadataEqual(*left.item_field(), *right.item_field()) && + MetadataEqual(*left.key_field(), *right.key_field()) && + MetadataEqual(*left.value_field(), *right.value_field()))) { + result_ = false; + return Status::OK(); + } + result_ = left.key_type()->Equals(*right.key_type(), check_metadata_, + check_internal_field_names_) && + left.item_type()->Equals(*right.item_type(), check_metadata_, + check_internal_field_names_); return Status::OK(); } @@ -676,6 +720,7 @@ class TypeEqualsVisitor { protected: const DataType& right_; bool check_metadata_; + bool check_internal_field_names_; bool result_; }; @@ -1267,13 +1312,14 @@ bool SparseTensorEquals(const SparseTensor& left, const SparseTensor& right, } } -bool TypeEquals(const DataType& left, const DataType& right, bool check_metadata) { +bool TypeEquals(const DataType& left, const DataType& right, bool check_metadata, + bool check_internal_field_names) { // The arrays are the same object if (&left == &right) { return true; } else if (left.id() != right.id()) { return false; - } else { + } else if (!check_internal_field_names) { // First try to compute fingerprints if (check_metadata) { const auto& left_metadata_fp = left.metadata_fingerprint(); @@ -1288,15 +1334,15 @@ bool TypeEquals(const DataType& left, const DataType& right, bool check_metadata if (!left_fp.empty() && !right_fp.empty()) { return left_fp == right_fp; } + } - // TODO remove check_metadata here? - TypeEqualsVisitor visitor(right, check_metadata); - auto error = VisitTypeInline(left, &visitor); - if (!error.ok()) { - DCHECK(false) << "Types are not comparable: " << error.ToString(); - } - return visitor.result(); + // TODO remove check_metadata here? + TypeEqualsVisitor visitor(right, check_metadata, check_internal_field_names); + auto error = VisitTypeInline(left, &visitor); + if (!error.ok()) { + DCHECK(false) << "Types are not comparable: " << error.ToString(); } + return visitor.result(); } } // namespace arrow diff --git a/cpp/src/arrow/compare.h b/cpp/src/arrow/compare.h index 6dbacfa86af59..03db748bcc5e7 100644 --- a/cpp/src/arrow/compare.h +++ b/cpp/src/arrow/compare.h @@ -124,8 +124,11 @@ ARROW_EXPORT bool SparseTensorEquals(const SparseTensor& left, const SparseTenso /// \param[in] right a DataType /// \param[in] check_metadata whether to compare KeyValueMetadata for child /// fields +/// \param[in] check_internal_field_names whether to consider list or map types +/// with differing field names as unequal. ARROW_EXPORT bool TypeEquals(const DataType& left, const DataType& right, - bool check_metadata = true); + bool check_metadata = true, + bool check_internal_field_names = true); /// Returns true if scalars are equal /// \param[in] left a Scalar diff --git a/cpp/src/arrow/compute/cast.cc b/cpp/src/arrow/compute/cast.cc index 99e8b89f1ca13..b0c9c52ca6925 100644 --- a/cpp/src/arrow/compute/cast.cc +++ b/cpp/src/arrow/compute/cast.cc @@ -97,14 +97,14 @@ class CastMetaFunction : public MetaFunction { ARROW_ASSIGN_OR_RAISE(auto cast_options, ValidateOptions(options)); // args[0].type() could be a nullptr so check for that before // we do anything with it. - if (args[0].type() && args[0].type()->Equals(*cast_options->to_type)) { + if (args[0].type() && + args[0].type()->Equals(*cast_options->to_type, /*check_metadata=*/false, + /*check_internal_field_names=*/false)) { // Nested types might differ in field names but still be considered equal, // so we can only return non-nested types as-is. if (!is_nested(args[0].type()->id())) { return args[0]; } else if (args[0].is_array()) { - // TODO(ARROW-14999): if types are equal except for field names of list - // types, we can also use this code path. ARROW_ASSIGN_OR_RAISE(std::shared_ptr array, ::arrow::internal::GetArrayView( args[0].array(), cast_options->to_type.owned_type)); diff --git a/cpp/src/arrow/table.cc b/cpp/src/arrow/table.cc index 47f82631782e4..b09b9478c3e69 100644 --- a/cpp/src/arrow/table.cc +++ b/cpp/src/arrow/table.cc @@ -451,7 +451,8 @@ Result> PromoteTableToSchema(const std::shared_ptr const std::shared_ptr& schema, MemoryPool* pool) { const std::shared_ptr current_schema = table->schema(); - if (current_schema->Equals(*schema, /*check_metadata=*/false)) { + if (current_schema->Equals(*schema, /*check_metadata=*/false, + /*check_internal_field_names=*/true)) { return table->ReplaceSchemaMetadata(schema->metadata()); } diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc index ea9525404c816..3602f9d80dc16 100644 --- a/cpp/src/arrow/type.cc +++ b/cpp/src/arrow/type.cc @@ -320,7 +320,7 @@ Result> Field::MergeWith(const Field& other, other.name()); } - if (Equals(other, /*check_metadata=*/false)) { + if (Equals(other, /*check_metadata=*/false, /*check_internal_field_names=*/false)) { return Copy(); } @@ -362,12 +362,14 @@ std::shared_ptr Field::Copy() const { return ::arrow::field(name_, type_, nullable_, metadata_); } -bool Field::Equals(const Field& other, bool check_metadata) const { +bool Field::Equals(const Field& other, bool check_metadata, + bool check_internal_field_names) const { if (this == &other) { return true; } if (this->name_ == other.name_ && this->nullable_ == other.nullable_ && - this->type_->Equals(*other.type_.get(), check_metadata)) { + this->type_->Equals(*other.type_.get(), check_metadata, + check_internal_field_names)) { if (!check_metadata) { return true; } else if (this->HasMetadata() && other.HasMetadata()) { @@ -381,8 +383,9 @@ bool Field::Equals(const Field& other, bool check_metadata) const { return false; } -bool Field::Equals(const std::shared_ptr& other, bool check_metadata) const { - return Equals(*other.get(), check_metadata); +bool Field::Equals(const std::shared_ptr& other, bool check_metadata, + bool check_internal_field_names) const { + return Equals(*other.get(), check_metadata, check_internal_field_names); } bool Field::IsCompatibleWith(const Field& other) const { return MergeWith(other).ok(); } @@ -408,15 +411,17 @@ void PrintTo(const Field& field, std::ostream* os) { *os << field.ToString(); } DataType::~DataType() {} -bool DataType::Equals(const DataType& other, bool check_metadata) const { - return TypeEquals(*this, other, check_metadata); +bool DataType::Equals(const DataType& other, bool check_metadata, + bool check_internal_field_names) const { + return TypeEquals(*this, other, check_metadata, check_internal_field_names); } -bool DataType::Equals(const std::shared_ptr& other) const { +bool DataType::Equals(const std::shared_ptr& other, bool check_metadata, + bool check_internal_field_names) const { if (!other) { return false; } - return Equals(*other.get()); + return Equals(*other.get(), check_metadata, check_internal_field_names); } size_t DataType::Hash() const { @@ -1557,7 +1562,8 @@ const std::vector>& Schema::fields() const { return impl_->fields_; } -bool Schema::Equals(const Schema& other, bool check_metadata) const { +bool Schema::Equals(const Schema& other, bool check_metadata, + bool check_internal_field_names) const { if (this == &other) { return true; } @@ -1589,7 +1595,8 @@ bool Schema::Equals(const Schema& other, bool check_metadata) const { // Fall back on field-by-field comparison for (int i = 0; i < num_fields(); ++i) { - if (!field(i)->Equals(*other.field(i).get(), check_metadata)) { + if (!field(i)->Equals(*other.field(i).get(), check_metadata, + check_internal_field_names)) { return false; } } @@ -1597,12 +1604,13 @@ bool Schema::Equals(const Schema& other, bool check_metadata) const { return true; } -bool Schema::Equals(const std::shared_ptr& other, bool check_metadata) const { +bool Schema::Equals(const std::shared_ptr& other, bool check_metadata, + bool check_internal_field_names) const { if (other == nullptr) { return false; } - return Equals(*other, check_metadata); + return Equals(*other, check_metadata, check_internal_field_names); } std::shared_ptr Schema::GetFieldByName(const std::string& name) const { @@ -2136,17 +2144,33 @@ std::string DictionaryType::ComputeFingerprint() const { } std::string ListType::ComputeFingerprint() const { - const auto& child_fingerprint = children_[0]->fingerprint(); + const auto& child_fingerprint = value_type()->fingerprint(); if (!child_fingerprint.empty()) { - return TypeIdFingerprint(*this) + "{" + child_fingerprint + "}"; + std::stringstream ss; + ss << TypeIdFingerprint(*this); + if (value_field()->nullable()) { + ss << 'n'; + } else { + ss << 'N'; + } + ss << '{' << child_fingerprint << '}'; + return ss.str(); } return ""; } std::string LargeListType::ComputeFingerprint() const { - const auto& child_fingerprint = children_[0]->fingerprint(); + const auto& child_fingerprint = value_type()->fingerprint(); if (!child_fingerprint.empty()) { - return TypeIdFingerprint(*this) + "{" + child_fingerprint + "}"; + std::stringstream ss; + ss << TypeIdFingerprint(*this); + if (value_field()->nullable()) { + ss << 'n'; + } else { + ss << 'N'; + } + ss << '{' << child_fingerprint << '}'; + return ss.str(); } return ""; } @@ -2155,20 +2179,33 @@ std::string MapType::ComputeFingerprint() const { const auto& key_fingerprint = key_type()->fingerprint(); const auto& item_fingerprint = item_type()->fingerprint(); if (!key_fingerprint.empty() && !item_fingerprint.empty()) { + std::stringstream ss; + ss << TypeIdFingerprint(*this); if (keys_sorted_) { - return TypeIdFingerprint(*this) + "s{" + key_fingerprint + item_fingerprint + "}"; + ss << 's'; + } + if (item_field()->nullable()) { + ss << 'n'; } else { - return TypeIdFingerprint(*this) + "{" + key_fingerprint + item_fingerprint + "}"; + ss << 'N'; } + ss << '{' << key_fingerprint + item_fingerprint << '}'; + return ss.str(); } return ""; } std::string FixedSizeListType::ComputeFingerprint() const { - const auto& child_fingerprint = children_[0]->fingerprint(); + const auto& child_fingerprint = value_type()->fingerprint(); if (!child_fingerprint.empty()) { std::stringstream ss; - ss << TypeIdFingerprint(*this) << "[" << list_size_ << "]" + ss << TypeIdFingerprint(*this); + if (value_field()->nullable()) { + ss << 'n'; + } else { + ss << 'N'; + } + ss << "[" << list_size_ << "]" << "{" << child_fingerprint << "}"; return ss.str(); } diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h index 415aaacf1c9ef..00abd71ba5c05 100644 --- a/cpp/src/arrow/type.h +++ b/cpp/src/arrow/type.h @@ -137,10 +137,24 @@ class ARROW_EXPORT DataType : public std::enable_shared_from_this, /// /// Types that are logically convertible from one to another (e.g. List /// and Binary) are NOT equal. - bool Equals(const DataType& other, bool check_metadata = false) const; + /// + /// \param[in] other the DataType to compare with. + /// \param[in] check_metadata whether to compare KeyValueMetadata for child + /// fields. + /// \param[in] check_internal_field_names whether to consider list or map types + /// with differing field names as unequal. + bool Equals(const DataType& other, bool check_metadata = false, + bool check_internal_field_names = false) const; /// \brief Return whether the types are equal - bool Equals(const std::shared_ptr& other) const; + /// + /// \param[in] other the DataType to compare with. + /// \param[in] check_metadata whether to compare KeyValueMetadata for child + /// fields. + /// \param[in] check_internal_field_names whether to consider list or map types + /// with differing field names as unequal. + bool Equals(const std::shared_ptr& other, bool check_metadata = false, + bool check_internal_field_names = false) const; /// \brief Return the child field at index i. const std::shared_ptr& field(int i) const { return children_[i]; } @@ -407,10 +421,14 @@ class ARROW_EXPORT Field : public detail::Fingerprintable, /// \param[in] other field to check equality with. /// \param[in] check_metadata controls if it should check for metadata /// equality. + /// \param[in] check_internal_field_names if true, will check whether + /// the field names. /// /// \return true if fields are equal, false otherwise. - bool Equals(const Field& other, bool check_metadata = false) const; - bool Equals(const std::shared_ptr& other, bool check_metadata = false) const; + bool Equals(const Field& other, bool check_metadata = false, + bool check_internal_field_names = false) const; + bool Equals(const std::shared_ptr& other, bool check_metadata = false, + bool check_internal_field_names = false) const; /// \brief Indicate if fields are compatibles. /// @@ -1885,8 +1903,10 @@ class ARROW_EXPORT Schema : public detail::Fingerprintable, ~Schema() override; /// Returns true if all of the schema fields are equal - bool Equals(const Schema& other, bool check_metadata = false) const; - bool Equals(const std::shared_ptr& other, bool check_metadata = false) const; + bool Equals(const Schema& other, bool check_metadata = false, + bool check_internal_field_names = false) const; + bool Equals(const std::shared_ptr& other, bool check_metadata = false, + bool check_internal_field_names = false) const; /// \brief Set endianness in the schema /// diff --git a/cpp/src/arrow/type_test.cc b/cpp/src/arrow/type_test.cc index 954ad63c8aa68..90d2374f3e8f0 100644 --- a/cpp/src/arrow/type_test.cc +++ b/cpp/src/arrow/type_test.cc @@ -1262,6 +1262,8 @@ TEST(TestLargeListType, Basics) { } TEST(TestMapType, Basics) { + auto md = key_value_metadata({"foo"}, {"foo value"}); + std::shared_ptr kt = std::make_shared(); std::shared_ptr it = std::make_shared(); @@ -1294,6 +1296,46 @@ TEST(TestMapType, Basics) { "some_entries", struct_({field("some_key", kt, false), field("some_value", mt)}), false))); AssertTypeEqual(mt3, *mt5); + // ...unless we explicitly ask about them. + ASSERT_FALSE( + mt3.Equals(mt5, /*check_metadata=*/false, /*check_internal_field_names=*/true)); + + // nullability of value type matters in comparisons + MapType map_type_non_nullable(kt, field("value", it, /*nullable=*/false)); + AssertTypeNotEqual(map_type, map_type_non_nullable); +} + +TEST(TestMapType, Metadata) { + auto md1 = key_value_metadata({"foo", "bar"}, {"foo value", "bar value"}); + auto md2 = key_value_metadata({"foo", "bar"}, {"foo value", "bar value"}); + auto md3 = key_value_metadata({"foo"}, {"foo value"}); + + auto t1 = map(utf8(), field("value", int32(), md1)); + auto t2 = map(utf8(), field("value", int32(), md2)); + auto t3 = map(utf8(), field("value", int32(), md3)); + auto t4 = + std::make_shared(field("key", utf8(), md1), field("value", int32(), md2)); + ASSERT_OK_AND_ASSIGN(auto t5, + MapType::Make(field("some_entries", + struct_({field("some_key", utf8(), false), + field("some_value", int32(), md2)}), + false, md2))); + + AssertTypeEqual(*t1, *t2); + AssertTypeEqual(*t1, *t2, /*check_metadata=*/true); + ASSERT_TRUE( + t1->Equals(t2, /*check_metadata=*/true, /*check_internal_field_names=*/true)); + + AssertTypeEqual(*t1, *t3); + AssertTypeNotEqual(*t1, *t3, /*check_metadata=*/true); + ASSERT_FALSE( + t1->Equals(t3, /*check_metadata=*/true, /*check_internal_field_names=*/true)); + + AssertTypeEqual(*t1, *t4); + AssertTypeNotEqual(*t1, *t4, /*check_metadata=*/true); + + AssertTypeEqual(*t1, *t5); + AssertTypeNotEqual(*t1, *t5, /*check_metadata=*/true); } TEST(TestFixedSizeListType, Basics) { @@ -1478,15 +1520,27 @@ TEST(TestListType, Equals) { auto t1 = list(utf8()); auto t2 = list(utf8()); auto t3 = list(binary()); - auto t4 = large_list(binary()); - auto t5 = large_list(binary()); - auto t6 = large_list(float64()); + auto t4 = list(field("item", utf8(), /*nullable=*/false)); + auto tl1 = large_list(binary()); + auto tl2 = large_list(binary()); + auto tl3 = large_list(float64()); AssertTypeEqual(*t1, *t2); AssertTypeNotEqual(*t1, *t3); - AssertTypeNotEqual(*t3, *t4); - AssertTypeEqual(*t4, *t5); - AssertTypeNotEqual(*t5, *t6); + AssertTypeNotEqual(*t1, *t4); + AssertTypeNotEqual(*t3, *tl1); + AssertTypeEqual(*tl1, *tl2); + AssertTypeNotEqual(*tl2, *tl3); + + std::shared_ptr vt = std::make_shared(); + std::shared_ptr inner_field = std::make_shared("non_default_name", vt); + + ListType list_type(vt); + ListType list_type_named(inner_field); + + AssertTypeEqual(list_type, list_type_named); + ASSERT_FALSE( + list_type.Equals(list_type_named, false, /*check_internal_field_names*/ true)); } TEST(TestListType, Metadata) { @@ -1507,10 +1561,14 @@ TEST(TestListType, Metadata) { auto t5 = list(f5); AssertTypeEqual(*t1, *t2); - AssertTypeEqual(*t1, *t2, /*check_metadata =*/false); + AssertTypeEqual(*t1, *t2, /*check_metadata =*/true); + ASSERT_TRUE( + t1->Equals(t2, /*check_metadata =*/true, /*check_internal_field_names*/ true)); AssertTypeEqual(*t1, *t3); AssertTypeNotEqual(*t1, *t3, /*check_metadata =*/true); + ASSERT_FALSE( + t1->Equals(t3, /*check_metadata =*/true, /*check_internal_field_names*/ true)); AssertTypeEqual(*t1, *t4); AssertTypeNotEqual(*t1, *t4, /*check_metadata =*/true); diff --git a/java/c/src/test/python/integration_tests.py b/java/c/src/test/python/integration_tests.py index 33ff1cf4a9af5..1773837bc7db5 100644 --- a/java/c/src/test/python/integration_tests.py +++ b/java/c/src/test/python/integration_tests.py @@ -16,6 +16,7 @@ # under the License. import decimal +from email.policy import strict import gc import os import sys @@ -142,7 +143,7 @@ def round_trip_field(self, field_generator): expected = field_generator() self.assertEqual(expected, new_field) - def round_trip_array(self, array_generator, expected_diff=None): + def round_trip_array(self, array_generator, ignore_field_names=False): original_arr = array_generator() with self.bridge.java_c.CDataDictionaryProvider() as dictionary_provider, \ self.bridge.python_to_java_array(original_arr, dictionary_provider) as vector: @@ -150,9 +151,11 @@ def round_trip_array(self, array_generator, expected_diff=None): new_array = self.bridge.java_to_python_array(vector, dictionary_provider) expected = array_generator() - if expected_diff: - self.assertEqual(expected, new_array.view(expected.type)) - self.assertEqual(expected.diff(new_array), expected_diff or '') + + self.assertEqual(expected, new_array) + if not ignore_field_names: + self.assertTrue(expected.type.equals(new_array.type, + check_metadata=True, check_internal_field_names=True)) def round_trip_record_batch(self, rb_generator): original_rb = rb_generator() @@ -191,7 +194,7 @@ def test_int_array(self): def test_list_array(self): self.round_trip_array(lambda: pa.array( [[], [0], [1, 2], [4, 5, 6]], pa.list_(pa.int64()) - ), "# Array types differed: list vs list<$data$: int64>\n") + ), ignore_field_names=True) def test_struct_array(self): fields = [ @@ -218,7 +221,7 @@ def test_map(self): keys = pa.array(pykeys, type="binary") items = pa.array(pyitems, type="i4") self.round_trip_array( - lambda: pa.MapArray.from_arrays(offsets, keys, items)) + lambda: pa.MapArray.from_arrays(offsets, keys, items), ignore_field_names=True) def test_field(self): self.round_trip_field(lambda: pa.field("aa", pa.bool_())) diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index 9cea340a3090e..69138cd472e4d 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -153,8 +153,8 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil: cdef cppclass CDataType" arrow::DataType": Type id() - c_bool Equals(const CDataType& other) - c_bool Equals(const shared_ptr[CDataType]& other) + c_bool Equals(const CDataType& other, c_bool check_metadata, c_bool check_internal_field_names) + c_bool Equals(const shared_ptr[CDataType]& other, c_bool check_metadata, c_bool check_internal_field_names) shared_ptr[CField] field(int i) const vector[shared_ptr[CField]] fields() diff --git a/python/pyarrow/tests/test_types.py b/python/pyarrow/tests/test_types.py index e922ca0e1caf6..d9efc1f243cac 100644 --- a/python/pyarrow/tests/test_types.py +++ b/python/pyarrow/tests/test_types.py @@ -518,6 +518,21 @@ def test_list_type(): assert ty.value_type == pa.int64() assert ty.value_field == pa.field("item", pa.int64(), nullable=True) + # nullability matters in comparison + ty_non_nullable = pa.list_(pa.field("item", pa.int64(), nullable=False)) + assert ty != ty_non_nullable + + # field names don't matter by default + ty_named = pa.list_(pa.field("element", pa.int64())) + assert ty == ty_named + assert not ty.equals(ty_named, check_internal_field_names=True) + + # metadata doesn't matter by default + ty_metadata = pa.list_( + pa.field("item", pa.int64(), metadata={"hello": "world"})) + assert ty == ty_metadata + assert not ty.equals(ty_metadata, check_metadata=True) + with pytest.raises(TypeError): pa.list_(None) @@ -540,6 +555,23 @@ def test_map_type(): assert ty.item_type == pa.int32() assert ty.item_field == pa.field("value", pa.int32(), nullable=True) + # nullability matters in comparison + ty_non_nullable = pa.map_(pa.utf8(), pa.field( + "value", pa.int32(), nullable=False)) + assert ty != ty_non_nullable + + # field names don't matter by default + ty_named = pa.map_(pa.field("x", pa.utf8(), nullable=False), + pa.field("y", pa.int32())) + assert ty == ty_named + assert not ty.equals(ty_named, check_internal_field_names=True) + + # metadata doesn't matter by default + ty_metadata = pa.map_(pa.utf8(), pa.field( + "value", pa.int32(), metadata={"hello": "world"})) + assert ty == ty_metadata + assert not ty.equals(ty_metadata, check_metadata=True) + with pytest.raises(TypeError): pa.map_(None) with pytest.raises(TypeError): diff --git a/python/pyarrow/types.pxi b/python/pyarrow/types.pxi index 8d5b261acb967..f92e9ccaae653 100644 --- a/python/pyarrow/types.pxi +++ b/python/pyarrow/types.pxi @@ -192,22 +192,32 @@ cdef class DataType(_Weakrefable): except (TypeError, ValueError): return NotImplemented - def equals(self, other): + def equals(self, other, check_metadata=False, check_internal_field_names=False): """ Return true if type is equivalent to passed value. Parameters ---------- other : DataType or string convertible to DataType + check_metadata : bool + Whether nested Field metadata equality should be checked as well. + check_internal_field_names : bool + Whether field names of ListType or MapType should be checked as well. Returns ------- is_equal : bool """ - cdef DataType other_type + cdef: + DataType other_type + c_bool c_check_metadata + c_bool c_check_internal_field_names + + c_check_metadata = check_metadata + c_check_internal_field_names = check_internal_field_names other_type = ensure_type(other) - return self.type.Equals(deref(other_type.type)) + return self.type.Equals(deref(other_type.type), c_check_metadata, c_check_internal_field_names) def to_pandas_dtype(self): """ @@ -870,7 +880,7 @@ cdef class BaseExtensionType(DataType): f"Expected array or chunked array, got {storage.__class__}") if not c_storage_type.get().Equals(deref(self.ext_type) - .storage_type()): + .storage_type(), False, False): raise TypeError( f"Incompatible storage type for {self}: " f"expected {self.storage_type}, got {storage.type}") diff --git a/r/R/arrowExports.R b/r/R/arrowExports.R index 144044d7e74e0..a30c535e3b12f 100644 --- a/r/R/arrowExports.R +++ b/r/R/arrowExports.R @@ -936,8 +936,8 @@ DataType__name <- function(type) { .Call(`_arrow_DataType__name`, type) } -DataType__Equals <- function(lhs, rhs) { - .Call(`_arrow_DataType__Equals`, lhs, rhs) +DataType__Equals <- function(lhs, rhs, check_metadata, check_internal_field_names) { + .Call(`_arrow_DataType__Equals`, lhs, rhs, check_metadata, check_internal_field_names) } DataType__num_fields <- function(type) { @@ -1956,8 +1956,8 @@ Schema__serialize <- function(schema) { .Call(`_arrow_Schema__serialize`, schema) } -Schema__Equals <- function(schema, other, check_metadata) { - .Call(`_arrow_Schema__Equals`, schema, other, check_metadata) +Schema__Equals <- function(schema, other, check_metadata, check_internal_field_names) { + .Call(`_arrow_Schema__Equals`, schema, other, check_metadata, check_internal_field_names) } arrow__UnifySchemas <- function(schemas) { diff --git a/r/R/field.R b/r/R/field.R index fce193ab53a41..eaabd2eefe164 100644 --- a/r/R/field.R +++ b/r/R/field.R @@ -57,6 +57,7 @@ Field <- R6Class("Field", Field$create <- function(name, type, metadata, nullable = TRUE) { assert_that(inherits(name, "character"), length(name) == 1L) type <- as_type(type, name) + # TODO(ARROW-18204): accept field metadata assert_that(missing(metadata), msg = "metadata= is currently ignored") Field__initialize(enc2utf8(name), type, nullable) } diff --git a/r/R/schema.R b/r/R/schema.R index 93e826eff2880..0ba3445513f03 100644 --- a/r/R/schema.R +++ b/r/R/schema.R @@ -118,8 +118,8 @@ Schema <- R6Class("Schema", metadata <- prepare_key_value_metadata(metadata) Schema__WithMetadata(self, metadata) }, - Equals = function(other, check_metadata = FALSE, ...) { - inherits(other, "Schema") && Schema__Equals(self, other, isTRUE(check_metadata)) + Equals = function(other, check_metadata = FALSE, check_internal_field_names = FALSE, ...) { + inherits(other, "Schema") && Schema__Equals(self, other, isTRUE(check_metadata), isTRUE(check_internal_field_names)) }, export_to_c = function(ptr) ExportSchema(self, ptr), code = function() { diff --git a/r/R/type.R b/r/R/type.R index cda606e3fa955..c41214a10d34e 100644 --- a/r/R/type.R +++ b/r/R/type.R @@ -37,8 +37,8 @@ DataType <- R6Class("DataType", ToString = function() { DataType__ToString(self) }, - Equals = function(other, ...) { - inherits(other, "DataType") && DataType__Equals(self, other) + Equals = function(other, check_metadata = FALSE, check_internal_field_names = FALSE, ...) { + inherits(other, "DataType") && DataType__Equals(self, other, isTRUE(check_metadata), isTRUE(check_internal_field_names)) }, fields = function() { DataType__fields(self) diff --git a/r/src/arrowExports.cpp b/r/src/arrowExports.cpp index d3f97f5a99f74..0c2c1f43c1472 100644 --- a/r/src/arrowExports.cpp +++ b/r/src/arrowExports.cpp @@ -2426,12 +2426,14 @@ BEGIN_CPP11 END_CPP11 } // datatype.cpp -bool DataType__Equals(const std::shared_ptr& lhs, const std::shared_ptr& rhs); -extern "C" SEXP _arrow_DataType__Equals(SEXP lhs_sexp, SEXP rhs_sexp){ +bool DataType__Equals(const std::shared_ptr& lhs, const std::shared_ptr& rhs, bool check_metadata, bool check_internal_field_names); +extern "C" SEXP _arrow_DataType__Equals(SEXP lhs_sexp, SEXP rhs_sexp, SEXP check_metadata_sexp, SEXP check_internal_field_names_sexp){ BEGIN_CPP11 arrow::r::Input&>::type lhs(lhs_sexp); arrow::r::Input&>::type rhs(rhs_sexp); - return cpp11::as_sexp(DataType__Equals(lhs, rhs)); + arrow::r::Input::type check_metadata(check_metadata_sexp); + arrow::r::Input::type check_internal_field_names(check_internal_field_names_sexp); + return cpp11::as_sexp(DataType__Equals(lhs, rhs, check_metadata, check_internal_field_names)); END_CPP11 } // datatype.cpp @@ -4947,13 +4949,14 @@ BEGIN_CPP11 END_CPP11 } // schema.cpp -bool Schema__Equals(const std::shared_ptr& schema, const std::shared_ptr& other, bool check_metadata); -extern "C" SEXP _arrow_Schema__Equals(SEXP schema_sexp, SEXP other_sexp, SEXP check_metadata_sexp){ +bool Schema__Equals(const std::shared_ptr& schema, const std::shared_ptr& other, bool check_metadata, bool check_internal_field_names); +extern "C" SEXP _arrow_Schema__Equals(SEXP schema_sexp, SEXP other_sexp, SEXP check_metadata_sexp, SEXP check_internal_field_names_sexp){ BEGIN_CPP11 arrow::r::Input&>::type schema(schema_sexp); arrow::r::Input&>::type other(other_sexp); arrow::r::Input::type check_metadata(check_metadata_sexp); - return cpp11::as_sexp(Schema__Equals(schema, other, check_metadata)); + arrow::r::Input::type check_internal_field_names(check_internal_field_names_sexp); + return cpp11::as_sexp(Schema__Equals(schema, other, check_metadata, check_internal_field_names)); END_CPP11 } // schema.cpp @@ -5511,7 +5514,7 @@ static const R_CallMethodDef CallEntries[] = { { "_arrow_struct__", (DL_FUNC) &_arrow_struct__, 1}, { "_arrow_DataType__ToString", (DL_FUNC) &_arrow_DataType__ToString, 1}, { "_arrow_DataType__name", (DL_FUNC) &_arrow_DataType__name, 1}, - { "_arrow_DataType__Equals", (DL_FUNC) &_arrow_DataType__Equals, 2}, + { "_arrow_DataType__Equals", (DL_FUNC) &_arrow_DataType__Equals, 4}, { "_arrow_DataType__num_fields", (DL_FUNC) &_arrow_DataType__num_fields, 1}, { "_arrow_DataType__fields", (DL_FUNC) &_arrow_DataType__fields, 1}, { "_arrow_DataType__id", (DL_FUNC) &_arrow_DataType__id, 1}, @@ -5766,7 +5769,7 @@ static const R_CallMethodDef CallEntries[] = { { "_arrow_Schema__metadata", (DL_FUNC) &_arrow_Schema__metadata, 1}, { "_arrow_Schema__WithMetadata", (DL_FUNC) &_arrow_Schema__WithMetadata, 2}, { "_arrow_Schema__serialize", (DL_FUNC) &_arrow_Schema__serialize, 1}, - { "_arrow_Schema__Equals", (DL_FUNC) &_arrow_Schema__Equals, 3}, + { "_arrow_Schema__Equals", (DL_FUNC) &_arrow_Schema__Equals, 4}, { "_arrow_arrow__UnifySchemas", (DL_FUNC) &_arrow_arrow__UnifySchemas, 1}, { "_arrow_Table__num_columns", (DL_FUNC) &_arrow_Table__num_columns, 1}, { "_arrow_Table__num_rows", (DL_FUNC) &_arrow_Table__num_rows, 1}, diff --git a/r/src/datatype.cpp b/r/src/datatype.cpp index dc8d3b18926ae..d39dedaee8c0d 100644 --- a/r/src/datatype.cpp +++ b/r/src/datatype.cpp @@ -327,8 +327,9 @@ std::string DataType__name(const std::shared_ptr& type) { // [[arrow::export]] bool DataType__Equals(const std::shared_ptr& lhs, - const std::shared_ptr& rhs) { - return lhs->Equals(*rhs); + const std::shared_ptr& rhs, bool check_metadata, + bool check_internal_field_names) { + return lhs->Equals(*rhs, check_metadata, check_internal_field_names); } // [[arrow::export]] diff --git a/r/src/schema.cpp b/r/src/schema.cpp index 0dac188ec07d5..59ba4cdca3854 100644 --- a/r/src/schema.cpp +++ b/r/src/schema.cpp @@ -152,8 +152,9 @@ cpp11::writable::raws Schema__serialize(const std::shared_ptr& sc // [[arrow::export]] bool Schema__Equals(const std::shared_ptr& schema, - const std::shared_ptr& other, bool check_metadata) { - return schema->Equals(*other, check_metadata); + const std::shared_ptr& other, bool check_metadata, + bool check_internal_field_names) { + return schema->Equals(*other, check_metadata, check_internal_field_names); } // [[arrow::export]] diff --git a/r/tests/testthat/test-data-type.R b/r/tests/testthat/test-data-type.R index 16fcf8e0a38cb..9f4ce203d9e7b 100644 --- a/r/tests/testthat/test-data-type.R +++ b/r/tests/testthat/test-data-type.R @@ -365,6 +365,20 @@ test_that("list type works as expected", { ) expect_equal(x$value_type, int32()) expect_equal(x$value_field, field("item", int32())) + + # nullability matters in comparison + expect_false(x$Equals(list_of(field("item", int32(), nullable = FALSE)))) + + # field names don't matter by default + other_name <- list_of(field("other", int32())) + expect_equal(x, other_name) + expect_false(x$Equals(other_name, check_internal_field_names = TRUE)) + + # TODO(ARROW-18204): metadata doesn't matter by default + # other_metadata <- list_of(field("item", int32(), # nolint + # metadata = list(hello="world"))) # nolint + # expect_equal(x, other_metadata) # nolint + # expect_false(x$Equals(other_metadata, check_metadata = TRUE)) # nolint }) test_that("map type works as expected", { @@ -388,6 +402,20 @@ test_that("map type works as expected", { # we can make this comparison: # expect_equal(x$value_type, struct(key = x$key_field, value = x$item_field)) # nolint expect_false(x$keys_sorted) + + # nullability matters in comparison + expect_false(x$Equals(map_of(int32(), field("value", utf8(), nullable = FALSE)))) + + # field names don't matter by default + other_name <- map_of(int32(), field("other", utf8())) + expect_equal(x, other_name) + expect_false(x$Equals(other_name, check_internal_field_names = TRUE)) + + # TODO(ARROW-18204): metadata doesn't matter by default + # other_metadata <- map_of(int32(), # nolint + # field("value", int32(), metadata = list(hello="world"))) # nolint + # expect_equal(x, other_metadata) # nolint + # expect_false(x$Equals(other_metadata, check_metadata = TRUE)) # nolint }) test_that("map type validates arguments", { diff --git a/r/tests/testthat/test-parquet.R b/r/tests/testthat/test-parquet.R index 32170534a47c3..591805d4ff5ec 100644 --- a/r/tests/testthat/test-parquet.R +++ b/r/tests/testthat/test-parquet.R @@ -457,9 +457,8 @@ test_that("Can read parquet with nested lists and maps", { skip_if_not(dir.exists(parquet_test_data), "Parquet test data missing") pq <- read_parquet(paste0(parquet_test_data, "/nested_lists.snappy.parquet"), as_data_frame = FALSE) - # value name is "element" from parquet reader, but type default is "item" - expect_equal(pq$a$type, list_of(field("element", list_of(field("element", list_of(field("element", utf8()))))))) + expect_equal(pq$a$type, list_of(list_of(list_of(utf8())))) pq <- read_parquet(paste0(parquet_test_data, "/nested_maps.snappy.parquet"), as_data_frame = FALSE) - expect_equal(pq$a$type, map_of(utf8(), map_of(int32(), boolean()))) + expect_equal(pq$a$type, map_of(utf8(), map_of(int32(), field("val", boolean(), nullable = FALSE)))) })