Skip to content

Commit

Permalink
Blockize & Tensorize (apache#514)
Browse files Browse the repository at this point in the history
* Blockize & Tensorize

* Update tensor intrin

* Fix blockized & Recalculate affine flags

* Cleanup utils.cc

* Add test cases of blockize

* Re-enable affine flag checking
  • Loading branch information
vinx13 authored Nov 18, 2021
1 parent 6c7763c commit 87dbb2c
Show file tree
Hide file tree
Showing 19 changed files with 2,306 additions and 3 deletions.
2 changes: 2 additions & 0 deletions include/tvm/arith/iter_affine_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,8 @@ Array<Array<IterMark>> SubspaceDivide(const Array<PrimExpr>& bindings,
const Array<Var>& sub_iters, const PrimExpr& predicate,
bool require_bijective, arith::Analyzer* analyzer);

PrimExpr NormalizeIterMapToExpr(const IterMapExpr& expr);

} // namespace arith
} // namespace tvm
#endif // TVM_ARITH_ITER_AFFINE_MAP_H_
30 changes: 30 additions & 0 deletions include/tvm/tir/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,36 @@ class LinkedParam : public ObjectRef {
TVM_DEFINE_OBJECT_REF_COW_METHOD(LinkedParamNode);
};

/*!
* \brief Tensor TensorIntrin for Tensorization
*/
class TensorIntrinNode : public Object {
public:
/*! \brief The function to describe the computation. */
PrimFunc description;
/*! \brief The intrinsic function for lower-level implement. */
PrimFunc implementation;

void VisitAttrs(AttrVisitor* v) {
v->Visit("description", &description);
v->Visit("implementation", &implementation);
}

static constexpr const char* _type_key = "tir.TensorIntrin";
TVM_DECLARE_FINAL_OBJECT_INFO(TensorIntrinNode, Object);
};

class TensorIntrin : public ObjectRef {
public:
TVM_DLL explicit TensorIntrin(PrimFunc desc_func, PrimFunc intrin_func);

TVM_DLL static TensorIntrin Register(String name, PrimFunc desc_func, PrimFunc intrin_func);

TVM_DLL static TensorIntrin Get(String name);

TVM_DEFINE_OBJECT_REF_METHODS(TensorIntrin, ObjectRef, TensorIntrinNode)
};

/*!
* \brief Specialize parameters of PrimFunc.
* \param func The PrimFunc to be specialized.
Expand Down
19 changes: 19 additions & 0 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,25 @@ class ScheduleNode : public runtime::Object {
virtual void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor,
int offset) = 0;
/******** Schedule: Blockize & Tensorize ********/
/*!
* \brief Make subtree rooted by a specific loop into a block
* \param loop_rv The root of the subtree
* \return The new block
*/
virtual BlockRV Blockize(const LoopRV& loop_rv) = 0;
/*!
* \brief Tensorize the computation enclosed by loop with tensor_intrin
* \param loop_rv the loop/block to be tensorized
* \param intrin the tensor intrinsic
*/
virtual void Tensorize(const LoopRV& loop_rv, const TensorIntrin& intrin) = 0;
/*!
* \brief Tensorize the computation enclosed by loop with tensor_intrin
* \param loop_rv The loop/block to be tensorized
* \param intrin_name Name of the tensor intrinsic
*/
virtual void Tensorize(const LoopRV& loop_rv, const String& intrin_name) = 0;

