Skip to content

Commit

Permalink
Recover MultiLevelTiling rule for non tensor core workload
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 committed Jun 2, 2022
1 parent d1c1b0f commit a80128b
Showing 1 changed file with 59 additions and 48 deletions.
107 changes: 59 additions & 48 deletions src/meta_schedule/schedule_rule/multi_level_tiling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,60 +75,53 @@ std::vector<State> MultiLevelTilingNode::AddWriteReuse(State state) const {
if (state.tensor_core_is_used) state = TensorCoreStore(state);
return {std::move(state)};
}
// Case 1. If the write cache is already there, we don't need to add another.
if (config.req == ReuseType::kMayReuse) {
std::vector<int> levels = config.levels;
ReuseType req = config.req;
if (Optional<Array<Integer>> ann = tir::GetAnn<Array<Integer>>(
state.sch->GetSRef(state.block_rv), "meta_schedule.write_cache_level")) {
req = ReuseType::kMustReuse;
levels = std::vector<int>(ann.value().begin(), ann.value().end());
}
std::vector<State> results;
if (req == ReuseType::kMayReuse) {
// Case 1. If the write cache is already there, we don't need to add another.
Array<BlockRV> consumer_rvs = state.sch->GetConsumers(state.block_rv);
if (consumer_rvs.size() == 1 && IsWriteCache(state.sch->GetSRef(consumer_rvs[0]))) {
state.write_cache = consumer_rvs[0];
state.write_cache_is_added = false;
std::vector<State> results;
results.push_back(state);

BlockRV consumer = state.write_cache.value();
// Enumerate the level of tile to be fused at
for (int level : config.levels) {
for (int level : levels) {
State new_state = state;
new_state.sch = state.sch->Copy();
new_state.sch->Seed(state.sch->ForkSeed());
if (new_state.tensor_core_is_used) {
new_state = TensorCoreStore(new_state);
} else {
const LoopRV& loop_rv = new_state.tiles[level - 1].back();
new_state.sch->ReverseComputeAt(consumer_rvs[0], loop_rv, true);
}
const LoopRV& loop_rv = new_state.tiles[level - 1].back();
new_state.sch->ReverseComputeAt(consumer, loop_rv, true);
results.push_back(std::move(new_state));
}
return results;
} else {
// Case 2. No write cache is added
State new_state(/*sch=*/state.sch->Copy(), /*block_rv=*/state.block_rv);
new_state.sch->Seed(state.sch->ForkSeed());
if (new_state.tensor_core_is_used) new_state = TensorCoreStore(new_state);
results.emplace_back(std::move(new_state));
}
}
std::vector<State> results;
// Case 2. No write cache is added
if (config.req == ReuseType::kMayReuse) {
State new_state(/*sch=*/state.sch->Copy(), /*block_rv=*/state.block_rv,
/*write_cache=*/NullOpt,
/*write_cache_is_added=*/false);
new_state.sch->Seed(state.sch->ForkSeed());
if (new_state.tensor_core_is_used) new_state = TensorCoreStore(new_state);
results.emplace_back(std::move(new_state));
}

// Case 3. Add one write cache
BlockRV write_cache = state.sch->CacheWrite(/*block_rv=*/state.block_rv, /*read_buffer_index=*/0,
/*storage_scope=*/config.scope);
for (int level : config.levels) {
State new_state = state;
new_state.sch = state.sch->Copy();
new_state.sch->Seed(state.sch->ForkSeed());
if (new_state.tensor_core_is_used) {
new_state = TensorCoreStore(new_state);
}
ICHECK(!new_state.tensor_core_is_used) << "not supported"; // FIXME
const LoopRV& loop_rv = new_state.tiles[level - 1].back();
BlockRV write_cache =
new_state.sch->WriteAt(/*loop_rv=*/loop_rv, /*block_rv=*/new_state.block_rv,
/*write_buffer_index=*/0,
/*storage_scope=*/config.scope);
new_state.write_cache = write_cache;
{
tir::Annotate(new_state.sch->state(), new_state.sch->GetSRef(write_cache), //
tir::attr::meta_schedule_cache_type, //
Integer(tir::attr::meta_schedule_cache_type_write));
}
new_state.sch->ReverseComputeAt(write_cache, loop_rv, true);
results.push_back(std::move(new_state));
}
return results;
Expand Down Expand Up @@ -220,21 +213,37 @@ std::vector<State> MultiLevelTilingNode::AddReadReuse(State state) const {
continue;
}
// Do cache_read
BlockRV cache_read_block = sch->ReadAt(loop_rv, block_rv, i, config.scope);
runtime::StorageScope scope = runtime::StorageScope::Create(config.scope);
Array<FloatImm> probs(3, FloatImm(DataType::Float(64), 1.0 / 3));
PrimExpr ann_val = sch->SampleCategorical({4, 8, 16}, probs);
sch->Annotate(cache_read_block, tir::attr::vector_bytes, ann_val);
if (scope.rank == runtime::StorageRank::kShared && add_local_stage) {
sch->Annotate(cache_read_block, tir::attr::local_stage, Integer(1));
}
if (scope.rank == runtime::StorageRank::kShared) {
sch->Annotate(cache_read_block, tir::attr::double_buffer_scope, Integer(0));
}
{
tir::Annotate(sch->state(), sch->GetSRef(cache_read_block), //
tir::attr::meta_schedule_cache_type,
Integer(tir::attr::meta_schedule_cache_type_read));
BlockRV cache_read_block;
if (state.tensor_core_is_used) {
cache_read_block = sch->ReadAt(loop_rv, block_rv, i, config.scope);
runtime::StorageScope scope = runtime::StorageScope::Create(config.scope);
Array<FloatImm> probs(3, FloatImm(DataType::Float(64), 1.0 / 3));
PrimExpr ann_val = sch->SampleCategorical({4, 8, 16}, probs);
sch->Annotate(cache_read_block, tir::attr::vector_bytes, ann_val);
if (scope.rank == runtime::StorageRank::kShared && add_local_stage) {
sch->Annotate(cache_read_block, tir::attr::local_stage, Integer(1));
}
if (scope.rank == runtime::StorageRank::kShared) {
sch->Annotate(cache_read_block, tir::attr::double_buffer_scope, Integer(0));
}
} else {
cache_read_block = sch->CacheRead(block_rv, i, config.scope);
// Insert cache_read block to the proper place
sch->ComputeAt(cache_read_block, loop_rv, true);
// Fuse the iterators of the cache_read
Array<LoopRV> buffer_loops = sch->GetLoops(cache_read_block);
LoopRV fused = sch->Fuse(Array<LoopRV>{buffer_loops.end() - buffer_ndim, //
buffer_loops.end()});
// Annotate cooperative fetching
if (!vector_load_lens.empty()) {
int n = vector_load_lens.size();
double prob = 1.0 / n;
tir::ExprRV vector_load_len =
sch->SampleCategorical(support::AsArray<int, Integer>(vector_load_lens),
Array<FloatImm>(n, FloatImm(DataType::Float(64), prob)));
sch->Annotate(cache_read_block, tir::attr::meta_schedule_cooperative_fetch,
vector_load_len);
}
}
}
State new_state = state;
Expand All @@ -246,7 +255,7 @@ std::vector<State> MultiLevelTilingNode::AddReadReuse(State state) const {
Array<Integer> order;
if (tir::IsCacheReadSharedPattern(loop)) {
stage = {0, 0, 0, 0, 0, 1, 1};
order = {0, 1, 3, 4, 5, 2, 6};
order = {0, 3, 1, 4, 5, 2, 6};
} else {
tir::FallbackRule(loop, &stage, &order);
}
Expand Down Expand Up @@ -287,6 +296,7 @@ Optional<LoopRV> MultiLevelTilingNode::TransformWithTensorIntrin(State& state, c
Optional<tir::LayoutInfo> opt_layout_info =
GetTensorizeLayoutInfo(state.sch->state(), state.sch->GetSRef(block_rv),
tir::TensorIntrin::Get(intrin_name)->desc);
ICHECK(opt_layout_info.defined());
if (!opt_layout_info) return NullOpt;
const tir::LayoutInfoNode* info = opt_layout_info.value().get();

Expand Down Expand Up @@ -419,6 +429,7 @@ inline std::vector<State> MultiLevelTilingNode::SeekForTensorCore(State state) c
if (!use_tensor_core) return {state};
// Do block & buffer layout transform to match Tensor Core wmma sync intrin
Optional<LoopRV> transformed_loop_rv = TransformWithTensorIntrin(state, "wmma_sync");
ICHECK(transformed_loop_rv.defined());
if (!transformed_loop_rv.defined()) return {state};
// Do tiling to match Tensor Core wmma sync intrin
BlockRV block_rv = state.block_rv;
Expand Down

0 comments on commit a80128b

Please sign in to comment.