Skip to content

Commit

Permalink
GH-36845: [C++][Python] Allow type promotion on pa.concat_tables (#…
Browse files Browse the repository at this point in the history
…36846)

Revival of #12000

### Rationale for this change

It would be great to be able to do promotions when `concat`'ing a table, such as:

```python
def test_concat_tables_with_promotion_int():
    import pyarrow as pa
    t1 = pa.Table.from_arrays(
        [pa.array([1, 2], type=pa.int64())], ["int"])
    t2 = pa.Table.from_arrays(
        [pa.array([3, 4], type=pa.int32())], ["int"])

    result = pa.concat_tables([t1, t2], promote=True)

    assert result.equals(pa.Table.from_arrays([
        pa.array([1, 2, 3, 4], type=pa.int64())
    ], ["int"]))
```

### What changes are included in this PR?

### Are these changes tested?

### Are there any user-facing changes?

* Closes: #36845

Lead-authored-by: Fokko Driesprong <[email protected]>
Co-authored-by: David Li <[email protected]>
Co-authored-by: Joris Van den Bossche <[email protected]>
Signed-off-by: Joris Van den Bossche <[email protected]>
  • Loading branch information
3 people authored Oct 10, 2023
1 parent 0f7c569 commit 5f57219
Show file tree
Hide file tree
Showing 17 changed files with 1,220 additions and 71 deletions.
4 changes: 2 additions & 2 deletions c_glib/test/dataset/test-file-system-dataset-factory.rb
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,9 @@ def test_validate_fragments
point: Arrow::Int16DataType.new)
options.validate_fragments = true
message = "[file-system-dataset-factory][finish]: " +
"Invalid: Unable to merge: " +
"Type error: Unable to merge: " +
"Field point has incompatible types: int16 vs int32"
error = assert_raise(Arrow::Error::Invalid) do
error = assert_raise(Arrow::Error::Type) do
@factory.finish(options)
end
assert_equal(message, error.message.lines(chomp: true).first)
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/arrow/dataset/discovery.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ Result<std::shared_ptr<Schema>> DatasetFactory::Inspect(InspectOptions options)
return arrow::schema({});
}

return UnifySchemas(schemas);
return UnifySchemas(schemas, options.field_merge_options);
}

