Skip to content

Commit

Permalink
[Fix][TVMScript] Fix LetStmt printing logic (#13900)
Browse files Browse the repository at this point in the history
This PR is the bug fix reported in #13892. Initially, we mix the logic of `LetStmt` docsifying method with and without concise scoping. For example, in
```python
x = T.var("int32")
with T.let(x, 0):
```
`x` in the `LetStmt` works as a right value, while in
```python
x: T.int32 = 0
```
`x` in the `LetStmt` works as a left value as result.
Our old logic mixed them together to generate the wrong code for the first case.
Meanwhile, during the fix, we found another bug in concise scoping check. For example, we have
```python
x = T.var("int32")
y = T.var("int32")
with T.let(x, y):
  with T.let(y, 0):
```
here we should not output
```python
x = T.var("int32")
y = T.var("int32")
with T.let(x, y):
  y: int32 = 0
```
becase this will define a new `y_1: int32 = 0` indeed, due the the variable shadowing logic of the parser, which is different from the `y` we define and refer to.
Our concise scoping `v: ... = ...` should launch if and only if the `v` is never defined before.
Otherwise, we use `with T.let(v, ...):` instead.
  • Loading branch information
cyx-6 authored Feb 3, 2023
1 parent e34506c commit 98008c2
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 7 deletions.
18 changes: 11 additions & 7 deletions src/script/printer/tir/stmt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,12 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
.set_dispatch<tir::LetStmt>("", [](tir::LetStmt stmt, ObjectPath p, IRDocsifier d) -> Doc {
bool concise = AllowConciseScoping(d);
ExprDoc rhs = d->AsDoc<ExprDoc>(stmt->value, p->Attr("value"));
With<TIRFrame> f(d, stmt);
ExprDoc lhs = d->IsVarDefined(stmt->var) ? d->GetVarDoc(stmt->var).value()
: DefineVar(stmt->var, *f, d);
AsDocBody(stmt->body, p->Attr("body"), f->get(), d);
Array<StmtDoc>* stmts = &(*f)->stmts;
if (concise) {
if (concise && !d->IsVarDefined(stmt->var)) {
ExprDoc rhs = d->AsDoc<ExprDoc>(stmt->value, p->Attr("value"));
With<TIRFrame> f(d, stmt);
ExprDoc lhs = DefineVar(stmt->var, *f, d);
AsDocBody(stmt->body, p->Attr("body"), f->get(), d);
Array<StmtDoc>* stmts = &(*f)->stmts;
Type type = stmt->var->type_annotation;
Optional<ExprDoc> type_doc =
d->AsDoc<ExprDoc>(type, p->Attr("var")->Attr("type_annotation"));
Expand All @@ -75,6 +74,11 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
stmts->insert(stmts->begin(), AssignDoc(lhs, rhs, type_doc));
return StmtBlockDoc(*stmts);
} else {
ExprDoc lhs = d->AsDoc<ExprDoc>(stmt->var, p->Attr("var"));
ExprDoc rhs = d->AsDoc<ExprDoc>(stmt->value, p->Attr("value"));
With<TIRFrame> f(d, stmt);
AsDocBody(stmt->body, p->Attr("body"), f->get(), d);
Array<StmtDoc>* stmts = &(*f)->stmts;
rhs = TIR(d, "let")->Call({lhs, rhs});
return ScopeDoc(NullOpt, rhs, *stmts);
}
Expand Down
1 change: 1 addition & 0 deletions tests/python/unittest/test_tvmscript_printer_tir.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ def test_let_stmt():
_assert_print(
obj,
"""
v = T.var("float32")
with T.let(v, T.float32(10)):
T.evaluate(0)
""",
Expand Down
28 changes: 28 additions & 0 deletions tests/python/unittest/test_tvmscript_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -3543,6 +3543,32 @@ def func():
return func


def let_stmt_var():
@T.prim_func
def func():
x = T.var("int32")
y = T.var("int32")
with T.let(x, 0):
with T.let(y, 0):
T.evaluate(0)
T.evaluate(0)

return func


def let_stmt_value():
@T.prim_func
def func():
x = T.var("int32")
y = T.var("int32")
with T.let(x, y):
with T.let(y, 0):
T.evaluate(0)
T.evaluate(0)

return func


ir_generator = tvm.testing.parameter(
opt_gemm_normalize,
opt_gemm_lower,
Expand Down Expand Up @@ -3601,6 +3627,8 @@ def func():
*nested_boolean_expressions(),
multi_env_threads,
intrinsic_pow,
let_stmt_var,
let_stmt_value,
)


Expand Down

0 comments on commit 98008c2

Please sign in to comment.