From ddafa6ccc192fbac7768cc4210b2a1ece18c947a Mon Sep 17 00:00:00 2001 From: Timo Betcke Date: Tue, 26 Mar 2024 20:45:17 +0000 Subject: [PATCH] Batched dgemm (#71) * Added batched dgemm * Cleaned up blas dependency * Fixed warnings * Fixed clippy errors --- src/dense.rs | 1 + src/dense/array.rs | 8 ++ src/dense/array/mult_into.rs | 1 + src/dense/batched_gemm.rs | 207 +++++++++++++++++++++++++++++++++++ src/lib.rs | 7 ++ 5 files changed, 224 insertions(+) create mode 100644 src/dense/batched_gemm.rs diff --git a/src/dense.rs b/src/dense.rs index 9c4e52d4..7fcc064e 100644 --- a/src/dense.rs +++ b/src/dense.rs @@ -8,6 +8,7 @@ pub mod tools; pub mod traits; pub mod array; +pub mod batched_gemm; pub mod gemm; pub mod layout; pub mod macros; diff --git a/src/dense/array.rs b/src/dense/array.rs index 4e0c9f7e..22e86704 100644 --- a/src/dense/array.rs +++ b/src/dense/array.rs @@ -37,6 +37,14 @@ pub type SliceArray<'a, Item, const NDIM: usize> = pub type SliceArrayMut<'a, Item, const NDIM: usize> = Array, NDIM>, NDIM>; +/// A view onto a matrix +pub type ViewArray<'a, Item, ArrayImpl, const NDIM: usize> = + Array, NDIM>; + +/// A mutable view onto a matrix +pub type ViewArrayMut<'a, Item, ArrayImpl, const NDIM: usize> = + Array, NDIM>; + /// The basic tuple type defining an array. pub struct Array(ArrayImpl) where diff --git a/src/dense/array/mult_into.rs b/src/dense/array/mult_into.rs index 50a1a9a3..2c9a8463 100644 --- a/src/dense/array/mult_into.rs +++ b/src/dense/array/mult_into.rs @@ -148,6 +148,7 @@ impl< let shapeb = new_shape(arr_b.shape(), transb); let expected_shape = [shapea[0], shapeb[1]]; + if self.shape() != expected_shape { self.resize_in_place(expected_shape); } diff --git a/src/dense/batched_gemm.rs b/src/dense/batched_gemm.rs new file mode 100644 index 00000000..8af4c492 --- /dev/null +++ b/src/dense/batched_gemm.rs @@ -0,0 +1,207 @@ +//! Interface to batched gemm operations + +use crate::dense::array::{DynamicArray, ViewArray, ViewArrayMut}; +use crate::dense::base_array::BaseArray; +use crate::dense::data_container::VectorContainer; +use crate::dense::traits::{Shape, UnsafeRandomAccessByValue}; +use crate::dense::types::RlstScalar; +use crate::dense::types::TransMode; +use crate::{rlst_dynamic_array2, MultInto, RlstResult, UnsafeRandomAccessMut}; + +/// Batched matrix-matrix products. +/// +/// Implementations of this trait allow batched matrix-matrix products. +pub trait BatchedGemm { + /// The scalar type. + type Item: RlstScalar; + /// Array implementation type. + type ArrayImpl: UnsafeRandomAccessByValue<2, Item = Self::Item> + Shape<2>; + /// Mutable array implementation type. + type ArrayImplMut: UnsafeRandomAccessByValue<2, Item = Self::Item> + + UnsafeRandomAccessMut<2, Item = Self::Item> + + Shape<2>; + + /// Instantiate a batched matrix-matrix product. + fn with( + left_dim: (usize, usize), + right_dim: (usize, usize), + number_of_matrices: usize, + alpha: Self::Item, + beta: Self::Item, + ) -> Self; + + /// Access the left matrix with given index. + fn left_matrix(&self, index: usize) -> Option>; + + /// Mutably access the left matrix with given index/ + fn left_matrix_mut( + &mut self, + index: usize, + ) -> Option>; + + /// Access the right matrix with given index. + fn right_matrix(&self, index: usize) -> Option>; + + /// Mutably access the right matrix with given index. + fn right_matrix_mut( + &mut self, + index: usize, + ) -> Option>; + + /// Access the result matrix with given index. + fn result_matrix(&self, index: usize) -> Option>; + + /// Mutably access the result matrix with given index. + fn result_matrix_mut( + &mut self, + index: usize, + ) -> Option>; + + /// Evaluate the batched matrix product. + fn evaluate(&mut self) -> RlstResult<()>; +} + +struct DefaultCpuBatchedGemm { + left_matrices: Vec>, + right_matrices: Vec>, + result_matrices: Vec>, + number_of_matrices: usize, + alpha: Item, + beta: Item, +} + +impl BatchedGemm for DefaultCpuBatchedGemm { + type Item = Item; + + type ArrayImpl = BaseArray, 2>; + + type ArrayImplMut = BaseArray, 2>; + + fn with( + left_dim: (usize, usize), + right_dim: (usize, usize), + number_of_matrices: usize, + alpha: Self::Item, + beta: Self::Item, + ) -> Self { + assert_eq!(left_dim.1, right_dim.0); + + let mut left_matrices = Vec::>::with_capacity(number_of_matrices); + let mut right_matrices = Vec::>::with_capacity(number_of_matrices); + let mut result_matrices = Vec::>::with_capacity(number_of_matrices); + + for _ in 0..number_of_matrices { + left_matrices.push(rlst_dynamic_array2!(Item, [left_dim.0, left_dim.1])); + right_matrices.push(rlst_dynamic_array2!(Item, [right_dim.0, right_dim.1])); + result_matrices.push(rlst_dynamic_array2!(Item, [left_dim.0, right_dim.1])); + } + + Self { + left_matrices, + right_matrices, + result_matrices, + number_of_matrices, + alpha, + beta, + } + } + + fn left_matrix(&self, index: usize) -> Option> { + self.left_matrices.get(index).map(|mat| mat.view()) + } + + fn left_matrix_mut( + &mut self, + index: usize, + ) -> Option> { + self.left_matrices.get_mut(index).map(|mat| mat.view_mut()) + } + + fn right_matrix(&self, index: usize) -> Option> { + self.right_matrices.get(index).map(|mat| mat.view()) + } + + fn right_matrix_mut( + &mut self, + index: usize, + ) -> Option> { + self.right_matrices.get_mut(index).map(|mat| mat.view_mut()) + } + + fn result_matrix(&self, index: usize) -> Option> { + self.result_matrices.get(index).map(|mat| mat.view()) + } + + fn result_matrix_mut( + &mut self, + index: usize, + ) -> Option> { + self.result_matrices + .get_mut(index) + .map(|mat| mat.view_mut()) + } + + fn evaluate(&mut self) -> RlstResult<()> { + for index in 0..self.number_of_matrices { + let left_matrix = self.left_matrices[index].view(); + let right_matrix = self.right_matrices[index].view(); + let result_matrix = self.result_matrices[index].view_mut(); + result_matrix.mult_into( + TransMode::NoTrans, + TransMode::NoTrans, + self.alpha, + left_matrix, + right_matrix, + self.beta, + ); + } + Ok(()) + } +} + +#[cfg(test)] +mod test { + + use super::*; + + use crate::dense::traits::DefaultIterator; + use crate::dense::traits::MultIntoResize; + + #[test] + pub fn test_batched_cpu_gemm() { + let mut batched_matmul = DefaultCpuBatchedGemm::::with((2, 3), (3, 5), 2, 1.0, 0.0); + + batched_matmul + .left_matrix_mut(0) + .unwrap() + .fill_from_seed_equally_distributed(0); + batched_matmul + .left_matrix_mut(1) + .unwrap() + .fill_from_seed_equally_distributed(1); + + batched_matmul + .right_matrix_mut(0) + .unwrap() + .fill_from_seed_equally_distributed(2); + batched_matmul + .right_matrix_mut(1) + .unwrap() + .fill_from_seed_equally_distributed(3); + + batched_matmul.evaluate().unwrap(); + + for index in 0..2 { + let expected = crate::dense::array::empty_array().simple_mult_into_resize( + batched_matmul.left_matrix(index).unwrap(), + batched_matmul.right_matrix(index).unwrap(), + ); + + crate::assert_array_relative_eq!( + expected, + batched_matmul.result_matrix(index).unwrap(), + 1E-12 + ); + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 38ca6e04..61335ea8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,3 +12,10 @@ pub mod threading; pub mod operator; pub use prelude::*; + +#[cfg(test)] +mod test { + + extern crate blas_src; + extern crate lapack_src; +}