Skip to content

Commit

Permalink
Add tensorize step for loop_state (apache#31)
Browse files Browse the repository at this point in the history
* Add tensorize step
  • Loading branch information
jcf94 authored and merrymercy committed Jun 20, 2020
1 parent 4ea6712 commit 6126cdb
Show file tree
Hide file tree
Showing 11 changed files with 246 additions and 28 deletions.
25 changes: 22 additions & 3 deletions python/tvm/ansor/loop_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,14 +411,33 @@ def storage_align(self, stage_id, it, factor, offset):
it : Iterator
factor : Int
offset : Int
"""
self.state_object = _ffi_api.StateStorageAlign(self.state_object, stage_id, it, factor, offset)
self.clear_cache()

def tensorize(self, stage_id, it, ti_func_name):
""" The `ti_func_name` corresponds to a global registered funcion
that returns a TensorIntrin
Parameters
----------
stage_id : Int
The index of the stage to do storage align
it : Iterator
The target iterator
ti_func_name : Str
Tensorize intrinsic function name
Returns
-------
state : State
The updated state
res_it : Iterator
The tensorized Iterator
"""
self.state_object = _ffi_api.StateStorageAlign(self.state_object, stage_id, it, factor, offset)
self.state_object, res = _ffi_api.StateTensorize(self.state_object,
stage_id, it,
ti_func_name)
self.clear_cache()
return res

def __str__(self):
return str(self.state_object)
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/ansor/task_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,8 @@ def tune(self, tune_option: TuneOption, search_policy: Union[str, List[SearchPol
else:
raise ValueError("Invalid strategy: " + self.strategy)

if self.verbose >= 1:
print("Next tuning task: %d" % task_idx)
self.tune_task(task_idx)

def tune_task(self, task_idx):
Expand Down
5 changes: 4 additions & 1 deletion src/ansor/compute_dag.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1086,7 +1086,8 @@ void ComputeDAG::InferBoundCommon(StateNode* pstate) const {
new_iters.push_back(IteratorNode::make(iter->name, (*find_res).second,
iter->iter_type,
iter->annotation,
&iter->ori_iters));
&iter->ori_iters,
iter->attr));
} else {
LOG(FATAL) << "Infer bound fails";
}
Expand Down Expand Up @@ -1161,6 +1162,8 @@ std::pair<te::Schedule, Array<te::Tensor> > ComputeDAG::ReplaySteps(
ps->ApplyToSchedule(stages, stage_to_axes, &schedule);
} else if (auto ps = step.as<StorageAlignStepNode>()) {
ps->ApplyToSchedule(stages, stage_to_axes);
} else if (auto ps = step.as<TensorizeStepNode>()) {
ps->ApplyToSchedule(stages, stage_to_axes);
} else {
LOG(FATAL) << "Invalid Step";
}
Expand Down
59 changes: 51 additions & 8 deletions src/ansor/loop_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ TVM_REGISTER_NODE_TYPE(IteratorNode);
// Maker for other classes
Iterator IteratorNode::make(std::string name, Range range,
IteratorType iter_type, IteratorAnnotation annotation,
const std::vector<Iterator>* ori_iters) {
const std::vector<Iterator>* ori_iters,
std::string attr) {
auto node = make_object<IteratorNode>();
node->name = std::move(name);
node->range = std::move(range);
Expand All @@ -48,6 +49,7 @@ Iterator IteratorNode::make(std::string name, Range range,
if (ori_iters != nullptr) {
node->ori_iters = *ori_iters;
}
node->attr = std::move(attr);
return Iterator(node);
}

Expand Down Expand Up @@ -310,6 +312,15 @@ void State::storage_align(int stage_id, const Iterator& it, int factor,
return DoStorageAlignStep(step);
}

Iterator State::tensorize(int stage_id, const Iterator& it,
std::string ti_func_name) {
const Stage& stage = operator->()->stages[stage_id];
TensorizeStep step = TensorizeStepNode::make(
stage_id, GetIndex(stage->iters, it), ti_func_name);
CopyOnWrite()->transform_steps.push_back(step);
return DoTensorizeStep(step);
}

// Steps' implementations
void State::DoReorderStep(const ReorderStep& step) {
const Stage& stage = operator->()->stages[step->stage_id];
Expand Down Expand Up @@ -509,8 +520,10 @@ Iterator State::DoAnnotationStep(const AnnotationStep& step) {
const Stage& stage = operator->()->stages[step->stage_id];
Iterator it = stage->iters[step->iter_id];

CHECK_EQ(it->annotation, IteratorAnnotation::kNone);
Iterator new_it = IteratorNode::make(it->name, it->range, it->iter_type,
step->annotation, &it->ori_iters);
step->annotation, &it->ori_iters,
it->attr);
Stage new_stage = stage;
new_stage.CopyOnWrite()->iters[step->iter_id] = new_it;
StateNode* pstate = CopyOnWrite();
Expand Down Expand Up @@ -538,7 +551,8 @@ void State::DoComputeAtStep(const ComputeAtStep& step) {
new_iters.push_back(it);
} else {
new_iters.push_back(IteratorNode::make(it->name, Range(), it->iter_type,
it->annotation, &it->ori_iters));
it->annotation, &it->ori_iters,
it->attr));
}
}

Expand All @@ -559,7 +573,8 @@ void State::DoComputeRootStep(const ComputeRootStep& step) {
std::vector<Iterator> new_iters;
for (const Iterator& it : stage->iters) {
new_iters.push_back(IteratorNode::make(it->name, Range(), it->iter_type,
it->annotation, &it->ori_iters));
it->annotation, &it->ori_iters,
it->attr));
}

// update attach map
Expand Down Expand Up @@ -747,6 +762,18 @@ void State::DoStorageAlignStep(const StorageAlignStep& step) {
stage->storage_offset = step->offset;
}

Iterator State::DoTensorizeStep(const TensorizeStep& step) {
const Stage& stage = operator->()->stages[step->stage_id];
Iterator it = stage->iters[step->iter_id];
Iterator new_it = IteratorNode::make(it->name, it->range, it->iter_type,
IteratorAnnotation::kTensorized, &it->ori_iters, step->ti_func_name);
Stage new_stage = stage;
new_stage.CopyOnWrite()->iters[step->iter_id] = new_it;
StateNode* pstate = CopyOnWrite();
pstate->stages[step->stage_id] = std::move(new_stage);
return new_it;
}

void State::DoStep(const Step& step, const ComputeDAG& dag) {
if (auto ps = step.as<ReorderStepNode>()) {
DoReorderStep(GetRef<ReorderStep>(ps));
Expand Down Expand Up @@ -776,6 +803,8 @@ void State::DoStep(const Step& step, const ComputeDAG& dag) {
DoRfactorStep(GetRef<RfactorStep>(ps), dag);
} else if (auto ps = step.as<StorageAlignStepNode>()) {
DoStorageAlignStep(GetRef<StorageAlignStep>(ps));
} else if (auto ps = step.as<TensorizeStepNode>()) {
DoTensorizeStep(GetRef<TensorizeStep>(ps));
} else {
LOG(FATAL) << "Invalid step: " << step;
}
Expand Down Expand Up @@ -854,15 +883,22 @@ void PrintStage(std::ostream* os, int stage_id, const StateNode* state,
case kThreadY:
*os << "gpu.threadIdx.y ";
break;
case kTensorized:
*os << "tensorize ";
break;
default:
LOG(FATAL) << "Invalid Annotation " << iter->annotation; break;
}
if (iter->range.defined()) {
*os << iter->name << " (" << iter->range->min << ","
<< iter->range->extent << ")"
<< "\n";
<< iter->range->extent << ")";
} else {
*os << iter->name << " (None)"
<< "\n";
*os << iter->name << " (None)";
}
if (!iter->attr.empty()) {
*os << " " << iter->attr;
}
*os << "\n";

indent += 2;
}
Expand Down Expand Up @@ -1174,6 +1210,13 @@ TVM_REGISTER_GLOBAL("ansor.StateStorageAlign")
return state;
});

TVM_REGISTER_GLOBAL("ansor.StateTensorize")
.set_body_typed([](State state, int stage_id, const Iterator& it,
std::string ti_func) {
const auto& res = state.tensorize(stage_id, it, ti_func);
return Array<ObjectRef>{state, res};
});

TVM_REGISTER_GLOBAL("ansor.StateEqual")
.set_body_typed([](State state1, State state2) {
return std::equal_to<State>()(state1, state2);
Expand Down
20 changes: 14 additions & 6 deletions src/ansor/loop_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ enum IteratorType {
/*! \brief The type of an iterator's annotation */
enum IteratorAnnotation {
kNone, kUnroll, kVectorize, kParallel,
kVThread, kBlockX, kThreadX, kBlockY, kThreadY
kVThread, kBlockX, kThreadX, kBlockY, kThreadY,
kTensorized
};

class Iterator;
Expand All @@ -90,14 +91,17 @@ class IteratorNode : public Object {
IteratorType iter_type;
IteratorAnnotation annotation;
std::vector<Iterator> ori_iters; // The original iterators before fusion
std::string attr;

static Iterator make(std::string name, Range range,
IteratorType iter_type, IteratorAnnotation annotation,
const std::vector<Iterator>* ori_iters = nullptr);
const std::vector<Iterator>* ori_iters = nullptr,
std::string attr = "");

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("name", &name);
v->Visit("range", &range);
v->Visit("attr", &attr);
}

static constexpr const char *_type_key = "ansor.Iterator";
Expand All @@ -115,6 +119,7 @@ class FuseStep; class AnnotationStep;
class ComputeAtStep; class ComputeRootStep; class ComputeInlineStep;
class CacheReadStep; class CacheWriteStep;
class PragmaStep; class RfactorStep; class StorageAlignStep;
class TensorizeStep;

/*!
* \brief A stage in the compute declaration
Expand Down Expand Up @@ -254,19 +259,21 @@ class State : public ObjectRef {
Iterator unroll(int stage_id, const Iterator& it, int max_unroll = -1);
Iterator bind_thread(int stage_id, const Iterator& it,
IteratorAnnotation thread_type);
Iterator tensorize(int stage_id, const Iterator& it,
std::string ti_func_name);
void compute_at(int stage_id, int target_stage_id,
const Iterator& target_iter);
void compute_root(int stage_id);
void compute_inline(int stage_id);
void pragma(int stage_id, const Iterator& it, const std::string& pragma_type);
void storage_align(int stage_id, const Iterator& it, int factor, int offset);
int cache_read(int stage_id, const std::string& scope_name,
const std::vector<int>& reader_stage_ids,
const ComputeDAG& task_dag);
int cache_write(int stage_id, const std::string& scope_name,
const ComputeDAG& task_dag);
void pragma(int stage_id, const Iterator& it, const std::string& pragma_type);
int rfactor(int stage_id, const Iterator& it, int factor_iter_id,
const ComputeDAG& task_dag);
void storage_align(int stage_id, const Iterator& it, int factor, int offset);

/* Do transform steps
* Note: The following functions only change loop state but do not change transform_history.
Expand All @@ -278,14 +285,15 @@ class State : public ObjectRef {
std::vector<Iterator> DoFollowFusedSplitStep(const FollowFusedSplitStep& step);
Iterator DoFuseStep(const FuseStep& step);
Iterator DoAnnotationStep(const AnnotationStep& step);
Iterator DoTensorizeStep(const TensorizeStep& step);
void DoComputeAtStep(const ComputeAtStep& step);
void DoComputeRootStep(const ComputeRootStep& step);
void DoComputeInlineStep(const ComputeInlineStep& step);
void DoPragmaStep(const PragmaStep& step);
void DoStorageAlignStep(const StorageAlignStep& step);
int DoCacheReadStep(const CacheReadStep& step, const ComputeDAG& dag);
int DoCacheWriteStep(const CacheWriteStep& step, const ComputeDAG& dag);
void DoPragmaStep(const PragmaStep& step);
int DoRfactorStep(const RfactorStep& step, const ComputeDAG& dag);
void DoStorageAlignStep(const StorageAlignStep& step);

// General do step functions with a runtime dynamic dispatcher
void DoStep(const Step& step, const ComputeDAG& dag);
Expand Down
20 changes: 18 additions & 2 deletions src/ansor/search_policy/meta_tile_rewrite_policy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -751,6 +751,11 @@ int InitPopulationThreadBind(const MetaTileRewritePolicyNode* policy,
continue;
}

if (HasAnnotationIter(stage, IteratorAnnotation::kThreadX)) {
// Skip if this stage has already done thread bind
continue;
}

std::vector<Iterator> to_fuse;

// This stage has not been tiled, but in GPU schedule, we must tile it
Expand Down Expand Up @@ -861,10 +866,16 @@ int InitPopulationCooperativeFetching(const MetaTileRewritePolicyNode* policy,
!HasCacheWriteStage((*state), stage_id - 1)) ||
(stage_id > 1 && HasCacheReadStage((*state), stage_id - 2) &&
HasCacheWriteStage((*state), stage_id - 2))) {
const Stage& target_stage = (*state)->stages[stage_id];
if (HasAnnotationIter(target_stage, IteratorAnnotation::kThreadX) ||
HasAnnotationIter(target_stage, IteratorAnnotation::kTensorized)) {
// Skip if this stage has already done thread bind or has been
// tensorized
continue;
}
// Get spatial_split_step_ids from the root stage
std::unordered_set<te::Operation, ObjectHash, ObjectEqual> consumers;
std::vector<int> spatial_split_step_ids;
const Stage& target_stage = (*state)->stages[stage_id];
GetConsumers(policy->cur_task_, (*state), target_stage->op, &consumers);
CHECK_EQ(consumers.size(), 1);
int target_stage_id = OperationToStage(*consumers.begin(), (*state));
Expand Down Expand Up @@ -1129,6 +1140,11 @@ int InitPopulationVectorization(const MetaTileRewritePolicyNode* policy,
continue;
}

if (HasAnnotationIter(stage, IteratorAnnotation::kTensorized)) {
// Skip if this stage has been tensorized
continue;
}

// try to fuse and vectorize the space iterators in the inner most tile
int cum_length_prod = 1;

Expand Down Expand Up @@ -1224,7 +1240,7 @@ int InitPopulationUnroll(const MetaTileRewritePolicyNode* policy,

n--;
}
} else if (stage->op->attrs.count(policy->always_unroll_key)) {
} else if (stage->op->attrs.count(policy->always_unroll_key)) {
// Special unroll policy
auto to_unroll_name_set = GetIterNameSetParam(stage->op->attrs,
policy->always_unroll_key);
Expand Down
10 changes: 10 additions & 0 deletions src/ansor/search_policy/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,16 @@ inline bool HasReduceIter(const Stage& stage) {
return false;
}

// Return whether the stage has specific annotated iterators
inline bool HasAnnotationIter(const Stage& stage, IteratorAnnotation type) {
for (const auto& iter : stage->iters) {
if (iter->annotation == type) {
return true;
}
}
return false;
}

// Return whether an op needs multi level tiling
inline bool NeedsMultilevelTiling(const SearchTask& task,
const State& state, const te::Operation& op) {
Expand Down
16 changes: 15 additions & 1 deletion src/ansor/serialization.cc
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,11 @@ struct Handler<std::vector<::tvm::ansor::Step> > {
writer->WriteArrayItem(ps->iter_id);
writer->WriteArrayItem(ps->factor);
writer->WriteArrayItem(ps->offset);
} else if (auto ps = data[i].as<::tvm::ansor::TensorizeStepNode>()) {
writer->WriteArrayItem(std::string("TS"));
writer->WriteArrayItem(ps->stage_id);
writer->WriteArrayItem(ps->iter_id);
writer->WriteArrayItem(ps->ti_func_name);
} else {
LOG(FATAL) << "Invalid step: " << data[i];
}
Expand All @@ -179,7 +184,7 @@ struct Handler<std::vector<::tvm::ansor::Step> > {
std::vector<::tvm::ansor::Step> * data) {
std::vector<int> int_list;
bool s, inner_to_outer, factor_or_nparts;
std::string name, scope_name, pragma_type;
std::string name, scope_name, pragma_type, ti_func_name;
int stage_id, target_stage_id, iter_id, src_step_id, n_split, ann, extent;
int level, factor_iter_id, factor, offset;

Expand Down Expand Up @@ -311,6 +316,15 @@ struct Handler<std::vector<::tvm::ansor::Step> > {
reader->Read(&offset);
data->push_back(::tvm::ansor::StorageAlignStepNode::make(
stage_id, iter_id, factor, offset));
} else if (name == "TS") {
s = reader->NextArrayItem(); CHECK(s);
reader->Read(&stage_id);
s = reader->NextArrayItem(); CHECK(s);
reader->Read(&iter_id);
s = reader->NextArrayItem(); CHECK(s);
reader->Read(&ti_func_name);
data->push_back(::tvm::ansor::TensorizeStepNode::make(
stage_id, iter_id, ti_func_name));
} else {
LOG(FATAL) << "Invalid step format";
}
Expand Down
Loading

0 comments on commit 6126cdb

Please sign in to comment.