Skip to content

Commit

Permalink
Frontend update, demo scripts. (apache#10)
Browse files Browse the repository at this point in the history
* Format and Buffer data structure (apache#1)

* [SparseTIR] Constructors and Python Interface for `Axis` and `SparseBuffer` (apache#2)

* add methods for Object

* axis constructors

* methods for SparseBuffer

* put into registry

* python interface

* [CherryPick][Intrinsic] lower_bound and upper_bound for binary search in Sparse TIR. (apache#483) (apache#4)

* upd

* upd

* fix

* upd

* upd

* upd

* upd

* upd

* fix

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* codegen-rule

* upd

* upd

* test

* upd

* fix

* two arguments

Co-authored-by: Zihao Ye <[email protected]>

* Fix AxisTree (apache#3)

* fix axis tree

* upd

* Format and Buffer data structure (apache#1)

* [SparseTIR] Constructors and Python Interface for `Axis` and `SparseBuffer` (apache#2)

* add methods for Object

* axis constructors

* methods for SparseBuffer

* put into registry

* python interface

* fix axis tree

* upd

* Format and Buffer data structure (apache#1)

* [SparseTIR] Constructors and Python Interface for `Axis` and `SparseBuffer` (apache#2)

* add methods for Object

* axis constructors

* methods for SparseBuffer

* put into registry

* python interface

* [CherryPick][Intrinsic] lower_bound and upper_bound for binary search in Sparse TIR. (apache#483) (apache#4)

* upd

* upd

* fix

* upd

* upd

* upd

* upd

* upd

* fix

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* codegen-rule

* upd

* upd

* test

* upd

* fix

* two arguments

Co-authored-by: Zihao Ye <[email protected]>

* Fix AxisTree (apache#3)

* fix axis tree

* upd

* [SparseTIR] Add SparseBufferLoad/SparseBufferStore (apache#5)

* Add dtype for SparseBuffer

* Add name for SparseBuffer. Remove `ndim`

* Remove namespace sparse

* Add SparseBufferLoad/Store

* Add method `ndim()`

* Format and Buffer data structure (apache#1)

* [SparseTIR] Constructors and Python Interface for `Axis` and `SparseBuffer` (apache#2)

* add methods for Object

* axis constructors

* methods for SparseBuffer

* put into registry

* python interface

* [CherryPick][Intrinsic] lower_bound and upper_bound for binary search in Sparse TIR. (apache#483) (apache#4)

* upd

* upd

* fix

* upd

* upd

* upd

* upd

* upd

* fix

* upd

* upd

* upd

* upd

* upd

* upd

* upd

* codegen-rule

* upd

* upd

* test

* upd

* fix

* two arguments

Co-authored-by: Zihao Ye <[email protected]>

* Fix AxisTree (apache#3)

* fix axis tree

* upd

* [SparseTIR] Add SparseBufferLoad/SparseBufferStore (apache#5)

* Add dtype for SparseBuffer

* Add name for SparseBuffer. Remove `ndim`

* Remove namespace sparse

* Add SparseBufferLoad/Store

* Add method `ndim()`

* [SparseTIR] Introduce SpIterVar (apache#6)

* [SparseTIR] Introduce SpIterVar

* Add conversion to PrimExpr

* [BugFix] Fix binary search & SpIterVar (apache#7)

* [BugFix] Add field `is_reduction` for SpIterVar (apache#9)

* [BugFix] Add field `is_reduction` for SpIterVar

* Formatting

* upd

* upd

Co-authored-by: Ruihang Lai <[email protected]>
  • Loading branch information
yzh119 and MasterJH5574 committed Nov 22, 2021
1 parent 002f7af commit 4c80d9e
Show file tree
Hide file tree
Showing 8 changed files with 451 additions and 99 deletions.
72 changes: 39 additions & 33 deletions include/tvm/tir/sparse.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ class AxisNode : public Object {
* the current axis. */
PrimExpr length;

String GetName() const { return name; }
PrimExpr GetLength() const { return length; }
DataType GetIndexType() const { return length->dtype; }

static constexpr const char* _type_key = "tir.sparse.Axis";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
Expand Down Expand Up @@ -139,8 +143,10 @@ class DenseVariableAxisNode : public DenseAxisNode {
v->Visit("indptr", &indptr);
}

bool SEqualReduce(const DenseVariableAxisNode* other, SEqualReducer equal) const {
return equal(name, other->name) && equal(length, other->length) && equal(indptr, other->indptr);
bool SEqualReduce(const DenseVariableAxisNode* other,
SEqualReducer equal) const {
return equal(name, other->name) && equal(length, other->length) &&
equal(indptr, other->indptr);
}

void SHashReduce(SHashReducer hash_reduce) const {
Expand All @@ -159,9 +165,11 @@ class DenseVariableAxisNode : public DenseAxisNode {
*/
class DenseVariableAxis : public DenseAxis {
public:
TVM_DLL explicit DenseVariableAxis(String name, PrimExpr length, Buffer indptr);
TVM_DLL explicit DenseVariableAxis(String name, PrimExpr length,
Buffer indptr);

TVM_DEFINE_OBJECT_REF_METHODS(DenseVariableAxis, DenseAxis, DenseVariableAxisNode);
TVM_DEFINE_OBJECT_REF_METHODS(DenseVariableAxis, DenseAxis,
DenseVariableAxisNode);
};

/*!
Expand Down Expand Up @@ -198,7 +206,8 @@ class SparseFixedAxisNode : public SparseAxisNode {
v->Visit("num_cols", &num_cols);
}

bool SEqualReduce(const SparseFixedAxisNode* other, SEqualReducer equal) const {
bool SEqualReduce(const SparseFixedAxisNode* other,
SEqualReducer equal) const {
return equal(name, other->name) && equal(length, other->length) &&
equal(indices, other->indices) && equal(num_cols, other->num_cols);
}
Expand All @@ -220,9 +229,11 @@ class SparseFixedAxisNode : public SparseAxisNode {
*/
class SparseFixedAxis : public SparseAxis {
public:
TVM_DLL explicit SparseFixedAxis(String name, PrimExpr length, Buffer indices, PrimExpr num_cols);
TVM_DLL explicit SparseFixedAxis(String name, PrimExpr length, Buffer indices,
PrimExpr num_cols);

TVM_DEFINE_OBJECT_REF_METHODS(SparseFixedAxis, SparseAxis, SparseFixedAxisNode);
TVM_DEFINE_OBJECT_REF_METHODS(SparseFixedAxis, SparseAxis,
SparseFixedAxisNode);
};

/*!
Expand All @@ -240,7 +251,8 @@ class SparseVariableAxisNode : public SparseAxisNode {
v->Visit("indices", &indices);
}

bool SEqualReduce(const SparseVariableAxisNode* other, SEqualReducer equal) const {
bool SEqualReduce(const SparseVariableAxisNode* other,
SEqualReducer equal) const {
return equal(name, other->name) && equal(length, other->length) &&
equal(indptr, other->indptr) && equal(indices, other->indices);
}
Expand All @@ -262,24 +274,25 @@ class SparseVariableAxisNode : public SparseAxisNode {
*/
class SparseVariableAxis : public SparseAxis {
public:
TVM_DLL explicit SparseVariableAxis(String name, PrimExpr length, Buffer indptr, Buffer indices);
TVM_DLL explicit SparseVariableAxis(String name, PrimExpr length,
Buffer indptr, Buffer indices);

TVM_DEFINE_OBJECT_REF_METHODS(SparseVariableAxis, SparseAxis, SparseVariableAxisNode);
TVM_DEFINE_OBJECT_REF_METHODS(SparseVariableAxis, SparseAxis,
SparseVariableAxisNode);
};

/*!
* \brief Axis Dependency Tree.
*/
class AxisTreeNode : public Object {
public:
// mapping from names to axes.
std::unordered_map<String, Axis> axis_map;
// unordered map that stores the parent relationship between axes.
std::unordered_map<Axis, Axis, ObjectPtrHash, ObjectPtrEqual> parent;
std::unordered_map<String, Optional<String>, ObjectPtrHash, ObjectPtrEqual>
parent;
// unordered map that stores the children relationship between axes.
std::unordered_map<Axis, Array<Axis>, ObjectPtrHash, ObjectPtrEqual> children;
// The root axis.
Axis root;
std::unordered_map<Optional<String>, Array<String>, ObjectPtrHash,
ObjectPtrEqual>
children;

void VisitAttrs(AttrVisitor* v) {}

Expand All @@ -293,7 +306,9 @@ class AxisTreeNode : public Object {
*/
class AxisTree : public ObjectRef {
public:
TVM_DLL AxisTree(Array<Axis> axes, Array<Optional<String>> axis_parent_names);
TVM_DLL AxisTree(Array<String> axis_names,
Array<Optional<String>> axis_parent_names);

TVM_DEFINE_OBJECT_REF_METHODS(AxisTree, ObjectRef, AxisTreeNode);
};

Expand All @@ -302,38 +317,30 @@ class AxisTree : public ObjectRef {
*/
class SparseBufferNode : public Object {
public:
/* Root of Axis Dependency Tree. */
AxisTree tree;
/* Axes */
Array<Axis> axes;
/* Buffer corresponding to flattened value */
Buffer data;
/* Buffer Name */
String name;
/* Data type */
runtime::DataType dtype;

inline int ndim() const { return static_cast<int>(axes.size()); }

void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &tree);
v->Visit("length", &axes);
v->Visit("num_cols", &data);
v->Visit("name", &name);
v->Visit("dtype", &dtype);
}

bool SEqualReduce(const SparseBufferNode* other, SEqualReducer equal) const {
return equal(tree, other->tree) && equal(axes, other->axes) && equal(data, other->data) &&
equal(name, other->name) && equal(dtype, other->dtype);
return equal(axes, other->axes) && equal(data, other->data) &&
equal(name, other->name);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(tree);
hash_reduce(axes);
hash_reduce(data);
hash_reduce(name);
hash_reduce(dtype);
}

static constexpr const char* _type_key = "tir.sparse.SparseBuffer";
Expand All @@ -346,8 +353,7 @@ class SparseBufferNode : public Object {
*/
class SparseBuffer : public ObjectRef {
public:
TVM_DLL explicit SparseBuffer(AxisTree tree, Array<Axis> axes, Buffer data, String name,
DataType dtype);
TVM_DLL explicit SparseBuffer(Array<Axis> axes, Buffer data, String name);

TVM_DEFINE_OBJECT_REF_METHODS(SparseBuffer, ObjectRef, SparseBufferNode);
};
Expand Down Expand Up @@ -380,8 +386,8 @@ class SpIterVarNode : public Object {

bool SEqualReduce(const SpIterVarNode* other, SEqualReducer equal) const {
return equal(var, other->var) && equal(max_extent, other->max_extent) &&
equal(axis, other->axis) && equal(is_reduction, other->is_reduction) &&
equal(kind, other->kind);
equal(axis, other->axis) &&
equal(is_reduction, other->is_reduction) && equal(kind, other->kind);
}

void SHashReduce(SHashReducer hash_reduce) const {
Expand All @@ -400,8 +406,8 @@ class SpIterVarNode : public Object {

class SpIterVar : public ObjectRef {
public:
TVM_DLL explicit SpIterVar(String name, PrimExpr max_extent, SpIterKind kind, bool is_reduction,
Optional<Axis> axis = NullOpt);
TVM_DLL explicit SpIterVar(String name, PrimExpr max_extent, SpIterKind kind,
bool is_reduction, Optional<Axis> axis = NullOpt);

/*!
* \return the corresponding var in the IterVar.
Expand Down
22 changes: 22 additions & 0 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,28 @@ class BufferStore : public Stmt {
TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferStoreNode);
};

/*!
* \brief Sparse Block node.
*/
class SparseBlockNode : public StmtNode {
public:
/*! \brief The sparse iteration variables of the block. */
Array<SpIterVar> sp_iter_vars;
/*! \brief The sparse buffers defined in the block. */
Array<SparseBuffer> sp_buffers;
/*! \brief The body of the block */
Stmt body;

static constexpr const char* _type_key = "tir.SparseBlock";
TVM_DECLARE_FINAL_OBJECT_INFO(SparseBlockNode, StmtNode);
};

class SparseBlock : public Stmt {
public:
TVM_DEFINE_OBJECT_REF_METHODS(SparseBlock, Stmt, SparseBlockNode);
};


/*!
* \brief Store value to the high dimension sparse buffer.
*
Expand Down
17 changes: 13 additions & 4 deletions python/tvm/script/context_maintainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@
import tvm
from tvm.ir import Span
from tvm.ir.expr import Range
from tvm.script.tir.sparse import MatchSparseBuffer
from tvm.tir import Var, Buffer, PrimExpr, Stmt, MatchBufferRegion
from tvm.runtime import Object
from tvm.tir.expr import IterVar
from tvm.tir.sparse import Axis, SparseBuffer
from .tir.node import BufferSlice


Expand Down Expand Up @@ -74,6 +76,10 @@ def example_func(a: T.handle, b: T.handle, c: T.handle) -> None:
"""List[Buffer]: list of T.alloc_buffer statements in the block signature"""
match_buffers: List[MatchBufferRegion] = []
"""List[MatchBufferRegion]: list of T.match_buffer statements in the block signature"""
axes: List[Axis] = []
"""List[Axis]: list of sparse axis created in the block signature."""
match_sparse_buffers: List[MatchSparseBuffer]
"""List[MatchSparseBuffer]: list of T.match_sparse_buffer statements in the block signature."""
iter_values: List[PrimExpr] = []
"""List[PrimExpr]: list of binding values for iter vars"""
iter_vars: List[IterVar] = []
Expand Down Expand Up @@ -119,14 +125,16 @@ class ContextMaintainer:
"""List[BlockInfo]: The block info for the current block scope"""
loop_stack: Dict[Var, Range] = {}
"""Dict[Var, Range]: The dict from loop var to its domain outside the block"""
symbols: List[Dict[str, Union[Var, Buffer]]] = []
symbols: List[Dict[str, Union[Var, Buffer, SparseBuffer, Axis]]] = []
"""List[Dict[str, Union[Var, Buffer]]]: Symbol map from name to object for the current scope"""

# function context
func_params: List[Var] = []
"""List[Var]: The function parameters"""
func_buffer_map: Mapping[Var, Buffer] = {}
"""Mapping[Var, Buffer]: The function buffer map"""
func_sparse_buffer_map: Mapping[Var, SparseBuffer] = {}
"""Mapping[Var, SparseBuffer]: The function sparse buffer map"""
func_dict_attr: Mapping[str, Object] = {}
"""Mapping[str, Object]: The function attrs"""
func_var_env_dict: Mapping[Var, str] = {}
Expand All @@ -151,6 +159,7 @@ def __init__(self, _report_error: Callable[[str, Union[Span, synr.ast.Span]], No
# function context
self.func_params = []
self.func_buffer_map = {}
self.func_sparse_buffer_map = {}
self.func_dict_attr = {}
self.func_var_env_dict = {}
# parser and analyzer
Expand Down Expand Up @@ -208,9 +217,9 @@ def exit_block_scope(self):
# Pop block_info
self.block_info_stack.pop()

def update_symbol(self, name: str, symbol: Union[Buffer, Var], node: synr.ast.Node):
def update_symbol(self, name: str, symbol: Union[Buffer, Var, SparseBuffer, Axis], node: synr.ast.Node):
"""Append a symbol into current scope"""
if isinstance(symbol, Buffer):
if isinstance(symbol, (Buffer, Var, SparseBuffer, Axis)):
if name in self.symbols[0]:
self.report_error("Duplicate Buffer name: " + symbol.name, node.span)
self.symbols[0][name] = symbol
Expand All @@ -225,7 +234,7 @@ def remove_symbol(self, name: str):
return
raise RuntimeError("Internal error of tvm script parser: no symbol named " + name)

def lookup_symbol(self, name: str) -> Optional[Union[Buffer, Var]]:
def lookup_symbol(self, name: str) -> Optional[Union[Buffer, Var, SparseBuffer, Axis]]:
"""Look up symbol by name"""
for symbols in reversed(self.symbols):
if name in symbols:
Expand Down
Loading

0 comments on commit 4c80d9e

Please sign in to comment.