diff --git a/include/tvm/tir/sparse.h b/include/tvm/tir/sparse.h index 4a87f264b1ce..c083bb1e3efb 100644 --- a/include/tvm/tir/sparse.h +++ b/include/tvm/tir/sparse.h @@ -45,18 +45,35 @@ enum class AxisKind : int { */ class AxisNode : public Object { public: + void VisitAttrs(AttrVisitor* v) { + v->Visit("name", &name); + v->Visit("length", &length); + v->Visit("is_derived_axis", &is_derived_axis); + } + + bool SEqualReduce(const AxisNode* other, SEqualReducer equal) const { + return equal(name, other->name) && equal(length, other->length) && + equal(is_derived_axis, other->is_derived_axis); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(name); + hash_reduce(length); + hash_reduce(is_derived_axis); + } + /* name of current axis. */ String name; /* length of current axis. For sparse axis, length refers to the upperbound of * the current axis. */ PrimExpr length; + /* indicates whether current axis is derived by dense(axis) or fuse(axis1, axis2, ...) */ + bool is_derived_axis = false; String GetName() const { return name; } PrimExpr GetLength() const { return length; } DataType GetIndexType() const { return length->dtype; } - - virtual bool is_fixed() const = 0; - + virtual AxisKind kind() const = 0; static constexpr const char* _type_key = "tir.sparse.Axis"; @@ -74,24 +91,6 @@ class Axis : public ObjectRef { TVM_DEFINE_OBJECT_REF_METHODS(Axis, ObjectRef, AxisNode); }; -/*! - * \brief Root of Axis Dependency Tree. - */ -class RootAxisNode : public Object { - public: - static constexpr const char* _type_key = "tir.sparse.RootAxis"; - TVM_DECLARE_FINAL_OBJECT_INFO(RootAxisNode, Object); -}; - -/*! - * \brief Managed reference to RootAxisNode. - * \sa RootAxisNode - */ -class RootAxis : public ObjectRef { - public: - TVM_DEFINE_OBJECT_REF_METHODS(RootAxis, ObjectRef, RootAxisNode); -}; - /*! * \brief Dense axis whose column indices are consecutive. */ @@ -133,84 +132,134 @@ class SparseAxis : public Axis { */ class DenseFixedAxisNode : public DenseAxisNode { public: - Optional from_sparse; + AxisKind kind() const final { return AxisKind::kDenseFixed; } + + static constexpr const char* _type_key = "tir.sparse.DenseFixedAxis"; + TVM_DECLARE_BASE_OBJECT_INFO(DenseFixedAxisNode, DenseAxisNode); +}; + +/*! + * \brief Managed reference to DenseFixedAxisNode. + * \sa DenseFixedAxisNode + */ +class DenseFixedAxis : public DenseAxis { + public: + TVM_DLL explicit DenseFixedAxis(String name, PrimExpr length); + + TVM_DEFINE_OBJECT_REF_METHODS(DenseFixedAxis, DenseAxis, DenseFixedAxisNode); +}; +/*! \brief Derivation axis, constructed by T.dense(axis). */ +class DenseFromSparseAxisNode : public DenseFixedAxisNode { + public: void VisitAttrs(AttrVisitor* v) { - v->Visit("name", &name); - v->Visit("length", &length); - v->Visit("from_sparse", &from_sparse); + DenseFixedAxisNode::VisitAttrs(v); + v->Visit("base", &base); } - bool SEqualReduce(const DenseFixedAxisNode* other, SEqualReducer equal) const { - return equal(name, other->name) && equal(length, other->length) && - equal(from_sparse, other->from_sparse); + bool SEqualReduce(const DenseFromSparseAxisNode* other, SEqualReducer equal) const { + return DenseFixedAxisNode::SEqualReduce(other, equal) && equal(base, other->base); } void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(name); - hash_reduce(length); - hash_reduce(from_sparse); + DenseFixedAxisNode::SHashReduce(hash_reduce); + hash_reduce(base); + } + + /* The based sparse axis. */ + SparseAxis base; + + static constexpr const char* _type_key = "tir.sparse.DenseFromSparseAxis"; + TVM_DECLARE_FINAL_OBJECT_INFO(DenseFromSparseAxisNode, DenseFixedAxisNode); +}; + +/*! + * \brief Managed reference of DenseFromSparseAxisNode. + * \sa DenseFromSparseAxisNode + */ +class DenseFromSparseAxis : public DenseFixedAxis { + public: + /* DenseFromSparseAxis could be constructed by specifying the based sparse axis. */ + TVM_DLL explicit DenseFromSparseAxis(SparseAxis base); + + TVM_DEFINE_OBJECT_REF_METHODS(DenseFromSparseAxis, DenseFixedAxis, DenseFromSparseAxisNode); +}; + +class FusedAxis; + +/*! \brief Derivation axis, constructed by T.fuse(axis1, axis2, ...) */ +class FusedAxisNode : public DenseFixedAxisNode { + public: + void VisitAttrs(AttrVisitor* v) { + DenseFixedAxisNode::VisitAttrs(v); + v->Visit("group", &group); + v->Visit("index", &index); } - bool is_fixed() const final{ - return true; + bool SEqualReduce(const FusedAxisNode* other, SEqualReducer equal) const { + return DenseFixedAxisNode::SEqualReduce(other, equal) && equal(group, other->group) && + equal(index, other->index); } - AxisKind kind() const final { - return AxisKind::kDenseFixed; + void SHashReduce(SHashReducer hash_reduce) const { + DenseFixedAxisNode::SHashReduce(hash_reduce); + hash_reduce(group); + hash_reduce(index); } - static constexpr const char* _type_key = "tir.sparse.DenseFixedAxis"; - TVM_DECLARE_FINAL_OBJECT_INFO(DenseFixedAxisNode, DenseAxisNode); + /* The group of axes to be fused. */ + Array group; + /* The index of current FusedAxis in the group. */ + int index; + + static constexpr const char* _type_key = "tir.sparse.FusedAxis"; + TVM_DECLARE_FINAL_OBJECT_INFO(FusedAxisNode, DenseFixedAxisNode); }; /*! - * \brief Managed reference to DenseFixedAxisNode. - * \sa DenseFixedAxisNode + * \brief Managed refenrence to FusedAxisNode. + * \sa FusedAxisNode */ -class DenseFixedAxis : public DenseAxis { +class FusedAxis : public DenseFixedAxis { public: - TVM_DLL explicit DenseFixedAxis(String name, PrimExpr length, - Optional from_sparse = NullOpt); + /* Fused axis could be constructed by specifying a group of based axes and an index */ + TVM_DLL explicit FusedAxis(Array group, int index); - TVM_DEFINE_OBJECT_REF_METHODS(DenseFixedAxis, DenseAxis, DenseFixedAxisNode); + TVM_DEFINE_OBJECT_REF_METHODS(FusedAxis, DenseFixedAxis, FusedAxisNode); }; +/*! + * \brief Dense axis with variable length, such as ragged tensor. + */ class DenseVariableAxisNode : public DenseAxisNode { public: Buffer indptr; void VisitAttrs(AttrVisitor* v) { - v->Visit("name", &name); - v->Visit("length", &length); + DenseAxisNode::VisitAttrs(v); v->Visit("indptr", &indptr); } bool SEqualReduce(const DenseVariableAxisNode* other, SEqualReducer equal) const { - return equal(name, other->name) && equal(length, other->length) && equal(indptr, other->indptr); + return DenseAxisNode::SEqualReduce(other, equal) && equal(indptr, other->indptr); } void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(name); - hash_reduce(length); + DenseAxisNode::SHashReduce(hash_reduce); hash_reduce(indptr); } - bool is_fixed() const final { - return false; - } + PrimExpr nnz() const { return indptr->shape[0]; } - AxisKind kind() const final { - return AxisKind::kDenseVariable; - } + AxisKind kind() const final { return AxisKind::kDenseVariable; } static constexpr const char* _type_key = "tir.sparse.DenseVariableAxis"; TVM_DECLARE_FINAL_OBJECT_INFO(DenseVariableAxisNode, DenseAxisNode); }; /*! - * \brief Dense axis whose length is dependent on its predecessors on the axis - * dependency tree. + * \brief Managed reference to DenseVariableAxisNode. + * \sa DenseVariableAxisNode */ class DenseVariableAxis : public DenseAxis { public: @@ -229,31 +278,23 @@ class SparseFixedAxisNode : public SparseAxisNode { PrimExpr nnz_cols; void VisitAttrs(AttrVisitor* v) { - v->Visit("name", &name); - v->Visit("length", &length); + SparseAxisNode::VisitAttrs(v); v->Visit("indptr", &indices); v->Visit("nnz_cols", &nnz_cols); } bool SEqualReduce(const SparseFixedAxisNode* other, SEqualReducer equal) const { - return equal(name, other->name) && equal(length, other->length) && - equal(indices, other->indices) && equal(nnz_cols, other->nnz_cols); + return SparseAxisNode::SEqualReduce(other, equal) && equal(indices, other->indices) && + equal(nnz_cols, other->nnz_cols); } void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(name); - hash_reduce(length); + SparseFixedAxisNode::SHashReduce(hash_reduce); hash_reduce(indices); hash_reduce(nnz_cols); } - bool is_fixed() const final { - return true; - } - - AxisKind kind() const final { - return AxisKind::kSparseFixed; - } + AxisKind kind() const final { return AxisKind::kSparseFixed; } static constexpr const char* _type_key = "tir.sparse.SparseFixedAxis"; TVM_DECLARE_FINAL_OBJECT_INFO(SparseFixedAxisNode, SparseAxisNode); @@ -279,31 +320,25 @@ class SparseVariableAxisNode : public SparseAxisNode { Buffer indices; void VisitAttrs(AttrVisitor* v) { - v->Visit("name", &name); - v->Visit("length", &length); + SparseAxisNode::VisitAttrs(v); v->Visit("indptr", &indptr); v->Visit("indices", &indices); } bool SEqualReduce(const SparseVariableAxisNode* other, SEqualReducer equal) const { - return equal(name, other->name) && equal(length, other->length) && - equal(indptr, other->indptr) && equal(indices, other->indices); + return SparseAxisNode::SEqualReduce(other, equal) && equal(indptr, other->indptr) && + equal(indices, other->indices); } void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(name); - hash_reduce(length); + SparseAxisNode::SHashReduce(hash_reduce); hash_reduce(indptr); hash_reduce(indices); } - bool is_fixed() const final { - return false; - } + PrimExpr nnz() const { return indptr->shape[0]; } - AxisKind kind() const final { - return AxisKind::kSparseVariable; - } + AxisKind kind() const final { return AxisKind::kSparseVariable; } static constexpr const char* _type_key = "tir.sparse.SparseVariableAxis"; TVM_DECLARE_FINAL_OBJECT_INFO(SparseVariableAxisNode, SparseAxisNode); @@ -408,7 +443,6 @@ class SparseBuffer : public ObjectRef { TVM_DEFINE_OBJECT_REF_METHODS(SparseBuffer, ObjectRef, SparseBufferNode); }; - // overload printing of for type. TVM_DLL std::ostream& operator<<(std::ostream& os, AxisKind kind); diff --git a/python/tvm/script/tir/intrin.py b/python/tvm/script/tir/intrin.py index d3f000e4048d..75e8df1d4033 100644 --- a/python/tvm/script/tir/intrin.py +++ b/python/tvm/script/tir/intrin.py @@ -28,6 +28,8 @@ SpIterVar, SparseFixedAxis, SparseVariableAxis, + DenseFromSparseAxis, + FusedAxis ) from ..registry import register from ..utils import get_param_list, tvm_span_from_synr @@ -263,6 +265,11 @@ def comm_reducer(lambda_io, identities, span): @register def dense(axis: Axis, span: Optional[Span] = None): if isinstance(axis, (SparseFixedAxis, SparseVariableAxis)): - return DenseFixedAxis(axis.name + "_dense", axis.length, axis) + return DenseFromSparseAxis(axis) else: return axis + + +@register +def fuse(group: List[Axis], span: Optional[Span] = None): + return [FusedAxis(group, _) for _ in range(len(group))] diff --git a/python/tvm/tir/sparse.py b/python/tvm/tir/sparse.py index 23dd3b98cb37..7c89f1d56672 100644 --- a/python/tvm/tir/sparse.py +++ b/python/tvm/tir/sparse.py @@ -65,17 +65,48 @@ class DenseFixedAxis(DenseAxis): length : PrimExpr The length of the axis - - from_sparse : Optional[SparseAxis] - The SparseAxis that this axis is created from """ name: str length: PrimExpr - from_sparse: Optional[SparseAxis] - def __init__(self, name, length, from_sparse=None): - self.__init_handle_by_constructor__(_ffi_api.DenseFixedAxis, name, length, from_sparse) # type: ignore + def __init__(self, name, length): + self.__init_handle_by_constructor__(_ffi_api.DenseFixedAxis, name, length) # type: ignore + + +@tvm._ffi.register_object("tir.sparse.DenseFromSparseAxis") +class DenseFromSparseAxis(DenseFixedAxis): + """DenseFromSparseAxis node + + Parameters + ---------- + base : Axis + The based sparse axis. + """ + + base: Axis + + def __init__(self, base): + self.__init_handle_by_constructor__(_ffi_api.DenseFromSparseAxis, base) # type: ignore + + +@tvm._ffi.register_object("tir.sparse.FusedAxis") +class FusedAxis(DenseFixedAxis): + """FusedAxis node + + Parameters + ---------- + group : List[Axis] + The axes group to be fused. + index : int + The index of current axis in the fused axes group. + """ + + group: List[Axis] + index: int + + def __init__(self, group, index): + self.__init_handle_by_constructor__(_ffi_api.FusedAxis, group, index) # type: ignore @tvm._ffi.register_object("tir.sparse.DenseVariableAxis") diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index 46bcb589e1d9..afdc0e6a34f1 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -469,8 +469,12 @@ Doc TVMScriptPrinter::AllocAxis(const Axis& axis) { Doc val; const auto* df_axis = axis.as(); - if (df_axis != nullptr && df_axis->from_sparse.defined()) { - val << tir_prefix_ << ".dense(" << Print(df_axis->from_sparse.value()) << ")"; + if (df_axis != nullptr && df_axis->is_derived_axis) { + if (const DenseFromSparseAxisNode* dfs_axis = axis.as()) { + val = Doc::Text(tir_prefix_ + ".dense(" + dfs_axis->base->name + ")"); + } else { + CHECK(false) << "Cannot allocate fused axis"; + } } else { std::string name = axis->name; if (name.length() == 0 || !std::isalnum(name[0])) { @@ -1315,9 +1319,28 @@ Doc TVMScriptPrinter::PrintSparseBlockName(const SparseBlockNode* op) { for (int i = 0; i < n_iter; ++i) { const SpIterVar& sp_iter = op->sp_iter_vars[i]; + const Axis& axis = sp_iter->axis; Doc iter_doc; - iter_doc << sp_iter->axis->name; - // TODO(zihao): fix expressions like T.dense(J) + + std::string axis_repr = sp_iter->axis->name; + if (axis->is_derived_axis) { + if (const DenseFromSparseAxisNode* dfs_axis = axis.as()) { + iter_doc << tir_prefix_ << ".dense(" << dfs_axis->base->name << ")"; + } else { + const FusedAxisNode* fused_axis = axis.as(); + std::string orig_axis_name = fused_axis->group[fused_axis->index]->name; + if (fused_axis->index == 0) { + iter_doc << tir_prefix_ << ".fuse(" << orig_axis_name; + } else if (fused_axis->index == fused_axis->group.size() - 1) { + iter_doc << orig_axis_name << ")"; + } else { + iter_doc << orig_axis_name; + } + } + } else { + iter_doc << axis->name; + } + var_not_in_headers_.insert(sp_iter->var.get()); sp_iter_docs.push_back(iter_doc); sp_iter_name_docs.push_back(Print(sp_iter->var)); diff --git a/src/tir/ir/sparse.cc b/src/tir/ir/sparse.cc index da1c8e0b7d7b..287d57608fc6 100644 --- a/src/tir/ir/sparse.cc +++ b/src/tir/ir/sparse.cc @@ -29,7 +29,8 @@ namespace tvm { namespace tir { -// Axis +/******** Attributes of sparse axis. ********/ + TVM_REGISTER_GLOBAL("tir.sparse.GetAxisName").set_body_typed([](Axis axis) { return axis->GetName(); }); @@ -42,33 +43,31 @@ TVM_REGISTER_GLOBAL("tir.sparse.GetAxisIndexType").set_body_typed([](Axis axis) return DLDataType2String(axis->GetIndexType()); }); -// DenseFixedAxis -DenseFixedAxis::DenseFixedAxis(String name, PrimExpr length, Optional from_sparse) { +/******** DenseFixedAxis ********/ + +/*! \brief Default constructor of DenseFixedAxis */ +DenseFixedAxis::DenseFixedAxis(String name, PrimExpr length) { ObjectPtr node = make_object(); node->name = std::move(name); node->length = std::move(length); - node->from_sparse = std::move(from_sparse); data_ = std::move(node); } TVM_REGISTER_NODE_TYPE(DenseFixedAxisNode); -TVM_REGISTER_GLOBAL("tir.sparse.DenseFixedAxis") - .set_body_typed([](String name, PrimExpr length, Optional from_sparse) { - return DenseFixedAxis(name, length, from_sparse); - }); +TVM_REGISTER_GLOBAL("tir.sparse.DenseFixedAxis").set_body_typed([](String name, PrimExpr length) { + return DenseFixedAxis(std::move(name), std::move(length)); +}); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); - p->stream << "dense_fixed(" << op->name << ", " << op->length; - if (op->from_sparse.defined()) { - p->stream << ", from_sparse=" << op->from_sparse.value(); - } - p->stream << ")"; + p->stream << "dense_fixed(" << op->name << ", " << op->length << ")"; }); -// DenseVariableAxis +/******** DenseVariableAxis ********/ + +/*! \brief Default constuctor of DenseVariableAxis */ DenseVariableAxis::DenseVariableAxis(String name, PrimExpr length, Buffer indptr) { ObjectPtr node = make_object(); node->name = std::move(name); @@ -87,10 +86,90 @@ TVM_REGISTER_GLOBAL("tir.sparse.DenseVariableAxis") TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); - p->stream << "dense_variable(" << op->name << ", " << op->length << ", " << op->indptr->name; + p->stream << "dense_variable(" << op->name << ", " << op->length << ", " << op->indptr->name + << ")"; + }); + +/******** DenseFromSparseAxis ********/ + +/*! \brief Default constructor of DenseFromSparseAxis */ +DenseFromSparseAxis::DenseFromSparseAxis(SparseAxis base) { + ObjectPtr node = make_object(); + node->name = base->name + "_dense"; + node->length = base->length; + node->is_derived_axis = true; + node->base = std::move(base); + data_ = std::move(node); +} + +TVM_REGISTER_NODE_TYPE(DenseFromSparseAxisNode); + +TVM_REGISTER_GLOBAL("tir.sparse.DenseFromSparseAxis").set_body_typed([](SparseAxis base) { + return DenseFromSparseAxis(std::move(base)); +}); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "dense_from_sparse(" << op->base->name << ")"; + }); + +/******** FusedAxis ********/ + +/*! \brief Default constructor of FusedAxis */ +FusedAxis::FusedAxis(Array group, int index) { + CHECK(index < int(group.size())) << "Index " << index << "exceeds the size of fused axes group."; + + // TODO(zihao): check whether it valid to fuse axes in the group. + + ObjectPtr node = make_object(); + std::string fused_name = group[0]->name; + for (int i = 1; i < group.size(); ++i) { + fused_name += group[i]->name; + } + node->name = "fused_" + fused_name + "_" + group[index]->name; + + if (const auto* df_axis = group[index].as()) { + node->length = df_axis->length; + } else if (const auto* sf_axis = group[index].as()) { + // TODO(zihao): accumulate previous dimensions. + } else if (const auto* dv_axis = group[index].as()) { + node->length = dv_axis->nnz(); + } else if (const auto* sv_axis = group[index].as()) { + node->length = sv_axis->nnz(); + } + + node->is_derived_axis = true; + node->group = std::move(group); + node->index = index; + data_ = std::move(node); +} + +TVM_REGISTER_NODE_TYPE(FusedAxisNode); + +TVM_REGISTER_GLOBAL("tir.sparse.FusedAxis").set_body_typed([](Array group, int index) { + return FusedAxis(std::move(group), index); +}); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << "fused("; + bool first = true; + for (auto&& orig_axis : op->group) { + if (first) { + first = false; + } else { + p->stream << ", "; + } + p->stream << orig_axis->name; + } + p->stream << ")"; }); -// SparseFixedAxis +/******** SparseFixedAxis ********/ + +/*! \brief Default constructor of SparseFixedAxis */ SparseFixedAxis::SparseFixedAxis(String name, PrimExpr length, Buffer indices, PrimExpr nnz_cols) { ObjectPtr node = make_object(); node->name = std::move(name); @@ -114,7 +193,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) << op->indices->name << ")"; }); -// SparseVariableAxis +/******** SparseVariableAxis ********/ + +/*! \brief Default constructor of SparseVariableAxis */ SparseVariableAxis::SparseVariableAxis(String name, PrimExpr length, Buffer indptr, Buffer indices) { ObjectPtr node = make_object(); @@ -139,7 +220,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) << ", " << op->indices->name << ")"; }); -// AxisTree +/******** AxisTree ********/ + +/*! \brief Default constructor of AxisTree */ AxisTree::AxisTree(Array axis_names, Array> axis_parent_names) { CHECK_EQ(axis_names.size(), axis_parent_names.size()) << "ValueError: The axis_names array should have the same length as " @@ -179,7 +262,9 @@ TVM_REGISTER_GLOBAL("tir.sparse.AxisTree") return AxisTree(axis_names, axis_parent_names); }); -// SparseBuffer +/******** SparseBuffer ********/ + +/*! \brief Default constructor of SparseBuffer */ SparseBuffer::SparseBuffer(Array axes, Buffer data, String name) { ObjectPtr node = make_object(); CHECK_GT(static_cast(axes.size()), 0) @@ -211,7 +296,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "], " << op->data << ")"; }); -// AxisKind +/******** AxisKind ********/ + +/*! \brief Printer function of Axiskind. */ std::ostream& operator<<(std::ostream& out, AxisKind type) { switch (type) { case AxisKind::kDenseFixed: @@ -232,7 +319,9 @@ std::ostream& operator<<(std::ostream& out, AxisKind type) { return out; } -// SpIterVar +/******** SpIterVar ********/ + +/*! \brief Default constructor of SpIterVar. */ SpIterVar::SpIterVar(Var var, PrimExpr max_extent, bool is_reduction, Axis axis) { ObjectPtr node = make_object(); @@ -256,8 +345,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->stream << "sp_iter_var(" << op->var->name_hint << ", " << op->max_extent << ", " - << (op->is_reduction ? "reduction" : "spatial") << ", " - << op->axis->name << ")"; + << (op->is_reduction ? "reduction" : "spatial") << ", " << op->axis->name << ")"; }); } // namespace tir diff --git a/src/tir/transforms/lower_sparse_tir.cc b/src/tir/transforms/lower_sparse_tir.cc index 0ca725339ba3..a9de88264546 100644 --- a/src/tir/transforms/lower_sparse_tir.cc +++ b/src/tir/transforms/lower_sparse_tir.cc @@ -93,7 +93,7 @@ Map UpdateBufferMap(PrimFunc f) { * \return The lowered index. */ PrimExpr AggregateOffset(PrimExpr prev_offset, const Axis& axis, PrimExpr index, - arith::Analyzer* ana_) { + arith::Analyzer* ana_ = nullptr) { PrimExpr new_offset; switch (axis->kind()) { case AxisKind::kDenseFixed: { @@ -116,7 +116,11 @@ PrimExpr AggregateOffset(PrimExpr prev_offset, const Axis& axis, PrimExpr index, break; } } - return ana_->Simplify(new_offset); + if (ana_ == nullptr) { + return new_offset; + } else { + return ana_->Simplify(new_offset); + } } /*! \brief Storing the context information of a sparse block. */ @@ -129,11 +133,12 @@ class SparseBlockCtx { : sp_iter_var_map_(std::move(other.sp_iter_var_map_)), offset_(std::move(other.offset_)), parent_(std::move(parent_)), - blk_name_(std::move(blk_name_)) {} + blk_name_(std::move(blk_name_)), + ana_(std::move(other.ana_)) {} // default constructor - explicit Scope(String blk_name, Array sp_iter_vars, AxisTree tree) - : blk_name_(std::move(blk_name)) { + explicit Scope(String blk_name, Array sp_iter_vars, AxisTree tree, arith::Analyzer* ana) + : blk_name_(std::move(blk_name)), ana_(ana) { std::unordered_map axis_name_sp_iter_map_; // initialize sparse iter var dependency map. for (const SpIterVar& sp_iter_var : sp_iter_vars) { @@ -145,8 +150,8 @@ class SparseBlockCtx { for (const SpIterVar& sp_iter_var : sp_iter_vars) { String axis_name = sp_iter_var->axis->name; const SpIterVarNode* node = sp_iter_var.get(); - if (support::EndsWith(axis_name, "_dense")) { - // ends with "_dense", the axis is generated via to_dense + if (sp_iter_var->axis->is_derived_axis) { + // The axis is a derived axis. parent_[node] = nullptr; } else { auto opt = tree->parent.Get(axis_name); @@ -215,7 +220,7 @@ class SparseBlockCtx { return it->second; } else { PrimExpr prev_off = GetOffset(parent_[sp_iter_var]); - PrimExpr new_off = AggregateOffset(prev_off, sp_iter_var->axis, sp_iter_var->var, &ana_); + PrimExpr new_off = AggregateOffset(prev_off, sp_iter_var->axis, sp_iter_var->var, ana_); offset_[sp_iter_var] = new_off; return new_off; } @@ -230,8 +235,8 @@ class SparseBlockCtx { std::tuple GetIndicesRange(const SpIterVarNode* sp_iter_var) { PrimExpr prev_off = GetOffset(parent_[sp_iter_var]); const Axis& axis = sp_iter_var->axis; - return {AggregateOffset(prev_off, axis, Integer(0), &ana_), - AggregateOffset(add(prev_off, 1), axis, Integer(0), &ana_)}; + return {AggregateOffset(prev_off, axis, Integer(0), ana_), + AggregateOffset(add(prev_off, 1), axis, Integer(0), ana_)}; } /*! @@ -243,16 +248,16 @@ class SparseBlockCtx { std::unordered_map sp_iter_var_map_; std::unordered_map offset_; std::unordered_map parent_; - arith::Analyzer ana_; String blk_name_; + arith::Analyzer* ana_; }; /*! \brief default constructor */ - explicit SparseBlockCtx(AxisTree tree) : tree_(std::move(tree)) {} + explicit SparseBlockCtx(AxisTree tree, arith::Analyzer* ana) : tree_(std::move(tree)), ana_(ana) {} /*! \brief enter new scope */ void EnterScope(const SparseBlockNode* sp_block) { - stack_.emplace_back(sp_block->name, sp_block->sp_iter_vars, tree_); + stack_.emplace_back(sp_block->name, sp_block->sp_iter_vars, tree_, ana_); } /*! \brief exit current scope */ @@ -277,6 +282,7 @@ class SparseBlockCtx { private: std::vector stack_; AxisTree tree_; + arith::Analyzer* ana_; /*! \brief the top scope in the sparse block stack. */ inline Scope* top() const { return const_cast(&stack_.back()); } @@ -293,11 +299,12 @@ class SparseBufferCtx { axes_(std::move(other.axes_)), offsets_(std::move(other.offsets_)), matches_(std::move(other.matches_)), - sp_blk_ctx_(std::move(other.sp_blk_ctx_)) {} + sp_blk_ctx_(std::move(other.sp_blk_ctx_)), + ana_(std::move(other.ana_)) {} /*! \brief default constructor */ - explicit Scope(String buf_name, Array axes, const SparseBlockCtx* sp_blk_ctx) - : buf_name_(std::move(buf_name)), axes_(std::move(axes)), sp_blk_ctx_(sp_blk_ctx) { + explicit Scope(String buf_name, Array axes, const SparseBlockCtx* sp_blk_ctx, arith::Analyzer* ana) + : buf_name_(std::move(buf_name)), axes_(std::move(axes)), sp_blk_ctx_(sp_blk_ctx), ana_(ana) { offsets_.emplace_back(Integer(0)); matches_.emplace_back(true); } @@ -326,14 +333,13 @@ class SparseBufferCtx { } // update offset - PrimExpr new_offset = AggregateOffset(offsets_.back(), axis, std::move(coordinate), &ana_); + PrimExpr new_offset = AggregateOffset(offsets_.back(), axis, std::move(coordinate), ana_); offsets_.emplace_back(std::move(new_offset)); } /*! \brief get the axis given dimension index of current buffer. */ - const Axis& GetAxis(int dim) const { - auto&& ret = axes_[dim]; - return ret; + Axis GetAxis(int dim) const { + return axes_[dim]; } /*! \brief whether the index access pattern of current buffer aligns with current block */ @@ -342,34 +348,33 @@ class SparseBufferCtx { /*! \brief return the indices range of the given dimension in current buffer. */ std::tuple GetIndicesRange(int dim) { const Axis& axis = axes_[dim]; - return {AggregateOffset(offsets_[dim], axis, Integer(0), &ana_), - AggregateOffset(add(offsets_[dim], 1), axis, Integer(0), &ana_)}; + return {AggregateOffset(offsets_[dim], axis, Integer(0), ana_), + AggregateOffset(add(offsets_[dim], 1), axis, Integer(0), ana_)}; } private: String buf_name_; Array axes_; - arith::Analyzer ana_; std::vector offsets_; std::vector matches_; const SparseBlockCtx* sp_blk_ctx_; + arith::Analyzer* ana_; }; /*! \brief default constructor */ - explicit SparseBufferCtx(AxisTree tree) : tree_(std::move(tree)) {} + explicit SparseBufferCtx(AxisTree tree, arith::Analyzer* ana) : tree_(std::move(tree)), ana_(ana) {} /*! \brief enter new scope */ void EnterScope(const SparseBuffer& sp_buf, const SparseBlockCtx* sp_blk_ctx) { - stack_.emplace_back(sp_buf->name, sp_buf->axes, sp_blk_ctx); + stack_.emplace_back(sp_buf->name, sp_buf->axes, sp_blk_ctx, ana_); } /*! \brief exit current scope */ void ExitScope() { stack_.pop_back(); } /*! \brief call GetAxis in top scope. */ - const Axis& GetAxis(int dim) const { - auto&& ret = top()->GetAxis(dim); - return ret; + Axis GetAxis(int dim) const { + return top()->GetAxis(dim); } /*! \brief call MatchWithSpBlock in top scope. */ @@ -385,8 +390,8 @@ class SparseBufferCtx { private: AxisTree tree_; - arith::Analyzer ana_; std::vector stack_; + arith::Analyzer* ana_; /*! \brief the top scope in the sparse buffer stack. */ inline Scope* top() const { return const_cast(&stack_.back()); } @@ -399,14 +404,13 @@ class SparseBufferCtx { class IndexTransformer : public StmtExprMutator { public: explicit IndexTransformer(const AxisTree& axis_tree) - : sp_blk_ctx_(axis_tree), sp_buf_ctx_(axis_tree), axis_tree_(axis_tree) {} + : sp_blk_ctx_(axis_tree, &ana_), sp_buf_ctx_(axis_tree, &ana_), axis_tree_(axis_tree) {} private: // Sparse block context stack; SparseBlockCtx sp_blk_ctx_; // Sparse buffer context stack; SparseBufferCtx sp_buf_ctx_; - /*! * \brief Return the offset of index on given dimension. * \param dim The dimension index. @@ -559,10 +563,16 @@ class IndexTransformer : public StmtExprMutator { for (int i = 0; i < n_iter; ++i) { SpIterVar sp_it_var = sp_block->sp_iter_vars[i]; String axis_name = sp_it_var->axis->name; - auto&& parent_axis = axis_tree_->parent.Get(axis_name); - CHECK(parent_axis.defined()) << "Sparse IterVar not defined in Axis Tree."; - String parent_axis_name = parent_axis.value(); - bool is_fixed_axis = sp_it_var->axis->is_fixed(); + String parent_axis_name; + if (sp_it_var->axis->is_derived_axis) { + // derived axis doesn't appear in the axis tree. + parent_axis_name = "root"; + } else { + auto&& parent_axis = axis_tree_->parent.Get(axis_name); + CHECK(parent_axis.defined()) << "Sparse IterVar not defined in Axis Tree."; + parent_axis_name = parent_axis.value(); + } + bool is_fixed_axis = (sp_it_var->axis->kind() == AxisKind::kDenseFixed || sp_it_var->axis->kind() == AxisKind::kSparseFixed); /* Add itervar to current block when * - it's not used yet (not in stack) and * - it's parent axis was used in outer blocks or diff --git a/tests/python/sparsetir/test_tir_sparse_lower.py b/tests/python/sparsetir/test_tir_sparse_lower.py index fab08767387f..da58f0608a35 100644 --- a/tests/python/sparsetir/test_tir_sparse_lower.py +++ b/tests/python/sparsetir/test_tir_sparse_lower.py @@ -14,9 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from os import replace -from numpy.core.fromnumeric import size -from scipy.sparse import bsr import tvm import tvm.testing import tvm.tir as tir @@ -50,6 +47,30 @@ def csrmm( C[vi, vk] = C[vi, vk] + A[vi, vj] * B[vj, vk] +@T.prim_func +def csrmm_dense_iter( + a: T.handle, + b: T.handle, + c: T.handle, + indptr: T.handle, + indices: T.handle, + n: T.int32, + m: T.int32, + k: T.int32, + nnz: T.int32, +) -> None: + I = T.dense_fixed(n) + J = T.sparse_variable((m, n + 1, nnz), (indptr, indices), "int32") + K = T.dense_fixed(k) + A = T.match_sparse_buffer(a, (I, J), nnz, "float32") + B = T.match_sparse_buffer(b, (T.dense(J), K), m * k, "float32") + C = T.match_sparse_buffer(c, (I, K), n * k, "float32") + with T.iter([I, T.dense(J), K], "SRS", "csrmm") as [vi, vj, vk]: + with T.init(): + C[vi, vk] = 0.0 + C[vi, vk] = C[vi, vk] + A[vi, vj] * B[vj, vk] + + @T.prim_func def lowered_csrmm(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.handle, n: T.int32, m: T.int32, k: T.int32, nnz: T.int32) -> None: A_data = T.match_buffer(a, [nnz], dtype="float32") @@ -233,6 +254,7 @@ def lowered_ellpack_mm(a: T.handle, b: T.handle, c: T.handle, indices: T.handle, C_data[(vi * blk + vbi) * feat_size + vf] = C_data[(vi * blk + vbi) * feat_size + vf] + A_data[((vi * col + vj) * blk + vbi) * blk + vbj] * B_data[(J_indices[vi * col + vj] * blk + vbj) * feat_size + vf] + @T.prim_func def csr_element_wise( a: T.handle, @@ -301,6 +323,18 @@ def test_csrmm(): tvm.testing.assert_allclose(y_ground_truth.reshape(-1), Y_nd.numpy(), rtol=1e-5, atol=1e-5) +def test_csrmm_dense_iter(): + mod = tvm.IRModule.from_expr(csrmm_dense_iter) + t = AxisTree({ + "J": "I", + "I": None, + "K": None + }) + mod = tvm.tir.transform.LowerSparseTIR(t)(mod) + print(mod["main"].script()) + # tvm.ir.assert_structural_equal(mod["main"], lowered_csrmm, True) + + def test_csr_reduce(): mod = tvm.IRModule.from_expr(csr_reduce) t = AxisTree({ @@ -456,6 +490,7 @@ def test_csr_element_wise(): if __name__ == "__main__": test_csrmm() + test_csrmm_dense_iter() test_csr_reduce() test_bsrmm() test_ellpack_mm() diff --git a/tests/python/sparsetir/test_tir_sparse_script_roundtrip.py b/tests/python/sparsetir/test_tir_sparse_script_roundtrip.py index 62feee2e32fa..5ea544470526 100644 --- a/tests/python/sparsetir/test_tir_sparse_script_roundtrip.py +++ b/tests/python/sparsetir/test_tir_sparse_script_roundtrip.py @@ -38,6 +38,24 @@ def csrmm(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.ha C[vi, vk] = C[vi, vk] + A[vi, vj] * B[vj, vk] +@T.prim_func +def csrmm_dense_iter(a: T.handle, b: T.handle, c: T.handle, indptr: T.handle, indices: T.handle) -> None: + n = T.var("int32") + m = T.var("int32") + k = T.var("int32") + nnz = T.var("int32") + I = T.dense_fixed(n) + J = T.sparse_variable((m, n + 1, nnz), (indptr, indices), "int32") + K = T.dense_fixed(k) + A = T.match_sparse_buffer(a, (I, J), nnz, "float32") + B = T.match_sparse_buffer(b, (T.dense(J), K), m * k, "float32") + C = T.match_sparse_buffer(c, (I, K), n * k, "float32") + with T.iter([I, T.dense(J), K], "SRS", "csrmm") as [vi, vj, vk]: + with T.init(): + C[vi, vk] = 0.0 + C[vi, vk] = C[vi, vk] + A[vi, vj] * B[vj, vk] + + @T.prim_func def csr_reduce(a: T.handle, b: T.handle, indptr: T.handle, indices: T.handle) -> None: n = T.var("int32") @@ -130,6 +148,12 @@ def test_csrmm(): tvm.ir.assert_structural_equal(func, rt_func, True) +def test_csrmm_dense_iter(): + func = csrmm_dense_iter + rt_func = tvm.script.from_source(func.script(show_meta=True)) + tvm.ir.assert_structural_equal(func, rt_func, True) + + def test_csr_reduce(): func = csr_reduce rt_func = tvm.script.from_source(func.script(show_meta=True)) @@ -156,6 +180,7 @@ def test_csr_element_wise(): if __name__ == "__main__": test_csrmm() + test_csrmm_dense_iter() test_csr_reduce() test_bsrmm() test_ellpack_mm()