Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vectorized serial grid reduction #1528

Merged
merged 41 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
aa6598d
Start lowering serial grid reduction
jacobhinkle Dec 6, 2023
68fe6e8
Add grid_serialization.{cpp,h}
jacobhinkle Dec 8, 2023
e27440c
Add requestSerialGridReduction
jacobhinkle Dec 8, 2023
e04cea5
Disable previous changes to indexing pass.
jacobhinkle Dec 8, 2023
ebec0e5
Call insertGridSerializationSyncs pass
jacobhinkle Dec 8, 2023
2530e51
Remove file added by mistake
jacobhinkle Dec 8, 2023
f395b57
Fix formatting lintrunner messed up
jacobhinkle Dec 8, 2023
a6f6611
Bump num_reduction_op_attr
jacobhinkle Dec 8, 2023
109e878
Add test
jacobhinkle Dec 8, 2023
e96a57f
Fix sync insertion in lowering pass.
jacobhinkle Dec 11, 2023
43234b9
Fix missing allocation of sync flag buffer
jacobhinkle Dec 11, 2023
c9b126e
Allocate global work buffer. Index is zero for now
jacobhinkle Dec 12, 2023
a2aec20
Taking a stab at replay/indexing of intermediate
jacobhinkle Dec 13, 2023
651a6ae
Use fullSelfReplay and getGlobalConsumerStridedIndices
jacobhinkle Dec 13, 2023
76f55b9
Infer shape using allocation domain instead of root
jacobhinkle Dec 13, 2023
a92c0bd
Update comments
jacobhinkle Dec 13, 2023
b24f3cb
Hoist index scalar.
jacobhinkle Dec 13, 2023
10872c1
Clean up comments.
jacobhinkle Dec 13, 2023
2ca845a
Update NVFuserTest.Pipeline_CUDA
jacobhinkle Dec 13, 2023
d60ba14
Clean up sum val computation
jacobhinkle Jan 4, 2024
6e3e55f
Clean up comments and reset sync pattern properly
jacobhinkle Jan 10, 2024
7afca1b
Fix compile error
jacobhinkle Jan 11, 2024
50b59f9
Re-use TensorDomain instead of replaying
jacobhinkle Jan 16, 2024
83c60a5
Copy domains to create new TensorDomain instead of reusing
jacobhinkle Jan 18, 2024
5f9c1cf
Allocate work buffer like leaf of output
jacobhinkle Jan 19, 2024
c6d3f10
Use serial grid reduction in split-K
jacobhinkle Dec 13, 2023
d90fbb0
Set proper dtype for init in MmaOp
jacobhinkle Dec 13, 2023
9286006
Restore split-k benchmarks
jacobhinkle Dec 20, 2023
491280d
Fix after rebase
jacobhinkle Jan 19, 2024
a9cd7b1
Vectorized serial grid reduction
jacobhinkle Dec 14, 2023
5171f41
Remove debug prints
jacobhinkle Dec 14, 2023
a24a552
Use loadGlobalToLocal instead of loadGenericVolatile
jacobhinkle Dec 18, 2023
25c8d77
Restore main
jacobhinkle Jan 19, 2024
c9e27a7
Restore changes from branch
jacobhinkle Jan 19, 2024
e40addb
Delint test
jacobhinkle Feb 2, 2024
b705d06
Restore check for SerialGridReduction in validateAndCollectVectorizeInfo
jacobhinkle Feb 2, 2024
68a268b
Merge remote-tracking branch 'origin/main' into vectorized_serial_red…
jacobhinkle Feb 2, 2024
dd9405b
Fix typo
jacobhinkle Feb 2, 2024
f93ddc7
Vectorize in scheduler
jacobhinkle Feb 2, 2024
7dcacf3
Remove obsolete CodegenNodes test
jacobhinkle Feb 3, 2024
99263d6
Merge branch 'main' into vectorized_serial_reduction
jacobhinkle Feb 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions csrc/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1693,11 +1693,16 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
block_flags,
ArgumentBuilder().arg("gridDim"));

int64_t vectorize_size = ir_utils::getVectorizeSize(out->view());

ArgumentBuilder template_args;
template_args.arg("/*vec_size=*/").append(std::to_string(vectorize_size));

