Skip to content

Commit

Permalink
Replace Tensor::buffer_mut
Browse files Browse the repository at this point in the history
This refactors the data accessor functions in `Tensor` to be more
consistent with conventions elsewhere (e.g., `get_`). It also checks a
bit more robustly whether the underlying pointer can in fact be casted
to the type we expect.
  • Loading branch information
abrown committed Jun 25, 2024
1 parent b6eacef commit 7515dfa
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 33 deletions.
4 changes: 2 additions & 2 deletions crates/openvino/src/prepostprocess.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
//! # let data = fs::read("tests/fixtures/inception/tensor-1x3x299x299-f32.bgr").expect("to read the tensor from file");
//! # let input_shape = Shape::new(&vec![1, 299, 299, 3]).expect("to create a new shape");
//! # let mut tensor = Tensor::new(ElementType::F32, &input_shape).expect("to create a new tensor");
//! # let buffer = tensor.buffer_mut().unwrap();
//! # let buffer = tensor.get_raw_data_mut().unwrap();
//! # buffer.copy_from_slice(&data);
//! // Insantiate a new core, read in a model, and set up a tensor with input data before performing pre/post processing
//! // Instantiate a new core, read in a model, and set up a tensor with input data before performing pre/post processing
//! // Pre-process the input by:
//! // - converting NHWC to NCHW
//! // - resizing the input image
Expand Down
100 changes: 72 additions & 28 deletions crates/openvino/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,49 +86,79 @@ impl Tensor {
Ok(byte_size)
}

