Skip to content

Commit

Permalink
basic impl
Browse files Browse the repository at this point in the history
  • Loading branch information
mapleFU committed Sep 8, 2024
1 parent 4e8325e commit 72705a9
Show file tree
Hide file tree
Showing 3 changed files with 299 additions and 14 deletions.
8 changes: 4 additions & 4 deletions cpp/src/arrow/compute/row/row_encoder_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,12 @@ Result<std::shared_ptr<KeyEncoder>> MakeKeyEncoder(
ARROW_ASSIGN_OR_RAISE(auto element_encoder,
MakeKeyEncoder(element_type, &element_extension_type, pool));
if (type.id() == Type::LIST) {
return std::make_shared<ListKeyEncoder<ListType>>(std::move(element_type),
std::move(element_encoder));
return std::make_shared<ListKeyEncoder<ListType>>(
type.type->GetSharedPtr(), std::move(element_type), std::move(element_encoder));
}
ARROW_CHECK(type.id() == Type::LARGE_LIST);
return std::make_shared<ListKeyEncoder<LargeListType>>(std::move(element_type),
std::move(element_encoder));
return std::make_shared<ListKeyEncoder<LargeListType>>(
type.type->GetSharedPtr(), std::move(element_type), std::move(element_encoder));
}

return Status::NotImplemented("Unsupported type for row encoder", type.ToString());
Expand Down
77 changes: 67 additions & 10 deletions cpp/src/arrow/compute/row/row_encoder_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,13 @@
#pragma once

#include <cstdint>
#include <iostream>

#include "arrow/compute/kernels/codegen_internal.h"
#include "arrow/visit_data_inline.h"

#include <arrow/array/concatenate.h>

