diff --git a/src/auto_scheduler/compute_dag.cc b/src/auto_scheduler/compute_dag.cc index 7c9ce4cb368d..23b38173c8c2 100755 --- a/src/auto_scheduler/compute_dag.cc +++ b/src/auto_scheduler/compute_dag.cc @@ -969,6 +969,9 @@ void ComputeDAG::RewriteLayout(const Array& transform_steps) { } } // end for placeholder } // end for stage + p_dag->access_analyzer = AccessAnalyzer(p_dag->tensors); + p_dag->ops = p_dag->access_analyzer->ops_topo_order; + p_dag->flop_ct = FlopEstimator().EstimateFlop(p_dag->ops); } std::pair> ComputeDAG::ApplySteps( @@ -989,16 +992,15 @@ std::pair> ComputeDAG::ApplySteps( if (stage_to_axes == nullptr) { stage_to_axes = &temp_stage_to_axes; } - Array ops; + Array out_ops; for (const auto& op : operator->()->ops) { - if (!op->IsInstance()) { - ops.push_back(op); + if (operator->()->access_analyzer.IsOutput(op)) { + out_ops.push_back(op); } } + // Create the initial schedule - // TODO(jcf94): Currently we only checked single output dag for TVM Auto-scheduler, - // update this after testing with multiple outputs. - te::Schedule schedule = te::create_schedule({ops.back()}); + te::Schedule schedule = te::create_schedule(out_ops); // init axes for (const auto& x : operator->()->ops) { @@ -1019,16 +1021,14 @@ std::pair> ComputeDAG::ApplySteps( String ComputeDAG::PrintStepsAsPython(const Array& transform_steps) const { Array stages; StageToAxesMap stage_to_axes; - Array ops; + Array out_ops; for (const auto& op : operator->()->ops) { - if (!op->IsInstance()) { - ops.push_back(op); + if (operator->()->access_analyzer.IsOutput(op)) { + out_ops.push_back(op); } } // Create the initial schedule - // TODO(jcf94): Currently we only checked single output dag for TVM Auto-scheduler, - // update this after testing with multiple outputs. - te::Schedule schedule = te::create_schedule({ops.back()}); + te::Schedule schedule = te::create_schedule(out_ops); // init axes for (const auto& x : operator->()->ops) { @@ -1040,16 +1040,18 @@ String ComputeDAG::PrintStepsAsPython(const Array& transform_steps) const std::stringstream ss; for (const auto& stage : stages) { if (stage->op->IsInstance()) { + auto op_name = CleanName(stage->op->name); + for (size_t i = 0; i < stage->leaf_iter_vars.size(); ++i) { - ss << stage->leaf_iter_vars[i]->var->name_hint; + ss << CleanName(stage->leaf_iter_vars[i]->var->name_hint, op_name); if (i != stage->leaf_iter_vars.size() - 1) { ss << ", "; } } ss << " = " - << "tuple(" << stage->op->name << ".op.axis)" + << "tuple(" << op_name << ".op.axis)" << " + " - << "tuple(" << stage->op->name << ".op.reduce_axis)\n"; + << "tuple(" << op_name << ".op.reduce_axis)\n"; } } // Call each step's PrintAsPythonAPI method diff --git a/src/auto_scheduler/search_policy/sketch_policy_rules.h b/src/auto_scheduler/search_policy/sketch_policy_rules.h index 928efc518827..035dc897d3da 100644 --- a/src/auto_scheduler/search_policy/sketch_policy_rules.h +++ b/src/auto_scheduler/search_policy/sketch_policy_rules.h @@ -28,6 +28,7 @@ #include #include +#include #include #include @@ -74,6 +75,12 @@ class SketchGenerationRule { */ virtual std::vector> Apply(const SketchPolicyNode& policy, const State& state, int stage_id) const = 0; + + /*! + * \brief Get the name of this rule. + * \return A string of the rule name. + */ + virtual std::string GetRuleName() const = 0; }; #define DEFINE_SKETCH_GENERATION_RULE(rule_name) \ @@ -83,6 +90,7 @@ class SketchGenerationRule { int stage_id) const final; \ std::vector> Apply(const SketchPolicyNode& policy, const State& state, \ int stage_id) const final; \ + std::string GetRuleName() const final { return #rule_name; } \ }; /*! \brief The rule that simply skips the current stage. It returns an unchanged state and move to diff --git a/src/auto_scheduler/transform_step.cc b/src/auto_scheduler/transform_step.cc index 2a9349739752..73f673421378 100755 --- a/src/auto_scheduler/transform_step.cc +++ b/src/auto_scheduler/transform_step.cc @@ -356,8 +356,9 @@ String AnnotationStepNode::PrintAsPythonAPI(Array* stages, std::stringstream ss; const auto& stage = (*stages)[stage_id]; const auto& iter = (*stage_to_axes)[stage][iter_id]; + const auto& op_name = CleanName(stage->op->name); - ss << "s[" << CleanName(stage->op->name) << "]."; + ss << "s[" << op_name << "]."; switch (annotation) { case IteratorAnnotation::kUnroll: ss << "unroll("; @@ -383,7 +384,7 @@ String AnnotationStepNode::PrintAsPythonAPI(Array* stages, LOG(FATAL) << "Invalid annotation " << static_cast(annotation); break; } - ss << CleanName(iter->var->name_hint); + ss << CleanName(iter->var->name_hint, op_name); switch (annotation) { case IteratorAnnotation::kVThread: case IteratorAnnotation::kBlockX: @@ -392,7 +393,7 @@ String AnnotationStepNode::PrintAsPythonAPI(Array* stages, case IteratorAnnotation::kThreadX: case IteratorAnnotation::kThreadY: case IteratorAnnotation::kThreadZ: - ss << ", tvm.thread_axis(\"" << IteratorAnnotationString[static_cast(annotation)] + ss << ", te.thread_axis(\"" << IteratorAnnotationString[static_cast(annotation)] << "\")"; break; default: @@ -541,10 +542,11 @@ IterVar FuseStepNode::ApplyToSchedule(Array* stages, String FuseStepNode::PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const { const auto& stage = (*stages)[stage_id]; + const auto& op_name = CleanName(stage->op->name); std::stringstream to_fuse; for (size_t i = 0; i < fused_ids.size(); ++i) { - to_fuse << CleanName(stage_to_axes->at(stage)[fused_ids[i]]->var->name_hint); + to_fuse << CleanName(stage_to_axes->at(stage)[fused_ids[i]]->var->name_hint, op_name); if (i != fused_ids.size() - 1) { to_fuse << ", "; } @@ -553,7 +555,7 @@ String FuseStepNode::PrintAsPythonAPI(Array* stages, std::stringstream ss; const auto& fused = ApplyToSchedule(stages, stage_to_axes); - ss << CleanName(fused->var->name_hint) << " = s[" << CleanName(stage->op->name) << "].fuse(" + ss << CleanName(fused->var->name_hint, op_name) << " = s[" << op_name << "].fuse(" << to_fuse.str() << ")\n"; return ss.str(); @@ -640,6 +642,7 @@ String PragmaStepNode::PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const { std::stringstream ss; const auto& stage = (*stages)[stage_id]; + const auto& op_name = CleanName(stage->op->name); if (StrStartsWith(pragma_type, "auto_unroll_max_step")) { size_t pos = 0; @@ -650,16 +653,16 @@ String PragmaStepNode::PrintAsPythonAPI(Array* stages, } CHECK_LT(pos, pragma_type.size()) << "max step value not found."; int value = atoi(pragma_type.c_str() + pos + 1); - ss << "s[" << CleanName(stage->op->name) << "].pragma(" - << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint) + ss << "s[" << op_name << "].pragma(" + << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint, op_name) << ", \"auto_unroll_max_step\", " << value << ")\n"; - ss << "s[" << CleanName(stage->op->name) << "].pragma(" - << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint) + ss << "s[" << op_name << "].pragma(" + << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint, op_name) << ", \"unroll_explicit\", True)\n"; } else { - ss << "s[" << CleanName(stage->op->name) << "].pragma(" - << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint) << ", \"" << pragma_type - << "\")\n"; + ss << "s[" << op_name << "].pragma(" + << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint, op_name) << ", \"" + << pragma_type << "\")\n"; } ApplyToSchedule(stages, stage_to_axes); @@ -726,11 +729,12 @@ void ReorderStepNode::ApplyToSchedule(Array* stages, String ReorderStepNode::PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const { const auto& stage = (*stages)[stage_id]; + const auto& op_name = CleanName(stage->op->name); std::stringstream ss; - ss << "s[" << CleanName(stage->op->name) << "].reorder("; + ss << "s[" << op_name << "].reorder("; for (size_t i = 0; i < after_ids.size(); ++i) { - ss << CleanName((*stage_to_axes)[stage][after_ids[i]]->var->name_hint); + ss << CleanName((*stage_to_axes)[stage][after_ids[i]]->var->name_hint, op_name); if (i != after_ids.size() - 1) { ss << ", "; } @@ -881,16 +885,17 @@ String PrintSplitAsPythonAPI(Array* stages, StageToAxesMap* stage_to_ int size = static_cast(lengths.size()); if (inner_to_outer) { for (int i = size - 1; i >= 0; i--) { - ss << CleanName(outs[size - i]->var->name_hint) << ", " - << CleanName(outs[size - i - 1]->var->name_hint) << " = s[" << func_name << "].split(" - << CleanName(to_split->var->name_hint) << ", factor=" << lengths[i] << ")\n"; + ss << CleanName(outs[size - i]->var->name_hint, func_name) << ", " + << CleanName(outs[size - i - 1]->var->name_hint, func_name) << " = s[" << func_name + << "].split(" << CleanName(to_split->var->name_hint, func_name) + << ", factor=" << lengths[i] << ")\n"; to_split = outs[size - i]; } } else { for (int i = 0; i < size; i++) { - ss << CleanName(outs[i]->var->name_hint) << ", " << CleanName(outs[i + 1]->var->name_hint) - << " = s[" << func_name << "].split(" << CleanName(to_split->var->name_hint) - << ", nparts=" << lengths[i] << ")\n"; + ss << CleanName(outs[i]->var->name_hint, func_name) << ", " + << CleanName(outs[i + 1]->var->name_hint, func_name) << " = s[" << func_name << "].split(" + << CleanName(to_split->var->name_hint, func_name) << ", nparts=" << lengths[i] << ")\n"; to_split = outs[i + 1]; } } @@ -1195,9 +1200,10 @@ String StorageAlignStepNode::PrintAsPythonAPI(Array* stages, StageToAxesMap* stage_to_axes) const { std::stringstream ss; const auto& stage = (*stages)[stage_id]; - ss << "s[" << CleanName(stage->op->name) << "].storage_align(" - << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint) << ", " << factor << ", " - << offset << ")\n"; + const auto& op_name = CleanName(stage->op->name); + ss << "s[" << op_name << "].storage_align(" + << CleanName((*stage_to_axes)[stage][iter_id]->var->name_hint, op_name) << ", " << factor + << ", " << offset << ")\n"; ApplyToSchedule(stages, stage_to_axes); return ss.str(); @@ -1269,8 +1275,11 @@ String ComputeAtStepNode::PrintAsPythonAPI(Array* stages, std::stringstream ss; const auto& stage = (*stages)[stage_id]; const auto& target_stage = (*stages)[target_stage_id]; - ss << "s[" << CleanName(stage->op->name) << "].compute_at(s[" << CleanName(target_stage->op->name) - << "], " << CleanName((*stage_to_axes)[target_stage][target_iter_id]->var->name_hint) << ")\n"; + const auto& op_name = CleanName(stage->op->name); + const auto& target_op_name = CleanName(target_stage->op->name); + ss << "s[" << op_name << "].compute_at(s[" << target_op_name << "], " + << CleanName((*stage_to_axes)[target_stage][target_iter_id]->var->name_hint, target_op_name) + << ")\n"; ApplyToSchedule(stages, stage_to_axes); return ss.str(); } @@ -1516,7 +1525,8 @@ String CacheReadStepNode::PrintAsPythonAPI(Array* stages, StageToAxes } auto out = ApplyToSchedule(stages, stage_to_axes, schedule); - ss << CleanName(out->op->name) << " = " + const auto& op_name = CleanName(out->op->name); + ss << op_name << " = " << "s.cache_read(" << CleanName(stage->op->name) << ", \"" << scope_name << "\", [" << CleanName(reader_stages[0]->op->name); for (size_t i = 1; i < reader_stage_ids.size(); ++i) { @@ -1527,13 +1537,13 @@ String CacheReadStepNode::PrintAsPythonAPI(Array* stages, StageToAxes // Print the iterators of the new added stage const auto& iters = out->op->root_iter_vars(); for (size_t i = 0; i < iters.size(); ++i) { - ss << CleanName(iters[i]->var->name_hint); + ss << CleanName(iters[i]->var->name_hint, op_name); if (i != iters.size() - 1) { ss << ", "; } } ss << " = " - << "tuple(" << CleanName(out->op->name) << ".op.axis)\n"; + << "tuple(" << op_name << ".op.axis)\n"; return ss.str(); } @@ -1652,16 +1662,17 @@ String CacheWriteStepNode::PrintAsPythonAPI(Array* stages, StageToAxe // Print the iterators of the new added stage for (const auto& out : outs) { const auto& iters = out->op->root_iter_vars(); + const auto& op_name = CleanName(out->op->name); for (size_t i = 0; i < iters.size(); ++i) { - ss << CleanName(iters[i]->var->name_hint); + ss << CleanName(iters[i]->var->name_hint, op_name); if (i != iters.size() - 1) { ss << ", "; } } ss << " = " - << "tuple(" << CleanName(out->op->name) << ".op.axis)" + << "tuple(" << op_name << ".op.axis)" << " + " - << "tuple(" << CleanName(out->op->name) << ".op.reduce_axis)\n"; + << "tuple(" << op_name << ".op.reduce_axis)\n"; } return ss.str(); @@ -1764,30 +1775,32 @@ String RfactorStepNode::PrintAsPythonAPI(Array* stages, StageToAxesMa for (const auto& out : outs) { const auto& iters = out->op->root_iter_vars(); + const auto& op_name = CleanName(out->op->name); for (size_t i = 0; i < iters.size(); ++i) { - ss << CleanName(iters[i]->var->name_hint); + ss << CleanName(iters[i]->var->name_hint, op_name); if (i != iters.size() - 1) { ss << ", "; } } ss << " = " - << "tuple(" << CleanName(out->op->name) << ".op.axis)" + << "tuple(" << op_name << ".op.axis)" << " + " - << "tuple(" << CleanName(out->op->name) << ".op.reduce_axis)\n"; + << "tuple(" << op_name << ".op.reduce_axis)\n"; } const auto& output = (*stages)[stage_id + 1]->op.output(0); const auto& iters = output->op->root_iter_vars(); + const auto& op_name = CleanName(output->op->name); for (size_t i = 0; i < iters.size(); ++i) { - ss << CleanName(iters[i]->var->name_hint); + ss << CleanName(iters[i]->var->name_hint, op_name); if (i != iters.size() - 1) { ss << ", "; } } ss << " = " - << "tuple(s[" << CleanName(output->op->name) << "].op.axis)" + << "tuple(s[" << op_name << "].op.axis)" << " + " - << "tuple(s[" << CleanName(output->op->name) << "].op.reduce_axis)\n"; + << "tuple(s[" << op_name << "].op.reduce_axis)\n"; return ss.str(); } diff --git a/src/auto_scheduler/utils.h b/src/auto_scheduler/utils.h index d036743c7b8b..610fec96617a 100755 --- a/src/auto_scheduler/utils.h +++ b/src/auto_scheduler/utils.h @@ -209,16 +209,20 @@ inline int64_t AxisLengthProd(const Array& axes) { } /*! - * \brief Clean the name of an iterator to make it valid in python code. + * \brief Clean the name of an iterator or an op to make it valid in python code. * \param str The original name. + * \param prefix The name prefix to differentiate the same name (e.g., the same iterator names). * \return The cleaned name. */ -inline std::string CleanName(const std::string& str) { +inline std::string CleanName(const std::string& str, const std::string& prefix = "") { std::string ret = str; StrReplace(&ret, ".", "_"); StrReplace(&ret, "@", "_"); StrReplace(&ret, "outer", "o"); StrReplace(&ret, "inner", "i"); + if (prefix != "") { + return prefix + "_" + ret; + } return ret; }