From 14bd53dc1240003f171c8655863eae188cd0880f Mon Sep 17 00:00:00 2001 From: Dan Harris <1327726+thinkharderdev@users.noreply.github.com> Date: Thu, 14 Mar 2024 22:47:48 -0400 Subject: [PATCH] Support dictionary encoding in structures for `FlightDataEncoder`, add documentation for `arrow_flight::encode::Dictionary` (#5488) * Add more detailed documentation for arrow_flight::encode::DicationaryHandling * fix doc link * Fix handling of nested dictionary arrays with DictionaryHandling::Hydrate * clippy * Handle large list and sparse unions * use top-level fields * PR comments --- arrow-flight/src/encode.rs | 478 +++++++++++++++++++++++++++++++++---- 1 file changed, 435 insertions(+), 43 deletions(-) diff --git a/arrow-flight/src/encode.rs b/arrow-flight/src/encode.rs index bb043681620..efd68812948 100644 --- a/arrow-flight/src/encode.rs +++ b/arrow-flight/src/encode.rs @@ -18,9 +18,11 @@ use std::{collections::VecDeque, fmt::Debug, pin::Pin, sync::Arc, task::Poll}; use crate::{error::Result, FlightData, FlightDescriptor, SchemaAsIpc}; -use arrow_array::{ArrayRef, RecordBatch, RecordBatchOptions}; + +use arrow_array::{Array, ArrayRef, RecordBatch, RecordBatchOptions, UnionArray}; use arrow_ipc::writer::{DictionaryTracker, IpcDataGenerator, IpcWriteOptions}; -use arrow_schema::{DataType, Field, Fields, Schema, SchemaRef}; + +use arrow_schema::{DataType, Field, FieldRef, Fields, Schema, SchemaRef, UnionMode}; use bytes::Bytes; use futures::{ready, stream::BoxStream, Stream, StreamExt}; @@ -323,9 +325,10 @@ impl FlightDataEncoder { None => self.encode_schema(batch.schema_ref()), }; - // encode the batch - let send_dictionaries = self.dictionary_handling == DictionaryHandling::Resend; - let batch = prepare_batch_for_flight(&batch, schema, send_dictionaries)?; + let batch = match self.dictionary_handling { + DictionaryHandling::Resend => batch, + DictionaryHandling::Hydrate => hydrate_dictionaries(&batch, schema)?, + }; for batch in split_batch_for_grpc_response(batch, self.max_flight_data_size) { let (flight_dictionaries, flight_batch) = self.encoder.encode_batch(&batch)?; @@ -388,6 +391,31 @@ impl Stream for FlightDataEncoder { /// Defines how a [`FlightDataEncoder`] encodes [`DictionaryArray`]s /// /// [`DictionaryArray`]: arrow_array::DictionaryArray +/// +/// In the arrow flight protocol dictionary values and keys are sent as two separate messages. +/// When a sender is encoding a [`RecordBatch`] containing ['DictionaryArray'] columns, it will +/// first send a dictionary batch (a batch with header `MessageHeader::DictionaryBatch`) containing +/// the dictionary values. The receiver is responsible for reading this batch and maintaining state that associates +/// those dictionary values with the corresponding array using the `dict_id` as a key. +/// +/// After sending the dictionary batch the sender will send the array data in a batch with header `MessageHeader::RecordBatch`. +/// For any dictionary array batches in this message, the encoded flight message will only contain the dictionary keys. The receiver +/// is then responsible for rebuilding the `DictionaryArray` on the client side using the dictionary values from the DictionaryBatch message +/// and the keys from the RecordBatch message. +/// +/// For example, if we have a batch with a `TypedDictionaryArray<'_, UInt32Type, Utf8Type>` (a dictionary array where they keys are `u32` and the +/// values are `String`), then the DictionaryBatch will contain a `StringArray` and the RecordBatch will contain a `UInt32Array`. +/// +/// Note that since `dict_id` defined in the `Schema` is used as a key to associate dictionary values to their arrays it is required that each +/// `DictionaryArray` in a `RecordBatch` have a unique `dict_id`. +/// +/// The current implementation does not support "delta" dictionaries so a new dictionary batch will be sent each time the encoder sees a +/// dictionary which is not pointer-equal to the previously observed dictionary for a given `dict_id`. +/// +/// For clients which may not support `DictionaryEncoding`, the `DictionaryHandling::Hydrate` method will bypass the process defined above +/// and "hydrate" any `DictionaryArray` in the batch to their underlying value type (e.g. `TypedDictionaryArray<'_, UInt32Type, Utf8Type>` will +/// be sent as a `StringArray`). With this method all data will be sent in ``MessageHeader::RecordBatch` messages and the batch schema +/// will be adjusted so that all dictionary encoded fields are changed to fields of the dictionary value type. #[derive(Debug, PartialEq)] pub enum DictionaryHandling { /// Expands to the underlying type (default). This likely sends more data @@ -395,13 +423,6 @@ pub enum DictionaryHandling { /// and is more compatible with other arrow flight client implementations /// that may not support `DictionaryEncoding` /// - /// An IPC response, streaming or otherwise, defines its schema up front - /// which defines the mapping from dictionary IDs. It then sends these - /// dictionaries over the wire. - /// - /// This requires identifying the different dictionaries in use, assigning - /// them IDs, and sending new dictionaries, delta or otherwise, when needed - /// /// See also: /// * Hydrate, @@ -411,9 +432,52 @@ pub enum DictionaryHandling { /// twice. /// /// [`DictionaryArray`]: arrow_array::DictionaryArray + /// + /// This requires identifying the different dictionaries in use and assigning + // them unique IDs Resend, } +fn prepare_field_for_flight(field: &FieldRef, send_dictionaries: bool) -> Field { + match field.data_type() { + DataType::List(inner) => Field::new_list( + field.name(), + prepare_field_for_flight(inner, send_dictionaries), + field.is_nullable(), + ) + .with_metadata(field.metadata().clone()), + DataType::LargeList(inner) => Field::new_list( + field.name(), + prepare_field_for_flight(inner, send_dictionaries), + field.is_nullable(), + ) + .with_metadata(field.metadata().clone()), + DataType::Struct(fields) => { + let new_fields: Vec = fields + .iter() + .map(|f| prepare_field_for_flight(f, send_dictionaries)) + .collect(); + Field::new_struct(field.name(), new_fields, field.is_nullable()) + .with_metadata(field.metadata().clone()) + } + DataType::Union(fields, mode) => { + let (type_ids, new_fields): (Vec, Vec) = fields + .iter() + .map(|(type_id, f)| (type_id, prepare_field_for_flight(f, send_dictionaries))) + .unzip(); + + Field::new_union(field.name(), type_ids, new_fields, *mode) + } + DataType::Dictionary(_, value_type) if !send_dictionaries => Field::new( + field.name(), + value_type.as_ref().clone(), + field.is_nullable(), + ) + .with_metadata(field.metadata().clone()), + _ => field.as_ref().clone(), + } +} + /// Prepare an arrow Schema for transport over the Arrow Flight protocol /// /// Convert dictionary types to underlying types @@ -430,6 +494,7 @@ fn prepare_schema_for_flight(schema: &Schema, send_dictionaries: bool) -> Schema field.is_nullable(), ) .with_metadata(field.metadata().clone()), + tpe if tpe.is_nested() => prepare_field_for_flight(field, send_dictionaries), _ => field.as_ref().clone(), }) .collect(); @@ -509,22 +574,14 @@ impl FlightIpcEncoder { } } -/// Prepares a RecordBatch for transport over the Arrow Flight protocol -/// -/// This means: -/// -/// 1. Hydrates any dictionaries to its underlying type. See +/// Hydrates any dictionaries arrays in `batch` to its underlying type. See /// hydrate_dictionary for more information. -/// -fn prepare_batch_for_flight( - batch: &RecordBatch, - schema: SchemaRef, - send_dictionaries: bool, -) -> Result { - let columns = batch - .columns() +fn hydrate_dictionaries(batch: &RecordBatch, schema: SchemaRef) -> Result { + let columns = schema + .fields() .iter() - .map(|c| hydrate_dictionary(c, send_dictionaries)) + .zip(batch.columns()) + .map(|(field, c)| hydrate_dictionary(c, field.data_type())) .collect::>>()?; let options = RecordBatchOptions::new().with_row_count(Some(batch.num_rows())); @@ -534,22 +591,43 @@ fn prepare_batch_for_flight( )?) } -/// Hydrates a dictionary to its underlying type if send_dictionaries is false. If send_dictionaries -/// is true, dictionaries are sent with every batch which is not as optimal as described in [DictionaryHandling::Hydrate] above, -/// but does enable sending DictionaryArray's via Flight. -fn hydrate_dictionary(array: &ArrayRef, send_dictionaries: bool) -> Result { - let arr = match array.data_type() { - DataType::Dictionary(_, value) if !send_dictionaries => arrow_cast::cast(array, value)?, - _ => Arc::clone(array), +/// Hydrates a dictionary to its underlying type. +fn hydrate_dictionary(array: &ArrayRef, data_type: &DataType) -> Result { + let arr = match (array.data_type(), data_type) { + (DataType::Union(_, UnionMode::Sparse), DataType::Union(fields, UnionMode::Sparse)) => { + let union_arr = array.as_any().downcast_ref::().unwrap(); + + let (type_ids, fields): (Vec, Vec<&FieldRef>) = fields.iter().unzip(); + + Arc::new(UnionArray::try_new( + &type_ids, + union_arr.type_ids().inner().clone(), + None, + fields + .iter() + .enumerate() + .map(|(col, field)| { + Ok(( + field.as_ref().clone(), + arrow_cast::cast(union_arr.child(col as i8), field.data_type())?, + )) + }) + .collect::>>()?, + )?) + } + (_, data_type) => arrow_cast::cast(array, data_type)?, }; Ok(arr) } #[cfg(test)] mod tests { + use arrow_array::builder::StringDictionaryBuilder; use arrow_array::*; use arrow_array::{cast::downcast_array, types::*}; + use arrow_buffer::Buffer; use arrow_cast::pretty::pretty_format_batches; + use arrow_schema::UnionMode; use std::collections::HashMap; use crate::decode::{DecodedPayload, FlightDataDecoder}; @@ -570,8 +648,8 @@ mod tests { let (_, baseline_flight_batch) = make_flight_data(&batch, &options); let big_batch = batch.slice(0, batch.num_rows() - 1); - let optimized_big_batch = prepare_batch_for_flight(&big_batch, Arc::clone(schema), false) - .expect("failed to optimize"); + let optimized_big_batch = + hydrate_dictionaries(&big_batch, Arc::clone(schema)).expect("failed to optimize"); let (_, optimized_big_flight_batch) = make_flight_data(&optimized_big_batch, &options); assert_eq!( @@ -581,8 +659,7 @@ mod tests { let small_batch = batch.slice(0, 1); let optimized_small_batch = - prepare_batch_for_flight(&small_batch, Arc::clone(schema), false) - .expect("failed to optimize"); + hydrate_dictionaries(&small_batch, Arc::clone(schema)).expect("failed to optimize"); let (_, optimized_small_flight_batch) = make_flight_data(&optimized_small_batch, &options); assert!( @@ -592,19 +669,29 @@ mod tests { #[tokio::test] async fn test_dictionary_hydration() { - let arr: DictionaryArray = vec!["a", "a", "b"].into_iter().collect(); + let arr1: DictionaryArray = vec!["a", "a", "b"].into_iter().collect(); + let arr2: DictionaryArray = vec!["c", "c", "d"].into_iter().collect(); + let schema = Arc::new(Schema::new(vec![Field::new_dictionary( "dict", DataType::UInt16, DataType::Utf8, false, )])); - let batch = RecordBatch::try_new(schema, vec![Arc::new(arr)]).unwrap(); - let encoder = - FlightDataEncoderBuilder::default().build(futures::stream::once(async { Ok(batch) })); + let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap(); + let batch2 = RecordBatch::try_new(schema, vec![Arc::new(arr2)]).unwrap(); + + let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2)]); + + let encoder = FlightDataEncoderBuilder::default().build(stream); let mut decoder = FlightDataDecoder::new(encoder); let expected_schema = Schema::new(vec![Field::new("dict", DataType::Utf8, false)]); let expected_schema = Arc::new(expected_schema); + let mut expected_arrays = vec![ + StringArray::from(vec!["a", "a", "b"]), + StringArray::from(vec!["c", "c", "d"]), + ] + .into_iter(); while let Some(decoded) = decoder.next().await { let decoded = decoded.unwrap(); match decoded.payload { @@ -612,7 +699,7 @@ mod tests { DecodedPayload::Schema(s) => assert_eq!(s, expected_schema), DecodedPayload::RecordBatch(b) => { assert_eq!(b.schema(), expected_schema); - let expected_array = StringArray::from(vec!["a", "a", "b"]); + let expected_array = expected_arrays.next().unwrap(); let actual_array = b.column_by_name("dict").unwrap(); let actual_array = downcast_array::(actual_array); @@ -622,6 +709,311 @@ mod tests { } } + #[tokio::test] + async fn test_dictionary_list_hydration() { + let mut builder = builder::ListBuilder::new(StringDictionaryBuilder::::new()); + + builder.append_value(vec![Some("a"), None, Some("b")]); + + let arr1 = builder.finish(); + + builder.append_value(vec![Some("c"), None, Some("d")]); + + let arr2 = builder.finish(); + + let schema = Arc::new(Schema::new(vec![Field::new_list( + "dict_list", + Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true), + true, + )])); + + let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap(); + let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap(); + + let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2)]); + + let encoder = FlightDataEncoderBuilder::default().build(stream); + + let mut decoder = FlightDataDecoder::new(encoder); + let expected_schema = Schema::new(vec![Field::new_list( + "dict_list", + Field::new("item", DataType::Utf8, true), + true, + )]); + + let expected_schema = Arc::new(expected_schema); + + let mut expected_arrays = vec![ + StringArray::from_iter(vec![Some("a"), None, Some("b")]), + StringArray::from_iter(vec![Some("c"), None, Some("d")]), + ] + .into_iter(); + + while let Some(decoded) = decoder.next().await { + let decoded = decoded.unwrap(); + match decoded.payload { + DecodedPayload::None => {} + DecodedPayload::Schema(s) => assert_eq!(s, expected_schema), + DecodedPayload::RecordBatch(b) => { + assert_eq!(b.schema(), expected_schema); + let expected_array = expected_arrays.next().unwrap(); + let list_array = + downcast_array::(b.column_by_name("dict_list").unwrap()); + let elem_array = downcast_array::(list_array.value(0).as_ref()); + + assert_eq!(elem_array, expected_array); + } + } + } + } + + #[tokio::test] + async fn test_dictionary_struct_hydration() { + let struct_fields = vec![Field::new_list( + "dict_list", + Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true), + true, + )]; + + let mut builder = builder::ListBuilder::new(StringDictionaryBuilder::::new()); + + builder.append_value(vec![Some("a"), None, Some("b")]); + + let arr1 = Arc::new(builder.finish()); + let arr1 = StructArray::new(struct_fields.clone().into(), vec![arr1], None); + + builder.append_value(vec![Some("c"), None, Some("d")]); + + let arr2 = Arc::new(builder.finish()); + let arr2 = StructArray::new(struct_fields.clone().into(), vec![arr2], None); + + let schema = Arc::new(Schema::new(vec![Field::new_struct( + "struct", + struct_fields.clone(), + true, + )])); + + let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap(); + let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap(); + + let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2)]); + + let encoder = FlightDataEncoderBuilder::default().build(stream); + + let mut decoder = FlightDataDecoder::new(encoder); + let expected_schema = Schema::new(vec![Field::new_struct( + "struct", + vec![Field::new_list( + "dict_list", + Field::new("item", DataType::Utf8, true), + true, + )], + true, + )]); + + let expected_schema = Arc::new(expected_schema); + + let mut expected_arrays = vec![ + StringArray::from_iter(vec![Some("a"), None, Some("b")]), + StringArray::from_iter(vec![Some("c"), None, Some("d")]), + ] + .into_iter(); + + while let Some(decoded) = decoder.next().await { + let decoded = decoded.unwrap(); + match decoded.payload { + DecodedPayload::None => {} + DecodedPayload::Schema(s) => assert_eq!(s, expected_schema), + DecodedPayload::RecordBatch(b) => { + assert_eq!(b.schema(), expected_schema); + let expected_array = expected_arrays.next().unwrap(); + let struct_array = + downcast_array::(b.column_by_name("struct").unwrap()); + let list_array = downcast_array::(struct_array.column(0)); + + let elem_array = downcast_array::(list_array.value(0).as_ref()); + + assert_eq!(elem_array, expected_array); + } + } + } + } + + #[tokio::test] + async fn test_dictionary_union_hydration() { + let struct_fields = vec![Field::new_list( + "dict_list", + Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true), + true, + )]; + + let type_ids = vec![0, 1, 2]; + let union_fields = vec![ + Field::new_list( + "dict_list", + Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true), + true, + ), + Field::new_struct("struct", struct_fields.clone(), true), + Field::new("string", DataType::Utf8, true), + ]; + + let struct_fields = vec![Field::new_list( + "dict_list", + Field::new_dictionary("item", DataType::UInt16, DataType::Utf8, true), + true, + )]; + + let mut builder = builder::ListBuilder::new(StringDictionaryBuilder::::new()); + + builder.append_value(vec![Some("a"), None, Some("b")]); + + let arr1 = builder.finish(); + + let type_id_buffer = Buffer::from_slice_ref([0_i8]); + let arr1 = UnionArray::try_new( + &type_ids, + type_id_buffer, + None, + vec![ + (union_fields[0].clone(), Arc::new(arr1)), + ( + union_fields[1].clone(), + new_null_array(union_fields[1].data_type(), 1), + ), + ( + union_fields[2].clone(), + new_null_array(union_fields[2].data_type(), 1), + ), + ], + ) + .unwrap(); + + builder.append_value(vec![Some("c"), None, Some("d")]); + + let arr2 = Arc::new(builder.finish()); + let arr2 = StructArray::new(struct_fields.clone().into(), vec![arr2], None); + + let type_id_buffer = Buffer::from_slice_ref([1_i8]); + let arr2 = UnionArray::try_new( + &type_ids, + type_id_buffer, + None, + vec![ + ( + union_fields[0].clone(), + new_null_array(union_fields[0].data_type(), 1), + ), + (union_fields[1].clone(), Arc::new(arr2)), + ( + union_fields[2].clone(), + new_null_array(union_fields[2].data_type(), 1), + ), + ], + ) + .unwrap(); + + let type_id_buffer = Buffer::from_slice_ref([2_i8]); + let arr3 = UnionArray::try_new( + &type_ids, + type_id_buffer, + None, + vec![ + ( + union_fields[0].clone(), + new_null_array(union_fields[0].data_type(), 1), + ), + ( + union_fields[1].clone(), + new_null_array(union_fields[1].data_type(), 1), + ), + ( + union_fields[2].clone(), + Arc::new(StringArray::from(vec!["e"])), + ), + ], + ) + .unwrap(); + + let schema = Arc::new(Schema::new(vec![Field::new_union( + "union", + type_ids.clone(), + union_fields.clone(), + UnionMode::Sparse, + )])); + + let batch1 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr1)]).unwrap(); + let batch2 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr2)]).unwrap(); + let batch3 = RecordBatch::try_new(schema.clone(), vec![Arc::new(arr3)]).unwrap(); + + let stream = futures::stream::iter(vec![Ok(batch1), Ok(batch2), Ok(batch3)]); + + let encoder = FlightDataEncoderBuilder::default().build(stream); + + let mut decoder = FlightDataDecoder::new(encoder); + + let hydrated_struct_fields = vec![Field::new_list( + "dict_list", + Field::new("item", DataType::Utf8, true), + true, + )]; + + let hydrated_union_fields = vec![ + Field::new_list("dict_list", Field::new("item", DataType::Utf8, true), true), + Field::new_struct("struct", hydrated_struct_fields.clone(), true), + Field::new("string", DataType::Utf8, true), + ]; + + let expected_schema = Schema::new(vec![Field::new_union( + "union", + type_ids.clone(), + hydrated_union_fields, + UnionMode::Sparse, + )]); + + let expected_schema = Arc::new(expected_schema); + + let mut expected_arrays = vec![ + StringArray::from_iter(vec![Some("a"), None, Some("b")]), + StringArray::from_iter(vec![Some("c"), None, Some("d")]), + StringArray::from(vec!["e"]), + ] + .into_iter(); + + let mut batch = 0; + while let Some(decoded) = decoder.next().await { + let decoded = decoded.unwrap(); + match decoded.payload { + DecodedPayload::None => {} + DecodedPayload::Schema(s) => assert_eq!(s, expected_schema), + DecodedPayload::RecordBatch(b) => { + assert_eq!(b.schema(), expected_schema); + let expected_array = expected_arrays.next().unwrap(); + let union_arr = + downcast_array::(b.column_by_name("union").unwrap()); + + let elem_array = match batch { + 0 => { + let list_array = downcast_array::(union_arr.child(0)); + downcast_array::(list_array.value(0).as_ref()) + } + 1 => { + let struct_array = downcast_array::(union_arr.child(1)); + let list_array = downcast_array::(struct_array.column(0)); + + downcast_array::(list_array.value(0).as_ref()) + } + _ => downcast_array::(union_arr.child(2)), + }; + + batch += 1; + + assert_eq!(elem_array, expected_array); + } + } + } + } + #[tokio::test] async fn test_send_dictionaries() { let schema = Arc::new(Schema::new(vec![Field::new_dictionary( @@ -683,7 +1075,7 @@ mod tests { ) .expect("cannot create record batch"); - prepare_batch_for_flight(&batch, batch.schema(), false).expect("failed to optimize"); + hydrate_dictionaries(&batch, batch.schema()).expect("failed to optimize"); } pub fn make_flight_data(