Skip to content

Commit

Permalink
Migrate to transform pass
Browse files Browse the repository at this point in the history
  • Loading branch information
meta-project-ci committed Apr 2, 2020
1 parent 6cee7b7 commit 458f0b3
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 9 deletions.
3 changes: 2 additions & 1 deletion include/tvm/tir/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,8 @@ Stmt DecorateDeviceScope(Stmt stmt);
Stmt HoistIfThenElse(Stmt stmt);

/*!
* \brief Narrow down PrimExpr datatype in stmt
* \brief Narrow down PrimExpr datatype in stmt to target_bits.
* \note Run this pass after StorageFlatten.
* \param stmt The stmt to do datatype rewrite
* \param target_bits the bit of target datatype
* \return Transformed stmt.
Expand Down
10 changes: 10 additions & 0 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,16 @@ TVM_DLL Pass LowerDeviceStorageAccessInfo();
*/
TVM_DLL Pass LowerWarpMemory();


/*!
* \brief Narrow down PrimExpr datatype in stmt to target_bits.
*
* \note Run this pass after StorageFlatten.
*
* \return The pass.
*/
TVM_DLL Pass NarrowDataType();

} // namespace transform
} // namespace tir
} // namespace tvm
Expand Down
15 changes: 15 additions & 0 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,18 @@ def LowerWarpMemory():
The result pass
"""
return _ffi_api.LowerWarpMemory()


def NarrowDataType():
"""Narrow down PrimExpr datatype in stmt to target_bits.
Returns
-------
fpass : tvm.ir.transform.Pass
The result pass
Note
----
Run this pass after StorageFlatten.
"""
return _ffi_api.NarrowDataType()
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@

#include <tvm/tir/ir_pass.h>
#include <tvm/tir/op.h>
#include <tvm/tir/transform.h>
#include <tvm/runtime/registry.h>
#include "../../arith/ir_mutator_with_analyzer.h"
#include "../../arith/ir_visitor_with_analyzer.h"

Expand Down Expand Up @@ -387,5 +389,25 @@ Stmt NarrowDataType(Stmt stmt, int target_bits) {
return DataTypeRewriter(target_bits)(stmt);
}

namespace transform {

Pass NarrowDataType() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
// TODO(@hzfan): should Target be Attr here, with target_bits inferred from it?
IntImm target_bits = f->GetAttr<IntImm>("target_bits");
CHECK(target_bits.defined())
<< "NarrowDataType: Require the target_bits";
n->body = DataTypeRewriter(target_bits->value)(std::move(n->body));
return f;
};
return CreatePrimFuncPass(
pass_func, 0, "tir.LowerDeviceStorageAccessInfo", {});
}

TVM_REGISTER_GLOBAL("tir.transform.NarrowDataType")
.set_body_typed(NarrowDataType);

} // namespace transform
} // namespace tir
} // namespace tvm
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,15 @@
from tvm.tir import const


def lower(sch, args, target_bits):
def lower_stmt(params, stmt, target_bits):
func = tvm.tir.PrimFunc(params, stmt).with_attr(
"target_bits", target_bits)
func = tvm.tir.transform.NarrowDataType()(tvm.IRModule.from_expr(func))["main"]
stmt = func.body
return stmt


def lower_sch(sch, args, target_bits):
binds = {}
arg_list = []
for x in args:
Expand All @@ -33,8 +41,7 @@ def lower(sch, args, target_bits):
bounds = te.schedule.InferBound(sch)
stmt = te.schedule.ScheduleOps(sch, bounds)
stmt = tvm.tir.ir_pass.StorageFlatten(stmt, binds, 64, False)
stmt = tvm.tir.ir_pass.NarrowDataType(stmt, target_bits)
return stmt
return lower_stmt(arg_list, stmt, target_bits)


def test_basic():
Expand All @@ -48,7 +55,7 @@ def check(m, n, target_bits, target_dtype):
with ib.for_range(0, n, name='j') as j:
B[i * n + j] = A[i * n + j] + 1
stmt = ib.get()
stmt = tvm.tir.ir_pass.NarrowDataType(stmt, target_bits)
stmt = lower_stmt([Ab, Bb], stmt, target_bits)
assert stmt.loop_var.dtype == target_dtype
assert stmt.body.loop_var.dtype == target_dtype

Expand Down Expand Up @@ -81,7 +88,7 @@ def check(m, n, target_bits, target_dtype):
ib.scope_attr(tx, "thread_extent", n)
B[bx * n + tx] = A[bx * n + tx] + 1
stmt = ib.get()
stmt = tvm.tir.ir_pass.NarrowDataType(stmt, target_bits)
stmt = lower_stmt([Ab, Bb], stmt, target_bits)
assert stmt.node.var.dtype == target_dtype
assert stmt.body.node.var.dtype == target_dtype

Expand Down Expand Up @@ -114,7 +121,7 @@ def check(m, lanes, target_bits, target_dtype):
with ib.for_range(0, m, name='i', dtype=m.dtype) as i:
B[i] = A[i] + 1
stmt = ib.get()
stmt = tvm.tir.ir_pass.NarrowDataType(stmt, target_bits)
stmt = lower_stmt([Ab, Bb], stmt, target_bits)
assert stmt.loop_var.dtype == target_dtype

# i32 -> i32
Expand All @@ -140,7 +147,7 @@ def check(m, target_bits, target_dtype):
k = te.reduce_axis((0, m), "k")
B = te.compute((), lambda *idx: te.sum(A[k], axis=k), name='B')
s = te.create_schedule(B.op)
stmt = lower(s, [A, B], target_bits)
stmt = lower_sch(s, [A, B], target_bits)
assert stmt.body[1].loop_var.dtype == target_dtype

# i32 -> i32
Expand All @@ -167,7 +174,7 @@ def check(m, n, target_bits, target_dtype):
with ib.for_range(0, n, name='j') as j:
A[i * n + j] = B[i * 2 * n + 2 * j] + 1
stmt = ib.get()
stmt = tvm.tir.ir_pass.NarrowDataType(stmt, target_bits)
stmt = lower_stmt([Ab, Bb], stmt, target_bits)
assert stmt.loop_var.dtype == target_dtype
assert stmt.body.loop_var.dtype == target_dtype

Expand Down

0 comments on commit 458f0b3

Please sign in to comment.