From aeb8c7b2700712e3a62edb40530ea0230120a045 Mon Sep 17 00:00:00 2001
From: "chengfan.jcf" <chengfan.jcf@alibaba-inc.com>
Date: Wed, 22 Jul 2020 15:18:20 +0800
Subject: [PATCH] Add pragma/storage_align/rfactor step

---
 include/tvm/auto_scheduler/loop_state.h       |  31 +-
 include/tvm/auto_scheduler/transform_step.h   | 195 +++++++++-
 python/tvm/auto_scheduler/loop_state.py       |  68 ++++
 src/auto_scheduler/loop_state.cc              |  40 ++
 src/auto_scheduler/transform_step.cc          | 350 +++++++++++++++++-
 src/auto_scheduler/utils.h                    |   6 +
 .../test_auto_scheduler_loop_state.py         |  54 +++
 .../unittest/test_auto_scheduler_measure.py   |  11 +-
 8 files changed, 747 insertions(+), 8 deletions(-)

diff --git a/include/tvm/auto_scheduler/loop_state.h b/include/tvm/auto_scheduler/loop_state.h
index 1c8ea770e2f84..059184b721585 100644
--- a/include/tvm/auto_scheduler/loop_state.h
+++ b/include/tvm/auto_scheduler/loop_state.h
@@ -340,6 +340,13 @@ class State : public ObjectRef {
    * result will become the new attach point.
    */
   TVM_DLL Iterator fuse(int stage_id, const Array<Iterator>& iters);
+  /*!
+   * \brief Schedule primitive corresponds to `te.Stage.pragma`.
+   * \param stage_id The index of the stage to add pragma.
+   * \param it The iterator to add pragma.
+   * \param pragma_type The pragma string.
+   */
+  TVM_DLL void pragma(int stage_id, const Iterator& it, const String& pragma_type);
   /*!
    * \brief Schedule primitive corresponds to `te::Stage::reorder`.
    * \param stage_id The index of the stage to be reordered.
@@ -359,6 +366,14 @@ class State : public ObjectRef {
   TVM_DLL Array<Iterator> split(int stage_id, const Iterator& it,
                                 const Array<Optional<Integer>>& lengths,
                                 bool inner_to_outer = true);
+  /*!
+   * \brief Schedule primitive corresponds to `te.Stage.storage_align`.
+   * \param stage_id The index of the stage to be aligned.
+   * \param it The iterator to be aligned.
+   * \param factor The factor in alignment specification.
+   * \param offset The offset in the alignment specification.
+   */
+  TVM_DLL void storage_align(int stage_id, const Iterator& it, int factor, int offset);
 
   /********** Step APIs working on multiple stages **********/
 
@@ -399,8 +414,8 @@ class State : public ObjectRef {
    * \note Cache read step will add an extra stage to the original ComputeDAG (at the back of the
    * target stage), a up-to-date ComputeDAG is stored in State's `current_compute_dag`.
    */
-  int cache_read(int stage_id, const String& scope_name, const Array<Integer>& reader_stage_ids,
-                 const ComputeDAG& dag);
+  TVM_DLL int cache_read(int stage_id, const String& scope_name,
+                         const Array<Integer>& reader_stage_ids, const ComputeDAG& dag);
   /*!
    * \brief Schedule primitive corresponds to `te::Schedule::cache_write`.
    * \param stage_id The index of the stage to be cache write.
@@ -410,7 +425,17 @@ class State : public ObjectRef {
    * target stage), a up-to-date ComputeDAG is stored in State's `current_compute_dag`.
    * This step will cache write all output tensors of the target stage.
    */
-  int cache_write(int stage_id, const String& scope_name, const ComputeDAG& dag);
+  TVM_DLL int cache_write(int stage_id, const String& scope_name, const ComputeDAG& dag);
+  /*!
+   * \brief Schedule primitive corresponds to `te::Schedule::rfactor`.
+   * \param stage_id The index of the iterator to be factored.
+   * \param iter_id The iterator to be factored.
+   * \param factor_iter_id The position where the new iterator is placed.
+   * \param dag The original ComputeDAG of this state.
+   * \note Rfactor step will add an extra stage to the original ComputeDAG, a up-to-date
+   * ComputeDAG is stored in State's `current_compute_dag`.
+   */
+  TVM_DLL int rfactor(int stage_id, const Iterator& it, int factor_iter_id, const ComputeDAG& dag);
 
   TVM_DEFINE_OBJECT_REF_METHODS(State, ObjectRef, StateNode);
   TVM_DEFINE_OBJECT_REF_COW_METHOD(StateNode);
diff --git a/include/tvm/auto_scheduler/transform_step.h b/include/tvm/auto_scheduler/transform_step.h
index 83d6e298a7d7d..57dc6ac6dcf77 100644
--- a/include/tvm/auto_scheduler/transform_step.h
+++ b/include/tvm/auto_scheduler/transform_step.h
@@ -347,6 +347,67 @@ class FuseStep : public Step {
   TVM_DEFINE_OBJECT_REF_METHODS(FuseStep, Step, FuseStepNode);
 };
 
+/*! \brief Pragma step that corresponds to te::Stage::pragma */
+class PragmaStepNode : public StepNode {
+ public:
+  /*! \brief The index of the iterator to add pragma. */
+  int iter_id;
+  /*! \brief The pragma string. */
+  String pragma_type;
+
+  void WriteToRecord(dmlc::JSONWriter* writer) const final;
+
+  /*!
+   * \brief Apply the current step to State.
+   * \param state A mutable pointer to State.
+   */
+  void ApplyToState(State* state) const;
+
+  /*!
+   * \brief Apply the current step to tvm.schedule.
+   * \param stages A pointer to a `te::Stage` Array.
+   * \param stage_to_axes A pointer to a StageToAxesMap.
+   */
+  void ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
+
+  /*!
+   * \brief Print the current step as equivalent python schedule API.
+   * \param stages A pointer to a `te::Stage` Array.
+   * \param stage_to_axes A pointer to a StageToAxesMap.
+   * \return Python schedule code.
+   */
+  String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
+
+  static constexpr const char* record_prefix_str = "PR";
+
+  static constexpr const char* _type_key = "auto_scheduler.PragmaStep";
+  TVM_DECLARE_FINAL_OBJECT_INFO(PragmaStepNode, Object);
+};
+
+/*!
+ * \brief Managed reference to PragmaStepNode.
+ * \sa PragmaStepNode
+ */
+class PragmaStep : public Step {
+ public:
+  /*!
+   * \brief The constructor.
+   * \param stage_id The index of the stage to be fused.
+   * \param iter_id The index of the iterator to add pragma.
+   * \param pragma_type The pragma string.
+   */
+  PragmaStep(int stage_id, int iter_id, String pragma_type);
+
+  /*!
+   * \brief The constructor used to read a step record from JSONReader and create the
+   * corresponding step.
+   * \param reader The input JSONReader.
+   */
+  explicit PragmaStep(dmlc::JSONReader* reader);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(PragmaStep, Step, PragmaStepNode);
+};
+
 /*! \brief Reorder step that corresponds to te::Stage::reorder */
 class ReorderStepNode : public StepNode {
  public:
@@ -487,6 +548,70 @@ class SplitStep : public Step {
   TVM_DEFINE_OBJECT_REF_METHODS(SplitStep, Step, SplitStepNode);
 };
 
+/*! \brief Storage align step that corresponds to te::Stage::storage_align */
+class StorageAlignStepNode : public StepNode {
+ public:
+  /*! \brief The iterator to be aligned. */
+  int iter_id;
+  /*! \brief The factor in alignment specification. */
+  int factor;
+  /*! \brief The offset in the alignment specification. */
+  int offset;
+
+  void WriteToRecord(dmlc::JSONWriter* writer) const final;
+
+  /*!
+   * \brief Apply the current step to State.
+   * \param state A mutable pointer to State.
+   */
+  void ApplyToState(State* state) const;
+
+  /*!
+   * \brief Apply the current step to tvm.schedule.
+   * \param stages A pointer to a `te::Stage` Array.
+   * \param stage_to_axes A pointer to a StageToAxesMap.
+   */
+  void ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
+
+  /*!
+   * \brief Print the current step as equivalent python schedule API.
+   * \param stages A pointer to a `te::Stage` Array.
+   * \param stage_to_axes A pointer to a StageToAxesMap.
+   * \return Python schedule code.
+   */
+  String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;
+
+  static constexpr const char* record_prefix_str = "SA";
+
+  static constexpr const char* _type_key = "auto_scheduler.StorageAlignStep";
+  TVM_DECLARE_FINAL_OBJECT_INFO(StorageAlignStepNode, Object);
+};
+
+/*!
+ * \brief Managed reference to StorageAlignStepNode.
+ * \sa StorageAlignStepNode
+ */
+class StorageAlignStep : public Step {
+ public:
+  /*!
+   * \brief The constructor.
+   * \param stage_id The index of the stage to be aligned.
+   * \param iter_id The index of the iterator to be aligned.
+   * \param factor The factor in alignment specification.
+   * \param offset The offset in the alignment specification.
+   */
+  StorageAlignStep(int stage_id, int iter_id, int factor, int offset);
+
+  /*!
+   * \brief The constructor used to read a step record from JSONReader and create the
+   * corresponding step.
+   * \param reader The input JSONReader.
+   */
+  explicit StorageAlignStep(dmlc::JSONReader* reader);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(StorageAlignStep, Step, StorageAlignStepNode);
+};
+
 /********** Steps working on multiple stages **********/
 
 /*! \brief Compute at step that corresponds to te::Stage::compute_at */
@@ -668,7 +793,7 @@ class ComputeRootStep : public Step {
   TVM_DEFINE_OBJECT_REF_METHODS(ComputeRootStep, Step, ComputeRootStepNode);
 };
 
-/********** Primitives adding new stages **********/
+/********** Steps adding new stages **********/
 
 /*!
  * \brief Cache read step that corresponds to te::Schedule::cache_read.
@@ -812,6 +937,74 @@ class CacheWriteStep : public Step {
   TVM_DEFINE_OBJECT_REF_METHODS(CacheWriteStep, Step, CacheWriteStepNode);
 };
 
+/*! \brief Reduction factor step that corresponds to te::Schedule::rfactor */
+class RfactorStepNode : public StepNode {
+ public:
+  /*! \brief The index of the iterator to be factored. */
+  int iter_id;
+  /*! \brief The position where the new iterator is placed. */
+  int factor_iter_id;
+
+  void WriteToRecord(dmlc::JSONWriter* writer) const final;
+
+  /*!
+   * \brief Apply the current step to State.
+   * \param state A mutable pointer to State.
+   * \param dag The original ComputeDAG of this state.
+   * \return The index of the new added stage.
+   */
+  int ApplyToState(State* state, const ComputeDAG& dag) const;
+
+  /*!
+   * \brief Apply the current step to tvm.schedule.
+   * \param stages A mutable pointer to a `te::Stage` Array.
+   * \param stage_to_axes A mutable pointer to a StageToAxesMap.
+   * \param schedule A mutable pointer to a te::Schedule.
+   * \return The output Tensors of the new added stage.
+   */
+  Array<te::Tensor> ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
+                                    te::Schedule* schedule) const;
+
+  /*!
+   * \brief Print the current step as equivalent python schedule API.
+   * \param stages A mutable pointer to a `te::Stage` Array.
+   * \param stage_to_axes A mutable pointer to a StageToAxesMap.
+   * \param schedule A mutable pointer to a te::Schedule.
+   * \return Python schedule code.
+   */
+  String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
+                          te::Schedule* schedule) const;
+
+  static constexpr const char* record_prefix_str = "RF";
+
+  static constexpr const char* _type_key = "auto_scheduler.RfactorStep";
+  TVM_DECLARE_FINAL_OBJECT_INFO(RfactorStepNode, Object);
+};
+
+/*!
+ * \brief Managed reference to RfactorStepNode.
+ * \sa RfactorStepNode
+ */
+class RfactorStep : public Step {
+ public:
+  /*!
+   * \brief The constructor.
+   * \param stage_id The index of the iterator to be factored.
+   * \param iter_id The index of the iterator to be factored.
+   * \param factor_iter_id The position where the new iterator is placed.
+   */
+  RfactorStep(int stage_id, int iter_id, int factor_iter_id);
+
+  /*!
+   * \brief The constructor used to read a step record from JSONReader and create the
+   * corresponding step.
+   * \param reader The input JSONReader.
+   */
+  explicit RfactorStep(dmlc::JSONReader* reader);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(RfactorStep, Step, RfactorStepNode);
+};
+
 }  // namespace auto_scheduler
 }  // namespace tvm
 
