Skip to content

Commit

Permalink
Fix merge_dictionary_values in selection kernels (apache#4833)
Browse files Browse the repository at this point in the history
  • Loading branch information
tustvold authored Sep 19, 2023
1 parent 47e8a8d commit f7464bc
Showing 1 changed file with 23 additions and 9 deletions.
32 changes: 23 additions & 9 deletions arrow-select/src/dictionary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ pub fn merge_dictionary_values<K: ArrowDictionaryKeyType>(
) -> Result<MergedDictionaries<K>, ArrowError> {
let mut num_values = 0;

let mut values = Vec::with_capacity(dictionaries.len());
let mut values_arrays = Vec::with_capacity(dictionaries.len());
let mut value_slices = Vec::with_capacity(dictionaries.len());

for (idx, dictionary) in dictionaries.iter().enumerate() {
Expand All @@ -164,11 +164,13 @@ pub fn merge_dictionary_values<K: ArrowDictionaryKeyType>(
(None, None) => None,
};
let keys = dictionary.keys().values();
let values_mask = compute_values_mask(keys, key_mask.as_ref());
let v = dictionary.values().as_ref();
num_values += v.len();
value_slices.push(get_masked_values(v, &values_mask));
values.push(v)
let values = dictionary.values().as_ref();
let values_mask = compute_values_mask(keys, key_mask.as_ref(), values.len());

let masked_values = get_masked_values(values, &values_mask);
num_values += masked_values.len();
value_slices.push(masked_values);
values_arrays.push(values)
}

// Map from value to new index
Expand Down Expand Up @@ -202,7 +204,7 @@ pub fn merge_dictionary_values<K: ArrowDictionaryKeyType>(

Ok(MergedDictionaries {
key_mappings,
values: interleave(&values, &indices)?,
values: interleave(&values_arrays, &indices)?,
})
}

Expand All @@ -211,9 +213,10 @@ pub fn merge_dictionary_values<K: ArrowDictionaryKeyType>(
fn compute_values_mask<K: ArrowNativeType>(
keys: &ScalarBuffer<K>,
mask: Option<&BooleanBuffer>,
max_key: usize,
) -> BooleanBuffer {
let mut builder = BooleanBufferBuilder::new(keys.len());
builder.advance(keys.len());
let mut builder = BooleanBufferBuilder::new(max_key);
builder.advance(max_key);

match mask {
Some(n) => n
Expand Down Expand Up @@ -330,4 +333,15 @@ mod tests {
assert_eq!(&merged.key_mappings[0], &[0, 0, 0, 1, 0]);
assert_eq!(&merged.key_mappings[1], &[]);
}

#[test]
fn test_merge_keys_smaller() {
let values = StringArray::from_iter_values(["a", "b"]);
let keys = Int32Array::from_iter_values([1]);
let a = DictionaryArray::new(keys, Arc::new(values));

let merged = merge_dictionary_values(&[&a], None).unwrap();
let expected = StringArray::from(vec!["b"]);
assert_eq!(merged.values.as_ref(), &expected);
}
}

0 comments on commit f7464bc

Please sign in to comment.