-
Notifications
You must be signed in to change notification settings - Fork 447
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Optimization] Implicit gemm rewrite #2545
base: main
Are you sure you want to change the base?
Conversation
…lt in CPU execution
…tial issues on AMD
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks awesome! Feels great to reuse a lot of components. There are still some improvements that we can make in our "design paradigm", especially in how we pass around the config. But this is beyond the scope of this PR.
I have a few comments, but it would also be great for @louisfd to review.
pub struct CmmaHalf<EG: Numeric, Stage: StageSize> { | ||
pub _eg: PhantomData<EG>, | ||
pub _stage: PhantomData<Stage>, | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could the Cmma
struct be generic over the accumulation precision?
Self::LhsLoader::advance_view(&mut lhs_loader, k_step); | ||
Self::RhsLoader::advance_view(&mut rhs_loader, k_step); | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Somehow adding a sync_units
after the for loop improved performance for the matmul. I think it makes sure all units in a plane are sync following the loop which improve the execution of following operations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll benchmark it
/// | ||
/// |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Empty lines in comment block 😅
use crate::kernel::conv::homogeneous::base::config; | ||
|
||
#[cube] | ||
/// Input to the convolution, responsible of filling the stage and providing a reader for it. | ||
/// Advances along the k-dimension to fill the stage with further data. | ||
pub trait Loader<EG: Numeric, ES: Numeric, G: global::Config>: | ||
CubeType + 'static + Send + Sync | ||
{ | ||
/// The stage reader which matches the input of the underlying stage matmul. | ||
type StageReader: CubeType; | ||
|
||
/// Fills the stage at the current k offset and returns a reader for it. | ||
fn fill_stage(this: &mut Self, #[comptime] config: G) -> Self::StageReader; | ||
|
||
/// Move the k offset by k_offset | ||
fn advance_view(this: &mut Self, k_offset: u32); | ||
} | ||
|
||
#[cube] | ||
impl<EG: Numeric, ES: Numeric, S: stage::Config, L: LoadingStrategy> | ||
Loader<EG, ES, config::Config<S>> for RhsLoader<EG, ES, S, L> | ||
{ | ||
type StageReader = RhsReader<ES>; | ||
|
||
fn fill_stage(this: &mut Self, #[comptime] config: config::Config<S>) -> Self::StageReader { | ||
CyclicLoading::load_to_slice::<EG, ES, config::Config<S>>( | ||
&this.tensor_view, | ||
&mut this.stage.as_slice_mut(), | ||
Ident::Rhs, | ||
config, | ||
); | ||
RhsReader::new(this.stage) | ||
} | ||
|
||
fn advance_view(this: &mut Self, k_offset: u32) { | ||
this.tensor_view.update_view(k_offset, Ident::Rhs); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you need to duplicate this? I don't believe we actually need to have the constraint G: global::Config
in the trait, only G
is good enough. The Algorithm
trait can make the link between the two types.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right, this used to be required when the config was a generic on the function, but it's now superfluous.
Pull Request Template
Checklist
run-checks all
script has been executed.Related Issues/PRs
Requires tracel-ai/cubecl#309 to land first
Changes
Adds a brand new implicit GEMM implementation that uses the matmul primitives in
cubecl
. This is slower for small k sizes, but much faster for large ones, and more flexible. I'm keeping the current implementation because it's significantly faster for certain sizes, and uses a significantly different loader strategy (loading only within each warp, which skips cross warp syncs).Adds a number of new convolution benchmarks to test performance with different sizes and characteristics.
Testing
All non-group tests pass, and CRAFT has the expected output with all layers using the new implicit GEMM. This tests many different and relatively large layers. Adds two new regression tests for bugs discovered during implementation.