From 222909a1b979b7f1a463a5611af0999ad3f40bb4 Mon Sep 17 00:00:00 2001 From: mwish Date: Sat, 7 Sep 2024 00:22:57 +0800 Subject: [PATCH] basic impl --- .../arrow/compute/row/row_encoder_internal.h | 77 ++++++++++++++++++- 1 file changed, 73 insertions(+), 4 deletions(-) diff --git a/cpp/src/arrow/compute/row/row_encoder_internal.h b/cpp/src/arrow/compute/row/row_encoder_internal.h index cfc6d221c8ef9..c7918e6853d6e 100644 --- a/cpp/src/arrow/compute/row/row_encoder_internal.h +++ b/cpp/src/arrow/compute/row/row_encoder_internal.h @@ -280,11 +280,34 @@ struct ARROW_EXPORT ListKeyEncoder : KeyEncoder { void AddLength(const ExecValue& data, int64_t batch_length, int32_t* lengths) override { if (data.is_array()) { - int64_t i = 0; ARROW_DCHECK_EQ(data.array.length, batch_length); - // TODO(mwish): implement me + const uint8_t* validity = data.array.buffers[0].data; + const auto* offsets = data.array.GetValues(1); + std::vector child_lengthes; + int32_t index{0}; + ArraySpan tmp_child_data(data.array.child_data[0]); + VisitBitBlocksVoid( + validity, data.array.offset, data.array.length, + [&](int64_t i) { + ARROW_UNUSED(i); + child_lengthes.clear(); + Offset list_length = offsets[i + 1] - offsets[i]; + child_lengthes.resize(list_length, 0); + tmp_child_data.SetSlice(offsets[i], list_length); + this->element_encoder_->AddLength(ExecValue{tmp_child_data}, batch_length, + child_lengthes.data()); + lengths[index] += kExtraByteForNull + sizeof(Offset); + for (int32_t j = 0; j < batch_length; j++) { + lengths[index] += child_lengthes[j]; + } + ++index; + }, + [&]() { + lengths[index] = kExtraByteForNull + sizeof(Offset); + ++index; + }); } else { - const auto& list_scalar = checked_cast(data.scalar); + const auto& list_scalar = data.scalar_as(); int32_t accum_length = 0; if (list_scalar.is_valid) { auto element_count = static_cast(list_scalar.value->length()); @@ -308,7 +331,53 @@ struct ARROW_EXPORT ListKeyEncoder : KeyEncoder { Status Encode(const ExecValue& data, int64_t batch_length, uint8_t** encoded_bytes) override { - return Status::NotImplemented("ListKeyEncoder::Decode"); + auto handle_null_value = [&encoded_bytes]() { + auto& encoded_ptr = *encoded_bytes++; + *encoded_ptr++ = kNullByte; + util::SafeStore(encoded_ptr, static_cast(0)); + encoded_ptr += sizeof(Offset); + }; + auto handle_valid_value = [&encoded_bytes, + this](const ArraySpan& child_array) -> Status { + auto& encoded_ptr = *encoded_bytes++; + *encoded_ptr++ = kValidByte; + util::SafeStore(encoded_ptr, static_cast(child_array.length)); + encoded_ptr += sizeof(Offset); + // handling the child data + return element_encoder_->Encode(ExecValue{child_array}, child_array.length, + encoded_bytes); + }; + if (data.is_array()) { + ARROW_DCHECK_EQ(data.array.length, batch_length); + const uint8_t* validity = data.array.buffers[0].data; + const auto* offsets = data.array.GetValues(1); + ArraySpan tmp_child_data(data.array.child_data[0]); + RETURN_NOT_OK(VisitBitBlocks( + validity, data.array.offset, data.array.length, + [&](int64_t i) { + ARROW_UNUSED(i); + Offset list_length = offsets[i + 1] - offsets[i]; + tmp_child_data.SetSlice(offsets[i], list_length); + return handle_valid_value(tmp_child_data); + }, + [&]() { + handle_null_value(); + return Status::OK(); + })); + } else { + const auto& list_scalar = data.scalar_as(); + ArraySpan span(*list_scalar.value->data()); + if (list_scalar.is_valid) { + for (int64_t i = 0; i < batch_length; i++) { + RETURN_NOT_OK(handle_valid_value(span)); + } + } else { + for (int64_t i = 0; i < batch_length; i++) { + handle_null_value(); + } + } + } + return Status::OK(); } void EncodeNull(uint8_t** encoded_bytes) override {