/******** Schedule: Annotation ********/
/*!
* \brief Annotate a loop with a key value pair
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from .stmt import IfThenElse, Evaluate, Prefetch, stmt_seq, stmt_list
from .stmt import BufferRegion, MatchBufferRegion, Block, BlockRealize

from .function import PrimFunc
from .function import PrimFunc, TensorIntrin

from .op import call_packed, call_intrin, call_pure_extern, call_extern
from .op import call_llvm_intrin, call_llvm_pure_intrin, ret, all, any, min_value, max_value, trace
Expand Down
26 changes: 26 additions & 0 deletions python/tvm/tir/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,29 @@ def script(self, tir_prefix: str = "tir", show_meta: bool = False) -> str:
return tvm._ffi.get_global_func("script.AsTVMScript")(
self, tir_prefix, show_meta
) # type: ignore

@tvm._ffi.register_object("tir.TensorIntrin")
class TensorIntrin(Object):
"""A function declaration expression.
Parameters
----------
desc_func: PrimFunc
The function to describe the computation
intrin_func: PrimFunc
The function for execution
"""

def __init__(self, desc_func, intrin_func):
self.__init_handle_by_constructor__(_ffi_api.TensorIntrin, desc_func, intrin_func)

@staticmethod
def register(name: str, desc_func: PrimFunc, intrin_func: PrimFunc):
return _ffi_api.TensorIntrinRegister( # pylint: disable=no-member
name, desc_func, intrin_func
)

@staticmethod
def get(name: str):
return _ffi_api.TensorIntrinGet(name) # pylint: disable=no-member
10 changes: 9 additions & 1 deletion python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from tvm.error import TVMError, register_error
from tvm.ir import IRModule, PrimExpr
from tvm.runtime import Object, String
from tvm.tir import Block, For, IntImm, PrimFunc
from tvm.tir import Block, For, IntImm, PrimFunc, TensorIntrin
from tvm.tir.expr import FloatImm

from . import _ffi_api
Expand Down Expand Up @@ -1656,6 +1656,14 @@ def after_storage_align(a: T.handle, c: T.handle) -> None:

########## Schedule: Blockize & Tensorize ##########

def blockize(self, loop: LoopRV) -> BlockRV:
return _ffi_api.ScheduleBlockize(self, loop) # pylint: disable=no-member

def tensorize(self, loop: LoopRV, intrin: Union[str, TensorIntrin]) -> None:
if isinstance(intrin, str):
intrin = String(intrin)
_ffi_api.ScheduleTensorize(self, loop, intrin) # pylint: disable=no-member

########## Schedule: Annotation ##########

def annotate(
Expand Down
67 changes: 67 additions & 0 deletions src/tir/ir/function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,65 @@ FuncType PrimFuncNode::func_type_annotation() const {

TVM_REGISTER_NODE_TYPE(PrimFuncNode);

TensorIntrin::TensorIntrin(PrimFunc desc_func, PrimFunc intrin_func) {
// check the number of func var is equal
CHECK_EQ(desc_func->params.size(), intrin_func->params.size());
CHECK_EQ(desc_func->buffer_map.size(), intrin_func->buffer_map.size());

// check both functions' bodies are directly block
const auto* desc_realize = Downcast<BlockRealize>(desc_func->body)->block->body.as<BlockRealizeNode>();
const auto* intrin_realize = Downcast<BlockRealize>(intrin_func->body)->block->body.as<BlockRealizeNode>();
CHECK(desc_realize != nullptr) << "description function's body expect a directly block";
CHECK(intrin_realize != nullptr) << "intrinsic function's body expect a directly block";

const Block& desc_block = desc_realize->block;
const Block& intrin_block = intrin_realize->block;

// check block var number and iter type
CHECK_EQ(desc_block->iter_vars.size(), intrin_block->iter_vars.size())
<< "Two blocks should have the same number of block vars";
for (size_t i = 0; i < desc_block->iter_vars.size(); i++) {
const IterVar& desc_var = desc_block->iter_vars[i];
const IterVar& intrin_var = intrin_block->iter_vars[i];
CHECK(desc_var->iter_type == intrin_var->iter_type)
<< "Block iter_type mismatch between " << desc_var->iter_type << " and "
<< intrin_var->iter_type;
}

auto n = make_object<TensorIntrinNode>();
n->description = std::move(desc_func);
n->implementation = std::move(intrin_func);
data_ = std::move(n);
}

class TensorIntrinManager {
public:
Map<String, tir::TensorIntrin> reg;

static TensorIntrinManager* Global() {
static TensorIntrinManager* inst = new TensorIntrinManager();
return inst;
}
};

TensorIntrin TensorIntrin::Register(String name, PrimFunc desc_func, PrimFunc intrin_func) {
TensorIntrinManager* manager = TensorIntrinManager::Global();
ICHECK_EQ(manager->reg.count(name), 0)
<< "ValueError: TensorIntrin '" << name << "' has already been registered";
TensorIntrin intrin(desc_func, intrin_func);
manager->reg.Set(name, intrin);
return intrin;
}

TensorIntrin TensorIntrin::Get(String name) {
const TensorIntrinManager* manager = TensorIntrinManager::Global();
ICHECK_EQ(manager->reg.count(name), 1)
<< "ValueError: TensorIntrin '" << name << "' is not registered";
return manager->reg.at(name);
}

TVM_REGISTER_NODE_TYPE(TensorIntrinNode);

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<PrimFuncNode>([](const ObjectRef& ref, ReprPrinter* p) {
// TODO(tvm-team) redirect to Text printer once we have a good text format.
Expand All @@ -85,5 +144,13 @@ TVM_REGISTER_GLOBAL("tir.PrimFunc")
return PrimFunc(params, body, ret_type, buffer_map, attrs, span);
});

TVM_REGISTER_GLOBAL("tir.TensorIntrin")
.set_body_typed([](PrimFunc desc_func, PrimFunc intrin_func) {
return TensorIntrin(desc_func, intrin_func);
});

TVM_REGISTER_GLOBAL("tir.TensorIntrinRegister").set_body_typed(TensorIntrin::Register);
TVM_REGISTER_GLOBAL("tir.TensorIntrinGet").set_body_typed(TensorIntrin::Get);

} // namespace tir
} // namespace tvm
2 changes: 2 additions & 0 deletions src/tir/schedule/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,8 @@ bool CanComputeInline(const ScheduleState& self, const StmtSRef& block_sref);
*/
bool CanReverseComputeInline(const ScheduleState& self, const StmtSRef& block_sref);