diff --git a/python/tvm/auto_scheduler/loop_state.py b/python/tvm/auto_scheduler/loop_state.py
index 8c3a936ccf0cc..c5f512b721f25 100644
--- a/python/tvm/auto_scheduler/loop_state.py
+++ b/python/tvm/auto_scheduler/loop_state.py
@@ -252,6 +252,22 @@ def fuse(self, stage, iters):
                                                     self._resolve_stage_id(stage), iters)
         return res
 
+    def pragma(self, stage, iterator, pragma_type):
+        """ Schedule primitive corresponds to te.pragma.
+
+        Parameters
+        ----------
+        stage : Union[int, Operation, Tensor]
+            The Stage to add pragma, which can be specified by the integer index, Operation,
+            or output tensor of the stage.
+        iterator : Iterator
+            The iterator to add pragma.
+        pragma_type : str
+            The pragma string.
+        """
+        self.state_object = _ffi_api.StatePragma(self.state_object, self._resolve_stage_id(stage),
+                                                 iterator, pragma_type)
+
     def reorder(self, stage, order):
         """ Schedule primitive corresponds to `te.Stage.reorder`, see also the `te.Stage` for more
         details.
@@ -301,6 +317,27 @@ def split(self, stage, iterator, lengths, inner_to_outer=True):
                                                      iterator, lengths, inner_to_outer)
         return res
 
+    def storage_align(self, stage, iterator, factor, offset):
+        """ Schedule primitive corresponds to te.storage_align.
+
+        See `te.schedule.Stage.storage_align` for more information.
+
+        Parameters
+        ----------
+        stage : Union[int, Operation, Tensor]
+            The Stage to be storage aligned, which can be specified by the integer index,
+            Operation, or output tensor of the stage.
+        iterator : Iterator
+            The iterator to be aligned.
+        factor : int
+            The factor in alignment specification.
+        offset : int
+            The offset in the alignment specification.
+        """
+        self.state_object = _ffi_api.StateStorageAlign(self.state_object,
+                                                       self._resolve_stage_id(stage), iterator,
+                                                       factor, offset)
+
     def compute_at(self, stage, target_stage, target_iter):
         """ Schedule primitive corresponds to `te.Stage.compute_at`, see also the `te.Stage` for
         more details.
