diff --git a/arrow/src/compute/kernels/take.rs b/arrow/src/compute/kernels/take.rs index 692de278974d..8d9b4cb3cf07 100644 --- a/arrow/src/compute/kernels/take.rs +++ b/arrow/src/compute/kernels/take.rs @@ -280,6 +280,17 @@ where .unwrap(); Ok(Arc::new(take_fixed_size_binary(values, indices)?)) } + DataType::Null => { + // Take applied to a null array produces a null array. + if values.len() >= indices.len() { + // If the existing null array is as big as the indices, we can use a slice of it + // to avoid allocating a new null array. + Ok(values.slice(0, indices.len())) + } else { + // If the existing null array isn't big enough, create a new one. + Ok(new_null_array(&DataType::Null, indices.len())) + } + } t => unimplemented!("Take not supported for data type {:?}", t), } } @@ -1813,6 +1824,38 @@ mod tests { .unwrap(); } + #[test] + fn test_null_array_smaller_than_indices() { + let values = NullArray::new(2); + let indices = UInt32Array::from(vec![Some(0), None, Some(15)]); + + let result = take(&values, &indices, None).unwrap(); + let expected: ArrayRef = Arc::new(NullArray::new(3)); + assert_eq!(&result, &expected); + } + + #[test] + fn test_null_array_larger_than_indices() { + let values = NullArray::new(5); + let indices = UInt32Array::from(vec![Some(0), None, Some(15)]); + + let result = take(&values, &indices, None).unwrap(); + let expected: ArrayRef = Arc::new(NullArray::new(3)); + assert_eq!(&result, &expected); + } + + #[test] + fn test_null_array_indices_out_of_bounds() { + let values = NullArray::new(5); + let indices = UInt32Array::from(vec![Some(0), None, Some(15)]); + + let result = take(&values, &indices, Some(TakeOptions { check_bounds: true })); + assert_eq!( + result.unwrap_err().to_string(), + "Compute error: Array index out of bounds, cannot get item at index 15 from 5 entries" + ); + } + #[test] fn test_take_dict() { let keys_builder = Int16Builder::new(8);