ArgumentBuilder func_args(block_nest_level_ + 1, kTab);
func_args.arg(gen(out));
func_args.arg(gen(grop->in()));
func_args.arg("&").append(gen(out));
func_args.arg("&").append(gen(grop->in()));
func_args.arg(gen(grop->init()));
func_args.arg(gen(grop->serialReductionTensor()));
func_args.arg("&").append(gen(grop->serialReductionTensor()));
func_args.arg(genReductionOp(op_type, out->dtype()));

// Whether this is the first or last step
Expand All @@ -1720,7 +1725,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
func_args.arg(read_pred);
}

indent() << "reduction::serialReductionStep(\n";
indent() << "reduction::serialReductionStep<" << template_args << ">(\n";
indent() << kTab << func_args << ");\n";
}

Expand Down
6 changes: 4 additions & 2 deletions csrc/device_lower/validation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -584,9 +584,11 @@ void validateAndCollectVectorizeInfo(Fusion* fusion) {
}
}
if (has_vectorize_dim) {
Expr* def = tv->definition();
NVF_ERROR(
tv->definition() == nullptr || tv->definition()->isA<LoadStoreOp>() ||
tv->definition()->isA<SliceOp>(),
def == nullptr || def->isA<LoadStoreOp>() || def->isA<SliceOp>() ||
(def->isA<ReductionOp>() &&
def->as<ReductionOp>()->serialGridReductionRequested()),
"Vectorized accesses cannot be inline with computation, they are only supported with a Set operation.",
"TensorView: ",
tv);
Expand Down
5 changes: 3 additions & 2 deletions csrc/scheduler/matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1100,8 +1100,7 @@ void scheduleMatmul(Fusion* fusion, const MatmulParams& params) {
// (iS) iBx iBy iTz iTy iS iS iS iTx iS rBz
//
// This reordering step lets us inline all but the last dim MNi3 (position
// nbatch + 7) which might be vectorized for the epilogue but which we
// can't vectorize for the gridReduce.
// nbatch + 7) which might be vectorized.
//
// NOTE: we need to do this reorder after the propagation above so that it
// doesn't get reset.
Expand All @@ -1115,6 +1114,8 @@ void scheduleMatmul(Fusion* fusion, const MatmulParams& params) {
{num_batch_dims + 8, num_batch_dims + 7},
{num_batch_dims + 9, num_batch_dims + 8},
});
// Vectorize inner-most dimension
splitk_sum->axis(-1)->parallelize(ParallelType::Vectorize);
}

// auto inline for all tensors except register tensors
Expand Down
28 changes: 20 additions & 8 deletions runtime/grid_reduction.cu
Original file line number Diff line number Diff line change
Expand Up @@ -625,19 +625,19 @@ __device__ void gridReduceGroup(

// This performs a single reduction step, combining a single element "in" with
// a previous value "work". For a serial grid reduction, "work" resides in
// global memory.
// global memory, while "in" and "out" are in registers.
//
// If the write predicate is false, this function returns early (noop). If the
// read predicate is false, "init" is used in place of "in".
//
// If first_step is false, "work" will be read and reduction_op will be called.
// The result will be written back to "work" unless last_step is true.
template <typename T, typename Func>
template <int64_t vec_size, typename T, typename Func>
__device__ void serialReductionStep(
T& out,
T in,
T* out,
T* in,
T init,
volatile T& work,
volatile T* work,
Func reduction_op,
bool first_step,
bool last_step,
Expand All @@ -646,12 +646,24 @@ __device__ void serialReductionStep(
if (!write_pred) {
return;
}
out = read_pred ? in : init;
if (read_pred) {
loadGeneric<T, vec_size>(out, in);
jacobhinkle marked this conversation as resolved.
Show resolved Hide resolved
} else {
#pragma unroll
for (int i = 0; i < vec_size; ++i) {
out[i] = init;
}
}
if (!first_step) {
reduction_op(out, work);
T work_reg[vec_size];
loadGlobalToLocal<T, vec_size, true, CacheOp::Global>(work_reg, work);
#pragma unroll
for (int i = 0; i < vec_size; ++i) {
reduction_op(out[i], work_reg[i]);
}
}
if (!last_step) {
work = out;
loadLocalToGlobal<T, vec_size, true>(work, out);
}
}

Expand Down
221 changes: 0 additions & 221 deletions test/test_serial_gridreduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,227 +35,6 @@ namespace nvfuser {

using SerialGridReductionTest = NVFuserTest;

// Test that we are able to generate code for a serial reduction
// TODO: remove this test once lowering of serial grid reductions is implemented
TEST_F(SerialGridReductionTest, CodegenNodes) {
for (bool serial : {true, false}) {
for (int64_t num_warps : {4, 8}) {
// B is size of inner serial loop. Outer loop is hardcoded at A=4
// Here we set B to a small value of 8 instead of 32 (i.e. 128 elements
// per thread), so that the non-serial compilation does not take too
// long.
for (int64_t B : {8}) {
std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
auto fusion = fusion_ptr.get();
FusionGuard fg(fusion);

int64_t blocks_x = 8;
int64_t blocks_y = 8;
int64_t blocks_z = 5;
int64_t A = 4; // Size of outer serial loop
int64_t H = blocks_z;
int64_t W = A * B * blocks_x * blocks_y * num_warps * 32;

// Unreduced dimensions should be concrete. Reduced dimension could be
// symbolic, but is concrete here so that we can read tv0 to registers
TensorView* tv0 = TensorViewBuilder()
.shape({H, W})
.dtype(DataType::Float)
.contiguity(true)
.build();
fusion->addInput(tv0);

auto tv1 = sum(tv0, {0});
fusion->addOutput(tv1);

// Start with
// [ rS{H}, iS{W} ]
// We are grid reducing the H dimension and we want to coalesce
// accesses in the W dimension. So we first reorder to
// [ iS{W}, rS{H} ]
// then schedule as
// [ iBIDx{blocks_x}, iBIDy{blocks_y}, iS{A}, iS{B}, iTIDy{num_warps},
// iTIDx{32}, rBIDz{blocks_z} ]
auto tv2 = tv0->cacheAfter();
auto tv3 = tv1->cacheBefore();

tv3->reorder({{1, 0}, {0, 1}}); // blocks_x*blocks_y*A*B*num_warps*32, H
tv3->split(0, 32); // blocks_x*blocks_y*A*B*num_warps, 32, H
tv3->split(0, num_warps); // blocks_x*blocks_y*A*B, num_warps, 32, H
tv3->split(0, B); // blocks_x*blocks_y*A, B, num_warps, 32, H
tv3->split(0, A); // blocks_x*blocks_y, A, B, num_warps, 32, H
tv3->split(0, blocks_y); // blocks_x, blocks_y, A, B, num_warps, 32, H
tv3->axis(0)->parallelize(ParallelType::BIDx);
tv3->axis(1)->parallelize(ParallelType::BIDy);
tv3->axis(4)->parallelize(ParallelType::TIDy);
tv3->axis(5)->parallelize(ParallelType::TIDx);
tv3->axis(6)->parallelize(ParallelType::BIDz);
// Reorder to put parallel dims first for better inlining
tv3->reorder({
{4, 2},
{5, 3},
{2, 4},
{3, 5},
});

TransformPropagator propagator(tv3);
MaxRootDomainInfoSpanningTree(tv3).traverse(&propagator);
scheduler_utils::parallelizeAllLike(tv3);

// Here we just transpose A and B in tv2, so that it will be partially
// inlined with tv3, resulting in a separate loop to load tv0 into
// registers (tv2).
tv2->reorder({
{-2, -3},
{-3, -2},
});

inlineMost();

FusionExecutor fe;
if (serial) {
fe.registerPostLoweringHook([](kir::Kernel* kernel) {
FusionGuard fg(kernel);

std::vector<Expr*>& top_level_exprs =
const_cast<std::vector<Expr*>&>(kernel->topLevelExprs());
kir::KernelSummary& summary =
const_cast<kir::KernelSummary&>(kernel->summary());
std::vector<const kir::Allocate*>& global_allocations =
summary.global_allocations;
// There should be a work buffer and a sync buffer allocated
ASSERT_EQ(global_allocations.size(), 2);

// Find the position of the last outer loop
size_t top_level_loop_pos = -1;
for (size_t i : c10::irange(top_level_exprs.size())) {
Expr* expr = top_level_exprs.at(i);
if (expr->isA<kir::ForLoop>()) {
top_level_loop_pos = i;
}
}

// This is a poor approximation of a traversal that would appear in
// a lowering pass to both set the isSerial() flag on grid
// reductions and insert wait/release syncs.
//
// tidx_scope is the inner-most fully parallelized scope. It is
// "top-level" in that its loops appear as top-level in the
// generated kernel
kir::Scope& tidx_scope = top_level_exprs.at(top_level_loop_pos)
->as<kir::ForLoop>()
->body() // BIDx
.at(0)
->as<kir::ForLoop>()
->body() // BIDy
.at(0)
->as<kir::ForLoop>()
->body() // TIDy
.at(0)
->as<kir::ForLoop>()
->body(); // TIDx
kir::Scope& B_scope = tidx_scope.exprs()
.at(5)
->as<kir::ForLoop>()
->body() // A (reduction loop)
.exprs()
.back()
->as<kir::ForLoop>()
->body(); // B
// We will need the store op output TensorIndex
LoadStoreOp* output_store_expr = B_scope.exprs()
.back()
->as<kir::IfThenElse>()
->thenBody()
.at(0)
->as<LoadStoreOp>();
// bidz_scope is the scope containing the GridReduction expression
kir::Scope& bidz_scope =
B_scope.exprs().at(4)->as<kir::ForLoop>()->body(); // BIDz
auto old_grop = bidz_scope.at(0)->as<kir::GridReduction>();
// Store the TensorIndex for the output tensor T1_g, so that we can
// re-use its index
auto t1_idx = output_store_expr->output(0)->as<kir::TensorIndex>();

// Create new TensorView and Allocate
auto output = kernel->outputs().at(0)->as<TensorView>();
Val* i0 = output->getRootDomain().at(0)->extent();
auto new_work_buf_tv =
TensorViewBuilder().shape(std::vector<Val*>{i0}).build();
new_work_buf_tv->setMemoryType(MemoryType::Global);
// associate the index of the output tensor with the work buffer
// NOTE: in actual lowering we would generate an index ourselves
// here, but this works for this test since the T1 store is inlined
// fully with the serial grid reduction.
Val* idx = t1_idx->index();

auto new_work_buf_idx =
IrBuilder::create<kir::TensorIndex>(new_work_buf_tv, idx);
auto new_work_buf_alloc = IrBuilder::create<kir::Allocate>(
new_work_buf_tv, MemoryType::Global, std::vector<Val*>{i0});
const kir::Allocate* orig_work_buf_alloc = global_allocations[0];
global_allocations[0] = new_work_buf_alloc;
// replace work buf alloc expr in top_level_exprs
for (auto i : c10::irange(top_level_exprs.size())) {
if (top_level_exprs[i] == orig_work_buf_alloc) {
top_level_exprs[i] = new_work_buf_alloc;
}
}
// replace work buf in kernel->parameters()
std::vector<Val*>& params =
const_cast<std::vector<Val*>&>(kernel->parameters());
for (auto i : c10::irange(params.size())) {
if (params[i] == orig_work_buf_alloc->buffer()) {
params[i] = new_work_buf_tv;
}
}
// replace the grid reduction Expr
auto new_grop = IrBuilder::create<kir::GridReduction>(
old_grop->getReductionOpType(),
old_grop->init(),
old_grop->out(),
old_grop->in(),
new_work_buf_alloc,
old_grop->sync_buffer(),
old_grop->entrance_index(),
old_grop->entrances(),
old_grop->isAllreduce(),
new_work_buf_idx);
new_grop = new_grop->withPredicate(old_grop->predicate())
->as<kir::GridReduction>();
new_grop = new_grop->withWritePredicate(old_grop->writePredicate())
->as<kir::GridReduction>();
bidz_scope.at(0) = new_grop;

auto sync_buf = global_allocations.at(1)->buffer();

std::vector<Expr*>& nonpar_top_level_exprs =
const_cast<std::vector<Expr*>&>(tidx_scope.exprs());
nonpar_top_level_exprs.insert(
nonpar_top_level_exprs.end() - 2,
IrBuilder::create<kir::BlockSerializeWait>(
ParallelTypeBitmap(ParallelType::BIDz), sync_buf));

nonpar_top_level_exprs.insert(
nonpar_top_level_exprs.end() - 1,
IrBuilder::create<kir::BlockSerializeRelease>(
ParallelTypeBitmap(ParallelType::BIDz), sync_buf));
});
}
fe.compileFusion(fusion);

auto input = at::randn(
{H, W}, at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0));
auto outputs = fe.runFusion({input});

if (serial) {
testValidate(fusion, outputs, {input}, __LINE__, __FILE__);
}
}
}
}
}

TEST_F(SerialGridReductionTest, Scheduling) {
for (bool serial : {true, false}) {
for (int64_t num_warps : {4, 8}) {
Expand Down
Loading