Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

combined inner outer reduction, add a simple test case #2400

Open
wants to merge 14 commits into
base: devel
Choose a base branch
from
85 changes: 57 additions & 28 deletions third_party/nvfuser/csrc/scheduler/normalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,10 @@ std::shared_ptr<ReductionParams> innerOuterPersistentHeuristic(

// Parameters for outer reduction:
// Reduction dim: bdimy
// Iteration dim: outer_unroll_factor, bdimx, gdimy
// Iteration dim: tmp_gmem_read_vect, bdimx, gdimy
struct InnerOuterParams {
int64_t inner_vect = -1;
int64_t inner_batch = -1;
int64_t outer_unroll_factor = -1;
int64_t bdimx = -1;
int64_t bdimy = -1;
int64_t gdimy = -1;
Expand All @@ -64,8 +63,6 @@ std::shared_ptr<ReductionParams> innerOuterPersistentHeuristic(
void verify() {
TORCH_INTERNAL_ASSERT(inner_vect != -1, "inner_vect is not set.");
TORCH_INTERNAL_ASSERT(inner_batch != -1, "inner_batch is not set.");
TORCH_INTERNAL_ASSERT(
outer_unroll_factor != -1, "outer_unroll_factor is not set.");
TORCH_INTERNAL_ASSERT(bdimx != -1, "bdimx is not set.");
TORCH_INTERNAL_ASSERT(bdimy != -1, "bdimy is not set.");
TORCH_INTERNAL_ASSERT(gdimy != -1, "gdimy is not set.");
Expand Down Expand Up @@ -178,11 +175,23 @@ std::shared_ptr<ReductionParams> innerOuterPersistentHeuristic(
iop.gdimy = blocks_per_sm * device_multiprocessor_count;
}

// Step-3, set OuterParams Iteration dim: outer_unroll_factor, bdimx, gdimy
// set the vectorization factor for the write to tmp gmem, may be different
// from inner_vect due to different data types, e.g. input is half and
// tmp_gmem is float
constexpr int64_t max_gmem_vect_access_bytes = 16;
const int64_t max_tmp_gmem_vect_factor =
max_gmem_vect_access_bytes / tmp_gmem_dtype_size;
iop.tmp_gmem_write_vect = std::min(max_tmp_gmem_vect_factor, iop.inner_vect);

// Step-3, set OuterParams Iteration dim: tmp_gmem_read_vect, bdimx, gdimy
// (already done)
iop.outer_unroll_factor = inner_dim_numel >= 4096 ? 4 : 2;
// The partial outer reduction result is stored in tmp gmem, set the
// vectorization factor for write and read
const int64_t workload_per_thread = inner_dim_numel >= 4096 ? 4 : 2;
iop.tmp_gmem_read_vect =
std::min(workload_per_thread, max_tmp_gmem_vect_factor);
iop.bdimx = scheduler_utils::roundUpPow2(
ceilDiv(inner_dim_numel / iop.outer_unroll_factor, iop.gdimy));
ceilDiv(inner_dim_numel / iop.tmp_gmem_read_vect, iop.gdimy));

// Step-4, set OuterParams Reduction dim: bdimy.
iop.bdimy = ceilDiv(threads_per_block, iop.bdimx);
Expand All @@ -208,12 +217,11 @@ std::shared_ptr<ReductionParams> innerOuterPersistentHeuristic(
threads_per_sm, threads_per_block_mrpb, dev_prop->warpSize);
iop.gdimy = blocks_per_sm * device_multiprocessor_count;

// Step-3, OuterParams, Iteration dim: outer_unroll_factor, bdimy, gdimy (in
// previous step). unroll_factor is set to 2 as a small unroll_factor is
// preferred for small sizes and we only process vectorized cases.
iop.outer_unroll_factor = 2;
// Step-3, OuterParams, Iteration dim: tmp_gmem_read_vect(reuse), bdimy, gdimy (in
// previous step). tmp_gmem_read_vect is set to 2 as a small workload per
// thread is preferred for small sizes and we only process vectorized cases.
iop.bdimy = std::min(
ceilDiv(inner_dim_numel / iop.outer_unroll_factor, iop.gdimy),
ceilDiv(inner_dim_numel / iop.tmp_gmem_read_vect, iop.gdimy),
scheduler_utils::safeDiv(threads_per_block_mrpb, iop.bdimx));
iop.bdimy = iop.bdimy;

Expand All @@ -228,15 +236,6 @@ std::shared_ptr<ReductionParams> innerOuterPersistentHeuristic(
rparams->block_dim_inner_reduction_extra = ParallelType::TIDy;
}

// The partial outer reduction result is stored in tmp gmem, set the
// vectorization factor for write and read
constexpr int64_t max_gmem_vect_access_bytes = 16;
const int64_t max_tmp_gmem_vect_factor =
max_gmem_vect_access_bytes / tmp_gmem_dtype_size;
iop.tmp_gmem_write_vect = std::min(max_tmp_gmem_vect_factor, iop.inner_vect);
iop.tmp_gmem_read_vect =
std::min(max_tmp_gmem_vect_factor, iop.outer_unroll_factor);

// check all the parameters in InnerOuterParams are set.
iop.verify();

Expand All @@ -245,7 +244,7 @@ std::shared_ptr<ReductionParams> innerOuterPersistentHeuristic(
rparams->combined_inner_outer = true;
// tmp_gmem is the intermediate result of outer reduction, its dtype is float,
// so the maximum vectorization factor is 4.
rparams->vectorization_factor_tmp_gmem_read = iop.tmp_gmem_write_vect;
rparams->vectorization_factor_tmp_gmem_read = iop.tmp_gmem_read_vect;
rparams->vectorization_factor_tmp_gmem_write = iop.tmp_gmem_write_vect;
rparams->cparams.maxrregcount =
getRegPerThreadGivenThreadsPerSM(iop.bdimx * iop.bdimy * blocks_per_sm);
Expand All @@ -267,11 +266,12 @@ std::shared_ptr<ReductionParams> innerOuterPersistentHeuristic(
rparams->tag = "InnerOuter Persistent Heuristic.\n";

if (isDebugDumpEnabled(DebugDumpOption::SchedulerDebug)) {
std::cerr << "\n===== Combined Reduction Stats ========\n"
std::cerr << "\n===== Combined InnerOuter Reduction Stats ========\n"
<< "outer_dim_numel: " << outer_dim_numel << "\n"
<< "inner_dim_numel: " << inner_dim_numel << "\n"
<< "vectorize_factor: " << iop.inner_vect << "\n"
<< "outer_unroll_factor: " << iop.outer_unroll_factor << "\n"
<< "vectorize_factor_input: " << iop.inner_vect << "\n"
<< "vectorization_factor_tmp_gmem_write: " << iop.tmp_gmem_write_vect << "\n"
<< "vectorization_factor_tmp_gmem_read: " << iop.tmp_gmem_read_vect << "\n"
<< "multiple_reds_per_blk: " << rparams->multiple_reds_per_blk
<< "\n"
<< "threads_per_sm: " << threads_per_sm << "\n"
Expand Down Expand Up @@ -1623,6 +1623,7 @@ void scheduleReductionCombinedOuter(
TensorView* partialResult = outer_reduction_tv->rFactor({1});
partialResult->cacheBefore();
partialResult->setMemoryType(MemoryType::Global);
std::cout << "partialResult= " << partialResult->toString() << std::endl;
TensorView* partialResultReload = partialResult->cacheAfter();

boundaryNodesSet.insert(partialResultReload);
Expand All @@ -1643,6 +1644,7 @@ void scheduleReductionCombinedOuter(
0, NamedScalar::getParallelDim(ParallelType::TIDy));
outer_reduction_tv->axis(1)->parallelize(ParallelType::TIDy);
}
std::cout << "outer_reduction_tv= " << outer_reduction_tv->toString()<< std::endl;;

// iteration domain
int axisID = -1;
Expand All @@ -1651,6 +1653,7 @@ void scheduleReductionCombinedOuter(
axisID, rparams.vectorization_factor_tmp_gmem_read);
outer_reduction_tv->axis(axisID--)->parallelize(ParallelType::Unroll);
}
std::cout << "outer_reduction_tv= " << outer_reduction_tv->toString()<< std::endl;;

if (rparams.tidx_for_outer_reduction) {
outer_reduction_tv->split(
Expand All @@ -1661,36 +1664,58 @@ void scheduleReductionCombinedOuter(
axisID, NamedScalar::getParallelDim(ParallelType::TIDx));
outer_reduction_tv->axis(axisID--)->parallelize(ParallelType::TIDx);
}
std::cout << "outer_reduction_tv= " << outer_reduction_tv->toString()<< std::endl;;

outer_reduction_tv->split(
axisID, NamedScalar::getParallelDim(ParallelType::BIDy));
outer_reduction_tv->axis(axisID--)->parallelize(ParallelType::BIDy);

std::cout << "outer_reduction_tv= " << outer_reduction_tv->toString()<< std::endl;;


} else {
// reduction domain
outer_reduction_tv->split(
0, NamedScalar::getParallelDim(ParallelType::TIDy));
outer_reduction_tv->axis(1)->parallelize(ParallelType::TIDy);
std::cout << "outer_reduction_tv= " << outer_reduction_tv->toString()
<< std::endl;

// iteration domain
int axisID = -1;
if (rparams.vectorization_factor_tmp_gmem_read > 1) {
outer_reduction_tv->split(
axisID, rparams.vectorization_factor_tmp_gmem_read);
outer_reduction_tv->axis(axisID--)->parallelize(ParallelType::Unroll);
}

std::cout << "outer_reduction_tv= " << outer_reduction_tv->toString()
<< std::endl;

if (rparams.lparams.bdimx() > 1) {
outer_reduction_tv->split(
axisID, NamedScalar::getParallelDim(ParallelType::TIDx));
outer_reduction_tv->axis(axisID--)->parallelize(ParallelType::TIDx);
}
std::cout << "outer_reduction_tv= " << outer_reduction_tv->toString()
<< std::endl;

outer_reduction_tv->split(
axisID, NamedScalar::getParallelDim(ParallelType::BIDy));

std::cout << "outer_reduction_tv= " << outer_reduction_tv->toString()
<< std::endl;

outer_reduction_tv->axis(axisID--)->parallelize(ParallelType::BIDy);

std::cout << "outer_reduction_tv= " << outer_reduction_tv->toString()
<< std::endl;
}

auto outer_reference_tv =
reduction_scheduler_utils::sortAndRFactor(outer_reduction_tv);
outer_reference_tvs.emplace_back(outer_reference_tv);

}
}

Expand Down Expand Up @@ -1761,7 +1786,7 @@ void schedulePersistentKernelInnerOuter(
rparams.cross_grid_inner_reduction && !rparams.fastest_dim;

// Propagate inner reduction. There is a cutoff at boundaryNodesSet, so this
// propagation will not propagate to the outer reduction.
// propagation will not propagate to the final outer reduction.
reduction_scheduler_utils::propagateTransformation(
inner_reference_tv, boundaryNodesSet);
reduction_scheduler_utils::propagateRFactor(
Expand All @@ -1776,7 +1801,7 @@ void schedulePersistentKernelInnerOuter(
inner_reduction_tvs,
cached_inputs,
cached_outputs);

std::cout << " Propagate inner reduction "<< std::endl;
// Propagate outer reduction. Each outer reduction is connected with its
// cached_gmem and output, since we added all the cached_gmem to the
// boundaryNodesSet, the transformation from one outer reduction can't
Expand All @@ -1796,6 +1821,7 @@ void schedulePersistentKernelInnerOuter(
cached_inputs,
cached_outputs);
}
std::cout << " Propagate outer reduction "<< std::endl;

// special vectorization of temp gmem, vectorization_factor_tmp_gmem_write is
// guaranteed to be smaller or equal to input vectorization factor.
Expand All @@ -1809,11 +1835,13 @@ void schedulePersistentKernelInnerOuter(
rparams.unroll_factor_inner_reduction) {
tv->split(-1, rparams.vectorization_factor_tmp_gmem_write);
}
std::cout << " cached_gmem " << tv->toString() << std::endl;
tv->axis(-1)->parallelize(ParallelType::Vectorize);
}
}
if (rparams.vectorization_factor_tmp_gmem_read > 1) {
for (auto tv : cached_gmem_reload) {
std::cout << " cached_gmem_reload " << tv->toString() << std::endl;
tv->axis(-2)->parallelize(ParallelType::Vectorize);
naoyam marked this conversation as resolved.
Show resolved Hide resolved
}
}
Expand All @@ -1825,7 +1853,8 @@ void schedulePersistentKernelInnerOuter(
if (rparams.vectorization_factor_tmp_gmem_read > 1) {
for (auto tv_pair : cached_outputs) {
if (tv_pair.second->axis(-1)->getParallelType() !=
ParallelType::Vectorize) {
ParallelType::Vectorize && tv_pair.second->axis(-1)->start()->isZeroInt() && tv_pair.second->axis(-1)->extent()->isConstScalar()) {
std::cout << " cached_outputs " << tv_pair.second->toString() << std::endl;
tv_pair.second->axis(-1)->parallelize(ParallelType::Vectorize);
}
}
Expand Down
63 changes: 47 additions & 16 deletions third_party/nvfuser/csrc/scheduler/normalization_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#include <scheduler/normalization_utils.h>
#include <scheduler/registry.h>
#include <utils.h>

#include <queue>
namespace nvfuser {
namespace normalization_scheduler_utils {

Expand Down Expand Up @@ -568,26 +568,57 @@ bool hasSharedInput(
bool hasSharedConsumer(
const std::vector<TensorView*>& inner_reduction_tvs,
const std::vector<TensorView*>& outer_reduction_tvs) {
// check inner reduction tvs
auto contvs = ir_utils::consumerTvsOf(inner_reduction_tvs[0]);
std::set<TensorView*> contvs_set(contvs.begin(), contvs.end());
for (int i = 1; i < (int)inner_reduction_tvs.size(); i++) {
for (auto tv : ir_utils::consumerTvsOf(inner_reduction_tvs[i])) {
if (contvs_set.find(tv) != contvs_set.end()) {
return true;
} else {
contvs_set.emplace(tv);
auto getAllConsumerTvsOf = [](TensorView* tv0) {
naoyam marked this conversation as resolved.
Show resolved Hide resolved
std::unordered_set<TensorView*> all_consumer_tvs;
std::queue<TensorView*> unvisited;
// start search from tv0, save its direct consumers
for (auto tv : ir_utils::consumerTvsOf(tv0)) {
if (all_consumer_tvs.find(tv) == all_consumer_tvs.end()) {
all_consumer_tvs.emplace(tv);
unvisited.push(tv);
}
}
// search for indirect consumers from tv0's direct consumers
while (!unvisited.empty()) {
auto next_tv = unvisited.front();
unvisited.pop();
for (auto tv : ir_utils::consumerTvsOf(next_tv)) {
if (all_consumer_tvs.find(tv) == all_consumer_tvs.end()) {
all_consumer_tvs.emplace(tv);
unvisited.push(tv);
}
}
}
std::cout << "tv0= " << tv0->toString() << std::endl;
for (auto entry : all_consumer_tvs) {
std::cout << "all_consumer_tvs of tv0= " << entry->toString()
<< std::endl;
}

return all_consumer_tvs;
};

// get all consumers of the inner_reduction_tvs
std::unordered_set<TensorView*> all_consumer_tvs_inner;
for (auto itv : inner_reduction_tvs) {
auto my_all_consumer_tvs = getAllConsumerTvsOf(itv);
all_consumer_tvs_inner.merge(my_all_consumer_tvs);
}

// check outer reduction tvs
for (int i = 0; i < (int)outer_reduction_tvs.size(); i++) {
for (auto tv : ir_utils::consumerTvsOf(outer_reduction_tvs[i])) {
if (contvs_set.find(tv) != contvs_set.end()) {
// check if outer reduction tvs have any shared consumer with inner reduction
// tvs and other outer reduction tvs
std::unordered_set<TensorView*> all_consumer_tvs_outer;
for (auto otv : outer_reduction_tvs) {
for (auto tv : getAllConsumerTvsOf(otv)) {
// check shared consumer with inner reduction tvs
if (all_consumer_tvs_inner.find(tv) != all_consumer_tvs_inner.end()) {
return true;
} else {
contvs_set.emplace(tv);
}
// check shared consumer with other outer reduction tvs
if(all_consumer_tvs_outer.find(tv) != all_consumer_tvs_outer.end()) {
return true;
}else{
all_consumer_tvs_outer.emplace(tv);
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions third_party/nvfuser/csrc/scheduler/normalization_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ bool hasSharedInput(
const std::vector<TensorView*>& inner_reduction_tvs,
const std::vector<TensorView*>& outer_reduction_tvs);

//! check if the inner reduction tvs have shared consumer
//! check if the outer reduction tvs have shared consumer
//! check if outer reduction tvs have any shared consumer with inner reduction
//! tvs and other outer reduction tvs
bool hasSharedConsumer(
const std::vector<TensorView*>& inner_reduction_tvs,
const std::vector<TensorView*>& outer_reduction_tvs);
Expand Down
Loading