diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.cc b/src/meta_schedule/schedule_rule/multi_level_tiling.cc index fb936e8d5eba..173660074c6f 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.cc @@ -75,60 +75,53 @@ std::vector MultiLevelTilingNode::AddWriteReuse(State state) const { if (state.tensor_core_is_used) state = TensorCoreStore(state); return {std::move(state)}; } - // Case 1. If the write cache is already there, we don't need to add another. - if (config.req == ReuseType::kMayReuse) { + std::vector levels = config.levels; + ReuseType req = config.req; + if (Optional> ann = tir::GetAnn>( + state.sch->GetSRef(state.block_rv), "meta_schedule.write_cache_level")) { + req = ReuseType::kMustReuse; + levels = std::vector(ann.value().begin(), ann.value().end()); + } + std::vector results; + if (req == ReuseType::kMayReuse) { + // Case 1. If the write cache is already there, we don't need to add another. Array consumer_rvs = state.sch->GetConsumers(state.block_rv); if (consumer_rvs.size() == 1 && IsWriteCache(state.sch->GetSRef(consumer_rvs[0]))) { - state.write_cache = consumer_rvs[0]; - state.write_cache_is_added = false; std::vector results; results.push_back(state); - BlockRV consumer = state.write_cache.value(); - // Enumerate the level of tile to be fused at - for (int level : config.levels) { + for (int level : levels) { State new_state = state; new_state.sch = state.sch->Copy(); new_state.sch->Seed(state.sch->ForkSeed()); if (new_state.tensor_core_is_used) { new_state = TensorCoreStore(new_state); + } else { + const LoopRV& loop_rv = new_state.tiles[level - 1].back(); + new_state.sch->ReverseComputeAt(consumer_rvs[0], loop_rv, true); } - const LoopRV& loop_rv = new_state.tiles[level - 1].back(); - new_state.sch->ReverseComputeAt(consumer, loop_rv, true); results.push_back(std::move(new_state)); } return results; + } else { + // Case 2. No write cache is added + State new_state(/*sch=*/state.sch->Copy(), /*block_rv=*/state.block_rv); + new_state.sch->Seed(state.sch->ForkSeed()); + if (new_state.tensor_core_is_used) new_state = TensorCoreStore(new_state); + results.emplace_back(std::move(new_state)); } } - std::vector results; - // Case 2. No write cache is added - if (config.req == ReuseType::kMayReuse) { - State new_state(/*sch=*/state.sch->Copy(), /*block_rv=*/state.block_rv, - /*write_cache=*/NullOpt, - /*write_cache_is_added=*/false); - new_state.sch->Seed(state.sch->ForkSeed()); - if (new_state.tensor_core_is_used) new_state = TensorCoreStore(new_state); - results.emplace_back(std::move(new_state)); - } + // Case 3. Add one write cache + BlockRV write_cache = state.sch->CacheWrite(/*block_rv=*/state.block_rv, /*read_buffer_index=*/0, + /*storage_scope=*/config.scope); for (int level : config.levels) { State new_state = state; new_state.sch = state.sch->Copy(); new_state.sch->Seed(state.sch->ForkSeed()); - if (new_state.tensor_core_is_used) { - new_state = TensorCoreStore(new_state); - } + ICHECK(!new_state.tensor_core_is_used) << "not supported"; // FIXME const LoopRV& loop_rv = new_state.tiles[level - 1].back(); - BlockRV write_cache = - new_state.sch->WriteAt(/*loop_rv=*/loop_rv, /*block_rv=*/new_state.block_rv, - /*write_buffer_index=*/0, - /*storage_scope=*/config.scope); - new_state.write_cache = write_cache; - { - tir::Annotate(new_state.sch->state(), new_state.sch->GetSRef(write_cache), // - tir::attr::meta_schedule_cache_type, // - Integer(tir::attr::meta_schedule_cache_type_write)); - } + new_state.sch->ReverseComputeAt(write_cache, loop_rv, true); results.push_back(std::move(new_state)); } return results; @@ -220,21 +213,37 @@ std::vector MultiLevelTilingNode::AddReadReuse(State state) const { continue; } // Do cache_read - BlockRV cache_read_block = sch->ReadAt(loop_rv, block_rv, i, config.scope); - runtime::StorageScope scope = runtime::StorageScope::Create(config.scope); - Array probs(3, FloatImm(DataType::Float(64), 1.0 / 3)); - PrimExpr ann_val = sch->SampleCategorical({4, 8, 16}, probs); - sch->Annotate(cache_read_block, tir::attr::vector_bytes, ann_val); - if (scope.rank == runtime::StorageRank::kShared && add_local_stage) { - sch->Annotate(cache_read_block, tir::attr::local_stage, Integer(1)); - } - if (scope.rank == runtime::StorageRank::kShared) { - sch->Annotate(cache_read_block, tir::attr::double_buffer_scope, Integer(0)); - } - { - tir::Annotate(sch->state(), sch->GetSRef(cache_read_block), // - tir::attr::meta_schedule_cache_type, - Integer(tir::attr::meta_schedule_cache_type_read)); + BlockRV cache_read_block; + if (state.tensor_core_is_used) { + cache_read_block = sch->ReadAt(loop_rv, block_rv, i, config.scope); + runtime::StorageScope scope = runtime::StorageScope::Create(config.scope); + Array probs(3, FloatImm(DataType::Float(64), 1.0 / 3)); + PrimExpr ann_val = sch->SampleCategorical({4, 8, 16}, probs); + sch->Annotate(cache_read_block, tir::attr::vector_bytes, ann_val); + if (scope.rank == runtime::StorageRank::kShared && add_local_stage) { + sch->Annotate(cache_read_block, tir::attr::local_stage, Integer(1)); + } + if (scope.rank == runtime::StorageRank::kShared) { + sch->Annotate(cache_read_block, tir::attr::double_buffer_scope, Integer(0)); + } + } else { + cache_read_block = sch->CacheRead(block_rv, i, config.scope); + // Insert cache_read block to the proper place + sch->ComputeAt(cache_read_block, loop_rv, true); + // Fuse the iterators of the cache_read + Array buffer_loops = sch->GetLoops(cache_read_block); + LoopRV fused = sch->Fuse(Array{buffer_loops.end() - buffer_ndim, // + buffer_loops.end()}); + // Annotate cooperative fetching + if (!vector_load_lens.empty()) { + int n = vector_load_lens.size(); + double prob = 1.0 / n; + tir::ExprRV vector_load_len = + sch->SampleCategorical(support::AsArray(vector_load_lens), + Array(n, FloatImm(DataType::Float(64), prob))); + sch->Annotate(cache_read_block, tir::attr::meta_schedule_cooperative_fetch, + vector_load_len); + } } } State new_state = state; @@ -246,7 +255,7 @@ std::vector MultiLevelTilingNode::AddReadReuse(State state) const { Array order; if (tir::IsCacheReadSharedPattern(loop)) { stage = {0, 0, 0, 0, 0, 1, 1}; - order = {0, 1, 3, 4, 5, 2, 6}; + order = {0, 3, 1, 4, 5, 2, 6}; } else { tir::FallbackRule(loop, &stage, &order); } @@ -287,6 +296,7 @@ Optional MultiLevelTilingNode::TransformWithTensorIntrin(State& state, c Optional opt_layout_info = GetTensorizeLayoutInfo(state.sch->state(), state.sch->GetSRef(block_rv), tir::TensorIntrin::Get(intrin_name)->desc); + ICHECK(opt_layout_info.defined()); if (!opt_layout_info) return NullOpt; const tir::LayoutInfoNode* info = opt_layout_info.value().get(); @@ -419,6 +429,7 @@ inline std::vector MultiLevelTilingNode::SeekForTensorCore(State state) c if (!use_tensor_core) return {state}; // Do block & buffer layout transform to match Tensor Core wmma sync intrin Optional transformed_loop_rv = TransformWithTensorIntrin(state, "wmma_sync"); + ICHECK(transformed_loop_rv.defined()); if (!transformed_loop_rv.defined()) return {state}; // Do tiling to match Tensor Core wmma sync intrin BlockRV block_rv = state.block_rv;