diff --git a/include/tvm/tir/data_type_rewriter.h b/include/tvm/tir/data_type_rewriter.h index 8bdcc097a2..846cda74c6 100644 --- a/include/tvm/tir/data_type_rewriter.h +++ b/include/tvm/tir/data_type_rewriter.h @@ -104,6 +104,7 @@ class IndexDataTypeRewriter : public DataTypeLegalizer { Stmt VisitStmt_(const BlockRealizeNode* op) override; Stmt VisitStmt_(const BlockNode* op) override; Stmt VisitStmt_(const BufferStoreNode* op) override; + Stmt VisitStmt_(const AttrStmtNode* op) override; PrimExpr VisitExpr_(const BufferLoadNode* op) override; Array VisitIndices(Array indices); Stmt VisitStmt_(const IfThenElseNode* op) override; diff --git a/python/tvm/topi/cuda/scan.py b/python/tvm/topi/cuda/scan.py index f697302961..238163722f 100644 --- a/python/tvm/topi/cuda/scan.py +++ b/python/tvm/topi/cuda/scan.py @@ -60,8 +60,8 @@ def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add, i your operation. """ - batch_size = prod(data.shape[:-1]) - scan_axis_size = data.shape[-1] + batch_size = cast(prod(data.shape[:-1]), "int32") + scan_axis_size = cast(data.shape[-1], "int32") ib = tvm.tir.ir_builder.create() @@ -105,7 +105,7 @@ def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add, i # Up Sweep of exclusive scan lim = ceil_log2(scan_axis_size) - with ib.for_range(0, cast(lim, "int64"), dtype="int64") as l2_width: + with ib.for_range(0, cast(lim, "int32"), dtype="int32") as l2_width: width = 2 << l2_width with ib.new_scope(): @@ -121,9 +121,9 @@ def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add, i by = te.thread_axis("blockIdx.y") ib.scope_attr(by, "thread_extent", nthread_by) - start = ib.allocate("int64", (1,), name="start", scope="local") - middle = ib.allocate("int64", (1,), name="middle", scope="local") - end = ib.allocate("int64", (1,), name="end", scope="local") + start = ib.allocate("int32", (1,), name="start", scope="local") + middle = ib.allocate("int32", (1,), name="middle", scope="local") + end = ib.allocate("int32", (1,), name="end", scope="local") start[0] = width * tid with ib.if_scope(start[0] < scan_axis_size): middle[0] = start[0] + tvm.tir.indexdiv(width, 2) @@ -143,7 +143,7 @@ def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add, i reduction[bx] = output[(bx + 1) * scan_axis_size - 1] output[(bx + 1) * scan_axis_size - 1] = cast(identity_value, out_dtype) - with ib.for_range(0, cast(lim, "int64"), dtype="int64") as l2_width: + with ib.for_range(0, cast(lim, "int32"), dtype="int32") as l2_width: width = 2 << (lim - l2_width - 1) with ib.new_scope(): @@ -159,9 +159,9 @@ def exclusive_scan_ir(data, output, reduction=None, binop=tvm.tir.generic.add, i by = te.thread_axis("blockIdx.y") ib.scope_attr(by, "thread_extent", nthread_by) - start = ib.allocate("int64", (1,), name="start", scope="local") - middle = ib.allocate("int64", (1,), name="middle", scope="local") - end = ib.allocate("int64", (1,), name="end", scope="local") + start = ib.allocate("int32", (1,), name="start", scope="local") + middle = ib.allocate("int32", (1,), name="middle", scope="local") + end = ib.allocate("int32", (1,), name="end", scope="local") tmp = ib.allocate(out_dtype, (1,), name="end", scope="local") start[0] = width * tid with ib.if_scope(tvm.tir.all(start[0] < scan_axis_size)): @@ -206,8 +206,8 @@ def get_reduction_from_exclusive_scan(data, ex_scan_output, binop=tvm.tir.generi ex_scan_output = expand_dims(ex_scan_output, axis=0) def ir(data, data_ex_scan, reduction): - batch_size = prod(data.shape[:-1]) - scan_axis_size = data.shape[-1] + batch_size = cast(prod(data.shape[:-1]), "int32") + scan_axis_size = cast(data.shape[-1], "int32") ib = tvm.tir.ir_builder.create() diff --git a/src/tir/ir/data_type_rewriter.cc b/src/tir/ir/data_type_rewriter.cc index e68d085b44..2bd1e06083 100644 --- a/src/tir/ir/data_type_rewriter.cc +++ b/src/tir/ir/data_type_rewriter.cc @@ -258,6 +258,17 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const AllocateNode* op) { } } +Stmt IndexDataTypeRewriter::VisitStmt_(const AttrStmtNode* op) { + if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) { + bool is_enabled = is_enabled_; + is_enabled_ = true; + auto stmt = DataTypeLegalizer::VisitStmt_(op); + is_enabled_ = is_enabled; + return stmt; + } + return DataTypeLegalizer::VisitStmt_(op); +} + Stmt IndexDataTypeRewriter::VisitStmt_(const DeclBufferNode* op) { Buffer new_buffer = VisitBuffer(op->buffer); DeclBuffer decl_buffer = Downcast(StmtExprMutator::VisitStmt_(op));