diff --git a/include/tvm/node/script_printer.h b/include/tvm/node/script_printer.h index eca302b395b3..c51d6a52b910 100644 --- a/include/tvm/node/script_printer.h +++ b/include/tvm/node/script_printer.h @@ -59,10 +59,16 @@ class PrinterConfigNode : public Object { bool print_line_numbers = false; /*! \brief Number of context lines to print around the underlined text */ int num_context_lines = -1; - /*! \brief Object path to be underlined */ - Optional path_to_underline = NullOpt; /*! \brief Whether to output with syntax sugar, set false for complete printing. */ bool syntax_sugar = true; + /* \brief Object path to be underlined */ + Array path_to_underline = Array(); + /*! \brief Object path to be annotated. */ + Map path_to_annotate = Map(); + /*! \brief Object to be underlined. */ + Array obj_to_underline = Array(); + /*! \brief Object to be annotated. */ + Map obj_to_annotate = Map(); void VisitAttrs(AttrVisitor* v) { v->Visit("ir_prefix", &ir_prefix); @@ -73,8 +79,11 @@ class PrinterConfigNode : public Object { v->Visit("indent_spaces", &indent_spaces); v->Visit("print_line_numbers", &print_line_numbers); v->Visit("num_context_lines", &num_context_lines); - v->Visit("path_to_underline", &path_to_underline); v->Visit("syntax_sugar", &syntax_sugar); + v->Visit("path_to_underline", &path_to_underline); + v->Visit("path_to_annotate", &path_to_annotate); + v->Visit("obj_to_underline", &obj_to_underline); + v->Visit("obj_to_annotate", &obj_to_annotate); } static constexpr const char* _type_key = "node.PrinterConfig"; diff --git a/include/tvm/script/printer/ir_docsifier.h b/include/tvm/script/printer/ir_docsifier.h index c41827fe9530..9225e7de3369 100644 --- a/include/tvm/script/printer/ir_docsifier.h +++ b/include/tvm/script/printer/ir_docsifier.h @@ -263,11 +263,50 @@ inline void FrameNode::ExitWithScope() { } } +template +inline static void AddDocDecoration(const Doc& d, const ObjectRef& obj, const ObjectPath& path, + const PrinterConfig& cfg) { + if (cfg->obj_to_annotate.count(obj)) { + if (const auto* stmt = d.as()) { + if (stmt->comment.defined()) { + stmt->comment = stmt->comment.value() + "\n" + cfg->obj_to_annotate.at(obj); + } else { + stmt->comment = cfg->obj_to_annotate.at(obj); + } + } else { + LOG(WARNING) << "Expect StmtDoc to be annotated for object " << obj << ", but got " + << Downcast(d)->_type_key; + } + } + for (const ObjectRef& o : cfg->obj_to_underline) { + if (o.same_as(obj)) { + cfg->path_to_underline.push_back(path); + } + } + for (const auto& pair : cfg->path_to_annotate) { + ObjectPath p = pair.first; + String attn = pair.second; + if (p->IsPrefixOf(path) && path->IsPrefixOf(p)) { + if (const auto* stmt = d.as()) { + if (stmt->comment.defined()) { + stmt->comment = stmt->comment.value() + "\n" + attn; + } else { + stmt->comment = attn; + } + } else { + LOG(WARNING) << "Expect StmtDoc to be annotated at object path " << p << ", but got " + << Downcast(d)->_type_key; + } + } + } +} + template inline TDoc IRDocsifierNode::AsDoc(const ObjectRef& obj, const ObjectPath& path) const { if (obj.defined()) { Doc d = IRDocsifier::vtable()(dispatch_tokens.back(), obj, path, GetRef(this)); d->source_paths.push_back(path); + AddDocDecoration(d, obj, path, cfg); return Downcast(d); } return Downcast(LiteralDoc::None(path)); diff --git a/python/tvm/runtime/object_path.py b/python/tvm/runtime/object_path.py index c4ec58a59697..ecca85d53da3 100644 --- a/python/tvm/runtime/object_path.py +++ b/python/tvm/runtime/object_path.py @@ -89,6 +89,8 @@ def map_value(self, key) -> "ObjectPath": def missing_map_entry(self) -> "ObjectPath": return _ffi_node_api.ObjectPathMissingMapEntry(self) + __hash__ = Object.__hash__ + @tvm._ffi.register_object("RootPath") class RootPath(ObjectPath): diff --git a/python/tvm/runtime/script_printer.py b/python/tvm/runtime/script_printer.py index 19d8e34ce85c..6838865490ad 100644 --- a/python/tvm/runtime/script_printer.py +++ b/python/tvm/runtime/script_printer.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """Configuration of TVMScript printer""" -from typing import Optional +from typing import List, Dict, Optional from tvm._ffi import register_object from tvm.runtime import Object @@ -38,8 +38,11 @@ class PrinterConfig(Object): indent_spaces: int print_line_numbers: bool num_context_lines: int - path_to_underline: Optional[ObjectPath] syntax_sugar: bool + path_to_underline: Optional[List[ObjectPath]] + path_to_annotate: Optional[Dict[ObjectPath, str]] + obj_to_underline: Optional[List[Object]] + obj_to_annotate: Optional[Dict[Object, str]] def __init__( self, @@ -54,8 +57,11 @@ def __init__( indent_spaces: int = 4, print_line_numbers: bool = False, num_context_lines: Optional[int] = None, - path_to_underline: Optional[ObjectPath] = None, syntax_sugar: bool = True, + path_to_underline: Optional[List[ObjectPath]] = None, + path_to_annotate: Optional[Dict[ObjectPath, str]] = None, + obj_to_underline: Optional[List[Object]] = None, + obj_to_annotate: Optional[Dict[Object, str]] = None, ) -> None: if num_context_lines is None: num_context_lines = -1 @@ -72,8 +78,11 @@ def __init__( "indent_spaces": indent_spaces, "print_line_numbers": print_line_numbers, "num_context_lines": num_context_lines, - "path_to_underline": path_to_underline, "syntax_sugar": syntax_sugar, + "path_to_underline": path_to_underline, + "path_to_annotate": path_to_annotate, + "obj_to_underline": obj_to_underline, + "obj_to_annotate": obj_to_annotate, }, ) @@ -98,8 +107,11 @@ def script( indent_spaces: int = 4, print_line_numbers: bool = False, num_context_lines: int = -1, - path_to_underline: Optional[ObjectPath] = None, syntax_sugar: bool = True, + path_to_underline: Optional[List[ObjectPath]] = None, + path_to_annotate: Optional[Dict[ObjectPath, str]] = None, + obj_to_underline: Optional[List[Object]] = None, + obj_to_annotate: Optional[Dict[Object, str]] = None, ) -> str: """Print TVM IR into TVMScript text format @@ -125,10 +137,16 @@ def script( Whether to print line numbers num_context_lines : int = -1 The number of lines of context to print before and after the line to underline. - path_to_underline : Optional[ObjectPath] = None - Object path to be underlined syntax_sugar: bool = True Whether to output with syntax sugar, set false for complete printing. + path_to_underline : Optional[List[ObjectPath]] = None + Object path to be underlined + path_to_annotate : Optional[Dict[ObjectPath, str]] = None + Object path to be annotated + obj_to_underline : Optional[List[Object]] = None + Object to be underlined + obj_to_annotate : Optional[Dict[Object, str]] = None + Object to be annotated Returns ------- @@ -148,8 +166,11 @@ def script( indent_spaces=indent_spaces, print_line_numbers=print_line_numbers, num_context_lines=num_context_lines, - path_to_underline=path_to_underline, syntax_sugar=syntax_sugar, + path_to_underline=path_to_underline, + path_to_annotate=path_to_annotate, + obj_to_underline=obj_to_underline, + obj_to_annotate=obj_to_annotate, ), ) @@ -168,8 +189,11 @@ def show( indent_spaces: int = 4, print_line_numbers: bool = False, num_context_lines: int = -1, - path_to_underline: Optional[ObjectPath] = None, syntax_sugar: bool = True, + path_to_underline: Optional[List[ObjectPath]] = None, + path_to_annotate: Optional[Dict[ObjectPath, str]] = None, + obj_to_underline: Optional[List[Object]] = None, + obj_to_annotate: Optional[Dict[Object, str]] = None, ) -> None: """A sugar for print highlighted TVM script. @@ -200,10 +224,16 @@ def show( Whether to print line numbers num_context_lines : int = -1 The number of lines of context to print before and after the line to underline. - path_to_underline : Optional[ObjectPath] = None - Object path to be underlined syntax_sugar: bool = True Whether to output with syntax sugar, set false for complete printing. + path_to_underline : Optional[List[ObjectPath]] = None + Object path to be underlined + path_to_annotate : Optional[Dict[ObjectPath, str]] = None + Object path to be annotated + obj_to_underline : Optional[List[Object]] = None + Object to be underlined + obj_to_annotate : Optional[Dict[Object, str]] = None + Object to be annotated """ from tvm.script.highlight import ( # pylint: disable=import-outside-toplevel cprint, @@ -221,8 +251,11 @@ def show( indent_spaces=indent_spaces, print_line_numbers=print_line_numbers, num_context_lines=num_context_lines, - path_to_underline=path_to_underline, syntax_sugar=syntax_sugar, + path_to_underline=path_to_underline, + path_to_annotate=path_to_annotate, + obj_to_underline=obj_to_underline, + obj_to_annotate=obj_to_annotate, ), style=style, black_format=black_format, diff --git a/python/tvm/script/printer/doc_printer.py b/python/tvm/script/printer/doc_printer.py index 137b71a77d9f..b43ca3b5333e 100644 --- a/python/tvm/script/printer/doc_printer.py +++ b/python/tvm/script/printer/doc_printer.py @@ -16,7 +16,7 @@ # under the License. """Functions to print doc into text format""" -from typing import Optional +from typing import List, Optional from tvm.runtime import ObjectPath from tvm.runtime.script_printer import PrinterConfig @@ -30,7 +30,7 @@ def to_python_script( indent_spaces: int = 4, print_line_numbers: bool = False, num_context_lines: Optional[int] = None, - path_to_underline: Optional[ObjectPath] = None, + path_to_underline: Optional[List[ObjectPath]] = None, ) -> str: """Convert Doc into Python script. diff --git a/src/node/script_printer.cc b/src/node/script_printer.cc index d8787259b50e..fcd3c53d026c 100644 --- a/src/node/script_printer.cc +++ b/src/node/script_printer.cc @@ -68,7 +68,18 @@ PrinterConfig::PrinterConfig(Map config_dict) { n->num_context_lines = Downcast(v)->value; } if (auto v = config_dict.Get("path_to_underline")) { - n->path_to_underline = Downcast(v); + n->path_to_underline = Downcast>>(v).value_or(Array()); + } + if (auto v = config_dict.Get("path_to_annotate")) { + n->path_to_annotate = + Downcast>>(v).value_or(Map()); + } + if (auto v = config_dict.Get("obj_to_underline")) { + n->obj_to_underline = Downcast>>(v).value_or(Array()); + } + if (auto v = config_dict.Get("obj_to_annotate")) { + n->obj_to_annotate = + Downcast>>(v).value_or(Map()); } if (auto v = config_dict.Get("syntax_sugar")) { n->syntax_sugar = Downcast(v)->value; diff --git a/src/node/structural_equal.cc b/src/node/structural_equal.cc index 788f3b7a1f3f..42726af9859a 100644 --- a/src/node/structural_equal.cc +++ b/src/node/structural_equal.cc @@ -326,9 +326,9 @@ class SEqualHandlerDefault::Impl { if (first_mismatch_->defined()) { oss << " at " << first_mismatch_->value()->lhs_path; if (root_lhs_.defined()) { - Map dict = {{"path_to_underline", first_mismatch_->value()->lhs_path}, - {"syntax_sugar", Bool(false)}}; - PrinterConfig cfg(dict); + PrinterConfig cfg; + cfg->syntax_sugar = false; + cfg->path_to_underline.push_back(first_mismatch_->value()->lhs_path); // The TVMScriptPrinter::Script will fallback to Repr printer, // if the root node to print is not supported yet, // e.g. Relay nodes, ArrayNode, MapNode, etc. @@ -341,9 +341,9 @@ class SEqualHandlerDefault::Impl { if (first_mismatch_->defined()) { oss << " at " << first_mismatch_->value()->rhs_path; if (root_rhs_.defined()) { - Map dict = {{"path_to_underline", first_mismatch_->value()->rhs_path}, - {"syntax_sugar", Bool(false)}}; - PrinterConfig cfg(dict); + PrinterConfig cfg; + cfg->syntax_sugar = false; + cfg->path_to_underline.push_back(first_mismatch_->value()->rhs_path); // The TVMScriptPrinter::Script will fallback to Repr printer, // if the root node to print is not supported yet, // e.g. Relay nodes, ArrayNode, MapNode, etc. diff --git a/src/script/printer/doc_printer/base_doc_printer.cc b/src/script/printer/doc_printer/base_doc_printer.cc index 8df599347f07..712796e7a1dd 100644 --- a/src/script/printer/doc_printer/base_doc_printer.cc +++ b/src/script/printer/doc_printer/base_doc_printer.cc @@ -25,20 +25,41 @@ namespace printer { namespace { -void SortAndMergeSpans(std::vector* spans) { - if (spans->empty()) { - return; +std::vector MergeAndExemptSpans(const std::vector& spans, + const std::vector& spans_exempted) { + // use prefix sum to merge and exempt spans + std::vector res; + std::vector> prefix_stamp; + for (ByteSpan span : spans) { + prefix_stamp.push_back({span.first, 1}); + prefix_stamp.push_back({span.second, -1}); } - std::sort(spans->begin(), spans->end()); - auto last = spans->begin(); - for (auto cur = spans->begin() + 1; cur != spans->end(); ++cur) { - if (cur->first > last->second) { - *++last = *cur; - } else if (cur->second > last->second) { - last->second = cur->second; + // at most spans.size() spans accumulated in prefix sum + // use spans.size() + 1 as stamp unit to exempt all positive spans + // with only one negative span + int max_n = spans.size() + 1; + for (ByteSpan span : spans_exempted) { + prefix_stamp.push_back({span.first, -max_n}); + prefix_stamp.push_back({span.second, max_n}); + } + std::sort(prefix_stamp.begin(), prefix_stamp.end()); + int prefix_sum = 0; + int n = prefix_stamp.size(); + for (int i = 0; i < n - 1; ++i) { + prefix_sum += prefix_stamp[i].second; + // positive prefix sum leads to spans without exemption + // different stamp positions guarantee the stamps in same position accumulated + if (prefix_sum > 0 && prefix_stamp[i].first < prefix_stamp[i + 1].first) { + if (res.size() && res.back().second == prefix_stamp[i].first) { + // merge to the last spans if it is successive + res.back().second = prefix_stamp[i + 1].first; + } else { + // add a new independent span + res.push_back({prefix_stamp[i].first, prefix_stamp[i + 1].first}); + } } } - spans->erase(++last, spans->end()); + return res; } size_t GetTextWidth(const std::string& text, const ByteSpan& span) { @@ -234,22 +255,24 @@ std::string DecorateText(const std::string& text, const std::vector& lin return ret; } -} // anonymous namespace +} // namespace DocPrinter::DocPrinter(const PrinterConfig& options) : options_(options) { line_starts_.push_back(0); } -void DocPrinter::Append(const Doc& doc) { Append(doc, NullOpt); } +void DocPrinter::Append(const Doc& doc) { Append(doc, PrinterConfig()); } -void DocPrinter::Append(const Doc& doc, Optional path_to_underline) { - path_to_underline_ = path_to_underline; - current_max_path_length_ = 0; - current_underline_candidates_.clear(); +void DocPrinter::Append(const Doc& doc, const PrinterConfig& cfg) { + for (const ObjectPath& p : cfg->path_to_underline) { + path_to_underline_.push_back(p); + current_max_path_length_.push_back(0); + current_underline_candidates_.push_back(std::vector()); + } PrintDoc(doc); - - underlines_.insert(underlines_.end(), current_underline_candidates_.begin(), - current_underline_candidates_.end()); + for (const auto& c : current_underline_candidates_) { + underlines_.insert(underlines_.end(), c.begin(), c.end()); + } } String DocPrinter::GetString() const { @@ -264,9 +287,8 @@ String DocPrinter::GetString() const { text.push_back('\n'); } - std::vector underlines = underlines_; - SortAndMergeSpans(&underlines); - return DecorateText(text, line_starts_, options_, underlines); + return DecorateText(text, line_starts_, options_, + MergeAndExemptSpans(underlines_, underlines_exempted_)); } void DocPrinter::PrintDoc(const Doc& doc) { @@ -332,14 +354,15 @@ void DocPrinter::PrintDoc(const Doc& doc) { } void DocPrinter::MarkSpan(const ByteSpan& span, const ObjectPath& path) { - if (path_to_underline_.defined()) { - if (path->Length() >= current_max_path_length_ && - path->IsPrefixOf(path_to_underline_.value())) { - if (path->Length() > current_max_path_length_) { - current_max_path_length_ = path->Length(); - current_underline_candidates_.clear(); + int n = path_to_underline_.size(); + for (int i = 0; i < n; ++i) { + ObjectPath p = path_to_underline_[i]; + if (path->Length() >= current_max_path_length_[i] && path->IsPrefixOf(p)) { + if (path->Length() > current_max_path_length_[i]) { + current_max_path_length_[i] = path->Length(); + current_underline_candidates_[i].clear(); } - current_underline_candidates_.push_back(span); + current_underline_candidates_[i].push_back(span); } } } diff --git a/src/script/printer/doc_printer/base_doc_printer.h b/src/script/printer/doc_printer/base_doc_printer.h index f5cf40a23357..aff587062d07 100644 --- a/src/script/printer/doc_printer/base_doc_printer.h +++ b/src/script/printer/doc_printer/base_doc_printer.h @@ -71,7 +71,7 @@ class DocPrinter { * * \sa GetString */ - void Append(const Doc& doc, Optional path_to_underline); + void Append(const Doc& doc, const PrinterConfig& cfg); /*! * \brief Get the printed string of all Doc appended @@ -232,9 +232,12 @@ class DocPrinter { * \sa output_ */ std::ostream& NewLine() { + size_t start_pos = output_.tellp(); output_ << "\n"; line_starts_.push_back(output_.tellp()); output_ << std::string(indent_, ' '); + size_t end_pos = output_.tellp(); + underlines_exempted_.push_back({start_pos, end_pos}); return output_; } @@ -248,6 +251,9 @@ class DocPrinter { */ std::ostringstream output_; + /*! \brief Spans that we have already committed to underline exemption. */ + std::vector underlines_exempted_; + private: void MarkSpan(const ByteSpan& span, const ObjectPath& path); @@ -261,16 +267,16 @@ class DocPrinter { std::vector line_starts_; /*! \brief Path of the object that we would like to underline */ - Optional path_to_underline_; + Array path_to_underline_; /*! * \brief Candidate spans to be underlined, until we find a better match. * (A better match is an object with a longer path that is still a prefix of path_to_underline_.) */ - std::vector current_underline_candidates_; + std::vector> current_underline_candidates_; /*! \brief Path length of the objects that are current candidates for underlining. */ - int current_max_path_length_; + std::vector current_max_path_length_; /*! \brief Spans that we have already committed to underline. */ std::vector underlines_; diff --git a/src/script/printer/doc_printer/python_doc_printer.cc b/src/script/printer/doc_printer/python_doc_printer.cc index 9d20afa148b5..e9a3b3567ec0 100644 --- a/src/script/printer/doc_printer/python_doc_printer.cc +++ b/src/script/printer/doc_printer/python_doc_printer.cc @@ -173,7 +173,12 @@ class PythonDocPrinter : public DocPrinter { void PrintTypedDoc(const DocStringDoc& doc) final; private: - void NewLineWithoutIndent() { output_ << "\n"; } + void NewLineWithoutIndent() { + size_t start_pos = output_.tellp(); + output_ << "\n"; + size_t end_pos = output_.tellp(); + underlines_exempted_.push_back({start_pos, end_pos}); + } template void PrintJoinedDocs(const Array& docs, const std::string& separator) { @@ -251,7 +256,10 @@ class PythonDocPrinter : public DocPrinter { bool has_newline = std::find(comment.begin(), comment.end(), '\n') != comment.end(); CHECK(!has_newline) << "ValueError: the comment string of " << stmt->GetTypeKey() << " cannot have newline."; + size_t start_pos = output_.tellp(); output_ << " # " << comment; + size_t end_pos = output_.tellp(); + underlines_exempted_.push_back({start_pos, end_pos}); } } @@ -259,6 +267,7 @@ class PythonDocPrinter : public DocPrinter { if (stmt->comment.defined()) { std::vector comment_lines = support::Split(stmt->comment.value(), '\n'); bool first_line = true; + size_t start_pos = output_.tellp(); for (const std::string& line : comment_lines) { if (first_line) { output_ << "# " << line; @@ -267,15 +276,17 @@ class PythonDocPrinter : public DocPrinter { NewLine() << "# " << line; } } + size_t end_pos = output_.tellp(); + underlines_exempted_.push_back({start_pos, end_pos}); if (new_line) { NewLine(); } } } - void PrintBlockComment(const String& comment) { - IncreaseIndent(); - NewLine() << "\"\"\""; + void PrintDocString(const String& comment) { + size_t start_pos = output_.tellp(); + output_ << "\"\"\""; std::vector comment_lines = support::Split(comment, '\n'); for (const std::string& line : comment_lines) { @@ -288,6 +299,14 @@ class PythonDocPrinter : public DocPrinter { } NewLine() << "\"\"\""; + size_t end_pos = output_.tellp(); + underlines_exempted_.push_back({start_pos, end_pos}); + } + + void PrintBlockComment(const String& comment) { + IncreaseIndent(); + NewLine(); + PrintDocString(comment); DecreaseIndent(); } }; @@ -662,7 +681,7 @@ void PythonDocPrinter::PrintTypedDoc(const CommentDoc& doc) { void PythonDocPrinter::PrintTypedDoc(const DocStringDoc& doc) { if (doc->comment.defined() && !doc->comment.value().empty()) { - output_ << "\"\"\"" << doc->comment.value() << "\"\"\""; + PrintDocString(doc->comment.value()); } } @@ -671,7 +690,7 @@ String DocToPythonScript(Doc doc, const PrinterConfig& cfg) { cfg->num_context_lines = std::numeric_limits::max(); } PythonDocPrinter printer(cfg); - printer.Append(doc, cfg->path_to_underline); + printer.Append(doc, cfg); std::string result = printer.GetString(); int last_space = result.size(); while (last_space > 0 && std::isspace(result[last_space - 1])) { diff --git a/src/script/printer/tir/block.cc b/src/script/printer/tir/block.cc index 979a27135cca..92a80eb36dba 100644 --- a/src/script/printer/tir/block.cc +++ b/src/script/printer/tir/block.cc @@ -215,7 +215,11 @@ Doc PrintBlock(IRDocsifier d, tir::Block block, ObjectPath block_p, // TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( "", [](tir::BlockRealize realize, ObjectPath p, IRDocsifier d) -> Doc { - return PrintBlock(d, realize->block, p->Attr("block"), realize, p); + Doc doc = PrintBlock(d, realize->block, p->Attr("block"), realize, p); + // since we do not have d->AsDoc for realize->block, + // we should add possible doc decoration manually. + AddDocDecoration(doc, realize->block, p->Attr("block"), d->cfg); + return doc; }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) diff --git a/src/tir/schedule/error.cc b/src/tir/schedule/error.cc index 1aae0202ac42..479fc34c75af 100644 --- a/src/tir/schedule/error.cc +++ b/src/tir/schedule/error.cc @@ -30,6 +30,8 @@ String ScheduleError::RenderReport(const String& primitive) const { std::unordered_map loc_obj_to_name; int n_locs = locs.size(); std::string msg = DetailRenderTemplate(); + PrinterConfig cfg; + cfg->syntax_sugar = false; if (n_locs > 0) { for (int i = 0; i < n_locs; ++i) { std::string name = locs[i]->GetTypeKey() + '#' + std::to_string(i); @@ -37,25 +39,13 @@ String ScheduleError::RenderReport(const String& primitive) const { for (size_t pos; (pos = msg.find(src)) != std::string::npos;) { msg.replace(pos, src.length(), name); } - loc_obj_to_name.emplace(locs[i], std::move(name)); + cfg->obj_to_annotate.Set(locs[i], name); + cfg->obj_to_underline.push_back(locs[i]); } } - - // print IR module - runtime::TypedPackedFunc annotate = - runtime::TypedPackedFunc( - [&loc_obj_to_name](const Stmt& expr) -> std::string { - auto it = loc_obj_to_name.find(Downcast(expr)); - if (it == loc_obj_to_name.end()) { - return ""; - } - return it->second; - }); - const auto* f = runtime::Registry::Get("script.AsTVMScriptWithDiagnostic"); - ICHECK(f != nullptr); os << "ScheduleError: An error occurred in the schedule primitive '" << primitive << "'.\n\nThe IR with diagnostic is:\n" - << ((*f)(mod, "T", false, annotate).operator String()); + << TVMScriptPrinter::Script(mod, cfg) << std::endl; // print error message os << "Error message: " << msg; diff --git a/tests/python/unittest/test_tvmscript_error_report.py b/tests/python/unittest/test_tvmscript_error_report.py index f542080f89f9..d2f275ac3d5f 100644 --- a/tests/python/unittest/test_tvmscript_error_report.py +++ b/tests/python/unittest/test_tvmscript_error_report.py @@ -459,9 +459,9 @@ def test_reorder_fail_block(): with pytest.raises(tvm.tir.ScheduleError) as execinfo: sch.reorder(l, i) expected_sub_error_message = ( - " # tir.Block#0\n" - ' with T.block("B"):\n' - " ^^^^^^^^^^^^^^^^^^\n" + " # tir.Block#0\n" + ' with T.block("B"):\n' + " ^^^^^^^^^^^^^^^^^^\n" ) assert expected_sub_error_message in str(execinfo.value) @@ -473,10 +473,10 @@ def test_reorder_fail_nested_loop_inner(): with pytest.raises(tvm.tir.ScheduleError) as execinfo: sch.reorder(k, i) expected_sub_error_message = ( - " for i in T.serial(128):\n" - " # tir.For#0\n" - " for j in T.serial(128):\n" - " ^^^^^^^^^^^^^^^^^^^^^^^\n" + " for i in range(128):\n" + " # tir.For#0\n" + " for j in range(128):\n" + " ^^^^^^^^^^^^^^^^^^^^\n" ) assert expected_sub_error_message in str(execinfo.value) @@ -488,10 +488,10 @@ def test_fuse_fail_nested_loop_outer(): with pytest.raises(tvm.tir.ScheduleError) as execinfo: sch.fuse(k, i) expected_sub_error_message = ( - " # tir.For#1\n" - " for i in T.serial(128):\n" - " ^^^^^^^^^^^^^^^^^^^^^^^\n" - " for j in T.serial(128):\n" + " # tir.For#1\n" + " for i in range(128):\n" + " ^^^^^^^^^^^^^^^^^^^^\n" + " for j in range(128):\n" ) assert expected_sub_error_message in str(execinfo.value) diff --git a/tests/python/unittest/test_tvmscript_printer_annotation.py b/tests/python/unittest/test_tvmscript_printer_annotation.py new file mode 100644 index 000000000000..70d5b655fb37 --- /dev/null +++ b/tests/python/unittest/test_tvmscript_printer_annotation.py @@ -0,0 +1,86 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Optional + +import pytest +from tvm.runtime import ObjectPath +from tvm.script import tir as T + + +@T.prim_func +def _func(): + T.evaluate(0) + T.evaluate(1) + T.evaluate(2) + T.evaluate(3) + T.evaluate(4) + T.evaluate(5) + T.evaluate(6) + T.evaluate(7) + + +def test_annotation_multi_object_paths(): + result = _func.script( + path_to_annotate={ + ObjectPath.root().attr("body").attr("seq").array_index(1): "annotation 1", + ObjectPath.root().attr("body").attr("seq").array_index(3): "annotation 3", + ObjectPath.root().attr("body").attr("seq").array_index(5): "annotation 5", + ObjectPath.root().attr("body").attr("seq").array_index(7): "annotation 7", + } + ) + assert ( + result + == """# from tvm.script import tir as T + +@T.prim_func +def main(): + T.evaluate(0) + T.evaluate(1) # annotation 1 + T.evaluate(2) + T.evaluate(3) # annotation 3 + T.evaluate(4) + T.evaluate(5) # annotation 5 + T.evaluate(6) + T.evaluate(7) # annotation 7""" + ) + + +def test_annotate_from_multi_obj(): + result = _func.script( + obj_to_annotate={ + _func.body.seq[1]: "annotation 1", + _func.body.seq[3]: "annotation 3", + _func.body.seq[5]: "annotation 5", + _func.body.seq[7]: "annotation 7", + } + ) + assert ( + result + == """# from tvm.script import tir as T + +@T.prim_func +def main(): + T.evaluate(0) + T.evaluate(1) # annotation 1 + T.evaluate(2) + T.evaluate(3) # annotation 3 + T.evaluate(4) + T.evaluate(5) # annotation 5 + T.evaluate(6) + T.evaluate(7) # annotation 7""" + ) diff --git a/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py b/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py index 75beb59d02cf..d1eb34f1588d 100644 --- a/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py +++ b/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py @@ -924,13 +924,19 @@ def test_print_comment_doc(comment, expected): ), ( "test comment 1", - '"""test comment 1"""', + ''' + """ + test comment 1 + """ + ''', ), ( "test comment 1\ntest comment 2", ''' - """test comment 1 - test comment 2""" + """ + test comment 1 + test comment 2 + """ ''', ), ], diff --git a/tests/python/unittest/test_tvmscript_printer_structural_equal.py b/tests/python/unittest/test_tvmscript_printer_structural_equal.py index 4bd578eab768..1b9e0fa9beab 100644 --- a/tests/python/unittest/test_tvmscript_printer_structural_equal.py +++ b/tests/python/unittest/test_tvmscript_printer_structural_equal.py @@ -31,9 +31,9 @@ def _error_message(exception): def _expected_result(func1, func2, objpath1, objpath2): return f"""ValueError: StructuralEqual check failed, caused by lhs at {objpath1}: -{func1.script(path_to_underline=objpath1, syntax_sugar=False)} +{func1.script(path_to_underline=[objpath1], syntax_sugar=False)} and rhs at {objpath2}: -{func2.script(path_to_underline=objpath2, syntax_sugar=False)}""" +{func2.script(path_to_underline=[objpath2], syntax_sugar=False)}""" def test_prim_func_buffer_map(): diff --git a/tests/python/unittest/test_tvmscript_printer_underlining.py b/tests/python/unittest/test_tvmscript_printer_underlining.py index 467aad2df517..7230d4546a9f 100644 --- a/tests/python/unittest/test_tvmscript_printer_underlining.py +++ b/tests/python/unittest/test_tvmscript_printer_underlining.py @@ -27,6 +27,7 @@ StmtBlockDoc, ) from tvm.script.printer.doc_printer import to_python_script +from tvm.script import tir as T def make_path(name: str) -> ObjectPath: @@ -69,7 +70,7 @@ def test_underline_basic(): ExprStmtDoc(make_id_doc("qux")), ] ) - assert to_python_script(doc, path_to_underline=make_path("baz")) == format_script( + assert to_python_script(doc, path_to_underline=[make_path("baz")]) == format_script( """ foo bar + baz @@ -87,7 +88,7 @@ def test_underline_multiple_spans(): ExprStmtDoc(OperationDoc(OperationKind.Add, [make_id_doc("foo"), make_id_doc("foo")])), ] ) - assert to_python_script(doc, path_to_underline=make_path("foo")) == format_script( + assert to_python_script(doc, path_to_underline=[make_path("foo")]) == format_script( """ foo ^^^ @@ -107,7 +108,7 @@ def test_underline_multiple_spans_with_line_numbers(): ] ) assert to_python_script( - doc, print_line_numbers=True, path_to_underline=make_path("foo") + doc, print_line_numbers=True, path_to_underline=[make_path("foo")] ) == format_script( """ 1 foo @@ -128,7 +129,7 @@ def test_underline_multiline(): ) doc.source_paths = [make_path("whole_doc")] - assert to_python_script(doc, path_to_underline=make_path("whole_doc")) == format_script( + assert to_python_script(doc, path_to_underline=[make_path("whole_doc")]) == format_script( """ foo ^^^ @@ -282,13 +283,13 @@ def test_print_two_context_lines(to_underline, expected_text): doc = StmtBlockDoc( [ExprStmtDoc(make_id_doc(f"x{i}", "yes" if i in to_underline else "no")) for i in range(10)] ) - result = to_python_script(doc, num_context_lines=2, path_to_underline=make_path("yes")) + result = to_python_script(doc, num_context_lines=2, path_to_underline=[make_path("yes")]) assert result == format_script(expected_text) def test_underline_and_print_line_numbers(): doc = StmtBlockDoc([ExprStmtDoc(make_id_doc(f"line{i + 1}")) for i in range(12)]) - result = to_python_script(doc, print_line_numbers=True, path_to_underline=make_path("line6")) + result = to_python_script(doc, print_line_numbers=True, path_to_underline=[make_path("line6")]) assert ( result.strip() == format_script( @@ -311,10 +312,46 @@ def test_underline_and_print_line_numbers(): ) +def test_underline_multi_object_paths(): + doc = StmtBlockDoc([ExprStmtDoc(make_id_doc(f"line{i + 1}")) for i in range(10)]) + result = to_python_script( + doc, + path_to_underline=[ + make_path("line1"), + make_path("line3"), + make_path("line5"), + make_path("line7"), + make_path("line9"), + ], + ) + assert ( + result.strip() + == format_script( + """ + line1 + ^^^^^ + line2 + line3 + ^^^^^ + line4 + line5 + ^^^^^ + line6 + line7 + ^^^^^ + line8 + line9 + ^^^^^ + line10 + """ + ).strip() + ) + + def test_underline_and_print_line_numbers_with_context(): doc = StmtBlockDoc([ExprStmtDoc(make_id_doc(f"line{i + 1}")) for i in range(12)]) result = to_python_script( - doc, print_line_numbers=True, num_context_lines=2, path_to_underline=make_path("line8") + doc, print_line_numbers=True, num_context_lines=2, path_to_underline=[make_path("line8")] ) assert result == format_script( """ @@ -332,7 +369,7 @@ def test_underline_and_print_line_numbers_with_context(): def test_underline_based_on_path_prefix(): doc = StmtBlockDoc([ExprStmtDoc(make_id_doc("foo")), ExprStmtDoc(make_id_doc("bar"))]) - result = to_python_script(doc, path_to_underline=make_path("foo").attr("x").attr("y")) + result = to_python_script(doc, path_to_underline=[make_path("foo").attr("x").attr("y")]) # There is no document that matches the desired path exactly, # but path of "foo" is a prefix of the desired path, and thus should be underlined. assert result == format_script( @@ -351,7 +388,7 @@ def test_longer_prefix_must_win(): doc = StmtBlockDoc( [ExprStmtDoc(make_id_doc("foo")), ExprStmtDoc(make_id_doc("bar")), ExprStmtDoc(foo_x)] ) - result = to_python_script(doc, path_to_underline=make_path("foo").attr("x").attr("y")) + result = to_python_script(doc, path_to_underline=[make_path("foo").attr("x").attr("y")]) # "foo" should not be underlined because there is a document with a more specific path prefix assert result == format_script( """ @@ -361,3 +398,75 @@ def test_longer_prefix_must_win(): ^^^^^ """ ) + + +def test_underline_from_obj(): + @T.prim_func + def func(a: T.int32, b: T.int32): + T.evaluate(a) + T.evaluate(b) + T.evaluate(a) + T.evaluate(b) + T.evaluate(a) + T.evaluate(b) + + result = func.script(obj_to_underline=[func.params[0]]) + assert result == format_script( + """ + # from tvm.script import tir as T + + @T.prim_func + def main(a: T.int32, b: T.int32): + T.evaluate(a) + ^ + T.evaluate(b) + T.evaluate(a) + ^ + T.evaluate(b) + T.evaluate(a) + ^ + T.evaluate(b) + """ + ) + + +def test_underline_from_multi_obj(): + @T.prim_func + def func(): + T.evaluate(0) + T.evaluate(1) + T.evaluate(2) + T.evaluate(3) + T.evaluate(4) + T.evaluate(5) + T.evaluate(6) + T.evaluate(7) + + result = func.script( + obj_to_underline=[ + func.body.seq[1], + func.body.seq[3], + func.body.seq[5], + func.body.seq[7], + ] + ) + assert result == format_script( + """ + # from tvm.script import tir as T + + @T.prim_func + def main(): + T.evaluate(0) + T.evaluate(1) + ^^^^^^^^^^^^^ + T.evaluate(2) + T.evaluate(3) + ^^^^^^^^^^^^^ + T.evaluate(4) + T.evaluate(5) + ^^^^^^^^^^^^^ + T.evaluate(6) + T.evaluate(7) + ^^^^^^^^^^^^^ + """ + )