-
Notifications
You must be signed in to change notification settings - Fork 803
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve performance of casting DictionaryArray
to StringViewArray
#5871
Changes from 4 commits
63db275
0308dd4
c27d548
fdebc0e
ca11706
52c28c1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -85,10 +85,69 @@ pub(crate) fn dictionary_cast<K: ArrowDictionaryKeyType>( | |
|
||
Ok(new_array) | ||
} | ||
Utf8View => { | ||
// `unpack_dictionary` can handle Utf8View/BinaryView types, but incurs unnecessary data copy of the value buffer. | ||
alamb marked this conversation as resolved.
Show resolved
Hide resolved
|
||
// we handle it here to avoid the copy. | ||
let dict_array = array | ||
.as_dictionary::<K>() | ||
.downcast_dict::<StringArray>() | ||
.unwrap(); | ||
|
||
let string_view = view_from_dict_values::<K, StringViewType, GenericStringType<i32>>( | ||
dict_array.values(), | ||
dict_array.keys(), | ||
); | ||
Ok(Arc::new(string_view)) | ||
} | ||
BinaryView => { | ||
// `unpack_dictionary` can handle Utf8View/BinaryView types, but incurs unnecessary data copy of the value buffer. | ||
// we handle it here to avoid the copy. | ||
let dict_array = array | ||
.as_dictionary::<K>() | ||
.downcast_dict::<BinaryArray>() | ||
.unwrap(); | ||
|
||
let binary_view = view_from_dict_values::<K, BinaryViewType, BinaryType>( | ||
dict_array.values(), | ||
dict_array.keys(), | ||
); | ||
Ok(Arc::new(binary_view)) | ||
} | ||
_ => unpack_dictionary::<K>(array, to_type, cast_options), | ||
} | ||
} | ||
|
||
fn view_from_dict_values<K: ArrowDictionaryKeyType, T: ByteViewType, V: ByteArrayType>( | ||
array: &GenericByteArray<V>, | ||
keys: &PrimitiveArray<K>, | ||
) -> GenericByteViewArray<T> { | ||
let value_buffer = array.values(); | ||
let value_offsets = array.value_offsets(); | ||
let mut builder = GenericByteViewBuilder::<T>::with_capacity(keys.len()); | ||
builder.append_block(value_buffer.clone()); | ||
for i in keys.iter() { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Another potential optimization is a separate loop if there are no nulls in keys (so we can avoid the branch) Another potental idea is to use However, I think we should merge this basic PR in as is, and then add a bencmark and optimize this kenrnel as a follow on PR (if we care). I can file a ticket if @tustvold agrees There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ops, I checked in a But if we do There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I can also move |
||
match i { | ||
Some(v) => { | ||
let idx = v.to_usize().unwrap(); | ||
|
||
// Safety | ||
// (1) The index is within bounds as they are offsets | ||
// (2) The append_view is safe | ||
unsafe { | ||
let offset = value_offsets.get_unchecked(idx).as_usize(); | ||
let end = value_offsets.get_unchecked(idx + 1).as_usize(); | ||
let length = end - offset; | ||
builder.append_view_unchecked(0, offset as u32, length as u32) | ||
} | ||
} | ||
None => { | ||
builder.append_null(); | ||
} | ||
} | ||
} | ||
builder.finish() | ||
} | ||
|
||
// Unpack a dictionary where the keys are of type <K> into a flattened array of type to_type | ||
pub(crate) fn unpack_dictionary<K>( | ||
array: &dyn Array, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5203,19 +5203,19 @@ mod tests { | |
_test_string_to_view::<i64>(); | ||
} | ||
|
||
const VIEW_TEST_DATA: [Option<&str>; 5] = [ | ||
Some("hello"), | ||
Some("world"), | ||
None, | ||
Some("large payload over 12 bytes"), | ||
Some("lulu"), | ||
]; | ||
|
||
fn _test_string_to_view<O>() | ||
where | ||
O: OffsetSizeTrait, | ||
{ | ||
let data = vec![ | ||
Some("hello"), | ||
Some("world"), | ||
None, | ||
Some("large payload over 12 bytes"), | ||
Some("lulu"), | ||
]; | ||
|
||
let string_array = GenericStringArray::<O>::from(data.clone()); | ||
let string_array = GenericStringArray::<O>::from_iter(VIEW_TEST_DATA); | ||
|
||
assert!(can_cast_types( | ||
string_array.data_type(), | ||
|
@@ -5225,7 +5225,7 @@ mod tests { | |
let string_view_array = cast(&string_array, &DataType::Utf8View).unwrap(); | ||
assert_eq!(string_view_array.data_type(), &DataType::Utf8View); | ||
|
||
let expect_string_view_array = StringViewArray::from(data); | ||
let expect_string_view_array = StringViewArray::from_iter(VIEW_TEST_DATA); | ||
assert_eq!(string_view_array.as_ref(), &expect_string_view_array); | ||
} | ||
|
||
|
@@ -5239,15 +5239,7 @@ mod tests { | |
where | ||
O: OffsetSizeTrait, | ||
{ | ||
let data: Vec<Option<&[u8]>> = vec![ | ||
Some(b"hello"), | ||
Some(b"world"), | ||
None, | ||
Some(b"large payload over 12 bytes"), | ||
Some(b"lulu"), | ||
]; | ||
|
||
let binary_array = GenericBinaryArray::<O>::from(data.clone()); | ||
let binary_array = GenericBinaryArray::<O>::from_iter(VIEW_TEST_DATA); | ||
|
||
assert!(can_cast_types( | ||
binary_array.data_type(), | ||
|
@@ -5257,10 +5249,48 @@ mod tests { | |
let binary_view_array = cast(&binary_array, &DataType::BinaryView).unwrap(); | ||
assert_eq!(binary_view_array.data_type(), &DataType::BinaryView); | ||
|
||
let expect_binary_view_array = BinaryViewArray::from(data); | ||
let expect_binary_view_array = BinaryViewArray::from_iter(VIEW_TEST_DATA); | ||
assert_eq!(binary_view_array.as_ref(), &expect_binary_view_array); | ||
} | ||
|
||
#[test] | ||
fn test_dict_to_view() { | ||
let values = StringArray::from_iter(VIEW_TEST_DATA); | ||
let keys = Int8Array::from_iter([Some(1), Some(0), None, Some(3), None, Some(1), Some(4)]); | ||
let string_dict_array = | ||
DictionaryArray::<Int8Type>::try_new(keys, Arc::new(values)).unwrap(); | ||
let typed_dict = string_dict_array.downcast_dict::<StringArray>().unwrap(); | ||
|
||
let string_view_array = { | ||
let mut builder = StringViewBuilder::new().with_block_size(8); // multiple buffers. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 |
||
for v in typed_dict.into_iter() { | ||
builder.append_option(v); | ||
} | ||
builder.finish() | ||
}; | ||
let expected_string_array_type = string_view_array.data_type(); | ||
let casted_string_array = cast(&string_dict_array, expected_string_array_type).unwrap(); | ||
assert_eq!(casted_string_array.data_type(), expected_string_array_type); | ||
assert_eq!(casted_string_array.as_ref(), &string_view_array); | ||
|
||
let binary_buffer = cast(&typed_dict.values(), &DataType::Binary).unwrap(); | ||
let binary_dict_array = | ||
DictionaryArray::<Int8Type>::new(typed_dict.keys().clone(), binary_buffer); | ||
let typed_binary_dict = binary_dict_array.downcast_dict::<BinaryArray>().unwrap(); | ||
|
||
let binary_view_array = { | ||
let mut builder = BinaryViewBuilder::new().with_block_size(8); // multiple buffers. | ||
for v in typed_binary_dict.into_iter() { | ||
builder.append_option(v); | ||
} | ||
builder.finish() | ||
}; | ||
let expected_binary_array_type = binary_view_array.data_type(); | ||
let casted_binary_array = cast(&binary_dict_array, expected_binary_array_type).unwrap(); | ||
assert_eq!(casted_binary_array.data_type(), expected_binary_array_type); | ||
assert_eq!(casted_binary_array.as_ref(), &binary_view_array); | ||
} | ||
|
||
#[test] | ||
fn test_view_to_string() { | ||
_test_view_to_string::<i32>(); | ||
|
@@ -5271,24 +5301,15 @@ mod tests { | |
where | ||
O: OffsetSizeTrait, | ||
{ | ||
let data: Vec<Option<&str>> = vec![ | ||
Some("hello"), | ||
Some("world"), | ||
None, | ||
Some("large payload over 12 bytes"), | ||
Some("lulu"), | ||
]; | ||
|
||
let view_array = { | ||
// ["hello", "world", null, "large payload over 12 bytes", "lulu"] | ||
let mut builder = StringViewBuilder::new().with_block_size(8); // multiple buffers. | ||
for s in data.iter() { | ||
for s in VIEW_TEST_DATA.iter() { | ||
builder.append_option(*s); | ||
} | ||
builder.finish() | ||
}; | ||
|
||
let expected_string_array = GenericStringArray::<O>::from(data); | ||
let expected_string_array = GenericStringArray::<O>::from_iter(VIEW_TEST_DATA); | ||
let expected_type = expected_string_array.data_type(); | ||
|
||
assert!(can_cast_types(view_array.data_type(), expected_type)); | ||
|
@@ -5318,7 +5339,6 @@ mod tests { | |
]; | ||
|
||
let view_array = { | ||
// ["hello", "world", null, "large payload over 12 bytes", "lulu"] | ||
let mut builder = BinaryViewBuilder::new().with_block_size(8); // multiple buffers. | ||
for s in data.iter() { | ||
builder.append_option(*s); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💯