Skip to content

Commit

Permalink
[TVMScript] Sugar T.env_thread + T.launch_thread (#14217)
Browse files Browse the repository at this point in the history
This PR introduces a syntactic sugar that combines T.env_thread and
T.launch_thread.

Previously, an AttrStmt that specifies thread extent or virtual thread
is required to be written in two steps:

```python
bx = T.env_thread("blockIdx.x")  // creates an IterVar
with T.launch_thread(bx, 128):   // specify the iter domain
  ...
```

With this PR, now this behavior can be merged in a single line:

```python
with T.launch_thread("blockIdx.x", 128) as bx:
  ...
```
  • Loading branch information
junrushao authored Mar 7, 2023
1 parent 2f2d5d4 commit be66a7e
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 27 deletions.
8 changes: 8 additions & 0 deletions include/tvm/script/ir_builder/tir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,14 @@ DeclBufferFrame DeclBuffer(Array<PrimExpr> shape, DataType dtype, String buffer_
*/
LaunchThreadFrame LaunchThread(Var var, PrimExpr extent);

/*!
* \brief Launch a new thread.
* \param thread_tag The thread type tag.
* \param extent The extent of environment thread.
* \return The result LaunchThreadFrame.
*/
LaunchThreadFrame LaunchThread(String thread_tag, PrimExpr extent);

/*!
* \brief Bind a var to thread env.
* \param thread_tag The thread type tag.
Expand Down
4 changes: 3 additions & 1 deletion python/tvm/script/ir_builder/tir/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,4 +115,6 @@ def __enter__(self) -> Buffer:

@_register_object("script.ir_builder.tir.LaunchThreadFrame")
class LaunchThreadFrame(TIRFrame):
...
def __enter__(self) -> Var:
super().__enter__()
return self.iter_var.var
13 changes: 8 additions & 5 deletions python/tvm/script/ir_builder/tir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from tvm import tir
from tvm.ir import Range, Type
from tvm.ir.base import deprecated
from tvm.runtime import convert, ndarray
from tvm.runtime import String, convert, ndarray
from tvm.target import Target

# pylint: disable=unused-import
Expand Down Expand Up @@ -1185,14 +1185,14 @@ def decl_buffer(


def launch_thread(
iter_var: IterVar, # pylint: disable=redefined-outer-name
thread: Union[IterVar, str], # pylint: disable=redefined-outer-name
extent: PrimExpr,
) -> frame.LaunchThreadFrame:
"""Launch a thread.
Parameters
----------
iter_var : IterVar
thread : Union[IterVar, str]
The iteration variable.
extent : PrimExpr
Expand All @@ -1213,11 +1213,14 @@ def launch_thread(
T.launch_thread(brow, 1)
"""
return _ffi_api.LaunchThread(iter_var, extent) # type: ignore[attr-defined] # pylint: disable=no-member

if isinstance(thread, str):
thread = String(thread)
return _ffi_api.LaunchThread(thread, extent) # type: ignore[attr-defined] # pylint: disable=no-member


def env_thread(thread_tag: str) -> IterVar:
"""Bind a var to thread env"
"""Bind a var to thread env
Parameters
----------
Expand Down
17 changes: 16 additions & 1 deletion src/script/ir_builder/tir/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,10 @@ LaunchThreadFrame LaunchThread(Var var, PrimExpr extent) {
return LaunchThreadFrame(n);
}

LaunchThreadFrame LaunchThread(String thread_tag, PrimExpr extent) {
return LaunchThread(EnvThread(thread_tag), extent);
}

RealizeFrame Realize(tvm::tir::BufferRegion buffer_slice, String storage_scope,
PrimExpr condition) {
ObjectPtr<RealizeFrameNode> n = make_object<RealizeFrameNode>();
Expand Down Expand Up @@ -658,7 +662,18 @@ TVM_REGISTER_GLOBAL("script.ir_builder.tir.If").set_body_typed(If);
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Then").set_body_typed(Then);
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Else").set_body_typed(Else);
TVM_REGISTER_GLOBAL("script.ir_builder.tir.DeclBuffer").set_body_typed(DeclBuffer);
TVM_REGISTER_GLOBAL("script.ir_builder.tir.LaunchThread").set_body_typed(LaunchThread);
TVM_REGISTER_GLOBAL("script.ir_builder.tir.LaunchThread")
.set_body_typed([](ObjectRef thread_tag_or_var, PrimExpr extent) {
if (const auto* var = thread_tag_or_var.as<tvm::tir::VarNode>()) {
return LaunchThread(GetRef<tvm::tir::Var>(var), extent);
} else if (const auto* str = thread_tag_or_var.as<StringObj>()) {
return LaunchThread(GetRef<String>(str), extent);
} else {
LOG(FATAL) << "ValueError: Unexpected type for TIR LaunchThread: "
<< thread_tag_or_var->GetTypeKey();
throw;
}
});
TVM_REGISTER_GLOBAL("script.ir_builder.tir.EnvThread").set_body_typed(EnvThread);

TVM_REGISTER_GLOBAL("script.ir_builder.tir.BufferStore").set_body_typed(BufferStore);
Expand Down
75 changes: 55 additions & 20 deletions src/script/printer/tir/stmt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,19 @@ bool AllowConciseScoping(const IRDocsifier& d) {
LOG(FATAL) << "NotImplementedError: fragment printing";
}

bool IsAncestorOfAllVarUse(const tir::Stmt& node, const ObjectRef& var, const IRDocsifier& d) {
if (!d->common_prefix.count(var.get())) {
return false;
}
const std::vector<const Object*>& path = d->common_prefix.at(var.get());
for (auto it = path.rbegin(); it != path.rend(); ++it) {
if (*it == node.get()) {
return true;
}
}
return false;
}

TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::Evaluate>("", [](tir::Evaluate eval, ObjectPath p, IRDocsifier d) -> Doc {
ExprDoc value = d->AsDoc<ExprDoc>(eval->value, p->Attr("value"));
Expand Down Expand Up @@ -322,6 +335,39 @@ ExprDoc DocsifyBufferRealize(const tir::BufferRealizeNode* stmt, Optional<ExprDo
return TIR(d, "realize")->Call(args, kwargs_keys, kwargs_values);
}

void InsertEnvThread(const tir::IterVar& iter_var, const ObjectPath& iter_var_p,
const IRDocsifier& d) {
Frame f = FindLowestVarDef(iter_var->var, d).value();
DefineVar(iter_var->var, f, d);
ExprDoc rhs = TIR(d, "env_thread")
->Call({LiteralDoc::Str(iter_var->thread_tag, //
iter_var_p->Attr("thread_tag"))});
ExprDoc lhs = d->AsDoc<ExprDoc>(iter_var->var, iter_var_p->Attr("var"));
f->stmts.push_back(AssignDoc(lhs, rhs, NullOpt));
}

ExprDoc DocsifyLaunchThread(const tir::AttrStmt& attr_stmt, const ObjectPath& attr_stmt_p,
Optional<tir::Var>* define_var, const IRDocsifier& d) {
tir::IterVar iter_var = Downcast<tir::IterVar>(attr_stmt->node);
ObjectPath iter_var_p = attr_stmt_p->Attr("node");

ExprDoc var_doc{nullptr};
if (d->IsVarDefined(iter_var->var)) {
var_doc = d->AsDoc<ExprDoc>(iter_var->var, iter_var_p->Attr("var"));
} else if (IsAncestorOfAllVarUse(attr_stmt, iter_var->var, d)) {
var_doc = LiteralDoc::Str(iter_var->thread_tag, iter_var_p->Attr("thread_tag"));
*define_var = iter_var->var;
} else {
InsertEnvThread(iter_var, iter_var_p, d);
var_doc = d->AsDoc<ExprDoc>(iter_var->var, iter_var_p->Attr("var"));
}
return TIR(d, "launch_thread")
->Call({
var_doc,
d->AsDoc<ExprDoc>(attr_stmt->value, attr_stmt_p->Attr("value")),
});
}

TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::BufferRealize>( //
"", [](tir::BufferRealize stmt, ObjectPath p, IRDocsifier d) -> Doc {
Expand All @@ -336,7 +382,9 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::AttrStmt>( //
"", [](tir::AttrStmt stmt, ObjectPath stmt_p, IRDocsifier d) -> Doc {
bool concise = AllowConciseScoping(d);
Optional<ExprDoc> lhs = NullOpt;
Optional<ExprDoc> rhs = NullOpt;
Optional<tir::Var> define_var = NullOpt;
tir::Stmt body = stmt->body;
ObjectPath body_p = stmt_p->Attr("body");
if (stmt->attr_key == "realize_scope") {
Expand All @@ -347,29 +395,13 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
/*value=*/d->AsDoc<ExprDoc>(stmt->value, stmt_p->Attr("value")),
/*p=*/stmt_p->Attr("body"), d);
body = realize->body;
body_p = body_p->Attr("body");
body_p = stmt_p->Attr("body")->Attr("body");
}
}
}
if (stmt->attr_key == "thread_extent" || stmt->attr_key == "virtual_thread") {
if (const auto* iter_var = stmt->node.as<tir::IterVarNode>()) {
if (!d->IsVarDefined(iter_var->var)) {
// `DefineVar` is not used here because a more specific name is desirable
ObjectPath iter_var_p = stmt_p->Attr("node");
Frame f = FindLowestVarDef(iter_var->var, d).value();
DefineVar(iter_var->var, f, d);
f->stmts.push_back(
AssignDoc(d->AsDoc<ExprDoc>(iter_var->var, iter_var_p->Attr("var")),
TIR(d, "env_thread")
->Call({LiteralDoc::Str(iter_var->thread_tag,
iter_var_p->Attr("thread_tag"))}), //
NullOpt));
}
rhs = TIR(d, "launch_thread")
->Call({
d->AsDoc<ExprDoc>(iter_var->var, stmt_p->Attr("node")),
d->AsDoc<ExprDoc>(stmt->value, stmt_p->Attr("value")),
});
if (stmt->node->IsInstance<tir::IterVarNode>()) {
rhs = DocsifyLaunchThread(stmt, stmt_p, &define_var, d);
}
}
if (!rhs.defined()) {
Expand All @@ -380,8 +412,11 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
});
}
With<TIRFrame> f(d, stmt);
if (define_var.defined()) {
lhs = DefineVar(define_var.value(), *f, d);
}
AsDocBody(body, body_p, f->get(), d);
return DoConciseScoping(NullOpt, rhs.value(), &(*f)->stmts, concise);
return DoConciseScoping(lhs, rhs.value(), &(*f)->stmts, concise);
});

TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
Expand Down
11 changes: 11 additions & 0 deletions tests/python/unittest/test_tvmscript_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,16 @@ def mmult(A: T.handle, B: T.handle, C: T.handle) -> None:
return Module


def launch_env_thread():
@T.prim_func
def main(inputs: T.Buffer((64, 2, 4), "float32")) -> None:
bx = T.launch_thread("blockIdx.x", 64)
for i, j in T.grid(2, 4):
T.evaluate(inputs[bx, i, j])

return main


def opt_gemm_mod_host():
@tvm.script.ir_module
class Module:
Expand Down Expand Up @@ -3563,6 +3573,7 @@ def func():


ir_generator = tvm.testing.parameter(
launch_env_thread,
opt_gemm_normalize,
opt_gemm_lower,
opt_gemm_mod_host,
Expand Down

0 comments on commit be66a7e

Please sign in to comment.