@@ -429,6 +466,37 @@ def cache_write(self, stage, scope_name):
         self._update_stage_id_map()
         return self.stages[int(new_stage_id)].op
 
+    def rfactor(self, stage, iterator, factor_iter_id):
+        """ Schedule primitive corresponds to te.schedule.rfactor.
+
+        See `te.schedule.Schedule.rfactor` for more information.
+
+        Parameters
+        ----------
+        stage : Union[int, Operation, Tensor]
+            The Stage to be factored, which can be specified by the integer index, Operation,
+            or output tensor of the stage.
+        iterator : Iterator
+            The reduction iterator to be factored.
+        factor_iter_id : int
+            The position where the new iterator is placed.
+
+        Returns
+        -------
+        new_stage_op : Operator
+            The Operator of the new added stage.
+
+        Notes
+        -----
+        Rfactor step will insert an extra stage to the original ComputeDAG (in the front of the
+        target stage).
+        """
+        self.state_object, new_stage_id = _ffi_api.StateRfactor(self.state_object,
+                                                                self._resolve_stage_id(stage),
+                                                                iterator, factor_iter_id,
+                                                                self.compute_dag)
+        return self._insert_new_stage(int(new_stage_id))
+
     def copy(self):
         """ Do deep copy of this State. """
         state = State(self.state_object, self.compute_dag)
diff --git a/src/auto_scheduler/loop_state.cc b/src/auto_scheduler/loop_state.cc
index 67c6b38845c32..481ca0f762419 100644
--- a/src/auto_scheduler/loop_state.cc
+++ b/src/auto_scheduler/loop_state.cc
@@ -247,6 +247,13 @@ Iterator State::fuse(int stage_id, const Array<Iterator>& iters) {
   return step->ApplyToState(this);
 }
 
