diff --git a/cpp/src/arrow/compute/row/row_encoder_internal.cc b/cpp/src/arrow/compute/row/row_encoder_internal.cc index 186502dbc2366..fb58333b5ac0c 100644 --- a/cpp/src/arrow/compute/row/row_encoder_internal.cc +++ b/cpp/src/arrow/compute/row/row_encoder_internal.cc @@ -77,12 +77,12 @@ Result> MakeKeyEncoder( ARROW_ASSIGN_OR_RAISE(auto element_encoder, MakeKeyEncoder(element_type, &element_extension_type, pool)); if (type.id() == Type::LIST) { - return std::make_shared>(std::move(element_type), - std::move(element_encoder)); + return std::make_shared>( + type.type->GetSharedPtr(), std::move(element_type), std::move(element_encoder)); } ARROW_CHECK(type.id() == Type::LARGE_LIST); - return std::make_shared>(std::move(element_type), - std::move(element_encoder)); + return std::make_shared>( + type.type->GetSharedPtr(), std::move(element_type), std::move(element_encoder)); } return Status::NotImplemented("Unsupported type for row encoder", type.ToString()); diff --git a/cpp/src/arrow/compute/row/row_encoder_internal.h b/cpp/src/arrow/compute/row/row_encoder_internal.h index c7918e6853d6e..9632aabf817f2 100644 --- a/cpp/src/arrow/compute/row/row_encoder_internal.h +++ b/cpp/src/arrow/compute/row/row_encoder_internal.h @@ -18,10 +18,13 @@ #pragma once #include +#include #include "arrow/compute/kernels/codegen_internal.h" #include "arrow/visit_data_inline.h" +#include + namespace arrow { using internal::checked_cast; @@ -271,11 +274,14 @@ struct ARROW_EXPORT NullKeyEncoder : KeyEncoder { template struct ARROW_EXPORT ListKeyEncoder : KeyEncoder { + static_assert(is_list_like_type(), "ListKeyEncoder only supports ListType"); using Offset = typename ListType::offset_type; - ListKeyEncoder(std::shared_ptr element_type, + ListKeyEncoder(std::shared_ptr self_type, + std::shared_ptr element_type, std::shared_ptr element_encoder) - : element_type_(std::move(element_type)), + : self_type_(std::move(self_type)), + element_type_(std::move(element_type)), element_encoder_(std::move(element_encoder)) {} void AddLength(const ExecValue& data, int64_t batch_length, int32_t* lengths) override { @@ -283,6 +289,7 @@ struct ARROW_EXPORT ListKeyEncoder : KeyEncoder { ARROW_DCHECK_EQ(data.array.length, batch_length); const uint8_t* validity = data.array.buffers[0].data; const auto* offsets = data.array.GetValues(1); + // AddLength for each list std::vector child_lengthes; int32_t index{0}; ArraySpan tmp_child_data(data.array.child_data[0]); @@ -292,12 +299,17 @@ struct ARROW_EXPORT ListKeyEncoder : KeyEncoder { ARROW_UNUSED(i); child_lengthes.clear(); Offset list_length = offsets[i + 1] - offsets[i]; + if (list_length == 0) { + lengths[index] += kExtraByteForNull + sizeof(Offset); + ++index; + return; + } child_lengthes.resize(list_length, 0); tmp_child_data.SetSlice(offsets[i], list_length); - this->element_encoder_->AddLength(ExecValue{tmp_child_data}, batch_length, + this->element_encoder_->AddLength(ExecValue{tmp_child_data}, list_length, child_lengthes.data()); lengths[index] += kExtraByteForNull + sizeof(Offset); - for (int32_t j = 0; j < batch_length; j++) { + for (int32_t j = 0; j < list_length; j++) { lengths[index] += child_lengthes[j]; } ++index; @@ -309,7 +321,8 @@ struct ARROW_EXPORT ListKeyEncoder : KeyEncoder { } else { const auto& list_scalar = data.scalar_as(); int32_t accum_length = 0; - if (list_scalar.is_valid) { + // Counting the size of the encoded list if the list is valid + if (list_scalar.is_valid && list_scalar.value->length() > 0) { auto element_count = static_cast(list_scalar.value->length()); // Counting the size of the encoded list std::vector child_lengthes(element_count, 0); @@ -344,8 +357,13 @@ struct ARROW_EXPORT ListKeyEncoder : KeyEncoder { util::SafeStore(encoded_ptr, static_cast(child_array.length)); encoded_ptr += sizeof(Offset); // handling the child data - return element_encoder_->Encode(ExecValue{child_array}, child_array.length, - encoded_bytes); + for (int64_t i = 0; i < child_array.length; i++) { + ArraySpan tmp_child_data(child_array); + tmp_child_data.SetSlice(child_array.offset + i, 1); + RETURN_NOT_OK( + this->element_encoder_->Encode(ExecValue{tmp_child_data}, 1, &encoded_ptr)); + } + return Status::OK(); }; if (data.is_array()) { ARROW_DCHECK_EQ(data.array.length, batch_length); @@ -366,8 +384,8 @@ struct ARROW_EXPORT ListKeyEncoder : KeyEncoder { })); } else { const auto& list_scalar = data.scalar_as(); - ArraySpan span(*list_scalar.value->data()); if (list_scalar.is_valid) { + ArraySpan span(*list_scalar.value->data()); for (int64_t i = 0; i < batch_length; i++) { RETURN_NOT_OK(handle_valid_value(span)); } @@ -389,9 +407,48 @@ struct ARROW_EXPORT ListKeyEncoder : KeyEncoder { Result> Decode(uint8_t** encoded_bytes, int32_t length, MemoryPool* pool) override { - return Status::NotImplemented("ListKeyEncoder::Decode"); + std::shared_ptr null_buf; + int32_t null_count; + ARROW_RETURN_NOT_OK(DecodeNulls(pool, length, encoded_bytes, &null_buf, &null_count)); + + // Build the offsets buffer of the ListArray + ARROW_ASSIGN_OR_RAISE(auto offset_buf, + AllocateBuffer(sizeof(Offset) * (1 + length), pool)); + auto raw_offsets = offset_buf->mutable_span_as(); + Offset element_sum = 0; + raw_offsets[0] = 0; + std::vector> child_datas; + for (int32_t i = 0; i < length; ++i) { + Offset element_count = util::SafeLoadAs(encoded_bytes[i]); + element_sum += element_count; + raw_offsets[i + 1] = element_sum; + encoded_bytes[i] += sizeof(Offset); + for (Offset j = 0; j < element_count; ++j) { + ARROW_ASSIGN_OR_RAISE( + auto child_data, + element_encoder_->Decode(encoded_bytes + i, /*length=*/1, pool)); + ArraySpan array_span(*child_data); + child_datas.push_back(array_span.ToArray()); + } + } + std::shared_ptr element_data; + if (!child_datas.empty()) { + ARROW_ASSIGN_OR_RAISE(auto element_array, ::arrow::Concatenate(child_datas, pool)); + element_data = element_array->data(); + } else { + // If there are no elements, we need to create an empty array + std::unique_ptr tmp; + RETURN_NOT_OK(MakeBuilder(pool, element_type_, &tmp)); + std::shared_ptr array; + RETURN_NOT_OK(tmp->Finish(&array)); + element_data = array->data(); + } + return ArrayData::Make(self_type_, length, + {std::move(null_buf), std::move(offset_buf)}, {element_data}, + null_count); } + std::shared_ptr self_type_; std::shared_ptr element_type_; std::shared_ptr element_encoder_; // extension_type_ is used to store the extension type of the list element. @@ -466,7 +523,7 @@ struct ARROW_EXPORT ListKeyEncoder : KeyEncoder { /// /// List Type is encoded as: /// [null byte, list element count, [element 1, element 2, ...]] -/// Element count uses 4 bytes. +/// Element count uses `Offset` bytes. /// /// Currently, we only support encoding of primitive types, dictionary types /// in the list, the nested list is not supported. diff --git a/cpp/src/arrow/compute/row/row_encoder_internal_test.cc b/cpp/src/arrow/compute/row/row_encoder_internal_test.cc index 78839d1ead557..dbe9d7b4b532a 100644 --- a/cpp/src/arrow/compute/row/row_encoder_internal_test.cc +++ b/cpp/src/arrow/compute/row/row_encoder_internal_test.cc @@ -20,6 +20,8 @@ #include "arrow/compute/row/row_encoder_internal.h" +#include "arrow/array/builder_nested.h" +#include "arrow/array/builder_primitive.h" #include "arrow/array/validate.h" #include "arrow/testing/gtest_util.h" #include "arrow/type.h" @@ -65,4 +67,230 @@ TEST(TestKeyEncoder, BooleanScalar) { } } +TEST(TestKeyEncoder, ListScalar) { + // Test with a list of int32_t. + { + // Handle non null values. + auto element_encoder = std::make_shared(::arrow::int32()); + ListKeyEncoder key_encoder{::arrow::list(::arrow::int32()), + ::arrow::int32(), element_encoder}; + auto element_array = ::arrow::ArrayFromJSON(::arrow::int32(), "[1, 2, null, 4, 5]"); + ListScalar scalar{element_array}; + constexpr int64_t kBatchLength = 10; + std::vector lengths(kBatchLength); + key_encoder.AddLength(ExecValue{&scalar}, kBatchLength, lengths.data()); + // Check that the lengths are all 5 + 5 * 5 = 30. + constexpr int64_t kPayloadWidth = 30; + for (int i = 0; i < kBatchLength; ++i) { + ASSERT_EQ(kPayloadWidth, lengths[i]) << i; + } + std::array, kBatchLength> payloads{}; + std::array payload_ptrs{}; + // Reset the payload pointers to point to the beginning of each payload. + // This is necessary because the key encoder may have modified the pointers. + auto reset_payload_ptrs = [&payload_ptrs, &payloads]() { + std::transform(payloads.begin(), payloads.end(), payload_ptrs.begin(), + [](auto& payload) -> uint8_t* { return payload.data(); }); + }; + reset_payload_ptrs(); + ASSERT_OK(key_encoder.Encode(ExecValue{&scalar}, kBatchLength, payload_ptrs.data())); + reset_payload_ptrs(); + ASSERT_OK_AND_ASSIGN(auto array_data, + key_encoder.Decode(payload_ptrs.data(), kBatchLength, + ::arrow::default_memory_pool())); + auto list_array = std::make_shared(array_data); + ASSERT_OK(arrow::internal::ValidateArrayFull(*array_data)); + auto element_builder = std::make_shared(); + ::arrow::ListBuilder builder(default_memory_pool(), element_builder); + for (int i = 0; i < 10; ++i) { + ASSERT_OK(builder.Append()); + ASSERT_OK(element_builder->Append(1)); + ASSERT_OK(element_builder->Append(2)); + ASSERT_OK(element_builder->AppendNull()); + ASSERT_OK(element_builder->Append(4)); + ASSERT_OK(element_builder->Append(5)); + } + std::shared_ptr expected_array; + ASSERT_OK(builder.Finish(&expected_array)); + + // Expect the list array to be equal to the expected array. + AssertArraysEqual(*expected_array, *list_array); + } + { + // Handle non null values. + auto element_encoder = std::make_shared(::arrow::int32()); + ListKeyEncoder key_encoder{::arrow::list(::arrow::int32()), + ::arrow::int32(), element_encoder}; + auto element_array = ::arrow::ArrayFromJSON(::arrow::int32(), "[1, 2, null, 4, 5]"); + ListScalar scalar{element_array, /*is_valid=*/false}; + constexpr int64_t kBatchLength = 10; + std::vector lengths(kBatchLength); + key_encoder.AddLength(ExecValue{&scalar}, kBatchLength, lengths.data()); + // Check that the lengths are all 5. + constexpr int64_t kPayloadWidth = 5; + for (int i = 0; i < kBatchLength; ++i) { + ASSERT_EQ(kPayloadWidth, lengths[i]) << i; + } + std::array, kBatchLength> payloads{}; + std::array payload_ptrs{}; + // Reset the payload pointers to point to the beginning of each payload. + // This is necessary because the key encoder may have modified the pointers. + auto reset_payload_ptrs = [&payload_ptrs, &payloads]() { + std::transform(payloads.begin(), payloads.end(), payload_ptrs.begin(), + [](auto& payload) -> uint8_t* { return payload.data(); }); + }; + reset_payload_ptrs(); + ASSERT_OK(key_encoder.Encode(ExecValue{&scalar}, kBatchLength, payload_ptrs.data())); + reset_payload_ptrs(); + ASSERT_OK_AND_ASSIGN(auto array_data, + key_encoder.Decode(payload_ptrs.data(), kBatchLength, + ::arrow::default_memory_pool())); + auto list_array = std::make_shared(array_data); + ASSERT_OK(arrow::internal::ValidateArrayFull(*array_data)); + auto element_builder = std::make_shared(); + ::arrow::ListBuilder builder(default_memory_pool(), element_builder); + for (int i = 0; i < 10; ++i) { + ASSERT_OK(builder.AppendNull()); + } + std::shared_ptr expected_array; + ASSERT_OK(builder.Finish(&expected_array)); + + // Expect the list array to be equal to the expected array. + AssertArraysEqual(*expected_array, *list_array); + } + { + // Handle non null values. + auto element_encoder = + std::make_shared>(::arrow::utf8()); + ListKeyEncoder key_encoder{::arrow::list(::arrow::utf8()), ::arrow::utf8(), + element_encoder}; + auto element_array = ::arrow::ArrayFromJSON(::arrow::utf8(), R"(["a", "bcd", null])"); + ListScalar scalar{element_array}; + constexpr int64_t kBatchLength = 10; + std::vector lengths(kBatchLength); + key_encoder.AddLength(ExecValue{&scalar}, kBatchLength, lengths.data()); + // Check that the lengths are all 5 + 5 * 3 + 4 = 24. + constexpr int64_t kPayloadWidth = 24; + for (int i = 0; i < kBatchLength; ++i) { + ASSERT_EQ(kPayloadWidth, lengths[i]) << i; + } + std::array, kBatchLength> payloads{}; + std::array payload_ptrs{}; + // Reset the payload pointers to point to the beginning of each payload. + // This is necessary because the key encoder may have modified the pointers. + auto reset_payload_ptrs = [&payload_ptrs, &payloads]() { + std::transform(payloads.begin(), payloads.end(), payload_ptrs.begin(), + [](auto& payload) -> uint8_t* { return payload.data(); }); + }; + reset_payload_ptrs(); + ASSERT_OK(key_encoder.Encode(ExecValue{&scalar}, kBatchLength, payload_ptrs.data())); + reset_payload_ptrs(); + ASSERT_OK_AND_ASSIGN(auto array_data, + key_encoder.Decode(payload_ptrs.data(), kBatchLength, + ::arrow::default_memory_pool())); + auto list_array = std::make_shared(array_data); + ASSERT_OK(arrow::internal::ValidateArrayFull(*array_data)); + auto element_builder = std::make_shared(); + ::arrow::ListBuilder builder(default_memory_pool(), element_builder); + for (int i = 0; i < 10; ++i) { + ASSERT_OK(builder.Append()); + ASSERT_OK(element_builder->Append("a")); + ASSERT_OK(element_builder->Append("bcd")); + ASSERT_OK(element_builder->AppendNull()); + } + std::shared_ptr expected_array; + ASSERT_OK(builder.Finish(&expected_array)); + + // Expect the list array to be equal to the expected array. + AssertArraysEqual(*expected_array, *list_array); + } +} + +TEST(TestKeyEncoder, ListArray) { + auto element_type = ::arrow::int32(); + auto list_type = ::arrow::list(element_type); + auto list_array = + ::arrow::ArrayFromJSON(list_type, "[[], [1, 2], [3], null, [4, 5, 6, 7]]"); + auto element_encoder = std::make_shared(element_type); + ListKeyEncoder key_encoder{list_type, element_type, element_encoder}; + std::vector lengths(list_array->length(), 0); + // Add the lengths of the list array. + key_encoder.AddLength(ExecValue{*list_array->data()}, list_array->length(), + lengths.data()); + std::vector> payloads(list_array->length()); + for (int i = 0; i < list_array->length(); ++i) { + payloads[i].resize(lengths[i]); + } + std::vector payload_ptrs(list_array->length()); + auto reset_payload_ptrs = [&payload_ptrs, &payloads]() { + std::transform(payloads.begin(), payloads.end(), payload_ptrs.begin(), + [](auto& payload) -> uint8_t* { return payload.data(); }); + }; + reset_payload_ptrs(); + ASSERT_OK(key_encoder.Encode(ExecValue{*list_array->data()}, list_array->length(), + payload_ptrs.data())); + reset_payload_ptrs(); + ASSERT_OK_AND_ASSIGN( + auto array_data, + key_encoder.Decode(payload_ptrs.data(), static_cast(list_array->length()), + ::arrow::default_memory_pool())); + auto list_array_decoded = std::make_shared(array_data); + ASSERT_OK(arrow::internal::ValidateArrayFull(*array_data)); + // check that the decoded list array is equal to the original list array. + AssertArraysEqual(*list_array, *list_array_decoded); +} + +TEST(TestRowEncoder, SupportedTypes) { + ExecContext context; + for (const auto& fixed_sized_types : + {::arrow::int8(), ::arrow::int16(), ::arrow::int32(), ::arrow::int64(), + ::arrow::uint8(), ::arrow::uint16(), ::arrow::uint32(), ::arrow::uint64(), + ::arrow::float32(), ::arrow::float64(), ::arrow::fixed_size_binary(10)}) { + RowEncoder encoder; + ASSERT_OK(encoder.Init({fixed_sized_types}, &context)); + } + for (const auto& var_len_binary_type : {::arrow::binary(), ::arrow::large_binary(), + ::arrow::utf8(), ::arrow::large_utf8()}) { + RowEncoder encoder; + ASSERT_OK(encoder.Init({var_len_binary_type}, &context)); + } + + for (const auto& dictionary_type : + {::arrow::dictionary(::arrow::int8(), ::arrow::utf8()), + ::arrow::dictionary(::arrow::int8(), ::arrow::int8())}) { + RowEncoder encoder; + ASSERT_OK(encoder.Init({dictionary_type}, &context)); + } +} + +TEST(TestRowEncoder, UnsupportedTypes) { + ExecContext context; + for (const auto& binary_view_type : {utf8_view(), binary_view()}) { + RowEncoder encoder; + ASSERT_NOT_OK(encoder.Init({binary_view_type}, &context)); + } + for (const auto& list_type : + {list_view(::arrow::int8()), large_list_view(::arrow::int8()), + fixed_size_list(::arrow::int8(), 10)}) { + RowEncoder encoder; + ASSERT_NOT_OK(encoder.Init({list_type}, &context)); + } + { + RowEncoder encoder; + auto struct_type = struct_({field("a", ::arrow::int8())}); + ASSERT_NOT_OK(encoder.Init({struct_type}, &context)); + } + { + RowEncoder encoder; + auto map_type = map(::arrow::int8(), ::arrow::int8()); + ASSERT_NOT_OK(encoder.Init({map_type}, &context)); + } + // Nested list type is unsupported currently + { + RowEncoder encoder; + auto nested_list_type = list(list(::arrow::int8())); + ASSERT_NOT_OK(encoder.Init({nested_list_type}, &context)); + } +} + } // namespace arrow::compute::internal