Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix projection in IPC reader #1736

Merged
merged 5 commits into from
May 26, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
325 changes: 309 additions & 16 deletions arrow/src/ipc/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,120 @@ fn create_array(
Ok((array, node_index, buffer_index))
}

/// Skip fields based on data types to advance `node_index` and `buffer_index`.
/// This function should be called when doing projection in fn `read_record_batch`.
/// The advancement logic references fn `create_array`.
fn skip_field(
nodes: &[ipc::FieldNode],
field: &Field,
data: &[u8],
buffers: &[ipc::Buffer],
dictionaries_by_id: &HashMap<i64, ArrayRef>,
mut node_index: usize,
mut buffer_index: usize,
) -> Result<(usize, usize)> {
use DataType::*;
let data_type = field.data_type();
match data_type {
Utf8 | Binary | LargeBinary | LargeUtf8 => {
node_index += 1;
buffer_index += 3;
}
FixedSizeBinary(_) => {
node_index += 1;
buffer_index += 2;
}
List(ref list_field) | LargeList(ref list_field) | Map(ref list_field, _) => {
node_index += 1;
buffer_index += 2;
let tuple = skip_field(
nodes,
list_field,
data,
buffers,
dictionaries_by_id,
node_index,
buffer_index,
)?;
node_index = tuple.0;
buffer_index = tuple.1;
}
FixedSizeList(ref list_field, _) => {
node_index += 1;
buffer_index += 1;
let tuple = skip_field(
nodes,
list_field,
data,
buffers,
dictionaries_by_id,
node_index,
buffer_index,
)?;
node_index = tuple.0;
buffer_index = tuple.1;
}
Struct(struct_fields) => {
node_index += 1;
buffer_index += 1;

// skip for each field
for struct_field in struct_fields {
let tuple = skip_field(
nodes,
struct_field,
data,
buffers,
dictionaries_by_id,
node_index,
buffer_index,
)?;
node_index = tuple.0;
buffer_index = tuple.1;
}
}
Dictionary(_, _) => {
node_index += 1;
buffer_index += 2;
}
Union(fields, _field_type_ids, mode) => {
node_index += 1;
buffer_index += 1;

match mode {
UnionMode::Dense => {
buffer_index += 1;
}
UnionMode::Sparse => {}
};

for field in fields {
let tuple = skip_field(
nodes,
field,
data,
buffers,
dictionaries_by_id,
node_index,
buffer_index,
)?;

node_index = tuple.0;
buffer_index = tuple.1;
}
}
Null => {
node_index += 1;
// no buffer increases
}
_ => {
node_index += 1;
buffer_index += 2;
}
};
Ok((node_index, buffer_index))
}

