Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TVMScript] Sugar T.env_thread + T.launch_thread #14217

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions include/tvm/script/ir_builder/tir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,14 @@ DeclBufferFrame DeclBuffer(Array<PrimExpr> shape, DataType dtype, String buffer_
*/
LaunchThreadFrame LaunchThread(Var var, PrimExpr extent);

/*!
* \brief Launch a new thread.
* \param thread_tag The thread type tag.
* \param extent The extent of environment thread.
* \return The result LaunchThreadFrame.
*/
LaunchThreadFrame LaunchThread(String thread_tag, PrimExpr extent);

/*!
* \brief Bind a var to thread env.
* \param thread_tag The thread type tag.
Expand Down
4 changes: 3 additions & 1 deletion python/tvm/script/ir_builder/tir/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,4 +115,6 @@ def __enter__(self) -> Buffer:

@_register_object("script.ir_builder.tir.LaunchThreadFrame")
class LaunchThreadFrame(TIRFrame):
...
def __enter__(self) -> Var:
super().__enter__()
return self.iter_var.var
13 changes: 8 additions & 5 deletions python/tvm/script/ir_builder/tir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from tvm import tir
from tvm.ir import Range, Type
from tvm.ir.base import deprecated
from tvm.runtime import convert, ndarray
from tvm.runtime import String, convert, ndarray
from tvm.target import Target

# pylint: disable=unused-import
Expand Down Expand Up @@ -1185,14 +1185,14 @@ def decl_buffer(


def launch_thread(
iter_var: IterVar, # pylint: disable=redefined-outer-name
thread: Union[IterVar, str], # pylint: disable=redefined-outer-name
extent: PrimExpr,
) -> frame.LaunchThreadFrame:
"""Launch a thread.

Parameters
----------
iter_var : IterVar
thread : Union[IterVar, str]
The iteration variable.

extent : PrimExpr
Expand All @@ -1213,11 +1213,14 @@ def launch_thread(
T.launch_thread(brow, 1)

"""
return _ffi_api.LaunchThread(iter_var, extent) # type: ignore[attr-defined] # pylint: disable=no-member

if isinstance(thread, str):
thread = String(thread)
return _ffi_api.LaunchThread(thread, extent) # type: ignore[attr-defined] # pylint: disable=no-member


def env_thread(thread_tag: str) -> IterVar:
"""Bind a var to thread env"
"""Bind a var to thread env

Parameters
----------
Expand Down
17 changes: 16 additions & 1 deletion src/script/ir_builder/tir/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,10 @@ LaunchThreadFrame LaunchThread(Var var, PrimExpr extent) {
return LaunchThreadFrame(n);
}

LaunchThreadFrame LaunchThread(String thread_tag, PrimExpr extent) {
return LaunchThread(EnvThread(thread_tag), extent);
}

RealizeFrame Realize(tvm::tir::BufferRegion buffer_slice, String storage_scope,
PrimExpr condition) {
ObjectPtr<RealizeFrameNode> n = make_object<RealizeFrameNode>();
Expand Down Expand Up @@ -658,7 +662,18 @@ TVM_REGISTER_GLOBAL("script.ir_builder.tir.If").set_body_typed(If);
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Then").set_body_typed(Then);
TVM_REGISTER_GLOBAL("script.ir_builder.tir.Else").set_body_typed(Else);
TVM_REGISTER_GLOBAL("script.ir_builder.tir.DeclBuffer").set_body_typed(DeclBuffer);
TVM_REGISTER_GLOBAL("script.ir_builder.tir.LaunchThread").set_body_typed(LaunchThread);
TVM_REGISTER_GLOBAL("script.ir_builder.tir.LaunchThread")
.set_body_typed([](ObjectRef thread_tag_or_var, PrimExpr extent) {
if (const auto* var = thread_tag_or_var.as<tvm::tir::VarNode>()) {
return LaunchThread(GetRef<tvm::tir::Var>(var), extent);
} else if (const auto* str = thread_tag_or_var.as<StringObj>()) {
return LaunchThread(GetRef<String>(str), extent);
} else {
LOG(FATAL) << "ValueError: Unexpected type for TIR LaunchThread: "
<< thread_tag_or_var->GetTypeKey();
throw;
}
});
TVM_REGISTER_GLOBAL("script.ir_builder.tir.EnvThread").set_body_typed(EnvThread);

TVM_REGISTER_GLOBAL("script.ir_builder.tir.BufferStore").set_body_typed(BufferStore);
Expand Down
75 changes: 55 additions & 20 deletions src/script/printer/tir/stmt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,19 @@ bool AllowConciseScoping(const IRDocsifier& d) {
LOG(FATAL) << "NotImplementedError: fragment printing";
}

bool IsAncestorOfAllVarUse(const tir::Stmt& node, const ObjectRef& var, const IRDocsifier& d) {
if (!d->common_prefix.count(var.get())) {
return false;
}
const std::vector<const Object*>& path = d->common_prefix.at(var.get());
for (auto it = path.rbegin(); it != path.rend(); ++it) {
if (*it == node.get()) {
return true;
}
}
return false;
}

TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::Evaluate>("", [](tir::Evaluate eval, ObjectPath p, IRDocsifier d) -> Doc {
ExprDoc value = d->AsDoc<ExprDoc>(eval->value, p->Attr("value"));
Expand Down Expand Up @@ -322,6 +335,39 @@ ExprDoc DocsifyBufferRealize(const tir::BufferRealizeNode* stmt, Optional<ExprDo
return TIR(d, "realize")->Call(args, kwargs_keys, kwargs_values);
}

void InsertEnvThread(const tir::IterVar& iter_var, const ObjectPath& iter_var_p,
const IRDocsifier& d) {
Frame f = FindLowestVarDef(iter_var->var, d).value();
DefineVar(iter_var->var, f, d);
ExprDoc rhs = TIR(d, "env_thread")
->Call({LiteralDoc::Str(iter_var->thread_tag, //
iter_var_p->Attr("thread_tag"))});
ExprDoc lhs = d->AsDoc<ExprDoc>(iter_var->var, iter_var_p->Attr("var"));
f->stmts.push_back(AssignDoc(lhs, rhs, NullOpt));
}

ExprDoc DocsifyLaunchThread(const tir::AttrStmt& attr_stmt, const ObjectPath& attr_stmt_p,
Optional<tir::Var>* define_var, const IRDocsifier& d) {
tir::IterVar iter_var = Downcast<tir::IterVar>(attr_stmt->node);
ObjectPath iter_var_p = attr_stmt_p->Attr("node");

ExprDoc var_doc{nullptr};
if (d->IsVarDefined(iter_var->var)) {
var_doc = d->AsDoc<ExprDoc>(iter_var->var, iter_var_p->Attr("var"));
} else if (IsAncestorOfAllVarUse(attr_stmt, iter_var->var, d)) {
var_doc = LiteralDoc::Str(iter_var->thread_tag, iter_var_p->Attr("thread_tag"));
*define_var = iter_var->var;
} else {
InsertEnvThread(iter_var, iter_var_p, d);
var_doc = d->AsDoc<ExprDoc>(iter_var->var, iter_var_p->Attr("var"));
}
return TIR(d, "launch_thread")
->Call({
var_doc,
d->AsDoc<ExprDoc>(attr_stmt->value, attr_stmt_p->Attr("value")),
});
}

TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::BufferRealize>( //
"", [](tir::BufferRealize stmt, ObjectPath p, IRDocsifier d) -> Doc {
Expand All @@ -336,7 +382,9 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::AttrStmt>( //
"", [](tir::AttrStmt stmt, ObjectPath stmt_p, IRDocsifier d) -> Doc {
bool concise = AllowConciseScoping(d);
Optional<ExprDoc> lhs = NullOpt;
Optional<ExprDoc> rhs = NullOpt;
Optional<tir::Var> define_var = NullOpt;
tir::Stmt body = stmt->body;
ObjectPath body_p = stmt_p->Attr("body");
if (stmt->attr_key == "realize_scope") {
Expand All @@ -347,29 +395,13 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
/*value=*/d->AsDoc<ExprDoc>(stmt->value, stmt_p->Attr("value")),
/*p=*/stmt_p->Attr("body"), d);
body = realize->body;
body_p = body_p->Attr("body");
body_p = stmt_p->Attr("body")->Attr("body");
}
}
}
if (stmt->attr_key == "thread_extent" || stmt->attr_key == "virtual_thread") {
if (const auto* iter_var = stmt->node.as<tir::IterVarNode>()) {
if (!d->IsVarDefined(iter_var->var)) {
// `DefineVar` is not used here because a more specific name is desirable
ObjectPath iter_var_p = stmt_p->Attr("node");
Frame f = FindLowestVarDef(iter_var->var, d).value();
DefineVar(iter_var->var, f, d);
f->stmts.push_back(
AssignDoc(d->AsDoc<ExprDoc>(iter_var->var, iter_var_p->Attr("var")),
TIR(d, "env_thread")
->Call({LiteralDoc::Str(iter_var->thread_tag,
iter_var_p->Attr("thread_tag"))}), //
NullOpt));
}
rhs = TIR(d, "launch_thread")
->Call({
d->AsDoc<ExprDoc>(iter_var->var, stmt_p->Attr("node")),
d->AsDoc<ExprDoc>(stmt->value, stmt_p->Attr("value")),
});
if (stmt->node->IsInstance<tir::IterVarNode>()) {
rhs = DocsifyLaunchThread(stmt, stmt_p, &define_var, d);
}
}
if (!rhs.defined()) {
Expand All @@ -380,8 +412,11 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
});
}
With<TIRFrame> f(d, stmt);
if (define_var.defined()) {
lhs = DefineVar(define_var.value(), *f, d);
}
AsDocBody(body, body_p, f->get(), d);
return DoConciseScoping(NullOpt, rhs.value(), &(*f)->stmts, concise);
return DoConciseScoping(lhs, rhs.value(), &(*f)->stmts, concise);
});

TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
Expand Down
11 changes: 11 additions & 0 deletions tests/python/unittest/test_tvmscript_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,16 @@ def mmult(A: T.handle, B: T.handle, C: T.handle) -> None:
return Module


def launch_env_thread():
@T.prim_func
def main(inputs: T.Buffer((64, 2, 4), "float32")) -> None:
bx = T.launch_thread("blockIdx.x", 64)
for i, j in T.grid(2, 4):
T.evaluate(inputs[bx, i, j])

return main


def opt_gemm_mod_host():
@tvm.script.ir_module
class Module:
Expand Down Expand Up @@ -3563,6 +3573,7 @@ def func():


ir_generator = tvm.testing.parameter(
launch_env_thread,
opt_gemm_normalize,
opt_gemm_lower,
opt_gemm_mod_host,
Expand Down