Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Enforce attaching storage scope to PointerType #8366

Merged
merged 92 commits into from
Jul 13, 2021
Merged
Show file tree
Hide file tree
Changes from 90 commits
Commits
Show all changes
92 commits
Select commit Hold shift + click to select a range
a84db69
Add storage scope to ProducerRealize, always create a buffer with scope
Jun 30, 2021
0140631
update schedule_ops.cc
Jun 30, 2021
b6d8e6c
update schedule_postproc_to_primfunc.cc
Jun 30, 2021
bebcc50
restore more realize_scope
Jun 30, 2021
edeaed2
make the default scope be "" instead of None in ir builder
Jun 30, 2021
cd6167e
restore realize_scope visit in storage_flatten.cc
Jun 30, 2021
a33fb0d
update storage_access.cc
Jun 30, 2021
e878eae
make sure buffer var is of PointerType in ir builder
Jun 30, 2021
9f98dea
enforce default storage scope of global
Jul 1, 2021
2923f4d
added remap pass but does not work yet
masahi Jul 1, 2021
6a79354
fixed all reduce issue
masahi Jul 2, 2021
a566af9
simplify
masahi Jul 2, 2021
345573a
trying mitigation for aot test
masahi Jul 2, 2021
be2ab96
merge remaining changes from initial branch
masahi Jul 2, 2021
f83cba0
remove use of attr::storage_scope from codegen
Jul 2, 2021
f1b0f3c
restore a visit to AttrStmt with attr::storage_scope in storage_rewrite
Jul 2, 2021
30edcd6
disable check
Jul 2, 2021
9560f1b
lint fix
Jul 2, 2021
3de3edd
revert default scope to ""
masahi Jul 2, 2021
b703d8b
format
masahi Jul 2, 2021
b7abe5a
fix volatile access to shared mem in lower all reduce
Jul 4, 2021
62f818c
fixed gpu coorporative load/store test
Jul 4, 2021
d41bdc8
pass storage scope to PointerType in tvm script parser
Jul 5, 2021
594d34b
fixed tvmscript roundtrip test
Jul 6, 2021
151eb02
fixed tir flatten buffer test
Jul 6, 2021
1e88335
fixed test_tir_transform_hoist_if.py
Jul 6, 2021
7e8a7c1
use storage scope global by default in aot_executor_codegen.cc
Jul 6, 2021
d458187
add missing default storage scope in create_primfunc.cc
Jul 6, 2021
410fd45
restore StorageInfo struct in llvm backend
Jul 6, 2021
742c243
UpdateStorageScope -> WithStorageScope
Jul 6, 2021
aa90d42
fixed lower warp memory test
Jul 6, 2021
1e9a7a3
GetStorageScope -> GetPtrStorageScope
Jul 6, 2021
afbe7f1
Enable storage scope invariant check in AttrStmt constructor
Jul 6, 2021
ee3aa5d
remove GetPtrStorageScope and WithStorageScope from public header
Jul 6, 2021
accfff4
move RemapStorageScope to its own file
Jul 6, 2021
a74aece
add more method to RemapStorageScope
Jul 6, 2021
98c3c3c
update lower_thread_allreduce to use RemapStorageScope
Jul 6, 2021
bafdb47
RemapStorageScope -> UpdatePointerStorageScope
Jul 7, 2021
16c81b5
remove realize_scope from hybrid script
masahi Jul 7, 2021
a14c5fa
removed realize_scope in schedule_ops
masahi Jul 7, 2021
c2ea828
remove realize_scope from schedule_postproc_to_primfunc
masahi Jul 7, 2021
d401022
remove remaining realize_scope usage from schedule_ops.cc
masahi Jul 7, 2021
de83623
remove realize_scope usage from storage_flatten.cc
masahi Jul 7, 2021
426b573
fixed test_tir_transform_lower_warp_memory.py following realize_scope…
Jul 7, 2021
1964f74
Add storage scope to ProducerRealize, always create a buffer with scope
Jun 30, 2021
63e0c85
update schedule_ops.cc
Jun 30, 2021
24225e5
update schedule_postproc_to_primfunc.cc
Jun 30, 2021
9ddfb28
restore more realize_scope
Jun 30, 2021
416169c
make the default scope be "" instead of None in ir builder
Jun 30, 2021
da3053e
restore realize_scope visit in storage_flatten.cc
Jun 30, 2021
76ed1ab
update storage_access.cc
Jun 30, 2021
e95146e
make sure buffer var is of PointerType in ir builder
Jun 30, 2021
13b2efc
enforce default storage scope of global
Jul 1, 2021
ae8858a
added remap pass but does not work yet
masahi Jul 1, 2021
cec5a0a
fixed all reduce issue
masahi Jul 2, 2021
e4f0965
simplify
masahi Jul 2, 2021
fca791c
trying mitigation for aot test
masahi Jul 2, 2021
a80dfad
merge remaining changes from initial branch
masahi Jul 2, 2021
6fd0fb3
remove use of attr::storage_scope from codegen
Jul 2, 2021
b3fa275
restore a visit to AttrStmt with attr::storage_scope in storage_rewrite
Jul 2, 2021
b62d8a1
disable check
Jul 2, 2021
4a084d7
lint fix
Jul 2, 2021
73911a9
revert default scope to ""
masahi Jul 2, 2021
1c8b779
format
masahi Jul 2, 2021
9f54b62
fix volatile access to shared mem in lower all reduce
Jul 4, 2021
034fb72
fixed gpu coorporative load/store test
Jul 4, 2021
6081519
pass storage scope to PointerType in tvm script parser
Jul 5, 2021
2d92f6d
fixed tvmscript roundtrip test
Jul 6, 2021
35fb917
fixed tir flatten buffer test
Jul 6, 2021
46647e9
fixed test_tir_transform_hoist_if.py
Jul 6, 2021
db8eb9a
use storage scope global by default in aot_executor_codegen.cc
Jul 6, 2021
66c61ae
add missing default storage scope in create_primfunc.cc
Jul 6, 2021
bd76606
restore StorageInfo struct in llvm backend
Jul 6, 2021
4cae209
UpdateStorageScope -> WithStorageScope
Jul 6, 2021
01b94cc
fixed lower warp memory test
Jul 6, 2021
cbaa6e7
GetStorageScope -> GetPtrStorageScope
Jul 6, 2021
20b930b
Enable storage scope invariant check in AttrStmt constructor
Jul 6, 2021
a247d94
remove GetPtrStorageScope and WithStorageScope from public header
Jul 6, 2021
0d8c9bc
move RemapStorageScope to its own file
Jul 6, 2021
c295338
add more method to RemapStorageScope
Jul 6, 2021
ebaceb9
update lower_thread_allreduce to use RemapStorageScope
Jul 6, 2021
0bf5081
RemapStorageScope -> UpdatePointerStorageScope
Jul 7, 2021
c03c360
remove realize_scope from hybrid script
masahi Jul 7, 2021
21d4134
removed realize_scope in schedule_ops
masahi Jul 7, 2021
0ff503b
remove realize_scope from schedule_postproc_to_primfunc
masahi Jul 7, 2021
086d891
remove remaining realize_scope usage from schedule_ops.cc
masahi Jul 7, 2021
e60ad8d
remove realize_scope usage from storage_flatten.cc
masahi Jul 7, 2021
cb697d8
fixed test_tir_transform_lower_warp_memory.py following realize_scope…
Jul 7, 2021
6a11ab4
minor fix
masahi Jul 8, 2021
f459dbf
Address comments
masahi Jul 10, 2021
3e5e0ef
Remove blank line diff
masahi Jul 12, 2021
5710513
Merge remote-tracking branch 'upstream/main' into storage_scope_refactor
masahi Jul 13, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions docs/dev/inferbound.rst
Original file line number Diff line number Diff line change
Expand Up @@ -447,13 +447,11 @@ Here is the IR after ScheduleOps (note that loops with extent 1 have been preser

::

// attr [compute(D, 0x2c070b0)] realize_scope = ""
realize D([0, 4], [0, 5], [0, 16]) {
produce D {
for (di, 0, 4) {
for (dj, 0, 5) {
for (dk, 0, 16) {
// attr [compute(C, 0x2c29990)] realize_scope = ""
realize C([dj, 1], [dk, 1]) {
produce C {
for (i, 0, 1) {
Expand Down
15 changes: 8 additions & 7 deletions include/tvm/te/operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,12 @@ class TVM_DLL OperationNode : public Object {
* \param stage the op's stage.
* \param realize_map The realization domain map of the operators.
* \param body The body that is going to get
* \param storage_scope The storage scope associated with this realization
* \return A realization statement that wraps body.
*/
virtual Stmt BuildRealize(const Stage& stage,
const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const = 0;
const std::unordered_map<IterVar, Range>& realize_map, const Stmt& body,
String storage_scope = "") const = 0;
/*!
* \brief Build the statement that provide the output tensors.
* \param stage The schedule stage of the op.
Expand Down Expand Up @@ -168,7 +169,7 @@ class PlaceholderOpNode : public OperationNode {
void GatherBound(const Operation& self, const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const final;
Stmt BuildRealize(const Stage& stage, const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const final;
const Stmt& body, String storage_scope = "") const final;
Stmt BuildProvide(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) const final;

Expand Down Expand Up @@ -212,7 +213,7 @@ class TVM_DLL BaseComputeOpNode : public OperationNode {
void GatherBound(const Operation& self, const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const final;
Stmt BuildRealize(const Stage& stage, const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const final;
const Stmt& body, String storage_scope = "") const final;
virtual size_t num_schedulable_dims() const = 0;

static constexpr const char* _type_key = "BaseComputeOp";
Expand Down Expand Up @@ -370,7 +371,7 @@ class ScanOpNode : public OperationNode {
void GatherBound(const Operation& self, const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const final;
Stmt BuildRealize(const Stage& stage, const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const final;
const Stmt& body, String storage_scope = "") const final;
Stmt BuildProvide(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) const final;

Expand Down Expand Up @@ -433,7 +434,7 @@ class ExternOpNode : public OperationNode {
void GatherBound(const Operation& self, const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const final;
Stmt BuildRealize(const Stage& stage, const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const final;
const Stmt& body, String storage_scope = "") const final;
Stmt BuildProvide(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) const final;

Expand Down Expand Up @@ -498,7 +499,7 @@ class HybridOpNode : public OperationNode {
void GatherBound(const Operation& self, const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const final;
Stmt BuildRealize(const Stage& stage, const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const final;
const Stmt& body, String storage_scope = "") const final;
Stmt BuildProvide(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) const final;

Expand Down
3 changes: 2 additions & 1 deletion include/tvm/tir/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -191,12 +191,13 @@ class Buffer : public ObjectRef {
* \param shape The shape of the buffer,
* \param dtype The content data type.
* \param name The name of the buffer
* \param storage_scope The storage scope associated with this buffer
* \param span The location of this object in the source code.
* \return The created buffer.
* \sa Buffer for complete constructor.
*/
TVM_DLL Buffer decl_buffer(Array<PrimExpr> shape, DataType dtype = DataType::Float(32),
String name = "buffer", Span span = Span());
String name = "buffer", String storage_scope = "", Span span = Span());

/*!
* \brief Base node for data producers.
Expand Down
9 changes: 7 additions & 2 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -464,25 +464,30 @@ class ProducerRealizeNode : public StmtNode {
PrimExpr condition;
/*! \brief The body of realization. */
Stmt body;
/*! \brief The storage scope associated with this realization. */
String storage_scope;

void VisitAttrs(AttrVisitor* v) {
v->Visit("producer", &producer);
v->Visit("bounds", &bounds);
v->Visit("condition", &condition);
v->Visit("body", &body);
v->Visit("storage_scope", &storage_scope);
v->Visit("span", &span);
}

bool SEqualReduce(const ProducerRealizeNode* other, SEqualReducer equal) const {
return equal(producer, other->producer) && equal(bounds, other->bounds) &&
equal(condition, other->condition) && equal(body, other->body);
equal(condition, other->condition) && equal(body, other->body) &&
equal(storage_scope, other->storage_scope);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(producer);
hash_reduce(bounds);
hash_reduce(condition);
hash_reduce(body);
hash_reduce(storage_scope);
}

static constexpr const char* _type_key = "tir.ProducerRealize";
Expand All @@ -496,7 +501,7 @@ class ProducerRealizeNode : public StmtNode {
class ProducerRealize : public Stmt {
public:
TVM_DLL ProducerRealize(DataProducer producer, Region bounds, PrimExpr condition, Stmt body,
Span span = Span());
String storage_scope = "", Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(ProducerRealize, Stmt, ProducerRealizeNode);
};
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/script/scope_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def enter_scope(

def setup_buffer_var(extents, dtype, scope, condition=True, span: Span = None):
"""Setup buffer var for a given type."""
buffer_ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype))
buffer_ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype), scope)
self.buffer_var = tvm.tir.Var(name, buffer_ptr_type, span)

setup_buffer_var(*arg_list, span=tvm_span_from_synr(var_span))
Expand Down
16 changes: 16 additions & 0 deletions python/tvm/script/special_stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,22 @@ def var(dtype, span):
super().__init__(var, def_symbol=True)


@register
class BufferVarDef(SpecialStmt):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @spectrometerHBH for tvmscript changes. This is for distinguishing a normal variable from a buffer variable. The latter one requires storage scope. See the change in test_tvmscript_roundtrip.py. The tvm script printer was also updated.

"""Special function for defining a variable of pointer type"""

def __init__(self):
def buffer_var(dtype, storage_scope, span):
assert isinstance(
self.node, ast.Assign
), f"BufferVarDef expected ast.Assign but got {type(self.node)}"
ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype), storage_scope)
v = te.var(self.node.lhs.id.name, ptr_type, span=span)
self.context.update_symbol(v.name, v, self.node)

super().__init__(buffer_var, def_symbol=True)


@register
class EnvThread(SpecialStmt):
"""Bind a var to thread env"""
Expand Down
3 changes: 1 addition & 2 deletions python/tvm/te/hybrid/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,7 @@ def wrap_up_realize(self, node, body):
_domain = [Range.from_min_extent(0, i) for i in _buf.shape]
_dtype = _buf.dtype
_true = tvm.runtime.convert(True)
body = tvm.tir.ProducerRealize(_buf, _domain, _true, body)
body = tvm.tir.AttrStmt(_buf.op, "realize_scope", tvm.runtime.convert(_scope), body)
body = tvm.tir.ProducerRealize(_buf, _domain, _true, body, tvm.runtime.convert(_scope))

for elem in to_pop:
self.symbols.pop(elem)
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/tir/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def decl_buffer(
# Bool is represented as uint1 in the IR, but stored as int8
storage_type = PrimType(dtype)
storage_type = PrimType("int8") if storage_type.dtype == "bool" else storage_type
data = Var(name, PointerType(storage_type), span)
data = Var(name, PointerType(storage_type, scope), span)
return _ffi_api.Buffer(
data,
dtype,
Expand Down
11 changes: 7 additions & 4 deletions python/tvm/tir/ir_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ def let(self, var_name, value):
self.emit(lambda x: _stmt.LetStmt(var, value, x))
return var

def allocate(self, dtype, shape, name="buf", scope=None):
def allocate(self, dtype, shape, name="buf", scope=""):
"""Create a allocate statement.

Parameters
Expand All @@ -416,15 +416,15 @@ def allocate(self, dtype, shape, name="buf", scope=None):
buffer : BufferVar
The buffer var representing the buffer.
"""
buffer_var = _expr.Var(name, PointerType(PrimType(dtype)))
buffer_var = _expr.Var(name, PointerType(PrimType(dtype), scope))
if not isinstance(shape, (list, tuple, _container.Array)):
shape = [shape]
if scope:
self.scope_attr(buffer_var, "storage_scope", scope)
self.emit(lambda x: _stmt.Allocate(buffer_var, dtype, shape, const(1, dtype="uint1"), x))
return BufferVar(self, buffer_var, shape, dtype)

def pointer(self, content_type, name="ptr"):
def pointer(self, content_type, name="ptr", scope=""):
"""Create pointer variable with content type.

Parameters
Expand All @@ -435,12 +435,15 @@ def pointer(self, content_type, name="ptr"):
name : str, optional
The name of the pointer.

scope : str, optional
The scope of the pointer.

Returns
-------
ptr : BufferVar
The buffer var representing the buffer.
"""
buffer_var = _expr.Var(name, dtype="handle")
buffer_var = _expr.Var(name, PointerType(PrimType(content_type), scope))
return BufferVar(self, buffer_var, None, content_type)

def buffer_ptr(self, buf, shape=None):
Expand Down
7 changes: 5 additions & 2 deletions python/tvm/tir/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,13 +364,16 @@ class ProducerRealize(Stmt):
body : Stmt
The realize body

storage_scope : str
The storage scope associated with this realization

span : Optional[Span]
The location of this itervar in the source code.
"""

def __init__(self, producer, bounds, condition, body, span=None):
def __init__(self, producer, bounds, condition, body, storage_scope="", span=None):
self.__init_handle_by_constructor__(
_ffi_api.ProducerRealize, producer, bounds, condition, body, span
_ffi_api.ProducerRealize, producer, bounds, condition, body, storage_scope, span
)


Expand Down
9 changes: 2 additions & 7 deletions src/contrib/hybrid/codegen_hybrid.cc
Original file line number Diff line number Diff line change
Expand Up @@ -315,10 +315,6 @@ void CodeGenHybrid::VisitStmt_(const AttrStmtNode* op) {
indent_ += tab_;
PrintStmt(op->body);
indent_ -= tab_;
} else if (op->attr_key == tir::attr::realize_scope) {
auto v = Downcast<Operation>(op->node);
alloc_storage_scope_[v] = op->value.as<StringImmNode>()->value;
PrintStmt(op->body);
} else {
// For now we ignore the unsupported AttrStmt
PrintStmt(op->body);
Expand All @@ -327,8 +323,7 @@ void CodeGenHybrid::VisitStmt_(const AttrStmtNode* op) {

void CodeGenHybrid::VisitStmt_(const ProducerRealizeNode* op) {
auto tensor = Downcast<Tensor>(op->producer);
ICHECK(alloc_storage_scope_.count(tensor->op));
if (!alloc_storage_scope_[tensor->op].empty()) {
if (!op->storage_scope.empty()) {
PrintIndent();
stream << GetTensorID(tensor) << " = allocate((";
for (size_t i = 0; i < op->bounds.size(); ++i) {
Expand All @@ -339,7 +334,7 @@ void CodeGenHybrid::VisitStmt_(const ProducerRealizeNode* op) {
stream << "), '";
PrintType(tensor->dtype, stream);
stream << "', '";
stream << alloc_storage_scope_[tensor->op] << "')\n";
stream << op->storage_scope << "')\n";
}
PrintStmt(op->body);
}
Expand Down
2 changes: 0 additions & 2 deletions src/contrib/hybrid/codegen_hybrid.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,6 @@ class CodeGenHybrid : public ExprFunctor<void(const PrimExpr&, std::ostream&)>,
* \param tensor The tensor to allocate a name.
*/
std::string GetTensorID(const Tensor& tensor);
/*! \brief the storage scope of allocation */
std::map<Operation, std::string> alloc_storage_scope_;
};

} // namespace contrib
Expand Down
14 changes: 12 additions & 2 deletions src/printer/tvmscript_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <tvm/ir/module.h>
#include <tvm/node/serialization.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/buffer.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/expr_functor.h>
#include <tvm/tir/function.h>
Expand Down Expand Up @@ -1013,8 +1014,17 @@ Doc TVMScriptPrinter::PrintPrimFunc(const PrimFunc& primFunc) {
return memo_var_[GetRef<Var>(a)].str() < memo_var_[GetRef<Var>(b)].str();
});
for (const auto& var : vars) {
header_var << Doc::NewLine() << Print(GetRef<Var>(var)) << " = tir.var(";
header_var << PrintDType(var->dtype) << ")";
auto type = GetRef<Var>(var)->type_annotation;
if (auto* ptr_type = type.as<PointerTypeNode>()) {
auto* prim_type = ptr_type->element_type.as<PrimTypeNode>();
ICHECK(prim_type);
header_var << Doc::NewLine() << Print(GetRef<Var>(var)) << " = tir.buffer_var(";
header_var << PrintDType(prim_type->dtype) << ", "
<< Doc::StrLiteral(ptr_type->storage_scope) << ")";
} else {
header_var << Doc::NewLine() << Print(GetRef<Var>(var)) << " = tir.var(";
header_var << PrintDType(var->dtype) << ")";
}
}
}
doc << Doc::Indent(4, header_attr << header_var << header_buf << body);
Expand Down
3 changes: 2 additions & 1 deletion src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -722,7 +722,8 @@ class AOTExecutorCodegen : public ExprVisitor {
// Define the storage allocator ids
for (auto kv : storage_device_map_) {
for (auto sid : kv.second->storage_ids) {
te::Var buffer_var(MakeString("sid_", sid), PointerType(PrimType(DataType::Int(8))));
te::Var buffer_var(MakeString("sid_", sid),
PointerType(PrimType(DataType::Int(8)), "global"));
sids_table_[sid] = buffer_var;
}
}
Expand Down
4 changes: 3 additions & 1 deletion src/runtime/thread_storage_scope.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,9 @@ struct StorageScope {
*/
static StorageScope Create(const std::string& s) {
StorageScope r;
if (s.compare(0, 6, "global") == 0) {
if (s.empty()) {
r.rank = StorageRank::kGlobal;
} else if (s.compare(0, 6, "global") == 0) {
r.rank = StorageRank::kGlobal;
r.tag = s.substr(6, std::string::npos);
} else if (s.compare(0, 6, "shared") == 0) {
Expand Down
5 changes: 3 additions & 2 deletions src/target/llvm/codegen_amdgpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ class CodeGenAMDGPU : public CodeGenLLVM {
if (info.alignment > 16) {
info.alignment = 16;
}
if (info.scope.rank == runtime::StorageRank::kLocal) {
auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var));
if (storage_scope.rank == runtime::StorageRank::kLocal) {
// const int local_address_space = 5;
// TODO(tqchen): for higher version of LLVM, local address space can be set.
llvm::AllocaInst* alloca = WithFunctionEntry([&]() {
Expand All @@ -99,7 +100,7 @@ class CodeGenAMDGPU : public CodeGenLLVM {
}
buf = alloca;
} else {
ICHECK(info.scope.rank == runtime::StorageRank::kShared)
ICHECK(storage_scope.rank == runtime::StorageRank::kShared)
<< "Can only allocate shared or local memory inside kernel";
// Shared memory: address space == 3
const unsigned shared_address_space = 3;
Expand Down
8 changes: 2 additions & 6 deletions src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,8 @@ void CodeGenLLVM::GetAlignment(DataType t, const VarNode* buf_var, const PrimExp
auto it = alloc_storage_info_.find(buf_var);
if (it != alloc_storage_info_.end()) {
const StorageInfo& info = it->second;
masahi marked this conversation as resolved.
Show resolved Hide resolved
*p_native_bits = NativeVectorBits(info.scope);
*p_native_bits =
NativeVectorBits(runtime::StorageScope::Create(GetPtrStorageScope(GetRef<Var>(buf_var))));
max_align_bits = info.alignment * 8;
} else {
*p_native_bits = native_vector_bits_;
Expand Down Expand Up @@ -1390,11 +1391,6 @@ void CodeGenLLVM::VisitStmt_(const AttrStmtNode* op) {
analyzer_->Bind(iv->var, Range::FromMinExtent(0, op->value));
}
}
} else if (op->attr_key == tir::attr::storage_scope) {
const VarNode* v = op->node.as<VarNode>();
ICHECK(v);
alloc_storage_info_[v].scope =
runtime::StorageScope::Create(op->value.as<StringImmNode>()->value);
} else if (op->attr_key == tir::attr::storage_alignment) {
const VarNode* v = op->node.as<VarNode>();
ICHECK(v);
Expand Down
2 changes: 0 additions & 2 deletions src/target/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,6 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const PrimExpr&)>,
protected:
/*! \brief The storage information */
struct StorageInfo {
/*! \brief The storage scope */
runtime::StorageScope scope;
/*! \brief The alignment of allocation */
masahi marked this conversation as resolved.
Show resolved Hide resolved
int alignment{0};
};
Expand Down
Loading