From e432404c170dad6f5b4d3ad8591d90057412f90c Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Thu, 11 Nov 2021 15:02:16 +0800 Subject: [PATCH] [SparseTIR] ReprPrinter for Axis and SpIterVar (#16) --- include/tvm/tir/sparse.h | 19 +++++++++++++ src/tir/ir/sparse.cc | 59 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 78 insertions(+) diff --git a/include/tvm/tir/sparse.h b/include/tvm/tir/sparse.h index a0dd5db19107..8621c2b572e7 100644 --- a/include/tvm/tir/sparse.h +++ b/include/tvm/tir/sparse.h @@ -382,6 +382,9 @@ enum class SpIterKind : int { kSparseVariable = 3 }; +// overload printing of for type. +TVM_DLL std::ostream& operator<<(std::ostream& os, SpIterKind kind); + /*! * \brief Iterator variables in SparseTIR */ @@ -437,6 +440,22 @@ class SpIterVar : public ObjectRef { // inline implementations inline SpIterVar::operator PrimExpr() const { return (*this)->var; } +// inline implementations +inline const char* SpIterKind2String(SpIterKind t) { + switch (t) { + case SpIterKind::kDenseFixed: + return "dense_fixed"; + case SpIterKind::kDenseVariable: + return "dense_variable"; + case SpIterKind::kSparseFixed: + return "sparse_fixed"; + case SpIterKind::kSparseVariable: + return "sparse_variable"; + } + LOG(FATAL) << "Unknown SpIterKind" << t; + throw; +} + } // namespace tir } // namespace tvm diff --git a/src/tir/ir/sparse.cc b/src/tir/ir/sparse.cc index 6a59dd0a5e5b..f782eea32e74 100644 --- a/src/tir/ir/sparse.cc +++ b/src/tir/ir/sparse.cc @@ -58,6 +58,16 @@ TVM_REGISTER_GLOBAL("tir.sparse.DenseFixedAxis") return DenseFixedAxis(name, length, from_sparse); }); +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "dense_fixed(" << op->name << ", " << op->length; + if (op->from_sparse.defined()) { + p->stream << ", from_sparse=" << op->from_sparse.value(); + } + p->stream << ")"; + }); + // DenseVariableAxis DenseVariableAxis::DenseVariableAxis(String name, PrimExpr length, Buffer indptr) { ObjectPtr node = make_object(); @@ -74,6 +84,12 @@ TVM_REGISTER_GLOBAL("tir.sparse.DenseVariableAxis") return DenseVariableAxis(name, length, indptr); }); +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "dense_variable(" << op->name << ", " << op->length << ", " << op->indptr->name; + }); + // SparseFixedAxis SparseFixedAxis::SparseFixedAxis(String name, PrimExpr length, Buffer indices, PrimExpr num_cols) { ObjectPtr node = make_object(); @@ -91,6 +107,13 @@ TVM_REGISTER_GLOBAL("tir.sparse.SparseFixedAxis") return SparseFixedAxis(name, length, indices, num_cols); }); +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "sparse_fixed(" << op->name << ", " << op->length << ", " << op->num_cols << ", " + << op->indices->name << ")"; + }); + // SparseVariableAxis SparseVariableAxis::SparseVariableAxis(String name, PrimExpr length, Buffer indptr, Buffer indices) { @@ -109,6 +132,13 @@ TVM_REGISTER_GLOBAL("tir.sparse.SparseVariableAxis") return SparseVariableAxis(name, length, indptr, indices); }); +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "sparse_variable(" << op->name << ", " << op->length << ", " << op->indptr->name + << ", " << op->indices->name << ")"; + }); + // AxisTree AxisTree::AxisTree(Array axis_names, Array> axis_parent_names) { CHECK_EQ(axis_names.size(), axis_parent_names.size()) @@ -178,6 +208,27 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "], " << op->data << ")"; }); +// SpIterKind +std::ostream& operator<<(std::ostream& out, SpIterKind type) { + switch (type) { + case SpIterKind::kDenseFixed: + out << "dense-fixed"; + break; + case SpIterKind::kDenseVariable: + out << "dense-variable"; + break; + case SpIterKind::kSparseFixed: + out << "sparse-fixed"; + break; + case SpIterKind::kSparseVariable: + out << "sparse-variable"; + break; + default: + LOG(FATAL) << "Cannot reach here"; + } + return out; +} + // SpIterVar SpIterVar::SpIterVar(Var var, PrimExpr max_extent, SpIterKind kind, bool is_reduction, Axis axis) { ObjectPtr node = make_object(); @@ -210,5 +261,13 @@ TVM_REGISTER_GLOBAL("tir.sparse.SpIterVar") return SpIterVar(var, max_extent, SpIterKind(kind), is_reduction, axis); }); +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "sp_iter_var(" << op->var->name_hint << ", " << op->max_extent << ", " + << op->kind << ", " << (op->is_reduction ? "reduction" : "spatial") << ", " + << op->axis->name << ")"; + }); + } // namespace tir } // namespace tvm