diff --git a/include/tvm/tir/op_attr_types.h b/include/tvm/tir/op_attr_types.h index 858d89c2d551..b2a644f9546e 100644 --- a/include/tvm/tir/op_attr_types.h +++ b/include/tvm/tir/op_attr_types.h @@ -61,6 +61,27 @@ using FLegalize = runtime::TypedPackedFunc; */ using TScriptPrinterName = String; +/*! + * \brief Specifies that TVMScript printer prints the dtype as the first/last argument. + If not specified, dtype will not be printed. + */ +enum class ScriptDtypePrintLocation : int { + /*! + * \brief Do not print dtype as an argument. + */ + kNone = 0, + /*! + * \brief Print dtype as the first argument. + */ + kFirst = 1, + /*! + * \brief FPrint dtype as the last argument. + */ + kLast = 2, +}; + +using TScriptDtypePrintLocation = Integer; + /*! * \brief The effect type of the call. */ diff --git a/src/script/printer/tir/expr.cc b/src/script/printer/tir/expr.cc index d860eeb2a7da..f1435c487044 100644 --- a/src/script/printer/tir/expr.cc +++ b/src/script/printer/tir/expr.cc @@ -222,26 +222,9 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](tir::Call call, ObjectPath call_p, IRDocsifier d) -> Doc { static const OpAttrMap& op_names = Op::GetAttrMap("TScriptPrinterName"); - static const std::unordered_set dtype_first_arg = { - tir::builtin::reinterpret().get(), - tir::builtin::call_extern().get(), - tir::builtin::call_llvm_intrin().get(), // - tir::builtin::call_llvm_pure_intrin().get(), // - tir::builtin::call_pure_extern().get(), // - tir::builtin::ptx_mma().get(), - tir::builtin::ptx_mma_sp().get(), - tir::builtin::ptx_ldmatrix().get(), - tir::builtin::ptx_cp_async().get(), - tir::builtin::mma_store().get(), - tir::builtin::mma_fill().get(), - tir::builtin::vectorlow().get(), - tir::builtin::vectorhigh().get(), - tir::builtin::vectorcombine().get(), - Op::Get("tir.type_annotation").get(), - }; - static const std::unordered_set dtype_last_arg = { - tir::builtin::tvm_struct_get().get(), - }; + static const OpAttrMap dtype_locations = + Op::GetAttrMap("TScriptDtypePrintLocation"); + tir::ScriptDtypePrintLocation dtype_print_location = tir::ScriptDtypePrintLocation::kNone; ExprDoc prefix{nullptr}; if (const auto* op = call->op.as()) { String name = op_names.get(GetRef(op), op->name); @@ -249,6 +232,10 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) LOG(WARNING) << "No TScriptPrinterName attribute for " << op->name; } prefix = TIR(d, name); + if (dtype_locations.count(GetRef(op))) { + dtype_print_location = static_cast( + dtype_locations[GetRef(op)].IntValue()); + } } else if (const auto* gv = call->op.as()) { prefix = LiteralDoc::Str(gv->name_hint, call_p->Attr("op")); } else { @@ -257,13 +244,13 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) Array args; int n_args = call->args.size(); args.reserve(n_args + 1); - if (dtype_first_arg.count(call->op.get())) { + if (dtype_print_location == tir::ScriptDtypePrintLocation::kFirst) { args.push_back(LiteralDoc::DataType(call->dtype, call_p->Attr("dtype"))); } for (int i = 0; i < n_args; ++i) { args.push_back(d->AsDoc(call->args[i], call_p->Attr("args")->ArrayIndex(i))); } - if (dtype_last_arg.count(call->op.get())) { + if (dtype_print_location == tir::ScriptDtypePrintLocation::kLast) { args.push_back(LiteralDoc::DataType(call->dtype, call_p->Attr("dtype"))); } return prefix->Call(args); diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 1652786cb510..fd2a98554da6 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -692,7 +692,9 @@ PrimExpr TypeAnnotation(DataType dtype, Span span) { } TVM_TIR_REGISTER_OP("type_annotation") - .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)); + .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)) + .set_attr("TScriptDtypePrintLocation", + Integer(ScriptDtypePrintLocation::kFirst)); } // namespace tir } // namespace tvm diff --git a/src/tir/op/builtin.cc b/src/tir/op/builtin.cc index 680202751f12..e240b7b701ba 100644 --- a/src/tir/op/builtin.cc +++ b/src/tir/op/builtin.cc @@ -40,6 +40,8 @@ namespace builtin { TIR_DEFINE_BUILTIN_FUNC(reinterpret) .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)) + .set_attr("TScriptDtypePrintLocation", + Integer(ScriptDtypePrintLocation::kFirst)) .set_num_inputs(1); TIR_DEFINE_BUILTIN_FUNC(ret) @@ -120,16 +122,24 @@ TIR_DEFINE_BUILTIN_FUNC(fma) .set_attr("TVectorizable", true); TIR_DEFINE_BUILTIN_FUNC(call_extern) - .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)) + .set_attr("TScriptDtypePrintLocation", + Integer(ScriptDtypePrintLocation::kFirst)); TIR_DEFINE_BUILTIN_FUNC(call_pure_extern) - .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)); + .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)) + .set_attr("TScriptDtypePrintLocation", + Integer(ScriptDtypePrintLocation::kFirst)); TIR_DEFINE_BUILTIN_FUNC(call_llvm_intrin) - .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)) + .set_attr("TScriptDtypePrintLocation", + Integer(ScriptDtypePrintLocation::kFirst)); TIR_DEFINE_BUILTIN_FUNC(call_llvm_pure_intrin) - .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)); + .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)) + .set_attr("TScriptDtypePrintLocation", + Integer(ScriptDtypePrintLocation::kFirst)); TIR_DEFINE_BUILTIN_FUNC(call_spirv_pure_glsl450) .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)); @@ -154,7 +164,9 @@ TIR_DEFINE_BUILTIN_FUNC(tvm_tuple).set_attr("TCallEffectKind", TIR_DEFINE_BUILTIN_FUNC(tvm_struct_get) .set_num_inputs(3) - .set_attr("TCallEffectKind", Integer(CallEffectKind::kReadState)); + .set_attr("TCallEffectKind", Integer(CallEffectKind::kReadState)) + .set_attr("TScriptDtypePrintLocation", + Integer(ScriptDtypePrintLocation::kLast)); TIR_DEFINE_BUILTIN_FUNC(tvm_struct_set) .set_num_inputs(4) @@ -249,19 +261,28 @@ TIR_DEFINE_BUILTIN_FUNC(tvm_fill_fragment) TIR_DEFINE_BUILTIN_FUNC(tvm_store_matrix_sync) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); -TIR_DEFINE_BUILTIN_FUNC(ptx_mma).set_attr("TCallEffectKind", - Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_BUILTIN_FUNC(ptx_mma) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)) + .set_attr("TScriptDtypePrintLocation", + Integer(ScriptDtypePrintLocation::kFirst)); + TIR_DEFINE_BUILTIN_FUNC(ptx_ldg32).set_num_inputs(4).set_attr( "TCallEffectKind", Integer(CallEffectKind::kPure)); TIR_DEFINE_BUILTIN_FUNC(ptx_mma_sp) - .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)) + .set_attr("TScriptDtypePrintLocation", + Integer(ScriptDtypePrintLocation::kFirst)); TIR_DEFINE_BUILTIN_FUNC(ptx_ldmatrix) - .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)) + .set_attr("TScriptDtypePrintLocation", + Integer(ScriptDtypePrintLocation::kFirst)); TIR_DEFINE_BUILTIN_FUNC(ptx_cp_async) - .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)) + .set_attr("TScriptDtypePrintLocation", + Integer(ScriptDtypePrintLocation::kFirst)); TIR_DEFINE_BUILTIN_FUNC(ptx_commit_group) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); @@ -269,20 +290,30 @@ TIR_DEFINE_BUILTIN_FUNC(ptx_commit_group) TIR_DEFINE_BUILTIN_FUNC(ptx_wait_group) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); -TIR_DEFINE_BUILTIN_FUNC(mma_store).set_attr("TCallEffectKind", - Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_BUILTIN_FUNC(mma_store) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)) + .set_attr("TScriptDtypePrintLocation", + Integer(ScriptDtypePrintLocation::kFirst)); -TIR_DEFINE_BUILTIN_FUNC(mma_fill).set_attr("TCallEffectKind", - Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_BUILTIN_FUNC(mma_fill) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)) + .set_attr("TScriptDtypePrintLocation", + Integer(ScriptDtypePrintLocation::kFirst)); TIR_DEFINE_BUILTIN_FUNC(vectorhigh) - .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)); + .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)) + .set_attr("TScriptDtypePrintLocation", + Integer(ScriptDtypePrintLocation::kFirst)); -TIR_DEFINE_BUILTIN_FUNC(vectorlow).set_attr("TCallEffectKind", - Integer(CallEffectKind::kPure)); +TIR_DEFINE_BUILTIN_FUNC(vectorlow) + .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)) + .set_attr("TScriptDtypePrintLocation", + Integer(ScriptDtypePrintLocation::kFirst)); TIR_DEFINE_BUILTIN_FUNC(vectorcombine) - .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)); + .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)) + .set_attr("TScriptDtypePrintLocation", + Integer(ScriptDtypePrintLocation::kFirst)); TIR_DEFINE_BUILTIN_FUNC(atomic_add) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque));