Skip to content

Commit

Permalink
[Draft][TIR] Remove PrimFuncNode::preflattened_buffer_map
Browse files Browse the repository at this point in the history
`PrimFuncNode::preflattened_buffer_map` was introduced in
apache#9727, in order to maintain a record
of the pre-flattened buffer shape until it can be used in
`MakePackedAPI`.  This commit instead maintains the pre-flattened
shapes in `PrimFuncNode::buffer_map`, while the body of the function
uses a flattened buffer alias.

Passes LLVM tests in test_target_codegen_llvm.py as initial proof of
concept.
  • Loading branch information
Lunderberg committed Apr 19, 2022
1 parent 557fc6c commit 37d47ba
Show file tree
Hide file tree
Showing 18 changed files with 154 additions and 303 deletions.
43 changes: 11 additions & 32 deletions include/tvm/tir/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,33 +88,22 @@ class PrimFuncNode : public BaseFuncNode {
* While we could have express parameter unpacking and constraint using
* normal statements, making buffer_map as first class citizen of PrimFunc
* will make program analysis much easier.
*/
Map<tir::Var, Buffer> buffer_map;

/*! \brief The buffer map prior to flattening.
*
* This contains the buffers as they exists prior to flattening, and
* is used for validating an input tensor passed into the packed
* API. Any buffer that is present in `buffer_map` but not present
* in `preflattened_buffer_map` is assumed to be the same before
* and after flattening (e.g. a 1-d tensor that is backed by 1-d
* flat memory).
*
* TODO(Lunderberg): Remove preflattened_buffer_map, and instead
* declare each flattened buffer as aliasing the original tensor
* shape. This should include improving the StmtExprMutator to
* provide easier interactions with Buffer objects, so that the
* bookkeeping of relationships between buffers doesn't need to be
* repeated across several transforms.
* Prior to buffer flattening, which is performed either in
* StorageFlatten for TE-based schedules or in FlattenBuffer for
* TIR-based schedules, these buffer objects are used directly in
* the body of the function. After buffer flattening, these buffer
* objects remain unflattened for use in argument validation, but
* all usage in the body of the function is done through a
* flattened alias of the buffer.
*/
Map<tir::Var, Buffer> preflattened_buffer_map;
Map<tir::Var, Buffer> buffer_map;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("params", &params);
v->Visit("body", &body);
v->Visit("ret_type", &ret_type);
v->Visit("buffer_map", &buffer_map);
v->Visit("preflattened_buffer_map", &preflattened_buffer_map);
v->Visit("attrs", &attrs);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
Expand All @@ -123,15 +112,13 @@ class PrimFuncNode : public BaseFuncNode {
bool SEqualReduce(const PrimFuncNode* other, SEqualReducer equal) const {
// visit params and buffer_map first as they contains defs.
return equal.DefEqual(params, other->params) && equal(buffer_map, other->buffer_map) &&
equal(preflattened_buffer_map, other->preflattened_buffer_map) &&
equal(ret_type, other->ret_type) && equal(body, other->body) &&
equal(attrs, other->attrs);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce.DefHash(params);
hash_reduce(buffer_map);
hash_reduce(preflattened_buffer_map);
hash_reduce(ret_type);
hash_reduce(body);
hash_reduce(attrs);
Expand Down Expand Up @@ -169,21 +156,13 @@ class PrimFunc : public BaseFunc {
* PrimFunc. (e.g. a buffer of shape ``[1024]`` originally
* generated as a tensor of shape ``[32, 32]``)
*
* \param preflattened_buffer_map The buffer map for
* parameter buffer unpacking. This contains buffer
* objects as they are expected to be passed in by the
* callee. (e.g. a buffer of shape ``[32, 32]`` originally
* generated as a tensor of shape ``[32, 32]``)
*
* \param attrs Additional function attributes.
*
* \param span The location of this object in the source code.
*/
TVM_DLL PrimFunc(
Array<tir::Var> params, Stmt body, Type ret_type = VoidType(),
Map<tir::Var, Buffer> buffer_map = Map<tir::Var, Buffer>(),
Optional<Map<tir::Var, Buffer>> preflattened_buffer_map = Optional<Map<tir::Var, Buffer>>(),
DictAttrs attrs = NullValue<DictAttrs>(), Span span = Span());
TVM_DLL PrimFunc(Array<tir::Var> params, Stmt body, Type ret_type = VoidType(),
Map<tir::Var, Buffer> buffer_map = Map<tir::Var, Buffer>(),
DictAttrs attrs = NullValue<DictAttrs>(), Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(PrimFunc, BaseFunc, PrimFuncNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(PrimFuncNode);
Expand Down
3 changes: 0 additions & 3 deletions python/tvm/script/context_maintainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,6 @@ class ContextMaintainer:
"""List[Var]: The function parameters"""
func_buffer_map: Mapping[Var, Buffer] = {}
"""Mapping[Var, Buffer]: The function buffer map"""
func_preflattened_buffer_map: Mapping[Var, Buffer] = {}
"""Mapping[Var, Buffer]: The function buffer map, prior to any flattening."""
func_dict_attr: Mapping[str, Object] = {}
"""Mapping[str, Object]: The function attrs"""
func_var_env_dict: Mapping[Var, str] = {}
Expand All @@ -153,7 +151,6 @@ def __init__(self, _report_error: Callable[[str, Union[Span, synr.ast.Span]], No
# function context
self.func_params = []
self.func_buffer_map = {}
self.func_preflattened_buffer_map = {}
self.func_dict_attr = {}
self.func_var_env_dict = {}
# parser and analyzer
Expand Down
1 change: 0 additions & 1 deletion python/tvm/script/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,6 @@ def check_decorator(decorators: List[ast.Expr]) -> bool:
body,
ret_type,
buffer_map=self.context.func_buffer_map,
preflattened_buffer_map=self.context.func_preflattened_buffer_map,
attrs=tvm.ir.make_node("DictAttrs", **dict_attr) if dict_attr else None,
span=tvm_span_from_synr(node.span),
)
Expand Down
73 changes: 0 additions & 73 deletions python/tvm/script/tir/special_stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,79 +863,6 @@ def func_attr(dict_attr, span):
super().__init__(func_attr, def_symbol=False)


@register
class PreflattenedBufferMap(SpecialStmt):
"""Special Stmt for declaring the PrimFunc::preflattened_buffer_map
Example
-------
.. code-block:: python
A0 = T.match_buffer(A, (48,), dtype="float32")
T.preflattened_buffer_map(A, (1, 4, 4, 3), elem_offset=1, align=4, dtype="float32")
"""

def __init__(self):
def preflattened_buffer(
postflattened,
shape,
dtype="float32",
data=None,
strides=None,
elem_offset=None,
scope="global",
align=-1,
offset_factor=0,
buffer_type="default",
span=None,
):

param = None
for key, value in self.context.func_buffer_map.items():
if value.same_as(postflattened):
param = key
break

assert (
param is not None
), f"Post-flatten buffer {postflattened.name} does not appear in the buffer map."

if data is None:
data = self.context.func_buffer_map[param].data

buffer_name: str = f"{postflattened.name}_preflatten"
if align != -1:
if isinstance(align, IntImm):
align = align.value
else:
assert isinstance(align, int), f"align: want int or IntImm, got {align!r}"

if offset_factor != 0:
if isinstance(offset_factor, IntImm):
offset_factor = offset_factor.value
else:
assert isinstance(
offset_factor, int
), f"offset_factor: want int or IntImm, got {offset_factor!r}"

preflattened = tvm.tir.decl_buffer(
shape,
dtype,
buffer_name,
data,
strides,
elem_offset,
scope,
align,
offset_factor,
buffer_type,
span=span,
)

self.context.func_preflattened_buffer_map[param] = preflattened

super().__init__(preflattened_buffer, def_symbol=False)


@register
class TargetAttrValue(SpecialStmt):
"""Special Stmt for target attr value.
Expand Down
7 changes: 0 additions & 7 deletions python/tvm/tir/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,6 @@ class PrimFunc(BaseFunc):
buffer_map : Map[tvm.tir.Var, tvm.tir.Buffer]
The buffer binding map.
preflattened_buffer_map : Optional[Map[tvm.tir.Var, tvm.tir.Buffer]]
The buffer binding map, prior to any flattening.
attrs: Optional[tvm.Attrs]
Attributes of the function, can be None
Expand All @@ -62,14 +59,12 @@ def __init__(
body,
ret_type=None,
buffer_map=None,
preflattened_buffer_map=None,
attrs=None,
span=None,
):

param_list = []
buffer_map = {} if buffer_map is None else buffer_map
preflattened_buffer_map = {} if preflattened_buffer_map is None else preflattened_buffer_map
for x in params:
x = tvm.runtime.convert(x) if not isinstance(x, Object) else x
if isinstance(x, Buffer):
Expand All @@ -87,7 +82,6 @@ def __init__(
body,
ret_type,
buffer_map,
preflattened_buffer_map,
attrs,
span,
) # type: ignore
Expand All @@ -113,7 +107,6 @@ def with_body(self, new_body, span=None):
new_body,
self.ret_type,
self.buffer_map,
self.preflattened_buffer_map,
self.attrs,
span,
)
Expand Down
20 changes: 0 additions & 20 deletions src/printer/tvmscript_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1533,26 +1533,6 @@ Doc TVMScriptPrinter::PrintPrimFunc(const PrimFunc& primFunc) {
body << Print((*it).first) << ", " << memo_buf_decl_[buf];
body << ")" << Doc::NewLine();
}
// print preflattened buffer map
for (const auto& param : op->params) {
auto pf_buf_it = op->preflattened_buffer_map.find(param);
if (pf_buf_it != op->preflattened_buffer_map.end()) {
const Buffer& preflattened = (*pf_buf_it).second;

auto buf_it = op->buffer_map.find(param);
ICHECK(buf_it != op->buffer_map.end()) << "Found pre-flattened buffer " << preflattened->name
<< " with no corresponding post-flatten buffer.";
const Buffer& postflattened = (*buf_it).second;

// Call Print() without assigning in order to fill memo_buf_decl_.
Print(preflattened);
buf_not_in_headers_.insert(preflattened.get());
ICHECK(memo_buf_decl_.count(preflattened));

body << tir_prefix_ << ".preflattened_buffer(" << Print(postflattened) << ", "
<< memo_buf_decl_.at(preflattened) << ")" << Doc::NewLine();
}
}
// print body
body << "# body" << Doc::NewLine();
if (op->body->IsInstance<BlockRealizeNode>() &&
Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -770,7 +770,7 @@ class AOTExecutorCodegen : public MixedModeVisitor {
tir::Stmt final_body = tir::SeqStmt({device_activations, body, device_deactivations});

// Make the PrimFunc
return tir::PrimFunc(main_signature_, final_body, VoidType(), main_buffer_map_, {},
return tir::PrimFunc(main_signature_, final_body, VoidType(), main_buffer_map_,
DictAttrs(dict_attrs));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class ConvertAddToSubtract : public MixedModeMutator {
};

tir::PrimFunc replacement_func = tir::PrimFunc({x_var, y_var, out_var}, math_loop, VoidType(),
buffer_map, {}, DictAttrs(dict_attrs));
buffer_map, DictAttrs(dict_attrs));

// Switch to TIRToRuntime hook for testing
Bool tir_to_runtime = func->GetAttr<Bool>("tir_to_runtime").value_or(Bool(false));
Expand Down
4 changes: 3 additions & 1 deletion src/relay/transforms/fold_constant.cc
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,9 @@ class ConstantFolder : public MixedModeMutator {

// Use a fresh build context in case we are already in a build context.
// needed for both execution and creation(due to JIT)
With<transform::PassContext> fresh_build_ctx(transform::PassContext::Create());
auto context = transform::PassContext::Create();
context->instruments = transform::PassContext::Current()->instruments;
With<transform::PassContext> fresh_build_ctx(context);

Map<String, ObjectRef> dict = (module_->attrs.defined())
? Map<String, ObjectRef>(module_->attrs.CopyOnWrite()->dict)
Expand Down
22 changes: 1 addition & 21 deletions src/tir/analysis/device_constraint_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,6 @@ class ApplyDeviceConstraintsMutator : public StmtExprMutator {

// Start with a copy of the current prim_func buffer map.
Map<Var, Buffer> new_buffer_map(prim_func->buffer_map.begin(), prim_func->buffer_map.end());
Map<Var, Buffer> new_preflattened_buffer_map(prim_func->preflattened_buffer_map.begin(),
prim_func->preflattened_buffer_map.end());
bool any_change = false;

// For each constrained parameter...
Expand All @@ -225,23 +223,6 @@ class ApplyDeviceConstraintsMutator : public StmtExprMutator {
any_change = true;
}
new_buffer_map.Set(param, new_buffer);

// Rewrite the pre-flattened buffers to account for constraint.
// This only has an impact if the IRModule being analyzed has
// already been run through the StorageFlatten or FlattenBuffer
// passes.
if (auto opt = prim_func->preflattened_buffer_map.Get(param)) {
Buffer pf_buffer = opt.value();
if (pf_buffer.same_as(buffer)) {
new_preflattened_buffer_map.Set(param, new_buffer);
} else {
const Buffer new_buffer = RewriteBuffer(pf_buffer, virtual_device);
if (!new_buffer.same_as(pf_buffer)) {
any_change = true;
}
new_preflattened_buffer_map.Set(param, new_buffer);
}
}
}
// Make sure we have accounted for all prim_func parameters.
CheckNoRemainingPointerParams(prim_func, &current_primfunc_param_index);
Expand All @@ -259,8 +240,7 @@ class ApplyDeviceConstraintsMutator : public StmtExprMutator {

if (any_change) {
return PrimFunc(prim_func->params, std::move(new_body), prim_func->ret_type,
std::move(new_buffer_map), std::move(new_preflattened_buffer_map),
prim_func->attrs, prim_func->span);
std::move(new_buffer_map), prim_func->attrs, prim_func->span);
} else {
return prim_func;
}
Expand Down
5 changes: 2 additions & 3 deletions src/tir/contrib/ethosu/passes.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,8 @@ class HoistAllocatesMutator : public StmtExprMutator {
current_alloc->span);
}

PrimFunc new_main_func =
PrimFunc(main_func->params, new_main_func_body, main_func->ret_type, main_func->buffer_map,
main_func->preflattened_buffer_map, main_func->attrs);
PrimFunc new_main_func = PrimFunc(main_func->params, new_main_func_body, main_func->ret_type,
main_func->buffer_map, main_func->attrs);
return new_main_func;
}

Expand Down
10 changes: 3 additions & 7 deletions src/tir/ir/function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,7 @@ namespace tvm {
namespace tir {
// Get the function type of a PrimFunc
PrimFunc::PrimFunc(Array<tir::Var> params, Stmt body, Type ret_type,
Map<tir::Var, Buffer> buffer_map,
Optional<Map<tir::Var, Buffer>> preflattened_buffer_map, DictAttrs attrs,
Span span) {
Map<tir::Var, Buffer> buffer_map, DictAttrs attrs, Span span) {
// Assume void-return type for now
// TODO(tvm-team) consider type deduction from body.
if (!ret_type.defined()) {
Expand All @@ -42,7 +40,6 @@ PrimFunc::PrimFunc(Array<tir::Var> params, Stmt body, Type ret_type,
n->body = std::move(body);
n->ret_type = std::move(ret_type);
n->buffer_map = std::move(buffer_map);
n->preflattened_buffer_map = preflattened_buffer_map.value_or(Map<tir::Var, Buffer>());
n->attrs = std::move(attrs);
n->checked_type_ = n->func_type_annotation();
n->span = std::move(span);
Expand Down Expand Up @@ -121,9 +118,8 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)

TVM_REGISTER_GLOBAL("tir.PrimFunc")
.set_body_typed([](Array<tir::Var> params, Stmt body, Type ret_type,
Map<tir::Var, Buffer> buffer_map,
Map<tir::Var, Buffer> preflattened_buffer_map, DictAttrs attrs, Span span) {
return PrimFunc(params, body, ret_type, buffer_map, preflattened_buffer_map, attrs, span);
Map<tir::Var, Buffer> buffer_map, DictAttrs attrs, Span span) {
return PrimFunc(params, body, ret_type, buffer_map, attrs, span);
});

TVM_REGISTER_GLOBAL("tir.TensorIntrin")
Expand Down
Loading

0 comments on commit 37d47ba

Please sign in to comment.