From 144fb491d0b27dd4119672577cf74e62c791969f Mon Sep 17 00:00:00 2001 From: LiangW <114222082+liangW-intellif@users.noreply.github.com> Date: Sun, 26 Feb 2023 03:57:27 +0800 Subject: [PATCH] [TVMScript] Use op attribute to control whether to print dtype in TVMScript (#14111) This PR adds an op attribute `TScriptDtypePrintLocation`, and modifies the dtype printing logic of the builtin op to check this attribute. So that user defined operators can use it to specify how there dtype argument are printed by appending attributes instead of appending members to `dtype_first_arg`/`dtype_last_arg`. --- include/tvm/tir/op_attr_types.h | 21 +++++++++++ src/script/printer/tir/expr.cc | 31 +++++---------- src/tir/ir/stmt.cc | 4 +- src/tir/op/builtin.cc | 67 ++++++++++++++++++++++++--------- 4 files changed, 82 insertions(+), 41 deletions(-) diff --git a/include/tvm/tir/op_attr_types.h b/include/tvm/tir/op_attr_types.h index 858d89c2d551..b2a644f9546e 100644 --- a/include/tvm/tir/op_attr_types.h +++ b/include/tvm/tir/op_attr_types.h @@ -61,6 +61,27 @@ using FLegalize = runtime::TypedPackedFunc; */ using TScriptPrinterName = String; +/*! + * \brief Specifies that TVMScript printer prints the dtype as the first/last argument. + If not specified, dtype will not be printed. + */ +enum class ScriptDtypePrintLocation : int { + /*! + * \brief Do not print dtype as an argument. + */ + kNone = 0, + /*! + * \brief Print dtype as the first argument. + */ + kFirst = 1, + /*! + * \brief FPrint dtype as the last argument. + */ + kLast = 2, +}; + +using TScriptDtypePrintLocation = Integer; + /*! * \brief The effect type of the call. */ diff --git a/src/script/printer/tir/expr.cc b/src/script/printer/tir/expr.cc index d860eeb2a7da..f1435c487044 100644 --- a/src/script/printer/tir/expr.cc +++ b/src/script/printer/tir/expr.cc @@ -222,26 +222,9 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](tir::Call call, ObjectPath call_p, IRDocsifier d) -> Doc { static const OpAttrMap& op_names = Op::GetAttrMap("TScriptPrinterName"); - static const std::unordered_set dtype_first_arg = { - tir::builtin::reinterpret().get(), - tir::builtin::call_extern().get(), - tir::builtin::call_llvm_intrin().get(), // - tir::builtin::call_llvm_pure_intrin().get(), // - tir::builtin::call_pure_extern().get(), // - tir::builtin::ptx_mma().get(), - tir::builtin::ptx_mma_sp().get(), - tir::builtin::ptx_ldmatrix().get(), - tir::builtin::ptx_cp_async().get(), - tir::builtin::mma_store().get(), - tir::builtin::mma_fill().get(), - tir::builtin::vectorlow().get(), - tir::builtin::vectorhigh().get(), - tir::builtin::vectorcombine().get(), - Op::Get("tir.type_annotation").get(), - }; - static const std::unordered_set dtype_last_arg = { - tir::builtin::tvm_struct_get().get(), - }; + static const OpAttrMap dtype_locations = + Op::GetAttrMap("TScriptDtypePrintLocation"); + tir::ScriptDtypePrintLocation dtype_print_location = tir::ScriptDtypePrintLocation::kNone; ExprDoc prefix{nullptr}; if (const auto* op = call->op.as()) { String name = op_names.get(GetRef(op), op->name); @@ -249,6 +232,10 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) LOG(WARNING) << "No TScriptPrinterName attribute for " << op->name; } prefix = TIR(d, name); + if (dtype_locations.count(GetRef(op))) { + dtype_print_location = static_cast( + dtype_locations[GetRef(op)].IntValue()); + } } else if (const auto* gv = call->op.as()) { prefix = LiteralDoc::Str(gv->name_hint, call_p->Attr("op")); } else { @@ -257,13 +244,13 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) Array args; int n_args = call->args.size(); args.reserve(n_args + 1); - if (dtype_first_arg.count(call->op.get())) { + if (dtype_print_location == tir::ScriptDtypePrintLocation::kFirst) { args.push_back(LiteralDoc::DataType(call->dtype, call_p->Attr("dtype"))); } for (int i = 0; i < n_args; ++i) { args.push_back(d->AsDoc(call->args[i], call_p->Attr("args")->ArrayIndex(i))); } - if (dtype_last_arg.count(call->op.get())) { + if (dtype_print_location == tir::ScriptDtypePrintLocation::kLast) { args.push_back(LiteralDoc::DataType(call->dtype, call_p->Attr("dtype"))); } return prefix->Call(args); diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 1652786cb510..fd2a98554da6 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -692,7 +692,9 @@ PrimExpr TypeAnnotation(DataType dtype, Span span) { } TVM_TIR_REGISTER_OP("type_annotation") - .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)); + .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)) + .set_attr("TScriptDtypePrintLocation", + Integer(ScriptDtypePrintLocation::kFirst)); } // namespace tir } // namespace tvm diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc index f9d522804260..746524570243 100644 --- a/src/tir/op/builtin.cc +++ b/src/tir/op/builtin.cc @@ -40,6 +40,8 @@ namespace builtin { TIR_DEFINE_BUILTIN_FUNC(reinterpret) .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)) + .set_attr("TScriptDtypePrintLocation", + Integer(ScriptDtypePrintLocation::kFirst)) .set_num_inputs(1); TIR_DEFINE_BUILTIN_FUNC(ret) @@ -120,16 +122,24 @@ TIR_DEFINE_BUILTIN_FUNC(fma) .set_attr("TVectorizable", true); TIR_DEFINE_BUILTIN_FUNC(call_extern) - .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)) + .set_attr("TScriptDtypePrintLocation", + Integer(ScriptDtypePrintLocation::kFirst)); TIR_DEFINE_BUILTIN_FUNC(call_pure_extern) - .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)); + .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)) + .set_attr("TScriptDtypePrintLocation", + Integer(ScriptDtypePrintLocation::kFirst)); TIR_DEFINE_BUILTIN_FUNC(call_llvm_intrin) - .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)) + .set_attr("TScriptDtypePrintLocation", + Integer(ScriptDtypePrintLocation::kFirst)); TIR_DEFINE_BUILTIN_FUNC(call_llvm_pure_intrin) - .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)); + .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)) + .set_attr("TScriptDtypePrintLocation", + Integer(ScriptDtypePrintLocation::kFirst)); TIR_DEFINE_BUILTIN_FUNC(call_spirv_pure_glsl450) .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)); @@ -154,7 +164,9 @@ TIR_DEFINE_BUILTIN_FUNC(tvm_tuple).set_attr("TCallEffectKind", TIR_DEFINE_BUILTIN_FUNC(tvm_struct_get) .set_num_inputs(3) - .set_attr("TCallEffectKind", Integer(CallEffectKind::kReadState)); + .set_attr("TCallEffectKind", Integer(CallEffectKind::kReadState)) + .set_attr("TScriptDtypePrintLocation", + Integer(ScriptDtypePrintLocation::kLast)); TIR_DEFINE_BUILTIN_FUNC(tvm_struct_set) .set_num_inputs(4) @@ -249,19 +261,28 @@ TIR_DEFINE_BUILTIN_FUNC(tvm_fill_fragment) TIR_DEFINE_BUILTIN_FUNC(tvm_store_matrix_sync) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); -TIR_DEFINE_BUILTIN_FUNC(ptx_mma).set_attr("TCallEffectKind", - Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_BUILTIN_FUNC(ptx_mma) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)) + .set_attr("TScriptDtypePrintLocation", + Integer(ScriptDtypePrintLocation::kFirst)); + TIR_DEFINE_BUILTIN_FUNC(ptx_ldg32).set_num_inputs(4).set_attr( "TCallEffectKind", Integer(CallEffectKind::kPure)); TIR_DEFINE_BUILTIN_FUNC(ptx_mma_sp) - .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)) + .set_attr("TScriptDtypePrintLocation", + Integer(ScriptDtypePrintLocation::kFirst)); TIR_DEFINE_BUILTIN_FUNC(ptx_ldmatrix) - .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)) + .set_attr("TScriptDtypePrintLocation", + Integer(ScriptDtypePrintLocation::kFirst)); TIR_DEFINE_BUILTIN_FUNC(ptx_cp_async) - .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)) + .set_attr("TScriptDtypePrintLocation", + Integer(ScriptDtypePrintLocation::kFirst)); TIR_DEFINE_BUILTIN_FUNC(ptx_commit_group) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); @@ -269,20 +290,30 @@ TIR_DEFINE_BUILTIN_FUNC(ptx_commit_group) TIR_DEFINE_BUILTIN_FUNC(ptx_wait_group) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); -TIR_DEFINE_BUILTIN_FUNC(mma_store).set_attr("TCallEffectKind", - Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_BUILTIN_FUNC(mma_store) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)) + .set_attr("TScriptDtypePrintLocation", + Integer(ScriptDtypePrintLocation::kFirst)); -TIR_DEFINE_BUILTIN_FUNC(mma_fill).set_attr("TCallEffectKind", - Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_BUILTIN_FUNC(mma_fill) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)) + .set_attr("TScriptDtypePrintLocation", + Integer(ScriptDtypePrintLocation::kFirst)); TIR_DEFINE_BUILTIN_FUNC(vectorhigh) - .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)); + .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)) + .set_attr("TScriptDtypePrintLocation", + Integer(ScriptDtypePrintLocation::kFirst)); -TIR_DEFINE_BUILTIN_FUNC(vectorlow).set_attr("TCallEffectKind", - Integer(CallEffectKind::kPure)); +TIR_DEFINE_BUILTIN_FUNC(vectorlow) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)) + .set_attr("TScriptDtypePrintLocation", + Integer(ScriptDtypePrintLocation::kFirst)); TIR_DEFINE_BUILTIN_FUNC(vectorcombine) - .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)); + .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)) + .set_attr("TScriptDtypePrintLocation", + Integer(ScriptDtypePrintLocation::kFirst)); TIR_DEFINE_BUILTIN_FUNC(atomic_add) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque));