diff --git a/c_glib/test/dataset/test-file-system-dataset-factory.rb b/c_glib/test/dataset/test-file-system-dataset-factory.rb index 30944ccd3bba4..c1fd128a40b00 100644 --- a/c_glib/test/dataset/test-file-system-dataset-factory.rb +++ b/c_glib/test/dataset/test-file-system-dataset-factory.rb @@ -109,9 +109,9 @@ def test_validate_fragments point: Arrow::Int16DataType.new) options.validate_fragments = true message = "[file-system-dataset-factory][finish]: " + - "Invalid: Unable to merge: " + + "Type error: Unable to merge: " + "Field point has incompatible types: int16 vs int32" - error = assert_raise(Arrow::Error::Invalid) do + error = assert_raise(Arrow::Error::Type) do @factory.finish(options) end assert_equal(message, error.message.lines(chomp: true).first) diff --git a/cpp/src/arrow/dataset/discovery.cc b/cpp/src/arrow/dataset/discovery.cc index 2aca85e6adcfd..7e5acba992bec 100644 --- a/cpp/src/arrow/dataset/discovery.cc +++ b/cpp/src/arrow/dataset/discovery.cc @@ -47,7 +47,7 @@ Result> DatasetFactory::Inspect(InspectOptions options) return arrow::schema({}); } - return UnifySchemas(schemas); + return UnifySchemas(schemas, options.field_merge_options); } Result> DatasetFactory::Finish() { diff --git a/cpp/src/arrow/dataset/discovery.h b/cpp/src/arrow/dataset/discovery.h index 238b33e40fe25..6d76dcef727e7 100644 --- a/cpp/src/arrow/dataset/discovery.h +++ b/cpp/src/arrow/dataset/discovery.h @@ -58,6 +58,10 @@ struct InspectOptions { /// `kInspectAllFragments`. A value of `0` disables inspection of fragments /// altogether so only the partitioning schema will be inspected. int fragments = 1; + + /// Control how to unify types. By default, types are merged strictly (the + /// type must match exactly, except nulls can be merged with other types). + Field::MergeOptions field_merge_options = Field::MergeOptions::Defaults(); }; struct FinishOptions { diff --git a/cpp/src/arrow/dataset/discovery_test.cc b/cpp/src/arrow/dataset/discovery_test.cc index 5b0590a277346..92cec7f324963 100644 --- a/cpp/src/arrow/dataset/discovery_test.cc +++ b/cpp/src/arrow/dataset/discovery_test.cc @@ -117,9 +117,15 @@ TEST_F(MockDatasetFactoryTest, UnifySchemas) { MakeFactory({schema({i32, f64}), schema({f64, i32_fake})}); // Unification fails when fields with the same name have clashing types. - ASSERT_RAISES(Invalid, factory_->Inspect()); + ASSERT_RAISES(TypeError, factory_->Inspect()); // Return the individual schema for closer inspection should not fail. AssertInspectSchemas({schema({i32, f64}), schema({f64, i32_fake})}); + + MakeFactory({schema({field("num", int32())}), schema({field("num", float64())})}); + ASSERT_RAISES(TypeError, factory_->Inspect()); + InspectOptions permissive_options; + permissive_options.field_merge_options = Field::MergeOptions::Permissive(); + AssertInspect(schema({field("num", float64())}), permissive_options); } class FileSystemDatasetFactoryTest : public DatasetFactoryTest { @@ -335,7 +341,7 @@ TEST_F(FileSystemDatasetFactoryTest, FinishWithIncompatibleSchemaShouldFail) { ASSERT_OK_AND_ASSIGN(auto dataset, factory_->Finish(options)); MakeFactory({fs::File("test")}); - ASSERT_RAISES(Invalid, factory_->Finish(options)); + ASSERT_RAISES(TypeError, factory_->Finish(options)); // Disable validation options.validate_fragments = false; @@ -463,8 +469,8 @@ TEST(UnionDatasetFactoryTest, ConflictingSchemas) { {dataset_factory_1, dataset_factory_2, dataset_factory_3})); // schema_3 conflicts with other, Inspect/Finish should not work - ASSERT_RAISES(Invalid, factory->Inspect()); - ASSERT_RAISES(Invalid, factory->Finish()); + ASSERT_RAISES(TypeError, factory->Inspect()); + ASSERT_RAISES(TypeError, factory->Finish()); // The user can inspect without error ASSERT_OK_AND_ASSIGN(auto schemas, factory->InspectSchemas({})); @@ -474,6 +480,12 @@ TEST(UnionDatasetFactoryTest, ConflictingSchemas) { auto i32_schema = schema({i32}); ASSERT_OK_AND_ASSIGN(auto dataset, factory->Finish(i32_schema)); EXPECT_EQ(*dataset->schema(), *i32_schema); + + // The user decided to allow merging the types. + FinishOptions options; + options.inspect_options.field_merge_options = Field::MergeOptions::Permissive(); + ASSERT_OK_AND_ASSIGN(dataset, factory->Finish(options)); + EXPECT_EQ(*dataset->schema(), *schema({f64, i32})); } } // namespace dataset diff --git a/cpp/src/arrow/table.cc b/cpp/src/arrow/table.cc index 47f82631782e4..967e78f6b4db1 100644 --- a/cpp/src/arrow/table.cc +++ b/cpp/src/arrow/table.cc @@ -30,6 +30,7 @@ #include "arrow/array/concatenate.h" #include "arrow/array/util.h" #include "arrow/chunked_array.h" +#include "arrow/compute/cast.h" #include "arrow/pretty_print.h" #include "arrow/record_batch.h" #include "arrow/result.h" @@ -450,6 +451,13 @@ Result> ConcatenateTables( Result> PromoteTableToSchema(const std::shared_ptr& table, const std::shared_ptr& schema, MemoryPool* pool) { + return PromoteTableToSchema(table, schema, compute::CastOptions::Safe(), pool); +} + +Result> PromoteTableToSchema(const std::shared_ptr
& table, + const std::shared_ptr& schema, + const compute::CastOptions& options, + MemoryPool* pool) { const std::shared_ptr current_schema = table->schema(); if (current_schema->Equals(*schema, /*check_metadata=*/false)) { return table->ReplaceSchemaMetadata(schema->metadata()); @@ -487,8 +495,8 @@ Result> PromoteTableToSchema(const std::shared_ptr
const int field_index = field_indices[0]; const auto& current_field = current_schema->field(field_index); if (!field->nullable() && current_field->nullable()) { - return Status::Invalid("Unable to promote field ", current_field->name(), - ": it was nullable but the target schema was not."); + return Status::TypeError("Unable to promote field ", current_field->name(), + ": it was nullable but the target schema was not."); } fields_seen[field_index] = true; @@ -502,9 +510,15 @@ Result> PromoteTableToSchema(const std::shared_ptr
continue; } - return Status::Invalid("Unable to promote field ", field->name(), - ": incompatible types: ", field->type()->ToString(), " vs ", - current_field->type()->ToString()); + if (!compute::CanCast(*current_field->type(), *field->type())) { + return Status::TypeError("Unable to promote field ", field->name(), + ": incompatible types: ", field->type()->ToString(), + " vs ", current_field->type()->ToString()); + } + compute::ExecContext ctx(pool); + ARROW_ASSIGN_OR_RAISE(auto casted, compute::Cast(table->column(field_index), + field->type(), options, &ctx)); + columns.push_back(casted.chunked_array()); } auto unseen_field_iter = std::find(fields_seen.begin(), fields_seen.end(), false); diff --git a/cpp/src/arrow/table.h b/cpp/src/arrow/table.h index 940ff73ae983f..551880f237586 100644 --- a/cpp/src/arrow/table.h +++ b/cpp/src/arrow/table.h @@ -313,16 +313,23 @@ Result> ConcatenateTables( ConcatenateTablesOptions options = ConcatenateTablesOptions::Defaults(), MemoryPool* memory_pool = default_memory_pool()); +namespace compute { +class CastOptions; +} + /// \brief Promotes a table to conform to the given schema. /// -/// If a field in the schema does not have a corresponding column in the -/// table, a column of nulls will be added to the resulting table. -/// If the corresponding column is of type Null, it will be promoted to -/// the type specified by schema, with null values filled. +/// If a field in the schema does not have a corresponding column in +/// the table, a column of nulls will be added to the resulting table. +/// If the corresponding column is of type Null, it will be promoted +/// to the type specified by schema, with null values filled. The +/// column will be casted to the type specified by the schema. +/// /// Returns an error: /// - if the corresponding column's type is not compatible with the /// schema. /// - if there is a column in the table that does not exist in the schema. +/// - if the cast fails or casting would be required but is not available. /// /// \param[in] table the input Table /// \param[in] schema the target schema to promote to @@ -333,4 +340,28 @@ Result> PromoteTableToSchema( const std::shared_ptr
& table, const std::shared_ptr& schema, MemoryPool* pool = default_memory_pool()); +/// \brief Promotes a table to conform to the given schema. +/// +/// If a field in the schema does not have a corresponding column in +/// the table, a column of nulls will be added to the resulting table. +/// If the corresponding column is of type Null, it will be promoted +/// to the type specified by schema, with null values filled. The column +/// will be casted to the type specified by the schema. +/// +/// Returns an error: +/// - if the corresponding column's type is not compatible with the +/// schema. +/// - if there is a column in the table that does not exist in the schema. +/// - if the cast fails or casting would be required but is not available. +/// +/// \param[in] table the input Table +/// \param[in] schema the target schema to promote to +/// \param[in] options The cast options to allow promotion of types +/// \param[in] pool The memory pool to be used if null-filled arrays need to +/// be created. +ARROW_EXPORT +Result> PromoteTableToSchema( + const std::shared_ptr
& table, const std::shared_ptr& schema, + const compute::CastOptions& options, MemoryPool* pool = default_memory_pool()); + } // namespace arrow diff --git a/cpp/src/arrow/table_test.cc b/cpp/src/arrow/table_test.cc index 925a1ce12643f..3949caa402846 100644 --- a/cpp/src/arrow/table_test.cc +++ b/cpp/src/arrow/table_test.cc @@ -29,6 +29,7 @@ #include "arrow/array/data.h" #include "arrow/array/util.h" #include "arrow/chunked_array.h" +#include "arrow/compute/cast.h" #include "arrow/record_batch.h" #include "arrow/status.h" #include "arrow/testing/gtest_util.h" @@ -418,16 +419,17 @@ TEST_F(TestPromoteTableToSchema, IncompatibleTypes) { auto table = MakeTableWithOneNullFilledColumn("field", int32(), length); // Invalid promotion: int32 to null. - ASSERT_RAISES(Invalid, PromoteTableToSchema(table, schema({field("field", null())}))); + ASSERT_RAISES(TypeError, PromoteTableToSchema(table, schema({field("field", null())}))); - // Invalid promotion: int32 to uint32. - ASSERT_RAISES(Invalid, PromoteTableToSchema(table, schema({field("field", uint32())}))); + // Invalid promotion: int32 to list. + ASSERT_RAISES(TypeError, + PromoteTableToSchema(table, schema({field("field", list(int32()))}))); } TEST_F(TestPromoteTableToSchema, IncompatibleNullity) { const int length = 10; auto table = MakeTableWithOneNullFilledColumn("field", int32(), length); - ASSERT_RAISES(Invalid, + ASSERT_RAISES(TypeError, PromoteTableToSchema( table, schema({field("field", uint32())->WithNullable(false)}))); } @@ -520,6 +522,36 @@ TEST_F(ConcatenateTablesWithPromotionTest, Simple) { AssertTablesEqualUnorderedFields(*expected, *result); } +TEST_F(ConcatenateTablesWithPromotionTest, Unify) { + auto t_i32 = TableFromJSON(schema({field("f0", int32())}), {"[[0], [1]]"}); + auto t_i64 = TableFromJSON(schema({field("f0", int64())}), {"[[2], [3]]"}); + auto t_null = TableFromJSON(schema({field("f0", null())}), {"[[null], [null]]"}); + + auto expected_int64 = + TableFromJSON(schema({field("f0", int64())}), {"[[0], [1], [2], [3]]"}); + auto expected_null = + TableFromJSON(schema({field("f0", int32())}), {"[[0], [1], [null], [null]]"}); + + ConcatenateTablesOptions options; + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, + ::testing::HasSubstr("Schema at index 1 was different"), + ConcatenateTables({t_i32, t_i64}, options)); + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, + ::testing::HasSubstr("Schema at index 1 was different"), + ConcatenateTables({t_i32, t_null}, options)); + + options.unify_schemas = true; + EXPECT_RAISES_WITH_MESSAGE_THAT(TypeError, + ::testing::HasSubstr("Field f0 has incompatible types"), + ConcatenateTables({t_i64, t_i32}, options)); + ASSERT_OK_AND_ASSIGN(auto actual, ConcatenateTables({t_i32, t_null}, options)); + AssertTablesEqual(*expected_null, *actual, /*same_chunk_layout=*/false); + + options.field_merge_options.promote_numeric_width = true; + ASSERT_OK_AND_ASSIGN(actual, ConcatenateTables({t_i32, t_i64}, options)); + AssertTablesEqual(*expected_int64, *actual, /*same_chunk_layout=*/false); +} + TEST_F(TestTable, Slice) { const int64_t length = 10; diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc index 47bf52660ffe9..a4f43256827da 100644 --- a/cpp/src/arrow/type.cc +++ b/cpp/src/arrow/type.cc @@ -39,17 +39,21 @@ #include "arrow/status.h" #include "arrow/table.h" #include "arrow/util/checked_cast.h" +#include "arrow/util/decimal.h" #include "arrow/util/hash_util.h" #include "arrow/util/hashing.h" #include "arrow/util/key_value_metadata.h" #include "arrow/util/logging.h" #include "arrow/util/range.h" #include "arrow/util/string.h" +#include "arrow/util/unreachable.h" #include "arrow/util/vector.h" #include "arrow/visit_type_inline.h" namespace arrow { +using internal::checked_cast; + constexpr Type::type NullType::type_id; constexpr Type::type ListType::type_id; constexpr Type::type LargeListType::type_id; @@ -261,22 +265,6 @@ namespace { using internal::checked_cast; -// Merges `existing` and `other` if one of them is of NullType, otherwise -// returns nullptr. -// - if `other` if of NullType or is nullable, the unified field will be nullable. -// - if `existing` is of NullType but other is not, the unified field will -// have `other`'s type and will be nullable -std::shared_ptr MaybePromoteNullTypes(const Field& existing, const Field& other) { - if (existing.type()->id() != Type::NA && other.type()->id() != Type::NA) { - return nullptr; - } - if (existing.type()->id() == Type::NA) { - return other.WithNullable(true)->WithMetadata(existing.metadata()); - } - // `other` must be null. - return existing.WithNullable(true); -} - FieldVector MakeFields( std::initializer_list>> init_list) { FieldVector fields; @@ -327,6 +315,459 @@ std::shared_ptr Field::WithNullable(const bool nullable) const { return std::make_shared(name_, type_, nullable, metadata_); } +Field::MergeOptions Field::MergeOptions::Permissive() { + MergeOptions options = Defaults(); + options.promote_nullability = true; + options.promote_decimal = true; + options.promote_decimal_to_float = true; + options.promote_integer_to_decimal = true; + options.promote_integer_to_float = true; + options.promote_integer_sign = true; + options.promote_numeric_width = true; + options.promote_binary = true; + options.promote_temporal_unit = true; + options.promote_list = true; + options.promote_dictionary = true; + options.promote_dictionary_ordered = false; + return options; +} + +std::string Field::MergeOptions::ToString() const { + std::stringstream ss; + ss << "MergeOptions{"; + ss << "promote_nullability=" << (promote_nullability ? "true" : "false"); + ss << ", promote_decimal=" << (promote_decimal ? "true" : "false"); + ss << ", promote_decimal_to_float=" << (promote_decimal_to_float ? "true" : "false"); + ss << ", promote_integer_to_decimal=" + << (promote_integer_to_decimal ? "true" : "false"); + ss << ", promote_integer_to_float=" << (promote_integer_to_float ? "true" : "false"); + ss << ", promote_integer_sign=" << (promote_integer_sign ? "true" : "false"); + ss << ", promote_numeric_width=" << (promote_numeric_width ? "true" : "false"); + ss << ", promote_binary=" << (promote_binary ? "true" : "false"); + ss << ", promote_temporal_unit=" << (promote_temporal_unit ? "true" : "false"); + ss << ", promote_list=" << (promote_list ? "true" : "false"); + ss << ", promote_dictionary=" << (promote_dictionary ? "true" : "false"); + ss << ", promote_dictionary_ordered=" + << (promote_dictionary_ordered ? "true" : "false"); + ss << '}'; + return ss.str(); +} + +namespace { +// Utilities for Field::MergeWith + +std::shared_ptr MakeBinary(const DataType& type) { + switch (type.id()) { + case Type::BINARY: + case Type::STRING: + return binary(); + case Type::LARGE_BINARY: + case Type::LARGE_STRING: + return large_binary(); + default: + Unreachable("Hit an unknown type"); + } + return nullptr; +} + +Result> WidenDecimals( + const std::shared_ptr& promoted_type, + const std::shared_ptr& other_type, const Field::MergeOptions& options) { + const auto& left = checked_cast(*promoted_type); + const auto& right = checked_cast(*other_type); + if (!options.promote_numeric_width && left.bit_width() != right.bit_width()) { + return Status::TypeError( + "Cannot promote decimal128 to decimal256 without promote_numeric_width=true"); + } + const int32_t max_scale = std::max(left.scale(), right.scale()); + const int32_t common_precision = + std::max(left.precision() + max_scale - left.scale(), + right.precision() + max_scale - right.scale()); + if (left.id() == Type::DECIMAL256 || right.id() == Type::DECIMAL256 || + common_precision > BasicDecimal128::kMaxPrecision) { + return DecimalType::Make(Type::DECIMAL256, common_precision, max_scale); + } + return DecimalType::Make(Type::DECIMAL128, common_precision, max_scale); +} + +Result> MergeTypes(std::shared_ptr promoted_type, + std::shared_ptr other_type, + const Field::MergeOptions& options); + +// Merge temporal types based on options. Returns nullptr for non-temporal types. +Result> MaybeMergeTemporalTypes( + const std::shared_ptr& promoted_type, + const std::shared_ptr& other_type, const Field::MergeOptions& options) { + if (options.promote_temporal_unit) { + if (promoted_type->id() == Type::DATE32 && other_type->id() == Type::DATE64) { + return date64(); + } + if (promoted_type->id() == Type::DATE64 && other_type->id() == Type::DATE32) { + return date64(); + } + + if (promoted_type->id() == Type::DURATION && other_type->id() == Type::DURATION) { + const auto& left = checked_cast(*promoted_type); + const auto& right = checked_cast(*other_type); + return duration(std::max(left.unit(), right.unit())); + } + + if (is_time(promoted_type->id()) && is_time(other_type->id())) { + const auto& left = checked_cast(*promoted_type); + const auto& right = checked_cast(*other_type); + const auto unit = std::max(left.unit(), right.unit()); + if (unit == TimeUnit::MICRO || unit == TimeUnit::NANO) { + return time64(unit); + } + return time32(unit); + } + } + + if (promoted_type->id() == Type::TIMESTAMP && other_type->id() == Type::TIMESTAMP) { + const auto& left = checked_cast(*promoted_type); + const auto& right = checked_cast(*other_type); + if (left.timezone().empty() ^ right.timezone().empty()) { + return Status::TypeError( + "Cannot merge timestamp with timezone and timestamp without timezone"); + } + if (left.timezone() != right.timezone()) { + return Status::TypeError("Cannot merge timestamps with differing timezones"); + } + if (options.promote_temporal_unit) { + return timestamp(std::max(left.unit(), right.unit()), left.timezone()); + } + } + + return nullptr; +} + +// Merge numeric types based on options. Returns nullptr for non-numeric types. +Result> MaybeMergeNumericTypes( + std::shared_ptr promoted_type, std::shared_ptr other_type, + const Field::MergeOptions& options) { + bool promoted = false; + if (options.promote_decimal_to_float) { + if (is_decimal(promoted_type->id()) && is_floating(other_type->id())) { + promoted_type = other_type; + promoted = true; + } else if (is_floating(promoted_type->id()) && is_decimal(other_type->id())) { + other_type = promoted_type; + promoted = true; + } + } + + if (options.promote_integer_to_decimal && + ((is_decimal(promoted_type->id()) && is_integer(other_type->id())) || + (is_decimal(other_type->id()) && is_integer(promoted_type->id())))) { + if (is_integer(promoted_type->id()) && is_decimal(other_type->id())) { + // Other type is always the int + promoted_type.swap(other_type); + } + ARROW_ASSIGN_OR_RAISE(const int32_t precision, + MaxDecimalDigitsForInteger(other_type->id())); + ARROW_ASSIGN_OR_RAISE(const auto promoted_decimal, + DecimalType::Make(promoted_type->id(), precision, 0)); + ARROW_ASSIGN_OR_RAISE(promoted_type, + WidenDecimals(promoted_type, promoted_decimal, options)); + return promoted_type; + } + + if (options.promote_decimal && is_decimal(promoted_type->id()) && + is_decimal(other_type->id())) { + ARROW_ASSIGN_OR_RAISE(promoted_type, + WidenDecimals(promoted_type, other_type, options)); + return promoted_type; + } + + if (options.promote_integer_sign && ((is_unsigned_integer(promoted_type->id()) && + is_signed_integer(other_type->id())) || + (is_signed_integer(promoted_type->id()) && + is_unsigned_integer(other_type->id())))) { + if (is_signed_integer(promoted_type->id()) && is_unsigned_integer(other_type->id())) { + // Other type is always the signed int + promoted_type.swap(other_type); + } + + if (!options.promote_numeric_width && + bit_width(promoted_type->id()) < bit_width(other_type->id())) { + return Status::TypeError( + "Cannot widen signed integers without promote_numeric_width=true"); + } + int max_width = + std::max(bit_width(promoted_type->id()), bit_width(other_type->id())); + + // If the unsigned one is bigger or equal to the signed one, we need another bit + if (bit_width(promoted_type->id()) >= bit_width(other_type->id())) { + ++max_width; + } + + if (max_width > 32) { + promoted_type = int64(); + } else if (max_width > 16) { + promoted_type = int32(); + } else if (max_width > 8) { + promoted_type = int16(); + } else { + promoted_type = int8(); + } + return promoted_type; + } + + if (options.promote_integer_to_float && + ((is_floating(promoted_type->id()) && is_integer(other_type->id())) || + (is_integer(promoted_type->id()) && is_floating(other_type->id())))) { + if (is_integer(promoted_type->id()) && is_floating(other_type->id())) { + // Other type is always the int + promoted_type.swap(other_type); + } + + const int int_width = bit_width(other_type->id()); + promoted = true; + if (int_width <= 8) { + other_type = float16(); + } else if (int_width <= 16) { + other_type = float32(); + } else { + other_type = float64(); + } + + if (!options.promote_numeric_width && + bit_width(promoted_type->id()) != bit_width(other_type->id())) { + return Status::TypeError("Cannot widen float without promote_numeric_width=true"); + } + } + + if (options.promote_numeric_width) { + const int max_width = + std::max(bit_width(promoted_type->id()), bit_width(other_type->id())); + if (is_floating(promoted_type->id()) && is_floating(other_type->id())) { + promoted = true; + if (max_width >= 64) { + promoted_type = float64(); + } else if (max_width >= 32) { + promoted_type = float32(); + } else { + promoted_type = float16(); + } + } else if (is_signed_integer(promoted_type->id()) && + is_signed_integer(other_type->id())) { + promoted = true; + if (max_width >= 64) { + promoted_type = int64(); + } else if (max_width >= 32) { + promoted_type = int32(); + } else if (max_width >= 16) { + promoted_type = int16(); + } else { + promoted_type = int8(); + } + } else if (is_unsigned_integer(promoted_type->id()) && + is_unsigned_integer(other_type->id())) { + promoted = true; + if (max_width >= 64) { + promoted_type = uint64(); + } else if (max_width >= 32) { + promoted_type = uint32(); + } else if (max_width >= 16) { + promoted_type = uint16(); + } else { + promoted_type = uint8(); + } + } + } + + return promoted ? promoted_type : nullptr; +} + +// Merge two dictionary types, or else give an error. +Result> MergeDictionaryTypes( + const std::shared_ptr& promoted_type, + const std::shared_ptr& other_type, const Field::MergeOptions& options) { + const auto& left = checked_cast(*promoted_type); + const auto& right = checked_cast(*other_type); + if (!options.promote_dictionary_ordered && left.ordered() != right.ordered()) { + return Status::TypeError( + "Cannot merge ordered and unordered dictionary unless " + "promote_dictionary_ordered=true"); + } + Field::MergeOptions index_options = options; + index_options.promote_integer_sign = true; + index_options.promote_numeric_width = true; + ARROW_ASSIGN_OR_RAISE( + auto indices, + MaybeMergeNumericTypes(left.index_type(), right.index_type(), index_options)); + ARROW_ASSIGN_OR_RAISE(auto values, + MergeTypes(left.value_type(), right.value_type(), options)); + auto ordered = left.ordered() && right.ordered(); + if (indices && values) { + return dictionary(indices, values, ordered); + } else if (values) { + return Status::TypeError("Could not merge dictionary index types"); + } + return Status::TypeError("Could not merge dictionary value types"); +} + +// Merge temporal types based on options. Returns nullptr for non-binary types. +Result> MaybeMergeBinaryTypes( + std::shared_ptr& promoted_type, std::shared_ptr& other_type, + const Field::MergeOptions& options) { + if (options.promote_binary) { + if (other_type->id() == Type::FIXED_SIZE_BINARY && + is_base_binary_like(promoted_type->id())) { + return MakeBinary(*promoted_type); + } else if (promoted_type->id() == Type::FIXED_SIZE_BINARY && + is_base_binary_like(other_type->id())) { + return MakeBinary(*other_type); + } else if (promoted_type->id() == Type::FIXED_SIZE_BINARY && + other_type->id() == Type::FIXED_SIZE_BINARY) { + return binary(); + } + + if ((other_type->id() == Type::LARGE_STRING || + other_type->id() == Type::LARGE_BINARY) && + (promoted_type->id() == Type::STRING || promoted_type->id() == Type::BINARY) + + ) { + // Promoted type is always large in case there are regular and large types + promoted_type.swap(other_type); + } + + // When one field is binary and the other a string + if (is_string(promoted_type->id()) && is_binary(other_type->id())) { + return MakeBinary(*promoted_type); + } else if (is_binary(promoted_type->id()) && is_string(other_type->id())) { + return MakeBinary(*promoted_type); + } + + // When the types are the same, but one is large + if ((promoted_type->id() == Type::STRING && other_type->id() == Type::LARGE_STRING) || + (promoted_type->id() == Type::LARGE_STRING && other_type->id() == Type::STRING)) { + return large_utf8(); + } else if ((promoted_type->id() == Type::BINARY && + other_type->id() == Type::LARGE_BINARY) || + (promoted_type->id() == Type::LARGE_BINARY && + other_type->id() == Type::BINARY)) { + return large_binary(); + } + } + + return nullptr; +} + +// Merge list types based on options. Returns nullptr for non-list types. +Result> MergeStructs( + const std::shared_ptr& promoted_type, + const std::shared_ptr& other_type, const Field::MergeOptions& options) { + SchemaBuilder builder(SchemaBuilder::CONFLICT_APPEND, options); + // Add the LHS fields. Duplicates will be preserved. + RETURN_NOT_OK(builder.AddFields(promoted_type->fields())); + + // Add the RHS fields. Duplicates will be merged, unless the field was + // already a duplicate, in which case we error (since we don't know which + // field to merge with). + builder.SetPolicy(SchemaBuilder::CONFLICT_MERGE); + RETURN_NOT_OK(builder.AddFields(other_type->fields())); + + ARROW_ASSIGN_OR_RAISE(auto schema, builder.Finish()); + return struct_(schema->fields()); +} + +// Merge list types based on options. Returns nullptr for non-list types. +Result> MaybeMergeListTypes( + const std::shared_ptr& promoted_type, + const std::shared_ptr& other_type, const Field::MergeOptions& options) { + if (promoted_type->id() == Type::FIXED_SIZE_LIST && + other_type->id() == Type::FIXED_SIZE_LIST) { + const auto& left = checked_cast(*promoted_type); + const auto& right = checked_cast(*other_type); + ARROW_ASSIGN_OR_RAISE( + auto value_field, + left.value_field()->MergeWith( + *right.value_field()->WithName(left.value_field()->name()), options)); + if (left.list_size() == right.list_size()) { + return fixed_size_list(std::move(value_field), left.list_size()); + } else { + return list(std::move(value_field)); + } + } else if (is_list(promoted_type->id()) && is_list(other_type->id())) { + const auto& left = checked_cast(*promoted_type); + const auto& right = checked_cast(*other_type); + ARROW_ASSIGN_OR_RAISE( + auto value_field, + left.value_field()->MergeWith( + *right.value_field()->WithName(left.value_field()->name()), options)); + + if (!options.promote_list && promoted_type->id() != other_type->id()) { + return Status::TypeError("Cannot merge lists unless promote_list=true"); + } + + if (promoted_type->id() == Type::LARGE_LIST || other_type->id() == Type::LARGE_LIST) { + return large_list(std::move(value_field)); + } else { + return list(std::move(value_field)); + } + } else if (promoted_type->id() == Type::MAP && other_type->id() == Type::MAP) { + const auto& left = checked_cast(*promoted_type); + const auto& right = checked_cast(*other_type); + ARROW_ASSIGN_OR_RAISE( + auto key_field, + left.key_field()->MergeWith( + *right.key_field()->WithName(left.key_field()->name()), options)); + ARROW_ASSIGN_OR_RAISE( + auto item_field, + left.item_field()->MergeWith( + *right.item_field()->WithName(left.item_field()->name()), options)); + return map(std::move(key_field->type()), std::move(item_field), + /*keys_sorted=*/left.keys_sorted() && right.keys_sorted()); + } else if (promoted_type->id() == Type::STRUCT && other_type->id() == Type::STRUCT) { + return MergeStructs(promoted_type, other_type, options); + } + + return nullptr; +} + +Result> MergeTypes(std::shared_ptr promoted_type, + std::shared_ptr other_type, + const Field::MergeOptions& options) { + if (promoted_type->Equals(*other_type)) return promoted_type; + + bool promoted = false; + if (options.promote_nullability) { + if (promoted_type->id() == Type::NA) { + return other_type; + } else if (other_type->id() == Type::NA) { + return promoted_type; + } + } else if (promoted_type->id() == Type::NA || other_type->id() == Type::NA) { + return Status::TypeError( + "Cannot merge type with null unless promote_nullability=true"); + } + + if (options.promote_dictionary && is_dictionary(promoted_type->id()) && + is_dictionary(other_type->id())) { + return MergeDictionaryTypes(promoted_type, other_type, options); + } + + ARROW_ASSIGN_OR_RAISE(auto maybe_promoted, + MaybeMergeTemporalTypes(promoted_type, other_type, options)); + if (maybe_promoted) return maybe_promoted; + + ARROW_ASSIGN_OR_RAISE(maybe_promoted, + MaybeMergeNumericTypes(promoted_type, other_type, options)); + if (maybe_promoted) return maybe_promoted; + + ARROW_ASSIGN_OR_RAISE(maybe_promoted, + MaybeMergeBinaryTypes(promoted_type, other_type, options)); + if (maybe_promoted) return maybe_promoted; + + ARROW_ASSIGN_OR_RAISE(maybe_promoted, + MaybeMergeListTypes(promoted_type, other_type, options)); + if (maybe_promoted) return maybe_promoted; + + return promoted ? promoted_type : nullptr; +} +} // namespace + Result> Field::MergeWith(const Field& other, MergeOptions options) const { if (name() != other.name()) { @@ -338,17 +779,30 @@ Result> Field::MergeWith(const Field& other, return Copy(); } - if (options.promote_nullability) { - if (type()->Equals(other.type())) { - return Copy()->WithNullable(nullable() || other.nullable()); + auto maybe_promoted_type = MergeTypes(type_, other.type(), options); + if (!maybe_promoted_type.ok()) { + return maybe_promoted_type.status().WithMessage( + "Unable to merge: Field ", name(), + " has incompatible types: ", type()->ToString(), " vs ", other.type()->ToString(), + ": ", maybe_promoted_type.status().message()); + } + auto promoted_type = *std::move(maybe_promoted_type); + if (promoted_type) { + bool nullable = nullable_; + if (options.promote_nullability) { + nullable = nullable || other.nullable() || type_->id() == Type::NA || + other.type()->id() == Type::NA; + } else if (nullable_ != other.nullable()) { + return Status::TypeError("Unable to merge: Field ", name(), + " has incompatible nullability: ", nullable_, " vs ", + other.nullable()); } - std::shared_ptr promoted = MaybePromoteNullTypes(*this, other); - if (promoted) return promoted; - } - return Status::Invalid("Unable to merge: Field ", name(), - " has incompatible types: ", type()->ToString(), " vs ", - other.type()->ToString()); + return std::make_shared(name_, promoted_type, nullable, metadata_); + } + return Status::TypeError("Unable to merge: Field ", name(), + " has incompatible types: ", type()->ToString(), " vs ", + other.type()->ToString()); } Result> Field::MergeWith(const std::shared_ptr& other, @@ -2022,7 +2476,8 @@ class SchemaBuilder::Impl { if (policy_ == CONFLICT_REPLACE) { fields_[i] = field; } else if (policy_ == CONFLICT_MERGE) { - ARROW_ASSIGN_OR_RAISE(fields_[i], fields_[i]->MergeWith(field)); + ARROW_ASSIGN_OR_RAISE(fields_[i], + fields_[i]->MergeWith(field, field_merge_options_)); } return Status::OK(); @@ -2643,6 +3098,12 @@ std::shared_ptr map(std::shared_ptr key_type, keys_sorted); } +std::shared_ptr map(std::shared_ptr key_field, + std::shared_ptr item_field, bool keys_sorted) { + return std::make_shared(std::move(key_field), std::move(item_field), + keys_sorted); +} + std::shared_ptr fixed_size_list(const std::shared_ptr& value_type, int32_t list_size) { return std::make_shared(value_type, list_size); diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h index 19910979287cc..3f4dd5c9b21fa 100644 --- a/cpp/src/arrow/type.h +++ b/cpp/src/arrow/type.h @@ -397,14 +397,76 @@ class ARROW_EXPORT Field : public detail::Fingerprintable, /// \brief Options that control the behavior of `MergeWith`. /// Options are to be added to allow type conversions, including integer /// widening, promotion from integer to float, or conversion to or from boolean. - struct MergeOptions { + struct ARROW_EXPORT MergeOptions : public util::ToStringOstreamable { /// If true, a Field of NullType can be unified with a Field of another type. /// The unified field will be of the other type and become nullable. /// Nullability will be promoted to the looser option (nullable if one is not /// nullable). bool promote_nullability = true; + /// Allow a decimal to be unified with another decimal of the same + /// width, adjusting scale and precision as appropriate. May fail + /// if the adjustment is not possible. + bool promote_decimal = false; + + /// Allow a decimal to be promoted to a float. The float type will + /// not itself be promoted (e.g. Decimal128 + Float32 = Float32). + bool promote_decimal_to_float = false; + + /// Allow an integer to be promoted to a decimal. + /// + /// May fail if the decimal has insufficient precision to + /// accommodate the integer (see promote_numeric_width). + bool promote_integer_to_decimal = false; + + /// Allow an integer of a given bit width to be promoted to a + /// float; the result will be a float of an equal or greater bit + /// width to both of the inputs. Examples: + /// - int8 + float32 = float32 + /// - int32 + float32 = float64 + /// - int32 + float64 = float64 + /// Because an int32 cannot always be represented exactly in the + /// 24 bits of a float32 mantissa. + bool promote_integer_to_float = false; + + /// Allow an unsigned integer of a given bit width to be promoted + /// to a signed integer that fits into the signed type: + /// uint + int16 = int16 + /// When widening is needed, set promote_numeric_width to true: + /// uint16 + int16 = int32 + bool promote_integer_sign = false; + + /// Allow an integer, float, or decimal of a given bit width to be + /// promoted to an equivalent type of a greater bit width. + bool promote_numeric_width = false; + + /// Allow strings to be promoted to binary types. Promotion of fixed size + /// binary types to variable sized formats, and binary to large binary, + /// and string to large string. + bool promote_binary = false; + + /// Second to millisecond, Time32 to Time64, Time32(SECOND) to Time32(MILLI), etc + bool promote_temporal_unit = false; + + /// Allow promotion from a list to a large-list and from a fixed-size list to a + /// variable sized list + bool promote_list = false; + + /// Unify dictionary index types and dictionary value types. + bool promote_dictionary = false; + + /// Allow merging ordered and non-ordered dictionaries. + /// The result will be ordered if and only if both inputs + /// are ordered. + bool promote_dictionary_ordered = false; + + /// Get default options. Only NullType will be merged with other types. static MergeOptions Defaults() { return MergeOptions(); } + /// Get permissive options. All options are enabled, except + /// promote_dictionary_ordered. + static MergeOptions Permissive(); + /// Get a human-readable representation of the options. + std::string ToString() const; }; /// \brief Merge the current field with a field of the same name. diff --git a/cpp/src/arrow/type_test.cc b/cpp/src/arrow/type_test.cc index 3dbefdcf0c564..9ba8cf98dea4f 100644 --- a/cpp/src/arrow/type_test.cc +++ b/cpp/src/arrow/type_test.cc @@ -328,11 +328,11 @@ TEST(TestField, TestMerge) { auto null_field = field("f", null()); Field::MergeOptions options; options.promote_nullability = false; - ASSERT_RAISES(Invalid, f->MergeWith(null_field, options)); - ASSERT_RAISES(Invalid, null_field->MergeWith(f, options)); + ASSERT_RAISES(TypeError, f->MergeWith(null_field, options)); + ASSERT_RAISES(TypeError, null_field->MergeWith(f, options)); // Also rejects fields with different nullability. - ASSERT_RAISES(Invalid, + ASSERT_RAISES(TypeError, f->WithNullable(true)->MergeWith(f->WithNullable(false), options)); } { @@ -349,7 +349,7 @@ TEST(TestField, TestMerge) { ASSERT_TRUE(result->Equals(f->WithNullable(true)->WithMetadata(metadata2))); } { - // promote_nullability == true; merge a nullable field and a in-nullable field. + // promote_nullability == true; merge a nullable field and an in-nullable field. Field::MergeOptions options; options.promote_nullability = true; auto f1 = field("f", int32())->WithNullable(false); @@ -840,7 +840,7 @@ TEST(TestSchemaBuilder, PolicyMerge) { AssertSchemaBuilderYield(builder, schema({f0_opt, f1})); // Unsupported merge with a different type - ASSERT_RAISES(Invalid, builder.AddField(f0_other)); + ASSERT_RAISES(TypeError, builder.AddField(f0_other)); // Builder should still contain state AssertSchemaBuilderYield(builder, schema({f0, f1})); @@ -895,8 +895,8 @@ TEST(TestSchemaBuilder, Merge) { ASSERT_OK_AND_ASSIGN(schema, SchemaBuilder::Merge({s2, s3, s1})); AssertSchemaEqual(schema, ::arrow::schema({f1, f0_opt})); - ASSERT_RAISES(Invalid, SchemaBuilder::Merge({s3, broken})); - ASSERT_RAISES(Invalid, SchemaBuilder::AreCompatible({s3, broken})); + ASSERT_RAISES(TypeError, SchemaBuilder::Merge({s3, broken})); + ASSERT_RAISES(TypeError, SchemaBuilder::AreCompatible({s3, broken})); } class TestUnifySchemas : public TestSchema { @@ -917,6 +917,159 @@ class TestUnifySchemas : public TestSchema { << lhs_field->ToString() << " vs " << rhs_field->ToString(); } } + + void CheckUnifyAsymmetric( + const std::shared_ptr& field1, const std::shared_ptr& field2, + const std::shared_ptr& expected, + const Field::MergeOptions& options = Field::MergeOptions::Defaults()) { + ARROW_SCOPED_TRACE("options: ", options); + ARROW_SCOPED_TRACE("field2: ", field2->ToString()); + ARROW_SCOPED_TRACE("field1: ", field1->ToString()); + ASSERT_OK_AND_ASSIGN(auto merged, field1->MergeWith(field2, options)); + AssertFieldEqual(merged, expected); + } + + void CheckPromoteTo( + const std::shared_ptr& field1, const std::shared_ptr& field2, + const std::shared_ptr& expected, + const Field::MergeOptions& options = Field::MergeOptions::Defaults()) { + CheckUnifyAsymmetric(field1, field2, expected, options); + CheckUnifyAsymmetric(field2, field1, expected, options); + } + + void CheckUnifyFailsInvalid( + const std::shared_ptr& field1, const std::shared_ptr& field2, + const Field::MergeOptions& options = Field::MergeOptions::Defaults(), + const std::string& match_message = "") { + ARROW_SCOPED_TRACE("options: ", options); + ARROW_SCOPED_TRACE("field2: ", field2->ToString()); + ARROW_SCOPED_TRACE("field1: ", field1->ToString()); + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, ::testing::HasSubstr(match_message), + field1->MergeWith(field2, options)); + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, ::testing::HasSubstr(match_message), + field2->MergeWith(field1, options)); + } + + void CheckUnifyFailsTypeError( + const std::shared_ptr& field1, const std::shared_ptr& field2, + const Field::MergeOptions& options = Field::MergeOptions::Defaults(), + const std::string& match_message = "") { + ARROW_SCOPED_TRACE("options: ", options); + ARROW_SCOPED_TRACE("field2: ", field2->ToString()); + ARROW_SCOPED_TRACE("field1: ", field1->ToString()); + ASSERT_RAISES(TypeError, field1->MergeWith(field2, options)); + ASSERT_RAISES(TypeError, field2->MergeWith(field1, options)); + EXPECT_RAISES_WITH_MESSAGE_THAT(TypeError, ::testing::HasSubstr(match_message), + field1->MergeWith(field2, options)); + EXPECT_RAISES_WITH_MESSAGE_THAT(TypeError, ::testing::HasSubstr(match_message), + field2->MergeWith(field1, options)); + } + + void CheckPromoteTo( + const std::shared_ptr& left, const std::shared_ptr& right, + const std::shared_ptr& expected, + const Field::MergeOptions& options = Field::MergeOptions::Defaults()) { + auto field1 = field("a", left); + auto field2 = field("a", right); + CheckPromoteTo(field1, field2, field("a", expected), options); + + field1 = field("a", left, /*nullable=*/false); + field2 = field("a", right, /*nullable=*/false); + CheckPromoteTo(field1, field2, field("a", expected, /*nullable=*/false), options); + + field1 = field("a", left); + field2 = field("a", right, /*nullable=*/false); + CheckPromoteTo(field1, field2, field("a", expected, /*nullable=*/true), options); + + field1 = field("a", left, /*nullable=*/false); + field2 = field("a", right); + CheckPromoteTo(field1, field2, field("a", expected, /*nullable=*/true), options); + } + + void CheckUnifyAsymmetric( + const std::shared_ptr& left, const std::shared_ptr& right, + const std::shared_ptr& expected, + const Field::MergeOptions& options = Field::MergeOptions::Defaults()) { + auto field1 = field("a", left); + auto field2 = field("a", right); + CheckUnifyAsymmetric(field1, field2, field("a", expected), options); + + field1 = field("a", left, /*nullable=*/false); + field2 = field("a", right, /*nullable=*/false); + CheckUnifyAsymmetric(field1, field2, field("a", expected, /*nullable=*/false), + options); + + field1 = field("a", left); + field2 = field("a", right, /*nullable=*/false); + CheckUnifyAsymmetric(field1, field2, field("a", expected, /*nullable=*/true), + options); + + field1 = field("a", left, /*nullable=*/false); + field2 = field("a", right); + CheckUnifyAsymmetric(field1, field2, field("a", expected, /*nullable=*/true), + options); + } + + void CheckPromoteTo( + const std::shared_ptr& from, + const std::vector>& to, + const Field::MergeOptions& options = Field::MergeOptions::Defaults()) { + for (const auto& ty : to) { + CheckPromoteTo(from, ty, ty, options); + } + } + + void CheckUnifyFailsInvalid( + const std::shared_ptr& left, const std::shared_ptr& right, + const Field::MergeOptions& options = Field::MergeOptions::Defaults()) { + auto field1 = field("a", left); + auto field2 = field("a", right); + CheckUnifyFailsInvalid(field1, field2, options); + } + + void CheckUnifyFailsInvalid( + const std::shared_ptr& from, + const std::vector>& to, + const Field::MergeOptions& options = Field::MergeOptions::Defaults()) { + for (const auto& ty : to) { + CheckUnifyFailsInvalid(from, ty, options); + } + } + + void CheckUnifyFailsInvalid( + const std::vector>& from, + const std::vector>& to, + const Field::MergeOptions& options = Field::MergeOptions::Defaults()) { + for (const auto& ty : from) { + CheckUnifyFailsInvalid(ty, to, options); + } + } + + void CheckUnifyFailsTypeError( + const std::shared_ptr& left, const std::shared_ptr& right, + const Field::MergeOptions& options = Field::MergeOptions::Defaults()) { + auto field1 = field("a", left); + auto field2 = field("a", right); + CheckUnifyFailsTypeError(field1, field2, options); + } + + void CheckUnifyFailsTypeError( + const std::shared_ptr& from, + const std::vector>& to, + const Field::MergeOptions& options = Field::MergeOptions::Defaults()) { + for (const auto& ty : to) { + CheckUnifyFailsTypeError(from, ty, options); + } + } + + void CheckUnifyFailsTypeError( + const std::vector>& from, + const std::vector>& to, + const Field::MergeOptions& options = Field::MergeOptions::Defaults()) { + for (const auto& ty : from) { + CheckUnifyFailsTypeError(ty, to, options); + } + } }; TEST_F(TestUnifySchemas, EmptyInput) { ASSERT_RAISES(Invalid, UnifySchemas({})); } @@ -1008,6 +1161,252 @@ TEST_F(TestUnifySchemas, MoreSchemas) { utf8_field->WithNullable(true)})); } +TEST_F(TestUnifySchemas, Numeric) { + auto options = Field::MergeOptions::Defaults(); + options.promote_numeric_width = true; + options.promote_integer_to_float = true; + options.promote_integer_sign = true; + CheckPromoteTo( + uint8(), + {uint16(), int16(), uint32(), int32(), uint64(), int64(), float32(), float64()}, + options); + CheckPromoteTo(int8(), {int16(), int32(), int64(), float32(), float64()}, options); + CheckPromoteTo(uint16(), {uint32(), int32(), uint64(), int64(), float32(), float64()}, + options); + CheckPromoteTo(int16(), {int32(), int64(), float32(), float64()}, options); + CheckPromoteTo(uint32(), {uint64(), int64(), float64()}, options); + CheckPromoteTo(int32(), {int64(), float64()}, options); + CheckPromoteTo(uint64(), {int64(), float64()}, options); + CheckPromoteTo(int64(), {float64()}, options); + CheckPromoteTo(float16(), {float32(), float64()}, options); + CheckPromoteTo(float32(), {float64()}, options); + CheckPromoteTo(uint64(), float32(), float64(), options); + CheckPromoteTo(int64(), float32(), float64(), options); + + options.promote_integer_sign = false; + CheckPromoteTo(uint8(), {uint16(), uint32(), uint64()}, options); + CheckPromoteTo(int8(), {int16(), int32(), int64()}, options); + CheckUnifyFailsTypeError(uint8(), {int8(), int16(), int32(), int64()}, options); + CheckPromoteTo(uint16(), {uint32(), uint64()}, options); + CheckPromoteTo(int16(), {int32(), int64()}, options); + CheckUnifyFailsTypeError(uint16(), {int16(), int32(), int64()}, options); + CheckPromoteTo(uint32(), {uint64()}, options); + CheckPromoteTo(int32(), {int64()}, options); + CheckUnifyFailsTypeError(uint32(), {int32(), int64()}, options); + CheckUnifyFailsTypeError(uint64(), {int64()}, options); + + options.promote_integer_sign = true; + options.promote_integer_to_float = false; + CheckUnifyFailsTypeError(IntTypes(), FloatingPointTypes(), options); + + options.promote_integer_to_float = true; + options.promote_numeric_width = false; + CheckUnifyFailsTypeError(int8(), {int16(), int32(), int64()}, options); + CheckUnifyFailsTypeError(int16(), {int32(), int64()}, options); + CheckUnifyFailsTypeError(int32(), {int64()}, options); + CheckUnifyFailsTypeError(int32(), {float16(), float32()}, options); + CheckPromoteTo(int32(), {float64()}, options); + CheckPromoteTo(int64(), {float64()}, options); + + CheckPromoteTo(uint8(), int8(), int16(), options); + CheckPromoteTo(uint16(), int8(), int32(), options); + CheckPromoteTo(uint32(), int8(), int64(), options); + CheckPromoteTo(uint32(), int32(), int64(), options); +} + +TEST_F(TestUnifySchemas, Decimal) { + auto options = Field::MergeOptions::Defaults(); + + options.promote_decimal_to_float = true; + CheckPromoteTo(decimal128(3, 2), {float32(), float64()}, options); + CheckPromoteTo(decimal256(3, 2), {float32(), float64()}, options); + + options.promote_integer_to_decimal = true; + CheckPromoteTo(int32(), decimal128(3, 2), decimal128(12, 2), options); + CheckPromoteTo(int32(), decimal128(3, -2), decimal128(10, 0), options); + + options.promote_decimal = true; + CheckPromoteTo(decimal128(3, 2), decimal128(5, 2), decimal128(5, 2), options); + CheckPromoteTo(decimal128(3, 2), decimal128(5, 3), decimal128(5, 3), options); + CheckPromoteTo(decimal128(3, 2), decimal128(5, 1), decimal128(6, 2), options); + CheckPromoteTo(decimal128(3, 2), decimal128(5, -2), decimal128(9, 2), options); + CheckPromoteTo(decimal128(3, -2), decimal128(5, -2), decimal128(5, -2), options); + CheckPromoteTo(decimal128(38, 10), decimal128(38, 5), decimal256(43, 10), options); + + CheckPromoteTo(decimal256(3, 2), decimal256(5, 2), decimal256(5, 2), options); + CheckPromoteTo(decimal256(3, 2), decimal256(5, 3), decimal256(5, 3), options); + CheckPromoteTo(decimal256(3, 2), decimal256(5, 1), decimal256(6, 2), options); + CheckPromoteTo(decimal256(3, 2), decimal256(5, -2), decimal256(9, 2), options); + CheckPromoteTo(decimal256(3, -2), decimal256(5, -2), decimal256(5, -2), options); + + // int32() is essentially decimal128(10, 0) + CheckPromoteTo(int32(), decimal128(3, 2), decimal128(12, 2), options); + CheckPromoteTo(int32(), decimal128(3, -2), decimal128(10, 0), options); + CheckPromoteTo(int64(), decimal128(38, 37), decimal256(56, 37), options); + + CheckUnifyFailsTypeError(decimal256(1, 0), decimal128(1, 0), options); + + options.promote_numeric_width = true; + CheckPromoteTo(decimal128(3, 2), decimal256(5, 2), decimal256(5, 2), options); + CheckPromoteTo(int32(), decimal128(38, 37), decimal256(47, 37), options); + CheckUnifyFailsInvalid(decimal128(38, 10), decimal256(76, 5), options); + + CheckUnifyFailsInvalid(int64(), decimal256(76, 75), options); +} + +TEST_F(TestUnifySchemas, Temporal) { + auto options = Field::MergeOptions::Defaults(); + + options.promote_temporal_unit = true; + CheckPromoteTo(date32(), {date64()}, options); + + CheckPromoteTo( + time32(TimeUnit::SECOND), + {time32(TimeUnit::MILLI), time64(TimeUnit::MICRO), time64(TimeUnit::NANO)}, + options); + CheckPromoteTo(time32(TimeUnit::MILLI), + {time64(TimeUnit::MICRO), time64(TimeUnit::NANO)}, options); + CheckPromoteTo(time64(TimeUnit::MICRO), {time64(TimeUnit::NANO)}, options); + + CheckPromoteTo( + duration(TimeUnit::SECOND), + {duration(TimeUnit::MILLI), duration(TimeUnit::MICRO), duration(TimeUnit::NANO)}, + options); + CheckPromoteTo(duration(TimeUnit::MILLI), + {duration(TimeUnit::MICRO), duration(TimeUnit::NANO)}, options); + CheckPromoteTo(duration(TimeUnit::MICRO), {duration(TimeUnit::NANO)}, options); + + CheckPromoteTo( + timestamp(TimeUnit::SECOND), + {timestamp(TimeUnit::MILLI), timestamp(TimeUnit::MICRO), timestamp(TimeUnit::NANO)}, + options); + CheckPromoteTo(timestamp(TimeUnit::MILLI), + {timestamp(TimeUnit::MICRO), timestamp(TimeUnit::NANO)}, options); + CheckPromoteTo(timestamp(TimeUnit::MICRO), {timestamp(TimeUnit::NANO)}, options); + + CheckUnifyFailsTypeError(timestamp(TimeUnit::SECOND), + timestamp(TimeUnit::SECOND, "UTC"), options); + CheckUnifyFailsTypeError(timestamp(TimeUnit::SECOND, "America/New_York"), + timestamp(TimeUnit::SECOND, "UTC"), options); + + options.promote_temporal_unit = false; + CheckUnifyFailsTypeError(timestamp(TimeUnit::MICRO), timestamp(TimeUnit::NANO), + options); +} + +TEST_F(TestUnifySchemas, Binary) { + auto options = Field::MergeOptions::Defaults(); + options.promote_binary = true; + CheckPromoteTo(utf8(), {large_utf8(), binary(), large_binary()}, options); + CheckPromoteTo(binary(), {large_binary()}, options); + CheckPromoteTo(fixed_size_binary(2), {fixed_size_binary(2), binary(), large_binary()}, + options); + CheckPromoteTo(fixed_size_binary(2), fixed_size_binary(4), binary(), options); + + options.promote_binary = false; + CheckUnifyFailsTypeError({utf8(), binary()}, {large_utf8(), large_binary()}); + CheckUnifyFailsTypeError(fixed_size_binary(2), BaseBinaryTypes()); + CheckUnifyFailsTypeError(utf8(), {binary(), large_binary(), fixed_size_binary(2)}); +} + +TEST_F(TestUnifySchemas, List) { + auto options = Field::MergeOptions::Defaults(); + options.promote_list = true; + + CheckPromoteTo(fixed_size_list(int8(), 2), fixed_size_list(int8(), 3), list(int8())); + + CheckPromoteTo(list(int8()), {large_list(int8())}, options); + CheckPromoteTo(fixed_size_list(int8(), 2), {list(int8()), large_list(int8())}, options); + + options.promote_numeric_width = true; + CheckPromoteTo(list(int8()), {list(int16()), list(int32()), list(int64())}, options); + CheckPromoteTo( + fixed_size_list(int8(), 2), + {fixed_size_list(int16(), 2), list(int16()), list(int32()), list(int64())}, + options); + CheckPromoteTo(fixed_size_list(int16(), 2), list(int8()), list(int16()), options); + + auto ty = list(field("foo", int8(), /*nullable=*/false)); + CheckUnifyAsymmetric(ty, list(int8()), list(field("foo", int8(), /*nullable=*/true)), + options); + CheckUnifyAsymmetric(ty, list(field("bar", int16(), /*nullable=*/false)), + list(field("foo", int16(), /*nullable=*/false)), options); + + options.promote_list = false; + CheckUnifyFailsTypeError(list(int8()), large_list(int8())); +} + +TEST_F(TestUnifySchemas, Map) { + auto options = Field::MergeOptions::Defaults(); + options.promote_numeric_width = true; + + CheckPromoteTo(map(int8(), int32()), + {map(int8(), int64()), map(int16(), int32()), map(int64(), int64())}, + options); + + // Do not test field names, since MapType intentionally ignores them in comparisons + // See ARROW-7173, ARROW-14999 + auto ty = map(int8(), field("value", int32(), /*nullable=*/false)); + CheckPromoteTo(ty, map(int8(), int32()), + map(int8(), field("value", int32(), /*nullable=*/true)), options); + CheckPromoteTo(ty, map(int16(), field("value", int64(), /*nullable=*/false)), + map(int16(), field("value", int64(), /*nullable=*/false)), options); +} + +TEST_F(TestUnifySchemas, Struct) { + auto options = Field::MergeOptions::Defaults(); + options.promote_numeric_width = true; + options.promote_binary = true; + + CheckPromoteTo(struct_({}), struct_({field("a", int8())}), + struct_({field("a", int8())}), options); + + CheckUnifyAsymmetric(struct_({field("b", utf8())}), struct_({field("a", int8())}), + struct_({field("b", utf8()), field("a", int8())}), options); + CheckUnifyAsymmetric(struct_({field("a", int8())}), struct_({field("b", utf8())}), + struct_({field("a", int8()), field("b", utf8())}), options); + + CheckPromoteTo(struct_({field("b", utf8())}), struct_({field("b", binary())}), + struct_({field("b", binary())}), options); + + CheckUnifyAsymmetric( + struct_({field("a", int8()), field("b", utf8()), field("a", int64())}), + struct_({field("b", binary())}), + struct_({field("a", int8()), field("b", binary()), field("a", int64())}), options); + + ASSERT_RAISES( + Invalid, + field("foo", struct_({field("a", int8()), field("b", utf8()), field("a", int64())})) + ->MergeWith(field("foo", struct_({field("a", int64())})), options)); +} + +TEST_F(TestUnifySchemas, Dictionary) { + auto options = Field::MergeOptions::Defaults(); + options.promote_dictionary = true; + options.promote_binary = true; + + CheckPromoteTo(dictionary(int8(), utf8()), + { + dictionary(int64(), utf8()), + dictionary(int8(), large_utf8()), + }, + options); + CheckPromoteTo(dictionary(int64(), utf8()), dictionary(int8(), large_utf8()), + dictionary(int64(), large_utf8()), options); + CheckPromoteTo(dictionary(int8(), utf8(), /*ordered=*/true), + { + dictionary(int64(), utf8(), /*ordered=*/true), + dictionary(int8(), large_utf8(), /*ordered=*/true), + }, + options); + CheckUnifyFailsTypeError(dictionary(int8(), utf8()), + dictionary(int8(), utf8(), /*ordered=*/true), options); + + options.promote_dictionary_ordered = true; + CheckPromoteTo(dictionary(int8(), utf8()), dictionary(int8(), utf8(), /*ordered=*/true), + dictionary(int8(), utf8(), /*ordered=*/false), options); +} + TEST_F(TestUnifySchemas, IncompatibleTypes) { auto int32_field = field("f", int32()); auto uint8_field = field("f", uint8(), false); @@ -1015,7 +1414,7 @@ TEST_F(TestUnifySchemas, IncompatibleTypes) { auto schema1 = schema({int32_field}); auto schema2 = schema({uint8_field}); - ASSERT_RAISES(Invalid, UnifySchemas({schema1, schema2})); + ASSERT_RAISES(TypeError, UnifySchemas({schema1, schema2})); } TEST_F(TestUnifySchemas, DuplicateFieldNames) { diff --git a/cpp/src/arrow/type_traits.h b/cpp/src/arrow/type_traits.h index fdceca00a3f39..bcbde23ae4a4b 100644 --- a/cpp/src/arrow/type_traits.h +++ b/cpp/src/arrow/type_traits.h @@ -1132,6 +1132,36 @@ constexpr bool is_temporal(Type::type type_id) { return false; } +/// \brief Check for a time type +/// +/// \param[in] type_id the type-id to check +/// \return whether type-id is a primitive type one +constexpr bool is_time(Type::type type_id) { + switch (type_id) { + case Type::TIME32: + case Type::TIME64: + return true; + default: + break; + } + return false; +} + +/// \brief Check for a date type +/// +/// \param[in] type_id the type-id to check +/// \return whether type-id is a primitive type one +constexpr bool is_date(Type::type type_id) { + switch (type_id) { + case Type::DATE32: + case Type::DATE64: + return true; + default: + break; + } + return false; +} + /// \brief Check for an interval type /// /// \param[in] type_id the type-id to check @@ -1195,6 +1225,22 @@ constexpr bool is_var_length_list(Type::type type_id) { return false; } +/// \brief Check for a list type +/// +/// \param[in] type_id the type-id to check +/// \return whether type-id is a list type one +constexpr bool is_list(Type::type type_id) { + switch (type_id) { + case Type::LIST: + case Type::LARGE_LIST: + case Type::FIXED_SIZE_LIST: + return true; + default: + break; + } + return false; +} + /// \brief Check for a list-like type /// /// \param[in] type_id the type-id to check diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index 482a6e91ba929..ad79c0edcd8a1 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -409,12 +409,16 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil: const shared_ptr[CDataType]& value_type() cdef cppclass CField" arrow::Field": - cppclass CMergeOptions "arrow::Field::MergeOptions": + cppclass CMergeOptions "MergeOptions": + CMergeOptions() c_bool promote_nullability @staticmethod CMergeOptions Defaults() + @staticmethod + CMergeOptions Permissive() + const c_string& name() shared_ptr[CDataType] type() c_bool nullable() @@ -514,7 +518,8 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil: shared_ptr[CSchema] RemoveMetadata() CResult[shared_ptr[CSchema]] UnifySchemas( - const vector[shared_ptr[CSchema]]& schemas) + const vector[shared_ptr[CSchema]]& schemas, + CField.CMergeOptions field_merge_options) cdef cppclass PrettyPrintOptions: PrettyPrintOptions() diff --git a/python/pyarrow/table.pxi b/python/pyarrow/table.pxi index 36601130b3f12..72af5a2deea9c 100644 --- a/python/pyarrow/table.pxi +++ b/python/pyarrow/table.pxi @@ -5024,16 +5024,16 @@ def table(data, names=None, schema=None, metadata=None, nthreads=None): "Expected pandas DataFrame, python dictionary or list of arrays") -def concat_tables(tables, c_bool promote=False, MemoryPool memory_pool=None): +def concat_tables(tables, MemoryPool memory_pool=None, str promote_options="none", **kwargs): """ Concatenate pyarrow.Table objects. - If promote==False, a zero-copy concatenation will be performed. The schemas + If promote_options="none", a zero-copy concatenation will be performed. The schemas of all the Tables must be the same (except the metadata), otherwise an exception will be raised. The result Table will share the metadata with the first table. - If promote==True, any null type arrays will be casted to the type of other + If promote_options="default", any null type arrays will be casted to the type of other arrays in the column of the same name. If a table is missing a particular field, null values of the appropriate type will be generated to take the place of the missing field. The new schema will share the metadata with the @@ -5041,14 +5041,18 @@ def concat_tables(tables, c_bool promote=False, MemoryPool memory_pool=None): first table which has the field defined. Note that type promotions may involve additional allocations on the given ``memory_pool``. + If promote_options="permissive", the behavior of default plus types will be promoted + to the common denominator that fits all the fields. + Parameters ---------- tables : iterable of pyarrow.Table objects Pyarrow tables to concatenate into a single Table. - promote : bool, default False - If True, concatenate tables with null-filling and null type promotion. memory_pool : MemoryPool, default None For memory allocations, if required, otherwise use default pool. + promote_options : str, default none + Accepts strings "none", "default" and "permissive". + **kwargs : dict, optional Examples -------- @@ -5078,11 +5082,24 @@ def concat_tables(tables, c_bool promote=False, MemoryPool memory_pool=None): CConcatenateTablesOptions options = ( CConcatenateTablesOptions.Defaults()) + if "promote" in kwargs: + warnings.warn( + "promote has been superseded by mode='default'.", FutureWarning, stacklevel=2) + if kwargs['promote'] is True: + promote_options = "default" + for table in tables: c_tables.push_back(table.sp_table) + if promote_options == "permissive": + options.field_merge_options = CField.CMergeOptions.Permissive() + elif promote_options in {"default", "none"}: + options.field_merge_options = CField.CMergeOptions.Defaults() + else: + raise ValueError(f"Invalid promote options: {promote_options}") + with nogil: - options.unify_schemas = promote + options.unify_schemas = promote_options != "none" c_result_table = GetResultValue( ConcatenateTables(c_tables, options, pool)) diff --git a/python/pyarrow/tests/test_dataset.py b/python/pyarrow/tests/test_dataset.py index 671405d1ee6a0..6f3b54b0cd681 100644 --- a/python/pyarrow/tests/test_dataset.py +++ b/python/pyarrow/tests/test_dataset.py @@ -2921,7 +2921,7 @@ def test_union_dataset_from_other_datasets(tempdir, multisourcefs): _, path = _create_single_file(tempdir, table=table) child4 = ds.dataset(path) - with pytest.raises(pa.ArrowInvalid, match='Unable to merge'): + with pytest.raises(pa.ArrowTypeError, match='Unable to merge'): ds.dataset([child1, child4]) diff --git a/python/pyarrow/tests/test_schema.py b/python/pyarrow/tests/test_schema.py index e28e0ac445ef8..fa75fcea30db7 100644 --- a/python/pyarrow/tests/test_schema.py +++ b/python/pyarrow/tests/test_schema.py @@ -717,13 +717,16 @@ def test_schema_merge(): ]) assert result.equals(expected) - with pytest.raises(pa.ArrowInvalid): + with pytest.raises(pa.ArrowTypeError): pa.unify_schemas([b, d]) # ARROW-14002: Try with tuple instead of list result = pa.unify_schemas((a, b, c)) assert result.equals(expected) + result = pa.unify_schemas([b, d], promote_options="permissive") + assert result.equals(d) + # raise proper error when passing a non-Schema value with pytest.raises(TypeError): pa.unify_schemas([a, 1]) diff --git a/python/pyarrow/tests/test_table.py b/python/pyarrow/tests/test_table.py index b9e0d692196fb..6b48633b91f8e 100644 --- a/python/pyarrow/tests/test_table.py +++ b/python/pyarrow/tests/test_table.py @@ -1330,6 +1330,23 @@ def test_concat_tables(): assert result.equals(expected) +def test_concat_tables_permissive(): + t1 = pa.Table.from_arrays([list(range(10))], names=('a',)) + t2 = pa.Table.from_arrays([list(('a', 'b', 'c'))], names=('a',)) + + with pytest.raises( + pa.ArrowTypeError, + match="Unable to merge: Field a has incompatible types: int64 vs string"): + _ = pa.concat_tables([t1, t2], promote_options="permissive") + + +def test_concat_tables_invalid_option(): + t = pa.Table.from_arrays([list(range(10))], names=('a',)) + + with pytest.raises(ValueError, match="Invalid promote options: invalid"): + pa.concat_tables([t, t], promote_options="invalid") + + def test_concat_tables_none_table(): # ARROW-11997 with pytest.raises(AttributeError): @@ -1359,19 +1376,51 @@ def test_concat_tables_with_different_schema_metadata(): assert table2.schema.equals(table3.schema) +def test_concat_tables_with_promote_option(): + t1 = pa.Table.from_arrays( + [pa.array([1, 2], type=pa.int64())], ["int64_field"]) + t2 = pa.Table.from_arrays( + [pa.array([1.0, 2.0], type=pa.float32())], ["float_field"]) + + with pytest.warns(FutureWarning): + result = pa.concat_tables([t1, t2], promote=True) + + assert result.equals(pa.Table.from_arrays([ + pa.array([1, 2, None, None], type=pa.int64()), + pa.array([None, None, 1.0, 2.0], type=pa.float32()), + ], ["int64_field", "float_field"])) + + t1 = pa.Table.from_arrays( + [pa.array([1, 2], type=pa.int64())], ["f"]) + t2 = pa.Table.from_arrays( + [pa.array([1, 2], type=pa.float32())], ["f"]) + + with pytest.raises(pa.ArrowInvalid, match="Schema at index 1 was different:"): + with pytest.warns(FutureWarning): + pa.concat_tables([t1, t2], promote=False) + + def test_concat_tables_with_promotion(): t1 = pa.Table.from_arrays( [pa.array([1, 2], type=pa.int64())], ["int64_field"]) t2 = pa.Table.from_arrays( [pa.array([1.0, 2.0], type=pa.float32())], ["float_field"]) - result = pa.concat_tables([t1, t2], promote=True) + result = pa.concat_tables([t1, t2], promote_options="default") assert result.equals(pa.Table.from_arrays([ pa.array([1, 2, None, None], type=pa.int64()), pa.array([None, None, 1.0, 2.0], type=pa.float32()), ], ["int64_field", "float_field"])) + t3 = pa.Table.from_arrays( + [pa.array([1, 2], type=pa.int32())], ["int64_field"]) + result = pa.concat_tables( + [t1, t3], promote_options="permissive") + assert result.equals(pa.Table.from_arrays([ + pa.array([1, 2, 1, 2], type=pa.int64()), + ], ["int64_field"])) + def test_concat_tables_with_promotion_error(): t1 = pa.Table.from_arrays( @@ -1379,8 +1428,8 @@ def test_concat_tables_with_promotion_error(): t2 = pa.Table.from_arrays( [pa.array([1, 2], type=pa.float32())], ["f"]) - with pytest.raises(pa.ArrowInvalid): - pa.concat_tables([t1, t2], promote=True) + with pytest.raises(pa.ArrowTypeError, match="Unable to merge:"): + pa.concat_tables([t1, t2], promote_options="default") def test_table_negative_indexing(): diff --git a/python/pyarrow/types.pxi b/python/pyarrow/types.pxi index bd34726adb0db..764cb8e7b5d8b 100644 --- a/python/pyarrow/types.pxi +++ b/python/pyarrow/types.pxi @@ -3154,13 +3154,13 @@ cdef class Schema(_Weakrefable): return self.__str__() -def unify_schemas(schemas): +def unify_schemas(schemas, *, promote_options="default"): """ Unify schemas by merging fields by name. The resulting schema will contain the union of fields from all schemas. Fields with the same name will be merged. Note that two fields with - different types will fail merging. + different types will fail merging by default. - The unified field will inherit the metadata from the schema where that field is first defined. @@ -3174,6 +3174,10 @@ def unify_schemas(schemas): ---------- schemas : list of Schema Schemas to merge into a single one. + promote_options : str, default default + Accepts strings "default" and "permissive". + Default: null and only null can be unified with another type. + Permissive: types are promoted to the greater common denominator. Returns ------- @@ -3187,12 +3191,22 @@ def unify_schemas(schemas): """ cdef: Schema schema + CField.CMergeOptions c_options vector[shared_ptr[CSchema]] c_schemas for schema in schemas: if not isinstance(schema, Schema): raise TypeError("Expected Schema, got {}".format(type(schema))) c_schemas.push_back(pyarrow_unwrap_schema(schema)) - return pyarrow_wrap_schema(GetResultValue(UnifySchemas(c_schemas))) + + if promote_options == "default": + c_options = CField.CMergeOptions.Defaults() + elif promote_options == "permissive": + c_options = CField.CMergeOptions.Permissive() + else: + raise ValueError(f"Invalid merge mode: {promote_options}") + + return pyarrow_wrap_schema( + GetResultValue(UnifySchemas(c_schemas, c_options))) cdef dict _type_cache = {}