Skip to content

Commit

Permalink
apacheGH-44923: [MATLAB] Add IPC RecordBatchStreamReader MATLAB cla…
Browse files Browse the repository at this point in the history
…ss (apache#45068)

### Rationale for this change

To enable support for the IPC Streaming format in the MATLAB interface, we should add a `RecordBatchStreamReader` class.

This is a followup to apache#44922 

### What changes are included in this PR?

1. Added a new `arrow.io.ipc.RecordBatchStreamReader` MATLAB class.

### Are these changes tested?

Yes.

1. Added new MATLAB test suite `arrow/matlab/test/arrow/io/ipc/tRecordBatchStreamReader.m`.

### Are there any user-facing changes?

Yes.

1. Users can now create `arrow.io.ipc.RecordBatchStreamReader` objects to read `RecordBatch` objects incrementally from an Arrow IPC Stream file.

### Notes

1. Thank you @ sgilmore10 for your help with this pull request!
* GitHub Issue: apache#44923

Lead-authored-by: Kevin Gurney <[email protected]>
Co-authored-by: Kevin Gurney <[email protected]>
Co-authored-by: Sarah Gilmore <[email protected]>
Signed-off-by: Kevin Gurney <[email protected]>
  • Loading branch information
3 people authored Dec 23, 2024
1 parent 035e331 commit 8b75373
Show file tree
Hide file tree
Showing 7 changed files with 622 additions and 0 deletions.
2 changes: 2 additions & 0 deletions matlab/src/cpp/arrow/matlab/error/error.h
Original file line number Diff line number Diff line change
Expand Up @@ -249,5 +249,7 @@ static const char* IPC_RECORD_BATCH_READER_OPEN_FAILED =
"arrow:io:ipc:FailedToOpenRecordBatchReader";
static const char* IPC_RECORD_BATCH_READ_INVALID_INDEX = "arrow:io:ipc:InvalidIndex";
static const char* IPC_RECORD_BATCH_READ_FAILED = "arrow:io:ipc:ReadFailed";
static const char* IPC_TABLE_READ_FAILED = "arrow:io:ipc:TableReadFailed";
static const char* IPC_END_OF_STREAM = "arrow:io:ipc:EndOfStream";

} // namespace arrow::matlab::error
154 changes: 154 additions & 0 deletions matlab/src/cpp/arrow/matlab/io/ipc/proxy/record_batch_stream_reader.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#include "arrow/matlab/io/ipc/proxy/record_batch_stream_reader.h"
#include "arrow/io/file.h"
#include "arrow/matlab/error/error.h"
#include "arrow/matlab/tabular/proxy/record_batch.h"
#include "arrow/matlab/tabular/proxy/schema.h"
#include "arrow/matlab/tabular/proxy/table.h"
#include "arrow/util/utf8.h"

#include "libmexclass/proxy/ProxyManager.h"