+void State::pragma(int stage_id, const Iterator& it, const String& pragma_type) {
+  const Stage& stage = operator->()->stages[stage_id];
+  PragmaStep step = PragmaStep(stage_id, GetIndex(stage->iters, it), pragma_type);
+  CopyOnWrite()->transform_steps.push_back(step);
+  return step->ApplyToState(this);
+}
+
 void State::reorder(int stage_id, const Array<Iterator>& order) {
   const Stage& stage = operator->()->stages[stage_id];
   CHECK_EQ(order.size(), stage->iters.size()) << "The order of all iterators "
@@ -268,6 +275,13 @@ Array<Iterator> State::split(int stage_id, const Iterator& it,
   return step->ApplyToState(this);
 }
 
+void State::storage_align(int stage_id, const Iterator& it, int factor, int offset) {
+  const Stage& stage = operator->()->stages[stage_id];
+  StorageAlignStep step = StorageAlignStep(stage_id, GetIndex(stage->iters, it), factor, offset);
+  CopyOnWrite()->transform_steps.push_back(step);
+  return step->ApplyToState(this);
+}
+
 void State::compute_at(int stage_id, int target_stage_id, const Iterator& target_iter) {
   const Stage& target_stage = operator->()->stages[target_stage_id];
   ComputeAtStep step =
@@ -301,6 +315,13 @@ int State::cache_write(int stage_id, const String& scope_name, const ComputeDAG&
   return step->ApplyToState(this, dag);
 }
 
+int State::rfactor(int stage_id, const Iterator& it, int factor_iter_id, const ComputeDAG& dag) {
+  const Stage& stage = operator->()->stages[stage_id];
+  RfactorStep step = RfactorStep(stage_id, GetIndex(stage->iters, it), factor_iter_id);
+  CopyOnWrite()->transform_steps.push_back(step);
+  return step->ApplyToState(this, dag);
+}
+
 void State::ApplySteps(const ComputeDAG& dag) {
   CHECK(operator->()->stages.size()) << "Invalid State with empty operation stages.";
 
@@ -441,6 +462,12 @@ TVM_REGISTER_GLOBAL("auto_scheduler.StateFuse")
       return Array<ObjectRef>{state, res};
     });
 
+TVM_REGISTER_GLOBAL("auto_scheduler.StatePragma")
+    .set_body_typed([](State state, int stage_id, const Iterator& it, const String& pragma_type) {
+      state.pragma(stage_id, it, pragma_type);
+      return state;
+    });
+
 TVM_REGISTER_GLOBAL("auto_scheduler.StateReorder")
     .set_body_typed([](State state, int stage_id, const Array<Iterator>& order) {
       state.reorder(stage_id, order);
@@ -454,6 +481,12 @@ TVM_REGISTER_GLOBAL("auto_scheduler.StateSplit")
       return Array<ObjectRef>{state, res};
     });
 
+TVM_REGISTER_GLOBAL("auto_scheduler.StateStorageAlign")
+    .set_body_typed([](State state, int stage_id, const Iterator& it, int factor, int offset) {
+      state.storage_align(stage_id, it, factor, offset);
+      return state;
+    });
+
 TVM_REGISTER_GLOBAL("auto_scheduler.StateComputeAt")
     .set_body_typed([](State state, int stage_id, int target_stage_id,
                        const Iterator& target_iter) {
@@ -487,6 +520,13 @@ TVM_REGISTER_GLOBAL("auto_scheduler.StateCacheWrite")
       return Array<ObjectRef>{state, Integer(res)};
     });
 
+TVM_REGISTER_GLOBAL("auto_scheduler.StateRfactor")
+    .set_body_typed([](State state, int stage_id, const Iterator& it, int factor_iter_id,
+                       const ComputeDAG& dag) {
+      int res = state.rfactor(stage_id, it, factor_iter_id, dag);
+      return Array<ObjectRef>{state, Integer(res)};
+    });
+
 TVM_REGISTER_GLOBAL("auto_scheduler.StateEqual").set_body_typed([](State state1, State state2) {
   return std::equal_to<State>()(state1, state2);
 });
diff --git a/src/auto_scheduler/transform_step.cc b/src/auto_scheduler/transform_step.cc
index 5c5cc4b2e760f..2eae04e828633 100644
--- a/src/auto_scheduler/transform_step.cc
+++ b/src/auto_scheduler/transform_step.cc
@@ -81,10 +81,14 @@ Step StepReadFromRecord(dmlc::JSONReader* reader) {
     return AnnotationStep(reader);
   } else if (name == FuseStepNode::record_prefix_str) {
     return FuseStep(reader);
+  } else if (name == PragmaStepNode::record_prefix_str) {
+    return PragmaStep(reader);
   } else if (name == ReorderStepNode::record_prefix_str) {
     return ReorderStep(reader);
   } else if (name == SplitStepNode::record_prefix_str) {
     return SplitStep(reader);
+  } else if (name == StorageAlignStepNode::record_prefix_str) {
+    return StorageAlignStep(reader);
   } else if (name == ComputeAtStepNode::record_prefix_str) {
     return ComputeAtStep(reader);
   } else if (name == ComputeInlineStepNode::record_prefix_str) {
@@ -95,6 +99,8 @@ Step StepReadFromRecord(dmlc::JSONReader* reader) {
     return CacheReadStep(reader);
   } else if (name == CacheWriteStepNode::record_prefix_str) {
     return CacheWriteStep(reader);
+  } else if (name == RfactorStepNode::record_prefix_str) {
+    return RfactorStep(reader);
   } else {
     LOG(FATAL) << "Invalid step format: " << name;
   }
@@ -107,10 +113,14 @@ void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag) {
     ps->ApplyToState(state);
   } else if (auto ps = step.as<FuseStepNode>()) {
     ps->ApplyToState(state);
+  } else if (auto ps = step.as<PragmaStepNode>()) {
+    ps->ApplyToState(state);
   } else if (auto ps = step.as<ReorderStepNode>()) {
     ps->ApplyToState(state);
   } else if (auto ps = step.as<SplitStepNode>()) {
     ps->ApplyToState(state);
+  } else if (auto ps = step.as<StorageAlignStepNode>()) {
+    ps->ApplyToState(state);
   } else if (auto ps = step.as<ComputeAtStepNode>()) {
     ps->ApplyToState(state);
   } else if (auto ps = step.as<ComputeInlineStepNode>()) {
@@ -121,6 +131,8 @@ void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag) {
     ps->ApplyToState(state, dag);
   } else if (auto ps = step.as<CacheWriteStepNode>()) {
     ps->ApplyToState(state, dag);
+  } else if (auto ps = step.as<RfactorStepNode>()) {
+    ps->ApplyToState(state, dag);
   } else {
     LOG(FATAL) << "Invalid step: " << step;
   }
@@ -132,10 +144,14 @@ void StepApplyToSchedule(const Step& step, Array<te::Stage>* stages, StageToAxes
     ps->ApplyToSchedule(stages, stage_to_axes);
   } else if (auto ps = step.as<FuseStepNode>()) {
     ps->ApplyToSchedule(stages, stage_to_axes);
+  } else if (auto ps = step.as<PragmaStepNode>()) {
+    ps->ApplyToSchedule(stages, stage_to_axes);
   } else if (auto ps = step.as<ReorderStepNode>()) {
     ps->ApplyToSchedule(stages, stage_to_axes);
   } else if (auto ps = step.as<SplitStepNode>()) {
     ps->ApplyToSchedule(stages, stage_to_axes);
+  } else if (auto ps = step.as<StorageAlignStepNode>()) {
+    ps->ApplyToSchedule(stages, stage_to_axes);
   } else if (auto ps = step.as<ComputeAtStepNode>()) {
     ps->ApplyToSchedule(stages, stage_to_axes);
   } else if (auto ps = step.as<ComputeInlineStepNode>()) {
@@ -146,6 +162,8 @@ void StepApplyToSchedule(const Step& step, Array<te::Stage>* stages, StageToAxes
     ps->ApplyToSchedule(stages, stage_to_axes, schedule);
   } else if (auto ps = step.as<CacheWriteStepNode>()) {
     ps->ApplyToSchedule(stages, stage_to_axes, schedule);
+  } else if (auto ps = step.as<RfactorStepNode>()) {
+    ps->ApplyToSchedule(stages, stage_to_axes, schedule);
   } else {
     LOG(FATAL) << "Invalid Step: " << step;
   }
@@ -157,10 +175,14 @@ String StepPrintAsPythonAPI(const Step& step, Array<te::Stage>* stages,
     return ps->PrintAsPythonAPI(stages, stage_to_axes);
   } else if (auto ps = step.as<FuseStepNode>()) {
     return ps->PrintAsPythonAPI(stages, stage_to_axes);
+  } else if (auto ps = step.as<PragmaStepNode>()) {
+    return ps->PrintAsPythonAPI(stages, stage_to_axes);
   } else if (auto ps = step.as<ReorderStepNode>()) {
     return ps->PrintAsPythonAPI(stages, stage_to_axes);
   } else if (auto ps = step.as<SplitStepNode>()) {
     return ps->PrintAsPythonAPI(stages, stage_to_axes);
+  } else if (auto ps = step.as<StorageAlignStepNode>()) {
+    return ps->PrintAsPythonAPI(stages, stage_to_axes);
   } else if (auto ps = step.as<ComputeAtStepNode>()) {
     return ps->PrintAsPythonAPI(stages, stage_to_axes);
   } else if (auto ps = step.as<ComputeInlineStepNode>()) {
@@ -171,6 +193,8 @@ String StepPrintAsPythonAPI(const Step& step, Array<te::Stage>* stages,
     return ps->PrintAsPythonAPI(stages, stage_to_axes, schedule);
   } else if (auto ps = step.as<CacheWriteStepNode>()) {
     return ps->PrintAsPythonAPI(stages, stage_to_axes, schedule);
+  } else if (auto ps = step.as<RfactorStepNode>()) {
+    return ps->PrintAsPythonAPI(stages, stage_to_axes, schedule);
   } else {
     LOG(FATAL) << "Invalid Step: " << step;
   }
@@ -471,6 +495,115 @@ String FuseStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
   return ss.str();
 }
 
+/********** Pragma **********/
+PragmaStep::PragmaStep(int stage_id, int iter_id, String pragma_type) {
+  auto node = make_object<PragmaStepNode>();
+  node->stage_id = stage_id;
+  node->iter_id = iter_id;
+  node->pragma_type = std::move(pragma_type);
+  data_ = std::move(node);
+}
+
+PragmaStep::PragmaStep(dmlc::JSONReader* reader) {
+  auto node = make_object<PragmaStepNode>();
+  bool s;
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->stage_id);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->iter_id);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  std::string string_value;
+  reader->Read(&string_value);
+  node->pragma_type = std::move(string_value);
+  data_ = std::move(node);
+}
+
+void PragmaStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
+  writer->WriteArraySeperator();
+  writer->WriteString(record_prefix_str);
+  writer->WriteArrayItem(stage_id);
+  writer->WriteArrayItem(iter_id);
+  writer->WriteArraySeperator();
+  writer->WriteString(pragma_type);
+}
+
+void PragmaStepNode::ApplyToState(State* state) const {
+  if (pragma_type == "debug_skip_region") {
+    StateNode* pstate = state->CopyOnWrite();
+    pstate->attach_map.DeleteStage(stage_id);
+  } else if (StrStartsWith(pragma_type, "auto_unroll_max_step")) {
+    StateNode* pstate = state->CopyOnWrite();
+    Stage stage = pstate->stages[stage_id];
+    size_t pos = 0;
+    for (; pos < pragma_type.size(); ++pos) {
+      if ((*(pragma_type.c_str() + pos)) == '$') {
+        break;
+      }
+    }
+    CHECK_LT(pos, pragma_type.size()) << "max step value not found.";
+    stage.CopyOnWrite()->attrs.auto_unroll_max_step = atoi(pragma_type.c_str() + pos + 1);
+    pstate->stages.Set(stage_id, std::move(stage));
+  } else if (pragma_type == "tensor_core") {
+    // Nothing needs to be done here
+  } else {
+    LOG(FATAL) << "Invalid pragma: " << pragma_type;
+  }
+}
+
+void PragmaStepNode::ApplyToSchedule(Array<te::Stage>* stages,
+                                     StageToAxesMap* stage_to_axes) const {
+  te::Stage stage = (*stages)[stage_id];
+  const Array<IterVar>& axes = (*stage_to_axes)[stage];
+  if (StrStartsWith(pragma_type, "auto_unroll_max_step")) {
+    size_t pos = 0;
+    for (; pos < pragma_type.size(); ++pos) {
+      if ((*(pragma_type.c_str() + pos)) == '$') {
+        break;
+      }
+    }
+    CHECK_LT(pos, pragma_type.size()) << "max step value not found.";
+    int value = atoi(pragma_type.c_str() + pos + 1);
+    stage.pragma(axes[iter_id], "auto_unroll_max_step", value);
+    stage.pragma(axes[iter_id], "unroll_explicit", true);
+  } else {
+    stage.pragma(axes[iter_id], pragma_type);
+  }
+  stages->Set(stage_id, std::move(stage));
+}
+
+String PragmaStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
+                                        StageToAxesMap* stage_to_axes) const {
+  std::stringstream ss;
+  const auto& stage = (*stages)[stage_id];
+
+  if (StrStartsWith(pragma_type, "auto_unroll_max_step")) {
+    size_t pos = 0;
+    for (; pos < pragma_type.size(); ++pos) {
+      if ((*(pragma_type.c_str() + pos)) == '$') {
+        break;
+      }
+    }
+    CHECK_LT(pos, pragma_type.size()) << "max step value not found.";
+    int value = atoi(pragma_type.c_str() + pos + 1);
+    ss << "s[" << CleanName(stage->op->name) << "].pragma("
+       << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint)
+       << ", \"auto_unroll_max_step\", " << value << ")\n";
+    ss << "s[" << CleanName(stage->op->name) << "].pragma("
+       << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint)
+       << ", \"unroll_explicit\", True)\n";
+  } else {
+    ss << "s[" << CleanName(stage->op->name) << "].pragma("
+       << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint) << ", \"" << pragma_type
+       << "\")\n";
+  }
+
+  ApplyToSchedule(stages, stage_to_axes);
+  return ss.str();
+}
+
 /********** Reorder **********/
 ReorderStep::ReorderStep(int stage_id, const Array<Integer>& after_ids) {
   auto node = make_object<ReorderStepNode>();
@@ -776,6 +909,70 @@ String SplitStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
   return PrintSplitAsPythonAPI(stages, stage_to_axes, stage_id, iter_id, lengths, inner_to_outer);
 }
 
