Skip to content

Commit

Permalink
enforce default storage scope of global
Browse files Browse the repository at this point in the history
  • Loading branch information
Masahiro Masuda committed Jul 1, 2021
1 parent c443506 commit 8e20003
Show file tree
Hide file tree
Showing 10 changed files with 37 additions and 16 deletions.
2 changes: 1 addition & 1 deletion include/tvm/ir/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ class PointerType : public Type {
* \param element_type The type of the element which the pointer points to.
* \param storage_scope The storage scope into which the pointer addresses
*/
TVM_DLL explicit PointerType(Type element_type, String storage_scope = "");
TVM_DLL explicit PointerType(Type element_type, String storage_scope = "global");

TVM_DEFINE_OBJECT_REF_METHODS(PointerType, Type, PointerTypeNode);
};
Expand Down
3 changes: 2 additions & 1 deletion include/tvm/tir/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -196,14 +196,15 @@ class Buffer : public ObjectRef {
* \sa Buffer for complete constructor.
*/
TVM_DLL Buffer decl_buffer(Array<PrimExpr> shape, DataType dtype = DataType::Float(32),
String name = "buffer", String storage_scope = "", Span span = Span());
String name = "buffer", String storage_scope = "global", Span span = Span());

/*!
* \brief Return the storage scope associated with a buffer variable.
* \param buffer_var The input buffer variable.
* \return A string representing the storage scope of this buffer variable.
*/
TVM_DLL String GetStorageScope(Var buffer_var);
TVM_DLL Var UpdateStorageScope(Var buffer_var, String storage_scope);

