diff --git a/arrow/src/compute/kernels/take.rs b/arrow/src/compute/kernels/take.rs index d325ce44a185..5b22f550c108 100644 --- a/arrow/src/compute/kernels/take.rs +++ b/arrow/src/compute/kernels/take.rs @@ -357,7 +357,13 @@ where // Soundness: `slice.map` is `TrustedLen`. let buffer = unsafe { Buffer::try_from_trusted_len_iter(values)? }; - Ok((buffer, indices.data_ref().null_buffer().cloned())) + Ok(( + buffer, + indices + .data_ref() + .null_buffer() + .map(|b| b.bit_slice(indices.offset(), indices.len())), + )) } // take implementation when both values and indices contain nulls @@ -516,7 +522,7 @@ where nulls = match indices.data_ref().null_buffer() { Some(buffer) => Some(buffer_bin_and( buffer, - 0, + indices.offset(), &null_buf.into(), 0, indices.len(), @@ -805,6 +811,24 @@ mod tests { Ok(()) } + fn test_take_primitive_arrays_non_null( + data: Vec, + index: &UInt32Array, + options: Option, + expected_data: Vec>, + ) -> Result<()> + where + T: ArrowPrimitiveType, + PrimitiveArray: From>, + PrimitiveArray: From>>, + { + let output = PrimitiveArray::::from(data); + let expected = Arc::new(PrimitiveArray::::from(expected_data)) as ArrayRef; + let output = take(&output, index, options)?; + assert_eq!(&output, &expected); + Ok(()) + } + fn test_take_impl_primitive_arrays( data: Vec>, index: &PrimitiveArray, @@ -876,6 +900,48 @@ mod tests { .unwrap(); } + #[test] + fn test_take_primitive_nullable_indices_non_null_values_with_offset() { + let index = + UInt32Array::from(vec![Some(0), Some(1), Some(2), Some(3), None, None]); + let index = index.slice(2, 4); + let index = index.as_any().downcast_ref::().unwrap(); + + assert_eq!( + index, + &UInt32Array::from(vec![Some(2), Some(3), None, None]) + ); + + test_take_primitive_arrays_non_null::( + vec![0, 10, 20, 30, 40, 50], + &index, + None, + vec![Some(20), Some(30), None, None], + ) + .unwrap(); + } + + #[test] + fn test_take_primitive_nullable_indices_nullable_values_with_offset() { + let index = + UInt32Array::from(vec![Some(0), Some(1), Some(2), Some(3), None, None]); + let index = index.slice(2, 4); + let index = index.as_any().downcast_ref::().unwrap(); + + assert_eq!( + index, + &UInt32Array::from(vec![Some(2), Some(3), None, None]) + ); + + test_take_primitive_arrays::( + vec![None, None, Some(20), Some(30), Some(40), Some(50)], + &index, + None, + vec![Some(20), Some(30), None, None], + ) + .unwrap(); + } + #[test] fn test_take_primitive() { let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]); @@ -1100,7 +1166,7 @@ mod tests { } #[test] - fn test_take_primitive_bool() { + fn test_take_bool() { let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]); // boolean test_take_boolean_arrays( @@ -1111,6 +1177,25 @@ mod tests { ); } + #[test] + fn test_take_bool_with_offset() { + let index = + UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2), None]); + let index = index.slice(2, 4); + let index = index + .as_any() + .downcast_ref::>() + .unwrap(); + + // boolean + test_take_boolean_arrays( + vec![Some(false), None, Some(true), Some(false), None], + &index, + None, + vec![None, Some(false), Some(true), None], + ); + } + fn _test_take_string<'a, K: 'static>() where K: Array + PartialEq + From>>,