From f5f0f334cb0677929dff7c19275535f675245a55 Mon Sep 17 00:00:00 2001 From: mwish Date: Mon, 2 Sep 2024 23:44:28 +0800 Subject: [PATCH] Compute: ListKeyEncoder interfaces --- .../arrow/compute/row/row_encoder_internal.cc | 107 +++++++++++------- .../arrow/compute/row/row_encoder_internal.h | 33 +++++- 2 files changed, 98 insertions(+), 42 deletions(-) diff --git a/cpp/src/arrow/compute/row/row_encoder_internal.cc b/cpp/src/arrow/compute/row/row_encoder_internal.cc index 0965e4e8f9571..56076f8940819 100644 --- a/cpp/src/arrow/compute/row/row_encoder_internal.cc +++ b/cpp/src/arrow/compute/row/row_encoder_internal.cc @@ -29,6 +29,51 @@ using internal::FirstTimeBitmapWriter; namespace compute { namespace internal { +Result> MakeKeyEncoder(const TypeHolder& column_type, std::shared_ptr* extension_type, MemoryPool* pool) { + const bool is_extension = column_type.id() == Type::EXTENSION; + const TypeHolder& type = + is_extension + ? arrow::internal::checked_cast(column_type.type) + ->storage_type() + : column_type; + + if (is_extension) { + *extension_type = arrow::internal::checked_pointer_cast( + column_type.GetSharedPtr()); + } + if (type.id() == Type::BOOL) { + return std::make_shared(); + } + + if (type.id() == Type::DICTIONARY) { + return std::make_shared(type.GetSharedPtr(), pool); + } + + if (is_fixed_width(type.id())) { + return std::make_shared(type.GetSharedPtr()); + } + + if (is_binary_like(type.id())) { + return std::make_shared>(type.GetSharedPtr()); + } + + if (is_large_binary_like(type.id())) { + return std::make_shared>(type.GetSharedPtr()); + } + + if (is_list(type.id())) { + auto element_type = ::arrow::checked_cast(type.type)->value_type(); + if (is_nested(element_type->id())) { + return Status::NotImplemented("Unsupported nested type in List for row encoder", type.ToString()); + } + std::shared_ptr element_extension_type; + ARROW_ASSIGN_OR_RAISE(auto element_encoder, MakeKeyEncoder(element_type, &element_extension_type, pool)); + return std::make_shared(std::move(element_type), std::move(element_encoder)); + } + + return Status::NotImplemented("Unsupported type for row encoder", type.ToString()); +} + // extract the null bitmap from the leading nullity bytes of encoded keys Status KeyEncoder::DecodeNulls(MemoryPool* pool, int32_t length, uint8_t** encoded_bytes, std::shared_ptr* null_bitmap, @@ -256,53 +301,32 @@ Result> DictionaryKeyEncoder::Decode(uint8_t** encode return data; } -void RowEncoder::Init(const std::vector& column_types, ExecContext* ctx) { - ctx_ = ctx; - encoders_.resize(column_types.size()); - extension_types_.resize(column_types.size()); +ListKeyEncoder::ListKeyEncoder(std::shared_ptr element_type, std::shared_ptr element_encoder) + : element_type_(std::move(element_type)), element_encoder_(std::move(element_encoder)) {} - for (size_t i = 0; i < column_types.size(); ++i) { - const bool is_extension = column_types[i].id() == Type::EXTENSION; - const TypeHolder& type = - is_extension - ? arrow::internal::checked_cast(column_types[i].type) - ->storage_type() - : column_types[i]; - - if (is_extension) { - extension_types_[i] = arrow::internal::checked_pointer_cast( - column_types[i].GetSharedPtr()); - } - if (type.id() == Type::BOOL) { - encoders_[i] = std::make_shared(); - continue; - } +void ListKeyEncoder::AddLength(const ExecValue& exec_value, int64_t batch_length, int32_t* lengths) {} - if (type.id() == Type::DICTIONARY) { - encoders_[i] = - std::make_shared(type.GetSharedPtr(), ctx->memory_pool()); - continue; - } +void ListKeyEncoder::AddLengthNull(int32_t* length) {} - if (is_fixed_width(type.id())) { - encoders_[i] = std::make_shared(type.GetSharedPtr()); - continue; - } +Status ListKeyEncoder::Encode(const ExecValue& data, int64_t batch_length, + uint8_t** encoded_bytes) { + return Status::NotImplemented("ListKeyEncoder::Encode"); +} - if (is_binary_like(type.id())) { - encoders_[i] = - std::make_shared>(type.GetSharedPtr()); - continue; - } +void ListKeyEncoder::EncodeNull(uint8_t** encoded_bytes) {} - if (is_large_binary_like(type.id())) { - encoders_[i] = - std::make_shared>(type.GetSharedPtr()); - continue; - } +Result> ListKeyEncoder::Decode(uint8_t** encoded_bytes, int32_t length, + MemoryPool* pool) { + return std::shared_ptr(nullptr); +} + +Status RowEncoder::Init(const std::vector& column_types, ExecContext* ctx) { + ctx_ = ctx; + encoders_.resize(column_types.size()); + extension_types_.resize(column_types.size()); - // We should not get here - ARROW_DCHECK(false); + for (size_t i = 0; i < column_types.size(); ++i) { + ARROW_ASSIGN_OR_RAISE(encoders_[i], MakeKeyEncoder(column_types[i], &extension_types_[i], ctx->memory_pool())); } int32_t total_length = 0; @@ -314,6 +338,7 @@ void RowEncoder::Init(const std::vector& column_types, ExecContext* for (size_t i = 0; i < column_types.size(); ++i) { encoders_[i]->EncodeNull(&buf_ptr); } + return Status::OK(); } void RowEncoder::Clear() { diff --git a/cpp/src/arrow/compute/row/row_encoder_internal.h b/cpp/src/arrow/compute/row/row_encoder_internal.h index 4d6cc34af2342..9b9256befc8eb 100644 --- a/cpp/src/arrow/compute/row/row_encoder_internal.h +++ b/cpp/src/arrow/compute/row/row_encoder_internal.h @@ -269,6 +269,28 @@ struct ARROW_EXPORT NullKeyEncoder : KeyEncoder { } }; +struct ARROW_EXPORT ListKeyEncoder : KeyEncoder { + explicit ListKeyEncoder(std::shared_ptr element_type, std::shared_ptr element_encoder); + + void AddLength(const ExecValue&, int64_t batch_length, int32_t* lengths) override; + + void AddLengthNull(int32_t* length) override; + + Status Encode(const ExecValue& data, int64_t batch_length, + uint8_t** encoded_bytes) override; + + void EncodeNull(uint8_t** encoded_bytes) override; + + Result> Decode(uint8_t** encoded_bytes, int32_t length, + MemoryPool* pool) override; + + std::shared_ptr element_type_; + std::shared_ptr element_encoder_; + // extension_type_ is used to store the extension type of the list element. + // It would be nullptr if the list element is not an extension type. + std::shared_ptr extension_type_; +}; + /// RowEncoder encodes ExecSpan to a variable length byte sequence /// created by concatenating the encoded form of each column. The encoding /// for each column depends on its data type. @@ -328,6 +350,15 @@ struct ARROW_EXPORT NullKeyEncoder : KeyEncoder { /// Null string Would be encoded as: /// 1 ( 1 byte for null) + 0 ( 4 bytes for length ) /// +/// ## List Type +/// +/// List Type is encoded as: +/// [null byte, list element count, [element 1, element 2, ...]] +/// Element count uses 4 bytes. +/// +/// Currently, we only support encoding of primitive types, dictionary types +/// in the list, the nested list is not supported. +/// /// # Row Encoding /// /// The row format is the concatenation of the encodings of each column. @@ -335,7 +366,7 @@ class ARROW_EXPORT RowEncoder { public: static constexpr int kRowIdForNulls() { return -1; } - void Init(const std::vector& column_types, ExecContext* ctx); + Status Init(const std::vector& column_types, ExecContext* ctx); void Clear(); Status EncodeAndAppend(const ExecSpan& batch); Result Decode(int64_t num_rows, const int32_t* row_ids);