Skip to content

Commit

Permalink
apacheGH-37429: [C++] Add arrow::ipc::StreamDecoder::Reset() (apache#…
Browse files Browse the repository at this point in the history
…37970)

### Rationale for this change

We can reuse the same StreamDecoder to read multiple streams with this.

### What changes are included in this PR?

Add StreamDecoder::Reset().

### Are these changes tested?

Yes.

### Are there any user-facing changes?

Yes.
* Closes: apache#37429

Authored-by: Sutou Kouhei <[email protected]>
Signed-off-by: Sutou Kouhei <[email protected]>
  • Loading branch information
kou authored Nov 2, 2023
1 parent ff762a5 commit cead3dd
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 4 deletions.
37 changes: 37 additions & 0 deletions cpp/src/arrow/ipc/read_write_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2164,6 +2164,43 @@ TEST(TestRecordBatchStreamReader, MalformedInput) {
ASSERT_RAISES(Invalid, RecordBatchStreamReader::Open(&garbage_reader));
}

namespace {
class EndlessCollectListener : public CollectListener {
public:
EndlessCollectListener() : CollectListener(), decoder_(nullptr) {}

void SetDecoder(StreamDecoder* decoder) { decoder_ = decoder; }

arrow::Status OnEOS() override { return decoder_->Reset(); }

private:
StreamDecoder* decoder_;
};
}; // namespace

TEST(TestStreamDecoder, Reset) {
auto listener = std::make_shared<EndlessCollectListener>();
StreamDecoder decoder(listener);
listener->SetDecoder(&decoder);

std::shared_ptr<RecordBatch> batch;
ASSERT_OK(MakeIntRecordBatch(&batch));
StreamWriterHelper writer_helper;
ASSERT_OK(writer_helper.Init(batch->schema(), IpcWriteOptions::Defaults()));
ASSERT_OK(writer_helper.WriteBatch(batch));
ASSERT_OK(writer_helper.Finish());

ASSERT_OK_AND_ASSIGN(auto all_buffer, ConcatenateBuffers({writer_helper.buffer_,
writer_helper.buffer_}));
// Consume by Buffer
ASSERT_OK(decoder.Consume(all_buffer));
ASSERT_EQ(2, listener->num_record_batches());

// Consume by raw data
ASSERT_OK(decoder.Consume(all_buffer->data(), all_buffer->size()));
ASSERT_EQ(4, listener->num_record_batches());
}

TEST(TestStreamDecoder, NextRequiredSize) {
auto listener = std::make_shared<CollectListener>();
StreamDecoder decoder(listener);
Expand Down
79 changes: 75 additions & 4 deletions cpp/src/arrow/ipc/reader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -932,14 +932,18 @@ class StreamDecoderInternal : public MessageDecoderListener {
return listener_->OnEOS();
}

std::shared_ptr<Listener> listener() const { return listener_; }

Listener* raw_listener() const { return listener_.get(); }

IpcReadOptions options() const { return options_; }

State state() const { return state_; }

std::shared_ptr<Schema> schema() const { return filtered_schema_; }

ReadStats stats() const { return stats_; }

State state() const { return state_; }

int num_required_initial_dictionaries() const {
return num_required_initial_dictionaries_;
}
Expand Down Expand Up @@ -2039,6 +2043,8 @@ class StreamDecoder::StreamDecoderImpl : public StreamDecoderInternal {

int64_t next_required_size() const { return message_decoder_.next_required_size(); }

const MessageDecoder* message_decoder() const { return &message_decoder_; }

private:
MessageDecoder message_decoder_;
};
Expand All @@ -2050,10 +2056,75 @@ StreamDecoder::StreamDecoder(std::shared_ptr<Listener> listener, IpcReadOptions
StreamDecoder::~StreamDecoder() {}

Status StreamDecoder::Consume(const uint8_t* data, int64_t size) {
return impl_->Consume(data, size);
while (size > 0) {
const auto next_required_size = impl_->next_required_size();
if (next_required_size == 0) {
break;
}
if (size < next_required_size) {
break;
}
ARROW_RETURN_NOT_OK(impl_->Consume(data, next_required_size));
data += next_required_size;
size -= next_required_size;
}
if (size > 0) {
return impl_->Consume(data, size);
} else {
return arrow::Status::OK();
}
}

Status StreamDecoder::Consume(std::shared_ptr<Buffer> buffer) {
return impl_->Consume(std::move(buffer));
if (buffer->size() == 0) {
return arrow::Status::OK();
}
if (impl_->next_required_size() == 0 || buffer->size() <= impl_->next_required_size()) {
return impl_->Consume(std::move(buffer));
} else {
int64_t offset = 0;
while (true) {
const auto next_required_size = impl_->next_required_size();
if (next_required_size == 0) {
break;
}
if (buffer->size() - offset <= next_required_size) {
break;
}
if (buffer->is_cpu()) {
switch (impl_->message_decoder()->state()) {
case MessageDecoder::State::INITIAL:
case MessageDecoder::State::METADATA_LENGTH:
// We don't need to pass a sliced buffer because
// MessageDecoder doesn't keep reference of the given
// buffer on these states.
ARROW_RETURN_NOT_OK(
impl_->Consume(buffer->data() + offset, next_required_size));
break;
default:
ARROW_RETURN_NOT_OK(
impl_->Consume(SliceBuffer(buffer, offset, next_required_size)));
break;
}
} else {
ARROW_RETURN_NOT_OK(
impl_->Consume(SliceBuffer(buffer, offset, next_required_size)));
}
offset += next_required_size;
}
if (buffer->size() - offset == 0) {
return arrow::Status::OK();
} else if (offset == 0) {
return impl_->Consume(std::move(buffer));
} else {
return impl_->Consume(SliceBuffer(std::move(buffer), offset));
}
}
}

Status StreamDecoder::Reset() {
impl_ = std::make_unique<StreamDecoderImpl>(impl_->listener(), impl_->options());
return Status::OK();
}

std::shared_ptr<Schema> StreamDecoder::schema() const { return impl_->schema(); }
Expand Down
8 changes: 8 additions & 0 deletions cpp/src/arrow/ipc/reader.h
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,14 @@ class ARROW_EXPORT StreamDecoder {
/// \return Status
Status Consume(std::shared_ptr<Buffer> buffer);

/// \brief Reset the internal status.
///
/// You can reuse this decoder for new stream after calling
/// this.
///
/// \return Status
Status Reset();

/// \return the shared schema of the record batches in the stream
std::shared_ptr<Schema> schema() const;

Expand Down

0 comments on commit cead3dd

Please sign in to comment.