Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ARROW-14999: [C++] Don't check field name in ListType Equals() #13851

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions cpp/src/arrow/chunked_array.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ bool ChunkedArray::Equals(const ChunkedArray& other) const {
return false;
}
// We cannot toggle check_metadata here yet, so we don't check it
if (!type_->Equals(*other.type_, /*check_metadata=*/false)) {
if (!type_->Equals(*other.type_, /*check_metadata=*/false,
/*check_internal_field_names=*/false)) {
return false;
}

Expand Down Expand Up @@ -130,7 +131,8 @@ bool ChunkedArray::ApproxEquals(const ChunkedArray& other,
return false;
}
// We cannot toggle check_metadata here yet, so we don't check it
if (!type_->Equals(*other.type_, /*check_metadata=*/false)) {
if (!type_->Equals(*other.type_, /*check_metadata=*/false,
/*check_internal_field_names=*/false)) {
return false;
}

Expand Down
78 changes: 62 additions & 16 deletions cpp/src/arrow/compare.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
#include "arrow/util/bitmap_ops.h"
#include "arrow/util/bitmap_reader.h"
#include "arrow/util/checked_cast.h"
#include "arrow/util/key_value_metadata.h"
#include "arrow/util/logging.h"
#include "arrow/util/macros.h"
#include "arrow/util/memory.h"
Expand Down Expand Up @@ -530,7 +531,8 @@ bool CompareArrayRanges(const ArrayData& left, const ArrayData& right,
int64_t right_start_idx, const EqualOptions& options,
bool floating_approximate) {
if (left.type->id() != right.type->id() ||
!TypeEquals(*left.type, *right.type, false /* check_metadata */)) {
!TypeEquals(*left.type, *right.type, /*check_metadata=*/false,
/*check_internal_field_names=*/false)) {
return false;
}

Expand All @@ -556,8 +558,20 @@ bool CompareArrayRanges(const ArrayData& left, const ArrayData& right,

class TypeEqualsVisitor {
public:
explicit TypeEqualsVisitor(const DataType& right, bool check_metadata)
: right_(right), check_metadata_(check_metadata), result_(false) {}
explicit TypeEqualsVisitor(const DataType& right, bool check_metadata,
bool check_internal_field_names)
: right_(right),
check_metadata_(check_metadata),
check_internal_field_names_(check_internal_field_names),
result_(false) {}

bool MetadataEqual(const Field& left, const Field& right) {
if (left.HasMetadata() && right.HasMetadata()) {
return left.metadata()->Equals(*right.metadata());
} else {
return !left.HasMetadata() && !right.HasMetadata();
}
}

Status VisitChildren(const DataType& left) {
if (left.num_fields() != right_.num_fields()) {
Expand Down Expand Up @@ -626,8 +640,23 @@ class TypeEqualsVisitor {
}

template <typename T>
enable_if_t<is_list_like_type<T>::value || is_struct_type<T>::value, Status> Visit(
const T& left) {
enable_if_t<is_list_like_type<T>::value, Status> Visit(const T& left) {
std::shared_ptr<Field> left_field = left.field(0);
std::shared_ptr<Field> right_field = checked_cast<const T&>(right_).field(0);
bool equal_names =
!check_internal_field_names_ || (left_field->name() == right_field->name());
bool equal_metadata = !check_metadata_ || MetadataEqual(*left_field, *right_field);

result_ = equal_names && equal_metadata &&
(left_field->nullable() == right_field->nullable()) &&
left_field->type()->Equals(*right_field->type(), check_metadata_,
check_internal_field_names_);

return Status::OK();
}

template <typename T>
enable_if_t<is_struct_type<T>::value, Status> Visit(const T& left) {
return VisitChildren(left);
}

Expand All @@ -637,8 +666,23 @@ class TypeEqualsVisitor {
result_ = false;
return Status::OK();
}
result_ = left.key_type()->Equals(*right.key_type(), check_metadata_) &&
left.item_type()->Equals(*right.item_type(), check_metadata_);
if (check_internal_field_names_ &&
(left.item_field()->name() != right.item_field()->name() ||
left.key_field()->name() != right.key_field()->name() ||
left.value_field()->name() != right.value_field()->name())) {
result_ = false;
return Status::OK();
}
if (check_metadata_ && !(MetadataEqual(*left.item_field(), *right.item_field()) &&
MetadataEqual(*left.key_field(), *right.key_field()) &&
MetadataEqual(*left.value_field(), *right.value_field()))) {
result_ = false;
return Status::OK();
}
result_ = left.key_type()->Equals(*right.key_type(), check_metadata_,
check_internal_field_names_) &&
left.item_type()->Equals(*right.item_type(), check_metadata_,
check_internal_field_names_);
return Status::OK();
}

Expand Down Expand Up @@ -676,6 +720,7 @@ class TypeEqualsVisitor {
protected:
const DataType& right_;
bool check_metadata_;
bool check_internal_field_names_;
bool result_;
};

Expand Down Expand Up @@ -1267,13 +1312,14 @@ bool SparseTensorEquals(const SparseTensor& left, const SparseTensor& right,
}
}

bool TypeEquals(const DataType& left, const DataType& right, bool check_metadata) {
bool TypeEquals(const DataType& left, const DataType& right, bool check_metadata,
bool check_internal_field_names) {
// The arrays are the same object
if (&left == &right) {
return true;
} else if (left.id() != right.id()) {
return false;
} else {
} else if (!check_internal_field_names) {
// First try to compute fingerprints
if (check_metadata) {
const auto& left_metadata_fp = left.metadata_fingerprint();
Expand All @@ -1288,15 +1334,15 @@ bool TypeEquals(const DataType& left, const DataType& right, bool check_metadata
if (!left_fp.empty() && !right_fp.empty()) {
return left_fp == right_fp;
}
}

// TODO remove check_metadata here?
TypeEqualsVisitor visitor(right, check_metadata);
auto error = VisitTypeInline(left, &visitor);
if (!error.ok()) {
DCHECK(false) << "Types are not comparable: " << error.ToString();
}
return visitor.result();
// TODO remove check_metadata here?
TypeEqualsVisitor visitor(right, check_metadata, check_internal_field_names);
auto error = VisitTypeInline(left, &visitor);
if (!error.ok()) {
DCHECK(false) << "Types are not comparable: " << error.ToString();
}
return visitor.result();
}

} // namespace arrow
5 changes: 4 additions & 1 deletion cpp/src/arrow/compare.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,11 @@ ARROW_EXPORT bool SparseTensorEquals(const SparseTensor& left, const SparseTenso
/// \param[in] right a DataType
/// \param[in] check_metadata whether to compare KeyValueMetadata for child
/// fields
/// \param[in] check_internal_field_names whether to consider list or map types
/// with differing field names as unequal.
ARROW_EXPORT bool TypeEquals(const DataType& left, const DataType& right,
bool check_metadata = true);
bool check_metadata = true,
bool check_internal_field_names = true);

/// Returns true if scalars are equal
/// \param[in] left a Scalar
Expand Down
6 changes: 3 additions & 3 deletions cpp/src/arrow/compute/cast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,14 @@ class CastMetaFunction : public MetaFunction {
ARROW_ASSIGN_OR_RAISE(auto cast_options, ValidateOptions(options));
// args[0].type() could be a nullptr so check for that before
// we do anything with it.
if (args[0].type() && args[0].type()->Equals(*cast_options->to_type)) {
if (args[0].type() &&
args[0].type()->Equals(*cast_options->to_type, /*check_metadata=*/false,
/*check_internal_field_names=*/false)) {
// Nested types might differ in field names but still be considered equal,
// so we can only return non-nested types as-is.
if (!is_nested(args[0].type()->id())) {
return args[0];
} else if (args[0].is_array()) {
// TODO(ARROW-14999): if types are equal except for field names of list
// types, we can also use this code path.
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<ArrayData> array,
::arrow::internal::GetArrayView(
args[0].array(), cast_options->to_type.owned_type));
Expand Down
3 changes: 2 additions & 1 deletion cpp/src/arrow/table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,8 @@ Result<std::shared_ptr<Table>> PromoteTableToSchema(const std::shared_ptr<Table>
const std::shared_ptr<Schema>& schema,
MemoryPool* pool) {
const std::shared_ptr<Schema> current_schema = table->schema();
if (current_schema->Equals(*schema, /*check_metadata=*/false)) {
if (current_schema->Equals(*schema, /*check_metadata=*/false,
/*check_internal_field_names=*/true)) {
return table->ReplaceSchemaMetadata(schema->metadata());
}

Expand Down
79 changes: 58 additions & 21 deletions cpp/src/arrow/type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ Result<std::shared_ptr<Field>> Field::MergeWith(const Field& other,
other.name());
}

if (Equals(other, /*check_metadata=*/false)) {
if (Equals(other, /*check_metadata=*/false, /*check_internal_field_names=*/false)) {
return Copy();
}

Expand Down Expand Up @@ -362,12 +362,14 @@ std::shared_ptr<Field> Field::Copy() const {
return ::arrow::field(name_, type_, nullable_, metadata_);
}

bool Field::Equals(const Field& other, bool check_metadata) const {
bool Field::Equals(const Field& other, bool check_metadata,
bool check_internal_field_names) const {
if (this == &other) {
return true;
}
if (this->name_ == other.name_ && this->nullable_ == other.nullable_ &&
this->type_->Equals(*other.type_.get(), check_metadata)) {
this->type_->Equals(*other.type_.get(), check_metadata,
check_internal_field_names)) {
if (!check_metadata) {
return true;
} else if (this->HasMetadata() && other.HasMetadata()) {
Expand All @@ -381,8 +383,9 @@ bool Field::Equals(const Field& other, bool check_metadata) const {
return false;
}

bool Field::Equals(const std::shared_ptr<Field>& other, bool check_metadata) const {
return Equals(*other.get(), check_metadata);
bool Field::Equals(const std::shared_ptr<Field>& other, bool check_metadata,
bool check_internal_field_names) const {
return Equals(*other.get(), check_metadata, check_internal_field_names);
}

bool Field::IsCompatibleWith(const Field& other) const { return MergeWith(other).ok(); }
Expand All @@ -408,15 +411,17 @@ void PrintTo(const Field& field, std::ostream* os) { *os << field.ToString(); }

DataType::~DataType() {}

bool DataType::Equals(const DataType& other, bool check_metadata) const {
return TypeEquals(*this, other, check_metadata);
bool DataType::Equals(const DataType& other, bool check_metadata,
bool check_internal_field_names) const {
return TypeEquals(*this, other, check_metadata, check_internal_field_names);
}

bool DataType::Equals(const std::shared_ptr<DataType>& other) const {
bool DataType::Equals(const std::shared_ptr<DataType>& other, bool check_metadata,
bool check_internal_field_names) const {
if (!other) {
return false;
}
return Equals(*other.get());
return Equals(*other.get(), check_metadata, check_internal_field_names);
}

size_t DataType::Hash() const {
Expand Down Expand Up @@ -1557,7 +1562,8 @@ const std::vector<std::shared_ptr<Field>>& Schema::fields() const {
return impl_->fields_;
}

bool Schema::Equals(const Schema& other, bool check_metadata) const {
bool Schema::Equals(const Schema& other, bool check_metadata,
bool check_internal_field_names) const {
if (this == &other) {
return true;
}
Expand Down Expand Up @@ -1589,20 +1595,22 @@ bool Schema::Equals(const Schema& other, bool check_metadata) const {

// Fall back on field-by-field comparison
for (int i = 0; i < num_fields(); ++i) {
if (!field(i)->Equals(*other.field(i).get(), check_metadata)) {
if (!field(i)->Equals(*other.field(i).get(), check_metadata,
check_internal_field_names)) {
return false;
}
}

return true;
}

bool Schema::Equals(const std::shared_ptr<Schema>& other, bool check_metadata) const {
bool Schema::Equals(const std::shared_ptr<Schema>& other, bool check_metadata,
bool check_internal_field_names) const {
if (other == nullptr) {
return false;
}

return Equals(*other, check_metadata);
return Equals(*other, check_metadata, check_internal_field_names);
}

std::shared_ptr<Field> Schema::GetFieldByName(const std::string& name) const {
Expand Down Expand Up @@ -2136,17 +2144,33 @@ std::string DictionaryType::ComputeFingerprint() const {
}

std::string ListType::ComputeFingerprint() const {
const auto& child_fingerprint = children_[0]->fingerprint();
const auto& child_fingerprint = value_type()->fingerprint();
if (!child_fingerprint.empty()) {
return TypeIdFingerprint(*this) + "{" + child_fingerprint + "}";
std::stringstream ss;
ss << TypeIdFingerprint(*this);
if (value_field()->nullable()) {
ss << 'n';
} else {
ss << 'N';
}
Comment on lines +2151 to +2155
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nullability of internal field is now part of the fingerprint.

ss << '{' << child_fingerprint << '}';
return ss.str();
}
return "";
}

std::string LargeListType::ComputeFingerprint() const {
const auto& child_fingerprint = children_[0]->fingerprint();
const auto& child_fingerprint = value_type()->fingerprint();
if (!child_fingerprint.empty()) {
return TypeIdFingerprint(*this) + "{" + child_fingerprint + "}";
std::stringstream ss;
ss << TypeIdFingerprint(*this);
if (value_field()->nullable()) {
ss << 'n';
} else {
ss << 'N';
}
ss << '{' << child_fingerprint << '}';
return ss.str();
}
return "";
}
Expand All @@ -2155,20 +2179,33 @@ std::string MapType::ComputeFingerprint() const {
const auto& key_fingerprint = key_type()->fingerprint();
const auto& item_fingerprint = item_type()->fingerprint();
if (!key_fingerprint.empty() && !item_fingerprint.empty()) {
std::stringstream ss;
ss << TypeIdFingerprint(*this);
if (keys_sorted_) {
return TypeIdFingerprint(*this) + "s{" + key_fingerprint + item_fingerprint + "}";
ss << 's';
}
if (item_field()->nullable()) {
ss << 'n';
} else {
return TypeIdFingerprint(*this) + "{" + key_fingerprint + item_fingerprint + "}";
ss << 'N';
}
ss << '{' << key_fingerprint + item_fingerprint << '}';
return ss.str();
}
return "";
}

std::string FixedSizeListType::ComputeFingerprint() const {
const auto& child_fingerprint = children_[0]->fingerprint();
const auto& child_fingerprint = value_type()->fingerprint();
if (!child_fingerprint.empty()) {
std::stringstream ss;
ss << TypeIdFingerprint(*this) << "[" << list_size_ << "]"
ss << TypeIdFingerprint(*this);
if (value_field()->nullable()) {
ss << 'n';
} else {
ss << 'N';
}
ss << "[" << list_size_ << "]"
<< "{" << child_fingerprint << "}";
return ss.str();
}
Expand Down
Loading