namespace arrow {

using internal::checked_cast;
Expand Down Expand Up @@ -271,18 +274,22 @@ struct ARROW_EXPORT NullKeyEncoder : KeyEncoder {

template <typename ListType>
struct ARROW_EXPORT ListKeyEncoder : KeyEncoder {
static_assert(is_list_like_type<ListType>(), "ListKeyEncoder only supports ListType");
using Offset = typename ListType::offset_type;

ListKeyEncoder(std::shared_ptr<DataType> element_type,
ListKeyEncoder(std::shared_ptr<DataType> self_type,
std::shared_ptr<DataType> element_type,
std::shared_ptr<KeyEncoder> element_encoder)
: element_type_(std::move(element_type)),
: self_type_(std::move(self_type)),
element_type_(std::move(element_type)),
element_encoder_(std::move(element_encoder)) {}

void AddLength(const ExecValue& data, int64_t batch_length, int32_t* lengths) override {
if (data.is_array()) {
ARROW_DCHECK_EQ(data.array.length, batch_length);
const uint8_t* validity = data.array.buffers[0].data;
const auto* offsets = data.array.GetValues<Offset>(1);
// AddLength for each list
std::vector<int32_t> child_lengthes;
int32_t index{0};
ArraySpan tmp_child_data(data.array.child_data[0]);
Expand All @@ -292,12 +299,17 @@ struct ARROW_EXPORT ListKeyEncoder : KeyEncoder {
ARROW_UNUSED(i);
child_lengthes.clear();
Offset list_length = offsets[i + 1] - offsets[i];
if (list_length == 0) {
lengths[index] += kExtraByteForNull + sizeof(Offset);
++index;
return;
}
child_lengthes.resize(list_length, 0);
tmp_child_data.SetSlice(offsets[i], list_length);
this->element_encoder_->AddLength(ExecValue{tmp_child_data}, batch_length,
this->element_encoder_->AddLength(ExecValue{tmp_child_data}, list_length,
child_lengthes.data());
lengths[index] += kExtraByteForNull + sizeof(Offset);
for (int32_t j = 0; j < batch_length; j++) {
for (int32_t j = 0; j < list_length; j++) {
lengths[index] += child_lengthes[j];
}
++index;
Expand All @@ -309,7 +321,8 @@ struct ARROW_EXPORT ListKeyEncoder : KeyEncoder {
} else {
const auto& list_scalar = data.scalar_as<BaseListScalar>();
int32_t accum_length = 0;
if (list_scalar.is_valid) {
// Counting the size of the encoded list if the list is valid
if (list_scalar.is_valid && list_scalar.value->length() > 0) {
auto element_count = static_cast<int32_t>(list_scalar.value->length());
// Counting the size of the encoded list
std::vector<int32_t> child_lengthes(element_count, 0);
Expand Down Expand Up @@ -344,8 +357,13 @@ struct ARROW_EXPORT ListKeyEncoder : KeyEncoder {
util::SafeStore(encoded_ptr, static_cast<Offset>(child_array.length));
encoded_ptr += sizeof(Offset);
// handling the child data
return element_encoder_->Encode(ExecValue{child_array}, child_array.length,
encoded_bytes);
for (int64_t i = 0; i < child_array.length; i++) {
ArraySpan tmp_child_data(child_array);
tmp_child_data.SetSlice(child_array.offset + i, 1);
RETURN_NOT_OK(
this->element_encoder_->Encode(ExecValue{tmp_child_data}, 1, &encoded_ptr));
}
return Status::OK();
};
if (data.is_array()) {
ARROW_DCHECK_EQ(data.array.length, batch_length);
Expand All @@ -366,8 +384,8 @@ struct ARROW_EXPORT ListKeyEncoder : KeyEncoder {
}));
} else {
const auto& list_scalar = data.scalar_as<BaseListScalar>();
ArraySpan span(*list_scalar.value->data());
if (list_scalar.is_valid) {
ArraySpan span(*list_scalar.value->data());
for (int64_t i = 0; i < batch_length; i++) {
RETURN_NOT_OK(handle_valid_value(span));
}
Expand All @@ -389,9 +407,48 @@ struct ARROW_EXPORT ListKeyEncoder : KeyEncoder {

Result<std::shared_ptr<ArrayData>> Decode(uint8_t** encoded_bytes, int32_t length,
MemoryPool* pool) override {
return Status::NotImplemented("ListKeyEncoder::Decode");
std::shared_ptr<Buffer> null_buf;
int32_t null_count;
ARROW_RETURN_NOT_OK(DecodeNulls(pool, length, encoded_bytes, &null_buf, &null_count));

// Build the offsets buffer of the ListArray
ARROW_ASSIGN_OR_RAISE(auto offset_buf,
AllocateBuffer(sizeof(Offset) * (1 + length), pool));
auto raw_offsets = offset_buf->mutable_span_as<Offset>();
Offset element_sum = 0;
raw_offsets[0] = 0;
std::vector<std::shared_ptr<Array>> child_datas;
for (int32_t i = 0; i < length; ++i) {
Offset element_count = util::SafeLoadAs<Offset>(encoded_bytes[i]);
element_sum += element_count;
raw_offsets[i + 1] = element_sum;
encoded_bytes[i] += sizeof(Offset);
for (Offset j = 0; j < element_count; ++j) {
ARROW_ASSIGN_OR_RAISE(
auto child_data,
element_encoder_->Decode(encoded_bytes + i, /*length=*/1, pool));
ArraySpan array_span(*child_data);
child_datas.push_back(array_span.ToArray());
}
}
std::shared_ptr<ArrayData> element_data;
if (!child_datas.empty()) {
ARROW_ASSIGN_OR_RAISE(auto element_array, ::arrow::Concatenate(child_datas, pool));
element_data = element_array->data();
} else {
// If there are no elements, we need to create an empty array
std::unique_ptr<ArrayBuilder> tmp;
RETURN_NOT_OK(MakeBuilder(pool, element_type_, &tmp));
std::shared_ptr<Array> array;
RETURN_NOT_OK(tmp->Finish(&array));
element_data = array->data();
}
return ArrayData::Make(self_type_, length,
{std::move(null_buf), std::move(offset_buf)}, {element_data},
null_count);
}

std::shared_ptr<DataType> self_type_;
std::shared_ptr<DataType> element_type_;
std::shared_ptr<KeyEncoder> element_encoder_;
// extension_type_ is used to store the extension type of the list element.
Expand Down Expand Up @@ -466,7 +523,7 @@ struct ARROW_EXPORT ListKeyEncoder : KeyEncoder {
///
/// List Type is encoded as:
/// [null byte, list element count, [element 1, element 2, ...]]
/// Element count uses 4 bytes.
/// Element count uses `Offset` bytes.
///
/// Currently, we only support encoding of primitive types, dictionary types
/// in the list, the nested list is not supported.
Expand Down
Loading

0 comments on commit 72705a9

Please sign in to comment.