Skip to content

Commit

Permalink
New doc (#53)
Browse files Browse the repository at this point in the history
* WIP: Documentation

* WIP: New documentation

* Fixed cargo format error
  • Loading branch information
tbetcke authored Dec 18, 2023
1 parent f482441 commit 19f7610
Show file tree
Hide file tree
Showing 35 changed files with 876 additions and 588 deletions.
2 changes: 1 addition & 1 deletion blis/examples/threads.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! Modify the threading behaviour in BLIS
use rlst_blis::threading;
use rlst_blis::interface::threading;

fn main() {
println!("Num threads: {}", threading::get_num_threads());
Expand Down
3 changes: 3 additions & 0 deletions blis/src/interface.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
//! Public interface to Blis routines.
pub mod gemm;
pub mod threading;
pub mod types;

use cauchy::Scalar;
Expand Down
23 changes: 22 additions & 1 deletion blis/src/interface/gemm.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,30 @@
//! Computational BLIS routines
//! Matrix product in Blis.
use super::types::TransMode;
use crate::interface::assert_data_size;
use crate::raw;
use cauchy::{c32, c64, Scalar};

/// Safe interface to using the Blis Gemm routine.
///
/// Performs the matrix multiplication `C = alpha * A * B + beta * C`.
/// # Arguments
///
/// - `transa` - Transposition mode of A.
/// - `transb` - Transposition mode of B.
/// - `m` - Row dimension of C.
/// - `k` - Column dimension of C.
/// - `k` - Column dimension of A. Same as row dimension of B.
/// - `a` - Reference to data of A.
/// - `alpha` - Scalar `alpha parameter.
/// - `rsa` - Row stride of A.
/// - `csb` - Row stride of A.
/// - `b` - Reference to data of B.
/// - `rsb` - Row stride of B.
/// - `csb` - Column stride of B.
/// - `beta` - Scalar `beta` parameter.
/// - `c` - Reference to data of C.
/// - `rsc` - Row stride of C.
/// - `csc` - Column stride of C.
pub trait Gemm: Scalar {
#[allow(clippy::too_many_arguments)]
fn gemm(
Expand Down
4 changes: 2 additions & 2 deletions blis/src/threading.rs → blis/src/interface/threading.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
//! Set global threading for BLIS
//! Set global threading for Blis.
use crate::raw::{self, bli_thread_set_num_threads};
use num_cpus;

/// Get the current number of threads used by BLIS.
/// Get the current number of threads used by Blis.
pub fn get_num_threads() -> usize {
let threads = unsafe { raw::bli_thread_get_num_threads() };

Expand Down
132 changes: 5 additions & 127 deletions blis/src/interface/types.rs
Original file line number Diff line number Diff line change
@@ -1,138 +1,16 @@
//! Interface to Blis types
use crate::raw;

/// Transposition Mode.
#[derive(Clone, Copy, PartialEq)]
#[repr(u32)]
pub enum TransMode {
/// Complex conjugate of matrix.
ConjNoTrans = raw::trans_t_BLIS_CONJ_NO_TRANSPOSE,
/// No modification of matrix.
NoTrans = raw::trans_t_BLIS_NO_TRANSPOSE,
/// Transposition of matrix.
Trans = raw::trans_t_BLIS_TRANSPOSE,
/// Conjugate transpose of matrix.
ConjTrans = raw::trans_t_BLIS_CONJ_TRANSPOSE,
}

// pub trait BlisIdentifier {
// const ID: u32;
// }

// impl BlisIdentifier for f32 {
// const ID: u32 = raw::num_t_BLIS_FLOAT;
// }

// impl BlisIdentifier for f64 {
// const ID: u32 = raw::num_t_BLIS_DOUBLE;
// }

// impl BlisIdentifier for c32 {
// const ID: u32 = raw::num_t_BLIS_SCOMPLEX;
// }

// impl BlisIdentifier for c64 {
// const ID: u32 = raw::num_t_BLIS_DCOMPLEX;
// }

// pub struct BlisObject {
// obj: raw::obj_t,
// requires_free: bool,
// }

// impl Drop for BlisObject {
// fn drop(&mut self) {
// if self.requires_free {
// unsafe {
// crate::raw::bli_obj_free(&mut self.obj);
// }
// }
// }
// }

// impl Default for raw::obj_t {
// fn default() -> Self {
// Self {
// root: std::ptr::null_mut(),
// off: [0, 0],
// dim: [0, 0],
// diag_off: 0,
// info: 0,
// info2: 0,
// elem_size: 0,
// buffer: std::ptr::null_mut(),
// rs: 0,
// cs: 0,
// is: 0,
// scalar: raw::dcomplex {
// real: 0.0,
// imag: 0.0,
// },
// m_padded: 0,
// n_padded: 0,
// ps: 0,
// pd: 0,
// m_panel: 0,
// n_panel: 0,
// pack_fn: None,
// pack_params: std::ptr::null_mut(),
// ker_fn: None,
// ker_params: std::ptr::null_mut(),
// }
// }
// }

// impl BlisObject {
// pub fn from_slice<T: Scalar + BlisIdentifier>(
// data: &mut [T],
// stride: (usize, usize),
// shape: (usize, usize),
// ) -> Self {
// // The maximum index that still needs to fit in the data slice.
// let max_index = stride.0 * (shape.0 - 1) + stride.1 * (shape.1 - 1);
// assert_eq!(
// data.len(),
// 1 + max_index,
// "Length of slice is {} but expected {}",
// data.len(),
// 1 + max_index
// );

// let mut obj = raw::obj_t::default();

// unsafe {
// raw::bli_obj_create_with_attached_buffer(
// T::ID,
// shape.0 as i64,
// shape.1 as i64,
// data.as_mut_ptr() as *mut std::ffi::c_void,
// stride.0 as i64,
// stride.1 as i64,
// &mut obj,
// )

// };

// BlisObject {
// obj,
// requires_free: false,
// }
// }

// pub fn from_scalar<T: Scalar + BlisIdentifier>(scalar: T) -> Self {
// let mut obj = raw::obj_t::default();
// unsafe { raw::bli_obj_create_1x1(T::ID, &mut obj) };

// unsafe {
// raw::bli_setsc(
// num::cast::<T::Real, f64>(scalar.re()).unwrap(),
// num::cast::<T::Real, f64>(scalar.im()).unwrap(),
// &obj,
// )
// };

// Self {
// obj,
// requires_free: true,
// }
// }

// pub fn get_obj(&self) -> &raw::obj_t {
// &self.obj
// }
// }
1 change: 0 additions & 1 deletion blis/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,3 @@

pub mod interface;
pub mod raw;
pub mod threading;
2 changes: 2 additions & 0 deletions blis/src/raw.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
//! Raw interface to Blis routines.
#![allow(non_upper_case_globals)]
#![allow(non_camel_case_types)]
#![allow(non_snake_case)]
Expand Down
3 changes: 1 addition & 2 deletions common/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
//! Common RLST data structures
//! Common Rlst data structures
#![cfg_attr(feature = "strict", deny(warnings))]

pub mod traits;
pub mod types;
17 changes: 0 additions & 17 deletions common/src/traits.rs

This file was deleted.

1 change: 0 additions & 1 deletion common/src/traits/accessors.rs

This file was deleted.

1 change: 0 additions & 1 deletion common/src/traits/constructors.rs

This file was deleted.

1 change: 0 additions & 1 deletion common/src/traits/in_place_operations.rs

This file was deleted.

1 change: 0 additions & 1 deletion common/src/traits/iterators.rs

This file was deleted.

1 change: 0 additions & 1 deletion common/src/traits/operations.rs

This file was deleted.

1 change: 0 additions & 1 deletion common/src/traits/properties.rs

This file was deleted.

6 changes: 3 additions & 3 deletions common/src/types.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
//! Basic types
//! Basic types.
pub use cauchy::{c32, c64, Scalar};
use thiserror::Error;

/// The RLST error type.
/// The Rlst error type.
#[derive(Error, Debug)]
pub enum RlstError {
#[error("Method {0} is not implemented.")]
Expand Down Expand Up @@ -34,7 +34,7 @@ pub enum RlstError {
MatrixNotHermitian,
}

/// Alias for an RLST Result type.
/// Alias for an Rlst Result type.
pub type RlstResult<T> = std::result::Result<T, RlstError>;

/// Data chunk of fixed size N.
Expand Down
70 changes: 66 additions & 4 deletions dense/src/array.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
//! Basic Array type
//!
//! [Array] is the basic type for dense calculations in Rlst. The full definition
//! `Array<Item, ArrayImpl, NDIM>` represents a tensor with `NDIM` axes, `Item` as data type
//! (e.g. `f64`), and implemented through `ArrayImpl`.
use crate::base_array::BaseArray;
use crate::data_container::SliceContainer;
use crate::data_container::SliceContainerMut;
use crate::data_container::VectorContainer;
use crate::traits::*;
use rlst_common::types::DataChunk;
Expand All @@ -15,9 +21,18 @@ pub mod random;
pub mod slice;
pub mod views;

/// A basic dynamically allocated array.
pub type DynamicArray<Item, const NDIM: usize> =
Array<Item, BaseArray<Item, VectorContainer<Item>, NDIM>, NDIM>;

/// A dynamically allocated array from a data slice.
pub type SliceArray<'a, Item, const NDIM: usize> =
Array<Item, BaseArray<Item, SliceContainer<'a, Item>, NDIM>, NDIM>;

/// A mutable dynamically allocated array from a data slice.
pub type SliceArrayMut<'a, Item, const NDIM: usize> =
Array<Item, BaseArray<Item, SliceContainerMut<'a, Item>, NDIM>, NDIM>;

/// The basic tuple type defining an array.
pub struct Array<Item, ArrayImpl, const NDIM: usize>(ArrayImpl)
where
Expand All @@ -30,19 +45,18 @@ impl<
const NDIM: usize,
> Array<Item, ArrayImpl, NDIM>
{
/// Instantiate a new array from an `ArrayImpl` structure.
pub fn new(arr: ArrayImpl) -> Self {
Self(arr)
}

pub fn new_dynamic_like_self(&self) -> DynamicArray<Item, NDIM> {
DynamicArray::<Item, NDIM>::from_shape(self.shape())
}

/// Return the number of elements in the array.
pub fn number_of_elements(&self) -> usize {
self.0.shape().iter().product()
}
}

/// Create a new heap allocated array from a given shape.
impl<Item: Scalar, const NDIM: usize> DynamicArray<Item, NDIM> {
pub fn from_shape(shape: [usize; NDIM]) -> Self {
let size = shape.iter().product();
Expand Down Expand Up @@ -151,6 +165,7 @@ impl<
}
}

/// Create an empty chunk.
pub(crate) fn empty_chunk<const N: usize, Item: Scalar>(
chunk_index: usize,
nelements: usize,
Expand Down Expand Up @@ -230,8 +245,55 @@ impl<
}

/// Create an empty array of given type and dimension.
///
/// Empty arrays serve as convenient containers for input into functions that
/// resize an array before filling it with data.
pub fn empty_array<Item: Scalar, const NDIM: usize>() -> DynamicArray<Item, NDIM> {
let shape = [0; NDIM];
let container = VectorContainer::new(0);
Array::new(BaseArray::new(container, shape))
}

impl<'a, Item: Scalar, const NDIM: usize> SliceArray<'a, Item, NDIM> {
/// Create a new array from a slice with a given `shape`.
///
/// The `stride` is automatically assumed to be column major.
pub fn from_shape(slice: &'a [Item], shape: [usize; NDIM]) -> Self {
Array::new(BaseArray::new(SliceContainer::new(slice), shape))
}

/// Create a new array from a slice with a given `shape` and `stride`.
pub fn from_shape_with_stride(
slice: &'a [Item],
shape: [usize; NDIM],
stride: [usize; NDIM],
) -> Self {
Array::new(BaseArray::new_with_stride(
SliceContainer::new(slice),
shape,
stride,
))
}
}

impl<'a, Item: Scalar, const NDIM: usize> SliceArrayMut<'a, Item, NDIM> {
/// Create a new array from a slice with a given `shape`.
///
/// The `stride` is automatically assumed to be column major.
pub fn from_shape(slice: &'a mut [Item], shape: [usize; NDIM]) -> Self {
Array::new(BaseArray::new(SliceContainerMut::new(slice), shape))
}

/// Create a new array from a slice with a given `shape` and `stride`.
pub fn from_shape_with_stride(
slice: &'a mut [Item],
shape: [usize; NDIM],
stride: [usize; NDIM],
) -> Self {
Array::new(BaseArray::new_with_stride(
SliceContainerMut::new(slice),
shape,
stride,
))
}
}
Loading

0 comments on commit 19f7610

Please sign in to comment.