Skip to content

Commit

Permalink
respect offset in utf8 and list casts (#335)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored May 24, 2021
1 parent 4c17ac8 commit 5ac771a
Showing 1 changed file with 34 additions and 1 deletion.
35 changes: 34 additions & 1 deletion arrow/src/compute/kernels/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1687,6 +1687,7 @@ where
};

let mut builder = ArrayData::builder(dtype)
.offset(array.offset())
.len(array.len())
.add_buffer(offset_buffer)
.add_buffer(str_values_buf);
Expand Down Expand Up @@ -1744,7 +1745,12 @@ where
_ => unreachable!(),
};

let offsets = data.buffer::<OffsetSizeFrom>(0);
// Safety:
// The first buffer is the offsets and they are aligned to OffSetSizeFrom: (i64 or i32)
// Justification:
// The safe variant data.buffer::<OffsetSizeFrom> take the offset into account and we
// cannot create a list array with offsets starting at non zero.
let offsets = unsafe { data.buffers()[0].as_slice().align_to::<OffsetSizeFrom>() }.1;

let iter = offsets.iter().map(|idx| {
let idx: OffsetSizeTo = NumCast::from(*idx).unwrap();
Expand All @@ -1757,6 +1763,7 @@ where

// wrap up
let mut builder = ArrayData::builder(out_dtype)
.offset(array.offset())
.len(array.len())
.add_buffer(offset_buffer)
.add_child_data(value_data);
Expand Down Expand Up @@ -3841,4 +3848,30 @@ mod tests {
Dictionary(Box::new(DataType::UInt32), Box::new(DataType::Utf8)),
]
}

#[test]
fn test_utf8_cast_offsets() {
// test if offset of the array is taken into account during cast
let str_array = StringArray::from(vec!["a", "b", "c"]);
let str_array = str_array.slice(1, 2);

let out = cast(&str_array, &DataType::LargeUtf8).unwrap();

let large_str_array = out.as_any().downcast_ref::<LargeStringArray>().unwrap();
let strs = large_str_array.into_iter().flatten().collect::<Vec<_>>();
assert_eq!(strs, &["b", "c"])
}

#[test]
fn test_list_cast_offsets() {
// test if offset of the array is taken into account during cast
let array1 = make_list_array().slice(1, 2);
let array2 = Arc::new(make_list_array()) as ArrayRef;

let dt = DataType::LargeList(Box::new(Field::new("item", DataType::Int32, true)));
let out1 = cast(&array1, &dt).unwrap();
let out2 = cast(&array2, &dt).unwrap();

assert_eq!(&out1, &out2.slice(1, 2))
}
}

0 comments on commit 5ac771a

Please sign in to comment.