diff --git a/arrow/src/compute/kernels/sort.rs b/arrow/src/compute/kernels/sort.rs index 9287425bf126..7cd463d6ebb9 100644 --- a/arrow/src/compute/kernels/sort.rs +++ b/arrow/src/compute/kernels/sort.rs @@ -410,24 +410,27 @@ fn sort_boolean( len = limit.min(len); } if !descending { - sort_by(&mut valids, len - nulls_len, |a, b| cmp(a.1, b.1)); + sort_by(&mut valids, len.saturating_sub(nulls_len), |a, b| { + cmp(a.1, b.1) + }); } else { - sort_by(&mut valids, len - nulls_len, |a, b| cmp(a.1, b.1).reverse()); + sort_by(&mut valids, len.saturating_sub(nulls_len), |a, b| { + cmp(a.1, b.1).reverse() + }); // reverse to keep a stable ordering nulls.reverse(); } // collect results directly into a buffer instead of a vec to avoid another aligned allocation - let mut result = MutableBuffer::new(values.len() * std::mem::size_of::()); + let result_capacity = len * std::mem::size_of::(); + let mut result = MutableBuffer::new(result_capacity); // sets len to capacity so we can access the whole buffer as a typed slice - result.resize(values.len() * std::mem::size_of::(), 0); + result.resize(result_capacity, 0); let result_slice: &mut [u32] = result.typed_data_mut(); - debug_assert_eq!(result_slice.len(), nulls_len + valids_len); - if options.nulls_first { let size = nulls_len.min(len); - result_slice[0..nulls_len.min(len)].copy_from_slice(&nulls); + result_slice[0..size].copy_from_slice(&nulls[0..size]); if nulls_len < len { insert_valid_values(result_slice, nulls_len, &valids[0..len - size]); } @@ -626,9 +629,13 @@ where len = limit.min(len); } if !descending { - sort_by(&mut valids, len - nulls_len, |a, b| cmp(a.1, b.1)); + sort_by(&mut valids, len.saturating_sub(nulls_len), |a, b| { + cmp(a.1, b.1) + }); } else { - sort_by(&mut valids, len - nulls_len, |a, b| cmp(a.1, b.1).reverse()); + sort_by(&mut valids, len.saturating_sub(nulls_len), |a, b| { + cmp(a.1, b.1).reverse() + }); // reverse to keep a stable ordering nulls.reverse(); } @@ -689,11 +696,11 @@ where len = limit.min(len); } if !descending { - sort_by(&mut valids, len - nulls_len, |a, b| { + sort_by(&mut valids, len.saturating_sub(nulls_len), |a, b| { cmp_array(a.1.as_ref(), b.1.as_ref()) }); } else { - sort_by(&mut valids, len - nulls_len, |a, b| { + sort_by(&mut valids, len.saturating_sub(nulls_len), |a, b| { cmp_array(a.1.as_ref(), b.1.as_ref()).reverse() }); // reverse to keep a stable ordering @@ -1285,6 +1292,48 @@ mod tests { None, vec![5, 0, 2, 1, 4, 3], ); + + // valid values less than limit with extra nulls + test_sort_to_indices_primitive_arrays::( + vec![Some(2.0), None, None, Some(1.0)], + Some(SortOptions { + descending: false, + nulls_first: false, + }), + Some(3), + vec![3, 0, 1], + ); + + test_sort_to_indices_primitive_arrays::( + vec![Some(2.0), None, None, Some(1.0)], + Some(SortOptions { + descending: false, + nulls_first: true, + }), + Some(3), + vec![1, 2, 3], + ); + + // more nulls than limit + test_sort_to_indices_primitive_arrays::( + vec![Some(1.0), None, None, None], + Some(SortOptions { + descending: false, + nulls_first: true, + }), + Some(2), + vec![1, 2], + ); + + test_sort_to_indices_primitive_arrays::( + vec![Some(1.0), None, None, None], + Some(SortOptions { + descending: false, + nulls_first: false, + }), + Some(2), + vec![0, 1], + ); } #[test] @@ -1329,6 +1378,48 @@ mod tests { Some(3), vec![5, 0, 2], ); + + // valid values less than limit with extra nulls + test_sort_to_indices_boolean_arrays( + vec![Some(true), None, None, Some(false)], + Some(SortOptions { + descending: false, + nulls_first: false, + }), + Some(3), + vec![3, 0, 1], + ); + + test_sort_to_indices_boolean_arrays( + vec![Some(true), None, None, Some(false)], + Some(SortOptions { + descending: false, + nulls_first: true, + }), + Some(3), + vec![1, 2, 3], + ); + + // more nulls than limit + test_sort_to_indices_boolean_arrays( + vec![Some(true), None, None, None], + Some(SortOptions { + descending: false, + nulls_first: true, + }), + Some(2), + vec![1, 2], + ); + + test_sort_to_indices_boolean_arrays( + vec![Some(true), None, None, None], + Some(SortOptions { + descending: false, + nulls_first: false, + }), + Some(2), + vec![0, 1], + ); } #[test] @@ -1686,6 +1777,48 @@ mod tests { Some(3), vec![3, 0, 2], ); + + // valid values less than limit with extra nulls + test_sort_to_indices_string_arrays( + vec![Some("def"), None, None, Some("abc")], + Some(SortOptions { + descending: false, + nulls_first: false, + }), + Some(3), + vec![3, 0, 1], + ); + + test_sort_to_indices_string_arrays( + vec![Some("def"), None, None, Some("abc")], + Some(SortOptions { + descending: false, + nulls_first: true, + }), + Some(3), + vec![1, 2, 3], + ); + + // more nulls than limit + test_sort_to_indices_string_arrays( + vec![Some("def"), None, None, None], + Some(SortOptions { + descending: false, + nulls_first: true, + }), + Some(2), + vec![1, 2], + ); + + test_sort_to_indices_string_arrays( + vec![Some("def"), None, None, None], + Some(SortOptions { + descending: false, + nulls_first: false, + }), + Some(2), + vec![0, 1], + ); } #[test] @@ -1799,6 +1932,48 @@ mod tests { Some(3), vec![None, None, Some("sad")], ); + + // valid values less than limit with extra nulls + test_sort_string_arrays( + vec![Some("def"), None, None, Some("abc")], + Some(SortOptions { + descending: false, + nulls_first: false, + }), + Some(3), + vec![Some("abc"), Some("def"), None], + ); + + test_sort_string_arrays( + vec![Some("def"), None, None, Some("abc")], + Some(SortOptions { + descending: false, + nulls_first: true, + }), + Some(3), + vec![None, None, Some("abc")], + ); + + // more nulls than limit + test_sort_string_arrays( + vec![Some("def"), None, None, None], + Some(SortOptions { + descending: false, + nulls_first: true, + }), + Some(2), + vec![None, None], + ); + + test_sort_string_arrays( + vec![Some("def"), None, None, None], + Some(SortOptions { + descending: false, + nulls_first: false, + }), + Some(2), + vec![Some("def"), None], + ); } #[test] @@ -1912,6 +2087,48 @@ mod tests { Some(3), vec![None, None, Some("sad")], ); + + // valid values less than limit with extra nulls + test_sort_string_dict_arrays::( + vec![Some("def"), None, None, Some("abc")], + Some(SortOptions { + descending: false, + nulls_first: false, + }), + Some(3), + vec![Some("abc"), Some("def"), None], + ); + + test_sort_string_dict_arrays::( + vec![Some("def"), None, None, Some("abc")], + Some(SortOptions { + descending: false, + nulls_first: true, + }), + Some(3), + vec![None, None, Some("abc")], + ); + + // more nulls than limit + test_sort_string_dict_arrays::( + vec![Some("def"), None, None, None], + Some(SortOptions { + descending: false, + nulls_first: true, + }), + Some(2), + vec![None, None], + ); + + test_sort_string_dict_arrays::( + vec![Some("def"), None, None, None], + Some(SortOptions { + descending: false, + nulls_first: false, + }), + Some(2), + vec![Some("def"), None], + ); } #[test] @@ -1999,6 +2216,52 @@ mod tests { vec![Some(vec![Some(1), Some(0)]), Some(vec![Some(1), Some(1)])], None, ); + + // valid values less than limit with extra nulls + test_sort_list_arrays::( + vec![Some(vec![Some(1)]), None, None, Some(vec![Some(2)])], + Some(SortOptions { + descending: false, + nulls_first: false, + }), + Some(3), + vec![Some(vec![Some(1)]), Some(vec![Some(2)]), None], + None, + ); + + test_sort_list_arrays::( + vec![Some(vec![Some(1)]), None, None, Some(vec![Some(2)])], + Some(SortOptions { + descending: false, + nulls_first: true, + }), + Some(3), + vec![None, None, Some(vec![Some(2)])], + None, + ); + + // more nulls than limit + test_sort_list_arrays::( + vec![Some(vec![Some(1)]), None, None, None], + Some(SortOptions { + descending: false, + nulls_first: true, + }), + Some(2), + vec![None, None], + None, + ); + + test_sort_list_arrays::( + vec![Some(vec![Some(1)]), None, None, None], + Some(SortOptions { + descending: false, + nulls_first: false, + }), + Some(2), + vec![Some(vec![Some(1)]), None], + None, + ); } #[test]