Skip to content

Commit

Permalink
Use slice::align_to
Browse files Browse the repository at this point in the history
When checking that tensor data can be converted to a slice of `T`, we
previously were checking the right things: is the length right? Is the
alignment correct? Turns out there is a `std` function that does this
for us: `slice::align_to`. Replacing the custom check with the `std`
version should have no effect on the code other than clarity.
  • Loading branch information
abrown committed Jul 19, 2024
1 parent 9242a8c commit 5ca8d14
Showing 1 changed file with 11 additions and 20 deletions.
31 changes: 11 additions & 20 deletions crates/openvino/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,11 @@ impl Tensor {
/// underlying pointer's alignment.
pub fn get_data<T>(&self) -> Result<&[T]> {
let raw_data = self.get_raw_data()?;
let len = get_safe_len::<T>(raw_data);
let slice = unsafe { std::slice::from_raw_parts(raw_data.as_ptr().cast::<T>(), len) };
let (prefix, slice, suffix) = unsafe { raw_data.align_to::<T>() };
assert!(
prefix.is_empty() && suffix.is_empty(),
"raw data is not aligned to `T`'s alignment"
);
Ok(slice)
}

Expand All @@ -129,27 +132,15 @@ impl Tensor {
/// underlying pointer's alignment.
pub fn get_data_mut<T>(&mut self) -> Result<&mut [T]> {
let raw_data = self.get_raw_data_mut()?;
let len = get_safe_len::<T>(raw_data);
let slice =
unsafe { std::slice::from_raw_parts_mut(raw_data.as_mut_ptr().cast::<T>(), len) };
let (prefix, slice, suffix) = unsafe { raw_data.align_to_mut::<T>() };
assert!(
prefix.is_empty() && suffix.is_empty(),
"raw data is not aligned to `T`'s alignment"
);
Ok(slice)
}
}

/// Convenience function for checking that we can cast `data` to a slice of `T`, returning the
/// length of that slice.
fn get_safe_len<T>(data: &[u8]) -> usize {
assert!(
data.len() % std::mem::size_of::<T>() == 0,
"data size is not a multiple of the size of `T`"
);
assert!(
data.as_ptr() as usize % std::mem::align_of::<T>() == 0,
"raw data is not aligned to `T`'s alignment"
);
data.len() / std::mem::size_of::<T>()
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -208,7 +199,7 @@ mod tests {
}

#[test]
#[should_panic(expected = "data size is not a multiple of the size of `T`")]
#[should_panic(expected = "raw data is not aligned to `T`'s alignment")]
fn casting_check() {
openvino_sys::library::load().unwrap();
let shape = Shape::new(&[10, 10, 10]).unwrap();
Expand Down

0 comments on commit 5ca8d14

Please sign in to comment.