/*!
* \brief Base node for data producers.
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/tir/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def decl_buffer(
data=None,
strides=None,
elem_offset=None,
scope="",
scope="global",
data_alignment=-1,
offset_factor=0,
buffer_type="",
Expand Down Expand Up @@ -250,7 +250,7 @@ def decl_buffer(
# Bool is represented as uint1 in the IR, but stored as int8
storage_type = PrimType(dtype)
storage_type = PrimType("int8") if storage_type.dtype == "bool" else storage_type
data = Var(name, PointerType(storage_type), span)
data = Var(name, PointerType(storage_type, scope), span)
return _ffi_api.Buffer(
data,
dtype,
Expand Down
3 changes: 2 additions & 1 deletion src/ir/type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
});

PointerType::PointerType(Type element_type, String storage_scope) {
ICHECK(storage_scope != "");
ObjectPtr<PointerTypeNode> n = make_object<PointerTypeNode>();
n->element_type = std::move(element_type);
n->storage_scope = std::move(storage_scope);
Expand All @@ -53,7 +54,7 @@ PointerType::PointerType(Type element_type, String storage_scope) {
TVM_REGISTER_NODE_TYPE(PointerTypeNode);

TVM_REGISTER_GLOBAL("ir.PointerType")
.set_body_typed([](Type element_type, String storage_scope = "") {
.set_body_typed([](Type element_type, String storage_scope = "global") {
return PointerType(element_type, storage_scope);
});

Expand Down
4 changes: 2 additions & 2 deletions src/te/operation/cross_thread_reduction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage,
for (size_t i = 0; i < size; ++i) {
DataType t = reduces[i]->dtype;
normal_res_handles.emplace_back("normal_reduce_temp" + std::to_string(i),
PointerType(PrimType(t)));
PointerType(PrimType(t), "local"));
lhs.push_back(Load(t, normal_res_handles[i], 0, const_true(t.lanes())));
}
Array<PrimExpr> init_value = combiner->identity_element;
Expand Down Expand Up @@ -177,7 +177,7 @@ Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage,
std::vector<Var> res_handles(size);
for (size_t idx = 0; idx < size; ++idx) {
DataType dtype = reduces[idx]->dtype;
res_handles[idx] = Var("reduce_temp" + std::to_string(idx), PointerType(PrimType(dtype)));
res_handles[idx] = Var("reduce_temp" + std::to_string(idx), PointerType(PrimType(dtype), "local"));
freduce_args.push_back(res_handles[idx]);
}

Expand Down
7 changes: 4 additions & 3 deletions src/te/schedule/schedule_postproc_to_primfunc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ namespace tvm {
namespace te {

// create a buffer for tensor.
Buffer CreateBufferFor(const Tensor& tensor, String storage_scope = "") {
Buffer CreateBufferFor(const Tensor& tensor, String storage_scope = "global") {
std::string name = tensor->op->name;
if (tensor->op->num_outputs() != 1) {
name += ".v" + std::to_string(tensor->value_index);
Expand Down Expand Up @@ -122,11 +122,12 @@ class TensorToBufferMapper : public StmtExprMutator {
}

private:
Buffer GetOrAllocBuffer(const Tensor& tensor, String storage_scope = "") {
Buffer GetOrAllocBuffer(const Tensor& tensor, String storage_scope = "global") {
return GetBuffer(tensor, storage_scope, true);
}

Buffer GetBuffer(const Tensor& tensor, String storage_scope = "", bool allow_alloc = false) {
Buffer GetBuffer(const Tensor& tensor, String storage_scope = "global",
bool allow_alloc = false) {
auto it = buffer_map_.find(tensor);
if (it != buffer_map_.end()) return it->second;
ICHECK(allow_alloc) << "Cannot find the Realization point of tensor " << tensor;
Expand Down
8 changes: 8 additions & 0 deletions src/tir/ir/buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ Array<PrimExpr> SimplifyArray(arith::Analyzer* ana, Array<PrimExpr> array) {
Buffer decl_buffer(Array<PrimExpr> shape, DataType dtype, String name, String storage_scope,
Span span) {
DataType storage_dtype = (dtype == DataType::Bool() ? DataType::Int(8) : dtype);
if (storage_scope == "") storage_scope = "global";
return Buffer(Var(name, PointerType(PrimType(storage_dtype), storage_scope), span), dtype, shape,
Array<PrimExpr>(), PrimExpr(), name, "", 0, 0, kDefault, span);
}
Expand All @@ -59,6 +60,13 @@ String GetStorageScope(Var buffer_var) {
return ptr_type->storage_scope;
}

Var UpdateStorageScope(Var buffer_var, String storage_scope) {
auto* ptr_type = buffer_var->type_annotation.as<PointerTypeNode>();
ICHECK(ptr_type) << "The provided variable is not of pointer type";
return Var(buffer_var->name_hint, PointerType(ptr_type->element_type, storage_scope),
buffer_var->span);
}

// Split the given expression w.r.t the add operator
inline std::vector<const PrimExpr*> ExprSplitAddition(const PrimExpr& expr) {
using namespace tir;
Expand Down
6 changes: 6 additions & 0 deletions src/tir/ir/stmt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)

// AttrStmt
AttrStmt::AttrStmt(ObjectRef node, String attr_key, PrimExpr value, Stmt body, Span span) {
if (attr_key == attr::storage_scope) {
const VarNode* buf = node.as<VarNode>();
CHECK(buf);
CHECK(value.as<StringImmNode>()->value == GetStorageScope(GetRef<Var>(buf)))
<< value.as<StringImmNode>()->value << ", " << GetStorageScope(GetRef<Var>(buf));
}
auto n = make_object<AttrStmtNode>();
n->node = node;
n->attr_key = std::move(attr_key);
Expand Down
12 changes: 6 additions & 6 deletions src/tir/transforms/lower_thread_allreduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,13 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
if (it != alloc_remap_.end()) {
const AllocateNode* repl = it->second.as<AllocateNode>();
if (warp_allocs_.count(repl)) {
stmt = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, op->body);
stmt = AttrStmt(repl->buffer_var, attr::storage_scope, StringImm("local"), stmt);
stmt = Allocate(UpdateStorageScope(repl->buffer_var, "local"), repl->dtype, repl->extents,
repl->condition, op->body);
} else {
// use volatile access to shared buffer.
stmt = AttrStmt(repl->buffer_var, attr::volatile_scope, 1, op->body);
stmt = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, stmt);
stmt = AttrStmt(repl->buffer_var, attr::storage_scope, StringImm("shared"), stmt);
stmt = Allocate(UpdateStorageScope(repl->buffer_var, "shared"), repl->dtype, repl->extents,
repl->condition, stmt);
}
return stmt;
} else {
Expand Down Expand Up @@ -365,8 +365,8 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
for (auto var : local_vars) {
const AllocateNode* repl = var.as<AllocateNode>();
if (repl) {
body = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, body);
body = AttrStmt(repl->buffer_var, attr::storage_scope, StringImm("local"), body);
body = Allocate(UpdateStorageScope(repl->buffer_var, "local"), repl->dtype, repl->extents,
repl->condition, body);
}
}

Expand Down
4 changes: 4 additions & 0 deletions src/tir/transforms/storage_flatten.cc
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ class StorageFlattener : public StmtExprMutator {
strides = Array<PrimExpr>(rstrides.rbegin(), rstrides.rend());
}

LOG(INFO) << "skey: " << skey.to_string();
e.buffer = Buffer(Var(op->buffer->data->name_hint, op->buffer->data->type_annotation),
op->buffer->dtype, shape, strides, PrimExpr(), op->buffer->name,
skey.to_string(), align, 0, kDefault);
Expand Down Expand Up @@ -225,6 +226,9 @@ class StorageFlattener : public StmtExprMutator {
ret = Allocate(e.buffer->data, storage_type, shape,
make_const(DataType::Bool(e.buffer->dtype.lanes()), true), body);
}
CHECK(e.buffer->scope == GetStorageScope(e.buffer->data))
<< e.buffer->scope << ", " << GetStorageScope(e.buffer->data) << ", "
<< GetStorageScope(op->buffer->data);
ret = AttrStmt(e.buffer->data, attr::storage_scope, StringImm(e.buffer->scope), ret);

if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) {
Expand Down

0 comments on commit 8e20003

Please sign in to comment.