From bbc39df3d72a3d37113e90add7f6a28c88c14d20 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Mon, 6 Mar 2023 12:57:39 -0800 Subject: [PATCH] [TVMScript] Sugar T.env_thread + T.launch_thread This PR introduces a syntactic sugar that combines T.env_thread and T.launch_thread. Previously, an AttrStmt that specifies thread extent or virtual thread is required to be written in two steps: ```python bx = T.env_thread("blockIdx.x") // creates an IterVar with T.launch_thread(bx, 128): // specify the iter domain ... ``` With this PR, now this behavior can be merged in a single line: ```python with T.launch_thread("blockIdx.x", 128) as bx: ... ``` --- include/tvm/script/ir_builder/tir/ir.h | 8 +++ python/tvm/script/ir_builder/tir/frame.py | 4 +- python/tvm/script/ir_builder/tir/ir.py | 13 ++-- src/script/ir_builder/tir/ir.cc | 17 ++++- src/script/printer/tir/stmt.cc | 75 +++++++++++++++++------ 5 files changed, 90 insertions(+), 27 deletions(-) diff --git a/include/tvm/script/ir_builder/tir/ir.h b/include/tvm/script/ir_builder/tir/ir.h index 2b89d0e736e89..8d8b0b42ba5c0 100644 --- a/include/tvm/script/ir_builder/tir/ir.h +++ b/include/tvm/script/ir_builder/tir/ir.h @@ -390,6 +390,14 @@ DeclBufferFrame DeclBuffer(Array shape, DataType dtype, String buffer_ */ LaunchThreadFrame LaunchThread(Var var, PrimExpr extent); +/*! + * \brief Launch a new thread. + * \param thread_tag The thread type tag. + * \param extent The extent of environment thread. + * \return The result LaunchThreadFrame. + */ +LaunchThreadFrame LaunchThread(String thread_tag, PrimExpr extent); + /*! * \brief Bind a var to thread env. * \param thread_tag The thread type tag. diff --git a/python/tvm/script/ir_builder/tir/frame.py b/python/tvm/script/ir_builder/tir/frame.py index 3e453f2e51833..b2229d503bfbd 100644 --- a/python/tvm/script/ir_builder/tir/frame.py +++ b/python/tvm/script/ir_builder/tir/frame.py @@ -115,4 +115,6 @@ def __enter__(self) -> Buffer: @_register_object("script.ir_builder.tir.LaunchThreadFrame") class LaunchThreadFrame(TIRFrame): - ... + def __enter__(self) -> Var: + super().__enter__() + return self.iter_var.var diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index 62a0aa8f32f77..e88597732c937 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -31,7 +31,7 @@ from tvm import tir from tvm.ir import Range, Type from tvm.ir.base import deprecated -from tvm.runtime import convert, ndarray +from tvm.runtime import String, convert, ndarray from tvm.target import Target # pylint: disable=unused-import @@ -1185,14 +1185,14 @@ def decl_buffer( def launch_thread( - iter_var: IterVar, # pylint: disable=redefined-outer-name + thread: Union[IterVar, str], # pylint: disable=redefined-outer-name extent: PrimExpr, ) -> frame.LaunchThreadFrame: """Launch a thread. Parameters ---------- - iter_var : IterVar + thread : Union[IterVar, str] The iteration variable. extent : PrimExpr @@ -1213,11 +1213,14 @@ def launch_thread( T.launch_thread(brow, 1) """ - return _ffi_api.LaunchThread(iter_var, extent) # type: ignore[attr-defined] # pylint: disable=no-member + + if isinstance(thread, str): + thread = String(thread) + return _ffi_api.LaunchThread(thread, extent) # type: ignore[attr-defined] # pylint: disable=no-member def env_thread(thread_tag: str) -> IterVar: - """Bind a var to thread env" + """Bind a var to thread env Parameters ---------- diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc index a54f3d926fc95..487265bff29ab 100644 --- a/src/script/ir_builder/tir/ir.cc +++ b/src/script/ir_builder/tir/ir.cc @@ -427,6 +427,10 @@ LaunchThreadFrame LaunchThread(Var var, PrimExpr extent) { return LaunchThreadFrame(n); } +LaunchThreadFrame LaunchThread(String thread_tag, PrimExpr extent) { + return LaunchThread(EnvThread(thread_tag), extent); +} + RealizeFrame Realize(tvm::tir::BufferRegion buffer_slice, String storage_scope, PrimExpr condition) { ObjectPtr n = make_object(); @@ -658,7 +662,18 @@ TVM_REGISTER_GLOBAL("script.ir_builder.tir.If").set_body_typed(If); TVM_REGISTER_GLOBAL("script.ir_builder.tir.Then").set_body_typed(Then); TVM_REGISTER_GLOBAL("script.ir_builder.tir.Else").set_body_typed(Else); TVM_REGISTER_GLOBAL("script.ir_builder.tir.DeclBuffer").set_body_typed(DeclBuffer); -TVM_REGISTER_GLOBAL("script.ir_builder.tir.LaunchThread").set_body_typed(LaunchThread); +TVM_REGISTER_GLOBAL("script.ir_builder.tir.LaunchThread") + .set_body_typed([](ObjectRef thread_tag_or_var, PrimExpr extent) { + if (const auto* var = thread_tag_or_var.as()) { + return LaunchThread(GetRef(var), extent); + } else if (const auto* str = thread_tag_or_var.as()) { + return LaunchThread(GetRef(str), extent); + } else { + LOG(FATAL) << "ValueError: Unexpected type for TIR LaunchThread: " + << thread_tag_or_var->GetTypeKey(); + throw; + } + }); TVM_REGISTER_GLOBAL("script.ir_builder.tir.EnvThread").set_body_typed(EnvThread); TVM_REGISTER_GLOBAL("script.ir_builder.tir.BufferStore").set_body_typed(BufferStore); diff --git a/src/script/printer/tir/stmt.cc b/src/script/printer/tir/stmt.cc index 92ad41edc9d5a..591d1e3bc1da3 100644 --- a/src/script/printer/tir/stmt.cc +++ b/src/script/printer/tir/stmt.cc @@ -45,6 +45,19 @@ bool AllowConciseScoping(const IRDocsifier& d) { LOG(FATAL) << "NotImplementedError: fragment printing"; } +bool IsAncestorOfAllVarUse(const tir::Stmt& node, const ObjectRef& var, const IRDocsifier& d) { + if (!d->common_prefix.count(var.get())) { + return false; + } + const std::vector& path = d->common_prefix.at(var.get()); + for (auto it = path.rbegin(); it != path.rend(); ++it) { + if (*it == node.get()) { + return true; + } + } + return false; +} + TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](tir::Evaluate eval, ObjectPath p, IRDocsifier d) -> Doc { ExprDoc value = d->AsDoc(eval->value, p->Attr("value")); @@ -322,6 +335,39 @@ ExprDoc DocsifyBufferRealize(const tir::BufferRealizeNode* stmt, OptionalCall(args, kwargs_keys, kwargs_values); } +void InsertEnvThread(const tir::IterVar& iter_var, const ObjectPath& iter_var_p, + const IRDocsifier& d) { + Frame f = FindLowestVarDef(iter_var->var, d).value(); + DefineVar(iter_var->var, f, d); + ExprDoc rhs = TIR(d, "env_thread") + ->Call({LiteralDoc::Str(iter_var->thread_tag, // + iter_var_p->Attr("thread_tag"))}); + ExprDoc lhs = d->AsDoc(iter_var->var, iter_var_p->Attr("var")); + f->stmts.push_back(AssignDoc(lhs, rhs, NullOpt)); +} + +ExprDoc DocsifyLaunchThread(const tir::AttrStmt& attr_stmt, const ObjectPath& attr_stmt_p, + Optional* define_var, const IRDocsifier& d) { + tir::IterVar iter_var = Downcast(attr_stmt->node); + ObjectPath iter_var_p = attr_stmt_p->Attr("node"); + + ExprDoc var_doc{nullptr}; + if (d->IsVarDefined(iter_var->var)) { + var_doc = d->AsDoc(iter_var->var, iter_var_p->Attr("var")); + } else if (IsAncestorOfAllVarUse(attr_stmt, iter_var->var, d)) { + var_doc = LiteralDoc::Str(iter_var->thread_tag, iter_var_p->Attr("thread_tag")); + *define_var = iter_var->var; + } else { + InsertEnvThread(iter_var, iter_var_p, d); + var_doc = d->AsDoc(iter_var->var, iter_var_p->Attr("var")); + } + return TIR(d, "launch_thread") + ->Call({ + var_doc, + d->AsDoc(attr_stmt->value, attr_stmt_p->Attr("value")), + }); +} + TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "", [](tir::BufferRealize stmt, ObjectPath p, IRDocsifier d) -> Doc { @@ -336,7 +382,9 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "", [](tir::AttrStmt stmt, ObjectPath stmt_p, IRDocsifier d) -> Doc { bool concise = AllowConciseScoping(d); + Optional lhs = NullOpt; Optional rhs = NullOpt; + Optional define_var = NullOpt; tir::Stmt body = stmt->body; ObjectPath body_p = stmt_p->Attr("body"); if (stmt->attr_key == "realize_scope") { @@ -347,29 +395,13 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) /*value=*/d->AsDoc(stmt->value, stmt_p->Attr("value")), /*p=*/stmt_p->Attr("body"), d); body = realize->body; - body_p = body_p->Attr("body"); + body_p = stmt_p->Attr("body")->Attr("body"); } } } if (stmt->attr_key == "thread_extent" || stmt->attr_key == "virtual_thread") { - if (const auto* iter_var = stmt->node.as()) { - if (!d->IsVarDefined(iter_var->var)) { - // `DefineVar` is not used here because a more specific name is desirable - ObjectPath iter_var_p = stmt_p->Attr("node"); - Frame f = FindLowestVarDef(iter_var->var, d).value(); - DefineVar(iter_var->var, f, d); - f->stmts.push_back( - AssignDoc(d->AsDoc(iter_var->var, iter_var_p->Attr("var")), - TIR(d, "env_thread") - ->Call({LiteralDoc::Str(iter_var->thread_tag, - iter_var_p->Attr("thread_tag"))}), // - NullOpt)); - } - rhs = TIR(d, "launch_thread") - ->Call({ - d->AsDoc(iter_var->var, stmt_p->Attr("node")), - d->AsDoc(stmt->value, stmt_p->Attr("value")), - }); + if (stmt->node->IsInstance()) { + rhs = DocsifyLaunchThread(stmt, stmt_p, &define_var, d); } } if (!rhs.defined()) { @@ -380,8 +412,11 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) }); } With f(d, stmt); + if (define_var.defined()) { + lhs = DefineVar(define_var.value(), *f, d); + } AsDocBody(body, body_p, f->get(), d); - return DoConciseScoping(NullOpt, rhs.value(), &(*f)->stmts, concise); + return DoConciseScoping(lhs, rhs.value(), &(*f)->stmts, concise); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)