+/********** Storage Align **********/
+StorageAlignStep::StorageAlignStep(int stage_id, int iter_id, int factor, int offset) {
+  auto node = make_object<StorageAlignStepNode>();
+  node->stage_id = stage_id;
+  node->iter_id = iter_id;
+  node->factor = factor;
+  node->offset = offset;
+  data_ = std::move(node);
+}
+
+StorageAlignStep::StorageAlignStep(dmlc::JSONReader* reader) {
+  auto node = make_object<StorageAlignStepNode>();
+  bool s;
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->stage_id);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->iter_id);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->factor);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->offset);
+  data_ = std::move(node);
+}
+
+void StorageAlignStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
+  writer->WriteArraySeperator();
+  writer->WriteString(record_prefix_str);
+  writer->WriteArrayItem(stage_id);
+  writer->WriteArrayItem(iter_id);
+  writer->WriteArrayItem(factor);
+  writer->WriteArrayItem(offset);
+}
+
+void StorageAlignStepNode::ApplyToState(State* state) const {
+  StateNode* pstate = state->CopyOnWrite();
+  Stage stage = pstate->stages[stage_id];
+  stage.CopyOnWrite()->attrs.storage_offset = offset;
+  pstate->stages.Set(stage_id, std::move(stage));
+}
+
+void StorageAlignStepNode::ApplyToSchedule(Array<te::Stage>* stages,
+                                           StageToAxesMap* stage_to_axes) const {
+  te::Stage stage = (*stages)[stage_id];
+  const Array<IterVar>& axes = (*stage_to_axes)[stage];
+  stage.storage_align(axes[iter_id], factor, offset);
+  stages->Set(stage_id, std::move(stage));
+}
+
+String StorageAlignStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
+                                              StageToAxesMap* stage_to_axes) const {
+  std::stringstream ss;
+  const auto& stage = (*stages)[stage_id];
+  ss << "s[" << CleanName(stage->op->name) << "].storage_align("
+     << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint) << ", " << factor << ", "
+     << offset << ")\n";
+
+  ApplyToSchedule(stages, stage_to_axes);
+  return ss.str();
+}
+
 /********** Steps working on multiple stages **********/
 
 /********** Compute At **********/
@@ -958,7 +1155,7 @@ String ComputeRootStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
   return ss.str();
 }
 
-/********** Primitives adding new stages **********/
+/********** Steps adding new stages **********/
 
 /*!
  * \brief Common part for steps that add new stages(e.g. CacheReadStep, CacheWriteStep,
@@ -967,11 +1164,27 @@ String ComputeRootStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
  */
 Array<Step> GetFormerStageModifiableSteps(Step current_step, const Array<Step>& transform_steps) {
   Array<Step> ret_steps;
-  for (const Step& step : transform_steps) {
+  for (size_t i = 0; i < transform_steps.size(); ++i) {
+    const Step& step = transform_steps[i];
     if (step->IsInstance<CacheWriteStepNode>() || step->IsInstance<CacheReadStepNode>()) {
       ret_steps.push_back(step);
+    } else if (step->IsInstance<RfactorStepNode>()) {
+      // add FuseStepNode required by rfactor
+      if (i >= 2 && transform_steps[i - 2]->IsInstance<FuseStepNode>()) {
+        const Step& fuse_step = transform_steps[i - 2];
+        if (fuse_step->stage_id == step->stage_id) {
+          ret_steps.push_back(fuse_step);
+        }
+      }
+      // add SplitStepNode required by rfactor
+      CHECK_GE(i, 1);
+      CHECK(transform_steps[i - 1]->IsInstance<SplitStepNode>());
+      const Step& split_step = transform_steps[i - 1];
+      CHECK_EQ(split_step->stage_id, step->stage_id);
+      ret_steps.push_back(split_step);
+      // add RfactorStepNode
+      ret_steps.push_back(step);
     }
-    // TODO(jcf94): add rfactor support
     // A state may have multiple stage modifiable steps, stop by the current step to avoid
     // replaying excess steps
     if (step.same_as(current_step)) {
@@ -1228,5 +1441,136 @@ String CacheWriteStepNode::PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxe
   return ss.str();
 }
 
+/********** Rfactor **********/
+RfactorStep::RfactorStep(int stage_id, int iter_id, int factor_iter_id) {
+  auto node = make_object<RfactorStepNode>();
+  node->stage_id = stage_id;
+  node->iter_id = iter_id;
+  node->factor_iter_id = factor_iter_id;
+  data_ = std::move(node);
+}
+
+RfactorStep::RfactorStep(dmlc::JSONReader* reader) {
+  auto node = make_object<RfactorStepNode>();
+  bool s;
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->stage_id);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->iter_id);
+  s = reader->NextArrayItem();
+  CHECK(s);
+  reader->Read(&node->factor_iter_id);
+  data_ = std::move(node);
+}
+
+void RfactorStepNode::WriteToRecord(dmlc::JSONWriter* writer) const {
+  writer->WriteArraySeperator();
+  writer->WriteString(record_prefix_str);
+  writer->WriteArrayItem(stage_id);
+  writer->WriteArrayItem(iter_id);
+  writer->WriteArrayItem(factor_iter_id);
+}
+
+int RfactorStepNode::ApplyToState(State* state, const ComputeDAG& dag) const {
+  StateNode* pstate = state->CopyOnWrite();
+  const auto& compute_at_type = pstate->stages[stage_id]->compute_at;
+  Array<Step> replay_steps;
+  for (size_t i = 0; i < pstate->transform_steps.size(); ++i) {
+    AddStageModificationSteps(i, pstate->transform_steps, &replay_steps);
+    if (pstate->transform_steps[i].same_as(GetRef<Step>(this))) {
+      break;
+    }
+  }
+  const ComputeDAG& current_compute_dag = dag.ReplayAndGetDAG(replay_steps);
+
+  // target -> target_compute + target
+  // Should insert new stage, update target stage, update the later stage's op
+  pstate->stages.insert(pstate->stages.begin() + stage_id,
+                        Stage(current_compute_dag->ops[stage_id]));
+  // maintain the compute_at type of target stage
+  Stage target_stage = Stage(current_compute_dag->ops[stage_id + 1]);
+  target_stage.CopyOnWrite()->compute_at = compute_at_type;
+  pstate->stages.Set(stage_id + 1, std::move(target_stage));
+
+  for (size_t i = stage_id + 2; i < pstate->stages.size(); ++i) {
+    Stage stage = pstate->stages[i];
+    stage.CopyOnWrite()->op = current_compute_dag->ops[i];
+    pstate->stages.Set(i, std::move(stage));
+  }
+  pstate->attach_map = pstate->attach_map.ApplyStageIdOfffset(stage_id, 1);
+
+  return stage_id;
+}
+
+Array<te::Tensor> RfactorStepNode::ApplyToSchedule(Array<te::Stage>* stages,
+                                                   StageToAxesMap* stage_to_axes,
+                                                   te::Schedule* schedule) const {
+  const auto& stage = (*stages)[stage_id];
+  const Array<IterVar>& axes = (*stage_to_axes)[stage];
+
+  const te::Tensor& tensor = stage->origin_op.output(0);
+  const IterVar& axis = axes[iter_id];
+  auto outs = schedule->rfactor(tensor, axis, factor_iter_id);
+
+  UpdateStageToAxesMap(stage, stage_to_axes);
+
+  const auto& new_stage = (*schedule)[outs[0]->op];
+  UpdateStageToAxesMap(new_stage, stage_to_axes);
+  stages->insert(stages->begin() + stage_id, new_stage);
+
+  return outs;
+}
+
+String RfactorStepNode::PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
+                                         te::Schedule* schedule) const {
+  std::stringstream ss;
+  const auto& stage = (*stages)[stage_id];
+
+  const auto& tensor_name = CleanName(stage->origin_op.output(0)->op->name);
+  const auto& axis_name = CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint);
+
+  const auto& outs = ApplyToSchedule(stages, stage_to_axes, schedule);
+
+  for (size_t i = 0; i < outs.size(); ++i) {
+    ss << CleanName(outs[i]->op->name);
+    if (i != outs.size() - 1) {
+      ss << ", ";
+    }
+  }
+  ss << " = "
+     << "s.rfactor(" << tensor_name << ", " << axis_name << ", " << factor_iter_id << ")\n";
+
+  for (const auto& out : outs) {
+    const auto& iters = out->op->root_iter_vars();
+    for (size_t i = 0; i < iters.size(); ++i) {
+      ss << CleanName(iters[i]->var->name_hint);
+      if (i != iters.size() - 1) {
+        ss << ", ";
+      }
+    }
+    ss << " = "
+       << "tuple(" << CleanName(out->op->name) << ".op.axis)"
+       << " + "
+       << "tuple(" << CleanName(out->op->name) << ".op.reduce_axis)\n";
+  }
+
+  const auto& output = (*stages)[stage_id + 1]->op.output(0);
+  const auto& iters = output->op->root_iter_vars();
+  for (size_t i = 0; i < iters.size(); ++i) {
+    ss << CleanName(iters[i]->var->name_hint);
+    if (i != iters.size() - 1) {
+      ss << ", ";
+    }
+  }
+  ss << " = "
+     << "tuple(s[" << CleanName(output->op->name) << "].op.axis)"
+     << " + "
+     << "tuple(s[" << CleanName(output->op->name) << "].op.reduce_axis)\n";
+
+  return ss.str();
+}
+
 }  // namespace auto_scheduler
 }  // namespace tvm
