Skip to content

Commit

Permalink
Refactor derived axis, frontend support of fusion. (apache#32)
Browse files Browse the repository at this point in the history
* upd

* upd

* fix
  • Loading branch information
yzh119 authored and MasterJH5574 committed Dec 22, 2021
1 parent f407846 commit 0afcfda
Show file tree
Hide file tree
Showing 8 changed files with 407 additions and 154 deletions.
198 changes: 116 additions & 82 deletions include/tvm/tir/sparse.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,18 +45,35 @@ enum class AxisKind : int {
*/
class AxisNode : public Object {
public:
void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name);
v->Visit("length", &length);
v->Visit("is_derived_axis", &is_derived_axis);
}

bool SEqualReduce(const AxisNode* other, SEqualReducer equal) const {
return equal(name, other->name) && equal(length, other->length) &&
equal(is_derived_axis, other->is_derived_axis);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(name);
hash_reduce(length);
hash_reduce(is_derived_axis);
}

/* name of current axis. */
String name;
/* length of current axis. For sparse axis, length refers to the upperbound of
* the current axis. */
PrimExpr length;
/* indicates whether current axis is derived by dense(axis) or fuse(axis1, axis2, ...) */
bool is_derived_axis = false;

String GetName() const { return name; }
PrimExpr GetLength() const { return length; }
DataType GetIndexType() const { return length->dtype; }

virtual bool is_fixed() const = 0;


virtual AxisKind kind() const = 0;

static constexpr const char* _type_key = "tir.sparse.Axis";
Expand All @@ -74,24 +91,6 @@ class Axis : public ObjectRef {
TVM_DEFINE_OBJECT_REF_METHODS(Axis, ObjectRef, AxisNode);
};

/*!
* \brief Root of Axis Dependency Tree.
*/
class RootAxisNode : public Object {
public:
static constexpr const char* _type_key = "tir.sparse.RootAxis";
TVM_DECLARE_FINAL_OBJECT_INFO(RootAxisNode, Object);
};

/*!
* \brief Managed reference to RootAxisNode.
* \sa RootAxisNode
*/
class RootAxis : public ObjectRef {
public:
TVM_DEFINE_OBJECT_REF_METHODS(RootAxis, ObjectRef, RootAxisNode);
};

/*!
* \brief Dense axis whose column indices are consecutive.
*/
Expand Down Expand Up @@ -133,84 +132,134 @@ class SparseAxis : public Axis {
*/
class DenseFixedAxisNode : public DenseAxisNode {
public:
Optional<SparseAxis> from_sparse;
AxisKind kind() const final { return AxisKind::kDenseFixed; }

static constexpr const char* _type_key = "tir.sparse.DenseFixedAxis";
TVM_DECLARE_BASE_OBJECT_INFO(DenseFixedAxisNode, DenseAxisNode);
};

/*!
* \brief Managed reference to DenseFixedAxisNode.
* \sa DenseFixedAxisNode
*/
class DenseFixedAxis : public DenseAxis {
public:
TVM_DLL explicit DenseFixedAxis(String name, PrimExpr length);

TVM_DEFINE_OBJECT_REF_METHODS(DenseFixedAxis, DenseAxis, DenseFixedAxisNode);
};

/*! \brief Derivation axis, constructed by T.dense(axis). */
class DenseFromSparseAxisNode : public DenseFixedAxisNode {
public:
void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name);
v->Visit("length", &length);
v->Visit("from_sparse", &from_sparse);
DenseFixedAxisNode::VisitAttrs(v);
v->Visit("base", &base);
}

bool SEqualReduce(const DenseFixedAxisNode* other, SEqualReducer equal) const {
return equal(name, other->name) && equal(length, other->length) &&
equal(from_sparse, other->from_sparse);
bool SEqualReduce(const DenseFromSparseAxisNode* other, SEqualReducer equal) const {
return DenseFixedAxisNode::SEqualReduce(other, equal) && equal(base, other->base);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(name);
hash_reduce(length);
hash_reduce(from_sparse);
DenseFixedAxisNode::SHashReduce(hash_reduce);
hash_reduce(base);
}

/* The based sparse axis. */
SparseAxis base;

static constexpr const char* _type_key = "tir.sparse.DenseFromSparseAxis";
TVM_DECLARE_FINAL_OBJECT_INFO(DenseFromSparseAxisNode, DenseFixedAxisNode);
};

/*!
* \brief Managed reference of DenseFromSparseAxisNode.
* \sa DenseFromSparseAxisNode
*/
class DenseFromSparseAxis : public DenseFixedAxis {
public:
/* DenseFromSparseAxis could be constructed by specifying the based sparse axis. */
TVM_DLL explicit DenseFromSparseAxis(SparseAxis base);

TVM_DEFINE_OBJECT_REF_METHODS(DenseFromSparseAxis, DenseFixedAxis, DenseFromSparseAxisNode);
};

class FusedAxis;

/*! \brief Derivation axis, constructed by T.fuse(axis1, axis2, ...) */
class FusedAxisNode : public DenseFixedAxisNode {
public:
void VisitAttrs(AttrVisitor* v) {
DenseFixedAxisNode::VisitAttrs(v);
v->Visit("group", &group);
v->Visit("index", &index);
}

bool is_fixed() const final{
return true;
bool SEqualReduce(const FusedAxisNode* other, SEqualReducer equal) const {
return DenseFixedAxisNode::SEqualReduce(other, equal) && equal(group, other->group) &&
equal(index, other->index);
}

AxisKind kind() const final {
return AxisKind::kDenseFixed;
void SHashReduce(SHashReducer hash_reduce) const {
DenseFixedAxisNode::SHashReduce(hash_reduce);
hash_reduce(group);
hash_reduce(index);
}

static constexpr const char* _type_key = "tir.sparse.DenseFixedAxis";
TVM_DECLARE_FINAL_OBJECT_INFO(DenseFixedAxisNode, DenseAxisNode);
/* The group of axes to be fused. */
Array<Axis> group;
/* The index of current FusedAxis in the group. */
int index;

static constexpr const char* _type_key = "tir.sparse.FusedAxis";
TVM_DECLARE_FINAL_OBJECT_INFO(FusedAxisNode, DenseFixedAxisNode);
};

/*!
* \brief Managed reference to DenseFixedAxisNode.
* \sa DenseFixedAxisNode
* \brief Managed refenrence to FusedAxisNode.
* \sa FusedAxisNode
*/
class DenseFixedAxis : public DenseAxis {
class FusedAxis : public DenseFixedAxis {
public:
TVM_DLL explicit DenseFixedAxis(String name, PrimExpr length,
Optional<SparseAxis> from_sparse = NullOpt);
/* Fused axis could be constructed by specifying a group of based axes and an index */
TVM_DLL explicit FusedAxis(Array<Axis> group, int index);

TVM_DEFINE_OBJECT_REF_METHODS(DenseFixedAxis, DenseAxis, DenseFixedAxisNode);
TVM_DEFINE_OBJECT_REF_METHODS(FusedAxis, DenseFixedAxis, FusedAxisNode);
};

/*!
* \brief Dense axis with variable length, such as ragged tensor.
*/
class DenseVariableAxisNode : public DenseAxisNode {
public:
Buffer indptr;

void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name);
v->Visit("length", &length);
DenseAxisNode::VisitAttrs(v);
v->Visit("indptr", &indptr);
}

bool SEqualReduce(const DenseVariableAxisNode* other, SEqualReducer equal) const {
return equal(name, other->name) && equal(length, other->length) && equal(indptr, other->indptr);
return DenseAxisNode::SEqualReduce(other, equal) && equal(indptr, other->indptr);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(name);
hash_reduce(length);
DenseAxisNode::SHashReduce(hash_reduce);
hash_reduce(indptr);
}

bool is_fixed() const final {
return false;
}
PrimExpr nnz() const { return indptr->shape[0]; }

AxisKind kind() const final {
return AxisKind::kDenseVariable;
}
AxisKind kind() const final { return AxisKind::kDenseVariable; }

static constexpr const char* _type_key = "tir.sparse.DenseVariableAxis";
TVM_DECLARE_FINAL_OBJECT_INFO(DenseVariableAxisNode, DenseAxisNode);
};

/*!
* \brief Dense axis whose length is dependent on its predecessors on the axis
* dependency tree.
* \brief Managed reference to DenseVariableAxisNode.
* \sa DenseVariableAxisNode
*/
class DenseVariableAxis : public DenseAxis {
public:
Expand All @@ -229,31 +278,23 @@ class SparseFixedAxisNode : public SparseAxisNode {
PrimExpr nnz_cols;

void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name);
v->Visit("length", &length);
SparseAxisNode::VisitAttrs(v);
v->Visit("indptr", &indices);
v->Visit("nnz_cols", &nnz_cols);
}

bool SEqualReduce(const SparseFixedAxisNode* other, SEqualReducer equal) const {
return equal(name, other->name) && equal(length, other->length) &&
equal(indices, other->indices) && equal(nnz_cols, other->nnz_cols);
return SparseAxisNode::SEqualReduce(other, equal) && equal(indices, other->indices) &&
equal(nnz_cols, other->nnz_cols);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(name);
hash_reduce(length);
SparseFixedAxisNode::SHashReduce(hash_reduce);
hash_reduce(indices);
hash_reduce(nnz_cols);
}

bool is_fixed() const final {
return true;
}

AxisKind kind() const final {
return AxisKind::kSparseFixed;
}
AxisKind kind() const final { return AxisKind::kSparseFixed; }

static constexpr const char* _type_key = "tir.sparse.SparseFixedAxis";
TVM_DECLARE_FINAL_OBJECT_INFO(SparseFixedAxisNode, SparseAxisNode);
Expand All @@ -279,31 +320,25 @@ class SparseVariableAxisNode : public SparseAxisNode {
Buffer indices;

void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name);
v->Visit("length", &length);
SparseAxisNode::VisitAttrs(v);
v->Visit("indptr", &indptr);
v->Visit("indices", &indices);
}

bool SEqualReduce(const SparseVariableAxisNode* other, SEqualReducer equal) const {
return equal(name, other->name) && equal(length, other->length) &&
equal(indptr, other->indptr) && equal(indices, other->indices);
return SparseAxisNode::SEqualReduce(other, equal) && equal(indptr, other->indptr) &&
equal(indices, other->indices);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(name);
hash_reduce(length);
SparseAxisNode::SHashReduce(hash_reduce);
hash_reduce(indptr);
hash_reduce(indices);
}

bool is_fixed() const final {
return false;
}
PrimExpr nnz() const { return indptr->shape[0]; }

AxisKind kind() const final {
return AxisKind::kSparseVariable;
}
AxisKind kind() const final { return AxisKind::kSparseVariable; }

static constexpr const char* _type_key = "tir.sparse.SparseVariableAxis";
TVM_DECLARE_FINAL_OBJECT_INFO(SparseVariableAxisNode, SparseAxisNode);
Expand Down Expand Up @@ -408,7 +443,6 @@ class SparseBuffer : public ObjectRef {
TVM_DEFINE_OBJECT_REF_METHODS(SparseBuffer, ObjectRef, SparseBufferNode);
};


// overload printing of for type.
TVM_DLL std::ostream& operator<<(std::ostream& os, AxisKind kind);

Expand Down
9 changes: 8 additions & 1 deletion python/tvm/script/tir/intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
SpIterVar,
SparseFixedAxis,
SparseVariableAxis,
DenseFromSparseAxis,
FusedAxis
)
from ..registry import register
from ..utils import get_param_list, tvm_span_from_synr
Expand Down Expand Up @@ -263,6 +265,11 @@ def comm_reducer(lambda_io, identities, span):
@register
def dense(axis: Axis, span: Optional[Span] = None):
if isinstance(axis, (SparseFixedAxis, SparseVariableAxis)):
return DenseFixedAxis(axis.name + "_dense", axis.length, axis)
return DenseFromSparseAxis(axis)
else:
return axis


