diff --git a/src/script/printer/tir/stmt.cc b/src/script/printer/tir/stmt.cc index 2820f9ba6384..7556f820df74 100644 --- a/src/script/printer/tir/stmt.cc +++ b/src/script/printer/tir/stmt.cc @@ -57,13 +57,12 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("", [](tir::LetStmt stmt, ObjectPath p, IRDocsifier d) -> Doc { bool concise = AllowConciseScoping(d); - ExprDoc rhs = d->AsDoc(stmt->value, p->Attr("value")); - With f(d, stmt); - ExprDoc lhs = d->IsVarDefined(stmt->var) ? d->GetVarDoc(stmt->var).value() - : DefineVar(stmt->var, *f, d); - AsDocBody(stmt->body, p->Attr("body"), f->get(), d); - Array* stmts = &(*f)->stmts; - if (concise) { + if (concise && !d->IsVarDefined(stmt->var)) { + ExprDoc rhs = d->AsDoc(stmt->value, p->Attr("value")); + With f(d, stmt); + ExprDoc lhs = DefineVar(stmt->var, *f, d); + AsDocBody(stmt->body, p->Attr("body"), f->get(), d); + Array* stmts = &(*f)->stmts; Type type = stmt->var->type_annotation; Optional type_doc = d->AsDoc(type, p->Attr("var")->Attr("type_annotation")); @@ -75,6 +74,11 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) stmts->insert(stmts->begin(), AssignDoc(lhs, rhs, type_doc)); return StmtBlockDoc(*stmts); } else { + ExprDoc lhs = d->AsDoc(stmt->var, p->Attr("var")); + ExprDoc rhs = d->AsDoc(stmt->value, p->Attr("value")); + With f(d, stmt); + AsDocBody(stmt->body, p->Attr("body"), f->get(), d); + Array* stmts = &(*f)->stmts; rhs = TIR(d, "let")->Call({lhs, rhs}); return ScopeDoc(NullOpt, rhs, *stmts); } diff --git a/tests/python/unittest/test_tvmscript_printer_tir.py b/tests/python/unittest/test_tvmscript_printer_tir.py index 49a33cd0f0e8..6f96b3a3dd31 100644 --- a/tests/python/unittest/test_tvmscript_printer_tir.py +++ b/tests/python/unittest/test_tvmscript_printer_tir.py @@ -254,6 +254,7 @@ def test_let_stmt(): _assert_print( obj, """ +v = T.var("float32") with T.let(v, T.float32(10)): T.evaluate(0) """, diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index 4300c4bbade9..f52b488fef6b 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -3543,6 +3543,32 @@ def func(): return func +def let_stmt_var(): + @T.prim_func + def func(): + x = T.var("int32") + y = T.var("int32") + with T.let(x, 0): + with T.let(y, 0): + T.evaluate(0) + T.evaluate(0) + + return func + + +def let_stmt_value(): + @T.prim_func + def func(): + x = T.var("int32") + y = T.var("int32") + with T.let(x, y): + with T.let(y, 0): + T.evaluate(0) + T.evaluate(0) + + return func + + ir_generator = tvm.testing.parameter( opt_gemm_normalize, opt_gemm_lower, @@ -3601,6 +3627,8 @@ def func(): *nested_boolean_expressions(), multi_env_threads, intrinsic_pow, + let_stmt_var, + let_stmt_value, )