Skip to content

Commit

Permalink
Rewrote matrix multiply
Browse files Browse the repository at this point in the history
  • Loading branch information
tbetcke committed Sep 30, 2023
1 parent c74a9d5 commit c4099a3
Show file tree
Hide file tree
Showing 8 changed files with 597 additions and 314 deletions.
8 changes: 4 additions & 4 deletions blis/src/interface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@ pub mod types;
use rlst_common::types::Scalar;

/// Compute expected size of a data slice from stride and shape.
pub fn get_expected_data_size(stride: (usize, usize), shape: (usize, usize)) -> usize {
if shape.0 == 0 || shape.1 == 0 {
pub fn get_expected_data_size(stride: [usize; 2], shape: [usize; 2]) -> usize {
if shape[0] == 0 || shape[1] == 0 {
return 0;
}

1 + (shape.0 - 1) * stride.0 + (shape.1 - 1) * stride.1
1 + (shape[0] - 1) * stride[0] + (shape[1] - 1) * stride[1]
}

/// Panic if expected data size is not identical to actual data size.
pub fn assert_data_size<T: Scalar>(data: &[T], stride: (usize, usize), shape: (usize, usize)) {
pub fn assert_data_size<T: Scalar>(data: &[T], stride: [usize; 2], shape: [usize; 2]) {
let expected = get_expected_data_size(stride, shape);

assert_eq!(
Expand Down
18 changes: 9 additions & 9 deletions blis/src/interface/gemm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,20 +48,20 @@ macro_rules! impl_gemm {
csc: usize,
) {
match transa {
TransMode::NoTrans => assert_data_size(a, (rsa, csa), (m, k)),
TransMode::ConjNoTrans => assert_data_size(a, (rsa, csa), (m, k)),
TransMode::Trans => assert_data_size(a, (rsa, csa), (k, m)),
TransMode::ConjTrans => assert_data_size(a, (rsa, csa), (k, m)),
TransMode::NoTrans => assert_data_size(a, [rsa, csa], [m, k]),
TransMode::ConjNoTrans => assert_data_size(a, [rsa, csa], [m, k]),
TransMode::Trans => assert_data_size(a, [rsa, csa], [k, m]),
TransMode::ConjTrans => assert_data_size(a, [rsa, csa], [k, m]),
}

match transb {
TransMode::NoTrans => assert_data_size(b, (rsb, csb), (k, n)),
TransMode::ConjNoTrans => assert_data_size(b, (rsb, csb), (k, n)),
TransMode::Trans => assert_data_size(b, (rsb, csb), (n, k)),
TransMode::ConjTrans => assert_data_size(b, (rsb, csb), (n, k)),
TransMode::NoTrans => assert_data_size(b, [rsb, csb], [k, n]),
TransMode::ConjNoTrans => assert_data_size(b, [rsb, csb], [k, n]),
TransMode::Trans => assert_data_size(b, [rsb, csb], [n, k]),
TransMode::ConjTrans => assert_data_size(b, [rsb, csb], [n, k]),
}

assert_data_size(c, (rsc, csc), (m, n));
assert_data_size(c, [rsc, csc], [m, n]);

unsafe {
raw::$blis_gemm(
Expand Down
6 changes: 3 additions & 3 deletions common/src/traits/accessors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,16 +76,16 @@ pub trait ChunkedAccess<const N: usize> {

/// Get raw access to the underlying data.
pub trait RawAccess {
type T: Scalar;
type Item: Scalar;

/// Get a slice of the whole data.
fn data(&self) -> &[Self::T];
fn data(&self) -> &[Self::Item];
}

/// Get mutable raw access to the underlying data.
pub trait RawAccessMut: RawAccess {
/// Get a mutable slice of the whole data.
fn data_mut(&mut self) -> &mut [Self::T];
fn data_mut(&mut self) -> &mut [Self::Item];
}

/// Check if `multi_index` not out of bounds with respect to `shape`.
Expand Down
51 changes: 51 additions & 0 deletions dense/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,21 @@
use crate::base_array::BaseArray;
use crate::data_container::VectorContainer;
use rlst_common::traits::ChunkedAccess;
use rlst_common::traits::NumberOfElements;
use rlst_common::traits::RandomAccessByRef;
use rlst_common::traits::RandomAccessMut;
use rlst_common::traits::RawAccess;
use rlst_common::traits::RawAccessMut;
use rlst_common::traits::Shape;
use rlst_common::traits::Stride;
use rlst_common::traits::UnsafeRandomAccessByRef;
use rlst_common::traits::UnsafeRandomAccessByValue;
use rlst_common::traits::UnsafeRandomAccessMut;
use rlst_common::types::DataChunk;
use rlst_common::types::Scalar;

pub mod iterators;
pub mod multiply;
pub mod operations;
pub mod operators;
pub mod random;
Expand Down Expand Up @@ -174,3 +179,49 @@ pub(crate) fn empty_chunk<const N: usize, Item: Scalar>(
valid_entries,
})
}

impl<
Item: Scalar,
ArrayImpl: UnsafeRandomAccessByValue<NDIM, Item = Item> + Shape<NDIM> + RawAccess<Item = Item>,
const NDIM: usize,
> RawAccess for Array<Item, ArrayImpl, NDIM>
{
type Item = Item;

fn data(&self) -> &[Self::Item] {
self.0.data()
}
}

impl<
Item: Scalar,
ArrayImpl: UnsafeRandomAccessByValue<NDIM, Item = Item> + Shape<NDIM> + RawAccessMut<Item = Item>,
const NDIM: usize,
> RawAccessMut for Array<Item, ArrayImpl, NDIM>
{
fn data_mut(&mut self) -> &mut [Self::Item] {
self.0.data_mut()
}
}

impl<
Item: Scalar,
ArrayImpl: UnsafeRandomAccessByValue<NDIM, Item = Item> + Shape<NDIM>,
const NDIM: usize,
> NumberOfElements for Array<Item, ArrayImpl, NDIM>
{
fn number_of_elements(&self) -> usize {
self.shape().iter().product()
}
}

impl<
Item: Scalar,
ArrayImpl: UnsafeRandomAccessByValue<NDIM, Item = Item> + Shape<NDIM> + Stride<NDIM>,
const NDIM: usize,
> Stride<NDIM> for Array<Item, ArrayImpl, NDIM>
{
fn stride(&self) -> [usize; NDIM] {
self.0.stride()
}
}
1 change: 1 addition & 0 deletions dense/src/array/multiply.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
//! Multiplication of Arrays
31 changes: 30 additions & 1 deletion dense/src/base_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ use crate::data_container::{DataContainer, DataContainerMut};
use crate::layout::{
check_multi_index_in_bounds, convert_1d_nd_from_shape, convert_nd_raw, stride_from_shape,
};
use rlst_common::traits::{ChunkedAccess, UnsafeRandomAccessByValue, UnsafeRandomAccessMut};
use rlst_common::traits::{
ChunkedAccess, RawAccess, RawAccessMut, Stride, UnsafeRandomAccessByValue,
UnsafeRandomAccessMut,
};
use rlst_common::{
traits::{Shape, UnsafeRandomAccessByRef},
types::Scalar,
Expand Down Expand Up @@ -109,3 +112,29 @@ impl<Item: Scalar, Data: DataContainerMut<Item = Item>, const N: usize, const ND
}
}
}

impl<Item: Scalar, Data: DataContainer<Item = Item>, const NDIM: usize> RawAccess
for BaseArray<Item, Data, NDIM>
{
type Item = Item;

fn data(&self) -> &[Self::Item] {
self.data.data()
}
}

impl<Item: Scalar, Data: DataContainerMut<Item = Item>, const NDIM: usize> RawAccessMut
for BaseArray<Item, Data, NDIM>
{
fn data_mut(&mut self) -> &mut [Self::Item] {
self.data.data_mut()
}
}

impl<Item: Scalar, Data: DataContainer<Item = Item>, const NDIM: usize> Stride<NDIM>
for BaseArray<Item, Data, NDIM>
{
fn stride(&self) -> [usize; NDIM] {
self.stride
}
}
2 changes: 1 addition & 1 deletion dense/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ pub mod layout;
pub mod macros;
pub mod traits;
// pub mod matrix;
// pub mod matrix_multiply;
pub mod matrix_multiply;
// pub mod matrix_ref;
// pub mod matrix_view;
// pub mod op_containers;
Expand Down
Loading

0 comments on commit c4099a3

Please sign in to comment.