Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Commit

Permalink
Support dictionaries in nested types over IPC (#587)
Browse files Browse the repository at this point in the history
  • Loading branch information
jorgecarleitao authored Nov 8, 2021
1 parent 6e9ea35 commit 37a9c75
Show file tree
Hide file tree
Showing 34 changed files with 515 additions and 196 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use arrow_format::ipc::Message::MessageHeader;
use futures::{channel::mpsc, sink::SinkExt, stream, StreamExt};
use tonic::{Request, Streaming};

use std::sync::Arc;
use std::{collections::HashMap, sync::Arc};

type ArrayRef = Arc<dyn Array>;
type SchemaRef = Arc<Schema>;
Expand Down Expand Up @@ -199,10 +199,10 @@ async fn consume_flight_location(
// first FlightData. Ignore this one.
let _schema_again = resp.next().await.unwrap();

let mut dictionaries_by_field = vec![None; schema.fields().len()];
let mut dictionaries = Default::default();

for (counter, expected_batch) in expected_data.iter().enumerate() {
let data = receive_batch_flight_data(&mut resp, schema.clone(), &mut dictionaries_by_field)
let data = receive_batch_flight_data(&mut resp, schema.clone(), &mut dictionaries)
.await
.unwrap_or_else(|| {
panic!(
Expand All @@ -215,7 +215,7 @@ async fn consume_flight_location(
let metadata = counter.to_string().into_bytes();
assert_eq!(metadata, data.app_metadata);

let actual_batch = deserialize_batch(&data, schema.clone(), true, &dictionaries_by_field)
let actual_batch = deserialize_batch(&data, schema.clone(), true, &dictionaries)
.expect("Unable to convert flight data to Arrow batch");

assert_eq!(expected_batch.schema(), actual_batch.schema());
Expand Down Expand Up @@ -245,7 +245,7 @@ async fn consume_flight_location(
async fn receive_batch_flight_data(
resp: &mut Streaming<FlightData>,
schema: SchemaRef,
dictionaries_by_field: &mut [Option<ArrayRef>],
dictionaries: &mut HashMap<usize, Arc<dyn Array>>,
) -> Option<FlightData> {
let mut data = resp.next().await?.ok()?;
let mut message =
Expand All @@ -259,7 +259,7 @@ async fn receive_batch_flight_data(
.expect("Error parsing dictionary"),
&schema,
true,
dictionaries_by_field,
dictionaries,
&mut reader,
0,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ async fn record_batch_from_message(
message: Message<'_>,
data_body: &[u8],
schema_ref: Arc<Schema>,
dictionaries_by_field: &[Option<Arc<dyn Array>>],
dictionaries: &mut HashMap<usize, Arc<dyn Array>>,
) -> Result<RecordBatch, Status> {
let ipc_batch = message
.header_as_record_batch()
Expand All @@ -288,7 +288,7 @@ async fn record_batch_from_message(
schema_ref,
None,
true,
dictionaries_by_field,
dictionaries,
ArrowSchema::MetadataVersion::V5,
&mut reader,
0,
Expand All @@ -302,22 +302,16 @@ async fn dictionary_from_message(
message: Message<'_>,
data_body: &[u8],
schema_ref: Arc<Schema>,
dictionaries_by_field: &mut [Option<Arc<dyn Array>>],
dictionaries: &mut HashMap<usize, Arc<dyn Array>>,
) -> Result<(), Status> {
let ipc_batch = message
.header_as_dictionary_batch()
.ok_or_else(|| Status::internal("Could not parse message header as dictionary batch"))?;

let mut reader = std::io::Cursor::new(data_body);

let dictionary_batch_result = ipc::read::read_dictionary(
ipc_batch,
&schema_ref,
true,
dictionaries_by_field,
&mut reader,
0,
);
let dictionary_batch_result =
ipc::read::read_dictionary(ipc_batch, &schema_ref, true, dictionaries, &mut reader, 0);
dictionary_batch_result
.map_err(|e| Status::internal(format!("Could not convert to Dictionary: {:?}", e)))
}
Expand All @@ -333,7 +327,7 @@ async fn save_uploaded_chunks(
let mut chunks = vec![];
let mut uploaded_chunks = uploaded_chunks.lock().await;

let mut dictionaries_by_field = vec![None; schema_ref.fields().len()];
let mut dictionaries = Default::default();

while let Some(Ok(data)) = input_stream.next().await {
let message = root_as_message(&data.data_header[..])
Expand All @@ -352,7 +346,7 @@ async fn save_uploaded_chunks(
message,
&data.data_body,
schema_ref.clone(),
&dictionaries_by_field,
&mut dictionaries,
)
.await?;

Expand All @@ -363,7 +357,7 @@ async fn save_uploaded_chunks(
message,
&data.data_body,
schema_ref.clone(),
&mut dictionaries_by_field,
&mut dictionaries,
)
.await?;
}
Expand Down
17 changes: 12 additions & 5 deletions integration-testing/unskip.patch
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
diff --git a/dev/archery/archery/integration/datagen.py b/dev/archery/archery/integration/datagen.py
index 2d90d6c86..d5a0bc833 100644
index 6a077a893..cab6ecd37 100644
--- a/dev/archery/archery/integration/datagen.py
+++ b/dev/archery/archery/integration/datagen.py
@@ -1569,8 +1569,7 @@ def get_generated_json_files(tempdir=None):
@@ -1561,8 +1561,7 @@ def get_generated_json_files(tempdir=None):
.skip_category('C#')
.skip_category('JS'), # TODO(ARROW-7900)

Expand All @@ -12,7 +12,7 @@ index 2d90d6c86..d5a0bc833 100644

generate_decimal256_case()
.skip_category('Go') # TODO(ARROW-7948): Decimal + Go
@@ -1582,18 +1581,15 @@ def get_generated_json_files(tempdir=None):
@@ -1574,18 +1573,15 @@ def get_generated_json_files(tempdir=None):

generate_interval_case()
.skip_category('C#')
Expand All @@ -34,7 +34,7 @@ index 2d90d6c86..d5a0bc833 100644

generate_non_canonical_map_case()
.skip_category('C#')
@@ -1611,14 +1607,12 @@ def get_generated_json_files(tempdir=None):
@@ -1602,14 +1598,12 @@ def get_generated_json_files(tempdir=None):
generate_nested_large_offsets_case()
.skip_category('C#')
.skip_category('Go')
Expand All @@ -51,7 +51,14 @@ index 2d90d6c86..d5a0bc833 100644

generate_custom_metadata_case()
.skip_category('C#')
@@ -1649,8 +1643,7 @@ def get_generated_json_files(tempdir=None):
@@ -1634,14 +1628,12 @@ def get_generated_json_files(tempdir=None):
.skip_category('C#')
.skip_category('Go')
.skip_category('Java') # TODO(ARROW-7779)
- .skip_category('JS')
- .skip_category('Rust'),
+ .skip_category('JS'),

generate_extension_case()
.skip_category('C#')
.skip_category('Go') # TODO(ARROW-3039): requires dictionaries
Expand Down
12 changes: 7 additions & 5 deletions src/array/dictionary/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ mod mutable;
pub use iterator::*;
pub use mutable::*;

use super::{new_empty_array, primitive::PrimitiveArray, Array};
use super::display::get_value_display;
use super::{display_fmt, new_empty_array, primitive::PrimitiveArray, Array};
use crate::scalar::NullScalar;

/// Trait denoting [`NativeType`]s that can be used as keys of a dictionary.
Expand Down Expand Up @@ -196,9 +197,10 @@ where
PrimitiveArray<K>: std::fmt::Display,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "{:?}{{", self.data_type())?;
writeln!(f, "keys: {},", self.keys())?;
writeln!(f, "values: {},", self.values())?;
write!(f, "}}")
let display = get_value_display(self);
let new_lines = false;
let head = &format!("{}", self.data_type());
let iter = self.iter().enumerate().map(|(i, x)| x.map(|_| display(i)));
display_fmt(iter, head, f, new_lines)
}
}
8 changes: 7 additions & 1 deletion src/array/display.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,13 @@ pub fn get_value_display<'a>(array: &'a dyn Array) -> Box<dyn Fn(usize) -> Strin
.unwrap();
let keys = a.keys();
let display = get_display(a.values().as_ref());
Box::new(move |row: usize| display(keys.value(row) as usize))
Box::new(move |row: usize| {
if keys.is_null(row) {
"".to_string()
}else {
display(keys.value(row) as usize)
}
})
}),
Map(_, _) => todo!(),
Struct(_) => {
Expand Down
22 changes: 21 additions & 1 deletion src/datatypes/field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use super::DataType;

/// A logical [`DataType`] and its associated metadata per
/// [Arrow specification](https://arrow.apache.org/docs/cpp/api/datatype.html)
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[derive(Debug, Clone, Eq)]
pub struct Field {
/// Its name
pub name: String,
Expand All @@ -39,6 +39,26 @@ pub struct Field {
pub metadata: Option<BTreeMap<String, String>>,
}

impl std::hash::Hash for Field {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.name.hash(state);
self.data_type.hash(state);
self.nullable.hash(state);
self.dict_is_ordered.hash(state);
self.metadata.hash(state);
}
}

impl PartialEq for Field {
fn eq(&self, other: &Self) -> bool {
self.name == other.name
&& self.data_type == other.data_type
&& self.nullable == other.nullable
&& self.dict_is_ordered == other.dict_is_ordered
&& self.metadata == other.metadata
}
}

impl Field {
/// Creates a new field
pub fn new(name: &str, data_type: DataType, nullable: bool) -> Self {
Expand Down
8 changes: 0 additions & 8 deletions src/datatypes/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,14 +164,6 @@ impl Schema {
Ok(&self.fields[self.index_of(name)?])
}

/// Returns all [`Field`]s with dictionary id `dict_id`.
pub fn fields_with_dict_id(&self, dict_id: i64) -> Vec<&Field> {
self.fields
.iter()
.filter(|f| f.dict_id() == Some(dict_id))
.collect()
}

/// Find the index of the column with the given name.
pub fn index_of(&self, name: &str) -> Result<usize> {
for i in 0..self.fields.len() {
Expand Down
5 changes: 3 additions & 2 deletions src/io/flight/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::collections::HashMap;
use std::convert::TryFrom;
use std::sync::Arc;

Expand Down Expand Up @@ -122,7 +123,7 @@ pub fn deserialize_batch(
data: &FlightData,
schema: Arc<Schema>,
is_little_endian: bool,
dictionaries_by_field: &[Option<Arc<dyn Array>>],
dictionaries: &HashMap<usize, Arc<dyn Array>>,
) -> Result<RecordBatch> {
// check that the data_header is a record batch message
let message = ipc::Message::root_as_message(&data.data_header[..])
Expand All @@ -141,7 +142,7 @@ pub fn deserialize_batch(
schema.clone(),
None,
is_little_endian,
dictionaries_by_field,
dictionaries,
ipc::Schema::MetadataVersion::V5,
&mut reader,
0,
Expand Down
20 changes: 2 additions & 18 deletions src/io/ipc/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -663,24 +663,8 @@ pub(crate) fn get_fb_field_type<'a>(
}
}
Struct(fields) => {
// struct's fields are children
let mut children = vec![];
for field in fields {
let inner_types = get_fb_field_type(field.data_type(), field.is_nullable(), fbb);
let field_name = fbb.create_string(field.name());
children.push(ipc::Field::create(
fbb,
&ipc::FieldArgs {
name: Some(field_name),
nullable: field.is_nullable(),
type_type: inner_types.type_type,
type_: Some(inner_types.type_),
dictionary: None,
children: inner_types.children,
custom_metadata: None,
},
));
}
let children: Vec<_> = fields.iter().map(|field| build_field(fbb, field)).collect();

FbFieldType {
type_type,
type_: ipc::Struct_Builder::new(fbb).finish().as_union_value(),
Expand Down
2 changes: 1 addition & 1 deletion src/io/ipc/read/array/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ pub fn read_binary<O: Offset, R: Read + Seek>(
where
Vec<u8>: TryInto<O::Bytes> + TryInto<<u8 as NativeType>::Bytes>,
{
let field_node = field_nodes.pop_front().unwrap().0;
let field_node = field_nodes.pop_front().unwrap();

let validity = read_validity(
buffers,
Expand Down
2 changes: 1 addition & 1 deletion src/io/ipc/read/array/boolean.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ pub fn read_boolean<R: Read + Seek>(
is_little_endian: bool,
compression: Option<ipc::Message::BodyCompression>,
) -> Result<BooleanArray> {
let field_node = field_nodes.pop_front().unwrap().0;
let field_node = field_nodes.pop_front().unwrap();

let length = field_node.length() as usize;
let validity = read_validity(
Expand Down
25 changes: 20 additions & 5 deletions src/io/ipc/read/array/dictionary.rs
Original file line number Diff line number Diff line change
@@ -1,27 +1,42 @@
use std::collections::VecDeque;
use std::collections::{HashMap, HashSet, VecDeque};
use std::convert::TryInto;
use std::io::{Read, Seek};
use std::sync::Arc;

use arrow_format::ipc;

use crate::array::{DictionaryArray, DictionaryKey};
use crate::error::Result;
use crate::array::{Array, DictionaryArray, DictionaryKey};
use crate::datatypes::Field;
use crate::error::{ArrowError, Result};

use super::super::deserialize::Node;
use super::{read_primitive, skip_primitive};

#[allow(clippy::too_many_arguments)]
pub fn read_dictionary<T: DictionaryKey, R: Read + Seek>(
field_nodes: &mut VecDeque<Node>,
field: &Field,
buffers: &mut VecDeque<&ipc::Schema::Buffer>,
reader: &mut R,
dictionaries: &HashMap<usize, Arc<dyn Array>>,
block_offset: u64,
compression: Option<ipc::Message::BodyCompression>,
is_little_endian: bool,
) -> Result<DictionaryArray<T>>
where
Vec<u8>: TryInto<T::Bytes>,
{
let values = field_nodes.front().unwrap().1.as_ref().unwrap();
let id = field.dict_id().unwrap() as usize;
let values = dictionaries
.get(&id)
.ok_or_else(|| {
let valid_ids = dictionaries.keys().collect::<HashSet<_>>();
ArrowError::Ipc(format!(
"Dictionary id {} not found. Valid ids: {:?}",
id, valid_ids
))
})?
.clone();

let keys = read_primitive(
field_nodes,
Expand All @@ -33,7 +48,7 @@ where
compression,
)?;

Ok(DictionaryArray::<T>::from_data(keys, values.clone()))
Ok(DictionaryArray::<T>::from_data(keys, values))
}

pub fn skip_dictionary(
Expand Down
2 changes: 1 addition & 1 deletion src/io/ipc/read/array/fixed_size_binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ pub fn read_fixed_size_binary<R: Read + Seek>(
is_little_endian: bool,
compression: Option<ipc::Message::BodyCompression>,
) -> Result<FixedSizeBinaryArray> {
let field_node = field_nodes.pop_front().unwrap().0;
let field_node = field_nodes.pop_front().unwrap();

let validity = read_validity(
buffers,
Expand Down
Loading

0 comments on commit 37a9c75

Please sign in to comment.