Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Optimization] Implicit gemm rewrite #2545

Open
wants to merge 60 commits into
base: main
Choose a base branch
from

Conversation

wingertge
Copy link
Contributor

@wingertge wingertge commented Nov 26, 2024

Pull Request Template

Checklist

  • Confirmed that run-checks all script has been executed.
  • Made sure the book is up to date with changes in this PR.

Related Issues/PRs

Requires tracel-ai/cubecl#309 to land first

Changes

Adds a brand new implicit GEMM implementation that uses the matmul primitives in cubecl. This is slower for small k sizes, but much faster for large ones, and more flexible. I'm keeping the current implementation because it's significantly faster for certain sizes, and uses a significantly different loader strategy (loading only within each warp, which skips cross warp syncs).

Adds a number of new convolution benchmarks to test performance with different sizes and characteristics.

Testing

All non-group tests pass, and CRAFT has the expected output with all layers using the new implicit GEMM. This tests many different and relatively large layers. Adds two new regression tests for bugs discovered during implementation.

@wingertge wingertge marked this pull request as ready for review November 27, 2024 18:14
Copy link

codecov bot commented Nov 27, 2024

Codecov Report

Attention: Patch coverage is 19.10230% with 775 lines in your changes missing coverage. Please review.

Project coverage is 81.96%. Comparing base (fba75d3) to head (5565f8e).
Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
...it/src/kernel/conv/conv2d/gemm/homogeneous/base.rs 1.75% 280 Missing ⚠️
...tes/burn-jit/src/kernel/conv/conv2d/gemm/launch.rs 28.65% 127 Missing ⚠️
...n-jit/src/kernel/conv/conv2d/gemm/loader/im2col.rs 0.00% 103 Missing ⚠️
...urn-jit/src/kernel/conv/conv2d/gemm/loader/bias.rs 0.00% 83 Missing ⚠️
...n-jit/src/kernel/conv/conv2d/gemm/reader/im2col.rs 0.00% 64 Missing ⚠️
.../burn-jit/src/kernel/conv/conv2d/gemm/algorithm.rs 24.48% 37 Missing ⚠️
...s/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs 0.00% 23 Missing ⚠️
...urn-jit/src/kernel/conv/conv2d/gemm/reader/bias.rs 0.00% 17 Missing ⚠️
...rates/burn-jit/src/kernel/conv/conv2d/gemm/base.rs 0.00% 13 Missing ⚠️
...urn-jit/src/kernel/conv/conv2d/gemm/loader/base.rs 0.00% 13 Missing ⚠️
... and 7 more
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2545      +/-   ##
==========================================
- Coverage   82.51%   81.96%   -0.55%     
==========================================
  Files         828      837       +9     
  Lines      107123   108030     +907     
==========================================
+ Hits        88395    88550     +155     
- Misses      18728    19480     +752     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Member

@nathanielsimard nathanielsimard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks awesome! Feels great to reuse a lot of components. There are still some improvements that we can make in our "design paradigm", especially in how we pass around the config. But this is beyond the scope of this PR.

I have a few comments, but it would also be great for @louisfd to review.

Comment on lines +117 to +120
pub struct CmmaHalf<EG: Numeric, Stage: StageSize> {
pub _eg: PhantomData<EG>,
pub _stage: PhantomData<Stage>,
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could the Cmma struct be generic over the accumulation precision?

Self::LhsLoader::advance_view(&mut lhs_loader, k_step);
Self::RhsLoader::advance_view(&mut rhs_loader, k_step);
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Somehow adding a sync_units after the for loop improved performance for the matmul. I think it makes sure all units in a plane are sync following the loop which improve the execution of following operations.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll benchmark it

Comment on lines +49 to +50
///
///
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Empty lines in comment block 😅

Comment on lines +13 to +48
use crate::kernel::conv::homogeneous::base::config;

#[cube]
/// Input to the convolution, responsible of filling the stage and providing a reader for it.
/// Advances along the k-dimension to fill the stage with further data.
pub trait Loader<EG: Numeric, ES: Numeric, G: global::Config>:
CubeType + 'static + Send + Sync
{
/// The stage reader which matches the input of the underlying stage matmul.
type StageReader: CubeType;

/// Fills the stage at the current k offset and returns a reader for it.
fn fill_stage(this: &mut Self, #[comptime] config: G) -> Self::StageReader;

/// Move the k offset by k_offset
fn advance_view(this: &mut Self, k_offset: u32);
}

#[cube]
impl<EG: Numeric, ES: Numeric, S: stage::Config, L: LoadingStrategy>
Loader<EG, ES, config::Config<S>> for RhsLoader<EG, ES, S, L>
{
type StageReader = RhsReader<ES>;

fn fill_stage(this: &mut Self, #[comptime] config: config::Config<S>) -> Self::StageReader {
CyclicLoading::load_to_slice::<EG, ES, config::Config<S>>(
&this.tensor_view,
&mut this.stage.as_slice_mut(),
Ident::Rhs,
config,
);
RhsReader::new(this.stage)
}

fn advance_view(this: &mut Self, k_offset: u32) {
this.tensor_view.update_view(k_offset, Ident::Rhs);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need to duplicate this? I don't believe we actually need to have the constraint G: global::Config in the trait, only G is good enough. The Algorithm trait can make the link between the two types.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, this used to be required when the config was a generic on the function, but it's now superfluous.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants