Skip to content

Commit

Permalink
Add pragma/storage_align/rfactor step
Browse files Browse the repository at this point in the history
  • Loading branch information
jcf94 committed Jul 27, 2020
1 parent b8f8b8d commit aeb8c7b
Show file tree
Hide file tree
Showing 8 changed files with 747 additions and 8 deletions.
31 changes: 28 additions & 3 deletions include/tvm/auto_scheduler/loop_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,13 @@ class State : public ObjectRef {
* result will become the new attach point.
*/
TVM_DLL Iterator fuse(int stage_id, const Array<Iterator>& iters);
/*!
* \brief Schedule primitive corresponds to `te.Stage.pragma`.
* \param stage_id The index of the stage to add pragma.
* \param it The iterator to add pragma.
* \param pragma_type The pragma string.
*/
TVM_DLL void pragma(int stage_id, const Iterator& it, const String& pragma_type);
/*!
* \brief Schedule primitive corresponds to `te::Stage::reorder`.
* \param stage_id The index of the stage to be reordered.
Expand All @@ -359,6 +366,14 @@ class State : public ObjectRef {
TVM_DLL Array<Iterator> split(int stage_id, const Iterator& it,
const Array<Optional<Integer>>& lengths,
bool inner_to_outer = true);
/*!
* \brief Schedule primitive corresponds to `te.Stage.storage_align`.
* \param stage_id The index of the stage to be aligned.
* \param it The iterator to be aligned.
* \param factor The factor in alignment specification.
* \param offset The offset in the alignment specification.
*/
TVM_DLL void storage_align(int stage_id, const Iterator& it, int factor, int offset);

/********** Step APIs working on multiple stages **********/

Expand Down Expand Up @@ -399,8 +414,8 @@ class State : public ObjectRef {
* \note Cache read step will add an extra stage to the original ComputeDAG (at the back of the
* target stage), a up-to-date ComputeDAG is stored in State's `current_compute_dag`.
*/
int cache_read(int stage_id, const String& scope_name, const Array<Integer>& reader_stage_ids,
const ComputeDAG& dag);
TVM_DLL int cache_read(int stage_id, const String& scope_name,
const Array<Integer>& reader_stage_ids, const ComputeDAG& dag);
/*!
* \brief Schedule primitive corresponds to `te::Schedule::cache_write`.
* \param stage_id The index of the stage to be cache write.
Expand All @@ -410,7 +425,17 @@ class State : public ObjectRef {
* target stage), a up-to-date ComputeDAG is stored in State's `current_compute_dag`.
* This step will cache write all output tensors of the target stage.
*/
int cache_write(int stage_id, const String& scope_name, const ComputeDAG& dag);
TVM_DLL int cache_write(int stage_id, const String& scope_name, const ComputeDAG& dag);
/*!
* \brief Schedule primitive corresponds to `te::Schedule::rfactor`.
* \param stage_id The index of the iterator to be factored.
* \param iter_id The iterator to be factored.
* \param factor_iter_id The position where the new iterator is placed.
* \param dag The original ComputeDAG of this state.
* \note Rfactor step will add an extra stage to the original ComputeDAG, a up-to-date
* ComputeDAG is stored in State's `current_compute_dag`.
*/
TVM_DLL int rfactor(int stage_id, const Iterator& it, int factor_iter_id, const ComputeDAG& dag);

TVM_DEFINE_OBJECT_REF_METHODS(State, ObjectRef, StateNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(StateNode);
Expand Down
195 changes: 194 additions & 1 deletion include/tvm/auto_scheduler/transform_step.h
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,67 @@ class FuseStep : public Step {
TVM_DEFINE_OBJECT_REF_METHODS(FuseStep, Step, FuseStepNode);
};

/*! \brief Pragma step that corresponds to te::Stage::pragma */
class PragmaStepNode : public StepNode {
public:
/*! \brief The index of the iterator to add pragma. */
int iter_id;
/*! \brief The pragma string. */
String pragma_type;

void WriteToRecord(dmlc::JSONWriter* writer) const final;

/*!
* \brief Apply the current step to State.
* \param state A mutable pointer to State.
*/
void ApplyToState(State* state) const;

/*!
* \brief Apply the current step to tvm.schedule.
* \param stages A pointer to a `te::Stage` Array.
* \param stage_to_axes A pointer to a StageToAxesMap.
*/
void ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;

/*!
* \brief Print the current step as equivalent python schedule API.
* \param stages A pointer to a `te::Stage` Array.
* \param stage_to_axes A pointer to a StageToAxesMap.
* \return Python schedule code.
*/
String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;

static constexpr const char* record_prefix_str = "PR";

static constexpr const char* _type_key = "auto_scheduler.PragmaStep";
TVM_DECLARE_FINAL_OBJECT_INFO(PragmaStepNode, Object);
};

/*!
* \brief Managed reference to PragmaStepNode.
* \sa PragmaStepNode
*/
class PragmaStep : public Step {
public:
/*!
* \brief The constructor.
* \param stage_id The index of the stage to be fused.
* \param iter_id The index of the iterator to add pragma.
* \param pragma_type The pragma string.
*/
PragmaStep(int stage_id, int iter_id, String pragma_type);

/*!
* \brief The constructor used to read a step record from JSONReader and create the
* corresponding step.
* \param reader The input JSONReader.
*/
explicit PragmaStep(dmlc::JSONReader* reader);

TVM_DEFINE_OBJECT_REF_METHODS(PragmaStep, Step, PragmaStepNode);
};

/*! \brief Reorder step that corresponds to te::Stage::reorder */
class ReorderStepNode : public StepNode {
public:
Expand Down Expand Up @@ -487,6 +548,70 @@ class SplitStep : public Step {
TVM_DEFINE_OBJECT_REF_METHODS(SplitStep, Step, SplitStepNode);
};

/*! \brief Storage align step that corresponds to te::Stage::storage_align */
class StorageAlignStepNode : public StepNode {
public:
/*! \brief The iterator to be aligned. */
int iter_id;
/*! \brief The factor in alignment specification. */
int factor;
/*! \brief The offset in the alignment specification. */
int offset;

void WriteToRecord(dmlc::JSONWriter* writer) const final;

/*!
* \brief Apply the current step to State.
* \param state A mutable pointer to State.
*/
void ApplyToState(State* state) const;

/*!
* \brief Apply the current step to tvm.schedule.
* \param stages A pointer to a `te::Stage` Array.
* \param stage_to_axes A pointer to a StageToAxesMap.
*/
void ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;

/*!
* \brief Print the current step as equivalent python schedule API.
* \param stages A pointer to a `te::Stage` Array.
* \param stage_to_axes A pointer to a StageToAxesMap.
* \return Python schedule code.
*/
String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;

static constexpr const char* record_prefix_str = "SA";

static constexpr const char* _type_key = "auto_scheduler.StorageAlignStep";
TVM_DECLARE_FINAL_OBJECT_INFO(StorageAlignStepNode, Object);
};

/*!
* \brief Managed reference to StorageAlignStepNode.
* \sa StorageAlignStepNode
*/
class StorageAlignStep : public Step {
public:
/*!
* \brief The constructor.
* \param stage_id The index of the stage to be aligned.
* \param iter_id The index of the iterator to be aligned.
* \param factor The factor in alignment specification.
* \param offset The offset in the alignment specification.
*/
StorageAlignStep(int stage_id, int iter_id, int factor, int offset);

/*!
* \brief The constructor used to read a step record from JSONReader and create the
* corresponding step.
* \param reader The input JSONReader.
*/
explicit StorageAlignStep(dmlc::JSONReader* reader);

TVM_DEFINE_OBJECT_REF_METHODS(StorageAlignStep, Step, StorageAlignStepNode);
};

/********** Steps working on multiple stages **********/

/*! \brief Compute at step that corresponds to te::Stage::compute_at */
Expand Down Expand Up @@ -668,7 +793,7 @@ class ComputeRootStep : public Step {
TVM_DEFINE_OBJECT_REF_METHODS(ComputeRootStep, Step, ComputeRootStepNode);
};

/********** Primitives adding new stages **********/
/********** Steps adding new stages **********/

/*!
* \brief Cache read step that corresponds to te::Schedule::cache_read.
Expand Down Expand Up @@ -812,6 +937,74 @@ class CacheWriteStep : public Step {
TVM_DEFINE_OBJECT_REF_METHODS(CacheWriteStep, Step, CacheWriteStepNode);
};

/*! \brief Reduction factor step that corresponds to te::Schedule::rfactor */
class RfactorStepNode : public StepNode {
public:
/*! \brief The index of the iterator to be factored. */
int iter_id;
/*! \brief The position where the new iterator is placed. */
int factor_iter_id;

void WriteToRecord(dmlc::JSONWriter* writer) const final;

/*!
* \brief Apply the current step to State.
* \param state A mutable pointer to State.
* \param dag The original ComputeDAG of this state.
* \return The index of the new added stage.
*/
int ApplyToState(State* state, const ComputeDAG& dag) const;

/*!
* \brief Apply the current step to tvm.schedule.
* \param stages A mutable pointer to a `te::Stage` Array.
* \param stage_to_axes A mutable pointer to a StageToAxesMap.
* \param schedule A mutable pointer to a te::Schedule.
* \return The output Tensors of the new added stage.
*/
Array<te::Tensor> ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
te::Schedule* schedule) const;

/*!
* \brief Print the current step as equivalent python schedule API.
* \param stages A mutable pointer to a `te::Stage` Array.
* \param stage_to_axes A mutable pointer to a StageToAxesMap.
* \param schedule A mutable pointer to a te::Schedule.
* \return Python schedule code.
*/
String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
te::Schedule* schedule) const;

static constexpr const char* record_prefix_str = "RF";

static constexpr const char* _type_key = "auto_scheduler.RfactorStep";
TVM_DECLARE_FINAL_OBJECT_INFO(RfactorStepNode, Object);
};

/*!
* \brief Managed reference to RfactorStepNode.
* \sa RfactorStepNode
*/
class RfactorStep : public Step {
public:
/*!
* \brief The constructor.
* \param stage_id The index of the iterator to be factored.
* \param iter_id The index of the iterator to be factored.
* \param factor_iter_id The position where the new iterator is placed.
*/
RfactorStep(int stage_id, int iter_id, int factor_iter_id);

/*!
* \brief The constructor used to read a step record from JSONReader and create the
* corresponding step.
* \param reader The input JSONReader.
*/
explicit RfactorStep(dmlc::JSONReader* reader);

TVM_DEFINE_OBJECT_REF_METHODS(RfactorStep, Step, RfactorStepNode);
};

} // namespace auto_scheduler
} // namespace tvm

Expand Down
68 changes: 68 additions & 0 deletions python/tvm/auto_scheduler/loop_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,22 @@ def fuse(self, stage, iters):
self._resolve_stage_id(stage), iters)
return res

def pragma(self, stage, iterator, pragma_type):
""" Schedule primitive corresponds to te.pragma.
Parameters
----------
stage : Union[int, Operation, Tensor]
The Stage to add pragma, which can be specified by the integer index, Operation,
or output tensor of the stage.
iterator : Iterator
The iterator to add pragma.
pragma_type : str
The pragma string.
"""
self.state_object = _ffi_api.StatePragma(self.state_object, self._resolve_stage_id(stage),
iterator, pragma_type)

def reorder(self, stage, order):
""" Schedule primitive corresponds to `te.Stage.reorder`, see also the `te.Stage` for more
details.
Expand Down Expand Up @@ -301,6 +317,27 @@ def split(self, stage, iterator, lengths, inner_to_outer=True):
iterator, lengths, inner_to_outer)
return res

def storage_align(self, stage, iterator, factor, offset):
""" Schedule primitive corresponds to te.storage_align.
See `te.schedule.Stage.storage_align` for more information.
Parameters
----------
stage : Union[int, Operation, Tensor]
The Stage to be storage aligned, which can be specified by the integer index,
Operation, or output tensor of the stage.
iterator : Iterator
The iterator to be aligned.
factor : int
The factor in alignment specification.
offset : int
The offset in the alignment specification.
"""
self.state_object = _ffi_api.StateStorageAlign(self.state_object,
self._resolve_stage_id(stage), iterator,
factor, offset)

def compute_at(self, stage, target_stage, target_iter):
""" Schedule primitive corresponds to `te.Stage.compute_at`, see also the `te.Stage` for
more details.
Expand Down Expand Up @@ -429,6 +466,37 @@ def cache_write(self, stage, scope_name):
self._update_stage_id_map()
return self.stages[int(new_stage_id)].op

def rfactor(self, stage, iterator, factor_iter_id):
""" Schedule primitive corresponds to te.schedule.rfactor.
See `te.schedule.Schedule.rfactor` for more information.
Parameters
----------
stage : Union[int, Operation, Tensor]
The Stage to be factored, which can be specified by the integer index, Operation,
or output tensor of the stage.
iterator : Iterator
The reduction iterator to be factored.
factor_iter_id : int
The position where the new iterator is placed.
Returns
-------
new_stage_op : Operator
The Operator of the new added stage.
Notes
-----
Rfactor step will insert an extra stage to the original ComputeDAG (in the front of the
target stage).
"""
self.state_object, new_stage_id = _ffi_api.StateRfactor(self.state_object,
self._resolve_stage_id(stage),
iterator, factor_iter_id,
self.compute_dag)
return self._insert_new_stage(int(new_stage_id))

def copy(self):
""" Do deep copy of this State. """
state = State(self.state_object, self.compute_dag)
Expand Down
Loading

0 comments on commit aeb8c7b

Please sign in to comment.