Skip to content

Commit

Permalink
minor-refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
mapleFU committed Sep 6, 2024
1 parent 9aaff83 commit 5055ef0
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 36 deletions.
56 changes: 26 additions & 30 deletions cpp/src/arrow/compute/row/row_encoder_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,18 @@ using internal::FirstTimeBitmapWriter;
namespace compute {
namespace internal {

Result<std::shared_ptr<KeyEncoder>> MakeKeyEncoder(const TypeHolder& column_type, std::shared_ptr<ExtensionType>* extension_type, MemoryPool* pool) {
Result<std::shared_ptr<KeyEncoder>> MakeKeyEncoder(
const TypeHolder& column_type, std::shared_ptr<ExtensionType>* extension_type,
MemoryPool* pool) {
const bool is_extension = column_type.id() == Type::EXTENSION;
const TypeHolder& type =
is_extension
? arrow::internal::checked_cast<const ExtensionType*>(column_type.type)
->storage_type()
: column_type;
is_extension ? arrow::internal::checked_cast<const ExtensionType*>(column_type.type)
->storage_type()
: column_type;

if (is_extension) {
*extension_type = arrow::internal::checked_pointer_cast<ExtensionType>(
column_type.GetSharedPtr());
*extension_type =
arrow::internal::checked_pointer_cast<ExtensionType>(column_type.GetSharedPtr());
}
if (type.id() == Type::BOOL) {
return std::make_shared<BooleanKeyEncoder>();
Expand All @@ -65,11 +66,23 @@ Result<std::shared_ptr<KeyEncoder>> MakeKeyEncoder(const TypeHolder& column_type
auto element_type =
::arrow::checked_cast<const BaseListType*>(type.type)->value_type();
if (is_nested(element_type->id())) {
return Status::NotImplemented("Unsupported nested type in List for row encoder", type.ToString());
return Status::NotImplemented("Unsupported nested type in List for row encoder",
type.ToString());
}
if (type.id() == Type::FIXED_SIZE_LIST) {
return Status::NotImplemented("Unsupported FixedSizeList for row encoder",
type.ToString());
}
std::shared_ptr<ExtensionType> element_extension_type;
ARROW_ASSIGN_OR_RAISE(auto element_encoder, MakeKeyEncoder(element_type, &element_extension_type, pool));
return std::make_shared<ListKeyEncoder>(std::move(element_type), std::move(element_encoder));
ARROW_ASSIGN_OR_RAISE(auto element_encoder,
MakeKeyEncoder(element_type, &element_extension_type, pool));
if (type.id() == Type::LIST) {
return std::make_shared<ListKeyEncoder<ListType>>(std::move(element_type),
std::move(element_encoder));
}
ARROW_CHECK(type.id() == Type::LARGE_LIST);
return std::make_shared<ListKeyEncoder<LargeListType>>(std::move(element_type),
std::move(element_encoder));
}

return Status::NotImplemented("Unsupported type for row encoder", type.ToString());
Expand Down Expand Up @@ -302,32 +315,15 @@ Result<std::shared_ptr<ArrayData>> DictionaryKeyEncoder::Decode(uint8_t** encode
return data;
}

ListKeyEncoder::ListKeyEncoder(std::shared_ptr<DataType> element_type, std::shared_ptr<KeyEncoder> element_encoder)
: element_type_(std::move(element_type)), element_encoder_(std::move(element_encoder)) {}

void ListKeyEncoder::AddLength(const ExecValue& exec_value, int64_t batch_length, int32_t* lengths) {}

void ListKeyEncoder::AddLengthNull(int32_t* length) {}

Status ListKeyEncoder::Encode(const ExecValue& data, int64_t batch_length,
uint8_t** encoded_bytes) {
return Status::NotImplemented("ListKeyEncoder::Encode");
}

void ListKeyEncoder::EncodeNull(uint8_t** encoded_bytes) {}

Result<std::shared_ptr<ArrayData>> ListKeyEncoder::Decode(uint8_t** encoded_bytes, int32_t length,
MemoryPool* pool) {
return std::shared_ptr<ArrayData>(nullptr);
}

Status RowEncoder::Init(const std::vector<TypeHolder>& column_types, ExecContext* ctx) {
ctx_ = ctx;
encoders_.resize(column_types.size());
extension_types_.resize(column_types.size());

for (size_t i = 0; i < column_types.size(); ++i) {
ARROW_ASSIGN_OR_RAISE(encoders_[i], MakeKeyEncoder(column_types[i], &extension_types_[i], ctx->memory_pool()));
ARROW_ASSIGN_OR_RAISE(
encoders_[i],
MakeKeyEncoder(column_types[i], &extension_types_[i], ctx->memory_pool()));
}

int32_t total_length = 0;
Expand Down
55 changes: 49 additions & 6 deletions cpp/src/arrow/compute/row/row_encoder_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -269,20 +269,59 @@ struct ARROW_EXPORT NullKeyEncoder : KeyEncoder {
}
};

template <typename ListType>
struct ARROW_EXPORT ListKeyEncoder : KeyEncoder {
explicit ListKeyEncoder(std::shared_ptr<DataType> element_type, std::shared_ptr<KeyEncoder> element_encoder);
using Offset = typename ListType::offset_type;

void AddLength(const ExecValue&, int64_t batch_length, int32_t* lengths) override;
ListKeyEncoder(std::shared_ptr<DataType> element_type,
std::shared_ptr<KeyEncoder> element_encoder)
: element_type_(std::move(element_type)),
element_encoder_(std::move(element_encoder)) {}

void AddLengthNull(int32_t* length) override;
void AddLength(const ExecValue& data, int64_t batch_length, int32_t* lengths) override {
if (data.is_array()) {
int64_t i = 0;
ARROW_DCHECK_EQ(data.array.length, batch_length);
// TODO(mwish): implement me
} else {
const auto& list_scalar = checked_cast<const BaseListScalar&>(data.scalar);
int32_t accum_length = 0;
if (list_scalar.is_valid) {
auto element_count = static_cast<int32_t>(list_scalar.value->length());
// Counting the size of the encoded list
std::vector<int32_t> child_lengthes(element_count, 0);
this->element_encoder_->AddLength(ExecValue{*list_scalar.value->data()},
element_count, child_lengthes.data());
for (int32_t i = 0; i < element_count; i++) {
accum_length += child_lengthes[i];
}
}
for (int64_t i = 0; i < batch_length; i++) {
lengths[i] += kExtraByteForNull + sizeof(Offset) + accum_length;
}
}
}

void AddLengthNull(int32_t* length) override {
*length += kExtraByteForNull + sizeof(Offset);
}

Status Encode(const ExecValue& data, int64_t batch_length,
uint8_t** encoded_bytes) override;
uint8_t** encoded_bytes) override {
return Status::NotImplemented("ListKeyEncoder::Decode");
}

void EncodeNull(uint8_t** encoded_bytes) override;
void EncodeNull(uint8_t** encoded_bytes) override {
auto& encoded_ptr = *encoded_bytes;
*encoded_ptr++ = kNullByte;
util::SafeStore(encoded_ptr, static_cast<Offset>(0));
encoded_ptr += sizeof(Offset);
}

Result<std::shared_ptr<ArrayData>> Decode(uint8_t** encoded_bytes, int32_t length,
MemoryPool* pool) override;
MemoryPool* pool) override {
return Status::NotImplemented("ListKeyEncoder::Decode");
}

std::shared_ptr<DataType> element_type_;
std::shared_ptr<KeyEncoder> element_encoder_;
Expand Down Expand Up @@ -350,6 +389,10 @@ struct ARROW_EXPORT ListKeyEncoder : KeyEncoder {
/// Null string Would be encoded as:
/// 1 ( 1 byte for null) + 0 ( 4 bytes for length )
///
/// The size of the "fixed-width" part is defined by the `offset_type`
/// of the variable-width type. For example, it would be 4 bytes for
/// String/Binary type and 8 bytes for LargeString/LargeBinary type.
///
/// ## List Type
///
/// List Type is encoded as:
Expand Down

0 comments on commit 5055ef0

Please sign in to comment.