namespace arrow::matlab::io::ipc::proxy {

RecordBatchStreamReader::RecordBatchStreamReader(
const std::shared_ptr<arrow::ipc::RecordBatchStreamReader> reader)
: reader{std::move(reader)} {
REGISTER_METHOD(RecordBatchStreamReader, getSchema);
REGISTER_METHOD(RecordBatchStreamReader, readRecordBatch);
REGISTER_METHOD(RecordBatchStreamReader, hasNextRecordBatch);
REGISTER_METHOD(RecordBatchStreamReader, readTable);
}

libmexclass::proxy::MakeResult RecordBatchStreamReader::make(
const libmexclass::proxy::FunctionArguments& constructor_arguments) {
namespace mda = ::matlab::data;
using RecordBatchStreamReaderProxy =
arrow::matlab::io::ipc::proxy::RecordBatchStreamReader;

const mda::StructArray opts = constructor_arguments[0];

const mda::StringArray filename_mda = opts[0]["Filename"];
const auto filename_utf16 = std::u16string(filename_mda[0]);
MATLAB_ASSIGN_OR_ERROR(const auto filename_utf8,
arrow::util::UTF16StringToUTF8(filename_utf16),
error::UNICODE_CONVERSION_ERROR_ID);

MATLAB_ASSIGN_OR_ERROR(auto input_stream, arrow::io::ReadableFile::Open(filename_utf8),
error::FAILED_TO_OPEN_FILE_FOR_READ);

MATLAB_ASSIGN_OR_ERROR(auto reader,
arrow::ipc::RecordBatchStreamReader::Open(input_stream),
error::IPC_RECORD_BATCH_READER_OPEN_FAILED);

return std::make_shared<RecordBatchStreamReaderProxy>(std::move(reader));
}

void RecordBatchStreamReader::getSchema(libmexclass::proxy::method::Context& context) {
namespace mda = ::matlab::data;
using SchemaProxy = arrow::matlab::tabular::proxy::Schema;

auto schema = reader->schema();

auto schema_proxy = std::make_shared<SchemaProxy>(std::move(schema));
const auto schema_proxy_id =
libmexclass::proxy::ProxyManager::manageProxy(schema_proxy);

mda::ArrayFactory factory;
const auto schema_proxy_id_mda = factory.createScalar(schema_proxy_id);
context.outputs[0] = schema_proxy_id_mda;
}

void RecordBatchStreamReader::readTable(libmexclass::proxy::method::Context& context) {
namespace mda = ::matlab::data;
using TableProxy = arrow::matlab::tabular::proxy::Table;

MATLAB_ASSIGN_OR_ERROR_WITH_CONTEXT(auto table, reader->ToTable(), context,
error::IPC_TABLE_READ_FAILED);
auto table_proxy = std::make_shared<TableProxy>(table);
const auto table_proxy_id = libmexclass::proxy::ProxyManager::manageProxy(table_proxy);

mda::ArrayFactory factory;
const auto table_proxy_id_mda = factory.createScalar(table_proxy_id);
context.outputs[0] = table_proxy_id_mda;
}

void RecordBatchStreamReader::readRecordBatch(
libmexclass::proxy::method::Context& context) {
namespace mda = ::matlab::data;
using RecordBatchProxy = arrow::matlab::tabular::proxy::RecordBatch;
using namespace libmexclass::error;
// If we don't have a "pre-cached" record batch to return, then try reading another
// record batch from the IPC Stream. If there are no more record batches in the stream,
// then error.
if (!nextRecordBatch) {
MATLAB_ASSIGN_OR_ERROR_WITH_CONTEXT(nextRecordBatch, reader->Next(), context,
error::IPC_RECORD_BATCH_READ_FAILED);
}
// Even if the read was "successful", the resulting record batch may be empty,
// signaling the end of the stream.
if (!nextRecordBatch) {
context.error =
Error{error::IPC_END_OF_STREAM,
"Reached end of Arrow IPC Stream. No more record batches to read."};
return;
}
auto record_batch_proxy = std::make_shared<RecordBatchProxy>(nextRecordBatch);
const auto record_batch_proxy_id =
libmexclass::proxy::ProxyManager::manageProxy(record_batch_proxy);
// Once we have "consumed" the next RecordBatch, set nextRecordBatch to nullptr
// so that the next call to hasNextRecordBatch correctly checks whether there are more
// record batches remaining in the IPC Stream.
nextRecordBatch = nullptr;
mda::ArrayFactory factory;
const auto record_batch_proxy_id_mda = factory.createScalar(record_batch_proxy_id);
context.outputs[0] = record_batch_proxy_id_mda;
}

void RecordBatchStreamReader::hasNextRecordBatch(
libmexclass::proxy::method::Context& context) {
namespace mda = ::matlab::data;
bool has_next_record_batch = true;
if (!nextRecordBatch) {
// Try to read another RecordBatch from the
// IPC Stream.
auto maybe_record_batch = reader->Next();
if (!maybe_record_batch.ok()) {
has_next_record_batch = false;
} else {
// If we read a RecordBatch successfully,
// then "cache" the RecordBatch
// so that we can return it on the next
// call to readRecordBatch.
nextRecordBatch = *maybe_record_batch;

// Even if the read was "successful", the resulting
// record batch may be empty, signaling that
// the end of the IPC stream has been reached.
if (!nextRecordBatch) {
has_next_record_batch = false;
}
}
}

mda::ArrayFactory factory;
context.outputs[0] = factory.createScalar(has_next_record_batch);
}

} // namespace arrow::matlab::io::ipc::proxy
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#pragma once

#include "arrow/ipc/reader.h"
#include "libmexclass/proxy/Proxy.h"

