Skip to content

Commit

Permalink
[Ansor] Support multiple output ops and fix Python API printing (#6584)
Browse files Browse the repository at this point in the history
  • Loading branch information
comaniac authored Oct 2, 2020
1 parent f9abf56 commit 72969b2
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 54 deletions.
32 changes: 17 additions & 15 deletions src/auto_scheduler/compute_dag.cc
Original file line number Diff line number Diff line change
Expand Up @@ -969,6 +969,9 @@ void ComputeDAG::RewriteLayout(const Array<Step>& transform_steps) {
}
} // end for placeholder
} // end for stage
p_dag->access_analyzer = AccessAnalyzer(p_dag->tensors);
p_dag->ops = p_dag->access_analyzer->ops_topo_order;
p_dag->flop_ct = FlopEstimator().EstimateFlop(p_dag->ops);
}

std::pair<te::Schedule, Array<te::Tensor>> ComputeDAG::ApplySteps(
Expand All @@ -989,16 +992,15 @@ std::pair<te::Schedule, Array<te::Tensor>> ComputeDAG::ApplySteps(
if (stage_to_axes == nullptr) {
stage_to_axes = &temp_stage_to_axes;
}
Array<te::Operation> ops;
Array<te::Operation> out_ops;
for (const auto& op : operator->()->ops) {
if (!op->IsInstance<te::PlaceholderOpNode>()) {
ops.push_back(op);
if (operator->()->access_analyzer.IsOutput(op)) {
out_ops.push_back(op);
}
}

// Create the initial schedule
// TODO(jcf94): Currently we only checked single output dag for TVM Auto-scheduler,
// update this after testing with multiple outputs.
te::Schedule schedule = te::create_schedule({ops.back()});
te::Schedule schedule = te::create_schedule(out_ops);

// init axes
for (const auto& x : operator->()->ops) {
Expand All @@ -1019,16 +1021,14 @@ std::pair<te::Schedule, Array<te::Tensor>> ComputeDAG::ApplySteps(
String ComputeDAG::PrintStepsAsPython(const Array<Step>& transform_steps) const {
Array<te::Stage> stages;
StageToAxesMap stage_to_axes;
Array<te::Operation> ops;
Array<te::Operation> out_ops;
for (const auto& op : operator->()->ops) {
if (!op->IsInstance<te::PlaceholderOpNode>()) {
ops.push_back(op);
if (operator->()->access_analyzer.IsOutput(op)) {
out_ops.push_back(op);
}
}
// Create the initial schedule
// TODO(jcf94): Currently we only checked single output dag for TVM Auto-scheduler,
// update this after testing with multiple outputs.
te::Schedule schedule = te::create_schedule({ops.back()});
te::Schedule schedule = te::create_schedule(out_ops);

// init axes
for (const auto& x : operator->()->ops) {
Expand All @@ -1040,16 +1040,18 @@ String ComputeDAG::PrintStepsAsPython(const Array<Step>& transform_steps) const
std::stringstream ss;
for (const auto& stage : stages) {
if (stage->op->IsInstance<te::ComputeOpNode>()) {
auto op_name = CleanName(stage->op->name);

for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) {
ss << stage->leaf_iter_vars[i]->var->name_hint;
ss << CleanName(stage->leaf_iter_vars[i]->var->name_hint, op_name);
if (i != stage->leaf_iter_vars.size() - 1) {
ss << ", ";
}
}
ss << " = "
<< "tuple(" << stage->op->name << ".op.axis)"
<< "tuple(" << op_name << ".op.axis)"
<< " + "
<< "tuple(" << stage->op->name << ".op.reduce_axis)\n";
<< "tuple(" << op_name << ".op.reduce_axis)\n";
}
}
// Call each step's PrintAsPythonAPI method
Expand Down
8 changes: 8 additions & 0 deletions src/auto_scheduler/search_policy/sketch_policy_rules.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <tvm/auto_scheduler/loop_state.h>
#include <tvm/auto_scheduler/search_task.h>

#include <string>
#include <utility>
#include <vector>

Expand Down Expand Up @@ -74,6 +75,12 @@ class SketchGenerationRule {
*/
virtual std::vector<std::pair<State, int>> Apply(const SketchPolicyNode& policy,
const State& state, int stage_id) const = 0;

/*!
* \brief Get the name of this rule.
* \return A string of the rule name.
*/
virtual std::string GetRuleName() const = 0;
};

#define DEFINE_SKETCH_GENERATION_RULE(rule_name) \
Expand All @@ -83,6 +90,7 @@ class SketchGenerationRule {
int stage_id) const final; \
std::vector<std::pair<State, int>> Apply(const SketchPolicyNode& policy, const State& state, \
int stage_id) const final; \
std::string GetRuleName() const final { return #rule_name; } \
};

/*! \brief The rule that simply skips the current stage. It returns an unchanged state and move to
Expand Down
87 changes: 50 additions & 37 deletions src/auto_scheduler/transform_step.cc
Original file line number Diff line number Diff line change
Expand Up @@ -356,8 +356,9 @@ String AnnotationStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
std::stringstream ss;
const auto& stage = (*stages)[stage_id];
const auto& iter = (*stage_to_axes)[stage][iter_id];
const auto& op_name = CleanName(stage->op->name);

ss << "s[" << CleanName(stage->op->name) << "].";
ss << "s[" << op_name << "].";
switch (annotation) {
case IteratorAnnotation::kUnroll:
ss << "unroll(";
Expand All @@ -383,7 +384,7 @@ String AnnotationStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
LOG(FATAL) << "Invalid annotation " << static_cast<int>(annotation);
break;
}
ss << CleanName(iter->var->name_hint);
ss << CleanName(iter->var->name_hint, op_name);
switch (annotation) {
case IteratorAnnotation::kVThread:
case IteratorAnnotation::kBlockX:
Expand All @@ -392,7 +393,7 @@ String AnnotationStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
case IteratorAnnotation::kThreadX:
case IteratorAnnotation::kThreadY:
case IteratorAnnotation::kThreadZ:
ss << ", tvm.thread_axis(\"" << IteratorAnnotationString[static_cast<int>(annotation)]
ss << ", te.thread_axis(\"" << IteratorAnnotationString[static_cast<int>(annotation)]
<< "\")";
break;
default:
Expand Down Expand Up @@ -541,10 +542,11 @@ IterVar FuseStepNode::ApplyToSchedule(Array<te::Stage>* stages,
String FuseStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
StageToAxesMap* stage_to_axes) const {
const auto& stage = (*stages)[stage_id];
const auto& op_name = CleanName(stage->op->name);
std::stringstream to_fuse;

for (size_t i = 0; i < fused_ids.size(); ++i) {
to_fuse << CleanName(stage_to_axes->at(stage)[fused_ids[i]]->var->name_hint);
to_fuse << CleanName(stage_to_axes->at(stage)[fused_ids[i]]->var->name_hint, op_name);
if (i != fused_ids.size() - 1) {
to_fuse << ", ";
}
Expand All @@ -553,7 +555,7 @@ String FuseStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
std::stringstream ss;
const auto& fused = ApplyToSchedule(stages, stage_to_axes);

ss << CleanName(fused->var->name_hint) << " = s[" << CleanName(stage->op->name) << "].fuse("
ss << CleanName(fused->var->name_hint, op_name) << " = s[" << op_name << "].fuse("
<< to_fuse.str() << ")\n";

return ss.str();
Expand Down Expand Up @@ -640,6 +642,7 @@ String PragmaStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
StageToAxesMap* stage_to_axes) const {
std::stringstream ss;
const auto& stage = (*stages)[stage_id];
const auto& op_name = CleanName(stage->op->name);

if (StrStartsWith(pragma_type, "auto_unroll_max_step")) {
size_t pos = 0;
Expand All @@ -650,16 +653,16 @@ String PragmaStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
}
CHECK_LT(pos, pragma_type.size()) << "max step value not found.";
int value = atoi(pragma_type.c_str() + pos + 1);
ss << "s[" << CleanName(stage->op->name) << "].pragma("
<< CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint)
ss << "s[" << op_name << "].pragma("
<< CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint, op_name)
<< ", \"auto_unroll_max_step\", " << value << ")\n";
ss << "s[" << CleanName(stage->op->name) << "].pragma("
<< CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint)
ss << "s[" << op_name << "].pragma("
<< CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint, op_name)
<< ", \"unroll_explicit\", True)\n";
} else {
ss << "s[" << CleanName(stage->op->name) << "].pragma("
<< CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint) << ", \"" << pragma_type
<< "\")\n";
ss << "s[" << op_name << "].pragma("
<< CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint, op_name) << ", \""
<< pragma_type << "\")\n";
}

ApplyToSchedule(stages, stage_to_axes);
Expand Down Expand Up @@ -726,11 +729,12 @@ void ReorderStepNode::ApplyToSchedule(Array<te::Stage>* stages,
String ReorderStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
StageToAxesMap* stage_to_axes) const {
const auto& stage = (*stages)[stage_id];
const auto& op_name = CleanName(stage->op->name);
std::stringstream ss;

ss << "s[" << CleanName(stage->op->name) << "].reorder(";
ss << "s[" << op_name << "].reorder(";
for (size_t i = 0; i < after_ids.size(); ++i) {
ss << CleanName((*stage_to_axes)[stage][after_ids[i]]->var->name_hint);
ss << CleanName((*stage_to_axes)[stage][after_ids[i]]->var->name_hint, op_name);
if (i != after_ids.size() - 1) {
ss << ", ";
}
Expand Down Expand Up @@ -881,16 +885,17 @@ String PrintSplitAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_
int size = static_cast<int>(lengths.size());
if (inner_to_outer) {
for (int i = size - 1; i >= 0; i--) {
ss << CleanName(outs[size - i]->var->name_hint) << ", "
<< CleanName(outs[size - i - 1]->var->name_hint) << " = s[" << func_name << "].split("
<< CleanName(to_split->var->name_hint) << ", factor=" << lengths[i] << ")\n";
ss << CleanName(outs[size - i]->var->name_hint, func_name) << ", "
<< CleanName(outs[size - i - 1]->var->name_hint, func_name) << " = s[" << func_name
<< "].split(" << CleanName(to_split->var->name_hint, func_name)
<< ", factor=" << lengths[i] << ")\n";
to_split = outs[size - i];
}
} else {
for (int i = 0; i < size; i++) {
ss << CleanName(outs[i]->var->name_hint) << ", " << CleanName(outs[i + 1]->var->name_hint)
<< " = s[" << func_name << "].split(" << CleanName(to_split->var->name_hint)
<< ", nparts=" << lengths[i] << ")\n";
ss << CleanName(outs[i]->var->name_hint, func_name) << ", "
<< CleanName(outs[i + 1]->var->name_hint, func_name) << " = s[" << func_name << "].split("
<< CleanName(to_split->var->name_hint, func_name) << ", nparts=" << lengths[i] << ")\n";
to_split = outs[i + 1];
}
}
Expand Down Expand Up @@ -1195,9 +1200,10 @@ String StorageAlignStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
StageToAxesMap* stage_to_axes) const {
std::stringstream ss;
const auto& stage = (*stages)[stage_id];
ss << "s[" << CleanName(stage->op->name) << "].storage_align("
<< CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint) << ", " << factor << ", "
<< offset << ")\n";
const auto& op_name = CleanName(stage->op->name);
ss << "s[" << op_name << "].storage_align("
<< CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint, op_name) << ", " << factor
<< ", " << offset << ")\n";

ApplyToSchedule(stages, stage_to_axes);
return ss.str();
Expand Down Expand Up @@ -1269,8 +1275,11 @@ String ComputeAtStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
std::stringstream ss;
const auto& stage = (*stages)[stage_id];
const auto& target_stage = (*stages)[target_stage_id];
ss << "s[" << CleanName(stage->op->name) << "].compute_at(s[" << CleanName(target_stage->op->name)
<< "], " << CleanName((*stage_to_axes)[target_stage][target_iter_id]->var->name_hint) << ")\n";
const auto& op_name = CleanName(stage->op->name);
const auto& target_op_name = CleanName(target_stage->op->name);
ss << "s[" << op_name << "].compute_at(s[" << target_op_name << "], "
<< CleanName((*stage_to_axes)[target_stage][target_iter_id]->var->name_hint, target_op_name)
<< ")\n";
ApplyToSchedule(stages, stage_to_axes);
return ss.str();
}
Expand Down Expand Up @@ -1516,7 +1525,8 @@ String CacheReadStepNode::PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxes
}
auto out = ApplyToSchedule(stages, stage_to_axes, schedule);

ss << CleanName(out->op->name) << " = "
const auto& op_name = CleanName(out->op->name);
ss << op_name << " = "
<< "s.cache_read(" << CleanName(stage->op->name) << ", \"" << scope_name << "\", ["
<< CleanName(reader_stages[0]->op->name);
for (size_t i = 1; i < reader_stage_ids.size(); ++i) {
Expand All @@ -1527,13 +1537,13 @@ String CacheReadStepNode::PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxes
// Print the iterators of the new added stage
const auto& iters = out->op->root_iter_vars();
for (size_t i = 0; i < iters.size(); ++i) {
ss << CleanName(iters[i]->var->name_hint);
ss << CleanName(iters[i]->var->name_hint, op_name);
if (i != iters.size() - 1) {
ss << ", ";
}
}
ss << " = "
<< "tuple(" << CleanName(out->op->name) << ".op.axis)\n";
<< "tuple(" << op_name << ".op.axis)\n";

return ss.str();
}
Expand Down Expand Up @@ -1652,16 +1662,17 @@ String CacheWriteStepNode::PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxe
// Print the iterators of the new added stage
for (const auto& out : outs) {
const auto& iters = out->op->root_iter_vars();
const auto& op_name = CleanName(out->op->name);
for (size_t i = 0; i < iters.size(); ++i) {
ss << CleanName(iters[i]->var->name_hint);
ss << CleanName(iters[i]->var->name_hint, op_name);
if (i != iters.size() - 1) {
ss << ", ";
}
}
ss << " = "
<< "tuple(" << CleanName(out->op->name) << ".op.axis)"
<< "tuple(" << op_name << ".op.axis)"
<< " + "
<< "tuple(" << CleanName(out->op->name) << ".op.reduce_axis)\n";
<< "tuple(" << op_name << ".op.reduce_axis)\n";
}

return ss.str();
Expand Down Expand Up @@ -1764,30 +1775,32 @@ String RfactorStepNode::PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMa

for (const auto& out : outs) {
const auto& iters = out->op->root_iter_vars();
const auto& op_name = CleanName(out->op->name);
for (size_t i = 0; i < iters.size(); ++i) {
ss << CleanName(iters[i]->var->name_hint);
ss << CleanName(iters[i]->var->name_hint, op_name);
if (i != iters.size() - 1) {
ss << ", ";
}
}
ss << " = "
<< "tuple(" << CleanName(out->op->name) << ".op.axis)"
<< "tuple(" << op_name << ".op.axis)"
<< " + "
<< "tuple(" << CleanName(out->op->name) << ".op.reduce_axis)\n";
<< "tuple(" << op_name << ".op.reduce_axis)\n";
}

const auto& output = (*stages)[stage_id + 1]->op.output(0);
const auto& iters = output->op->root_iter_vars();
const auto& op_name = CleanName(output->op->name);
for (size_t i = 0; i < iters.size(); ++i) {
ss << CleanName(iters[i]->var->name_hint);
ss << CleanName(iters[i]->var->name_hint, op_name);
if (i != iters.size() - 1) {
ss << ", ";
}
}
ss << " = "
<< "tuple(s[" << CleanName(output->op->name) << "].op.axis)"
<< "tuple(s[" << op_name << "].op.axis)"
<< " + "
<< "tuple(s[" << CleanName(output->op->name) << "].op.reduce_axis)\n";
<< "tuple(s[" << op_name << "].op.reduce_axis)\n";

return ss.str();
}
Expand Down
8 changes: 6 additions & 2 deletions src/auto_scheduler/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -209,16 +209,20 @@ inline int64_t AxisLengthProd(const Array<tir::IterVar>& axes) {
}

/*!
* \brief Clean the name of an iterator to make it valid in python code.
* \brief Clean the name of an iterator or an op to make it valid in python code.
* \param str The original name.
* \param prefix The name prefix to differentiate the same name (e.g., the same iterator names).
* \return The cleaned name.
*/
inline std::string CleanName(const std::string& str) {
inline std::string CleanName(const std::string& str, const std::string& prefix = "") {
std::string ret = str;
StrReplace(&ret, ".", "_");
StrReplace(&ret, "@", "_");
StrReplace(&ret, "outer", "o");
StrReplace(&ret, "inner", "i");
if (prefix != "") {
return prefix + "_" + ret;
}
return ret;
}

Expand Down

0 comments on commit 72969b2

Please sign in to comment.