Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] Migrate VTA TIR passes to the new pass manager. #5397

Merged
merged 1 commit into from
Apr 21, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions include/tvm/target/target.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <tvm/support/with.h>
#include <tvm/node/container.h>
#include <tvm/ir/expr.h>
#include <tvm/ir/transform.h>

#include <string>
#include <vector>
Expand Down Expand Up @@ -225,8 +226,8 @@ class BuildConfigNode : public Object {
/*! \brief Whether to partition const loop */
bool partition_const_loop = false;

/*! \brief Whether to dump the IR of each pass (only when building from python) */
std::vector< std::pair<int, runtime::PackedFunc> > add_lower_pass;
/*! \brief List of passes to be injected into the low-level pipeline. */
std::vector<std::pair<int, transform::Pass>> add_lower_pass;

/*! \brief Whether to dump the IR of each pass (only when building from python) */
bool dump_pass_ir = false;
Expand Down
8 changes: 4 additions & 4 deletions python/tvm/autotvm/measure/measure_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,9 +615,9 @@ def gpu_verify_pass(**kwargs):
"""Verify the validity of a gpu kernel.
This pass will check memory usage and number of threads per block.
"""
def verify_pass(stmt):
valid = ir_pass.VerifyGPUCode(stmt, kwargs)
def verify_pass(f, *_):
valid = ir_pass.VerifyGPUCode(f.body, kwargs)
if not valid:
raise InstantiationError("Skipped because of invalid gpu kernel")
return stmt
return verify_pass
return f
return tvm.tir.transform.prim_func_pass(verify_pass, opt_level=0)
29 changes: 5 additions & 24 deletions python/tvm/driver/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,25 +123,6 @@ def form_irmodule(sch, args, name, binds):
return tvm.IRModule({name: func})


def _wrap_as_prim_func_pass(flist, name):
"""Wrap flist as a function pass.

This is an temporary adapter before we fully
migrate to the new pass manager.
"""
def _transform(func, *_):
stmt = func.body
for f in flist:
stmt = f(stmt)
# create a new function with updated body.
return tvm.tir.PrimFunc(func.params,
stmt,
func.ret_type,
func.buffer_map,
func.attrs)
return tvm.tir.transform.prim_func_pass(_transform, opt_level=0, name=name)


def lower(sch,
args,
name="main",
Expand Down Expand Up @@ -190,15 +171,15 @@ def lower(sch,
else:
mod = sch

pass_list = lower_phase0
# Phase 1
pass_list = [
_wrap_as_prim_func_pass(lower_phase0, "Custom-Phase0"),
pass_list += [
tvm.tir.transform.InjectPrefetch(),
tvm.tir.transform.StorageFlatten(64, cfg.instrument_bound_checkers),
tvm.tir.transform.NarrowDataType(32),
tvm.tir.transform.Simplify(),
_wrap_as_prim_func_pass(lower_phase1, "Custom-Phase1"),
]
pass_list += lower_phase1

# Phase 2
if not simple_mode:
Expand All @@ -214,8 +195,8 @@ def lower(sch,
cfg.auto_unroll_max_depth,
cfg.auto_unroll_max_extent,
cfg.unroll_explicit),
_wrap_as_prim_func_pass(lower_phase2, "Custom-Phase2"),
]
pass_list += lower_phase2

# Phase 3
pass_list += [
Expand All @@ -225,7 +206,7 @@ def lower(sch,

if not cfg.disable_select_rewriting:
pass_list += [tvm.tir.transform.RewriteUnsafeSelect()]
pass_list += [_wrap_as_prim_func_pass(lower_phase3, "Custom-Phase3")]
pass_list += lower_phase3

# Instrument BoundCheckers
if cfg.instrument_bound_checkers:
Expand Down
16 changes: 16 additions & 0 deletions python/tvm/tir/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,19 @@ def __init__(self,

self.__init_handle_by_constructor__(
_ffi_api.PrimFunc, param_list, body, ret_type, buffer_map, attrs)

def with_body(self, new_body):
"""Create a new PrimFunc with the same set signatures but a new body.

Parameters
----------
new_body : Stmt
The new body.

Returns
-------
new_func : PrimFunc
The created new function.
"""
return PrimFunc(
self.params, new_body, self.ret_type, self.buffer_map, self.attrs)
4 changes: 2 additions & 2 deletions src/target/target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -434,12 +434,12 @@ TVM_REGISTER_GLOBAL("target.ExitBuildConfigScope")
TVM_REGISTER_GLOBAL("target.BuildConfigSetAddLowerPass")
.set_body([](TVMArgs args, TVMRetValue* ret) {
BuildConfig cfg = args[0];
std::vector< std::pair<int, PackedFunc> > add_lower_pass;
std::vector<std::pair<int, transform::Pass>> add_lower_pass;
CHECK_EQ(args.size() % 2, 1);
for (int i = 1; i < args.size(); i += 2) {
add_lower_pass.push_back(std::make_pair(
args[i].operator int(),
args[i + 1].operator tvm::runtime::PackedFunc()));
args[i + 1].operator transform::Pass()));
}
cfg->add_lower_pass = add_lower_pass;
});
Expand Down
8 changes: 5 additions & 3 deletions tests/python/relay/test_pass_fold_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,13 @@ def expected():
z = relay.add(y, relay.const(c_data))
return relay.Function([x], z)

def fail(x):
raise RuntimeError()
def FailPass():
def _transform(m, *args):
raise RuntimeError()
return tvm.transform.module_pass(_transform, opt_level=0)

# the fold constant should work on any context.
with tvm.target.build_config(add_lower_pass=[(0, fail)]):
with tvm.target.build_config(add_lower_pass=[(0, FailPass())]):
with tvm.target.create("cuda"):
zz = run_opt_pass(before(), transform.FoldConstant())
zexpected = run_opt_pass(expected(), transform.InferType())
Expand Down
10 changes: 7 additions & 3 deletions tests/python/unittest/test_target_codegen_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def test_cuda_shuffle():
sch[c].bind(xo, thrx)
sch[c].vectorize(xi)

def my_vectorize(stmt):
def MyVectorize():
def vectorizer(op):
if op.for_type == tvm.tir.For.Vectorized:
four = tvm.tir.const(4, 'int32')
Expand All @@ -198,9 +198,13 @@ def vectorizer(op):
new_b = tvm.tir.Shuffle(bs, ids)
return tvm.tir.Store(store.buffer_var, new_a + new_b, idx, all_ones)
return None
return tvm.tir.ir_pass.IRTransform(stmt, None, vectorizer, ['For'])

with tvm.target.build_config(add_lower_pass=[(1, my_vectorize)]):
def _transform(f, *_):
return f.with_body(
tvm.tir.ir_pass.IRTransform(f.body, None, vectorizer, ['For']))
return tvm.tir.transform.prim_func_pass(_transform, opt_level=0, name="MyVectorize")

with tvm.target.build_config(add_lower_pass=[(1, MyVectorize())]):
module = tvm.build(sch, [a, b, c], target='cuda')
a_ = np.array(list(range(64)), dtype='int32')
b_ = np.array((list(range(4))[::-1]) * 16, dtype='int32')
Expand Down
11 changes: 7 additions & 4 deletions tests/python/unittest/test_target_codegen_llvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,8 +671,7 @@ def test_llvm_shuffle():
c = te.compute((8, ), lambda x: a[x] + b[7-x])
sch = te.create_schedule(c.op)

def my_vectorize(stmt):

def my_vectorize():
def vectorizer(op):
store = op.body
idx = tvm.tir.Ramp(tvm.tir.const(0, 'int32'), tvm.tir.const(1, 'int32'), 8)
Expand All @@ -684,9 +683,13 @@ def vectorizer(op):
value = new_a + new_b
return tvm.tir.Store(store.buffer_var, new_a + new_b, idx, all_ones)

return tvm.tir.ir_pass.IRTransform(stmt, None, vectorizer, ['For'])
def _transform(f, *_):
return f.with_body(
tvm.tir.ir_pass.IRTransform(f.body, None, vectorizer, ['For']))

return tvm.tir.transform.prim_func_pass(_transform, opt_level=0, name="my_vectorize")

with tvm.target.build_config(add_lower_pass=[(1, my_vectorize)]):
with tvm.target.build_config(add_lower_pass=[(1, my_vectorize())]):
ir = tvm.lower(sch, [a, b, c], simple_mode=True)
module = tvm.build(sch, [a, b, c])
a_ = tvm.nd.array(np.arange(1, 9, dtype='int32'))
Expand Down
8 changes: 4 additions & 4 deletions tests/python/unittest/test_tir_pass_verify_gpu_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@
from tvm import te

def get_verify_pass(valid, **kwargs):
def verify_pass(stmt):
valid[0] = tvm.tir.ir_pass.VerifyGPUCode(stmt, kwargs)
return stmt
return verify_pass
def _fverify(f, *_):
valid[0] = tvm.tir.ir_pass.VerifyGPUCode(f.body, kwargs)
return f
return tvm.tir.transform.prim_func_pass(_fverify, opt_level=0)

def test_shared_memory():
def check_shared_memory(dtype):
Expand Down
11 changes: 6 additions & 5 deletions tutorials/dev/low_level_custom_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,19 +117,20 @@ def vectorize8(op):
return body
return None

def vectorize(stmt):
@tvm.tir.transform.prim_func_pass(opt_level=0)
def vectorize(f, mod, ctx):
global loops

tvm.tir.ir_pass.PostOrderVisit(stmt, find_width8)
tvm.tir.ir_pass.PostOrderVisit(f.body, find_width8)

if not loops:
return stmt
return sf

# The last list arugment indicates what kinds of nodes will be transformed.
# Thus, in this case only `For` nodes will call `vectorize8`
stmt = tvm.tir.ir_pass.IRTransform(stmt, None, vectorize8, ['For'])
return f.with_body(
tvm.tir.ir_pass.IRTransform(f.body, None, vectorize8, ['For']))

return stmt

#####################################################################
# Glue to Lowering
Expand Down
56 changes: 29 additions & 27 deletions vta/python/vta/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,22 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-argument
# pylint: disable=unused-argument, invalid-name
"""VTA specific buildin for runtime."""
import tvm
from . import ir_pass
from . import transform
from .environment import get_env


def lift_coproc_scope(x):
"""Lift coprocessings cope to the """
x = ir_pass.lift_alloc_to_scope_begin(x)
x = tvm.tir.ir_pass.LiftAttrScope(x, "coproc_scope", False)
return x

def early_rewrite(stmt):
def EarlyRewrite():
"""Try to do storage rewrite in early pass."""
try:
return tvm.tir.ir_pass.StorageRewrite(stmt)
except tvm.error.TVMError:
return stmt
def _transform(mod, ctx):
try:
return tvm.tir.transform.StorageRewrite()(mod)
except tvm.error.TVMError:
return mod
return tvm.transform.module_pass(
_transform, opt_level=0, name="tir.vta.EarlyRewrite")


def build_config(debug_flag=0, **kwargs):
Expand Down Expand Up @@ -60,27 +57,32 @@ def build_config(debug_flag=0, **kwargs):
vta_module = tvm.build(s, ...)
"""
env = get_env()
def add_debug(stmt):

@tvm.tir.transform.prim_func_pass(opt_level=0)
def add_debug(f, *_):
debug = tvm.tir.call_extern(
"int32", "VTASetDebugMode",
env.dev.command_handle,
debug_flag)

return tvm.tir.stmt_seq(debug, stmt)
pass_list = [(0, ir_pass.inject_conv2d_transpose_skip),
(1, ir_pass.inject_dma_intrin),
(1, ir_pass.inject_skip_copy),
(1, ir_pass.annotate_alu_coproc_scope),
(1, lambda x: tvm.tir.ir_pass.LiftAttrScope(x, "coproc_uop_scope", True)),
(1, lift_coproc_scope),
(1, ir_pass.inject_coproc_sync),
(1, early_rewrite)]
return f.with_body(tvm.tir.stmt_seq(debug, f.body))


pass_list = [(0, transform.InjectConv2DTransposeSkip()),
(1, transform.InjectDMAIntrin()),
(1, transform.InjectSkipCopy()),
(1, transform.AnnotateALUCoProcScope()),
(1, tvm.tir.transform.LiftAttrScope("coproc_uop_scope")),
(1, transform.LiftAllocToScopeBegin()),
(1, tvm.tir.transform.LiftAttrScope("coproc_scope")),
(1, transform.InjectCoProcSync()),
(1, EarlyRewrite())]
if debug_flag:
pass_list.append((1, add_debug))
pass_list.append((2, ir_pass.inject_alu_intrin))
pass_list.append((3, tvm.tir.ir_pass.LowerStorageAccessInfo))
pass_list.append((3, ir_pass.fold_uop_loop))
pass_list.append((3, ir_pass.cpu_access_rewrite))
pass_list.append((2, transform.InjectALUIntrin()))
pass_list.append((3, tvm.tir.transform.LowerDeviceStorageAccessInfo()))
pass_list.append((3, transform.FoldUopLoop()))
pass_list.append((3, transform.CPUAccessRewrite()))
return tvm.target.build_config(add_lower_pass=pass_list, **kwargs)


Expand Down
Loading