From cdea1af0ac304f98a23c0d3a71f3e85e602f6887 Mon Sep 17 00:00:00 2001 From: zhixingheyi-tian Date: Wed, 12 Jan 2022 15:05:08 +0800 Subject: [PATCH] [NSE-602] Fix Array type shuffle split segmentation fault (#623) * [NSE-602] Fix Array type shuffle split segmentation fault * Fix clang code format --- .../ColumnarShuffleExchangeExec.scala | 2 +- native-sql-engine/cpp/src/shuffle/splitter.cc | 8 ++-- .../cpp/src/tests/shuffle_split_test.cc | 45 +++++++++++++++++++ 3 files changed, 51 insertions(+), 4 deletions(-) diff --git a/native-sql-engine/core/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala b/native-sql-engine/core/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala index 75bdf371b..57ebc0a60 100644 --- a/native-sql-engine/core/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala +++ b/native-sql-engine/core/src/main/scala/org/apache/spark/sql/execution/ColumnarShuffleExchangeExec.scala @@ -86,7 +86,7 @@ case class ColumnarShuffleExchangeExec( // check input datatype for (attr <- child.output) { try { - ConverterUtils.checkIfTypeSupported(attr.dataType) + ConverterUtils.createArrowField(attr) } catch { case e: UnsupportedOperationException => throw new UnsupportedOperationException( diff --git a/native-sql-engine/cpp/src/shuffle/splitter.cc b/native-sql-engine/cpp/src/shuffle/splitter.cc index a640ee958..40fec7d13 100644 --- a/native-sql-engine/cpp/src/shuffle/splitter.cc +++ b/native-sql-engine/cpp/src/shuffle/splitter.cc @@ -1315,9 +1315,11 @@ arrow::Status Splitter::AppendList( using ValueBuilderType = typename arrow::TypeTraits::BuilderType; using ValueArrayType = typename arrow::TypeTraits::ArrayType; std::vector dst_values_builders; - for (auto builder : dst_builders) { - dst_values_builders.push_back( - checked_cast(builder->value_builder())); + dst_values_builders.resize(dst_builders.size()); + for (auto i = 0; i < dst_builders.size(); ++i) { + if (dst_builders[i] != nullptr) + dst_values_builders[i] = + checked_cast(dst_builders[i]->value_builder()); } auto src_arr_values = std::dynamic_pointer_cast(src_arr->values()); diff --git a/native-sql-engine/cpp/src/tests/shuffle_split_test.cc b/native-sql-engine/cpp/src/tests/shuffle_split_test.cc index d58152b36..4fefdbc7b 100644 --- a/native-sql-engine/cpp/src/tests/shuffle_split_test.cc +++ b/native-sql-engine/cpp/src/tests/shuffle_split_test.cc @@ -525,6 +525,51 @@ TEST_F(SplitterTest, TestRoundRobinListArraySplitter) { } } +TEST_F(SplitterTest, TestHashListArraySplitterWithMorePartitions) { + int32_t num_partitions = 5; + split_options_.buffer_size = 4; + + auto f_uint64 = field("f_uint64", arrow::uint64()); + auto f_arr_str = field("f_arr", arrow::list(arrow::utf8())); + + auto rb_schema = arrow::schema({f_uint64, f_arr_str}); + + const std::vector input_batch_1_data = { + R"([1, 2])", R"([["alice0", "bob1"], ["alice2"]])"}; + std::shared_ptr input_batch_arr; + MakeInputBatch(input_batch_1_data, rb_schema, &input_batch_arr); + + auto f_2 = TreeExprBuilder::MakeField(f_uint64); + auto expr_1 = TreeExprBuilder::MakeExpression(f_2, field("f_uint64", uint64())); + + ARROW_ASSIGN_OR_THROW(splitter_, Splitter::Make("hash", rb_schema, num_partitions, + {expr_1}, split_options_)); + + ASSERT_NOT_OK(splitter_->Split(*input_batch_arr)); + + ASSERT_NOT_OK(splitter_->Stop()); + + const auto& lengths = splitter_->PartitionLengths(); + ASSERT_EQ(lengths.size(), 5); + + CheckFileExsists(splitter_->DataFile()); + + std::shared_ptr file_reader; + ARROW_ASSIGN_OR_THROW(file_reader, GetRecordBatchStreamReader(splitter_->DataFile())); + + ASSERT_EQ(*file_reader->schema(), *rb_schema); + + std::vector> batches; + ASSERT_NOT_OK(file_reader->ReadAll(&batches)); + + for (const auto& rb : batches) { + ASSERT_EQ(rb->num_columns(), rb_schema->num_fields()); + for (auto i = 0; i < rb->num_columns(); ++i) { + ASSERT_EQ(rb->column(i)->length(), rb->num_rows()); + } + } +} + TEST_F(SplitterTest, TestRoundRobinListArraySplitterwithCompression) { auto f_arr_str = field("f_arr", arrow::list(arrow::utf8())); auto f_arr_bool = field("f_bool", arrow::list(arrow::boolean()));