diff --git a/src/tir/transforms/lower_sparse_tir.cc b/src/tir/transforms/lower_sparse_tir.cc index 42fb1ba12638..0ca725339ba3 100644 --- a/src/tir/transforms/lower_sparse_tir.cc +++ b/src/tir/transforms/lower_sparse_tir.cc @@ -85,10 +85,11 @@ Map UpdateBufferMap(PrimFunc f) { } /*! - * \brief Compupte the partially lowered index. + * \brief Aggregate offset on previous axes with the index on current axis. * \param prev_offset The lowered index accumulated over all axis prior to current axis. * \param axis Current axis. * \param index The sparse index on current axis. + * \param ana_ The analyzer used for simplifying expressions. TODO(zihao): make it more cleaner. * \return The lowered index. */ PrimExpr AggregateOffset(PrimExpr prev_offset, const Axis& axis, PrimExpr index, @@ -167,8 +168,14 @@ class SparseBlockCtx { offset_[nullptr] = Integer(0); } - Optional GetSparseIterVar(const VarNode* var_node) const { - auto it = sp_iter_var_map_.find(var_node); + /*! + * \brief Get sparse iter var corresponding to given variable node in the current scope. + * \param var The variable node in AST. + * \return A optional wrapper of sparse iter var. If var is not a sparse iter var, return + * NullOpt. + */ + Optional GetSparseIterVar(const VarNode* var) const { + auto it = sp_iter_var_map_.find(var); if (it != sp_iter_var_map_.end()) { return it->second; } else { @@ -177,8 +184,9 @@ class SparseBlockCtx { } /*! - * \brief Get coordinate of corresding sparse iter var. + * \brief Get coordinate of corresding sparse iter var in the current scope. * \param sp_iter_var The compressed iterator. + * \return A PrimExpr representing the coordinate. */ PrimExpr GetCoordinate(const SpIterVarNode* sp_iter_var) { const Axis& axis = sp_iter_var->axis; @@ -196,7 +204,10 @@ class SparseBlockCtx { } } - /*! \brief TODO + /*! + * \brief Get the real offset in compressed buffer of given sparse iter var. + * \param sp_iter_var The sparse iter var to lookup. + * \return A PrimExpr representing the offset. */ PrimExpr GetOffset(const SpIterVarNode* sp_iter_var) { auto it = offset_.find(sp_iter_var); @@ -210,7 +221,11 @@ class SparseBlockCtx { } } - /*! \brief TODO + /*! + * \brief Get the indices range in compressed buffer of given sparse iter var. + * \param sp_iter_var The sparse iter var to lookup. + * \return A tuple of PrimExpr, the first elements refers to the start position, and the second + * elements refers the end position. */ std::tuple GetIndicesRange(const SpIterVarNode* sp_iter_var) { PrimExpr prev_off = GetOffset(parent_[sp_iter_var]); @@ -219,7 +234,8 @@ class SparseBlockCtx { AggregateOffset(add(prev_off, 1), axis, Integer(0), &ana_)}; } - /*! \brief TODO + /*! + * \brief Get the current block name. */ const String GetBlockName() const { return blk_name_; } @@ -231,31 +247,39 @@ class SparseBlockCtx { String blk_name_; }; + /*! \brief default constructor */ explicit SparseBlockCtx(AxisTree tree) : tree_(std::move(tree)) {} + /*! \brief enter new scope */ void EnterScope(const SparseBlockNode* sp_block) { stack_.emplace_back(sp_block->name, sp_block->sp_iter_vars, tree_); } + /*! \brief exit current scope */ void ExitScope() { stack_.pop_back(); } - Optional GetSparseIterVar(const VarNode* var_node) const { - return local()->GetSparseIterVar(var_node); + /*! \brief call GetSparseIterVar in the top scope. */ + Optional GetSparseIterVar(const VarNode* node) const { + return top()->GetSparseIterVar(node); } - PrimExpr GetCoordinate(const SpIterVarNode* node) { return local()->GetCoordinate(node); } + /*! \brief call GetCoordinate in the top scope. */ + PrimExpr GetCoordinate(const SpIterVarNode* node) { return top()->GetCoordinate(node); } + /*! \brief call GetIndicesRange in the top scope. */ std::tuple GetIndicesRange(const SpIterVarNode* sp_iter_var) { - return local()->GetIndicesRange(sp_iter_var); + return top()->GetIndicesRange(sp_iter_var); } - const String GetBlockName() const { return local()->GetBlockName(); } + /*! \brief call GetBlockName in the top scope. */ + const String GetBlockName() const { return top()->GetBlockName(); } private: std::vector stack_; AxisTree tree_; - inline Scope* local() const { return const_cast(&stack_.back()); } + /*! \brief the top scope in the sparse block stack. */ + inline Scope* top() const { return const_cast(&stack_.back()); } }; /*! \brief Storing the context information of a sparse buffer. */ @@ -263,7 +287,7 @@ class SparseBufferCtx { public: class Scope { public: - // move constructor + /*! \brief move constructor */ explicit Scope(Scope&& other) : buf_name_(std::move(other.buf_name_)), axes_(std::move(other.axes_)), @@ -271,17 +295,18 @@ class SparseBufferCtx { matches_(std::move(other.matches_)), sp_blk_ctx_(std::move(other.sp_blk_ctx_)) {} - // default constructor + /*! \brief default constructor */ explicit Scope(String buf_name, Array axes, const SparseBlockCtx* sp_blk_ctx) : buf_name_(std::move(buf_name)), axes_(std::move(axes)), sp_blk_ctx_(sp_blk_ctx) { offsets_.emplace_back(Integer(0)); matches_.emplace_back(true); } - void Register(int idx, PrimExpr coordinate, PrimExpr orig_idx) { - ICHECK(idx + 1 == int(offsets_.size())) - << "Cannot register coordinate of index " << std::to_string(idx) << " at this time"; - const Axis& axis = GetAxis(idx); + /*! \brief register the coordinate of a new dimension of current buffer. */ + void Register(int dim, PrimExpr coordinate, PrimExpr orig_idx) { + ICHECK(dim + 1 == int(offsets_.size())) + << "Cannot register coordinate of index " << std::to_string(dim) << " at this time"; + const Axis& axis = GetAxis(dim); // update matches boolean array if (!matches_.back()) { @@ -305,17 +330,20 @@ class SparseBufferCtx { offsets_.emplace_back(std::move(new_offset)); } - const Axis& GetAxis(int idx) const { - auto && ret = axes_[idx]; + /*! \brief get the axis given dimension index of current buffer. */ + const Axis& GetAxis(int dim) const { + auto&& ret = axes_[dim]; return ret; } + /*! \brief whether the index access pattern of current buffer aligns with current block */ const inline bool MatchWithSpBlock() const { return matches_.back(); } - std::tuple GetIndicesRange(int idx) { - const Axis& axis = axes_[idx]; - return {AggregateOffset(offsets_[idx], axis, Integer(0), &ana_), - AggregateOffset(add(offsets_[idx], 1), axis, Integer(0), &ana_)}; + /*! \brief return the indices range of the given dimension in current buffer. */ + std::tuple GetIndicesRange(int dim) { + const Axis& axis = axes_[dim]; + return {AggregateOffset(offsets_[dim], axis, Integer(0), &ana_), + AggregateOffset(add(offsets_[dim], 1), axis, Integer(0), &ana_)}; } private: @@ -327,31 +355,41 @@ class SparseBufferCtx { const SparseBlockCtx* sp_blk_ctx_; }; + /*! \brief default constructor */ explicit SparseBufferCtx(AxisTree tree) : tree_(std::move(tree)) {} + /*! \brief enter new scope */ void EnterScope(const SparseBuffer& sp_buf, const SparseBlockCtx* sp_blk_ctx) { stack_.emplace_back(sp_buf->name, sp_buf->axes, sp_blk_ctx); } + /*! \brief exit current scope */ void ExitScope() { stack_.pop_back(); } - const Axis& GetAxis(int idx) const { - auto&& ret = local()->GetAxis(idx); + /*! \brief call GetAxis in top scope. */ + const Axis& GetAxis(int dim) const { + auto&& ret = top()->GetAxis(dim); return ret; } - const inline bool MatchWithSpBlock() const { return local()->MatchWithSpBlock(); } + /*! \brief call MatchWithSpBlock in top scope. */ + const inline bool MatchWithSpBlock() const { return top()->MatchWithSpBlock(); } - std::tuple GetIndicesRange(int idx) { return local()->GetIndicesRange(idx); } + /*! \brief call GetIndicesRange in top scope. */ + std::tuple GetIndicesRange(int dim) { return top()->GetIndicesRange(dim); } - void Register(int idx, PrimExpr coordinate, PrimExpr orig_idx) { local()->Register(idx, std::move(coordinate), std::move(orig_idx)); } + /*! \brief call Register in top scope. */ + void Register(int dim, PrimExpr coordinate, PrimExpr orig_idx) { + top()->Register(dim, std::move(coordinate), std::move(orig_idx)); + } private: AxisTree tree_; arith::Analyzer ana_; std::vector stack_; - inline Scope* local() const { return const_cast(&stack_.back()); } + /*! \brief the top scope in the sparse buffer stack. */ + inline Scope* top() const { return const_cast(&stack_.back()); } }; /*! @@ -361,7 +399,7 @@ class SparseBufferCtx { class IndexTransformer : public StmtExprMutator { public: explicit IndexTransformer(const AxisTree& axis_tree) - : axis_tree_(axis_tree), sp_blk_ctx_(axis_tree), sp_buf_ctx_(axis_tree) {} + : sp_blk_ctx_(axis_tree), sp_buf_ctx_(axis_tree), axis_tree_(axis_tree) {} private: // Sparse block context stack; @@ -369,13 +407,18 @@ class IndexTransformer : public StmtExprMutator { // Sparse buffer context stack; SparseBufferCtx sp_buf_ctx_; - PrimExpr ViewIndexInAxis(int idx, PrimExpr index) { + /*! + * \brief Return the offset of index on given dimension. + * \param dim The dimension index. + * \param index The PrimExpr representing the index on this dimension. + */ + PrimExpr ViewIndexInAxis(int dim, PrimExpr index) { // decompress index to coordinate on iterator axis. // the index might not be a single var node, use visitor to recursive construct the coordinate. PrimExpr coordinate = ExprMutator::VisitExpr(index); - const Axis& axis = sp_buf_ctx_.GetAxis(idx); + const Axis& axis = sp_buf_ctx_.GetAxis(dim); // register to sparse buffer scope - sp_buf_ctx_.Register(idx, coordinate, index); + sp_buf_ctx_.Register(dim, coordinate, index); PrimExpr offset = index; // compress coordinate to index on sparse buffer axis. @@ -388,14 +431,14 @@ class IndexTransformer : public StmtExprMutator { case AxisKind::kSparseFixed: { auto sf_axis = axis.as(); PrimExpr l, r; - std::tie(l, r) = sp_buf_ctx_.GetIndicesRange(idx); + std::tie(l, r) = sp_buf_ctx_.GetIndicesRange(dim); offset = lower_bound(sf_axis->indices->data, coordinate, l, r); break; } case AxisKind::kSparseVariable: auto sv_axis = axis.as(); PrimExpr l, r; - std::tie(l, r) = sp_buf_ctx_.GetIndicesRange(idx); + std::tie(l, r) = sp_buf_ctx_.GetIndicesRange(dim); offset = lower_bound(sv_axis->indices->data, coordinate, l, r); break; } @@ -404,6 +447,11 @@ class IndexTransformer : public StmtExprMutator { return offset; } + /*! + * \brief Compute the offset of given indices in compressed sparse buffer layout. + * \param sp_buffer The sparse buffer to access. + * \param indices The array of indices. + */ PrimExpr ComputeOffset(SparseBuffer sp_buffer, const Array& indices) { int num_lowered_indices = static_cast(indices.size()); ICHECK_LE(num_lowered_indices, sp_buffer->ndim()); @@ -426,7 +474,7 @@ class IndexTransformer : public StmtExprMutator { auto it = sp_blk_ctx_.GetSparseIterVar(v); if (it.defined()) { return sp_blk_ctx_.GetCoordinate(it.value().get()); - } else{ + } else { return GetRef(v); } }