Result<std::shared_ptr<Dataset>> DatasetFactory::Finish() {
Expand Down
4 changes: 4 additions & 0 deletions cpp/src/arrow/dataset/discovery.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ struct InspectOptions {
/// `kInspectAllFragments`. A value of `0` disables inspection of fragments
/// altogether so only the partitioning schema will be inspected.
int fragments = 1;

/// Control how to unify types. By default, types are merged strictly (the
/// type must match exactly, except nulls can be merged with other types).
Field::MergeOptions field_merge_options = Field::MergeOptions::Defaults();
};

struct FinishOptions {
Expand Down
20 changes: 16 additions & 4 deletions cpp/src/arrow/dataset/discovery_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,15 @@ TEST_F(MockDatasetFactoryTest, UnifySchemas) {

MakeFactory({schema({i32, f64}), schema({f64, i32_fake})});
// Unification fails when fields with the same name have clashing types.
ASSERT_RAISES(Invalid, factory_->Inspect());
ASSERT_RAISES(TypeError, factory_->Inspect());
// Return the individual schema for closer inspection should not fail.
AssertInspectSchemas({schema({i32, f64}), schema({f64, i32_fake})});

MakeFactory({schema({field("num", int32())}), schema({field("num", float64())})});
ASSERT_RAISES(TypeError, factory_->Inspect());
InspectOptions permissive_options;
permissive_options.field_merge_options = Field::MergeOptions::Permissive();
AssertInspect(schema({field("num", float64())}), permissive_options);
}

class FileSystemDatasetFactoryTest : public DatasetFactoryTest {
Expand Down Expand Up @@ -335,7 +341,7 @@ TEST_F(FileSystemDatasetFactoryTest, FinishWithIncompatibleSchemaShouldFail) {
ASSERT_OK_AND_ASSIGN(auto dataset, factory_->Finish(options));

MakeFactory({fs::File("test")});
ASSERT_RAISES(Invalid, factory_->Finish(options));
ASSERT_RAISES(TypeError, factory_->Finish(options));

// Disable validation
options.validate_fragments = false;
Expand Down Expand Up @@ -463,8 +469,8 @@ TEST(UnionDatasetFactoryTest, ConflictingSchemas) {
{dataset_factory_1, dataset_factory_2, dataset_factory_3}));

// schema_3 conflicts with other, Inspect/Finish should not work
ASSERT_RAISES(Invalid, factory->Inspect());
ASSERT_RAISES(Invalid, factory->Finish());
ASSERT_RAISES(TypeError, factory->Inspect());
ASSERT_RAISES(TypeError, factory->Finish());

// The user can inspect without error
ASSERT_OK_AND_ASSIGN(auto schemas, factory->InspectSchemas({}));
Expand All @@ -474,6 +480,12 @@ TEST(UnionDatasetFactoryTest, ConflictingSchemas) {
auto i32_schema = schema({i32});
ASSERT_OK_AND_ASSIGN(auto dataset, factory->Finish(i32_schema));
EXPECT_EQ(*dataset->schema(), *i32_schema);

// The user decided to allow merging the types.
FinishOptions options;
options.inspect_options.field_merge_options = Field::MergeOptions::Permissive();
ASSERT_OK_AND_ASSIGN(dataset, factory->Finish(options));
EXPECT_EQ(*dataset->schema(), *schema({f64, i32}));
}

} // namespace dataset
Expand Down
24 changes: 19 additions & 5 deletions cpp/src/arrow/table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include "arrow/array/concatenate.h"
#include "arrow/array/util.h"
#include "arrow/chunked_array.h"
#include "arrow/compute/cast.h"
#include "arrow/pretty_print.h"
#include "arrow/record_batch.h"
#include "arrow/result.h"
Expand Down Expand Up @@ -450,6 +451,13 @@ Result<std::shared_ptr<Table>> ConcatenateTables(
Result<std::shared_ptr<Table>> PromoteTableToSchema(const std::shared_ptr<Table>& table,
const std::shared_ptr<Schema>& schema,
MemoryPool* pool) {
return PromoteTableToSchema(table, schema, compute::CastOptions::Safe(), pool);
}

Result<std::shared_ptr<Table>> PromoteTableToSchema(const std::shared_ptr<Table>& table,
const std::shared_ptr<Schema>& schema,
const compute::CastOptions& options,
MemoryPool* pool) {
const std::shared_ptr<Schema> current_schema = table->schema();
if (current_schema->Equals(*schema, /*check_metadata=*/false)) {
return table->ReplaceSchemaMetadata(schema->metadata());
Expand Down Expand Up @@ -487,8 +495,8 @@ Result<std::shared_ptr<Table>> PromoteTableToSchema(const std::shared_ptr<Table>
const int field_index = field_indices[0];
const auto& current_field = current_schema->field(field_index);
if (!field->nullable() && current_field->nullable()) {
return Status::Invalid("Unable to promote field ", current_field->name(),
": it was nullable but the target schema was not.");
return Status::TypeError("Unable to promote field ", current_field->name(),
": it was nullable but the target schema was not.");
}

fields_seen[field_index] = true;
Expand All @@ -502,9 +510,15 @@ Result<std::shared_ptr<Table>> PromoteTableToSchema(const std::shared_ptr<Table>
continue;
}

return Status::Invalid("Unable to promote field ", field->name(),
": incompatible types: ", field->type()->ToString(), " vs ",
current_field->type()->ToString());
if (!compute::CanCast(*current_field->type(), *field->type())) {
return Status::TypeError("Unable to promote field ", field->name(),
": incompatible types: ", field->type()->ToString(),
" vs ", current_field->type()->ToString());
}
compute::ExecContext ctx(pool);
ARROW_ASSIGN_OR_RAISE(auto casted, compute::Cast(table->column(field_index),
field->type(), options, &ctx));
columns.push_back(casted.chunked_array());
}

auto unseen_field_iter = std::find(fields_seen.begin(), fields_seen.end(), false);
Expand Down
39 changes: 35 additions & 4 deletions cpp/src/arrow/table.h
Original file line number Diff line number Diff line change
Expand Up @@ -313,16 +313,23 @@ Result<std::shared_ptr<Table>> ConcatenateTables(
ConcatenateTablesOptions options = ConcatenateTablesOptions::Defaults(),
MemoryPool* memory_pool = default_memory_pool());

namespace compute {
class CastOptions;
}

/// \brief Promotes a table to conform to the given schema.
///
/// If a field in the schema does not have a corresponding column in the
/// table, a column of nulls will be added to the resulting table.
/// If the corresponding column is of type Null, it will be promoted to
/// the type specified by schema, with null values filled.
/// If a field in the schema does not have a corresponding column in
/// the table, a column of nulls will be added to the resulting table.
/// If the corresponding column is of type Null, it will be promoted
/// to the type specified by schema, with null values filled. The
/// column will be casted to the type specified by the schema.
///
/// Returns an error:
/// - if the corresponding column's type is not compatible with the
/// schema.
/// - if there is a column in the table that does not exist in the schema.
/// - if the cast fails or casting would be required but is not available.
///
/// \param[in] table the input Table
/// \param[in] schema the target schema to promote to
Expand All @@ -333,4 +340,28 @@ Result<std::shared_ptr<Table>> PromoteTableToSchema(
const std::shared_ptr<Table>& table, const std::shared_ptr<Schema>& schema,
MemoryPool* pool = default_memory_pool());

/// \brief Promotes a table to conform to the given schema.
///
/// If a field in the schema does not have a corresponding column in
/// the table, a column of nulls will be added to the resulting table.
/// If the corresponding column is of type Null, it will be promoted
/// to the type specified by schema, with null values filled. The column
/// will be casted to the type specified by the schema.
///
/// Returns an error:
/// - if the corresponding column's type is not compatible with the
/// schema.
/// - if there is a column in the table that does not exist in the schema.
/// - if the cast fails or casting would be required but is not available.
///
/// \param[in] table the input Table
/// \param[in] schema the target schema to promote to
/// \param[in] options The cast options to allow promotion of types
/// \param[in] pool The memory pool to be used if null-filled arrays need to
/// be created.
ARROW_EXPORT
Result<std::shared_ptr<Table>> PromoteTableToSchema(
const std::shared_ptr<Table>& table, const std::shared_ptr<Schema>& schema,
const compute::CastOptions& options, MemoryPool* pool = default_memory_pool());

} // namespace arrow
40 changes: 36 additions & 4 deletions cpp/src/arrow/table_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "arrow/array/data.h"
#include "arrow/array/util.h"
#include "arrow/chunked_array.h"
#include "arrow/compute/cast.h"
#include "arrow/record_batch.h"
#include "arrow/status.h"
#include "arrow/testing/gtest_util.h"
Expand Down Expand Up @@ -418,16 +419,17 @@ TEST_F(TestPromoteTableToSchema, IncompatibleTypes) {
auto table = MakeTableWithOneNullFilledColumn("field", int32(), length);

// Invalid promotion: int32 to null.
ASSERT_RAISES(Invalid, PromoteTableToSchema(table, schema({field("field", null())})));
ASSERT_RAISES(TypeError, PromoteTableToSchema(table, schema({field("field", null())})));

// Invalid promotion: int32 to uint32.
ASSERT_RAISES(Invalid, PromoteTableToSchema(table, schema({field("field", uint32())})));
// Invalid promotion: int32 to list.
ASSERT_RAISES(TypeError,
PromoteTableToSchema(table, schema({field("field", list(int32()))})));
}

TEST_F(TestPromoteTableToSchema, IncompatibleNullity) {
const int length = 10;
auto table = MakeTableWithOneNullFilledColumn("field", int32(), length);
ASSERT_RAISES(Invalid,
ASSERT_RAISES(TypeError,
PromoteTableToSchema(
table, schema({field("field", uint32())->WithNullable(false)})));
}
Expand Down Expand Up @@ -520,6 +522,36 @@ TEST_F(ConcatenateTablesWithPromotionTest, Simple) {
AssertTablesEqualUnorderedFields(*expected, *result);
}

TEST_F(ConcatenateTablesWithPromotionTest, Unify) {
auto t_i32 = TableFromJSON(schema({field("f0", int32())}), {"[[0], [1]]"});
auto t_i64 = TableFromJSON(schema({field("f0", int64())}), {"[[2], [3]]"});
auto t_null = TableFromJSON(schema({field("f0", null())}), {"[[null], [null]]"});

auto expected_int64 =
TableFromJSON(schema({field("f0", int64())}), {"[[0], [1], [2], [3]]"});
auto expected_null =
TableFromJSON(schema({field("f0", int32())}), {"[[0], [1], [null], [null]]"});

ConcatenateTablesOptions options;
EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
::testing::HasSubstr("Schema at index 1 was different"),
ConcatenateTables({t_i32, t_i64}, options));
EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid,
::testing::HasSubstr("Schema at index 1 was different"),
ConcatenateTables({t_i32, t_null}, options));

options.unify_schemas = true;
EXPECT_RAISES_WITH_MESSAGE_THAT(TypeError,
::testing::HasSubstr("Field f0 has incompatible types"),
ConcatenateTables({t_i64, t_i32}, options));
ASSERT_OK_AND_ASSIGN(auto actual, ConcatenateTables({t_i32, t_null}, options));
AssertTablesEqual(*expected_null, *actual, /*same_chunk_layout=*/false);

options.field_merge_options.promote_numeric_width = true;
ASSERT_OK_AND_ASSIGN(actual, ConcatenateTables({t_i32, t_i64}, options));
AssertTablesEqual(*expected_int64, *actual, /*same_chunk_layout=*/false);
}

TEST_F(TestTable, Slice) {
const int64_t length = 10;

Expand Down
Loading

0 comments on commit 5f57219

Please sign in to comment.