Skip to content

Commit

Permalink
Compute: ListKeyEncoder interfaces
Browse files Browse the repository at this point in the history
  • Loading branch information
mapleFU committed Sep 2, 2024
1 parent 44d3f76 commit f5f0f33
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 42 deletions.
107 changes: 66 additions & 41 deletions cpp/src/arrow/compute/row/row_encoder_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,51 @@ using internal::FirstTimeBitmapWriter;
namespace compute {
namespace internal {

Result<std::shared_ptr<KeyEncoder>> MakeKeyEncoder(const TypeHolder& column_type, std::shared_ptr<ExtensionType>* extension_type, MemoryPool* pool) {
const bool is_extension = column_type.id() == Type::EXTENSION;
const TypeHolder& type =
is_extension
? arrow::internal::checked_cast<const ExtensionType*>(column_type.type)
->storage_type()
: column_type;

if (is_extension) {
*extension_type = arrow::internal::checked_pointer_cast<ExtensionType>(
column_type.GetSharedPtr());
}
if (type.id() == Type::BOOL) {
return std::make_shared<BooleanKeyEncoder>();
}

if (type.id() == Type::DICTIONARY) {
return std::make_shared<DictionaryKeyEncoder>(type.GetSharedPtr(), pool);
}

if (is_fixed_width(type.id())) {
return std::make_shared<FixedWidthKeyEncoder>(type.GetSharedPtr());
}

if (is_binary_like(type.id())) {
return std::make_shared<VarLengthKeyEncoder<BinaryType>>(type.GetSharedPtr());
}

if (is_large_binary_like(type.id())) {
return std::make_shared<VarLengthKeyEncoder<LargeBinaryType>>(type.GetSharedPtr());
}

if (is_list(type.id())) {
auto element_type = ::arrow::checked_cast<BaseListType*>(type.type)->value_type();
if (is_nested(element_type->id())) {
return Status::NotImplemented("Unsupported nested type in List for row encoder", type.ToString());
}
std::shared_ptr<ExtensionType> element_extension_type;
ARROW_ASSIGN_OR_RAISE(auto element_encoder, MakeKeyEncoder(element_type, &element_extension_type, pool));
return std::make_shared<ListKeyEncoder>(std::move(element_type), std::move(element_encoder));
}

return Status::NotImplemented("Unsupported type for row encoder", type.ToString());
}

// extract the null bitmap from the leading nullity bytes of encoded keys
Status KeyEncoder::DecodeNulls(MemoryPool* pool, int32_t length, uint8_t** encoded_bytes,
std::shared_ptr<Buffer>* null_bitmap,
Expand Down Expand Up @@ -256,53 +301,32 @@ Result<std::shared_ptr<ArrayData>> DictionaryKeyEncoder::Decode(uint8_t** encode
return data;
}

void RowEncoder::Init(const std::vector<TypeHolder>& column_types, ExecContext* ctx) {
ctx_ = ctx;
encoders_.resize(column_types.size());
extension_types_.resize(column_types.size());
ListKeyEncoder::ListKeyEncoder(std::shared_ptr<DataType> element_type, std::shared_ptr<KeyEncoder> element_encoder)
: element_type_(std::move(element_type)), element_encoder_(std::move(element_encoder)) {}

for (size_t i = 0; i < column_types.size(); ++i) {
const bool is_extension = column_types[i].id() == Type::EXTENSION;
const TypeHolder& type =
is_extension
? arrow::internal::checked_cast<const ExtensionType*>(column_types[i].type)
->storage_type()
: column_types[i];

if (is_extension) {
extension_types_[i] = arrow::internal::checked_pointer_cast<ExtensionType>(
column_types[i].GetSharedPtr());
}
if (type.id() == Type::BOOL) {
encoders_[i] = std::make_shared<BooleanKeyEncoder>();
continue;
}
void ListKeyEncoder::AddLength(const ExecValue& exec_value, int64_t batch_length, int32_t* lengths) {}

if (type.id() == Type::DICTIONARY) {
encoders_[i] =
std::make_shared<DictionaryKeyEncoder>(type.GetSharedPtr(), ctx->memory_pool());
continue;
}
void ListKeyEncoder::AddLengthNull(int32_t* length) {}

if (is_fixed_width(type.id())) {
encoders_[i] = std::make_shared<FixedWidthKeyEncoder>(type.GetSharedPtr());
continue;
}
Status ListKeyEncoder::Encode(const ExecValue& data, int64_t batch_length,
uint8_t** encoded_bytes) {
return Status::NotImplemented("ListKeyEncoder::Encode");
}

if (is_binary_like(type.id())) {
encoders_[i] =
std::make_shared<VarLengthKeyEncoder<BinaryType>>(type.GetSharedPtr());
continue;
}
void ListKeyEncoder::EncodeNull(uint8_t** encoded_bytes) {}

if (is_large_binary_like(type.id())) {
encoders_[i] =
std::make_shared<VarLengthKeyEncoder<LargeBinaryType>>(type.GetSharedPtr());
continue;
}
Result<std::shared_ptr<ArrayData>> ListKeyEncoder::Decode(uint8_t** encoded_bytes, int32_t length,
MemoryPool* pool) {
return std::shared_ptr<ArrayData>(nullptr);
}

Status RowEncoder::Init(const std::vector<TypeHolder>& column_types, ExecContext* ctx) {
ctx_ = ctx;
encoders_.resize(column_types.size());
extension_types_.resize(column_types.size());

// We should not get here
ARROW_DCHECK(false);
for (size_t i = 0; i < column_types.size(); ++i) {
ARROW_ASSIGN_OR_RAISE(encoders_[i], MakeKeyEncoder(column_types[i], &extension_types_[i], ctx->memory_pool()));
}

int32_t total_length = 0;
Expand All @@ -314,6 +338,7 @@ void RowEncoder::Init(const std::vector<TypeHolder>& column_types, ExecContext*
for (size_t i = 0; i < column_types.size(); ++i) {
encoders_[i]->EncodeNull(&buf_ptr);
}
return Status::OK();
}

void RowEncoder::Clear() {
Expand Down
33 changes: 32 additions & 1 deletion cpp/src/arrow/compute/row/row_encoder_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,28 @@ struct ARROW_EXPORT NullKeyEncoder : KeyEncoder {
}
};

struct ARROW_EXPORT ListKeyEncoder : KeyEncoder {
explicit ListKeyEncoder(std::shared_ptr<DataType> element_type, std::shared_ptr<KeyEncoder> element_encoder);

void AddLength(const ExecValue&, int64_t batch_length, int32_t* lengths) override;

void AddLengthNull(int32_t* length) override;

Status Encode(const ExecValue& data, int64_t batch_length,
uint8_t** encoded_bytes) override;

void EncodeNull(uint8_t** encoded_bytes) override;

Result<std::shared_ptr<ArrayData>> Decode(uint8_t** encoded_bytes, int32_t length,
MemoryPool* pool) override;

std::shared_ptr<DataType> element_type_;
std::shared_ptr<KeyEncoder> element_encoder_;
// extension_type_ is used to store the extension type of the list element.
// It would be nullptr if the list element is not an extension type.
std::shared_ptr<ExtensionType> extension_type_;
};

/// RowEncoder encodes ExecSpan to a variable length byte sequence
/// created by concatenating the encoded form of each column. The encoding
/// for each column depends on its data type.
Expand Down Expand Up @@ -328,14 +350,23 @@ struct ARROW_EXPORT NullKeyEncoder : KeyEncoder {
/// Null string Would be encoded as:
/// 1 ( 1 byte for null) + 0 ( 4 bytes for length )
///
/// ## List Type
///
/// List Type is encoded as:
/// [null byte, list element count, [element 1, element 2, ...]]
/// Element count uses 4 bytes.
///
/// Currently, we only support encoding of primitive types, dictionary types
/// in the list, the nested list is not supported.
///
/// # Row Encoding
///
/// The row format is the concatenation of the encodings of each column.
class ARROW_EXPORT RowEncoder {
public:
static constexpr int kRowIdForNulls() { return -1; }

void Init(const std::vector<TypeHolder>& column_types, ExecContext* ctx);
Status Init(const std::vector<TypeHolder>& column_types, ExecContext* ctx);
void Clear();
Status EncodeAndAppend(const ExecSpan& batch);
Result<ExecBatch> Decode(int64_t num_rows, const int32_t* row_ids);
Expand Down

0 comments on commit f5f0f33

Please sign in to comment.