diff --git a/src/auto_scheduler/utils.h b/src/auto_scheduler/utils.h
index da5032e11c97a..aacdcf4265f9e 100644
--- a/src/auto_scheduler/utils.h
+++ b/src/auto_scheduler/utils.h
@@ -162,6 +162,12 @@ inline double FloatArrayMean(const Array<PrimExpr>& float_array) {
   return sum / float_array.size();
 }
 
+/*! \brief Return whether a string starts with another substring */
+inline bool StrStartsWith(const String& a, const String& b) {
+  if (b.size() > a.size()) return false;
+  return std::equal(a.c_str(), a.c_str() + b.size(), b.c_str());
+}
+
 /********** Other Utilities **********/
 /*! \brief Get an int value from an Expr */
 inline int64_t GetIntImm(const PrimExpr& expr) {
diff --git a/tests/python/unittest/test_auto_scheduler_loop_state.py b/tests/python/unittest/test_auto_scheduler_loop_state.py
index 8282d4a40e5ef..255acf1dcc68b 100644
--- a/tests/python/unittest/test_auto_scheduler_loop_state.py
+++ b/tests/python/unittest/test_auto_scheduler_loop_state.py
@@ -417,7 +417,61 @@ def test_cache_read_write():
     for it0, it1 in zip(s0[kernel_split].iters, s0[kernel_split_global].iters):
         assert it0.range == it1.range
 
+
+def test_rfactor():
+    A, B, C = matmul_auto_scheduler_test(8, 8, 512)
+    dag = auto_scheduler.ComputeDAG([A, B, C])
+    s0 = dag.get_init_state()
+
+    ko, ki = s0.split(C, s0[C].iters[2], [16])
+
+    s1 = s0.copy()
+    C_r = s1.rfactor(C, ko, 2)
+    """
+        Placeholder: A, B
+        for i (0,8)
+          for j (0,8)
+            for k_o (0,32)
+              for k_i (0,16)
+                C.rf = ...
+        for ax0 (0,8)
+          for ax1 (0,8)
+            for k_o_v (0,32)
+              C.repl = ...
+    """
+    assert s1[C_r].iters[0].range.extent == 8
+    assert s1[C_r].iters[1].range.extent == 8
+    assert s1[C_r].iters[2].range.extent == 32
+    assert s1[C_r].iters[3].range.extent == 16
+    assert s1[C].iters[0].range.extent == 8
+    assert s1[C].iters[1].range.extent == 8
+    assert s1[C].iters[2].range.extent == 32
+
+    s2 = s0.copy()
+    C_r = s2.rfactor(C, ki, 2)
+    """
+        Placeholder: A, B
+        for i (0,8)
+          for j (0,8)
+            for k_i (0,16)
+              for k_o (0,32)
+                C.rf = ...
+        for ax0 (0,8)
+          for ax1 (0,8)
+            for k_i_v (0,16)
+              C.repl = ...
+    """
+    assert s2[C_r].iters[0].range.extent == 8
+    assert s2[C_r].iters[1].range.extent == 8
+    assert s2[C_r].iters[2].range.extent == 16
+    assert s2[C_r].iters[3].range.extent == 32
+    assert s2[C].iters[0].range.extent == 8
+    assert s2[C].iters[1].range.extent == 8
+    assert s2[C].iters[2].range.extent == 16
+
+
 if __name__ == "__main__":
     test_split_fuse_reorder_annotation()
     test_compute_at_root_inline()
     test_cache_read_write()
+    test_rfactor()
diff --git a/tests/python/unittest/test_auto_scheduler_measure.py b/tests/python/unittest/test_auto_scheduler_measure.py
index 5f2f87ad9baa2..a963fa2b16add 100644
--- a/tests/python/unittest/test_auto_scheduler_measure.py
+++ b/tests/python/unittest/test_auto_scheduler_measure.py
@@ -37,8 +37,10 @@ def test_record():
     k = te.reduce_axis((0, 512), name='k')
     E = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * D[k][j], axis=[k]), name='E')
     F = topi.nn.relu(E)
+    k = te.reduce_axis((0, 512), name='k')
+    G = te.compute((512, 512), lambda i, j: te.sum(A[i][k] * F[k][j], axis=[k]), name='G')
 
-    dag = auto_scheduler.ComputeDAG([A, B, F])
+    dag = auto_scheduler.ComputeDAG([A, B, G])
     s = dag.get_init_state()
 
     # Split
@@ -71,6 +73,13 @@ def test_record():
     s.compute_at(D_global, E, s[E].iters[2])
     # Cache Write
     s.cache_write(D, "shared")
+    # Pragma
+    s.pragma(C, s[C].iters[0], "auto_unroll_max_step$64")
+    # StorageAlign
+    s.storage_align(E, s[E].iters[-1], 8, 4)
+    # Rfactor
+    ko, _ = s.split(G, s[G].iters[2], [16])
+    s.rfactor(G, ko, 2)
 
     target = tvm.target.create("llvm")
     task = auto_scheduler.SearchTask(dag, "test", target)