@register
def fuse(group: List[Axis], span: Optional[Span] = None):
return [FusedAxis(group, _) for _ in range(len(group))]
43 changes: 37 additions & 6 deletions python/tvm/tir/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,17 +65,48 @@ class DenseFixedAxis(DenseAxis):
length : PrimExpr
The length of the axis
from_sparse : Optional[SparseAxis]
The SparseAxis that this axis is created from
"""

name: str
length: PrimExpr
from_sparse: Optional[SparseAxis]

def __init__(self, name, length, from_sparse=None):
self.__init_handle_by_constructor__(_ffi_api.DenseFixedAxis, name, length, from_sparse) # type: ignore
def __init__(self, name, length):
self.__init_handle_by_constructor__(_ffi_api.DenseFixedAxis, name, length) # type: ignore


@tvm._ffi.register_object("tir.sparse.DenseFromSparseAxis")
class DenseFromSparseAxis(DenseFixedAxis):
"""DenseFromSparseAxis node
Parameters
----------
base : Axis
The based sparse axis.
"""

base: Axis

def __init__(self, base):
self.__init_handle_by_constructor__(_ffi_api.DenseFromSparseAxis, base) # type: ignore


@tvm._ffi.register_object("tir.sparse.FusedAxis")
class FusedAxis(DenseFixedAxis):
"""FusedAxis node
Parameters
----------
group : List[Axis]
The axes group to be fused.
index : int
The index of current axis in the fused axes group.
"""

group: List[Axis]
index: int

def __init__(self, group, index):
self.__init_handle_by_constructor__(_ffi_api.FusedAxis, group, index) # type: ignore


@tvm._ffi.register_object("tir.sparse.DenseVariableAxis")
Expand Down
Loading

0 comments on commit 0afcfda

Please sign in to comment.