From 3e4a30e02947452152617f8ae997230440137c40 Mon Sep 17 00:00:00 2001 From: Junru Shao <junrushao1994@gmail.com> Date: Sat, 22 Jan 2022 05:23:03 -0800 Subject: [PATCH] Fix cooperative fetching (#17) --- include/tvm/meta_schedule/schedule_rule.h | 4 +- .../schedule_rule/multi_level_tiling.py | 6 +- .../meta_schedule/testing/schedule_rule.py | 6 +- python/tvm/meta_schedule/tune.py | 4 +- src/meta_schedule/mutator/mutate_tile_size.cc | 108 ++++++++++++++---- .../schedule_rule/multi_level_tiling.cc | 24 ++-- ...hedule_schedule_rule_multi_level_tiling.py | 88 +++++++------- .../test_meta_schedule_sketch_cuda.py | 80 ++++++------- .../unittest/test_meta_schedule_tune_tir.py | 2 +- 9 files changed, 193 insertions(+), 129 deletions(-) diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h index eb22178ff2bd..449c6cf7e4cf 100644 --- a/include/tvm/meta_schedule/schedule_rule.h +++ b/include/tvm/meta_schedule/schedule_rule.h @@ -141,7 +141,7 @@ class ScheduleRule : public runtime::ObjectRef { * - [blockIdx.x, vthread.x, threadIdx.x] on GPU * \param use_tensor_core Whether to apply tensor core wmma intrinsic for the computation * \param max_innermost_factor The maximum size of the innermost factor. NullOpt means no limit - * \param vector_load_max_len The length of vector lane in vectorized cooperative fetching. + * \param vector_load_lens The length of vector lane in vectorized cooperative fetching. * NullOpt means disable vectorization * \param reuse_read Data reuse configuration for reading. NullOpt means no reuse. * \param reuse_write Data reuse configuration for writing. NullOpt means no reuse. @@ -151,7 +151,7 @@ class ScheduleRule : public runtime::ObjectRef { Optional<Array<String>> tile_binds, // bool use_tensor_core, // Optional<Integer> max_innermost_factor, // - Optional<Integer> vector_load_max_len, // + Optional<Array<Integer>> vector_load_lens, // Optional<Map<String, ObjectRef>> reuse_read, // Optional<Map<String, ObjectRef>> reuse_write); /*! diff --git a/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py b/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py index b9eba95b869e..9e030d8a425c 100644 --- a/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py +++ b/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py @@ -57,7 +57,7 @@ class MultiLevelTiling(ScheduleRule): Whether to apply tensor core wmma intrinsic for the computation max_innermost_factor : Optional[int] The maximum size of the innermost factor. None means no limit - vector_load_max_len : Optional[int] + vector_load_lens : Optional[List[int]] The length of vector lane in vectorized cooperative fetching. None means disable vectorization reuse_read : Optional[ReuseType] @@ -72,7 +72,7 @@ def __init__( tile_binds: Optional[List[str]] = None, use_tensor_core: bool = False, max_innermost_factor: Optional[int] = None, - vector_load_max_len: Optional[int] = None, + vector_load_lens: Optional[List[int]] = None, reuse_read: Optional[ReuseType] = None, reuse_write: Optional[ReuseType] = None, ) -> None: @@ -82,7 +82,7 @@ def __init__( tile_binds, use_tensor_core, max_innermost_factor, - vector_load_max_len, + vector_load_lens, reuse_read.as_dict() if reuse_read is not None else None, reuse_write.as_dict() if reuse_write is not None else None, ) diff --git a/python/tvm/meta_schedule/testing/schedule_rule.py b/python/tvm/meta_schedule/testing/schedule_rule.py index d62a54bebac6..83434a123a03 100644 --- a/python/tvm/meta_schedule/testing/schedule_rule.py +++ b/python/tvm/meta_schedule/testing/schedule_rule.py @@ -111,7 +111,7 @@ def multi_level_tiling(target: Target) -> ScheduleRule: structure="SSRSRS", tile_binds=None, max_innermost_factor=64, - vector_load_max_len=None, + vector_load_lens=None, reuse_read=None, reuse_write=ReuseType( req="may", @@ -124,7 +124,7 @@ def multi_level_tiling(target: Target) -> ScheduleRule: structure="SSSRRSRS", tile_binds=["blockIdx.x", "vthread.x", "threadIdx.x"], max_innermost_factor=64, - vector_load_max_len=4, + vector_load_lens=[1, 2, 3, 4], reuse_read=ReuseType( req="must", levels=[4], @@ -147,7 +147,7 @@ def multi_level_tiling_tensor_core(target: Target) -> ScheduleRule: tile_binds=["blockIdx.x", "blockIdx.y", "threadIdx.y"], use_tensor_core=True, max_innermost_factor=64, - vector_load_max_len=4, + vector_load_lens=[1, 2, 3, 4], reuse_read=ReuseType( req="must", levels=[4], diff --git a/python/tvm/meta_schedule/tune.py b/python/tvm/meta_schedule/tune.py index ee9198eb1def..4f38d7cc98be 100644 --- a/python/tvm/meta_schedule/tune.py +++ b/python/tvm/meta_schedule/tune.py @@ -101,7 +101,7 @@ def _sch_rules() -> List[ScheduleRule]: structure="SSRSRS", tile_binds=None, max_innermost_factor=64, - vector_load_max_len=None, + vector_load_lens=None, reuse_read=None, reuse_write=M.ReuseType( req="may", @@ -158,7 +158,7 @@ def _sch_rules() -> List[ScheduleRule]: structure="SSSRRSRS", tile_binds=["blockIdx.x", "vthread.x", "threadIdx.x"], max_innermost_factor=64, - vector_load_max_len=4, + vector_load_lens=[1, 2, 3, 4], reuse_read=M.ReuseType( req="must", levels=[4], diff --git a/src/meta_schedule/mutator/mutate_tile_size.cc b/src/meta_schedule/mutator/mutate_tile_size.cc index 1daf1f265e70..02c418b3c2c4 100644 --- a/src/meta_schedule/mutator/mutate_tile_size.cc +++ b/src/meta_schedule/mutator/mutate_tile_size.cc @@ -33,7 +33,7 @@ using tir::Trace; * \param decision The decision of Sample-Perfect-Tile * \return The result of downcast */ -std::vector<int64_t> DowncastDecision(const ObjectRef& decision) { +std::vector<int64_t> DowncastTilingDecision(const ObjectRef& decision) { const auto* arr = TVM_TYPE_AS(arr, decision, runtime::ArrayNode); return support::AsVector<ObjectRef, int64_t>(GetRef<Array<ObjectRef>>(arr)); } @@ -73,34 +73,62 @@ class MutateTileSizeNode : public MutatorNode { * \param decision The decision selected * \return Whether a decision is found */ -bool FindSamplePerfectTile(const Trace& trace, TRandState* rand_state, Instruction* inst, - std::vector<int64_t>* decision) { +void FindSamplePerfectTile(const Trace& trace, std::vector<Instruction>* inst, + std::vector<std::vector<int64_t>>* decision) { static const InstructionKind& inst_sample_perfect_tile = InstructionKind::Get("SamplePerfectTile"); - std::vector<Instruction> instructions; - std::vector<std::vector<int64_t>> decisions; + std::vector<Instruction>& instructions = *inst; + std::vector<std::vector<int64_t>>& decisions = *decision; instructions.reserve(trace->decisions.size()); decisions.reserve(trace->decisions.size()); for (const auto& kv : trace->decisions) { const Instruction& inst = kv.first; const ObjectRef& decision = kv.second; - if (!inst->kind.same_as(inst_sample_perfect_tile)) { - continue; + if (inst->kind.same_as(inst_sample_perfect_tile)) { + std::vector<int64_t> tiles = DowncastTilingDecision(decision); + if (tiles.size() >= 2 && Product(tiles) >= 2) { + instructions.push_back(inst); + decisions.push_back(tiles); + } } - std::vector<int64_t> tiles = DowncastDecision(decision); - if (tiles.size() >= 2 && Product(tiles) >= 2) { - instructions.push_back(inst); - decisions.push_back(tiles); + } +} + +void FindSampleVectorize(const Trace& trace, std::vector<Instruction>* inst, + std::vector<int64_t>* decision) { + static const InstructionKind& inst_sample_categorical = InstructionKind::Get("SampleCategorical"); + static const InstructionKind& inst_annotate = InstructionKind::Get("Annotate"); + std::vector<Instruction>& instructions = *inst; + std::vector<int64_t>& decisions = *decision; + std::unordered_set<const Object*> annotated; + instructions.reserve(trace->decisions.size()); + decisions.reserve(trace->decisions.size()); + annotated.reserve(trace->decisions.size()); + // Find annotation with `meta_schedule_cooperative_fetch` + for (const Instruction& inst : trace->insts) { + if (inst->kind.same_as(inst_annotate)) { + ICHECK_EQ(inst->attrs.size(), 1); + ICHECK_EQ(inst->inputs.size(), 2); + if (Downcast<String>(inst->attrs[0]) == tir::attr::meta_schedule_cooperative_fetch) { + const auto* ann_val = inst->inputs[1].as<tir::ExprRVNode>(); + ICHECK(ann_val); + annotated.insert(ann_val); + } } } - int n = instructions.size(); - if (n > 0) { - int i = tir::SampleInt(rand_state, 0, n); - *inst = instructions[i]; - *decision = decisions[i]; - return true; + // Find sampling instruction that generates the annotation + for (const auto& kv : trace->decisions) { + const Instruction& inst = kv.first; + const ObjectRef& decision = kv.second; + if (inst->kind.same_as(inst_sample_categorical)) { + ICHECK_EQ(inst->outputs.size(), 1); + if (annotated.count(inst->outputs[0].get())) { + const auto* d = TVM_TYPE_AS(d, decision, IntImmNode); + instructions.push_back(inst); + decisions.push_back(d->value); + } + } } - return false; } struct FactorMemo { @@ -146,12 +174,8 @@ struct FactorMemo { std::mutex mutex_; }; -Optional<Trace> MutateTileSizeNode::Apply(const Trace& trace, TRandState* rand_state) { - Instruction inst; - std::vector<int64_t> tiles; - if (!FindSamplePerfectTile(trace, rand_state, &inst, &tiles)) { - return NullOpt; - } +Optional<Trace> MutateSampleTileSize(const Trace& trace, Instruction inst, + std::vector<int64_t> tiles, TRandState* rand_state) { int n_splits = tiles.size(); // Step 1. Choose two loops, `x` and `y` int x, y; @@ -194,6 +218,42 @@ Optional<Trace> MutateTileSizeNode::Apply(const Trace& trace, TRandState* rand_s } } +Optional<Trace> MutateSampleVectorize(const Trace& trace, Instruction inst, + int64_t original_decision, TRandState* rand_state) { + ICHECK_EQ(inst->attrs.size(), 2); + std::vector<double> probs = + support::AsVector<FloatImm, double>(Downcast<Array<FloatImm>>(inst->attrs[1])); + probs.erase(probs.begin() + original_decision); + int result = tir::MakeMultinomialSampler(rand_state, probs)(); + if (result >= original_decision) { + result += 1; + } + return trace->WithDecision(inst, Integer(result), /*remove_postproc=*/true); +} + +Optional<Trace> MutateTileSizeNode::Apply(const Trace& trace, TRandState* rand_state) { + std::vector<Instruction> sample_perfect_tile_insts; + std::vector<Instruction> sample_vectorize_insts; + std::vector<std::vector<int64_t>> sample_perfect_tile_tiles; + std::vector<int64_t> sample_vectorize_decisions; + FindSamplePerfectTile(trace, &sample_perfect_tile_insts, &sample_perfect_tile_tiles); + FindSampleVectorize(trace, &sample_vectorize_insts, &sample_vectorize_decisions); + int size_a = sample_perfect_tile_insts.size(); + int size_b = sample_vectorize_insts.size(); + if (size_a == 0 && size_b == 0) { + return NullOpt; + } + int n = tir::SampleInt(rand_state, 0, size_a + size_b); + if (n < size_a) { + return MutateSampleTileSize(trace, sample_perfect_tile_insts[n], sample_perfect_tile_tiles[n], + rand_state); + } else { + n -= size_a; + return MutateSampleVectorize(trace, sample_vectorize_insts[n], sample_vectorize_decisions[n], + rand_state); + } +} + Mutator Mutator::MutateTileSize() { return Mutator(make_object<MutateTileSizeNode>()); } TVM_REGISTER_NODE_TYPE(MutateTileSizeNode); diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.cc b/src/meta_schedule/schedule_rule/multi_level_tiling.cc index a0ffe7e00426..a5d677c5cdf2 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.cc @@ -322,7 +322,7 @@ class MultiLevelTilingNode : public ScheduleRuleNode { /*! \brief The maximum size of the innermost factor */ int max_innermost_factor; /*! \brief The length of vector lane in vectorized cooperative fetching */ - int vector_load_max_len; + std::vector<int> vector_load_lens; /*! \brief Data reuse configuration for reading */ ReuseConfig reuse_read_; /*! \brief Data reuse configuration for writing */ @@ -337,7 +337,7 @@ class MultiLevelTilingNode : public ScheduleRuleNode { v->Visit("tile_binds", &tile_binds); v->Visit("use_tensor_core", &use_tensor_core); v->Visit("max_innermost_factor", &max_innermost_factor); - v->Visit("vector_load_max_len", &vector_load_max_len); + // `vector_load_lens` is not visited // `reuse_read_` is not visited // `reuse_write_` is not visited // `s_indices_` is not visited @@ -491,12 +491,14 @@ inline std::vector<State> MultiLevelTilingNode::AddReadReuse(State state) const LoopRV fused = sch->Fuse(Array<LoopRV>{buffer_loops.end() - buffer_ndim, // buffer_loops.end()}); // Annotate cooperative fetching - if (vector_load_max_len > 0) { - // cooperative fetch + vectorized loading - // Split into inner and outer, vectorize the inner loop - Array<ExprRV> factors = sch->SamplePerfectTile(fused, 2, vector_load_max_len); - // Add cooperative fetching - sch->Annotate(cache_read_block, tir::attr::meta_schedule_cooperative_fetch, factors[1]); + if (!vector_load_lens.empty()) { + int n = vector_load_lens.size(); + double prob = 1.0 / n; + ExprRV vector_load_len = + sch->SampleCategorical(support::AsArray<int, Integer>(vector_load_lens), + Array<FloatImm>(n, FloatImm(DataType::Float(64), prob))); + sch->Annotate(cache_read_block, tir::attr::meta_schedule_cooperative_fetch, + vector_load_len); } } State new_state = state; @@ -545,7 +547,7 @@ inline std::vector<State> MultiLevelTilingNode::FuseWriteReuse(State state) cons ScheduleRule ScheduleRule::MultiLevelTiling(String structure, Optional<Array<String>> tile_binds, bool use_tensor_core, Optional<Integer> max_innermost_factor, - Optional<Integer> vector_load_max_len, + Optional<Array<Integer>> vector_load_lens, Optional<Map<String, ObjectRef>> reuse_read, Optional<Map<String, ObjectRef>> reuse_write) { ObjectPtr<MultiLevelTilingNode> n = make_object<MultiLevelTilingNode>(); @@ -561,7 +563,9 @@ ScheduleRule ScheduleRule::MultiLevelTiling(String structure, Optional<Array<Str tir::TensorIntrin::Get("wmma_fill"); } n->max_innermost_factor = max_innermost_factor.value_or(Integer(-1))->value; - n->vector_load_max_len = vector_load_max_len.value_or(Integer(-1))->value; + n->vector_load_lens = vector_load_lens.defined() + ? support::AsVector<Integer, int>(vector_load_lens.value()) + : std::vector<int>(); n->reuse_read_ = reuse_read.defined() ? ReuseConfig(reuse_read.value()) : ReuseConfig(); n->reuse_write_ = reuse_write.defined() ? ReuseConfig(reuse_write.value()) : ReuseConfig(); for (int i = 0, len = structure.size(); i < len; ++i) { diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py b/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py index dd703e49ff0e..dba661bba03c 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py @@ -190,14 +190,14 @@ def test_cuda_matmul(): "sch.compute_at(block=b34, loop=l28, preserve_unit_loops=True)", "l35, l36, l37, l38, l39, l40 = sch.get_loops(block=b34)", "l41 = sch.fuse(l39, l40)", - "v42, v43 = sch.sample_perfect_tile(loop=l41, n=2, max_innermost_factor=4)", - 'sch.annotate(block_or_loop=b34, ann_key="meta_schedule.cooperative_fetch", ann_val=v43)', - 'b44 = sch.cache_read(block=b0, read_buffer_index=2, storage_scope="shared")', - "sch.compute_at(block=b44, loop=l28, preserve_unit_loops=True)", - "l45, l46, l47, l48, l49, l50 = sch.get_loops(block=b44)", - "l51 = sch.fuse(l49, l50)", - "v52, v53 = sch.sample_perfect_tile(loop=l51, n=2, max_innermost_factor=4)", - 'sch.annotate(block_or_loop=b44, ann_key="meta_schedule.cooperative_fetch", ann_val=v53)', + "v42 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b34, ann_key="meta_schedule.cooperative_fetch", ann_val=v42)', + 'b43 = sch.cache_read(block=b0, read_buffer_index=2, storage_scope="shared")', + "sch.compute_at(block=b43, loop=l28, preserve_unit_loops=True)", + "l44, l45, l46, l47, l48, l49 = sch.get_loops(block=b43)", + "l50 = sch.fuse(l48, l49)", + "v51 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b43, ann_key="meta_schedule.cooperative_fetch", ann_val=v51)', "sch.reverse_compute_at(block=b1, loop=l33, preserve_unit_loops=True)", ] ] @@ -244,14 +244,14 @@ def test_cuda_matmul_relu(): "sch.compute_at(block=b34, loop=l28, preserve_unit_loops=True)", "l35, l36, l37, l38, l39, l40 = sch.get_loops(block=b34)", "l41 = sch.fuse(l39, l40)", - "v42, v43 = sch.sample_perfect_tile(loop=l41, n=2, max_innermost_factor=4)", - 'sch.annotate(block_or_loop=b34, ann_key="meta_schedule.cooperative_fetch", ann_val=v43)', - 'b44 = sch.cache_read(block=b0, read_buffer_index=2, storage_scope="shared")', - "sch.compute_at(block=b44, loop=l28, preserve_unit_loops=True)", - "l45, l46, l47, l48, l49, l50 = sch.get_loops(block=b44)", - "l51 = sch.fuse(l49, l50)", - "v52, v53 = sch.sample_perfect_tile(loop=l51, n=2, max_innermost_factor=4)", - 'sch.annotate(block_or_loop=b44, ann_key="meta_schedule.cooperative_fetch", ann_val=v53)', + "v42 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b34, ann_key="meta_schedule.cooperative_fetch", ann_val=v42)', + 'b43 = sch.cache_read(block=b0, read_buffer_index=2, storage_scope="shared")', + "sch.compute_at(block=b43, loop=l28, preserve_unit_loops=True)", + "l44, l45, l46, l47, l48, l49 = sch.get_loops(block=b43)", + "l50 = sch.fuse(l48, l49)", + "v51 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b43, ann_key="meta_schedule.cooperative_fetch", ann_val=v51)', "sch.reverse_compute_at(block=b1, loop=l33, preserve_unit_loops=True)", ] ] @@ -310,20 +310,20 @@ def test_cuda_tensor_core_matmul(): "sch.compute_at(block=b52, loop=l46, preserve_unit_loops=True)", "l53, l54, l55, l56, l57, l58 = sch.get_loops(block=b52)", "l59 = sch.fuse(l57, l58)", - "v60, v61 = sch.sample_perfect_tile(loop=l59, n=2, max_innermost_factor=4)", - 'sch.annotate(block_or_loop=b52, ann_key="meta_schedule.cooperative_fetch", ann_val=v61)', - 'b62 = sch.cache_read(block=b16, read_buffer_index=2, storage_scope="shared")', - "sch.compute_at(block=b62, loop=l46, preserve_unit_loops=True)", - "l63, l64, l65, l66, l67, l68 = sch.get_loops(block=b62)", - "l69 = sch.fuse(l67, l68)", - "v70, v71 = sch.sample_perfect_tile(loop=l69, n=2, max_innermost_factor=4)", - 'sch.annotate(block_or_loop=b62, ann_key="meta_schedule.cooperative_fetch", ann_val=v71)', - 'b72 = sch.cache_read(block=b16, read_buffer_index=1, storage_scope="wmma.matrix_a")', - 'b73 = sch.cache_read(block=b16, read_buffer_index=2, storage_scope="wmma.matrix_b")', - "sch.compute_at(block=b72, loop=l48, preserve_unit_loops=True)", - "sch.compute_at(block=b73, loop=l48, preserve_unit_loops=True)", - 'sch.annotate(block_or_loop=b72, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_a")', - 'sch.annotate(block_or_loop=b73, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_b")', + "v60 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b52, ann_key="meta_schedule.cooperative_fetch", ann_val=v60)', + 'b61 = sch.cache_read(block=b16, read_buffer_index=2, storage_scope="shared")', + "sch.compute_at(block=b61, loop=l46, preserve_unit_loops=True)", + "l62, l63, l64, l65, l66, l67 = sch.get_loops(block=b61)", + "l68 = sch.fuse(l66, l67)", + "v69 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b61, ann_key="meta_schedule.cooperative_fetch", ann_val=v69)', + 'b70 = sch.cache_read(block=b16, read_buffer_index=1, storage_scope="wmma.matrix_a")', + 'b71 = sch.cache_read(block=b16, read_buffer_index=2, storage_scope="wmma.matrix_b")', + "sch.compute_at(block=b70, loop=l48, preserve_unit_loops=True)", + "sch.compute_at(block=b71, loop=l48, preserve_unit_loops=True)", + 'sch.annotate(block_or_loop=b70, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_a")', + 'sch.annotate(block_or_loop=b71, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_b")', "sch.reverse_compute_at(block=b19, loop=l51, preserve_unit_loops=True)", "sch.reverse_compute_at(block=b18, loop=l51, preserve_unit_loops=True)", ] @@ -382,20 +382,20 @@ def test_cuda_tensor_core_matmul_relu(): "sch.compute_at(block=b52, loop=l46, preserve_unit_loops=True)", "l53, l54, l55, l56, l57, l58 = sch.get_loops(block=b52)", "l59 = sch.fuse(l57, l58)", - "v60, v61 = sch.sample_perfect_tile(loop=l59, n=2, max_innermost_factor=4)", - 'sch.annotate(block_or_loop=b52, ann_key="meta_schedule.cooperative_fetch", ann_val=v61)', - 'b62 = sch.cache_read(block=b16, read_buffer_index=2, storage_scope="shared")', - "sch.compute_at(block=b62, loop=l46, preserve_unit_loops=True)", - "l63, l64, l65, l66, l67, l68 = sch.get_loops(block=b62)", - "l69 = sch.fuse(l67, l68)", - "v70, v71 = sch.sample_perfect_tile(loop=l69, n=2, max_innermost_factor=4)", - 'sch.annotate(block_or_loop=b62, ann_key="meta_schedule.cooperative_fetch", ann_val=v71)', - 'b72 = sch.cache_read(block=b16, read_buffer_index=1, storage_scope="wmma.matrix_a")', - 'b73 = sch.cache_read(block=b16, read_buffer_index=2, storage_scope="wmma.matrix_b")', - "sch.compute_at(block=b72, loop=l48, preserve_unit_loops=True)", - "sch.compute_at(block=b73, loop=l48, preserve_unit_loops=True)", - 'sch.annotate(block_or_loop=b72, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_a")', - 'sch.annotate(block_or_loop=b73, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_b")', + "v60 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b52, ann_key="meta_schedule.cooperative_fetch", ann_val=v60)', + 'b61 = sch.cache_read(block=b16, read_buffer_index=2, storage_scope="shared")', + "sch.compute_at(block=b61, loop=l46, preserve_unit_loops=True)", + "l62, l63, l64, l65, l66, l67 = sch.get_loops(block=b61)", + "l68 = sch.fuse(l66, l67)", + "v69 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b61, ann_key="meta_schedule.cooperative_fetch", ann_val=v69)', + 'b70 = sch.cache_read(block=b16, read_buffer_index=1, storage_scope="wmma.matrix_a")', + 'b71 = sch.cache_read(block=b16, read_buffer_index=2, storage_scope="wmma.matrix_b")', + "sch.compute_at(block=b70, loop=l48, preserve_unit_loops=True)", + "sch.compute_at(block=b71, loop=l48, preserve_unit_loops=True)", + 'sch.annotate(block_or_loop=b70, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_a")', + 'sch.annotate(block_or_loop=b71, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_b")', "sch.reverse_compute_at(block=b19, loop=l51, preserve_unit_loops=True)", "sch.reverse_compute_at(block=b18, loop=l51, preserve_unit_loops=True)", ] diff --git a/tests/python/unittest/test_meta_schedule_sketch_cuda.py b/tests/python/unittest/test_meta_schedule_sketch_cuda.py index ff31db46351c..3255c958a575 100644 --- a/tests/python/unittest/test_meta_schedule_sketch_cuda.py +++ b/tests/python/unittest/test_meta_schedule_sketch_cuda.py @@ -56,17 +56,17 @@ def test_meta_schedule_cuda_sketch_matmul(): "sch.compute_at(block=b35, loop=l29, preserve_unit_loops=True)", "l36, l37, l38, l39, l40, l41 = sch.get_loops(block=b35)", "l42 = sch.fuse(l40, l41)", - "v43, v44 = sch.sample_perfect_tile(loop=l42, n=2, max_innermost_factor=4)", - 'sch.annotate(block_or_loop=b35, ann_key="meta_schedule.cooperative_fetch", ann_val=v44)', - 'b45 = sch.cache_read(block=b0, read_buffer_index=2, storage_scope="shared")', - "sch.compute_at(block=b45, loop=l29, preserve_unit_loops=True)", - "l46, l47, l48, l49, l50, l51 = sch.get_loops(block=b45)", - "l52 = sch.fuse(l50, l51)", - "v53, v54 = sch.sample_perfect_tile(loop=l52, n=2, max_innermost_factor=4)", - 'sch.annotate(block_or_loop=b45, ann_key="meta_schedule.cooperative_fetch", ann_val=v54)', + "v43 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b35, ann_key="meta_schedule.cooperative_fetch", ann_val=v43)', + 'b44 = sch.cache_read(block=b0, read_buffer_index=2, storage_scope="shared")', + "sch.compute_at(block=b44, loop=l29, preserve_unit_loops=True)", + "l45, l46, l47, l48, l49, l50 = sch.get_loops(block=b44)", + "l51 = sch.fuse(l49, l50)", + "v52 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b44, ann_key="meta_schedule.cooperative_fetch", ann_val=v52)', "sch.reverse_compute_at(block=b2, loop=l34, preserve_unit_loops=True)", - "v55 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001])", - 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v55)', + "v53 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001])", + 'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v53)', ] ] # pylint: enable=line-too-long @@ -112,18 +112,18 @@ def test_meta_schedule_cuda_sketch_matmul_relu(): "sch.compute_at(block=b36, loop=l30, preserve_unit_loops=True)", "l37, l38, l39, l40, l41, l42 = sch.get_loops(block=b36)", "l43 = sch.fuse(l41, l42)", - "v44, v45 = sch.sample_perfect_tile(loop=l43, n=2, max_innermost_factor=4)", - 'sch.annotate(block_or_loop=b36, ann_key="meta_schedule.cooperative_fetch", ann_val=v45)', - 'b46 = sch.cache_read(block=b0, read_buffer_index=2, storage_scope="shared")', - "sch.compute_at(block=b46, loop=l30, preserve_unit_loops=True)", - "l47, l48, l49, l50, l51, l52 = sch.get_loops(block=b46)", - "l53 = sch.fuse(l51, l52)", - "v54, v55 = sch.sample_perfect_tile(loop=l53, n=2, max_innermost_factor=4)", - 'sch.annotate(block_or_loop=b46, ann_key="meta_schedule.cooperative_fetch", ann_val=v55)', + "v44 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b36, ann_key="meta_schedule.cooperative_fetch", ann_val=v44)', + 'b45 = sch.cache_read(block=b0, read_buffer_index=2, storage_scope="shared")', + "sch.compute_at(block=b45, loop=l30, preserve_unit_loops=True)", + "l46, l47, l48, l49, l50, l51 = sch.get_loops(block=b45)", + "l52 = sch.fuse(l50, l51)", + "v53 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b45, ann_key="meta_schedule.cooperative_fetch", ann_val=v53)', "sch.reverse_compute_at(block=b3, loop=l35, preserve_unit_loops=True)", "sch.reverse_compute_inline(block=b1)", - "v56 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001])", - 'sch.annotate(block_or_loop=b2, ann_key="meta_schedule.unroll_explicit", ann_val=v56)', + "v54 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001])", + 'sch.annotate(block_or_loop=b2, ann_key="meta_schedule.unroll_explicit", ann_val=v54)', ] ] # pylint: enable=line-too-long @@ -177,18 +177,18 @@ def test_meta_schedule_cuda_sketch_conv2d_nchw(): "sch.compute_at(block=b72, loop=l66, preserve_unit_loops=True)", "l73, l74, l75, l76, l77, l78, l79, l80, l81, l82 = sch.get_loops(block=b72)", "l83 = sch.fuse(l79, l80, l81, l82)", - "v84, v85 = sch.sample_perfect_tile(loop=l83, n=2, max_innermost_factor=4)", - 'sch.annotate(block_or_loop=b72, ann_key="meta_schedule.cooperative_fetch", ann_val=v85)', - 'b86 = sch.cache_read(block=b1, read_buffer_index=2, storage_scope="shared")', - "sch.compute_at(block=b86, loop=l66, preserve_unit_loops=True)", - "l87, l88, l89, l90, l91, l92, l93, l94, l95, l96 = sch.get_loops(block=b86)", - "l97 = sch.fuse(l93, l94, l95, l96)", - "v98, v99 = sch.sample_perfect_tile(loop=l97, n=2, max_innermost_factor=4)", - 'sch.annotate(block_or_loop=b86, ann_key="meta_schedule.cooperative_fetch", ann_val=v99)', + "v84 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b72, ann_key="meta_schedule.cooperative_fetch", ann_val=v84)', + 'b85 = sch.cache_read(block=b1, read_buffer_index=2, storage_scope="shared")', + "sch.compute_at(block=b85, loop=l66, preserve_unit_loops=True)", + "l86, l87, l88, l89, l90, l91, l92, l93, l94, l95 = sch.get_loops(block=b85)", + "l96 = sch.fuse(l92, l93, l94, l95)", + "v97 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b85, ann_key="meta_schedule.cooperative_fetch", ann_val=v97)', "sch.reverse_compute_at(block=b3, loop=l71, preserve_unit_loops=True)", "sch.compute_inline(block=b0)", - "v100 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001])", - 'sch.annotate(block_or_loop=b2, ann_key="meta_schedule.unroll_explicit", ann_val=v100)', + "v98 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001])", + 'sch.annotate(block_or_loop=b2, ann_key="meta_schedule.unroll_explicit", ann_val=v98)', ] ] # pylint: enable=line-too-long @@ -253,22 +253,22 @@ def test_meta_schedule_cuda_sketch_conv2d_nchw_bias_bn_relu(): # pylint: disabl "sch.compute_at(block=b76, loop=l70, preserve_unit_loops=True)", "l77, l78, l79, l80, l81, l82, l83, l84, l85, l86 = sch.get_loops(block=b76)", "l87 = sch.fuse(l83, l84, l85, l86)", - "v88, v89 = sch.sample_perfect_tile(loop=l87, n=2, max_innermost_factor=4)", - 'sch.annotate(block_or_loop=b76, ann_key="meta_schedule.cooperative_fetch", ann_val=v89)', - 'b90 = sch.cache_read(block=b1, read_buffer_index=2, storage_scope="shared")', - "sch.compute_at(block=b90, loop=l70, preserve_unit_loops=True)", - "l91, l92, l93, l94, l95, l96, l97, l98, l99, l100 = sch.get_loops(block=b90)", - "l101 = sch.fuse(l97, l98, l99, l100)", - "v102, v103 = sch.sample_perfect_tile(loop=l101, n=2, max_innermost_factor=4)", - 'sch.annotate(block_or_loop=b90, ann_key="meta_schedule.cooperative_fetch", ann_val=v103)', + "v88 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b76, ann_key="meta_schedule.cooperative_fetch", ann_val=v88)', + 'b89 = sch.cache_read(block=b1, read_buffer_index=2, storage_scope="shared")', + "sch.compute_at(block=b89, loop=l70, preserve_unit_loops=True)", + "l90, l91, l92, l93, l94, l95, l96, l97, l98, l99 = sch.get_loops(block=b89)", + "l100 = sch.fuse(l96, l97, l98, l99)", + "v101 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])", + 'sch.annotate(block_or_loop=b89, ann_key="meta_schedule.cooperative_fetch", ann_val=v101)', "sch.reverse_compute_at(block=b7, loop=l75, preserve_unit_loops=True)", "sch.reverse_compute_inline(block=b5)", "sch.reverse_compute_inline(block=b4)", "sch.reverse_compute_inline(block=b3)", "sch.reverse_compute_inline(block=b2)", "sch.compute_inline(block=b0)", - "v104 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001])", - 'sch.annotate(block_or_loop=b6, ann_key="meta_schedule.unroll_explicit", ann_val=v104)', + "v102 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001])", + 'sch.annotate(block_or_loop=b6, ann_key="meta_schedule.unroll_explicit", ann_val=v102)', ] ] # pylint: enable=line-too-long diff --git a/tests/python/unittest/test_meta_schedule_tune_tir.py b/tests/python/unittest/test_meta_schedule_tune_tir.py index d99bfe4a86e5..6e80e5a69c11 100644 --- a/tests/python/unittest/test_meta_schedule_tune_tir.py +++ b/tests/python/unittest/test_meta_schedule_tune_tir.py @@ -124,7 +124,7 @@ def _sch_rules(): tile_binds=["blockIdx.x", "blockIdx.y", "threadIdx.y"], use_tensor_core=True, max_innermost_factor=64, - vector_load_max_len=4, + vector_load_lens=[1, 2, 3, 4], reuse_read=schedule_rule.ReuseType( req="must", levels=[4],