Skip to content

Commit

Permalink
Enable serial grid reduction for split-K in matmul scheduler (#1510)
Browse files Browse the repository at this point in the history
Stacked on #1456.

This simply enables serial reduction in the matmul scheduler when
split-K is used.
  • Loading branch information
jacobhinkle authored Jan 23, 2024
1 parent da7c4e9 commit 4e2ff18
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 19 deletions.
6 changes: 1 addition & 5 deletions benchmark/matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -685,11 +685,7 @@ static void MatmulShapeWarpStageAutoSplitK(benchmark::internal::Benchmark* b) {

ForAllLayouts(EagerModeBenchmark);
ForAllLayouts(NvfuserMatmulBenchmark);
// Disable split-K benchmarks due to slow compilation.
// See https://github.com/NVIDIA/Fuser/issues/1389.
// These benchmarks should be enabled again after merging
// https://github.com/NVIDIA/Fuser/pull/1510
// ForAllLayouts(AutoSplitKBenchmark);
ForAllLayouts(AutoSplitKBenchmark);
ForAllLayouts(AutoPartitionedKBenchmark);

// Note: SplitK Reduction benchmarks are parametrized only by M, N. The splitk
Expand Down
28 changes: 15 additions & 13 deletions csrc/ops/arith.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2559,19 +2559,6 @@ TensorView* fusedMultiplySum(
TensorView* tv_b,
const std::vector<int>& axes,
Val* init) {
if (init == nullptr) {
init = IrBuilder::create<Val>(0.0);
}

// TODO:
// We will want to support initialize and rfactor with
// mma as well, for maybe fusing bias in prolog.
// TODO: check init type if given a tv,
// not supported currently though.
NVF_CHECK(
init->isConstScalar(),
"Cannot create a reduction operation where the initial value is not a const scalar.");

// TODO:
// Validate axis relationships between a and b
NVF_CHECK(tv_a->nDims() > 0, "Tried to reduce a 0-dim tensor");
Expand All @@ -2596,6 +2583,21 @@ TensorView* fusedMultiplySum(
canonicalizeAxes(axes, tv_a->domain()->noReductions().size());

TensorView* out = newForMma(tv_a, tv_b, uint_axes);

if (init == nullptr) {
init = IrBuilder::create<Val>(0.0, out->dtype());
}

// TODO:
// We will want to support initialize and rfactor with
// mma as well, for maybe fusing bias in prolog.
NVF_CHECK(
init->isConstScalar(),
"Cannot create a reduction operation where the initial value is not a const scalar.");
NVF_CHECK(
init->dtype() == out->dtype(),
"Init value dtype for fusedMultiplySum must match output.");

IrBuilder::create<MmaOp>(out, tv_a, tv_b, init);

return out;
Expand Down
2 changes: 2 additions & 0 deletions csrc/scheduler/matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -899,6 +899,8 @@ void scheduleMatmul(Fusion* fusion, const MatmulParams& params) {
splitk_sum = mma_result;
mma_result = splitk_sum->rFactor({-4, -1});

splitk_sum->definition()->as<ReductionOp>()->requestSerialGridReduction();

num_splitk_dims = 1;
}

Expand Down
3 changes: 2 additions & 1 deletion csrc/scheduler/mma_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1288,8 +1288,9 @@ RolesMapOpt getTensorsRoles(Fusion* fusion) {
namespace {

void addMMAOp(Fusion* fusion_, std::vector<MulSumProperties>& props) {
auto* init = IrBuilder::create<Val>(0.0);
for (auto prop : props) {
auto* init =
IrBuilder::create<Val>(0.0, prop.insouts.out->getDataType().value());
IrBuilder::create<MmaOp>(
prop.insouts.out, prop.insouts.a, prop.insouts.b, init);
}
Expand Down

0 comments on commit 4e2ff18

Please sign in to comment.