namespace arrow::matlab::io::ipc::proxy {

class RecordBatchStreamReader : public libmexclass::proxy::Proxy {
public:
RecordBatchStreamReader(std::shared_ptr<arrow::ipc::RecordBatchStreamReader> reader);

~RecordBatchStreamReader() = default;

static libmexclass::proxy::MakeResult make(
const libmexclass::proxy::FunctionArguments& constructor_arguments);

protected:
std::shared_ptr<arrow::ipc::RecordBatchStreamReader> reader;
std::shared_ptr<arrow::RecordBatch> nextRecordBatch;

void getSchema(libmexclass::proxy::method::Context& context);
void readRecordBatch(libmexclass::proxy::method::Context& context);
void hasNextRecordBatch(libmexclass::proxy::method::Context& context);
void readTable(libmexclass::proxy::method::Context& context);
};

} // namespace arrow::matlab::io::ipc::proxy
2 changes: 2 additions & 0 deletions matlab/src/cpp/arrow/matlab/proxy/factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "arrow/matlab/io/feather/proxy/writer.h"
#include "arrow/matlab/io/ipc/proxy/record_batch_file_reader.h"
#include "arrow/matlab/io/ipc/proxy/record_batch_file_writer.h"
#include "arrow/matlab/io/ipc/proxy/record_batch_stream_reader.h"
#include "arrow/matlab/io/ipc/proxy/record_batch_stream_writer.h"
#include "arrow/matlab/tabular/proxy/record_batch.h"
#include "arrow/matlab/tabular/proxy/schema.h"
Expand Down Expand Up @@ -113,6 +114,7 @@ libmexclass::proxy::MakeResult Factory::make_proxy(
REGISTER_PROXY(arrow.io.ipc.proxy.RecordBatchFileReader , arrow::matlab::io::ipc::proxy::RecordBatchFileReader);
REGISTER_PROXY(arrow.io.ipc.proxy.RecordBatchFileWriter , arrow::matlab::io::ipc::proxy::RecordBatchFileWriter);
REGISTER_PROXY(arrow.io.ipc.proxy.RecordBatchStreamWriter , arrow::matlab::io::ipc::proxy::RecordBatchStreamWriter);
REGISTER_PROXY(arrow.io.ipc.proxy.RecordBatchStreamReader , arrow::matlab::io::ipc::proxy::RecordBatchStreamReader);

// clang-format on

Expand Down
83 changes: 83 additions & 0 deletions matlab/src/matlab/+arrow/+io/+ipc/RecordBatchStreamReader.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
%RECORDBATCHSTREAMREADER Class for reading Arrow record batches from the
% Arrow IPC Stream format.

% Licensed to the Apache Software Foundation (ASF) under one or more
% contributor license agreements. See the NOTICE file distributed with
% this work for additional information regarding copyright ownership.
% The ASF licenses this file to you under the Apache License, Version
% 2.0 (the "License"); you may not use this file except in compliance
% with the License. You may obtain a copy of the License at
%
% http://www.apache.org/licenses/LICENSE-2.0
%
% Unless required by applicable law or agreed to in writing, software
% distributed under the License is distributed on an "AS IS" BASIS,
% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
% implied. See the License for the specific language governing
% permissions and limitations under the License.

classdef RecordBatchStreamReader < matlab.mixin.Scalar

properties(SetAccess=private, GetAccess=public, Hidden)
Proxy
end

properties (Dependent, SetAccess=private, GetAccess=public)
Schema
end

methods
function obj = RecordBatchStreamReader(filename)
arguments
filename(1, 1) string {mustBeNonzeroLengthText}
end
args = struct(Filename=filename);
proxyName = "arrow.io.ipc.proxy.RecordBatchStreamReader";
obj.Proxy = arrow.internal.proxy.create(proxyName, args);
end

function schema = get.Schema(obj)
proxyID = obj.Proxy.getSchema();
proxyName = "arrow.tabular.proxy.Schema";
proxy = libmexclass.proxy.Proxy(ID=proxyID, Name=proxyName);
schema = arrow.tabular.Schema(proxy);
end

function tf = hasnext(obj)
tf = obj.Proxy.hasNextRecordBatch();
end

function tf = done(obj)
tf = ~obj.Proxy.hasNextRecordBatch();
end

function arrowRecordBatch = read(obj)
% NOTE: This function is a "convenience alias" for the readRecordBatch
% method, which has a longer name. This is the exact same implementation
% as readRecordBatch. Since this method might be called in a tight loop,
% it should be slightly more efficient to call the C++ code directly,
% rather than invoking obj.readRecordBatch indirectly. We are intentionally
% trading off code duplication for performance here.
proxyID = obj.Proxy.readRecordBatch();
proxyName = "arrow.tabular.proxy.RecordBatch";
proxy = libmexclass.proxy.Proxy(ID=proxyID, Name=proxyName);
arrowRecordBatch = arrow.tabular.RecordBatch(proxy);
end

function arrowRecordBatch = readRecordBatch(obj)
proxyID = obj.Proxy.readRecordBatch();
proxyName = "arrow.tabular.proxy.RecordBatch";
proxy = libmexclass.proxy.Proxy(ID=proxyID, Name=proxyName);
arrowRecordBatch = arrow.tabular.RecordBatch(proxy);
end

function arrowTable = readTable(obj)
proxyID = obj.Proxy.readTable();
proxyName = "arrow.tabular.proxy.Table";
proxy = libmexclass.proxy.Proxy(ID=proxyID, Name=proxyName);
arrowTable = arrow.tabular.Table(proxy);
end

end

end
Loading

0 comments on commit 8b75373

Please sign in to comment.