From 3e4a30e02947452152617f8ae997230440137c40 Mon Sep 17 00:00:00 2001
From: Junru Shao <junrushao1994@gmail.com>
Date: Sat, 22 Jan 2022 05:23:03 -0800
Subject: [PATCH] Fix cooperative fetching (#17)

---
 include/tvm/meta_schedule/schedule_rule.h     |   4 +-
 .../schedule_rule/multi_level_tiling.py       |   6 +-
 .../meta_schedule/testing/schedule_rule.py    |   6 +-
 python/tvm/meta_schedule/tune.py              |   4 +-
 src/meta_schedule/mutator/mutate_tile_size.cc | 108 ++++++++++++++----
 .../schedule_rule/multi_level_tiling.cc       |  24 ++--
 ...hedule_schedule_rule_multi_level_tiling.py |  88 +++++++-------
 .../test_meta_schedule_sketch_cuda.py         |  80 ++++++-------
 .../unittest/test_meta_schedule_tune_tir.py   |   2 +-
 9 files changed, 193 insertions(+), 129 deletions(-)

diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h
index eb22178ff2bd..449c6cf7e4cf 100644
--- a/include/tvm/meta_schedule/schedule_rule.h
+++ b/include/tvm/meta_schedule/schedule_rule.h
@@ -141,7 +141,7 @@ class ScheduleRule : public runtime::ObjectRef {
    * - [blockIdx.x, vthread.x, threadIdx.x] on GPU
    * \param use_tensor_core Whether to apply tensor core wmma intrinsic for the computation
    * \param max_innermost_factor The maximum size of the innermost factor. NullOpt means no limit
-   * \param vector_load_max_len The length of vector lane in vectorized cooperative fetching.
+   * \param vector_load_lens The length of vector lane in vectorized cooperative fetching.
    * NullOpt means disable vectorization
    * \param reuse_read Data reuse configuration for reading. NullOpt means no reuse.
    * \param reuse_write Data reuse configuration for writing. NullOpt means no reuse.
@@ -151,7 +151,7 @@ class ScheduleRule : public runtime::ObjectRef {
                                                Optional<Array<String>> tile_binds,           //
                                                bool use_tensor_core,                         //
                                                Optional<Integer> max_innermost_factor,       //
-                                               Optional<Integer> vector_load_max_len,        //
+                                               Optional<Array<Integer>> vector_load_lens,    //
                                                Optional<Map<String, ObjectRef>> reuse_read,  //
                                                Optional<Map<String, ObjectRef>> reuse_write);
   /*!
diff --git a/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py b/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py
index b9eba95b869e..9e030d8a425c 100644
--- a/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py
+++ b/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py
@@ -57,7 +57,7 @@ class MultiLevelTiling(ScheduleRule):
         Whether to apply tensor core wmma intrinsic for the computation
     max_innermost_factor : Optional[int]
         The maximum size of the innermost factor. None means no limit
-    vector_load_max_len : Optional[int]
+    vector_load_lens : Optional[List[int]]
         The length of vector lane in vectorized cooperative fetching.
         None means disable vectorization
     reuse_read : Optional[ReuseType]
@@ -72,7 +72,7 @@ def __init__(
         tile_binds: Optional[List[str]] = None,
         use_tensor_core: bool = False,
         max_innermost_factor: Optional[int] = None,
-        vector_load_max_len: Optional[int] = None,
+        vector_load_lens: Optional[List[int]] = None,
         reuse_read: Optional[ReuseType] = None,
         reuse_write: Optional[ReuseType] = None,
     ) -> None:
@@ -82,7 +82,7 @@ def __init__(
             tile_binds,
             use_tensor_core,
             max_innermost_factor,
-            vector_load_max_len,
+            vector_load_lens,
             reuse_read.as_dict() if reuse_read is not None else None,
             reuse_write.as_dict() if reuse_write is not None else None,
         )
diff --git a/python/tvm/meta_schedule/testing/schedule_rule.py b/python/tvm/meta_schedule/testing/schedule_rule.py
index d62a54bebac6..83434a123a03 100644
--- a/python/tvm/meta_schedule/testing/schedule_rule.py
+++ b/python/tvm/meta_schedule/testing/schedule_rule.py
@@ -111,7 +111,7 @@ def multi_level_tiling(target: Target) -> ScheduleRule:
             structure="SSRSRS",
             tile_binds=None,
             max_innermost_factor=64,
-            vector_load_max_len=None,
+            vector_load_lens=None,
             reuse_read=None,
             reuse_write=ReuseType(
                 req="may",
@@ -124,7 +124,7 @@ def multi_level_tiling(target: Target) -> ScheduleRule:
             structure="SSSRRSRS",
             tile_binds=["blockIdx.x", "vthread.x", "threadIdx.x"],
             max_innermost_factor=64,
-            vector_load_max_len=4,
+            vector_load_lens=[1, 2, 3, 4],
             reuse_read=ReuseType(
                 req="must",
                 levels=[4],
@@ -147,7 +147,7 @@ def multi_level_tiling_tensor_core(target: Target) -> ScheduleRule:
             tile_binds=["blockIdx.x", "blockIdx.y", "threadIdx.y"],
             use_tensor_core=True,
             max_innermost_factor=64,
-            vector_load_max_len=4,
+            vector_load_lens=[1, 2, 3, 4],
             reuse_read=ReuseType(
                 req="must",
                 levels=[4],
diff --git a/python/tvm/meta_schedule/tune.py b/python/tvm/meta_schedule/tune.py
index ee9198eb1def..4f38d7cc98be 100644
--- a/python/tvm/meta_schedule/tune.py
+++ b/python/tvm/meta_schedule/tune.py
@@ -101,7 +101,7 @@ def _sch_rules() -> List[ScheduleRule]:
                 structure="SSRSRS",
                 tile_binds=None,
                 max_innermost_factor=64,
-                vector_load_max_len=None,
+                vector_load_lens=None,
                 reuse_read=None,
                 reuse_write=M.ReuseType(
                     req="may",
@@ -158,7 +158,7 @@ def _sch_rules() -> List[ScheduleRule]:
                 structure="SSSRRSRS",
                 tile_binds=["blockIdx.x", "vthread.x", "threadIdx.x"],
                 max_innermost_factor=64,
-                vector_load_max_len=4,
+                vector_load_lens=[1, 2, 3, 4],
                 reuse_read=M.ReuseType(
                     req="must",
                     levels=[4],
diff --git a/src/meta_schedule/mutator/mutate_tile_size.cc b/src/meta_schedule/mutator/mutate_tile_size.cc
index 1daf1f265e70..02c418b3c2c4 100644
--- a/src/meta_schedule/mutator/mutate_tile_size.cc
+++ b/src/meta_schedule/mutator/mutate_tile_size.cc
@@ -33,7 +33,7 @@ using tir::Trace;
  * \param decision The decision of Sample-Perfect-Tile
  * \return The result of downcast
  */
-std::vector<int64_t> DowncastDecision(const ObjectRef& decision) {
+std::vector<int64_t> DowncastTilingDecision(const ObjectRef& decision) {
   const auto* arr = TVM_TYPE_AS(arr, decision, runtime::ArrayNode);
   return support::AsVector<ObjectRef, int64_t>(GetRef<Array<ObjectRef>>(arr));
 }
@@ -73,34 +73,62 @@ class MutateTileSizeNode : public MutatorNode {
  * \param decision The decision selected
  * \return Whether a decision is found
  */
-bool FindSamplePerfectTile(const Trace& trace, TRandState* rand_state, Instruction* inst,
-                           std::vector<int64_t>* decision) {
+void FindSamplePerfectTile(const Trace& trace, std::vector<Instruction>* inst,
+                           std::vector<std::vector<int64_t>>* decision) {
   static const InstructionKind& inst_sample_perfect_tile =
       InstructionKind::Get("SamplePerfectTile");
-  std::vector<Instruction> instructions;
-  std::vector<std::vector<int64_t>> decisions;
+  std::vector<Instruction>& instructions = *inst;
+  std::vector<std::vector<int64_t>>& decisions = *decision;
   instructions.reserve(trace->decisions.size());
   decisions.reserve(trace->decisions.size());
   for (const auto& kv : trace->decisions) {
     const Instruction& inst = kv.first;
     const ObjectRef& decision = kv.second;
-    if (!inst->kind.same_as(inst_sample_perfect_tile)) {
-      continue;
+    if (inst->kind.same_as(inst_sample_perfect_tile)) {
+      std::vector<int64_t> tiles = DowncastTilingDecision(decision);
+      if (tiles.size() >= 2 && Product(tiles) >= 2) {
+        instructions.push_back(inst);
+        decisions.push_back(tiles);
+      }
     }
-    std::vector<int64_t> tiles = DowncastDecision(decision);
-    if (tiles.size() >= 2 && Product(tiles) >= 2) {
-      instructions.push_back(inst);
-      decisions.push_back(tiles);
+  }
+}
+
+void FindSampleVectorize(const Trace& trace, std::vector<Instruction>* inst,
+                         std::vector<int64_t>* decision) {
+  static const InstructionKind& inst_sample_categorical = InstructionKind::Get("SampleCategorical");
+  static const InstructionKind& inst_annotate = InstructionKind::Get("Annotate");
+  std::vector<Instruction>& instructions = *inst;
+  std::vector<int64_t>& decisions = *decision;
+  std::unordered_set<const Object*> annotated;
+  instructions.reserve(trace->decisions.size());
+  decisions.reserve(trace->decisions.size());
+  annotated.reserve(trace->decisions.size());
+  // Find annotation with `meta_schedule_cooperative_fetch`
+  for (const Instruction& inst : trace->insts) {
+    if (inst->kind.same_as(inst_annotate)) {
+      ICHECK_EQ(inst->attrs.size(), 1);
+      ICHECK_EQ(inst->inputs.size(), 2);
+      if (Downcast<String>(inst->attrs[0]) == tir::attr::meta_schedule_cooperative_fetch) {
+        const auto* ann_val = inst->inputs[1].as<tir::ExprRVNode>();
+        ICHECK(ann_val);
+        annotated.insert(ann_val);
+      }
     }
   }
-  int n = instructions.size();
-  if (n > 0) {
-    int i = tir::SampleInt(rand_state, 0, n);
-    *inst = instructions[i];
-    *decision = decisions[i];
-    return true;
+  // Find sampling instruction that generates the annotation
+  for (const auto& kv : trace->decisions) {
+    const Instruction& inst = kv.first;
+    const ObjectRef& decision = kv.second;
+    if (inst->kind.same_as(inst_sample_categorical)) {
+      ICHECK_EQ(inst->outputs.size(), 1);
+      if (annotated.count(inst->outputs[0].get())) {
+        const auto* d = TVM_TYPE_AS(d, decision, IntImmNode);
+        instructions.push_back(inst);
+        decisions.push_back(d->value);
+      }
+    }
   }
-  return false;
 }
 
 struct FactorMemo {
@@ -146,12 +174,8 @@ struct FactorMemo {
   std::mutex mutex_;
 };
 
-Optional<Trace> MutateTileSizeNode::Apply(const Trace& trace, TRandState* rand_state) {
-  Instruction inst;
-  std::vector<int64_t> tiles;
-  if (!FindSamplePerfectTile(trace, rand_state, &inst, &tiles)) {
-    return NullOpt;
-  }
+Optional<Trace> MutateSampleTileSize(const Trace& trace, Instruction inst,
+                                     std::vector<int64_t> tiles, TRandState* rand_state) {
   int n_splits = tiles.size();
   // Step 1. Choose two loops, `x` and `y`
   int x, y;
@@ -194,6 +218,42 @@ Optional<Trace> MutateTileSizeNode::Apply(const Trace& trace, TRandState* rand_s
   }
 }
 
+Optional<Trace> MutateSampleVectorize(const Trace& trace, Instruction inst,
+                                      int64_t original_decision, TRandState* rand_state) {
+  ICHECK_EQ(inst->attrs.size(), 2);
+  std::vector<double> probs =
+      support::AsVector<FloatImm, double>(Downcast<Array<FloatImm>>(inst->attrs[1]));
+  probs.erase(probs.begin() + original_decision);
+  int result = tir::MakeMultinomialSampler(rand_state, probs)();
+  if (result >= original_decision) {
+    result += 1;
+  }
+  return trace->WithDecision(inst, Integer(result), /*remove_postproc=*/true);
+}
+
+Optional<Trace> MutateTileSizeNode::Apply(const Trace& trace, TRandState* rand_state) {
+  std::vector<Instruction> sample_perfect_tile_insts;
+  std::vector<Instruction> sample_vectorize_insts;
+  std::vector<std::vector<int64_t>> sample_perfect_tile_tiles;
+  std::vector<int64_t> sample_vectorize_decisions;
+  FindSamplePerfectTile(trace, &sample_perfect_tile_insts, &sample_perfect_tile_tiles);
+  FindSampleVectorize(trace, &sample_vectorize_insts, &sample_vectorize_decisions);
+  int size_a = sample_perfect_tile_insts.size();
+  int size_b = sample_vectorize_insts.size();
+  if (size_a == 0 && size_b == 0) {
+    return NullOpt;
+  }
+  int n = tir::SampleInt(rand_state, 0, size_a + size_b);
+  if (n < size_a) {
+    return MutateSampleTileSize(trace, sample_perfect_tile_insts[n], sample_perfect_tile_tiles[n],
+                                rand_state);
+  } else {
+    n -= size_a;
+    return MutateSampleVectorize(trace, sample_vectorize_insts[n], sample_vectorize_decisions[n],
+                                 rand_state);
+  }
+}
+
 Mutator Mutator::MutateTileSize() { return Mutator(make_object<MutateTileSizeNode>()); }
 
 TVM_REGISTER_NODE_TYPE(MutateTileSizeNode);
diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.cc b/src/meta_schedule/schedule_rule/multi_level_tiling.cc
index a0ffe7e00426..a5d677c5cdf2 100644
--- a/src/meta_schedule/schedule_rule/multi_level_tiling.cc
+++ b/src/meta_schedule/schedule_rule/multi_level_tiling.cc
@@ -322,7 +322,7 @@ class MultiLevelTilingNode : public ScheduleRuleNode {
   /*! \brief The maximum size of the innermost factor */
   int max_innermost_factor;
   /*! \brief The length of vector lane in vectorized cooperative fetching */
-  int vector_load_max_len;
+  std::vector<int> vector_load_lens;
   /*! \brief Data reuse configuration for reading */
   ReuseConfig reuse_read_;
   /*! \brief Data reuse configuration for writing */
@@ -337,7 +337,7 @@ class MultiLevelTilingNode : public ScheduleRuleNode {
     v->Visit("tile_binds", &tile_binds);
     v->Visit("use_tensor_core", &use_tensor_core);
     v->Visit("max_innermost_factor", &max_innermost_factor);
-    v->Visit("vector_load_max_len", &vector_load_max_len);
+    // `vector_load_lens` is not visited
     // `reuse_read_` is not visited
     // `reuse_write_` is not visited
     // `s_indices_` is not visited
@@ -491,12 +491,14 @@ inline std::vector<State> MultiLevelTilingNode::AddReadReuse(State state) const
       LoopRV fused = sch->Fuse(Array<LoopRV>{buffer_loops.end() - buffer_ndim,  //
                                              buffer_loops.end()});
       // Annotate cooperative fetching
-      if (vector_load_max_len > 0) {
-        // cooperative fetch + vectorized loading
-        // Split into inner and outer, vectorize the inner loop
-        Array<ExprRV> factors = sch->SamplePerfectTile(fused, 2, vector_load_max_len);
-        // Add cooperative fetching
-        sch->Annotate(cache_read_block, tir::attr::meta_schedule_cooperative_fetch, factors[1]);
+      if (!vector_load_lens.empty()) {
+        int n = vector_load_lens.size();
+        double prob = 1.0 / n;
+        ExprRV vector_load_len =
+            sch->SampleCategorical(support::AsArray<int, Integer>(vector_load_lens),
+                                   Array<FloatImm>(n, FloatImm(DataType::Float(64), prob)));
+        sch->Annotate(cache_read_block, tir::attr::meta_schedule_cooperative_fetch,
+                      vector_load_len);
       }
     }
     State new_state = state;
@@ -545,7 +547,7 @@ inline std::vector<State> MultiLevelTilingNode::FuseWriteReuse(State state) cons
 ScheduleRule ScheduleRule::MultiLevelTiling(String structure, Optional<Array<String>> tile_binds,
                                             bool use_tensor_core,
                                             Optional<Integer> max_innermost_factor,
-                                            Optional<Integer> vector_load_max_len,
+                                            Optional<Array<Integer>> vector_load_lens,
                                             Optional<Map<String, ObjectRef>> reuse_read,
                                             Optional<Map<String, ObjectRef>> reuse_write) {
   ObjectPtr<MultiLevelTilingNode> n = make_object<MultiLevelTilingNode>();
@@ -561,7 +563,9 @@ ScheduleRule ScheduleRule::MultiLevelTiling(String structure, Optional<Array<Str
     tir::TensorIntrin::Get("wmma_fill");
   }
   n->max_innermost_factor = max_innermost_factor.value_or(Integer(-1))->value;
-  n->vector_load_max_len = vector_load_max_len.value_or(Integer(-1))->value;
+  n->vector_load_lens = vector_load_lens.defined()
+                            ? support::AsVector<Integer, int>(vector_load_lens.value())
+                            : std::vector<int>();
   n->reuse_read_ = reuse_read.defined() ? ReuseConfig(reuse_read.value()) : ReuseConfig();
   n->reuse_write_ = reuse_write.defined() ? ReuseConfig(reuse_write.value()) : ReuseConfig();
   for (int i = 0, len = structure.size(); i < len; ++i) {
diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py b/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py
index dd703e49ff0e..dba661bba03c 100644
--- a/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py
+++ b/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py
@@ -190,14 +190,14 @@ def test_cuda_matmul():
             "sch.compute_at(block=b34, loop=l28, preserve_unit_loops=True)",
             "l35, l36, l37, l38, l39, l40 = sch.get_loops(block=b34)",
             "l41 = sch.fuse(l39, l40)",
-            "v42, v43 = sch.sample_perfect_tile(loop=l41, n=2, max_innermost_factor=4)",
-            'sch.annotate(block_or_loop=b34, ann_key="meta_schedule.cooperative_fetch", ann_val=v43)',
-            'b44 = sch.cache_read(block=b0, read_buffer_index=2, storage_scope="shared")',
-            "sch.compute_at(block=b44, loop=l28, preserve_unit_loops=True)",
-            "l45, l46, l47, l48, l49, l50 = sch.get_loops(block=b44)",
-            "l51 = sch.fuse(l49, l50)",
-            "v52, v53 = sch.sample_perfect_tile(loop=l51, n=2, max_innermost_factor=4)",
-            'sch.annotate(block_or_loop=b44, ann_key="meta_schedule.cooperative_fetch", ann_val=v53)',
+            "v42 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])",
+            'sch.annotate(block_or_loop=b34, ann_key="meta_schedule.cooperative_fetch", ann_val=v42)',
+            'b43 = sch.cache_read(block=b0, read_buffer_index=2, storage_scope="shared")',
+            "sch.compute_at(block=b43, loop=l28, preserve_unit_loops=True)",
+            "l44, l45, l46, l47, l48, l49 = sch.get_loops(block=b43)",
+            "l50 = sch.fuse(l48, l49)",
+            "v51 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])",
+            'sch.annotate(block_or_loop=b43, ann_key="meta_schedule.cooperative_fetch", ann_val=v51)',
             "sch.reverse_compute_at(block=b1, loop=l33, preserve_unit_loops=True)",
         ]
     ]
@@ -244,14 +244,14 @@ def test_cuda_matmul_relu():
             "sch.compute_at(block=b34, loop=l28, preserve_unit_loops=True)",
             "l35, l36, l37, l38, l39, l40 = sch.get_loops(block=b34)",
             "l41 = sch.fuse(l39, l40)",
-            "v42, v43 = sch.sample_perfect_tile(loop=l41, n=2, max_innermost_factor=4)",
-            'sch.annotate(block_or_loop=b34, ann_key="meta_schedule.cooperative_fetch", ann_val=v43)',
-            'b44 = sch.cache_read(block=b0, read_buffer_index=2, storage_scope="shared")',
-            "sch.compute_at(block=b44, loop=l28, preserve_unit_loops=True)",
-            "l45, l46, l47, l48, l49, l50 = sch.get_loops(block=b44)",
-            "l51 = sch.fuse(l49, l50)",
-            "v52, v53 = sch.sample_perfect_tile(loop=l51, n=2, max_innermost_factor=4)",
-            'sch.annotate(block_or_loop=b44, ann_key="meta_schedule.cooperative_fetch", ann_val=v53)',
+            "v42 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])",
+            'sch.annotate(block_or_loop=b34, ann_key="meta_schedule.cooperative_fetch", ann_val=v42)',
+            'b43 = sch.cache_read(block=b0, read_buffer_index=2, storage_scope="shared")',
+            "sch.compute_at(block=b43, loop=l28, preserve_unit_loops=True)",
+            "l44, l45, l46, l47, l48, l49 = sch.get_loops(block=b43)",
+            "l50 = sch.fuse(l48, l49)",
+            "v51 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])",
+            'sch.annotate(block_or_loop=b43, ann_key="meta_schedule.cooperative_fetch", ann_val=v51)',
             "sch.reverse_compute_at(block=b1, loop=l33, preserve_unit_loops=True)",
         ]
     ]
@@ -310,20 +310,20 @@ def test_cuda_tensor_core_matmul():
             "sch.compute_at(block=b52, loop=l46, preserve_unit_loops=True)",
             "l53, l54, l55, l56, l57, l58 = sch.get_loops(block=b52)",
             "l59 = sch.fuse(l57, l58)",
-            "v60, v61 = sch.sample_perfect_tile(loop=l59, n=2, max_innermost_factor=4)",
-            'sch.annotate(block_or_loop=b52, ann_key="meta_schedule.cooperative_fetch", ann_val=v61)',
-            'b62 = sch.cache_read(block=b16, read_buffer_index=2, storage_scope="shared")',
-            "sch.compute_at(block=b62, loop=l46, preserve_unit_loops=True)",
-            "l63, l64, l65, l66, l67, l68 = sch.get_loops(block=b62)",
-            "l69 = sch.fuse(l67, l68)",
-            "v70, v71 = sch.sample_perfect_tile(loop=l69, n=2, max_innermost_factor=4)",
-            'sch.annotate(block_or_loop=b62, ann_key="meta_schedule.cooperative_fetch", ann_val=v71)',
-            'b72 = sch.cache_read(block=b16, read_buffer_index=1, storage_scope="wmma.matrix_a")',
-            'b73 = sch.cache_read(block=b16, read_buffer_index=2, storage_scope="wmma.matrix_b")',
-            "sch.compute_at(block=b72, loop=l48, preserve_unit_loops=True)",
-            "sch.compute_at(block=b73, loop=l48, preserve_unit_loops=True)",
-            'sch.annotate(block_or_loop=b72, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_a")',
-            'sch.annotate(block_or_loop=b73, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_b")',
+            "v60 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])",
+            'sch.annotate(block_or_loop=b52, ann_key="meta_schedule.cooperative_fetch", ann_val=v60)',
+            'b61 = sch.cache_read(block=b16, read_buffer_index=2, storage_scope="shared")',
+            "sch.compute_at(block=b61, loop=l46, preserve_unit_loops=True)",
+            "l62, l63, l64, l65, l66, l67 = sch.get_loops(block=b61)",
+            "l68 = sch.fuse(l66, l67)",
+            "v69 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])",
+            'sch.annotate(block_or_loop=b61, ann_key="meta_schedule.cooperative_fetch", ann_val=v69)',
+            'b70 = sch.cache_read(block=b16, read_buffer_index=1, storage_scope="wmma.matrix_a")',
+            'b71 = sch.cache_read(block=b16, read_buffer_index=2, storage_scope="wmma.matrix_b")',
+            "sch.compute_at(block=b70, loop=l48, preserve_unit_loops=True)",
+            "sch.compute_at(block=b71, loop=l48, preserve_unit_loops=True)",
+            'sch.annotate(block_or_loop=b70, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_a")',
+            'sch.annotate(block_or_loop=b71, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_b")',
             "sch.reverse_compute_at(block=b19, loop=l51, preserve_unit_loops=True)",
             "sch.reverse_compute_at(block=b18, loop=l51, preserve_unit_loops=True)",
         ]
@@ -382,20 +382,20 @@ def test_cuda_tensor_core_matmul_relu():
             "sch.compute_at(block=b52, loop=l46, preserve_unit_loops=True)",
             "l53, l54, l55, l56, l57, l58 = sch.get_loops(block=b52)",
             "l59 = sch.fuse(l57, l58)",
-            "v60, v61 = sch.sample_perfect_tile(loop=l59, n=2, max_innermost_factor=4)",
-            'sch.annotate(block_or_loop=b52, ann_key="meta_schedule.cooperative_fetch", ann_val=v61)',
-            'b62 = sch.cache_read(block=b16, read_buffer_index=2, storage_scope="shared")',
-            "sch.compute_at(block=b62, loop=l46, preserve_unit_loops=True)",
-            "l63, l64, l65, l66, l67, l68 = sch.get_loops(block=b62)",
-            "l69 = sch.fuse(l67, l68)",
-            "v70, v71 = sch.sample_perfect_tile(loop=l69, n=2, max_innermost_factor=4)",
-            'sch.annotate(block_or_loop=b62, ann_key="meta_schedule.cooperative_fetch", ann_val=v71)',
-            'b72 = sch.cache_read(block=b16, read_buffer_index=1, storage_scope="wmma.matrix_a")',
-            'b73 = sch.cache_read(block=b16, read_buffer_index=2, storage_scope="wmma.matrix_b")',
-            "sch.compute_at(block=b72, loop=l48, preserve_unit_loops=True)",
-            "sch.compute_at(block=b73, loop=l48, preserve_unit_loops=True)",
-            'sch.annotate(block_or_loop=b72, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_a")',
-            'sch.annotate(block_or_loop=b73, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_b")',
+            "v60 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])",
+            'sch.annotate(block_or_loop=b52, ann_key="meta_schedule.cooperative_fetch", ann_val=v60)',
+            'b61 = sch.cache_read(block=b16, read_buffer_index=2, storage_scope="shared")',
+            "sch.compute_at(block=b61, loop=l46, preserve_unit_loops=True)",
+            "l62, l63, l64, l65, l66, l67 = sch.get_loops(block=b61)",
+            "l68 = sch.fuse(l66, l67)",
+            "v69 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])",
+            'sch.annotate(block_or_loop=b61, ann_key="meta_schedule.cooperative_fetch", ann_val=v69)',
+            'b70 = sch.cache_read(block=b16, read_buffer_index=1, storage_scope="wmma.matrix_a")',
+            'b71 = sch.cache_read(block=b16, read_buffer_index=2, storage_scope="wmma.matrix_b")',
+            "sch.compute_at(block=b70, loop=l48, preserve_unit_loops=True)",
+            "sch.compute_at(block=b71, loop=l48, preserve_unit_loops=True)",
+            'sch.annotate(block_or_loop=b70, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_a")',
+            'sch.annotate(block_or_loop=b71, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_b")',
             "sch.reverse_compute_at(block=b19, loop=l51, preserve_unit_loops=True)",
             "sch.reverse_compute_at(block=b18, loop=l51, preserve_unit_loops=True)",
         ]
diff --git a/tests/python/unittest/test_meta_schedule_sketch_cuda.py b/tests/python/unittest/test_meta_schedule_sketch_cuda.py
index ff31db46351c..3255c958a575 100644
--- a/tests/python/unittest/test_meta_schedule_sketch_cuda.py
+++ b/tests/python/unittest/test_meta_schedule_sketch_cuda.py
@@ -56,17 +56,17 @@ def test_meta_schedule_cuda_sketch_matmul():
             "sch.compute_at(block=b35, loop=l29, preserve_unit_loops=True)",
             "l36, l37, l38, l39, l40, l41 = sch.get_loops(block=b35)",
             "l42 = sch.fuse(l40, l41)",
-            "v43, v44 = sch.sample_perfect_tile(loop=l42, n=2, max_innermost_factor=4)",
-            'sch.annotate(block_or_loop=b35, ann_key="meta_schedule.cooperative_fetch", ann_val=v44)',
-            'b45 = sch.cache_read(block=b0, read_buffer_index=2, storage_scope="shared")',
-            "sch.compute_at(block=b45, loop=l29, preserve_unit_loops=True)",
-            "l46, l47, l48, l49, l50, l51 = sch.get_loops(block=b45)",
-            "l52 = sch.fuse(l50, l51)",
-            "v53, v54 = sch.sample_perfect_tile(loop=l52, n=2, max_innermost_factor=4)",
-            'sch.annotate(block_or_loop=b45, ann_key="meta_schedule.cooperative_fetch", ann_val=v54)',
+            "v43 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])",
+            'sch.annotate(block_or_loop=b35, ann_key="meta_schedule.cooperative_fetch", ann_val=v43)',
+            'b44 = sch.cache_read(block=b0, read_buffer_index=2, storage_scope="shared")',
+            "sch.compute_at(block=b44, loop=l29, preserve_unit_loops=True)",
+            "l45, l46, l47, l48, l49, l50 = sch.get_loops(block=b44)",
+            "l51 = sch.fuse(l49, l50)",
+            "v52 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])",
+            'sch.annotate(block_or_loop=b44, ann_key="meta_schedule.cooperative_fetch", ann_val=v52)',
             "sch.reverse_compute_at(block=b2, loop=l34, preserve_unit_loops=True)",
-            "v55 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001])",
-            'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v55)',
+            "v53 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001])",
+            'sch.annotate(block_or_loop=b1, ann_key="meta_schedule.unroll_explicit", ann_val=v53)',
         ]
     ]
     # pylint: enable=line-too-long
@@ -112,18 +112,18 @@ def test_meta_schedule_cuda_sketch_matmul_relu():
             "sch.compute_at(block=b36, loop=l30, preserve_unit_loops=True)",
             "l37, l38, l39, l40, l41, l42 = sch.get_loops(block=b36)",
             "l43 = sch.fuse(l41, l42)",
-            "v44, v45 = sch.sample_perfect_tile(loop=l43, n=2, max_innermost_factor=4)",
-            'sch.annotate(block_or_loop=b36, ann_key="meta_schedule.cooperative_fetch", ann_val=v45)',
-            'b46 = sch.cache_read(block=b0, read_buffer_index=2, storage_scope="shared")',
-            "sch.compute_at(block=b46, loop=l30, preserve_unit_loops=True)",
-            "l47, l48, l49, l50, l51, l52 = sch.get_loops(block=b46)",
-            "l53 = sch.fuse(l51, l52)",
-            "v54, v55 = sch.sample_perfect_tile(loop=l53, n=2, max_innermost_factor=4)",
-            'sch.annotate(block_or_loop=b46, ann_key="meta_schedule.cooperative_fetch", ann_val=v55)',
+            "v44 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])",
+            'sch.annotate(block_or_loop=b36, ann_key="meta_schedule.cooperative_fetch", ann_val=v44)',
+            'b45 = sch.cache_read(block=b0, read_buffer_index=2, storage_scope="shared")',
+            "sch.compute_at(block=b45, loop=l30, preserve_unit_loops=True)",
+            "l46, l47, l48, l49, l50, l51 = sch.get_loops(block=b45)",
+            "l52 = sch.fuse(l50, l51)",
+            "v53 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])",
+            'sch.annotate(block_or_loop=b45, ann_key="meta_schedule.cooperative_fetch", ann_val=v53)',
             "sch.reverse_compute_at(block=b3, loop=l35, preserve_unit_loops=True)",
             "sch.reverse_compute_inline(block=b1)",
-            "v56 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001])",
-            'sch.annotate(block_or_loop=b2, ann_key="meta_schedule.unroll_explicit", ann_val=v56)',
+            "v54 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001])",
+            'sch.annotate(block_or_loop=b2, ann_key="meta_schedule.unroll_explicit", ann_val=v54)',
         ]
     ]
     # pylint: enable=line-too-long
@@ -177,18 +177,18 @@ def test_meta_schedule_cuda_sketch_conv2d_nchw():
             "sch.compute_at(block=b72, loop=l66, preserve_unit_loops=True)",
             "l73, l74, l75, l76, l77, l78, l79, l80, l81, l82 = sch.get_loops(block=b72)",
             "l83 = sch.fuse(l79, l80, l81, l82)",
-            "v84, v85 = sch.sample_perfect_tile(loop=l83, n=2, max_innermost_factor=4)",
-            'sch.annotate(block_or_loop=b72, ann_key="meta_schedule.cooperative_fetch", ann_val=v85)',
-            'b86 = sch.cache_read(block=b1, read_buffer_index=2, storage_scope="shared")',
-            "sch.compute_at(block=b86, loop=l66, preserve_unit_loops=True)",
-            "l87, l88, l89, l90, l91, l92, l93, l94, l95, l96 = sch.get_loops(block=b86)",
-            "l97 = sch.fuse(l93, l94, l95, l96)",
-            "v98, v99 = sch.sample_perfect_tile(loop=l97, n=2, max_innermost_factor=4)",
-            'sch.annotate(block_or_loop=b86, ann_key="meta_schedule.cooperative_fetch", ann_val=v99)',
+            "v84 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])",
+            'sch.annotate(block_or_loop=b72, ann_key="meta_schedule.cooperative_fetch", ann_val=v84)',
+            'b85 = sch.cache_read(block=b1, read_buffer_index=2, storage_scope="shared")',
+            "sch.compute_at(block=b85, loop=l66, preserve_unit_loops=True)",
+            "l86, l87, l88, l89, l90, l91, l92, l93, l94, l95 = sch.get_loops(block=b85)",
+            "l96 = sch.fuse(l92, l93, l94, l95)",
+            "v97 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])",
+            'sch.annotate(block_or_loop=b85, ann_key="meta_schedule.cooperative_fetch", ann_val=v97)',
             "sch.reverse_compute_at(block=b3, loop=l71, preserve_unit_loops=True)",
             "sch.compute_inline(block=b0)",
-            "v100 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001])",
-            'sch.annotate(block_or_loop=b2, ann_key="meta_schedule.unroll_explicit", ann_val=v100)',
+            "v98 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001])",
+            'sch.annotate(block_or_loop=b2, ann_key="meta_schedule.unroll_explicit", ann_val=v98)',
         ]
     ]
     # pylint: enable=line-too-long
@@ -253,22 +253,22 @@ def test_meta_schedule_cuda_sketch_conv2d_nchw_bias_bn_relu():  # pylint: disabl
             "sch.compute_at(block=b76, loop=l70, preserve_unit_loops=True)",
             "l77, l78, l79, l80, l81, l82, l83, l84, l85, l86 = sch.get_loops(block=b76)",
             "l87 = sch.fuse(l83, l84, l85, l86)",
-            "v88, v89 = sch.sample_perfect_tile(loop=l87, n=2, max_innermost_factor=4)",
-            'sch.annotate(block_or_loop=b76, ann_key="meta_schedule.cooperative_fetch", ann_val=v89)',
-            'b90 = sch.cache_read(block=b1, read_buffer_index=2, storage_scope="shared")',
-            "sch.compute_at(block=b90, loop=l70, preserve_unit_loops=True)",
-            "l91, l92, l93, l94, l95, l96, l97, l98, l99, l100 = sch.get_loops(block=b90)",
-            "l101 = sch.fuse(l97, l98, l99, l100)",
-            "v102, v103 = sch.sample_perfect_tile(loop=l101, n=2, max_innermost_factor=4)",
-            'sch.annotate(block_or_loop=b90, ann_key="meta_schedule.cooperative_fetch", ann_val=v103)',
+            "v88 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])",
+            'sch.annotate(block_or_loop=b76, ann_key="meta_schedule.cooperative_fetch", ann_val=v88)',
+            'b89 = sch.cache_read(block=b1, read_buffer_index=2, storage_scope="shared")',
+            "sch.compute_at(block=b89, loop=l70, preserve_unit_loops=True)",
+            "l90, l91, l92, l93, l94, l95, l96, l97, l98, l99 = sch.get_loops(block=b89)",
+            "l100 = sch.fuse(l96, l97, l98, l99)",
+            "v101 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])",
+            'sch.annotate(block_or_loop=b89, ann_key="meta_schedule.cooperative_fetch", ann_val=v101)',
             "sch.reverse_compute_at(block=b7, loop=l75, preserve_unit_loops=True)",
             "sch.reverse_compute_inline(block=b5)",
             "sch.reverse_compute_inline(block=b4)",
             "sch.reverse_compute_inline(block=b3)",
             "sch.reverse_compute_inline(block=b2)",
             "sch.compute_inline(block=b0)",
-            "v104 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001])",
-            'sch.annotate(block_or_loop=b6, ann_key="meta_schedule.unroll_explicit", ann_val=v104)',
+            "v102 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001])",
+            'sch.annotate(block_or_loop=b6, ann_key="meta_schedule.unroll_explicit", ann_val=v102)',
         ]
     ]
     # pylint: enable=line-too-long
diff --git a/tests/python/unittest/test_meta_schedule_tune_tir.py b/tests/python/unittest/test_meta_schedule_tune_tir.py
index d99bfe4a86e5..6e80e5a69c11 100644
--- a/tests/python/unittest/test_meta_schedule_tune_tir.py
+++ b/tests/python/unittest/test_meta_schedule_tune_tir.py
@@ -124,7 +124,7 @@ def _sch_rules():
                     tile_binds=["blockIdx.x", "blockIdx.y", "threadIdx.y"],
                     use_tensor_core=True,
                     max_innermost_factor=64,
-                    vector_load_max_len=4,
+                    vector_load_lens=[1, 2, 3, 4],
                     reuse_read=schedule_rule.ReuseType(
                         req="must",
                         levels=[4],