Skip to content

Commit

Permalink
[Unity] Support cumsum with pure int32 (#16439)
Browse files Browse the repository at this point in the history
This PR fixes a bug on attr handling in data type rewriter and enforces i32 buffer in cumsum function definition, which ensures that cumsum can be run on a machine with int32 but not int64.
  • Loading branch information
jinhongyii authored and junrushao committed Jan 21, 2024
1 parent 6bab1b8 commit 94889f3
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 12 deletions.
1 change: 1 addition & 0 deletions include/tvm/tir/data_type_rewriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ class IndexDataTypeRewriter : public DataTypeLegalizer {
Stmt VisitStmt_(const BlockRealizeNode* op) override;
Stmt VisitStmt_(const BlockNode* op) override;
Stmt VisitStmt_(const BufferStoreNode* op) override;
Stmt VisitStmt_(const AttrStmtNode* op) override;
PrimExpr VisitExpr_(const BufferLoadNode* op) override;
Array<PrimExpr> VisitIndices(Array<PrimExpr> indices);
Stmt VisitStmt_(const IfThenElseNode* op) override;
Expand Down
24 changes: 12 additions & 12 deletions python/tvm/topi/cuda/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add, i
your operation.
"""

batch_size = prod(data.shape[:-1])
scan_axis_size = data.shape[-1]
batch_size = cast(prod(data.shape[:-1]), "int32")
scan_axis_size = cast(data.shape[-1], "int32")

ib = tvm.tir.ir_builder.create()

Expand Down Expand Up @@ -105,7 +105,7 @@ def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add, i
# Up Sweep of exclusive scan
lim = ceil_log2(scan_axis_size)

with ib.for_range(0, cast(lim, "int64"), dtype="int64") as l2_width:
with ib.for_range(0, cast(lim, "int32"), dtype="int32") as l2_width:
width = 2 << l2_width

with ib.new_scope():
Expand All @@ -121,9 +121,9 @@ def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add, i

by = te.thread_axis("blockIdx.y")
ib.scope_attr(by, "thread_extent", nthread_by)
start = ib.allocate("int64", (1,), name="start", scope="local")
middle = ib.allocate("int64", (1,), name="middle", scope="local")
end = ib.allocate("int64", (1,), name="end", scope="local")
start = ib.allocate("int32", (1,), name="start", scope="local")
middle = ib.allocate("int32", (1,), name="middle", scope="local")
end = ib.allocate("int32", (1,), name="end", scope="local")
start[0] = width * tid
with ib.if_scope(start[0] < scan_axis_size):
middle[0] = start[0] + tvm.tir.indexdiv(width, 2)
Expand All @@ -143,7 +143,7 @@ def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add, i
reduction[bx] = output[(bx + 1) * scan_axis_size - 1]
output[(bx + 1) * scan_axis_size - 1] = cast(identity_value, out_dtype)

with ib.for_range(0, cast(lim, "int64"), dtype="int64") as l2_width:
with ib.for_range(0, cast(lim, "int32"), dtype="int32") as l2_width:
width = 2 << (lim - l2_width - 1)

with ib.new_scope():
Expand All @@ -159,9 +159,9 @@ def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add, i

by = te.thread_axis("blockIdx.y")
ib.scope_attr(by, "thread_extent", nthread_by)
start = ib.allocate("int64", (1,), name="start", scope="local")
middle = ib.allocate("int64", (1,), name="middle", scope="local")
end = ib.allocate("int64", (1,), name="end", scope="local")
start = ib.allocate("int32", (1,), name="start", scope="local")
middle = ib.allocate("int32", (1,), name="middle", scope="local")
end = ib.allocate("int32", (1,), name="end", scope="local")
tmp = ib.allocate(out_dtype, (1,), name="end", scope="local")
start[0] = width * tid
with ib.if_scope(tvm.tir.all(start[0] < scan_axis_size)):
Expand Down Expand Up @@ -206,8 +206,8 @@ def get_reduction_from_exclusive_scan(data, ex_scan_output, binop=tvm.tir.generi
ex_scan_output = expand_dims(ex_scan_output, axis=0)

def ir(data, data_ex_scan, reduction):
batch_size = prod(data.shape[:-1])
scan_axis_size = data.shape[-1]
batch_size = cast(prod(data.shape[:-1]), "int32")
scan_axis_size = cast(data.shape[-1], "int32")

ib = tvm.tir.ir_builder.create()

Expand Down
11 changes: 11 additions & 0 deletions src/tir/ir/data_type_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,17 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const AllocateNode* op) {
}
}

Stmt IndexDataTypeRewriter::VisitStmt_(const AttrStmtNode* op) {
if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) {
bool is_enabled = is_enabled_;
is_enabled_ = true;
auto stmt = DataTypeLegalizer::VisitStmt_(op);
is_enabled_ = is_enabled;
return stmt;
}
return DataTypeLegalizer::VisitStmt_(op);
}

Stmt IndexDataTypeRewriter::VisitStmt_(const DeclBufferNode* op) {
Buffer new_buffer = VisitBuffer(op->buffer);
DeclBuffer decl_buffer = Downcast<DeclBuffer>(StmtExprMutator::VisitStmt_(op));
Expand Down

0 comments on commit 94889f3

Please sign in to comment.