From e76a0b1ff0662822f98ae96570755a90e9b852d7 Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Mon, 24 May 2021 14:44:27 +0200 Subject: [PATCH] respect offset in utf8 and list casts (#335) --- arrow/src/compute/kernels/cast.rs | 35 ++++++++++++++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/arrow/src/compute/kernels/cast.rs b/arrow/src/compute/kernels/cast.rs index de1516b0768b..49543b8c88fa 100644 --- a/arrow/src/compute/kernels/cast.rs +++ b/arrow/src/compute/kernels/cast.rs @@ -1687,6 +1687,7 @@ where }; let mut builder = ArrayData::builder(dtype) + .offset(array.offset()) .len(array.len()) .add_buffer(offset_buffer) .add_buffer(str_values_buf); @@ -1744,7 +1745,12 @@ where _ => unreachable!(), }; - let offsets = data.buffer::(0); + // Safety: + // The first buffer is the offsets and they are aligned to OffSetSizeFrom: (i64 or i32) + // Justification: + // The safe variant data.buffer:: take the offset into account and we + // cannot create a list array with offsets starting at non zero. + let offsets = unsafe { data.buffers()[0].as_slice().align_to::() }.1; let iter = offsets.iter().map(|idx| { let idx: OffsetSizeTo = NumCast::from(*idx).unwrap(); @@ -1757,6 +1763,7 @@ where // wrap up let mut builder = ArrayData::builder(out_dtype) + .offset(array.offset()) .len(array.len()) .add_buffer(offset_buffer) .add_child_data(value_data); @@ -3840,4 +3847,30 @@ mod tests { Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Utf8)), ] } + + #[test] + fn test_utf8_cast_offsets() { + // test if offset of the array is taken into account during cast + let str_array = StringArray::from(vec!["a", "b", "c"]); + let str_array = str_array.slice(1, 2); + + let out = cast(&str_array, &DataType::LargeUtf8).unwrap(); + + let large_str_array = out.as_any().downcast_ref::().unwrap(); + let strs = large_str_array.into_iter().flatten().collect::>(); + assert_eq!(strs, &["b", "c"]) + } + + #[test] + fn test_list_cast_offsets() { + // test if offset of the array is taken into account during cast + let array1 = make_list_array().slice(1, 2); + let array2 = Arc::new(make_list_array()) as ArrayRef; + + let dt = DataType::LargeList(Box::new(Field::new("item", DataType::Int32, true))); + let out1 = cast(&array1, &dt).unwrap(); + let out2 = cast(&array2, &dt).unwrap(); + + assert_eq!(&out1, &out2.slice(1, 2)) + } }