Skip to content

Commit

Permalink
[PTYTHON] Migrate VTA TIR passes to the new pass manager. (apache#5397)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored and Trevor Morris committed Jun 18, 2020
1 parent 1b3ddd6 commit aa603c5
Show file tree
Hide file tree
Showing 13 changed files with 1,050 additions and 1,073 deletions.
5 changes: 3 additions & 2 deletions include/tvm/target/target.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <tvm/support/with.h>
#include <tvm/node/container.h>
#include <tvm/ir/expr.h>
#include <tvm/ir/transform.h>

#include <string>
#include <vector>
Expand Down Expand Up @@ -225,8 +226,8 @@ class BuildConfigNode : public Object {
/*! \brief Whether to partition const loop */
bool partition_const_loop = false;

/*! \brief Whether to dump the IR of each pass (only when building from python) */
std::vector< std::pair<int, runtime::PackedFunc> > add_lower_pass;
/*! \brief List of passes to be injected into the low-level pipeline. */
std::vector<std::pair<int, transform::Pass>> add_lower_pass;

/*! \brief Whether to dump the IR of each pass (only when building from python) */
bool dump_pass_ir = false;
Expand Down
8 changes: 4 additions & 4 deletions python/tvm/autotvm/measure/measure_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,9 +615,9 @@ def gpu_verify_pass(**kwargs):
"""Verify the validity of a gpu kernel.
This pass will check memory usage and number of threads per block.
"""
def verify_pass(stmt):
valid = ir_pass.VerifyGPUCode(stmt, kwargs)
def verify_pass(f, *_):
valid = ir_pass.VerifyGPUCode(f.body, kwargs)
if not valid:
raise InstantiationError("Skipped because of invalid gpu kernel")
return stmt
return verify_pass
return f
return tvm.tir.transform.prim_func_pass(verify_pass, opt_level=0)
29 changes: 5 additions & 24 deletions python/tvm/driver/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,25 +123,6 @@ def form_irmodule(sch, args, name, binds):
return tvm.IRModule({name: func})


def _wrap_as_prim_func_pass(flist, name):
"""Wrap flist as a function pass.
This is an temporary adapter before we fully
migrate to the new pass manager.
"""
def _transform(func, *_):
stmt = func.body
for f in flist:
stmt = f(stmt)
# create a new function with updated body.
return tvm.tir.PrimFunc(func.params,
stmt,
func.ret_type,
func.buffer_map,
func.attrs)
return tvm.tir.transform.prim_func_pass(_transform, opt_level=0, name=name)


def lower(sch,
args,
name="main",
Expand Down Expand Up @@ -190,15 +171,15 @@ def lower(sch,
else:
mod = sch

pass_list = lower_phase0
# Phase 1
pass_list = [
_wrap_as_prim_func_pass(lower_phase0, "Custom-Phase0"),
pass_list += [
tvm.tir.transform.InjectPrefetch(),
tvm.tir.transform.StorageFlatten(64, cfg.instrument_bound_checkers),
tvm.tir.transform.NarrowDataType(32),
tvm.tir.transform.Simplify(),
_wrap_as_prim_func_pass(lower_phase1, "Custom-Phase1"),
]
pass_list += lower_phase1

# Phase 2
if not simple_mode:
Expand All @@ -214,8 +195,8 @@ def lower(sch,
cfg.auto_unroll_max_depth,
cfg.auto_unroll_max_extent,
cfg.unroll_explicit),
_wrap_as_prim_func_pass(lower_phase2, "Custom-Phase2"),
]
pass_list += lower_phase2

# Phase 3
pass_list += [
Expand All @@ -225,7 +206,7 @@ def lower(sch,

if not cfg.disable_select_rewriting:
pass_list += [tvm.tir.transform.RewriteUnsafeSelect()]
pass_list += [_wrap_as_prim_func_pass(lower_phase3, "Custom-Phase3")]
pass_list += lower_phase3

# Instrument BoundCheckers
if cfg.instrument_bound_checkers:
Expand Down
16 changes: 16 additions & 0 deletions python/tvm/tir/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,19 @@ def __init__(self,

self.__init_handle_by_constructor__(
_ffi_api.PrimFunc, param_list, body, ret_type, buffer_map, attrs)

def with_body(self, new_body):
"""Create a new PrimFunc with the same set signatures but a new body.
Parameters
----------
new_body : Stmt
The new body.
Returns
-------
new_func : PrimFunc
The created new function.
"""
return PrimFunc(
self.params, new_body, self.ret_type, self.buffer_map, self.attrs)
4 changes: 2 additions & 2 deletions src/target/target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -434,12 +434,12 @@ TVM_REGISTER_GLOBAL("target.ExitBuildConfigScope")
TVM_REGISTER_GLOBAL("target.BuildConfigSetAddLowerPass")
.set_body([](TVMArgs args, TVMRetValue* ret) {
BuildConfig cfg = args[0];
std::vector< std::pair<int, PackedFunc> > add_lower_pass;
std::vector<std::pair<int, transform::Pass>> add_lower_pass;
CHECK_EQ(args.size() % 2, 1);
for (int i = 1; i < args.size(); i += 2) {
add_lower_pass.push_back(std::make_pair(
args[i].operator int(),
args[i + 1].operator tvm::runtime::PackedFunc()));
args[i + 1].operator transform::Pass()));
}
cfg->add_lower_pass = add_lower_pass;
});
Expand Down
8 changes: 5 additions & 3 deletions tests/python/relay/test_pass_fold_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,13 @@ def expected():
z = relay.add(y, relay.const(c_data))
return relay.Function([x], z)

def fail(x):
raise RuntimeError()
def FailPass():
def _transform(m, *args):
raise RuntimeError()
return tvm.transform.module_pass(_transform, opt_level=0)

# the fold constant should work on any context.
with tvm.target.build_config(add_lower_pass=[(0, fail)]):
with tvm.target.build_config(add_lower_pass=[(0, FailPass())]):
with tvm.target.create("cuda"):
zz = run_opt_pass(before(), transform.FoldConstant())
zexpected = run_opt_pass(expected(), transform.InferType())
Expand Down
10 changes: 7 additions & 3 deletions tests/python/unittest/test_target_codegen_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def test_cuda_shuffle():
sch[c].bind(xo, thrx)
sch[c].vectorize(xi)

def my_vectorize(stmt):
def MyVectorize():
def vectorizer(op):
if op.for_type == tvm.tir.For.Vectorized:
four = tvm.tir.const(4, 'int32')
Expand All @@ -198,9 +198,13 @@ def vectorizer(op):
new_b = tvm.tir.Shuffle(bs, ids)
return tvm.tir.Store(store.buffer_var, new_a + new_b, idx, all_ones)
return None
return tvm.tir.ir_pass.IRTransform(stmt, None, vectorizer, ['For'])

with tvm.target.build_config(add_lower_pass=[(1, my_vectorize)]):
def _transform(f, *_):
return f.with_body(
tvm.tir.ir_pass.IRTransform(f.body, None, vectorizer, ['For']))
return tvm.tir.transform.prim_func_pass(_transform, opt_level=0, name="MyVectorize")

with tvm.target.build_config(add_lower_pass=[(1, MyVectorize())]):
module = tvm.build(sch, [a, b, c], target='cuda')
a_ = np.array(list(range(64)), dtype='int32')
b_ = np.array((list(range(4))[::-1]) * 16, dtype='int32')
Expand Down
11 changes: 7 additions & 4 deletions tests/python/unittest/test_target_codegen_llvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,8 +671,7 @@ def test_llvm_shuffle():
c = te.compute((8, ), lambda x: a[x] + b[7-x])
sch = te.create_schedule(c.op)

def my_vectorize(stmt):

def my_vectorize():
def vectorizer(op):
store = op.body
idx = tvm.tir.Ramp(tvm.tir.const(0, 'int32'), tvm.tir.const(1, 'int32'), 8)
Expand All @@ -684,9 +683,13 @@ def vectorizer(op):
value = new_a + new_b
return tvm.tir.Store(store.buffer_var, new_a + new_b, idx, all_ones)

return tvm.tir.ir_pass.IRTransform(stmt, None, vectorizer, ['For'])
def _transform(f, *_):
return f.with_body(
tvm.tir.ir_pass.IRTransform(f.body, None, vectorizer, ['For']))

return tvm.tir.transform.prim_func_pass(_transform, opt_level=0, name="my_vectorize")

with tvm.target.build_config(add_lower_pass=[(1, my_vectorize)]):
with tvm.target.build_config(add_lower_pass=[(1, my_vectorize())]):
ir = tvm.lower(sch, [a, b, c], simple_mode=True)
module = tvm.build(sch, [a, b, c])
a_ = tvm.nd.array(np.arange(1, 9, dtype='int32'))
Expand Down
8 changes: 4 additions & 4 deletions tests/python/unittest/test_tir_pass_verify_gpu_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
from tvm import te

def get_verify_pass(valid, **kwargs):
def verify_pass(stmt):
valid[0] = tvm.tir.ir_pass.VerifyGPUCode(stmt, kwargs)
return stmt
return verify_pass
def _fverify(f, *_):
valid[0] = tvm.tir.ir_pass.VerifyGPUCode(f.body, kwargs)
return f
return tvm.tir.transform.prim_func_pass(_fverify, opt_level=0)

def test_shared_memory():
def check_shared_memory(dtype):
Expand Down
11 changes: 6 additions & 5 deletions tutorials/dev/low_level_custom_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,19 +117,20 @@ def vectorize8(op):
return body
return None

def vectorize(stmt):
@tvm.tir.transform.prim_func_pass(opt_level=0)
def vectorize(f, mod, ctx):
global loops

tvm.tir.ir_pass.PostOrderVisit(stmt, find_width8)
tvm.tir.ir_pass.PostOrderVisit(f.body, find_width8)

if not loops:
return stmt
return sf

# The last list arugment indicates what kinds of nodes will be transformed.
# Thus, in this case only `For` nodes will call `vectorize8`
stmt = tvm.tir.ir_pass.IRTransform(stmt, None, vectorize8, ['For'])
return f.with_body(
tvm.tir.ir_pass.IRTransform(f.body, None, vectorize8, ['For']))

return stmt

#####################################################################
# Glue to Lowering
Expand Down
56 changes: 29 additions & 27 deletions vta/python/vta/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,22 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-argument
# pylint: disable=unused-argument, invalid-name
"""VTA specific buildin for runtime."""
import tvm
from . import ir_pass
from . import transform
from .environment import get_env


def lift_coproc_scope(x):
"""Lift coprocessings cope to the """
x = ir_pass.lift_alloc_to_scope_begin(x)
x = tvm.tir.ir_pass.LiftAttrScope(x, "coproc_scope", False)
return x

def early_rewrite(stmt):
def EarlyRewrite():
"""Try to do storage rewrite in early pass."""
try:
return tvm.tir.ir_pass.StorageRewrite(stmt)
except tvm.error.TVMError:
return stmt
def _transform(mod, ctx):
try:
return tvm.tir.transform.StorageRewrite()(mod)
except tvm.error.TVMError:
return mod
return tvm.transform.module_pass(
_transform, opt_level=0, name="tir.vta.EarlyRewrite")


def build_config(debug_flag=0, **kwargs):
Expand Down Expand Up @@ -60,27 +57,32 @@ def build_config(debug_flag=0, **kwargs):
vta_module = tvm.build(s, ...)
"""
env = get_env()
def add_debug(stmt):

@tvm.tir.transform.prim_func_pass(opt_level=0)
def add_debug(f, *_):
debug = tvm.tir.call_extern(
"int32", "VTASetDebugMode",
env.dev.command_handle,
debug_flag)

return tvm.tir.stmt_seq(debug, stmt)
pass_list = [(0, ir_pass.inject_conv2d_transpose_skip),
(1, ir_pass.inject_dma_intrin),
(1, ir_pass.inject_skip_copy),
(1, ir_pass.annotate_alu_coproc_scope),
(1, lambda x: tvm.tir.ir_pass.LiftAttrScope(x, "coproc_uop_scope", True)),
(1, lift_coproc_scope),
(1, ir_pass.inject_coproc_sync),
(1, early_rewrite)]
return f.with_body(tvm.tir.stmt_seq(debug, f.body))


pass_list = [(0, transform.InjectConv2DTransposeSkip()),
(1, transform.InjectDMAIntrin()),
(1, transform.InjectSkipCopy()),
(1, transform.AnnotateALUCoProcScope()),
(1, tvm.tir.transform.LiftAttrScope("coproc_uop_scope")),
(1, transform.LiftAllocToScopeBegin()),
(1, tvm.tir.transform.LiftAttrScope("coproc_scope")),
(1, transform.InjectCoProcSync()),
(1, EarlyRewrite())]
if debug_flag:
pass_list.append((1, add_debug))
pass_list.append((2, ir_pass.inject_alu_intrin))
pass_list.append((3, tvm.tir.ir_pass.LowerStorageAccessInfo))
pass_list.append((3, ir_pass.fold_uop_loop))
pass_list.append((3, ir_pass.cpu_access_rewrite))
pass_list.append((2, transform.InjectALUIntrin()))
pass_list.append((3, tvm.tir.transform.LowerDeviceStorageAccessInfo()))
pass_list.append((3, transform.FoldUopLoop()))
pass_list.append((3, transform.CPUAccessRewrite()))
return tvm.target.build_config(add_lower_pass=pass_list, **kwargs)


Expand Down
Loading

0 comments on commit aa603c5

Please sign in to comment.