From e8576db88bad824a9997b8d499a9730059d0348c Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Wed, 8 Mar 2023 01:34:18 -0800 Subject: [PATCH] [Fix]Fix function ObjectPath in IRModule SEqual (#14230) This PR fixes the `IRModuleNode::SEqualReduce` to produce the correct `ObjectPath` for functions in `IRModule`. And this PR fixes some missing `Doc->source_paths`'s, when printer transforms `Doc` locally. --- src/ir/module.cc | 21 +++++ .../printer/doc_printer/python_doc_printer.cc | 1 - src/script/printer/ir/ir.cc | 1 + src/script/printer/utils.h | 10 ++- ...test_tvmscript_printer_structural_equal.py | 40 ++++++--- .../test_tvmscript_printer_underlining.py | 86 ++++++++++++++++++- 6 files changed, 143 insertions(+), 16 deletions(-) diff --git a/src/ir/module.cc b/src/ir/module.cc index 22c6faf3d69d..42ced9612045 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -65,6 +65,27 @@ IRModule::IRModule(tvm::Map functions, bool IRModuleNode::SEqualReduce(const IRModuleNode* other, SEqualReducer equal) const { if (functions.size() != other->functions.size()) return false; if (!equal(this->attrs, other->attrs)) return false; + if (equal.IsPathTracingEnabled()) { + const ObjectPathPair& obj_path_pair = equal.GetCurrentObjectPaths(); + for (const auto& kv : this->functions) { + if (!other->ContainGlobalVar(kv.first->name_hint)) return false; + ObjectPathPair func_paths = {obj_path_pair->lhs_path->Attr("functions")->MapValue(kv.first), + obj_path_pair->rhs_path->Attr("functions") + ->MapValue(other->GetGlobalVar(kv.first->name_hint))}; + if (!equal(kv.second, other->Lookup(kv.first->name_hint), func_paths)) return false; + } + if (type_definitions.size() != other->type_definitions.size()) return false; + for (const auto& kv : this->type_definitions) { + if (!other->ContainGlobalTypeVar(kv.first->name_hint)) return false; + ObjectPathPair type_def_paths = { + obj_path_pair->lhs_path->Attr("type_definitions")->MapValue(kv.first), + obj_path_pair->rhs_path->Attr("type_definitions") + ->MapValue(other->GetGlobalTypeVar(kv.first->name_hint))}; + if (!equal(kv.second, other->LookupTypeDef(kv.first->name_hint), type_def_paths)) + return false; + } + return true; + } for (const auto& kv : this->functions) { if (!other->ContainGlobalVar(kv.first->name_hint)) return false; if (!equal(kv.second, other->Lookup(kv.first->name_hint))) return false; diff --git a/src/script/printer/doc_printer/python_doc_printer.cc b/src/script/printer/doc_printer/python_doc_printer.cc index 994d048a2e07..e726cd42a241 100644 --- a/src/script/printer/doc_printer/python_doc_printer.cc +++ b/src/script/printer/doc_printer/python_doc_printer.cc @@ -674,7 +674,6 @@ void PythonDocPrinter::PrintTypedDoc(const ClassDoc& doc) { PrintBlockComment(doc->comment.value()); } PrintIndentedBlock(doc->body); - NewLineWithoutIndent(); } void PythonDocPrinter::PrintTypedDoc(const CommentDoc& doc) { diff --git a/src/script/printer/ir/ir.cc b/src/script/printer/ir/ir.cc index e6f4a1eaee2c..065cfe5168ad 100644 --- a/src/script/printer/ir/ir.cc +++ b/src/script/printer/ir/ir.cc @@ -72,6 +72,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) d->cfg->binding_names.pop_back(); if (const auto* stmt_block = doc.as()) { (*f)->stmts.push_back(stmt_block->stmts.back()); + (*f)->stmts.back()->source_paths = std::move(doc->source_paths); } else if (const auto* stmt = doc.as()) { (*f)->stmts.push_back(GetRef(stmt)); } else { diff --git a/src/script/printer/utils.h b/src/script/printer/utils.h index 10c7aaf4f2bb..ec0f0eaf72b0 100644 --- a/src/script/printer/utils.h +++ b/src/script/printer/utils.h @@ -61,6 +61,7 @@ inline void RedirectedReprPrinterMethod(const ObjectRef& obj, ReprPrinter* p) { inline std::string Docsify(const ObjectRef& obj, const IRDocsifier& d, const Frame& f, const PrinterConfig& cfg) { Doc doc = d->AsDoc(obj, ObjectPath::Root()); + bool move_source_paths = false; if (const auto* expr_doc = doc.as()) { if (!cfg->verbose_expr) { f->stmts.clear(); @@ -72,6 +73,7 @@ inline std::string Docsify(const ObjectRef& obj, const IRDocsifier& d, const Fra for (const StmtDoc& d : stmt_block->stmts) { f->stmts.push_back(d); } + move_source_paths = true; } else { LOG(FATAL) << "TypeError: Unexpected doc type: " << doc->GetTypeKey(); } @@ -87,7 +89,13 @@ inline std::string Docsify(const ObjectRef& obj, const IRDocsifier& d, const Fra CommentDoc("Metadata omitted. Use show_meta=True in script() method to show it.")); } } - os << DocToPythonScript(StmtBlockDoc(f->stmts), cfg); + if (move_source_paths) { + StmtBlockDoc new_doc(f->stmts); + new_doc->source_paths = std::move(doc->source_paths); + os << DocToPythonScript(new_doc, cfg); + } else { + os << DocToPythonScript(StmtBlockDoc(f->stmts), cfg); + } return os.str(); } diff --git a/tests/python/unittest/test_tvmscript_printer_structural_equal.py b/tests/python/unittest/test_tvmscript_printer_structural_equal.py index 1b9e0fa9beab..5c587354cc3f 100644 --- a/tests/python/unittest/test_tvmscript_printer_structural_equal.py +++ b/tests/python/unittest/test_tvmscript_printer_structural_equal.py @@ -21,7 +21,7 @@ from tvm.ir import assert_structural_equal from tvm.relay.op.transform import split from tvm.runtime import ObjectPath -from tvm.script import tir as T +from tvm.script import ir as I, tir as T def _error_message(exception): @@ -68,21 +68,35 @@ def func2(a: T.handle, b: T.handle): def test_evaluate(): - @T.prim_func - def func1(): - T.evaluate(0) - - @T.prim_func - def func2(): - T.evaluate(1) + @I.ir_module + class module1: + @T.prim_func + def func(): + T.evaluate(0) + + @I.ir_module + class module2: + @T.prim_func + def func(): + T.evaluate(1) with pytest.raises(ValueError) as ve: - assert_structural_equal(func1, func2) + assert_structural_equal(module1, module2) assert _error_message(ve.value) == _expected_result( - func1, - func2, - ObjectPath.root().attr("body").attr("value").attr("value"), - ObjectPath.root().attr("body").attr("value").attr("value"), + module1, + module2, + ObjectPath.root() + .attr("functions") + .map_value(module1.get_global_var("func")) + .attr("body") + .attr("value") + .attr("value"), + ObjectPath.root() + .attr("functions") + .map_value(module2.get_global_var("func")) + .attr("body") + .attr("value") + .attr("value"), ) diff --git a/tests/python/unittest/test_tvmscript_printer_underlining.py b/tests/python/unittest/test_tvmscript_printer_underlining.py index 7230d4546a9f..4a4d17d0d89b 100644 --- a/tests/python/unittest/test_tvmscript_printer_underlining.py +++ b/tests/python/unittest/test_tvmscript_printer_underlining.py @@ -27,7 +27,7 @@ StmtBlockDoc, ) from tvm.script.printer.doc_printer import to_python_script -from tvm.script import tir as T +from tvm.script import ir as I, tir as T def make_path(name: str) -> ObjectPath: @@ -470,3 +470,87 @@ def main(): ^^^^^^^^^^^^^ """ ) + + +def test_underline_func(): + @T.prim_func + def func(): + T.evaluate(0) + + result = func.script( + path_to_underline=[ + ObjectPath.root(), + ] + ) + assert result == format_script( + """ + # from tvm.script import tir as T + + @T.prim_func + ^^^^^^^^^^^^ + def main(): + ^^^^^^^^^^^ + T.evaluate(0) + ^^^^^^^^^^^^^ + """ + ) + + +def test_underline_func_in_irmodule(): + @I.ir_module + class irmodule: + @T.prim_func + def func(): + T.evaluate(0) + + result = irmodule.script( + path_to_underline=[ + ObjectPath.root().attr("functions").map_value(irmodule.get_global_var("func")), + ] + ) + assert result == format_script( + """ + # from tvm.script import ir as I + # from tvm.script import tir as T + + @I.ir_module + class Module: + @T.prim_func + ^^^^^^^^^^^^ + def func(): + ^^^^^^^^^^^ + T.evaluate(0) + ^^^^^^^^^^^^^ + """ + ) + + +def test_underline_irmodule(): + @I.ir_module + class irmodule: + @T.prim_func + def func(): + T.evaluate(0) + + result = irmodule.script( + path_to_underline=[ + ObjectPath.root(), + ] + ) + assert result == format_script( + """ + # from tvm.script import ir as I + # from tvm.script import tir as T + + @I.ir_module + ^^^^^^^^^^^^ + class Module: + ^^^^^^^^^^^^^ + @T.prim_func + ^^^^^^^^^^^^ + def func(): + ^^^^^^^^^^^ + T.evaluate(0) + ^^^^^^^^^^^^^ + """ + )