/// Reads the correct number of buffers based on data type and null_count, and creates a
/// primitive array ref
fn create_primitive_array(
Expand Down Expand Up @@ -493,21 +607,37 @@ pub fn read_record_batch(
let mut arrays = vec![];

if let Some(projection) = projection {
let fields = schema.fields();
for &index in projection {
let field = &fields[index];
let triple = create_array(
field_nodes,
field,
buf,
buffers,
dictionaries_by_id,
node_index,
buffer_index,
)?;
node_index = triple.1;
buffer_index = triple.2;
arrays.push(triple.0);
// project fields
for (idx, field) in schema.fields().iter().enumerate() {
// Create array for projected field
if projection.contains(&idx) {
let triple = create_array(
field_nodes,
field,
buf,
buffers,
dictionaries_by_id,
node_index,
buffer_index,
)?;
node_index = triple.1;
buffer_index = triple.2;
arrays.push(triple.0);
} else {
// Skip field.
// This must be called to advance `node_index` and `buffer_index`.
let tuple = skip_field(
field_nodes,
field,
buf,
buffers,
dictionaries_by_id,
node_index,
buffer_index,
)?;
node_index = tuple.0;
buffer_index = tuple.1;
}
}

RecordBatch::try_new(Arc::new(schema.project(projection)?), arrays)
Expand Down Expand Up @@ -1032,7 +1162,7 @@ mod tests {

use flate2::read::GzDecoder;

use crate::datatypes::{ArrowNativeType, Int8Type};
use crate::datatypes::{ArrowNativeType, Float64Type, Int32Type, Int8Type};
use crate::{datatypes, util::integration_util::*};

#[test]
Expand Down Expand Up @@ -1260,6 +1390,169 @@ mod tests {
});
}

fn create_test_projection_schema() -> Schema {
// define field types
let list_data_type =
DataType::List(Box::new(Field::new("item", DataType::Int32, true)));

let fixed_size_list_data_type = DataType::FixedSizeList(
Box::new(Field::new("item", DataType::Int32, false)),
3,
);

let key_type = DataType::Int8;
let value_type = DataType::Utf8;
let dict_data_type =
DataType::Dictionary(Box::new(key_type), Box::new(value_type));

let union_fileds = vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Float64, false),
];
let union_data_type = DataType::Union(union_fileds, vec![0, 1], UnionMode::Dense);

let struct_fields = vec![
Field::new("id", DataType::Int32, false),
Field::new(
"list",
DataType::List(Box::new(Field::new("item", DataType::Int8, true))),
false,
),
];
let struct_data_type = DataType::Struct(struct_fields);

// define schema
Schema::new(vec![
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

Field::new("f0", DataType::UInt32, false),
Field::new("f1", DataType::Utf8, false),
Field::new("f2", DataType::Boolean, false),
Field::new("f3", union_data_type, true),
Field::new("f4", DataType::Null, true),
Field::new("f5", DataType::Float64, true),
Field::new("f6", list_data_type, false),
Field::new("f7", DataType::FixedSizeBinary(3), true),
Field::new("f8", fixed_size_list_data_type, false),
Field::new("f9", struct_data_type, false),
Field::new("f10", DataType::Boolean, false),
Field::new("f11", dict_data_type, false),
Field::new("f12", DataType::Utf8, false),
])
}

fn create_test_projection_batch_data(schema: &Schema) -> RecordBatch {
// set test data for each column
let array0 = UInt32Array::from(vec![1, 2, 3]);
let array1 = StringArray::from(vec!["foo", "bar", "baz"]);
let array2 = BooleanArray::from(vec![true, false, true]);

let mut union_builder = UnionBuilder::new_dense(3);
union_builder.append::<Int32Type>("a", 1).unwrap();
union_builder.append::<Float64Type>("b", 10.1).unwrap();
union_builder.append_null::<Float64Type>("b").unwrap();
let array3 = union_builder.build().unwrap();

let array4 = NullArray::new(3);
let array5 = Float64Array::from(vec![Some(1.1), None, Some(3.3)]);
let array6_values = vec![
Some(vec![Some(10), Some(10), Some(10)]),
Some(vec![Some(20), Some(20), Some(20)]),
Some(vec![Some(30), Some(30)]),
];
let array6 = ListArray::from_iter_primitive::<Int32Type, _, _>(array6_values);
let array7_values = vec![vec![11, 12, 13], vec![22, 23, 24], vec![33, 34, 35]];
let array7 =
FixedSizeBinaryArray::try_from_iter(array7_values.into_iter()).unwrap();

let array8_values = ArrayData::builder(DataType::Int32)
.len(9)
.add_buffer(Buffer::from_slice_ref(&[
40, 41, 42, 43, 44, 45, 46, 47, 48,
]))
.build()
.unwrap();
let array8_data = ArrayData::builder(schema.field(8).data_type().clone())
.len(3)
.add_child_data(array8_values)
.build()
.unwrap();
let array8 = FixedSizeListArray::from(array8_data);

let array9_id: ArrayRef = Arc::new(Int32Array::from(vec![1001, 1002, 1003]));
let array9_list: ArrayRef =
Arc::new(ListArray::from_iter_primitive::<Int8Type, _, _>(vec![
Some(vec![Some(-10)]),
Some(vec![Some(-20), Some(-20), Some(-20)]),
Some(vec![Some(-30)]),
]));
let array9 = ArrayDataBuilder::new(schema.field(9).data_type().clone())
.add_child_data(array9_id.data().clone())
.add_child_data(array9_list.data().clone())
.len(3)
.build()
.unwrap();
let array9: ArrayRef = Arc::new(StructArray::from(array9));

let array10 = BooleanArray::from(vec![false, false, true]);

let array11_values = StringArray::from(vec!["x", "yy", "zzz"]);
let array11_keys = Int8Array::from_iter_values([1, 1, 2]);
let array11 =
DictionaryArray::<Int8Type>::try_new(&array11_keys, &array11_values).unwrap();

let array12 = StringArray::from(vec!["a", "bb", "ccc"]);

// create record batch
RecordBatch::try_new(
Arc::new(schema.clone()),
vec![
Arc::new(array0),
Arc::new(array1),
Arc::new(array2),
Arc::new(array3),
Arc::new(array4),
Arc::new(array5),
Arc::new(array6),
Arc::new(array7),
Arc::new(array8),
Arc::new(array9),
Arc::new(array10),
Arc::new(array11),
Arc::new(array12),
],
)
.unwrap()
}

#[test]
fn test_projection_array_values() {
// define schema
let schema = create_test_projection_schema();

// create record batch with test data
let batch = create_test_projection_batch_data(&schema);

// write record batch in IPC format
let mut buf = Vec::new();
{
let mut writer = ipc::writer::FileWriter::try_new(&mut buf, &schema).unwrap();
writer.write(&batch).unwrap();
writer.finish().unwrap();
}

// read record batch with projection
for index in 0..12 {
let projection = vec![index];
let reader =
FileReader::try_new(std::io::Cursor::new(buf.clone()), Some(projection));
let read_batch = reader.unwrap().next().unwrap().unwrap();
let projected_column = read_batch.column(0);
let expected_column = batch.column(index);

// check the projected column equals the expected column
assert_eq!(projected_column.as_ref(), expected_column.as_ref());
}
}

#[test]
fn test_arrow_single_float_row() {
let schema = Schema::new(vec![
Expand Down