Skip to content

Commit

Permalink
[Fix]Fix function ObjectPath in IRModule SEqual (#14230)
Browse files Browse the repository at this point in the history
This PR fixes the `IRModuleNode::SEqualReduce` to produce the correct `ObjectPath` for functions in `IRModule`.

And this PR fixes some missing `Doc->source_paths`'s, when printer transforms `Doc` locally.
  • Loading branch information
cyx-6 authored Mar 8, 2023
1 parent 9d732d0 commit e8576db
Show file tree
Hide file tree
Showing 6 changed files with 143 additions and 16 deletions.
21 changes: 21 additions & 0 deletions src/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,27 @@ IRModule::IRModule(tvm::Map<GlobalVar, BaseFunc> functions,
bool IRModuleNode::SEqualReduce(const IRModuleNode* other, SEqualReducer equal) const {
if (functions.size() != other->functions.size()) return false;
if (!equal(this->attrs, other->attrs)) return false;
if (equal.IsPathTracingEnabled()) {
const ObjectPathPair& obj_path_pair = equal.GetCurrentObjectPaths();
for (const auto& kv : this->functions) {
if (!other->ContainGlobalVar(kv.first->name_hint)) return false;
ObjectPathPair func_paths = {obj_path_pair->lhs_path->Attr("functions")->MapValue(kv.first),
obj_path_pair->rhs_path->Attr("functions")
->MapValue(other->GetGlobalVar(kv.first->name_hint))};
if (!equal(kv.second, other->Lookup(kv.first->name_hint), func_paths)) return false;
}
if (type_definitions.size() != other->type_definitions.size()) return false;
for (const auto& kv : this->type_definitions) {
if (!other->ContainGlobalTypeVar(kv.first->name_hint)) return false;
ObjectPathPair type_def_paths = {
obj_path_pair->lhs_path->Attr("type_definitions")->MapValue(kv.first),
obj_path_pair->rhs_path->Attr("type_definitions")
->MapValue(other->GetGlobalTypeVar(kv.first->name_hint))};
if (!equal(kv.second, other->LookupTypeDef(kv.first->name_hint), type_def_paths))
return false;
}
return true;
}
for (const auto& kv : this->functions) {
if (!other->ContainGlobalVar(kv.first->name_hint)) return false;
if (!equal(kv.second, other->Lookup(kv.first->name_hint))) return false;
Expand Down
1 change: 0 additions & 1 deletion src/script/printer/doc_printer/python_doc_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -674,7 +674,6 @@ void PythonDocPrinter::PrintTypedDoc(const ClassDoc& doc) {
PrintBlockComment(doc->comment.value());
}
PrintIndentedBlock(doc->body);
NewLineWithoutIndent();
}

void PythonDocPrinter::PrintTypedDoc(const CommentDoc& doc) {
Expand Down
1 change: 1 addition & 0 deletions src/script/printer/ir/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
d->cfg->binding_names.pop_back();
if (const auto* stmt_block = doc.as<StmtBlockDocNode>()) {
(*f)->stmts.push_back(stmt_block->stmts.back());
(*f)->stmts.back()->source_paths = std::move(doc->source_paths);
} else if (const auto* stmt = doc.as<StmtDocNode>()) {
(*f)->stmts.push_back(GetRef<StmtDoc>(stmt));
} else {
Expand Down
10 changes: 9 additions & 1 deletion src/script/printer/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ inline void RedirectedReprPrinterMethod(const ObjectRef& obj, ReprPrinter* p) {
inline std::string Docsify(const ObjectRef& obj, const IRDocsifier& d, const Frame& f,
const PrinterConfig& cfg) {
Doc doc = d->AsDoc(obj, ObjectPath::Root());
bool move_source_paths = false;
if (const auto* expr_doc = doc.as<ExprDocNode>()) {
if (!cfg->verbose_expr) {
f->stmts.clear();
Expand All @@ -72,6 +73,7 @@ inline std::string Docsify(const ObjectRef& obj, const IRDocsifier& d, const Fra
for (const StmtDoc& d : stmt_block->stmts) {
f->stmts.push_back(d);
}
move_source_paths = true;
} else {
LOG(FATAL) << "TypeError: Unexpected doc type: " << doc->GetTypeKey();
}
Expand All @@ -87,7 +89,13 @@ inline std::string Docsify(const ObjectRef& obj, const IRDocsifier& d, const Fra
CommentDoc("Metadata omitted. Use show_meta=True in script() method to show it."));
}
}
os << DocToPythonScript(StmtBlockDoc(f->stmts), cfg);
if (move_source_paths) {
StmtBlockDoc new_doc(f->stmts);
new_doc->source_paths = std::move(doc->source_paths);
os << DocToPythonScript(new_doc, cfg);
} else {
os << DocToPythonScript(StmtBlockDoc(f->stmts), cfg);
}
return os.str();
}

Expand Down
40 changes: 27 additions & 13 deletions tests/python/unittest/test_tvmscript_printer_structural_equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from tvm.ir import assert_structural_equal
from tvm.relay.op.transform import split
from tvm.runtime import ObjectPath
from tvm.script import tir as T
from tvm.script import ir as I, tir as T


def _error_message(exception):
Expand Down Expand Up @@ -68,21 +68,35 @@ def func2(a: T.handle, b: T.handle):


def test_evaluate():
@T.prim_func
def func1():
T.evaluate(0)

@T.prim_func
def func2():
T.evaluate(1)
@I.ir_module
class module1:
@T.prim_func
def func():
T.evaluate(0)

@I.ir_module
class module2:
@T.prim_func
def func():
T.evaluate(1)

with pytest.raises(ValueError) as ve:
assert_structural_equal(func1, func2)
assert_structural_equal(module1, module2)
assert _error_message(ve.value) == _expected_result(
func1,
func2,
ObjectPath.root().attr("body").attr("value").attr("value"),
ObjectPath.root().attr("body").attr("value").attr("value"),
module1,
module2,
ObjectPath.root()
.attr("functions")
.map_value(module1.get_global_var("func"))
.attr("body")
.attr("value")
.attr("value"),
ObjectPath.root()
.attr("functions")
.map_value(module2.get_global_var("func"))
.attr("body")
.attr("value")
.attr("value"),
)


Expand Down
86 changes: 85 additions & 1 deletion tests/python/unittest/test_tvmscript_printer_underlining.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
StmtBlockDoc,
)
from tvm.script.printer.doc_printer import to_python_script
from tvm.script import tir as T
from tvm.script import ir as I, tir as T


def make_path(name: str) -> ObjectPath:
Expand Down Expand Up @@ -470,3 +470,87 @@ def main():
^^^^^^^^^^^^^
"""
)


def test_underline_func():
@T.prim_func
def func():
T.evaluate(0)

result = func.script(
path_to_underline=[
ObjectPath.root(),
]
)
assert result == format_script(
"""
# from tvm.script import tir as T
@T.prim_func
^^^^^^^^^^^^
def main():
^^^^^^^^^^^
T.evaluate(0)
^^^^^^^^^^^^^
"""
)


def test_underline_func_in_irmodule():
@I.ir_module
class irmodule:
@T.prim_func
def func():
T.evaluate(0)

result = irmodule.script(
path_to_underline=[
ObjectPath.root().attr("functions").map_value(irmodule.get_global_var("func")),
]
)
assert result == format_script(
"""
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
class Module:
@T.prim_func
^^^^^^^^^^^^
def func():
^^^^^^^^^^^
T.evaluate(0)
^^^^^^^^^^^^^
"""
)


def test_underline_irmodule():
@I.ir_module
class irmodule:
@T.prim_func
def func():
T.evaluate(0)

result = irmodule.script(
path_to_underline=[
ObjectPath.root(),
]
)
assert result == format_script(
"""
# from tvm.script import ir as I
# from tvm.script import tir as T
@I.ir_module
^^^^^^^^^^^^
class Module:
^^^^^^^^^^^^^
@T.prim_func
^^^^^^^^^^^^
def func():
^^^^^^^^^^^
T.evaluate(0)
^^^^^^^^^^^^^
"""
)

0 comments on commit e8576db

Please sign in to comment.