Skip to content

Commit

Permalink
apacheGH-38418: [MATLAB] Add method for extracting one row of an `arr…
Browse files Browse the repository at this point in the history
…ow.tabular.Table` as a string (apache#38463)

### Rationale for this change

We would like to modify the display of the `arrow.tabular.Table` and `arrow.tabular.RecordBatch` classes to be more "MATLAB-like". In order to do this, we need to add a method to their respective C++ Proxy classes that  returns a single row of the Table/RecordBatch as a MATLAB `string` array.

### What changes are included in this PR?

Added  new function template:
```cpp 
template <typename TabularLike>
arrow::matlab::tabular::print_row(const std::shared_ptr<TabularLike>& tabularObject, const int64_t row_index) 
```
This function template returns a string representation of the specified row in `tabbularObject`. 

Added a new proxy method called `getRowString` to both the `Table` and `RecordBatch` C++ proxy classes. These methods invoke `print_row` to return a string representation of one row in the `Table`/`RecordBatch`.  Neither MATLAB class `arrow.tabular.Table` nor `arrow.tabular.RecordBatch` expose these methods directly because they will only be used internally for display. 

Below is an example Output of `getRowString()`:

```matlab
>> matlabTable = table([1; 2; 3], ["ABC"; "DE"; "FGH"], datetime(2023, 10, 25) + days(0:2)');
>> arrowTable = arrow.table(matlabTable);
>> rowOneAsString = arrowTable.Proxy.getRowString(struct(Index=int64(1)))

rowOneAsString = 

    "1 | "ABC" | 2023-10-25 00:00:00.000000"
```

### Are these changes tested?

Yes, added a new test class called `tTabularInternal.m`. Because `getRowString()` is not a method on the MATLAB classes `arrow.tabular.Table` and `arrow.tabular.RecordBatch`, this test class calls `getRowString()` on their `Proxy` properties, which are public but hidden.

### Are there any user-facing changes?

No.

* Closes: apache#38418

Lead-authored-by: Sarah Gilmore <[email protected]>
Co-authored-by: sgilmore10 <[email protected]>
Co-authored-by: Kevin Gurney <[email protected]>
Signed-off-by: Kevin Gurney <[email protected]>
  • Loading branch information
2 people authored and loicalleyne committed Nov 13, 2023
1 parent 468c180 commit 476b1df
Show file tree
Hide file tree
Showing 8 changed files with 230 additions and 2 deletions.
1 change: 1 addition & 0 deletions matlab/src/cpp/arrow/matlab/error/error.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,4 +202,5 @@ namespace arrow::matlab::error {
static const char* INDEX_OUT_OF_RANGE = "arrow:index:OutOfRange";
static const char* BUFFER_VIEW_OR_COPY_FAILED = "arrow:buffer:ViewOrCopyFailed";
static const char* ARRAY_PRETTY_PRINT_FAILED = "arrow:array:PrettyPrintFailed";
static const char* TABULAR_GET_ROW_AS_STRING_FAILED = "arrow:tabular:GetRowAsStringFailed";
}
77 changes: 77 additions & 0 deletions matlab/src/cpp/arrow/matlab/tabular/get_row_as_string.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#pragma once

#include "arrow/pretty_print.h"

#include <sstream>

namespace arrow::matlab::tabular {

namespace {
arrow::PrettyPrintOptions make_pretty_print_options() {
auto opts = arrow::PrettyPrintOptions::Defaults();
opts.skip_new_lines = true;
opts.array_delimiters.open = "";
opts.array_delimiters.close = "";
opts.chunked_array_delimiters.open = "";
opts.chunked_array_delimiters.close = "";
return opts;
}
}

template <typename TabularType>
arrow::Result<std::string> get_row_as_string(const std::shared_ptr<TabularType>& tabular_object, const int64_t matlab_row_index) {
std::stringstream ss;
const int64_t row_index = matlab_row_index - 1;
if (row_index >= tabular_object->num_rows() || row_index < 0) {
ss << "Invalid Row Index: " << matlab_row_index;
return arrow::Status::Invalid(ss.str());
}

const auto opts = make_pretty_print_options();
const auto num_columns = tabular_object->num_columns();
const auto& columns = tabular_object->columns();

for (int32_t i = 0; i < num_columns; ++i) {
const auto& column = columns[i];
const auto type_id = column->type()->id();
if (arrow::is_primitive(type_id) || arrow::is_string(type_id)) {
auto slice = column->Slice(row_index, 1);
ARROW_RETURN_NOT_OK(arrow::PrettyPrint(*slice, opts, &ss));
} else if (type_id == arrow::Type::type::STRUCT) {
// Use <Struct> as a placeholder since we don't have a good
// way to display StructArray elements horiztonally on screen.
ss << "<Struct>";
} else if (type_id == arrow::Type::type::LIST) {
// Use <List> as a placeholder since we don't have a good
// way to display ListArray elements horiztonally on screen.
ss << "<List>";
} else {
return arrow::Status::NotImplemented("Datatype " + column->type()->ToString() + "is not currently supported for display.");
}

if (i + 1 < num_columns) {
// Only add the delimiter if there is at least
// one more element to print.
ss << " | ";
}
}
return ss.str();
}
}
18 changes: 18 additions & 0 deletions matlab/src/cpp/arrow/matlab/tabular/proxy/record_batch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "arrow/matlab/error/error.h"
#include "arrow/matlab/tabular/proxy/record_batch.h"
#include "arrow/matlab/tabular/proxy/schema.h"
#include "arrow/matlab/tabular/get_row_as_string.h"
#include "arrow/type.h"
#include "arrow/util/utf8.h"

Expand Down Expand Up @@ -58,6 +59,7 @@ namespace arrow::matlab::tabular::proxy {
REGISTER_METHOD(RecordBatch, getColumnByIndex);
REGISTER_METHOD(RecordBatch, getColumnByName);
REGISTER_METHOD(RecordBatch, getSchema);
REGISTER_METHOD(RecordBatch, getRowAsString);
}

std::shared_ptr<arrow::RecordBatch> RecordBatch::unwrap() {
Expand Down Expand Up @@ -218,4 +220,20 @@ namespace arrow::matlab::tabular::proxy {
context.outputs[0] = schema_proxy_id_mda;
}

void RecordBatch::getRowAsString(libmexclass::proxy::method::Context& context) {
namespace mda = ::matlab::data;
using namespace libmexclass::proxy;
mda::ArrayFactory factory;

mda::StructArray args = context.inputs[0];
const mda::TypedArray<int64_t> index_mda = args[0]["Index"];
const auto matlab_row_index = int64_t(index_mda[0]);

MATLAB_ASSIGN_OR_ERROR_WITH_CONTEXT(auto row_str_utf8, arrow::matlab::tabular::get_row_as_string(record_batch, matlab_row_index),
context, error::TABULAR_GET_ROW_AS_STRING_FAILED);
MATLAB_ASSIGN_OR_ERROR_WITH_CONTEXT(auto row_str_utf16, arrow::util::UTF8StringToUTF16(row_str_utf8),
context, error::UNICODE_CONVERSION_ERROR_ID);
context.outputs[0] = factory.createScalar(row_str_utf16);
}

}
1 change: 1 addition & 0 deletions matlab/src/cpp/arrow/matlab/tabular/proxy/record_batch.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ namespace arrow::matlab::tabular::proxy {
void getColumnByIndex(libmexclass::proxy::method::Context& context);
void getColumnByName(libmexclass::proxy::method::Context& context);
void getSchema(libmexclass::proxy::method::Context& context);
void getRowAsString(libmexclass::proxy::method::Context& context);

std::shared_ptr<arrow::RecordBatch> record_batch;
};
Expand Down
19 changes: 19 additions & 0 deletions matlab/src/cpp/arrow/matlab/tabular/proxy/table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
#include "arrow/matlab/error/error.h"
#include "arrow/matlab/tabular/proxy/table.h"
#include "arrow/matlab/tabular/proxy/schema.h"
#include "arrow/matlab/tabular/get_row_as_string.h"

#include "arrow/type.h"
#include "arrow/util/utf8.h"

Expand Down Expand Up @@ -57,6 +59,7 @@ namespace arrow::matlab::tabular::proxy {
REGISTER_METHOD(Table, getSchema);
REGISTER_METHOD(Table, getColumnByIndex);
REGISTER_METHOD(Table, getColumnByName);
REGISTER_METHOD(Table, getRowAsString);
}

std::shared_ptr<arrow::Table> Table::unwrap() {
Expand Down Expand Up @@ -212,4 +215,20 @@ namespace arrow::matlab::tabular::proxy {
context.outputs[0] = chunked_array_proxy_id_mda;
}

void Table::getRowAsString(libmexclass::proxy::method::Context& context) {
namespace mda = ::matlab::data;
using namespace libmexclass::proxy;
mda::ArrayFactory factory;

mda::StructArray args = context.inputs[0];
const mda::TypedArray<int64_t> index_mda = args[0]["Index"];
const auto matlab_row_index = int64_t(index_mda[0]);

MATLAB_ASSIGN_OR_ERROR_WITH_CONTEXT(auto row_str_utf8, arrow::matlab::tabular::get_row_as_string(table, matlab_row_index),
context, error::TABULAR_GET_ROW_AS_STRING_FAILED);
MATLAB_ASSIGN_OR_ERROR_WITH_CONTEXT(auto row_str_utf16, arrow::util::UTF8StringToUTF16(row_str_utf8),
context, error::UNICODE_CONVERSION_ERROR_ID);
context.outputs[0] = factory.createScalar(row_str_utf16);
}

}
1 change: 1 addition & 0 deletions matlab/src/cpp/arrow/matlab/tabular/proxy/table.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ namespace arrow::matlab::tabular::proxy {
void getSchema(libmexclass::proxy::method::Context& context);
void getColumnByIndex(libmexclass::proxy::method::Context& context);
void getColumnByName(libmexclass::proxy::method::Context& context);
void getRowAsString(libmexclass::proxy::method::Context& context);

std::shared_ptr<arrow::Table> table;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
end

% Seed the random number generator to ensure
% reproducible results in tests.
rng(1);
% reproducible results in tests across MATLAB sessions.
rng(1, "twister");

import arrow.type.ID
import arrow.array.*
Expand Down Expand Up @@ -101,6 +101,7 @@

% Return the class names as a string array
classes = string({metaClass.Name});
classes = sort(classes);
end

function dict = getNumericArrayToMatlabDictionary()
Expand Down
110 changes: 110 additions & 0 deletions matlab/test/arrow/tabular/tTabularInternal.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
%TTABULARINTERNAL Unit tests for internal functionality of tabular types.

% Licensed to the Apache Software Foundation (ASF) under one or more
% contributor license agreements. See the NOTICE file distributed with
% this work for additional information regarding copyright ownership.
% The ASF licenses this file to you under the Apache License, Version
% 2.0 (the "License"); you may not use this file except in compliance
% with the License. You may obtain a copy of the License at
%
% http://www.apache.org/licenses/LICENSE-2.0
%
% Unless required by applicable law or agreed to in writing, software
% distributed under the License is distributed on an "AS IS" BASIS,
% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
% implied. See the License for the specific language governing
% permissions and limitations under the License.

classdef tTabularInternal < matlab.unittest.TestCase

properties(TestParameter)
TabularObjectWithAllTypes

TabularObjectWithOneColumn

TabularObjectWithThreeRows
end

methods (TestParameterDefinition, Static)
function TabularObjectWithAllTypes = initializeTabularObjectWithAllTypes()
arrays = arrow.internal.test.tabular.createAllSupportedArrayTypes(NumRows=1);
arrowTable = arrow.tabular.Table.fromArrays(arrays{:});
arrowRecordBatch = arrow.tabular.Table.fromArrays(arrays{:});
TabularObjectWithAllTypes = struct(Table=arrowTable, ...
RecordBatch=arrowRecordBatch);
end

function TabularObjectWithOneColumn = initializeTabularObjectWithOneColumn()
t = table((1:3)');
arrowTable = arrow.table(t);
arrowRecordBatch = arrow.recordBatch(t);
TabularObjectWithOneColumn = struct(Table=arrowTable, ...
RecordBatch=arrowRecordBatch);
end

function TabularObjectWithThreeRows = initializeTabularObjectWithThreeRows()
t = table((1:3)', ["A"; "B"; "C"]);
arrowTable = arrow.table(t);
arrowRecordBatch = arrow.recordBatch(t);
TabularObjectWithThreeRows = struct(Table=arrowTable, ...
RecordBatch=arrowRecordBatch);
end
end

methods (Test)
function RowWithAllTypes(testCase, TabularObjectWithAllTypes)
% Verify getRowAsString successfully returns the expected string
% when called on a Table/RecordBatch that contains all
% supported array types.
proxy = TabularObjectWithAllTypes.Proxy;
columnStrs = ["false", "2024-02-23", "2023-08-24", "78", "38", ...
"24", "48", "89", "102", "<List>", """107""", "<Struct>", ...
"00:03:44", "00:00:07.000000", "2024-02-10 00:00:00.000000", ...
"107", "143", "36", "51"];
expectedString = strjoin(columnStrs, " | ");
actualString = proxy.getRowAsString(struct(Index=int64(1)));
testCase.verifyEqual(actualString, expectedString);
end

function RowWithOneColumn(testCase, TabularObjectWithOneColumn)
% Verify getRowAsString successfully returns the expected string
% when called on a Table/RecordBatch with one column.
proxy = TabularObjectWithOneColumn.Proxy;
expectedString = "1";
actualString = proxy.getRowAsString(struct(Index=int64(1)));
testCase.verifyEqual(actualString, expectedString);
end

function RowIndex(testCase, TabularObjectWithThreeRows)
% Verify getRowAsString returns the expected string for
% the provided row index.
proxy = TabularObjectWithThreeRows.Proxy;

actualString = proxy.getRowAsString(struct(Index=int64(1)));
expectedString = "1 | ""A""";
testCase.verifyEqual(actualString, expectedString);

actualString = proxy.getRowAsString(struct(Index=int64(2)));
expectedString = "2 | ""B""";
testCase.verifyEqual(actualString, expectedString);

actualString = proxy.getRowAsString(struct(Index=int64(3)));
expectedString = "3 | ""C""";
testCase.verifyEqual(actualString, expectedString);
end

function GetRowAsStringFailed(testCase, TabularObjectWithThreeRows)
% Verify getRowAsString throws an error with the ID
% arrow:tabular:GetRowAsStringFailed if provided invalid index
% values.
proxy = TabularObjectWithThreeRows.Proxy;
fcn = @() proxy.getRowAsString(struct(Index=int64(0)));
testCase.verifyError(fcn, "arrow:tabular:GetRowAsStringFailed");

fcn = @() proxy.getRowAsString(struct(Index=int64(4)));
testCase.verifyError(fcn, "arrow:tabular:GetRowAsStringFailed");
end

end

end

0 comments on commit 476b1df

Please sign in to comment.