/// Get a mutable reference to the data of the tensor.
pub fn get_data<T>(&mut self) -> Result<&mut [T]> {
let mut data = std::ptr::null_mut();
try_unsafe!(ov_tensor_data(self.ptr, std::ptr::addr_of_mut!(data),))?;
let size = self.get_byte_size()? / std::mem::size_of::<T>();
let slice = unsafe { std::slice::from_raw_parts_mut(data.cast::<T>(), size) };
/// Get the underlying data for the tensor.
pub fn get_raw_data(&self) -> Result<&[u8]> {
let mut buffer = std::ptr::null_mut();
try_unsafe!(ov_tensor_data(self.ptr, std::ptr::addr_of_mut!(buffer)))?;
let size = self.get_byte_size()?;
let slice = unsafe { std::slice::from_raw_parts(buffer.cast::<u8>(), size) };
Ok(slice)
}

/// Get a mutable reference to the buffer of the tensor.
///
/// # Returns
///
/// A mutable reference to the buffer of the tensor.
pub fn buffer_mut(&mut self) -> Result<&mut [u8]> {
/// Get a mutable reference to the underlying data for the tensor.
pub fn get_raw_data_mut(&mut self) -> Result<&mut [u8]> {
let mut buffer = std::ptr::null_mut();
try_unsafe!(ov_tensor_data(self.ptr, std::ptr::addr_of_mut!(buffer)))?;
let size = self.get_byte_size()?;
let slice = unsafe { std::slice::from_raw_parts_mut(buffer.cast::<u8>(), size) };
Ok(slice)
}

/// Get a `T`-casted slice of the underlying data for the tensor.
///
/// # Panics
///
/// This method will panic if it can't cast the data to `T` due to the type size or the
/// underlying pointer's alignment.
pub fn get_data<T>(&self) -> Result<&[T]> {
let raw_data = self.get_raw_data()?;
let len = get_safe_len::<T>(raw_data);
let slice = unsafe { std::slice::from_raw_parts(raw_data.as_ptr().cast::<T>(), len) };
Ok(slice)
}

/// Get a mutable `T`-casted slice of the underlying data for the tensor.
///
/// # Panics
///
/// This method will panic if it can't cast the data to `T` due to the type size or the
/// underlying pointer's alignment.
pub fn get_data_mut<T>(&mut self) -> Result<&mut [T]> {
let raw_data = self.get_raw_data_mut()?;
let len = get_safe_len::<T>(raw_data);
let slice =
unsafe { std::slice::from_raw_parts_mut(raw_data.as_mut_ptr().cast::<T>(), len) };
Ok(slice)
}
}

/// Convenience function for checking that we can cast `data` to a slice of `T`, returning the
/// length of that slice.
fn get_safe_len<T>(data: &[u8]) -> usize {
if data.len() % std::mem::size_of::<T>() != 0 {
panic!("data size is not a multiple of the size of `T`");
}
if data.as_ptr() as usize % std::mem::align_of::<T>() != 0 {
panic!("raw data is not aligned to `T`'s alignment");
}
data.len() / std::mem::size_of::<T>()
}

#[cfg(test)]
mod tests {
use super::*;
use crate::{ElementType, LoadingError, Shape};

#[test]
fn test_create_tensor() {
openvino_sys::library::load()
.map_err(LoadingError::SystemFailure)
.unwrap();
openvino_sys::library::load().unwrap();
let shape = Shape::new(&vec![1, 3, 227, 227]).unwrap();
let tensor = Tensor::new(ElementType::F32, &shape).unwrap();
assert!(!tensor.ptr.is_null());
}

#[test]
fn test_get_shape() {
openvino_sys::library::load()
.map_err(LoadingError::SystemFailure)
.unwrap();
openvino_sys::library::load().unwrap();
let tensor = Tensor::new(
ElementType::F32,
&Shape::new(&vec![1, 3, 227, 227]).unwrap(),
Expand All @@ -140,9 +170,7 @@ mod tests {

#[test]
fn test_get_element_type() {
openvino_sys::library::load()
.map_err(LoadingError::SystemFailure)
.unwrap();
openvino_sys::library::load().unwrap();
let tensor = Tensor::new(
ElementType::F32,
&Shape::new(&vec![1, 3, 227, 227]).unwrap(),
Expand All @@ -154,9 +182,7 @@ mod tests {

#[test]
fn test_get_size() {
openvino_sys::library::load()
.map_err(LoadingError::SystemFailure)
.unwrap();
openvino_sys::library::load().unwrap();
let tensor = Tensor::new(
ElementType::F32,
&Shape::new(&vec![1, 3, 227, 227]).unwrap(),
Expand All @@ -168,9 +194,7 @@ mod tests {

#[test]
fn test_get_byte_size() {
openvino_sys::library::load()
.map_err(LoadingError::SystemFailure)
.unwrap();
openvino_sys::library::load().unwrap();
let tensor = Tensor::new(
ElementType::F32,
&Shape::new(&vec![1, 3, 227, 227]).unwrap(),
Expand All @@ -182,4 +206,24 @@ mod tests {
1 * 3 * 227 * 227 * std::mem::size_of::<f32>() as usize
);
}

#[test]
fn casting() {
openvino_sys::library::load().unwrap();
let shape = Shape::new(&vec![10, 10, 10]).unwrap();
let tensor = Tensor::new(ElementType::F32, &shape).unwrap();
let data = tensor.get_data::<f32>().unwrap();
assert_eq!(data.len(), 10 * 10 * 10);
}

#[test]
#[should_panic(expected = "data size is not a multiple of the size of `T`")]
fn casting_check() {
openvino_sys::library::load().unwrap();
let shape = Shape::new(&vec![10, 10, 10]).unwrap();
let tensor = Tensor::new(ElementType::F32, &shape).unwrap();
#[allow(dead_code)]
struct LargeOddType([u8; 1061]);
tensor.get_data::<LargeOddType>().unwrap();
}
}
2 changes: 1 addition & 1 deletion crates/openvino/tests/classify-alexnet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ fn classify_alexnet() -> anyhow::Result<()> {
let input_shape = Shape::new(&vec![1, 227, 227, 3])?;
let element_type = ElementType::F32;
let mut tensor = Tensor::new(element_type, &input_shape)?;
let buffer = tensor.buffer_mut()?;
let buffer = tensor.get_raw_data_mut()?;
buffer.copy_from_slice(&data);

// Pre-process the input by:
Expand Down
2 changes: 1 addition & 1 deletion crates/openvino/tests/classify-inception.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ fn classify_inception() -> anyhow::Result<()> {
let input_shape = Shape::new(&vec![1, 299, 299, 3])?;
let element_type = ElementType::F32;
let mut tensor = Tensor::new(element_type, &input_shape)?;
let buffer = tensor.buffer_mut()?;
let buffer = tensor.get_raw_data_mut()?;
buffer.copy_from_slice(&data);

// Pre-process the input by:
Expand Down
2 changes: 1 addition & 1 deletion crates/openvino/tests/classify-mobilenet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ fn classify_mobilenet() -> anyhow::Result<()> {
let input_shape = Shape::new(&vec![1, 224, 224, 3])?;
let element_type = ElementType::F32;
let mut tensor = Tensor::new(element_type, &input_shape)?;
let buffer = tensor.buffer_mut()?;
let buffer = tensor.get_raw_data_mut()?;
buffer.copy_from_slice(&data);

// Pre-process the input by:
Expand Down

0 comments on commit 7515dfa

Please sign in to comment.