diff --git a/CMakeLists.txt b/CMakeLists.txt index 260adf90dd831..a9d9fc3f4989b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -132,6 +132,7 @@ file(GLOB_RECURSE COMPILER_SRCS src/autotvm/*.cc src/tir/*.cc src/driver/*.cc + src/printer/*.cc src/api/*.cc ) diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index 8f922c0d42f72..0bb68912bfbda 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -308,5 +308,20 @@ class IRModule : public ObjectRef { TVM_DLL static IRModule FromText(const std::string& text, const std::string& source_path); }; +/*! + * \brief Render the node as a string in the text format. + * + * \param node The node to be rendered. + * \param show_meta_data Whether to print meta data section. + * \param annotate An optional callback function for attaching + * additional comment block to an expr. + * + * \note We support a limited set of IR nodes that are part of + * relay IR and + * \return The text representation. + */ +TVM_DLL std::string AsText(const ObjectRef& node, + bool show_meta_data = true, + runtime::TypedPackedFunc annotate = nullptr); } // namespace tvm #endif // TVM_IR_MODULE_H_ diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 1062c20bb4f91..17ea7aa8c9a52 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -542,18 +542,6 @@ class TempExpr : public Expr { /*! \brief Pretty print a Relay node, producing a fragment of the Relay text format. */ std::string PrettyPrint(const ObjectRef& node); -/*! - * \brief Render the node as a string in the Relay text format. - * \param node The node to be rendered. - * \param show_meta_data Whether to print meta data section. - * \param annotate An optional callback function for attaching - * additional comment block to an expr. - * \return The text representation. - */ -std::string AsText(const ObjectRef& node, - bool show_meta_data = true, - runtime::TypedPackedFunc annotate = nullptr); - /*! \brief namespace of the attributes that are attached to a function. */ namespace attr { /*! \brief Mark the function as a primitive function. */ diff --git a/src/ir/error.cc b/src/ir/error.cc index 62faf502e1cad..9d498288d2ba2 100644 --- a/src/ir/error.cc +++ b/src/ir/error.cc @@ -111,7 +111,7 @@ void ErrorReporter::RenderErrors(const IRModule& module, bool use_color) { // // The annotation callback will annotate the error messages // contained in the map. - annotated_prog << relay::AsText(func, false, [&err_map](tvm::relay::Expr expr) { + annotated_prog << AsText(func, false, [&err_map](const ObjectRef& expr) { auto it = err_map.find(expr); if (it != err_map.end()) { CHECK_NE(it->second.size(), 0); diff --git a/src/printer/doc.cc b/src/printer/doc.cc new file mode 100644 index 0000000000000..9072fd6bda338 --- /dev/null +++ b/src/printer/doc.cc @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/tvm/relay/doc.cc + * \brief Doc ADT used for pretty printing. + * + * Reference: Philip Wadler. A Prettier Printer. Journal of Functional Programming'98 + */ +#include +#include +#include +#include "doc.h" + +namespace tvm { + +/*! + * \brief Represent a piece of text in the doc. + */ +class DocTextNode : public DocAtomNode { + public: + /*! \brief The str content in the text. */ + std::string str; + + explicit DocTextNode(std::string str_val) + : str(str_val) { + if (str.find_first_of("\t\n") != str.npos) { + LOG(WARNING) << "text node: '" << str << "' should not has tab or newline."; + } + } + + static constexpr const char* _type_key = "printer.DocText"; + TVM_DECLARE_FINAL_OBJECT_INFO(DocTextNode, DocAtomNode); +}; + +TVM_REGISTER_OBJECT_TYPE(DocTextNode); + +class DocText : public DocAtom { + public: + explicit DocText(std::string str) { + data_ = runtime::make_object(str); + } + + TVM_DEFINE_OBJECT_REF_METHODS(DocText, DocAtom, DocTextNode); +}; + +/*! + * \brief Represent a line breaker in the doc. + */ +class DocLineNode : public DocAtomNode { + public: + /*! \brief The amount of indent in newline. */ + int indent; + + explicit DocLineNode(int indent) + : indent(indent) {} + + static constexpr const char* _type_key = "printer.DocLine"; + TVM_DECLARE_FINAL_OBJECT_INFO(DocLineNode, DocAtomNode); +}; + +TVM_REGISTER_OBJECT_TYPE(DocLineNode); + +class DocLine : public DocAtom { + public: + explicit DocLine(int indent) { + data_ = runtime::make_object(indent); + } + + TVM_DEFINE_OBJECT_REF_METHODS(DocLine, DocAtom, DocLineNode); +}; + +// DSL function implementations +Doc& Doc::operator<<(const Doc& right) { + CHECK(this != &right); + this->stream_.insert( + this->stream_.end(), right.stream_.begin(), right.stream_.end()); + return *this; +} + +Doc& Doc::operator<<(std::string right) { + return *this << DocText(right); +} + +Doc& Doc::operator<<(const DocAtom& right) { + this->stream_.push_back(right); + return *this; +} + +std::string Doc::str() { + std::ostringstream os; + for (auto atom : this->stream_) { + if (auto* text = atom.as()) { + os << text->str; + } else if (auto* line = atom.as()) { + os << "\n" << std::string(line->indent, ' '); + } else { + LOG(FATAL) << "do not expect type " << atom->GetTypeKey(); + } + } + return os.str(); +} + +Doc Doc::NewLine(int indent) { + return Doc() << DocLine(indent); +} + +Doc Doc::Text(std::string text) { + return Doc() << DocText(text); +} + +Doc Doc::Indent(int indent, Doc doc) { + for (size_t i = 0; i < doc.stream_.size(); ++i) { + if (auto* line = doc.stream_[i].as()) { + doc.stream_[i] = DocLine(indent + line->indent); + } + } + return doc; +} + +Doc Doc::StrLiteral(const std::string& value, std::string quote) { + // TODO(M.K.): add escape. + Doc doc; + return doc << quote << value << quote; +} + +Doc Doc::PyBoolLiteral(bool value) { + if (value) { + return Doc::Text("True"); + } else { + return Doc::Text("False"); + } +} + +Doc Doc::Brace(std::string open, + const Doc& body, + std::string close, + int indent) { + Doc doc; + doc << open; + doc << Indent(indent, NewLine() << body) << NewLine(); + doc << close; + return doc; +} + +Doc Doc::Concat(const std::vector& vec, const Doc& sep) { + Doc seq; + if (vec.size() != 0) { + if (vec.size() == 1) return vec[0]; + seq << vec[0]; + for (size_t i = 1; i < vec.size(); ++i) { + seq << sep << vec[i]; + } + } + return seq; +} +} // namespace tvm diff --git a/src/printer/doc.h b/src/printer/doc.h new file mode 100644 index 0000000000000..34a284b0f116a --- /dev/null +++ b/src/printer/doc.h @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/printer/doc.h + * \brief Doc ADT used for pretty printing. + * + * Reference: Philip Wadler. A Prettier Printer. Journal of Functional Programming'98 + */ +#ifndef TVM_PRINTER_DOC_H_ +#define TVM_PRINTER_DOC_H_ + +#include +#include +#include +#include +#include +#include + +namespace tvm { + +/*! + * \brief Doc atom node for the ADT. + * \sa DocAtom + */ +class DocAtomNode : public Object { + public: + static constexpr const char* _type_key = "printer.DocAtom"; + TVM_DECLARE_BASE_OBJECT_INFO(DocAtomNode, Object); +}; + +/*! + * \brief Managed reference to DocAtomNode. + * \sa DocAtomNode. +*/ +class DocAtom : public ObjectRef { + public: + TVM_DEFINE_OBJECT_REF_METHODS(DocAtom, ObjectRef, DocAtomNode); +}; + +/*! + * \brief Stream-like interface for Doc DSL. + * + * The Doc DSL de-couples the layout decision from the printing decision. + * + * The layout(code formating) decisions include: + * - Change indentation. + * - Break single line into multiple ones(subjected to future improvements). + */ +class Doc { + public: + /*! \brief default constructor */ + Doc() {} + /*! + * \brief Append right to the end of the current doc stream. + * \param right The doc to be appended. + * \return reference to self. + */ + Doc& operator<<(const Doc& right); + /*! + * \brief Append right to the end of the current doc stream. + * \param right The doc to be appended. + * \return reference to self. + * \note pass by value to allow copy elison optimization. + */ + Doc& operator<<(std::string right); + /*! + * \brief Append right to the end of the current doc stream. + * \param right The doc to be appended. + * \return reference to self. + */ + Doc& operator<<(const DocAtom& right); + /*! + * \brief Convert value to string via std::ostreamstream + * the append to the current doc stream. + * \param right The doc to be appended. + * \tparam T the type of the value. + * \return reference to self. + */ + template::value>::type> + Doc& operator<<(const T& value) { + std::ostringstream os; + os << value; + return *this << os.str(); + } + /*! + * \brief Convert the doc stream into string. + * \return The string representation. + */ + std::string str(); + /*! + * \brief Create a doc that represents text content. + * \return The created doc. + */ + static Doc Text(std::string value); + /*! + * \brief Create a doc that represents a new line. + * \return The created doc. + */ + static Doc NewLine(int indent = 0); + /*! + * \brief Create a new doc that adds indentation to everyline of the doc. + * \param indent The indent to be added. + * \param doc The doc to be indented. + * \return The created doc. + * \note pass by value to allow copy elison optimization. + */ + static Doc Indent(int indent, Doc doc); + /*! + * \brief Create a Doc that represents a string literal. + * \param value The content of the string literal. + * \param quote The quote in the literal. + * \return The created doc. + */ + static Doc StrLiteral(const std::string& value, std::string quote = "\""); + /*! + * \brief Create a Doc that represents a boolean literal in python syntax. + * \param value The bool value. + * \return The created doc. + */ + static Doc PyBoolLiteral(bool value); + /*! + * \brief Enclose body by brace and add indent. + * \param body The body + * \param open The open brace. + * \param close The close brace. + * \param indent amount of indentation. + * \return The created doc. + */ + static Doc Brace(std::string open, + const Doc& body, + std::string close, + int indent = 2); + /*! + * \brief Create a doc by concatenating together with separator. + * \param vec The docs to be concatenated. + * \param sep The seperator. + * \return The created doc. + */ + static Doc Concat(const std::vector& vec, const Doc& sep = Text(", ")); + + private: + /*! \brief Internal doc stream. */ + std::vector stream_; +}; + +} // namespace tvm +#endif // TVM_PRINTER_DOC_H_ diff --git a/src/printer/meta_data.h b/src/printer/meta_data.h new file mode 100644 index 0000000000000..6c300fd851760 --- /dev/null +++ b/src/printer/meta_data.h @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/printer/meta_data.h + * \brief Meta data context for printers. + */ +#ifndef TVM_PRINTER_META_DATA_H_ +#define TVM_PRINTER_META_DATA_H_ + +#include +#include +#include +#include +#include "doc.h" + +namespace tvm { +/*! + * \brief Meta data context for Printers + * + * This is an important part to enable bi-directional serializability. + * We use tvm's Node system to build the current IR. + * It can be hard to design a text format for all the possible nodes + * as the set of nodes can grow when we do more extensions. + * + * Instead of trying to design readable text format for every node, + * we support a meta data section in the text format. + * We allow the text format to refer to a node in the meta data section. + * + * The meta data section is a json serialized string of an Map>. + * Each element in the meta data section can be referenced by the text format. + * Each meta data node is printed in the following format. + * + * meta[type-key-of-node>][] + * + * Specifically, consider the following IR(constructed by python). + * + * \code + * + * n = tvm.var("n") + * x = tvm.relay.var("x", shape=(n, 1)) + * f = tvm.relay.Function([x], x) + * print(f.astext()) + * + * \endcode + * + * The corresponding text format is shown in the following code block. + * + * \code + * + * fn (%x: Tensor[(meta[Variable][0],), float32]) { + * %x + * } + * # Meta data section is a json-serialized string + * # of the following array. + * # [tvm.var("n")] + * + * \endcode + * + * Note that we store tvm.var("n") in the meta data section. + * Since it is stored in the index-0 in the meta data section, + * we print it as meta[Variable][0]. + * + * The text parser can recover this object by loading from the corresponding + * location in the meta data section. + * + * This is is a design trade-off. + * It allows us to embedded any meta data in the text format, + * while still being able to tweak the text part of the printed IR easily. + */ +class TextMetaDataContext { + public: + /*! + * \brief Get text representation of meta node. + * \param node The node to be converted to meta node. + * \return A string representation of the meta node. + */ + Doc GetMetaNode(const ObjectRef& node) { + auto it = meta_repr_.find(node); + if (it != meta_repr_.end()) { + return it->second; + } + std::string type_key = node->GetTypeKey(); + CHECK(!type_key.empty()); + Array& mvector = + meta_data_[type_key]; + int64_t index = static_cast(mvector.size()); + mvector.push_back(node); + Doc doc; + doc << "meta[" << type_key << "][" << index << "]"; + meta_repr_[node] = doc; + return meta_repr_[node]; + } + + /*! + * \brief Print a key value pair + */ + Doc PrintKeyValue(const std::string& str, const Doc& v) const { + return Doc() << "\"" << str << "\": " << v; + } + + /*! + * \brief Get the metadata section in json format. + * \return the meta data string. + */ + Doc GetMetaSection() const { + if (meta_data_.size() == 0) return Doc(); + return Doc::Text( + SaveJSON(Map(meta_data_.begin(), meta_data_.end()))); + } + + /*! \return whether the meta data context is empty. */ + bool empty() const { + return meta_data_.empty(); + } + + private: + /*! \brief additional metadata stored in TVM json format */ + std::unordered_map > meta_data_; + /*! \brief map from meta data into its string representation */ + std::unordered_map meta_repr_; +}; +} // namespace tvm +#endif // TVM_PRINTER_META_DATA_H_ diff --git a/src/relay/ir/pretty_printer.cc b/src/printer/relay_text_printer.cc similarity index 74% rename from src/relay/ir/pretty_printer.cc rename to src/printer/relay_text_printer.cc index c21f565f430c5..74979e9a4e691 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/printer/relay_text_printer.cc @@ -18,9 +18,11 @@ */ /*! - * \file pretty_printer.cc - * \brief Pretty printer for Relay programs - * Supports ANF, GNF, and metadata. + * \file text_format_printer.cc + * \brief Printer to print out the IR text format + * that can be parsed by a parser. + * + * Supports ANF, GNF in relay and metadata. * * Inlining heuristics: * - Always inline: @@ -31,142 +33,27 @@ * - Otherwise, inline if the node is at the end of a scope and is used at most once. */ #include -#include -#include #include +#include #include #include "doc.h" -#include "../pass/dependency_graph.h" -#include "../../ir/attr_functor.h" +#include "meta_data.h" +#include "../relay/pass/dependency_graph.h" +#include "../ir/attr_functor.h" namespace tvm { namespace relay { -static const char* kSemVer = "v0.0.4"; - -Doc Brace(const Doc& d, - const std::string& open = "{", - const std::string& close = "}", - int indent = 2) { - Doc doc; - doc << open; - doc << Indent(indent, PrintNewLine() << d) << PrintNewLine(); - doc << close; - return doc; -} - -/*! - * \brief Meta data context for PrettyPrinter. - * - * This is an important part to enable bi-directional serializability. - * We use tvm's Node system to build the current IR. - * It can be hard to design a text format for all the possible nodes - * as the set of nodes can grow when we do more extensions. - * - * Instead of trying to design readable text format for every node, - * we support a meta data section in the text format. - * We allow the text format to refer to a node in the meta data section. - * - * The meta data section is a json serialized string of an Map>. - * Each element in the meta data section can be referenced by the text format. - * Each meta data node is printed in the following format. - * - * meta[type-key-of-node>][] - * - * Specifically, consider the following IR(constructed by python). - * - * \code - * - * n = tvm.var("n") - * x = tvm.relay.var("x", shape=(n, 1)) - * f = tvm.relay.Function([x], x) - * print(f.astext()) - * - * \endcode - * - * The corresponding text format is shown in the following code block. - * - * \code - * - * fn (%x: Tensor[(meta[Variable][0],), float32]) { - * %x - * } - * # Meta data section is a json-serialized string - * # of the following array. - * # [tvm.var("n")] - * - * \endcode - * - * Note that we store tvm.var("n") in the meta data section. - * Since it is stored in the index-0 in the meta data section, - * we print it as meta[Variable][0]. - * - * The text parser can recover this object by loading from the corresponding - * location in the meta data section. - * - * This is is a design trade-off. - * It allows us to embedded any meta data in the text format, - * while still being able to tweak the text part of the printed IR easily. - */ -class TextMetaDataContext { - public: - /*! - * \brief Get text representation of meta node. - * \param node The node to be converted to meta node. - * \return A string representation of the meta node. - */ - Doc GetMetaNode(const ObjectRef& node) { - auto it = meta_repr_.find(node); - if (it != meta_repr_.end()) { - return it->second; - } - std::string type_key = node->GetTypeKey(); - CHECK(!type_key.empty()); - Array& mvector = - meta_data_[type_key]; - int64_t index = static_cast(mvector.size()); - mvector.push_back(node); - Doc doc; - doc << "meta[" << type_key << "][" << index << "]"; - meta_repr_[node] = doc; - return meta_repr_[node]; - } - - Doc PrintKeyValue(const std::string& str, const Doc& v) const { - return Doc("\"") << str << "\": " << v; - } - - /*! - * \brief Get the metadata section in json format. - * \return the meta data string. - */ - Doc GetMetaSection() const { - if (meta_data_.size() == 0) return Doc(); - return Doc(SaveJSON(Map(meta_data_.begin(), meta_data_.end()))); - } - - /*! \return whether the meta data context is empty. */ - bool empty() const { - return meta_data_.empty(); - } - - private: - /*! \brief additional metadata stored in TVM json format */ - std::unordered_map > meta_data_; - /*! \brief map from meta data into its string representation */ - std::unordered_map meta_repr_; -}; - -class PrettyPrinter : +class RelayTextPrinter : public ExprFunctor, public PatternFunctor, public TypeFunctor, public AttrFunctor { public: - explicit PrettyPrinter(bool show_meta_data, - runtime::TypedPackedFunc annotate) : - show_meta_data_(show_meta_data), - annotate_(annotate) {} + explicit RelayTextPrinter(bool show_meta_data, + runtime::TypedPackedFunc annotate) + : show_meta_data_(show_meta_data), + annotate_(annotate) {} /*! * \brief Print additional info about expr in comment. @@ -194,7 +81,7 @@ class PrettyPrinter : Doc doc; Doc body; doc << "{"; - doc << Indent(indent, body << PrintNewLine() << PrintScope(node)) << PrintNewLine(); + doc << Doc::Indent(indent, body << Doc::NewLine() << PrintScope(node)) << Doc::NewLine(); doc << "}"; return doc; } @@ -220,10 +107,10 @@ class PrettyPrinter : Doc doc; doc << PrintScope(node); if (!meta_.empty()) { - doc << PrintNewLine(); + doc << Doc::NewLine(); if (show_meta_data_) { // append meta data in the end. - doc << "METADATA:" << PrintNewLine() << meta_.GetMetaSection(); + doc << "METADATA:" << Doc::NewLine() << meta_.GetMetaSection(); } else { doc << "// meta data omitted. you can use show_meta_data=True to include meta data"; } @@ -244,8 +131,9 @@ class PrettyPrinter : } else if (node.as()) { return PrintMod(Downcast(node)); } else { - Doc doc; - return doc << node; + std::ostringstream os; + os << node; + return Doc() << os.str(); } } @@ -278,23 +166,23 @@ class PrettyPrinter : } } name_alloc_map_[unique_prefix] = 0; - return Doc(unique_prefix); + return Doc::Text(unique_prefix); } Doc Print(Kind k) { switch (k) { case kType: - return Doc("Type"); + return Doc::Text("Type"); case kShapeVar: - return Doc("Shape"); + return Doc::Text("Shape"); case kBaseType: - return Doc("BaseType"); + return Doc::Text("BaseType"); case kConstraint: - return Doc("Constraint"); + return Doc::Text("Constraint"); case kAdtHandle: - return Doc("AdtHandle"); + return Doc::Text("AdtHandle"); case kTypeData: - return Doc("TypeData"); + return Doc::Text("TypeData"); default: LOG(ERROR) << "Unknown Kind"; throw; @@ -387,7 +275,7 @@ class PrettyPrinter : // wrap GNFed let in brackets Doc body; printed_expr << "("; - printed_expr << Indent(2, body << PrintNewLine() << VisitExpr(expr)) << PrintNewLine(); + printed_expr << Doc::Indent(2, body << Doc::NewLine() << VisitExpr(expr)) << Doc::NewLine(); printed_expr << ")"; } else { printed_expr = VisitExpr(expr); @@ -399,7 +287,7 @@ class PrettyPrinter : if (expr.as()) { // This is our first time visiting the var and we hit the VarNode case // in the visitor. Thus the variable is free. - doc_stack_.back() << "free_var " << printed_expr << PrintNewLine(); + doc_stack_.back() << "free_var " << printed_expr << Doc::NewLine(); // Memoization is done in AllocVar. return memo_[expr]; } else if (inline_expr) { @@ -408,7 +296,7 @@ class PrettyPrinter : } else { Doc temp_var = AllocTemp(); memo_[expr] = temp_var; - doc_stack_.back() << temp_var << " = " << printed_expr << ";" << PrintNewLine(); + doc_stack_.back() << temp_var << " = " << printed_expr << ";" << Doc::NewLine(); return temp_var; } } @@ -419,6 +307,28 @@ class PrettyPrinter : return AllocVar(GetRef(op)); } + /*! + * \brief special method to print out const scalar + * \param dtype The data type + * \param value The value to be printed. + */ + template + static Doc ScalarLiteral(DataType dtype, const T& value) { + std::ostringstream os; + if (dtype == DataType::Int(32)) { + os << value; + } else if (dtype == DataType::Float(32)) { + os << value << 'f'; + } else if (dtype == DataType::Float(64)) { + os << value; + } else if (dtype == DataType::Bool()) { + return Doc::PyBoolLiteral(value != 0); + } else { + os << value; + } + return Doc::Text(os.str()); + } + Doc VisitExpr_(const ConstantNode* op) final { // Print out simple scalars directly. if (op->is_scalar()) { @@ -426,15 +336,15 @@ class PrettyPrinter : DataType dtype = DataType(op->data->dtype); CHECK_EQ(op->data->ctx.device_type, kDLCPU); if (dtype == DataType::Int(32)) { - return PrintConstScalar(dtype, static_cast(op->data->data)); + return ScalarLiteral(dtype, static_cast(op->data->data)[0]); } else if (dtype == DataType::Int(64)) { - return PrintConstScalar(dtype, static_cast(op->data->data)); + return ScalarLiteral(dtype, static_cast(op->data->data)[0]); } else if (dtype == DataType::Float(32)) { - return PrintConstScalar(dtype, static_cast(op->data->data)); + return ScalarLiteral(dtype, static_cast(op->data->data)[0]); } else if (dtype == DataType::Float(64)) { - return PrintConstScalar(dtype, static_cast(op->data->data)); + return ScalarLiteral(dtype, static_cast(op->data->data)[0]); } else if (dtype == DataType::Bool()) { - return PrintConstScalar(dtype, static_cast(op->data->data)); + return ScalarLiteral(dtype, static_cast(op->data->data)[0]); } } // default fall-back, record it as meta node. @@ -448,7 +358,7 @@ class PrettyPrinter : fields.push_back(Print(field)); } Doc doc; - doc << "(" << PrintSep(fields); + doc << "(" << Doc::Concat(fields); // conform to python tuple format (1,) if (op->fields.size() == 1) { doc << ","; @@ -478,7 +388,7 @@ class PrettyPrinter : << " = " << Print(op->value, false, true) << ";" - << PrintNewLine(); + << Doc::NewLine(); // we use a scope here so GNF hoisting doesn't escape too far // and nested, unique lets are not hoisted doc << PrintScope(op->body); @@ -492,9 +402,9 @@ class PrettyPrinter : doc << "["; std::vector type_params; for (const TypeVar& tv : fn->type_params) { - type_params.push_back(Doc(tv->name_hint)); + type_params.push_back(Doc::Text(tv->name_hint)); } - doc << PrintSep(type_params); + doc << Doc::Concat(type_params); doc << "]"; } doc << "("; @@ -505,7 +415,7 @@ class PrettyPrinter : for (const Doc& d : PrintFuncAttrs(fn->attrs)) { params.push_back(d); } - doc << PrintSep(params) << ") "; + doc << Doc::Concat(params) << ") "; if (fn->ret_type.defined()) { doc << "-> " << Print(fn->ret_type) << " "; } @@ -530,36 +440,36 @@ class PrettyPrinter : // type definitions for (const auto& kv : mod->type_definitions) { if (counter++ != 0) { - doc << PrintNewLine(); + doc << Doc::NewLine(); } doc << Print(kv.second); - doc << PrintNewLine(); + doc << Doc::NewLine(); } // functions for (const auto& kv : mod->functions) { dg_ = DependencyGraph::Create(&arena_, kv.second); if (counter++ != 0) { - doc << PrintNewLine(); + doc << Doc::NewLine(); } std::ostringstream os; os << "def @" << kv.first->name_hint; - doc << PrintFunc(Doc(os.str()), kv.second); - doc << PrintNewLine(); + doc << PrintFunc(Doc::Text(os.str()), kv.second); + doc << Doc::NewLine(); } return doc; } Doc VisitExpr_(const FunctionNode* op) final { - return PrintFunc(Doc("fn "), GetRef(op)); + return PrintFunc(Doc::Text("fn "), GetRef(op)); } Doc VisitExpr_(const GlobalVarNode* op) final { - return Doc('@' + op->name_hint); + return Doc::Text('@' + op->name_hint); } Doc VisitExpr_(const OpNode* op) final { - return Doc(op->name); + return Doc::Text(op->name); } Doc VisitExpr_(const CallNode* op) final { @@ -584,7 +494,7 @@ class PrettyPrinter : // don't print as a call if it's a 0-arity cons return doc; } else { - return doc << "(" << PrintSep(args) << ")"; + return doc << "(" << Doc::Concat(args) << ")"; } } @@ -619,13 +529,13 @@ class PrettyPrinter : Doc rhs_doc = PrintScope(clause->rhs); if (clause->rhs.as()) { // only add braces if there are multiple lines on the rhs - rhs_doc = Brace(rhs_doc); + rhs_doc = Doc::Brace("{", rhs_doc, "}"); } clause_doc << rhs_doc << ","; clause_docs.push_back(clause_doc); } - doc << Indent(2, body << PrintNewLine() << PrintSep(clause_docs, PrintNewLine())) - << PrintNewLine() << "}"; + doc << Doc::Indent(2, body << Doc::NewLine() << Doc::Concat(clause_docs, Doc::NewLine())) + << Doc::NewLine() << "}"; return doc; } @@ -651,7 +561,7 @@ class PrettyPrinter : for (const auto& pat : p->patterns) { pats.push_back(Print(pat)); } - doc << PrintSep(pats) << ")"; + doc << Doc::Concat(pats) << ")"; } return doc; } @@ -663,12 +573,12 @@ class PrettyPrinter : for (const auto& pat : pt->patterns) { pats.push_back(Print(pat)); } - doc << PrintSep(pats) << ")"; + doc << Doc::Concat(pats) << ")"; return doc; } Doc VisitPattern_(const PatternWildcardNode* pw) final { - return Doc("_"); + return Doc::Text("_"); } Doc VisitPattern_(const PatternVarNode* pv) final { @@ -684,7 +594,7 @@ class PrettyPrinter : for (Type input : n->inputs) { inputs.push_back(Print(input)); } - doc << PrintSep(inputs) << ")"; + doc << Doc::Concat(inputs) << ")"; } return doc; } @@ -711,11 +621,11 @@ class PrettyPrinter : } Doc VisitType_(const TypeVarNode* node) final { - return Doc(node->name_hint); + return Doc::Text(node->name_hint); } Doc VisitType_(const GlobalTypeVarNode* node) final { - return Doc(node->name_hint); + return Doc::Text(node->name_hint); } Doc VisitType_(const TypeCallNode* node) final { @@ -725,11 +635,15 @@ class PrettyPrinter : args.push_back(PrintType(t, false)); } doc << "["; - doc << PrintSep(args); + doc << Doc::Concat(args); doc << "]"; return doc; } + Doc PrintDType(DataType dtype) { + return Doc::Text(runtime::DLDataType2String(dtype)); + } + Doc VisitType_(const TensorTypeNode* node) final { // scalar type if (node->shape.size() == 0) { @@ -741,7 +655,7 @@ class PrettyPrinter : for (ObjectRef shape : node->shape) { shapes.push_back(PrintAttr(shape)); } - doc << PrintSep(shapes); + doc << Doc::Concat(shapes); return doc << "), " << PrintDType(node->dtype) << "]"; } @@ -751,7 +665,7 @@ class PrettyPrinter : fields.push_back(Print(field)); } Doc doc; - doc << "(" << PrintSep(fields); + doc << "(" << Doc::Concat(fields); // conform to python tuple format (1,) if (node->fields.size() == 1) { doc << ","; @@ -768,14 +682,14 @@ class PrettyPrinter : for (Type type_param : node->type_params) { type_params.push_back(Print(type_param)); } - doc << PrintSep(type_params); + doc << Doc::Concat(type_params); doc << "]"; } std::vector arg_types; for (Type arg_type : node->arg_types) { arg_types.push_back(Print(arg_type)); } - return doc << "(" << PrintSep(arg_types) << ") -> " << Print(node->ret_type); + return doc << "(" << Doc::Concat(arg_types) << ") -> " << Print(node->ret_type); } Doc VisitType_(const RelayRefTypeNode* node) final { @@ -795,7 +709,7 @@ class PrettyPrinter : for (Type type_var : node->type_vars) { type_vars.push_back(Print(type_var)); } - doc << PrintSep(type_vars) << "]"; + doc << Doc::Concat(type_vars) << "]"; } doc << " "; @@ -804,14 +718,14 @@ class PrettyPrinter : constructor_docs.push_back(Print(constructor, /* meta */ false, /* try_inline */ true)); } Doc separator; - separator << "," << PrintNewLine(); + separator << "," << Doc::NewLine(); Doc adt_body; - adt_body << PrintSep(constructor_docs, separator); + adt_body << Doc::Concat(constructor_docs, separator); // add trailing comma if there are any constructors if (!constructor_docs.empty()) { adt_body << ","; } - doc << Brace(adt_body); + doc << Doc::Brace("{", adt_body, "}"); in_adt_def_ = false; return doc; } @@ -832,7 +746,7 @@ class PrettyPrinter : } return printed_attr; } else { - return Doc("None"); + return Doc::Text("None"); } } @@ -847,28 +761,28 @@ class PrettyPrinter : for (auto val : op->data) { arr_vals.push_back(PrintAttr(val)); } - doc << PrintSep(arr_vals); + doc << Doc::Concat(arr_vals); doc << "]"; return doc; } Doc VisitAttr_(const tir::IntImmNode* op) final { - return PrintConstScalar(op->dtype, &(op->value)); + return ScalarLiteral(op->dtype, op->value); } Doc VisitAttr_(const tir::FloatImmNode* op) final { - return PrintConstScalar(op->dtype, &(op->value)); + return ScalarLiteral(op->dtype, op->value); } Doc VisitAttr_(const tir::StringImmNode* op) final { - return PrintString(op->value); + return Doc::StrLiteral(op->value); } private: /*! \brief Whether to print meta data. */ bool show_meta_data_; /*! \brief additional comment function */ - runtime::TypedPackedFunc annotate_; + runtime::TypedPackedFunc annotate_; /*! \brief Stack of docs to implement scoped GNFing. */ std::vector doc_stack_{}; /*! \brief Map from Expr to Doc */ @@ -896,9 +810,11 @@ class PrettyPrinter : /*! * \brief Attribute printer which prints the attributes in the call. */ -class PrettyPrinter::AttrPrinter : public AttrVisitor { +class RelayTextPrinter::AttrPrinter : + public AttrVisitor { public: - AttrPrinter(std::vector* doc, PrettyPrinter* parent) : docs(doc), parent_(parent) {} + AttrPrinter(std::vector* doc, RelayTextPrinter* parent) + : docs(doc), parent_(parent) {} template void PrintKV(const char* key, const T& value) { @@ -922,16 +838,16 @@ class PrettyPrinter::AttrPrinter : public AttrVisitor { PrintKV(key, *value); } void Visit(const char* key, bool* value) final { - PrintKV(key, PrintBool(*value)); + PrintKV(key, Doc::PyBoolLiteral(*value)); } void Visit(const char* key, std::string* value) final { - PrintKV(key, PrintString(*value)); + PrintKV(key, Doc::StrLiteral(*value)); } void Visit(const char* key, void** value) final { LOG(FATAL) << "do not allow void as argument"; } void Visit(const char* key, DataType* value) final { - PrintKV(key, PrintString(runtime::DLDataType2String(*value))); + PrintKV(key, Doc::StrLiteral(runtime::DLDataType2String(*value))); } void Visit(const char* key, runtime::NDArray* value) final { LOG(FATAL) << "do not allow NDarray as argument"; @@ -942,10 +858,11 @@ class PrettyPrinter::AttrPrinter : public AttrVisitor { private: std::vector* docs; - PrettyPrinter* parent_; + RelayTextPrinter* parent_; }; -std::vector PrettyPrinter::PrintCallAttrs(const Attrs& attrs, const Expr& op) { +std::vector RelayTextPrinter::PrintCallAttrs( + const Attrs& attrs, const Expr& op) { std::vector docs; if (!attrs.defined()) return docs; const auto* op_node = op.as(); @@ -962,7 +879,7 @@ std::vector PrettyPrinter::PrintCallAttrs(const Attrs& attrs, const Expr& o } } -std::vector PrettyPrinter::PrintFuncAttrs(const Attrs& attrs) { +std::vector RelayTextPrinter::PrintFuncAttrs(const Attrs& attrs) { std::vector docs; if (!attrs.defined()) return docs; const auto* dict_attrs = attrs.as(); @@ -974,30 +891,27 @@ std::vector PrettyPrinter::PrintFuncAttrs(const Attrs& attrs) { } return docs; } +} // namespace relay -std::string PrettyPrint_(const ObjectRef& node, - bool show_meta_data, - runtime::TypedPackedFunc annotate) { - Doc doc; - doc << kSemVer << PrintNewLine() - << PrettyPrinter(show_meta_data, annotate).PrintFinal(node); - return doc.str(); -} - -std::string PrettyPrint(const ObjectRef& node) { +static const char* kSemVer = "v0.0.4"; +// TODO(tvm-team): split into files, related: arith/analyzer.h +// +// - text_printer.h (common header) +// - text_printer.cc (prints modules dispatch into relay and tir files) +// - type_text_printer.cc(specific printing logics for types, +// can also consider put under type_text_printer) +// - Implements AsText +// - relay_text_printer.cc (specific printing logics for relay) +// - tir_text_printer.cc (specific printing logics for TIR) +std::string AsText(const ObjectRef& node, + bool show_meta_data, + runtime::TypedPackedFunc annotate) { Doc doc; - doc << PrettyPrinter(false, runtime::TypedPackedFunc()).PrintFinal(node); + doc << kSemVer << Doc::NewLine() + << relay::RelayTextPrinter(show_meta_data, annotate).PrintFinal(node); return doc.str(); } -std::string AsText(const ObjectRef& node, - bool show_meta_data, - runtime::TypedPackedFunc annotate) { - return PrettyPrint_(node, show_meta_data, annotate); -} - TVM_REGISTER_GLOBAL("relay._expr.AsText") .set_body_typed(AsText); - -} // namespace relay } // namespace tvm diff --git a/src/relay/ir/doc.cc b/src/relay/ir/doc.cc deleted file mode 100644 index 26aec39e5282b..0000000000000 --- a/src/relay/ir/doc.cc +++ /dev/null @@ -1,126 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/tvm/relay/doc.cc - * \brief Doc ADT used for pretty printing. - * Based on Section 1 of https://homepages.inf.ed.ac.uk/wadler/papers/prettier/prettier.pdf. - */ -#include -#include -#include "doc.h" - -namespace tvm { -namespace relay { - -// Text constructor -DocAtom Text(const std::string& str) { - return std::make_shared(str); -} - -// Line constructor -DocAtom Line(int indent = 0) { - return std::make_shared(indent); -} - -Doc::Doc(const std::string& str) { - if (str == "\n") { - this->stream_ = {Line()}; - } else { - this->stream_ = {Text(str)}; - } -} - -// DSL function implementations - -Doc& Doc::operator<<(const Doc& right) { - CHECK(this != &right); - this->stream_.insert(this->stream_.end(), right.stream_.begin(), right.stream_.end()); - return *this; -} - -Doc& Doc::operator<<(const std::string& right) { - return *this << Doc(right); -} - -Doc& Doc::operator<<(const DocAtom& right) { - this->stream_.push_back(right); - return *this; -} - -Doc Indent(int indent, const Doc& doc) { - Doc ret; - for (auto atom : doc.stream_) { - if (auto text = std::dynamic_pointer_cast(atom)) { - ret.stream_.push_back(text); - } else if (auto line = std::dynamic_pointer_cast(atom)) { - ret.stream_.push_back(Line(indent + line->indent)); - } else {CHECK(false);} - } - return ret; -} - -std::string Doc::str() { - std::ostringstream os; - for (auto atom : this->stream_) { - if (auto text = std::dynamic_pointer_cast(atom)) { - os << text->str; - } else if (auto line = std::dynamic_pointer_cast(atom)) { - os << "\n" << std::string(line->indent, ' '); - } else {CHECK(false);} - } - return os.str(); -} - -Doc PrintSep(const std::vector& vec, const Doc& sep) { - Doc seq; - if (vec.size() != 0) { - seq = vec[0]; - for (size_t i = 1; i < vec.size(); i++) { - seq << sep << vec[i]; - } - } - return seq; -} - -Doc PrintBool(bool value) { - if (value) { - return Doc("True"); - } else { - return Doc("False"); - } -} - -Doc PrintDType(DataType dtype) { - return Doc(runtime::DLDataType2String(dtype)); -} - -Doc PrintString(const std::string& value) { - // TODO(M.K.): add escape. - Doc doc; - return doc << "\"" << value << "\""; -} - -Doc PrintNewLine(int ident) { - Doc doc; - return doc << Line(ident); -} - -} // namespace relay -} // namespace tvm diff --git a/src/relay/ir/doc.h b/src/relay/ir/doc.h deleted file mode 100644 index a41fd6145d26b..0000000000000 --- a/src/relay/ir/doc.h +++ /dev/null @@ -1,130 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/relay/doc.h - * \brief Doc ADT used for pretty printing. - * Based on Section 1 of - * https://homepages.inf.ed.ac.uk/wadler/papers/prettier/prettier.pdf, but with - * a vector instead of an implicitly linked list. - */ -#ifndef TVM_RELAY_IR_DOC_H_ -#define TVM_RELAY_IR_DOC_H_ - -#include -#include -#include -#include - -namespace tvm { -namespace relay { - -// Doc Atom ADT -struct DocAtomNode { - virtual ~DocAtomNode() = default; -}; - -using DocAtom = std::shared_ptr; - -struct TextNode : DocAtomNode { - std::string str; - - explicit TextNode(const std::string& str) : str(str) { - if (str.find_first_of("\t\n") != str.npos) { - LOG(WARNING) << "text node: '" << str << "' should not has tab or newline."; - } - } -}; - -struct LineNode : DocAtomNode { - int indent; - - explicit LineNode(int indent) : indent(indent) {} -}; - -// Doc is a stream-like interface -class Doc { - public: - Doc() {} - explicit Doc(const std::string& str); - template - explicit Doc(const T& str) { - (*this) << str; - } - - // Append right to this. - Doc& operator<<(const Doc& right); - // Like above. - Doc& operator<<(const std::string& right); - // Like above. - Doc& operator<<(const DocAtom& right); - // Like above, but converts right to a string first. - template - Doc& operator<<(const T& right) { - std::ostringstream os; - os << right; - return *this << os.str(); - } - - // Indent a doc stream. - friend Doc Indent(int indent, const Doc& doc); - - // Wadler's `layout` - std::string str(); - - private: - std::vector stream_; -}; - -// DSL functions - -// Render vectors of docs with a separator. e.g. PrintSep([1, 2, 3], f) -> 1f2f3 -Doc PrintSep(const std::vector& vec, const Doc& sep = Doc(", ")); -// Print a constant bool value. -Doc PrintBool(bool value); -// Print a data type. -Doc PrintDType(DataType dtype); -// Print a string. -Doc PrintString(const std::string& value); -// Print a newline. -Doc PrintNewLine(int indent = 0); -/*! - * \brief special method to print out const scalar - * \param dtype The data type - * \param data The pointer to hold the data. - */ -template -Doc PrintConstScalar(DataType dtype, const T* data) { - std::ostringstream os; - if (dtype == DataType::Int(32)) { - os << data[0]; - } else if (dtype == DataType::Float(32)) { - os << data[0] << 'f'; - } else if (dtype == DataType::Bool()) { - return PrintBool(data[0] != 0); - } else { - // todo(@M.K.) this is unsafe. fix. - os << data[0]; - } - return Doc(os.str()); -} - -} // namespace relay -} // namespace tvm -#endif // TVM_RELAY_IR_DOC_H_ diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 3d8cc3a85b2bd..8151db416b748 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -21,6 +21,7 @@ * \file src/tvm/relay/ir/expr.cc * \brief The expression AST nodes of Relay. */ +#include #include namespace tvm { @@ -362,5 +363,9 @@ TVM_REGISTER_GLOBAL("relay._expr.FunctionSetAttr") TVM_REGISTER_GLOBAL("relay._make.Any") .set_body_typed([]() { return Any::make(); }); +std::string PrettyPrint(const ObjectRef& node) { + return AsText(node, false, nullptr); +} + } // namespace relay } // namespace tvm diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index 3d37d61448f0e..fd8e6fd0f59f5 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -957,7 +957,7 @@ class FuseMutator : private ExprMutator { // Debug function, dump the group assignment in text. void DebugDumpGroup(const Expr& body) { - std::string text = AsText(body, false, [this](const Expr& expr) -> std::string { + std::string text = AsText(body, false, [this](const ObjectRef& expr) -> std::string { auto it = gmap_.find(expr.get()); if (it == gmap_.end()) return ""; std::ostringstream os;