Skip to content

Commit

Permalink
cargo +nightly fmt
Browse files Browse the repository at this point in the history
  • Loading branch information
SunDoge committed Aug 21, 2023
1 parent f642c58 commit ae0b2d5
Show file tree
Hide file tree
Showing 15 changed files with 161 additions and 119 deletions.
7 changes: 6 additions & 1 deletion examples/with_pyo3/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,12 @@ pub fn tensordict(py: Python<'_>) -> PyResult<&PyDict> {

#[pyfunction]
pub fn print_tensor(tensor: ManagedTensor) {
dbg!(tensor.shape(), tensor.strides(), tensor.dtype(), tensor.device());
dbg!(
tensor.shape(),
tensor.strides(),
tensor.dtype(),
tensor.device()
);
assert!(tensor.dtype() == DataType::F32);
dbg!(tensor.as_slice::<f32>());
}
Expand Down
19 changes: 19 additions & 0 deletions rustfmt.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
unstable_features = true

version = "Two"

group_imports = "StdExternalCrate"
imports_granularity = "Crate"
reorder_imports = true

wrap_comments = true
normalize_comments = true

reorder_impl_items = true
condense_wildcard_suffixes = true
enum_discrim_align_threshold = 20
use_field_init_shorthand = true

format_strings = true
format_code_in_doc_comments = true
format_macro_matchers = true
74 changes: 35 additions & 39 deletions src/data_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,18 @@ impl Default for DataType {
}

impl DataType {
// Bfloat
pub const BF16: Self = Self {
code: DataTypeCode::Bfloat,
bits: 16,
lanes: 1,
};
// Bool
pub const BOOL: Self = Self {
code: DataTypeCode::Bool,
bits: 8,
lanes: 1,
};
// Float
pub const F16: Self = Self {
code: DataTypeCode::Float,
Expand All @@ -34,75 +46,59 @@ impl DataType {
bits: 64,
lanes: 1,
};

// Uint
pub const U8: Self = Self {
code: DataTypeCode::UInt,
bits: 8,
pub const I128: Self = Self {
code: DataTypeCode::Int,
bits: 128,
lanes: 1,
};
pub const U16: Self = Self {
code: DataTypeCode::UInt,
pub const I16: Self = Self {
code: DataTypeCode::Int,
bits: 16,
lanes: 1,
};
pub const U32: Self = Self {
code: DataTypeCode::UInt,
pub const I32: Self = Self {
code: DataTypeCode::Int,
bits: 32,
lanes: 1,
};
pub const U64: Self = Self {
code: DataTypeCode::UInt,
pub const I64: Self = Self {
code: DataTypeCode::Int,
bits: 64,
lanes: 1,
};
pub const U128: Self = Self {
code: DataTypeCode::UInt,
bits: 128,
lanes: 1,
};

// Int
pub const I8: Self = Self {
code: DataTypeCode::Int,
bits: 8,
lanes: 1,
};
pub const I16: Self = Self {
code: DataTypeCode::Int,
pub const U128: Self = Self {
code: DataTypeCode::UInt,
bits: 128,
lanes: 1,
};
pub const U16: Self = Self {
code: DataTypeCode::UInt,
bits: 16,
lanes: 1,
};
pub const I32: Self = Self {
code: DataTypeCode::Int,
pub const U32: Self = Self {
code: DataTypeCode::UInt,
bits: 32,
lanes: 1,
};
pub const I64: Self = Self {
code: DataTypeCode::Int,
pub const U64: Self = Self {
code: DataTypeCode::UInt,
bits: 64,
lanes: 1,
};
pub const I128: Self = Self {
code: DataTypeCode::Int,
bits: 128,
lanes: 1,
};

// Bool
pub const BOOL: Self = Self {
code: DataTypeCode::Bool,
// Uint
pub const U8: Self = Self {
code: DataTypeCode::UInt,
bits: 8,
lanes: 1,
};

// Bfloat
pub const BF16: Self = Self {
code: DataTypeCode::Bfloat,
bits: 16,
lanes: 1,
};

/// Calculate `DataType` size as (bits * lanes + 7) // 8
pub fn size(&self) -> usize {
((self.bits as u32 * self.lanes as u32 + 7) / 8) as usize
Expand Down
3 changes: 1 addition & 2 deletions src/dl_managed_tensor.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use crate::ffi;
use crate::tensor::traits::TensorView;
use crate::{ffi, tensor::traits::TensorView};

impl TensorView for ffi::DLManagedTensor {
fn data_ptr(&self) -> *mut std::ffi::c_void {
Expand Down
4 changes: 2 additions & 2 deletions src/dl_tensor.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use crate::ffi;
use crate::tensor::traits::TensorView;
use crate::{ffi, tensor::traits::TensorView};

impl TensorView for ffi::DLTensor {
fn data_ptr(&self) -> *mut std::ffi::c_void {
Expand Down Expand Up @@ -29,6 +28,7 @@ impl TensorView for ffi::DLTensor {
fn dtype(&self) -> ffi::DataType {
self.dtype
}

fn byte_offset(&self) -> u64 {
self.byte_offset
}
Expand Down
69 changes: 36 additions & 33 deletions src/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,37 +20,37 @@ pub struct PackVersion {
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum DeviceType {
/// CPU device
Cpu = 1,
Cpu = 1,
/// CUDA GPU device
Cuda = 2,
Cuda = 2,
/// Pinned CUDA CPU memory by cudaMallocHost
CudaHost = 3,
CudaHost = 3,
/// OpenCL devices.
OpenCl = 4,
OpenCl = 4,
/// Vulkan buffer for next generation graphics.
Vulkan = 7,
Vulkan = 7,
/// Metal for Apple GPU.
Metal = 8,
Metal = 8,
/// Verilog simulator buffer
Vpi = 9,
Vpi = 9,
/// ROCm GPUs for AMD GPUs
Rocm = 10,
Rocm = 10,
/// Pinned ROCm CPU memory allocated by hipMallocHost
RocmHost = 11,
RocmHost = 11,
/// Reserved extension device type,
/// used for quickly test extension device
/// The semantics can differ depending on the implementation.
ExtDev = 12,
ExtDev = 12,
/// CUDA managed/unified memory allocated by cudaMallocManaged
CudaManaged = 13,
/// Unified shared memory allocated on a oneAPI non-partititioned
/// device. Call to oneAPI runtime is required to determine the device
/// type, the USM allocation type and the sycl context it is bound to.
OneApi = 14,
OneApi = 14,
/// GPU support for next generation WebGPU standard.
WebGpu = 15,
WebGpu = 15,
/// Qualcomm Hexagon DSP
Hexagon = 16,
Hexagon = 16,
}

impl From<i32> for DeviceType {
Expand All @@ -66,34 +66,36 @@ pub struct Device {
/// The device type used in the device.
pub device_type: DeviceType,
/// The device index.
/// For vanilla CPU memory, pinned memory, or managed memory, this is set to 0.
/// For vanilla CPU memory, pinned memory, or managed memory, this is set to
/// 0.
pub device_id: i32,
}

#[repr(u8)]
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum DataTypeCode {
/// signed integer
Int = 0,
Int = 0,
/// unsigned integer
UInt = 1,
UInt = 1,
/// IEEE floating point
Float = 2,
Float = 2,
/// Opaque handle type, reserved for testing purposes.
/// Frameworks need to agree on the handle data type for the exchange to be well-defined.
/// Frameworks need to agree on the handle data type for the exchange to be
/// well-defined.
OpaqueHandle = 3,
/// bfloat16
Bfloat = 4,
Bfloat = 4,
/// complex number
/// (C/C++/Python layout: compact struct per complex number)
Complex = 5,
Complex = 5,
/// boolean
Bool = 6,
Bool = 6,
}

/// The data type the tensor can hold. The data type is assumed to follow the
/// native endian-ness. An explicit error message should be raised when attempting to
/// export an array with non-native endianness
/// native endian-ness. An explicit error message should be raised when
/// attempting to export an array with non-native endianness
/// Examples
/// - float: type_code = 2, bits = 32, lanes=1
/// - float4(vectorized 4 float): type_code = 2, bits = 32, lanes=4
Expand All @@ -115,9 +117,10 @@ pub struct DataType {
#[derive(Debug, Clone, Copy)]
pub struct DLTensor {
/// The data pointer points to the allocated data. This will be CUDA
/// device pointer or cl_mem handle in OpenCL. It may be opaque on some device
/// types. This pointer is always aligned to 256 bytes as in CUDA. The
/// `byte_offset` field should be used to point to the beginning of the data.
/// device pointer or cl_mem handle in OpenCL. It may be opaque on some
/// device types. This pointer is always aligned to 256 bytes as in
/// CUDA. The `byte_offset` field should be used to point to the
/// beginning of the data.
///
/// Note that as of Nov 2021, multiply libraries (CuPy, PyTorch, TensorFlow,
/// TVM, perhaps others) do not adhere to this 256 byte aligment requirement
Expand All @@ -137,7 +140,6 @@ pub struct DLTensor {
/// return size;
/// }
/// ```
///
pub data: *mut c_void,
/// The device of the tensor
pub device: Device,
Expand Down Expand Up @@ -169,9 +171,9 @@ pub struct DLManagedTensor {

/// A versioned and managed C Tensor object, manage memory of DLTensor.
/// This data structure is intended to facilitate the borrowing of DLTensor by
/// another framework. It is not meant to transfer the tensor. When the borrowing
/// framework doesn't need the tensor, it should call the deleter to notify the
/// host that the resource is no longer needed.
/// another framework. It is not meant to transfer the tensor. When the
/// borrowing framework doesn't need the tensor, it should call the deleter to
/// notify the host that the resource is no longer needed.
///
/// This is the current standard DLPack exchange data structure.
#[repr(C)]
Expand All @@ -185,9 +187,10 @@ pub struct DLManagedTensorVersioned {
pub manager_ctx: *mut c_void,

/// Destructor.
/// This should be called to destruct manager_ctx which holds the DLManagedTensorVersioned.
/// It can be NULL if there is no way for the caller to provide a reasonable
/// destructor. The destructors deletes the argument self as well.
/// This should be called to destruct manager_ctx which holds the
/// DLManagedTensorVersioned. It can be NULL if there is no way for the
/// caller to provide a reasonable destructor. The destructors deletes
/// the argument self as well.
pub deleter: Option<unsafe extern "C" fn(*mut Self)>,
/// Additional bitmask flags information about the tensor.
/// By default the flags should be set to 0.
Expand Down
15 changes: 10 additions & 5 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,15 @@ mod python;
pub mod ffi;
pub mod utils;

/// Imports the structs and traits for you to implement [`IntoDLPack`] and [`FromDLPack`].
/// Imports the structs and traits for you to implement [`IntoDLPack`] and
/// [`FromDLPack`].
pub mod prelude;

pub use crate::manager_ctx::ManagerCtx;
pub use crate::shape_and_strides::ShapeAndStrides;
pub use crate::tensor::traits::{DLPack, FromDLPack, InferDtype, IntoDLPack, TensorView, ToTensor};
pub use crate::tensor::ManagedTensor;
pub use crate::{
manager_ctx::ManagerCtx,
shape_and_strides::ShapeAndStrides,
tensor::{
traits::{DLPack, FromDLPack, InferDtype, IntoDLPack, TensorView, ToTensor},
ManagedTensor,
},
};
12 changes: 8 additions & 4 deletions src/manager_ctx.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
use std::ptr::NonNull;

use crate::tensor::traits::{IntoDLPack, TensorView};
use crate::ShapeAndStrides;
use crate::{ffi, prelude::ToTensor};
use crate::{
ffi,
prelude::ToTensor,
tensor::traits::{IntoDLPack, TensorView},
ShapeAndStrides,
};

unsafe extern "C" fn deleter_fn<T>(dl_managed_tensor: *mut ffi::DLManagedTensor) {
// Reconstruct pointer and destroy it.
Expand All @@ -12,7 +15,8 @@ unsafe extern "C" fn deleter_fn<T>(dl_managed_tensor: *mut ffi::DLManagedTensor)
unsafe { Box::from_raw(ctx) };
}

// TODO: should be ManagerCtx<T, M> where M is one of DLManagedTensor and DLManagedTensorVersioned
// TODO: should be ManagerCtx<T, M> where M is one of DLManagedTensor and
// DLManagedTensorVersioned
/// The ManagerCtx holds the Tensor and its metadata.
pub struct ManagerCtx<T> {
inner: T,
Expand Down
2 changes: 0 additions & 2 deletions src/pack_version.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,3 @@ impl Default for PackVersion {
}
}
}


8 changes: 5 additions & 3 deletions src/prelude.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
pub use crate::ffi::{DataType, Device, PackVersion};
pub use crate::tensor::traits::{DLPack, FromDLPack, InferDtype, IntoDLPack, TensorView, ToTensor};
pub use crate::{ManagedTensor, ManagerCtx, ShapeAndStrides};
pub use crate::{
ffi::{DataType, Device, PackVersion},
tensor::traits::{DLPack, FromDLPack, InferDtype, IntoDLPack, TensorView, ToTensor},
ManagedTensor, ManagerCtx, ShapeAndStrides,
};
Loading

0 comments on commit ae0b2d5

Please sign in to comment.