Skip to content

Commit

Permalink
Mark typed buffer APIs safe (apache#996) (apache#1027)
Browse files Browse the repository at this point in the history
  • Loading branch information
tustvold committed Jun 13, 2022
1 parent fcf655e commit 7d64fc6
Show file tree
Hide file tree
Showing 14 changed files with 39 additions and 55 deletions.
8 changes: 3 additions & 5 deletions arrow/src/array/array_union.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ impl UnionArray {
}

// Check the type_ids
let type_id_slice: &[i8] = unsafe { type_ids.typed_data() };
let type_id_slice: &[i8] = type_ids.typed_data();
let invalid_type_ids = type_id_slice
.iter()
.filter(|i| *i < &0)
Expand All @@ -201,7 +201,7 @@ impl UnionArray {
// Check the value offsets if provided
if let Some(offset_buffer) = &value_offsets {
let max_len = type_ids.len() as i32;
let offsets_slice: &[i32] = unsafe { offset_buffer.typed_data() };
let offsets_slice: &[i32] = offset_buffer.typed_data();
let invalid_offsets = offsets_slice
.iter()
.filter(|i| *i < &0 || *i > &max_len)
Expand Down Expand Up @@ -255,9 +255,7 @@ impl UnionArray {
pub fn value_offset(&self, index: usize) -> i32 {
assert!(index - self.offset() < self.len());
if self.is_dense() {
// safety: reinterpreting is safe since the offset buffer contains `i32` values and is
// properly aligned.
unsafe { self.data().buffers()[1].typed_data::<i32>()[index] }
self.data().buffers()[1].typed_data::<i32>()[index]
} else {
index as i32
}
Expand Down
4 changes: 2 additions & 2 deletions arrow/src/array/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ pub(crate) fn builder_to_mutable_buffer<T: ArrowNativeType>(
/// builder.append(45);
/// let buffer = builder.finish();
///
/// assert_eq!(unsafe { buffer.typed_data::<u8>() }, &[42, 43, 44, 45]);
/// assert_eq!(buffer.typed_data::<u8>(), &[42, 43, 44, 45]);
/// # Ok(())
/// # }
/// ```
Expand Down Expand Up @@ -291,7 +291,7 @@ impl<T: ArrowNativeType> BufferBuilder<T> {
///
/// let buffer = builder.finish();
///
/// assert_eq!(unsafe { buffer.typed_data::<u8>() }, &[42, 44, 46]);
/// assert_eq!(buffer.typed_data::<u8>(), &[42, 44, 46]);
/// ```
#[inline]
pub fn finish(&mut self) -> Buffer {
Expand Down
5 changes: 2 additions & 3 deletions arrow/src/array/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -767,8 +767,7 @@ impl ArrayData {
)));
}

// SAFETY: Bounds checked above
Ok(unsafe { &(buffer.typed_data::<T>()[self.offset..self.offset + len]) })
Ok(&buffer.typed_data::<T>()[self.offset..self.offset + len])
}

/// Does a cheap sanity check that the `self.len` values in `buffer` are valid
Expand Down Expand Up @@ -1161,7 +1160,7 @@ impl ArrayData {

// Justification: buffer size was validated above
let indexes: &[T] =
unsafe { &(buffer.typed_data::<T>()[self.offset..self.offset + self.len]) };
&buffer.typed_data::<T>()[self.offset..self.offset + self.len];

indexes.iter().enumerate().try_for_each(|(i, &dict_index)| {
// Do not check the value is null (value can be arbitrary)
Expand Down
25 changes: 10 additions & 15 deletions arrow/src/buffer/immutable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -181,19 +181,14 @@ impl Buffer {

/// View buffer as typed slice.
///
/// # Safety
/// # Panics
///
/// `ArrowNativeType` is public so that it can be used as a trait bound for other public
/// components, such as the `ToByteSlice` trait. However, this means that it can be
/// implemented by user defined types, which it is not intended for.
pub unsafe fn typed_data<T: ArrowNativeType + num::Num>(&self) -> &[T] {
// JUSTIFICATION
// Benefit
// Many of the buffers represent specific types, and consumers of `Buffer` often need to re-interpret them.
// Soundness
// * The pointer is non-null by construction
// * alignment asserted below.
let (prefix, offsets, suffix) = self.as_slice().align_to::<T>();
/// This function panics if the underlying buffer is not aligned
/// correctly for type `T`.
pub fn typed_data<T: ArrowNativeType>(&self) -> &[T] {
// SAFETY
// ArrowNativeType are trivially transmutable, and this method checks alignment
let (prefix, offsets, suffix) = unsafe { self.as_slice().align_to::<T>() };
assert!(prefix.is_empty() && suffix.is_empty());
offsets
}
Expand Down Expand Up @@ -451,7 +446,7 @@ mod tests {
macro_rules! check_as_typed_data {
($input: expr, $native_t: ty) => {{
let buffer = Buffer::from_slice_ref($input);
let slice: &[$native_t] = unsafe { buffer.typed_data::<$native_t>() };
let slice: &[$native_t] = buffer.typed_data::<$native_t>();
assert_eq!($input, slice);
}};
}
Expand Down Expand Up @@ -573,12 +568,12 @@ mod tests {
)
};

let slice = unsafe { buffer.typed_data::<i32>() };
let slice = buffer.typed_data::<i32>();
assert_eq!(slice, &[1, 2, 3, 4, 5]);

let buffer = buffer.slice(std::mem::size_of::<i32>());

let slice = unsafe { buffer.typed_data::<i32>() };
let slice = buffer.typed_data::<i32>();
assert_eq!(slice, &[2, 3, 4, 5]);
}
}
13 changes: 5 additions & 8 deletions arrow/src/buffer/mutable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -275,17 +275,14 @@ impl MutableBuffer {

/// View this buffer asa slice of a specific type.
///
/// # Safety
///
/// This function must only be used with buffers which are treated
/// as type `T` (e.g. extended with items of type `T`).
///
/// # Panics
///
/// This function panics if the underlying buffer is not aligned
/// correctly for type `T`.
pub unsafe fn typed_data_mut<T: ArrowNativeType>(&mut self) -> &mut [T] {
let (prefix, offsets, suffix) = self.as_slice_mut().align_to_mut::<T>();
pub fn typed_data_mut<T: ArrowNativeType>(&mut self) -> &mut [T] {
// SAFETY
// ArrowNativeType are trivially transmutable, and this method checks alignment
let (prefix, offsets, suffix) = unsafe { self.as_slice_mut().align_to_mut::<T>() };
assert!(prefix.is_empty() && suffix.is_empty());
offsets
}
Expand All @@ -299,7 +296,7 @@ impl MutableBuffer {
/// assert_eq!(buffer.len(), 8) // u32 has 4 bytes
/// ```
#[inline]
pub fn extend_from_slice<T: ToByteSlice>(&mut self, items: &[T]) {
pub fn extend_from_slice<T: ArrowNativeType>(&mut self, items: &[T]) {
let len = items.len();
let additional = len * std::mem::size_of::<T>();
self.reserve(additional);
Expand Down
4 changes: 1 addition & 3 deletions arrow/src/buffer/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,7 @@ where

let left_chunks = left.bit_chunks(offset_in_bits, len_in_bits);

// Safety: buffer is always treated as type `u64` in the code
// below.
let result_chunks = unsafe { result.typed_data_mut::<u64>().iter_mut() };
let result_chunks = result.typed_data_mut::<u64>().iter_mut();

result_chunks
.zip(left_chunks.iter())
Expand Down
2 changes: 1 addition & 1 deletion arrow/src/compute/kernels/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2084,7 +2084,7 @@ where
let list_data = array.data();
let str_values_buf = str_array.value_data();

let offsets = unsafe { list_data.buffers()[0].typed_data::<OffsetSizeFrom>() };
let offsets = list_data.buffers()[0].typed_data::<OffsetSizeFrom>();

let mut offset_builder = BufferBuilder::<OffsetSizeTo>::new(offsets.len());
offsets.iter().try_for_each::<_, Result<_>>(|offset| {
Expand Down
6 changes: 2 additions & 4 deletions arrow/src/compute/kernels/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -452,8 +452,7 @@ fn sort_boolean(
let mut result = MutableBuffer::new(result_capacity);
// sets len to capacity so we can access the whole buffer as a typed slice
result.resize(result_capacity, 0);
// Safety: the buffer is always treated as `u32` in the code below
let result_slice: &mut [u32] = unsafe { result.typed_data_mut() };
let result_slice: &mut [u32] = result.typed_data_mut();

if options.nulls_first {
let size = nulls_len.min(len);
Expand Down Expand Up @@ -565,8 +564,7 @@ where
let mut result = MutableBuffer::new(result_capacity);
// sets len to capacity so we can access the whole buffer as a typed slice
result.resize(result_capacity, 0);
// Safety: the buffer is always treated as `u32` in the code below
let result_slice: &mut [u32] = unsafe { result.typed_data_mut() };
let result_slice: &mut [u32] = result.typed_data_mut();

if options.nulls_first {
let size = nulls_len.min(len);
Expand Down
3 changes: 1 addition & 2 deletions arrow/src/compute/kernels/take.rs
Original file line number Diff line number Diff line change
Expand Up @@ -688,8 +688,7 @@ where
let bytes_offset = (data_len + 1) * std::mem::size_of::<OffsetSize>();
let mut offsets_buffer = MutableBuffer::from_len_zeroed(bytes_offset);

// Safety: the buffer is always treated as as a type of `OffsetSize` in the code below
let offsets = unsafe { offsets_buffer.typed_data_mut() };
let offsets = offsets_buffer.typed_data_mut();
let mut values = MutableBuffer::new(0);
let mut length_so_far = OffsetSize::zero();
offsets[0] = length_so_far;
Expand Down
4 changes: 2 additions & 2 deletions parquet/src/arrow/array_reader/byte_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,13 @@ impl<I: OffsetSizeTrait + ScalarValue> ArrayReader for ByteArrayReader<I> {
fn get_def_levels(&self) -> Option<&[i16]> {
self.def_levels_buffer
.as_ref()
.map(|buf| unsafe { buf.typed_data() })
.map(|buf| buf.typed_data())
}

fn get_rep_levels(&self) -> Option<&[i16]> {
self.rep_levels_buffer
.as_ref()
.map(|buf| unsafe { buf.typed_data() })
.map(|buf| buf.typed_data())
}
}

Expand Down
6 changes: 3 additions & 3 deletions parquet/src/arrow/array_reader/byte_array_dictionary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,13 +187,13 @@ where
fn get_def_levels(&self) -> Option<&[i16]> {
self.def_levels_buffer
.as_ref()
.map(|buf| unsafe { buf.typed_data() })
.map(|buf| buf.typed_data())
}

fn get_rep_levels(&self) -> Option<&[i16]> {
self.rep_levels_buffer
.as_ref()
.map(|buf| unsafe { buf.typed_data() })
.map(|buf| buf.typed_data())
}
}

Expand Down Expand Up @@ -356,7 +356,7 @@ where
assert_eq!(dict.data_type(), &self.value_type);

let dict_buffers = dict.data().buffers();
let dict_offsets = unsafe { dict_buffers[0].typed_data::<V>() };
let dict_offsets = dict_buffers[0].typed_data::<V>();
let dict_values = dict_buffers[1].as_slice();

values.extend_from_dictionary(
Expand Down
8 changes: 4 additions & 4 deletions parquet/src/arrow/array_reader/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -226,13 +226,13 @@ where
fn get_def_levels(&self) -> Option<&[i16]> {
self.def_levels_buffer
.as_ref()
.map(|buf| unsafe { buf.typed_data() })
.map(|buf| buf.typed_data())
}

fn get_rep_levels(&self) -> Option<&[i16]> {
self.rep_levels_buffer
.as_ref()
.map(|buf| unsafe { buf.typed_data() })
.map(|buf| buf.typed_data())
}
}

Expand Down Expand Up @@ -447,13 +447,13 @@ where
fn get_def_levels(&self) -> Option<&[i16]> {
self.def_levels_buffer
.as_ref()
.map(|buf| unsafe { buf.typed_data() })
.map(|buf| buf.typed_data())
}

fn get_rep_levels(&self) -> Option<&[i16]> {
self.rep_levels_buffer
.as_ref()
.map(|buf| unsafe { buf.typed_data() })
.map(|buf| buf.typed_data())
}
}

Expand Down
2 changes: 1 addition & 1 deletion parquet/src/arrow/buffer/dictionary_buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ impl<K: ScalarValue + ArrowNativeType + Ord, V: ScalarValue + OffsetSizeTrait>
Self::Dict { keys, values } => {
let mut spilled = OffsetBuffer::default();
let dict_buffers = values.data().buffers();
let dict_offsets = unsafe { dict_buffers[0].typed_data::<V>() };
let dict_offsets = dict_buffers[0].typed_data::<V>();
let dict_values = dict_buffers[1].as_slice();

if values.is_empty() {
Expand Down
4 changes: 2 additions & 2 deletions parquet/src/arrow/record_reader/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,7 @@ mod tests {

// Verify result record data
let actual = record_reader.consume_record_data().unwrap();
let actual_values = unsafe { actual.typed_data::<i32>() };
let actual_values = actual.typed_data::<i32>();

let expected = &[0, 7, 0, 6, 3, 0, 8];
assert_eq!(actual_values.len(), expected.len());
Expand Down Expand Up @@ -687,7 +687,7 @@ mod tests {

// Verify result record data
let actual = record_reader.consume_record_data().unwrap();
let actual_values = unsafe { actual.typed_data::<i32>() };
let actual_values = actual.typed_data::<i32>();
let expected = &[4, 0, 0, 7, 6, 3, 2, 8, 9];
assert_eq!(actual_values.len(), expected.len());

Expand Down

0 comments on commit 7d64fc6

Please sign in to comment.