diff --git a/src/printer/text_printer.h b/src/printer/text_printer.h index 65b4a812ce38..7baa3878ff72 100644 --- a/src/printer/text_printer.h +++ b/src/printer/text_printer.h @@ -320,6 +320,7 @@ class TIRTextPrinter : public StmtFunctor, Doc PrintIterVar(const IterVarNode* op); Doc PrintRange(const RangeNode* op); Doc PrintBuffer(const BufferNode* op); + Doc BufferNode2Doc(const BufferNode* op, Doc doc); Doc PrintString(const StringObj* op) { return Doc::StrLiteral(op->data); } /*! diff --git a/src/printer/tir_text_printer.cc b/src/printer/tir_text_printer.cc index 0de7215ac71e..8e86fb55f226 100644 --- a/src/printer/tir_text_printer.cc +++ b/src/printer/tir_text_printer.cc @@ -116,25 +116,7 @@ Doc TIRTextPrinter::PrintPrimFunc(const PrimFunc& prim_func) { std::vector buffer_docs; for (const auto& it : memo_buf_) { const auto& buf = it.first; - buffer_docs.push_back(Print(buf) << Doc::Text(": Buffer(") << Print(buf->data) << ", " - << PrintDType(buf->dtype) << ", " << Print(buf->shape) - << ", " << Print(buf->strides)); - if (!is_zero(buf->elem_offset)) { - buffer_docs.back() << ", elem_offset=" << Print(buf->elem_offset); - } - if (buf->scope != "global") { - buffer_docs.back() << ", scope=" << Doc::StrLiteral(buf->scope); - } - if (buf->data_alignment != 128) { - buffer_docs.back() << ", align=" << buf->data_alignment; - } - if (buf->offset_factor != 1) { - buffer_docs.back() << ", offset_factor=" << buf->offset_factor; - } - if (buf->buffer_type != 1) { - buffer_docs.back() << ", type=" << Doc::StrLiteral("auto"); - } - buffer_docs.back() << ")"; + buffer_docs.push_back(BufferNode2Doc(buf.get(), Print(buf))); } buffer_doc << Doc::NewLine() << "buffers = {"; buffer_doc << PrintSep(buffer_docs, Doc::Indent(11, Doc::Text(",") << Doc::NewLine())); @@ -203,8 +185,36 @@ Doc TIRTextPrinter::PrintRange(const RangeNode* op) { Doc TIRTextPrinter::PrintBuffer(const BufferNode* op) { const Buffer& buffer = GetRef(op); - CHECK_GT(memo_buf_.count(buffer), 0); - return meta_->InMeta(buffer) ? meta_->GetMetaNode(buffer) : memo_buf_[buffer]; + + if (meta_->InMeta(buffer)) { + return meta_->GetMetaNode(buffer); + } else if (memo_buf_.count(buffer)) { + return memo_buf_[buffer]; + } else { + memo_buf_[buffer] = AllocBuf(buffer); + return BufferNode2Doc(op, memo_buf_[buffer]); + } +} + +Doc TIRTextPrinter::BufferNode2Doc(const BufferNode* buf, Doc doc) { + doc << Doc::Text(": Buffer(") << Print(buf->data) << ", " << PrintDType(buf->dtype) << ", " + << Print(buf->shape) << ", " << Print(buf->strides); + if (!is_zero(buf->elem_offset)) { + doc << ", elem_offset=" << Print(buf->elem_offset); + } + if (buf->scope != "global") { + doc << ", scope=" << Doc::StrLiteral(buf->scope); + } + if (buf->data_alignment != 128) { + doc << ", align=" << buf->data_alignment; + } + if (buf->offset_factor != 1) { + doc << ", offset_factor=" << buf->offset_factor; + } + if (buf->buffer_type != 1) { + doc << ", type=" << Doc::StrLiteral("auto"); + } + return doc << ")"; } Doc TIRTextPrinter::VisitExprDefault_(const Object* op) {