bool CheckOneLine(const Stmt& s);

} // namespace tir
} // namespace tvm

Expand Down
12 changes: 12 additions & 0 deletions src/tir/schedule/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1457,5 +1457,17 @@ bool HasIfThenElse(const Stmt& stmt) {
return has_branch;
}

bool CheckOneLine(const Stmt& s) {
bool legal = true, meet_block = false;
PostOrderVisit(s, [&legal, &meet_block](const ObjectRef& obj) {
if (obj->IsInstance<SeqStmtNode>() && !meet_block) {
legal = false;
} else if (obj->IsInstance<BlockRealizeNode>()) {
meet_block = true;
}
});
return legal;
}

} // namespace tir
} // namespace tvm
23 changes: 22 additions & 1 deletion src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,28 @@ BlockRV ConcreteScheduleNode::RFactor(const LoopRV& loop_rv, int factor_axis) {
}

/******** Schedule: Blockize & Tensorize ********/
/******** Schedule: Annotation ********/
BlockRV ConcreteScheduleNode::Blockize(const LoopRV& loop_rv) {
StmtSRef result{nullptr};
TVM_TIR_SCHEDULE_BEGIN();
result = tir::Blockize(state_, this->GetSRef(loop_rv));
this->state_->DebugVerify();
TVM_TIR_SCHEDULE_END("blockize", this->error_render_level_);
return CreateRV<BlockRV>(result);
}

void ConcreteScheduleNode::Tensorize(const LoopRV& loop_rv, const TensorIntrin& intrin) {
TVM_TIR_SCHEDULE_BEGIN();
tir::Tensorize(state_, this->GetSRef(loop_rv), intrin);
this->state_->DebugVerify();
TVM_TIR_SCHEDULE_END("tensorize", this->error_render_level_);
}

void ConcreteScheduleNode::Tensorize(const LoopRV& loop_rv, const String& intrin_name) {
TVM_TIR_SCHEDULE_BEGIN();
tir::Tensorize(state_, this->GetSRef(loop_rv), tir::TensorIntrin::Get(intrin_name));
this->state_->DebugVerify();
TVM_TIR_SCHEDULE_END("tensorize", this->error_render_level_);
}

void ConcreteScheduleNode::Annotate(const LoopRV& loop_rv, const String& ann_key,
const ObjectRef& ann_val) {
Expand Down
4 changes: 4 additions & 0 deletions src/tir/schedule/concrete_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,10 @@ class ConcreteScheduleNode : public ScheduleNode {
void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor,
int offset) override;
/******** Schedule: Blockize & Tensorize ********/
BlockRV Blockize(const LoopRV& loop_rv) override;
void Tensorize(const LoopRV& loop_rv, const TensorIntrin& intrin) override;
void Tensorize(const LoopRV& loop_rv, const String& intrin_name) override;

/******** Schedule: Annotation ********/
void Annotate(const LoopRV& loop_rv, const String& ann_key, const ObjectRef& ann_val) override;
void Unannotate(const LoopRV& loop_rv, const String& ann_key) override;
Expand Down
5 changes: 5 additions & 0 deletions src/tir/schedule/primitive.h
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,11 @@ TVM_DLL void StorageAlign(ScheduleState self, const StmtSRef& block_sref, int bu
int axis, int factor, int offset);

/******** Schedule: Blockize & Tensorize ********/

TVM_DLL StmtSRef Blockize(ScheduleState self, const StmtSRef& loop_sref);
TVM_DLL void Tensorize(ScheduleState self, const StmtSRef& loop_sref,
const TensorIntrin& intrinsic);

/******** Schedule: Annotation ********/
/*!
* \brief Annotate a block/loop with a key value pair
Expand Down
Loading

0 comments on commit 87dbb2c

Please sign in to comment.