diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 10de08710fbe3..ce6375550f6ef 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -530,7 +530,7 @@ TVM_DLL Pass CanonicalizeOps(); * * \return The pass. */ -TVM_DLL Pass AlterOpLayout(); +TVM_DLL Pass AlterOpLayout(const std::string& alter_map_attr_name = "FTVMAlterOpLayout"); /*! * \brief Legalizes an expr with another expression. diff --git a/python/tvm/relay/op/__init__.py b/python/tvm/relay/op/__init__.py index a089cab669c92..2d9af8a741b10 100644 --- a/python/tvm/relay/op/__init__.py +++ b/python/tvm/relay/op/__init__.py @@ -18,7 +18,7 @@ """Relay core operators.""" # operator defs from .op import get, register, register_schedule, register_compute, register_gradient, \ - register_pattern, register_alter_op_layout, register_legalize, \ + register_pattern, register_alter_op_layout, register_convert_op_layout, register_legalize, \ schedule_injective, Op, OpPattern, debug # Operators diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py index fcbc3fd544479..e7060e31b3761 100644 --- a/python/tvm/relay/op/op.py +++ b/python/tvm/relay/op/op.py @@ -186,6 +186,23 @@ def register_alter_op_layout(op_name, alter_layout=None, level=10): return register(op_name, "FTVMAlterOpLayout", alter_layout, level) +def register_convert_op_layout(op_name, convert_layout=None, level=10): + """Register alter op layout function for an op + + Parameters + ---------- + op_name : str + The name of the operator + + alter_layout: function (attrs: Attrs, inputs: List[Expr]) -> new_expr: Expr + The function for changing the layout or replacing the operator + + level : int + The priority level + """ + return register(op_name, "FTVMConvertOpLayout", convert_layout, level) + + def register_legalize(op_name, legal_op=None, level=10): """Register legal transformation function for an op diff --git a/python/tvm/relay/transform.py b/python/tvm/relay/transform.py index d3509dabddf99..efeda38a3a8aa 100644 --- a/python/tvm/relay/transform.py +++ b/python/tvm/relay/transform.py @@ -430,18 +430,23 @@ def CombineParallelDense(min_num_branches=3): return _transform.CombineParallelDense(min_num_branches) -def AlterOpLayout(): +def AlterOpLayout(alter_map_attr_name="FTVMAlterOpLayout"): """Alternate the layouts of operators or replace primitive operators with other expressions. This pass can be used for computing convolution in custom layouts or other general weight pre-transformation. + Parameters + ---------- + alter_map_attr_name : str + The Op's attr name which corresponds to the alter rule function. + Returns ------- ret : tvm.relay.Pass The registered pass that alters the layout of operators. """ - return _transform.AlterOpLayout() + return _transform.AlterOpLayout(alter_map_attr_name) def Legalize(legalize_map_attr_name="FTVMLegalize"): @@ -979,3 +984,63 @@ def visit_var(self, var): else: return var return ChangeBatchMutator().visit(func) + +@function_pass(opt_level=1) +class ConvertLayout: + """ + Given a dest layout, this pass transforms the expr such that most of the ops input data layout + is changed to the dest layout. In ideal situation, there are only 2 layout transforms, one at + the start and one at the end. + + This pass is not a part of relay.build and is expected to be called between framework-relay + parser and relay.build call. This is very helpful for hardware backends that support/prefer only + type of data layout. + + RFC - https://discuss.tvm.ai/t/layout-conversion-pass/4009 + + This pass uses most of the AlterOpLayout and InferCorrectLayout infrastructure. We can just + overwrite conv2d alter op layout. Most of the other operators try to adapt to their input layout + using the InferCorrectLayout infrastructure. + + Parameters + ---------- + dst_layout: str + The desired layout for the transformed expr. + + Returns + ------- + pass: FunctionPass + The pass. + """ + + def __init__(self, dst_layout): + assert dst_layout == 'NCHW', \ + 'Currently, only NCHW layout conversion is supported.' + + # Register convert layout function for conv2d op. + from tvm.relay.op import register_convert_op_layout + @register_convert_op_layout("nn.conv2d") + def alter_conv2d(attrs, inputs, tinfos): + data_layout = attrs['data_layout'] + kernel_layout = attrs['kernel_layout'] + data, weight = inputs + if dst_layout == 'NCHW': + new_attrs = dict(attrs) + new_attrs['data_layout'] = dst_layout + new_attrs['kernel_layout'] = 'OIHW' + + if data_layout == 'NHWC' and kernel_layout == 'HWIO': + # Convert (NHWC, HWIO) to (NCHW, OIHW) + return relay.nn.conv2d(data, weight, **new_attrs) + elif data_layout == 'NHWC' and kernel_layout == 'HWOI': + # Convert (NHWC, HWOI) to (NCHW, OIHW). Depthwise conv2d. + return relay.nn.conv2d(data, weight, **new_attrs) + return None + return None + + def transform_function(self, func, mod, ctx): + cur_mod = relay.Module.from_expr(func) + cur_mod = CanonicalizeOps()(cur_mod) + cur_mod = AlterOpLayout("FTVMConvertOpLayout")(cur_mod) + cur_mod = FoldConstant()(cur_mod) + return cur_mod['main'] diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 9254c7e3e7b9a..6c5d7e0cfffbb 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -345,6 +345,8 @@ class RelayBuildModule : public runtime::ModuleNode { } else { relay_module = seq(relay_module); } + auto func1 = relay_module->Lookup("main"); + LOG(INFO) << AsText(func1, false); // Handle heterogeneous compilation. transform::PassContext pass_ctx = PassContext::Current(); diff --git a/src/relay/pass/alter_op_layout.cc b/src/relay/pass/alter_op_layout.cc index bbfb97c56dc2f..9936ec36495aa 100644 --- a/src/relay/pass/alter_op_layout.cc +++ b/src/relay/pass/alter_op_layout.cc @@ -185,8 +185,9 @@ std::tuple, Array, bool> CallInfer( // Call registered FTVMAlterOpLayout of an op // Returns the altered expression Call CallAlter(const Call& ref_call, - const std::vector& new_args) { - static auto falter_layout = Op::GetAttr("FTVMAlterOpLayout"); + const std::vector& new_args, + const std::string& alter_map_attr_name) { + auto falter_layout = Op::GetAttr(alter_map_attr_name); Op op = Downcast(ref_call->op); Expr new_e; @@ -213,9 +214,10 @@ Call CallAlter(const Call& ref_call, return GetRef(new_call); } -Expr AlterOpLayoutRewrite(const Call &ref_call, - const Array &new_args, - const NodeRef& ctx) { +Expr Rewriter(const Call &ref_call, + const Array &new_args, + const NodeRef& ctx, + const std::string& alter_map_attr_name) { std::vector inputs; std::vector normal_new_args; Array > input_shapes; @@ -289,7 +291,7 @@ Expr AlterOpLayoutRewrite(const Call &ref_call, } // new_op = alter(op) - Call new_call = CallAlter(ref_call, normal_new_args); + Call new_call = CallAlter(ref_call, normal_new_args, alter_map_attr_name); // new_in2, new_out = op.infer(new_in) if (new_call->op->IsInstance()) { @@ -353,26 +355,44 @@ Expr AlterOpLayoutRewrite(const Call &ref_call, } } +Expr AlterOpLayoutRewrite(const Call &ref_call, + const Array &new_args, + const NodeRef& ctx) { + return Rewriter(ref_call, new_args, ctx, "FTVMAlterOpLayout"); +} + +Expr ConvertOpLayoutRewrite(const Call &ref_call, + const Array &new_args, + const NodeRef& ctx) { + return Rewriter(ref_call, new_args, ctx, "FTVMConvertOpLayout"); +} + // Limiations: // 1. the altered op should have the same number of arguments as the previous one // 2. do not support nested tuple arguments -Expr AlterOpLayout(const Expr& expr) { +Expr AlterOpLayout(const Expr& expr, const std::string& alter_map_attr_name) { TransformMemorizer transformMemorizer(make_node()); auto fcontext = [&](const Call& call) -> NodeRef{ return transformMemorizer; }; - return ForwardRewrite(expr, AlterOpLayoutRewrite, fcontext); + if (alter_map_attr_name == "FTVMAlterOpLayout") { + return ForwardRewrite(expr, AlterOpLayoutRewrite, fcontext); + } else if (alter_map_attr_name == "FTVMConvertOpLayout") { + return ForwardRewrite(expr, ConvertOpLayoutRewrite, fcontext); + } + LOG(FATAL) << "AlterOpLayout supports only 2 attr names - FTVMAlterOpLayout/FTVMConvertOpLayout"; + return Expr(); } } // namespace alter_op_layout namespace transform { -Pass AlterOpLayout() { +Pass AlterOpLayout(const std::string& alter_map_attr_name) { runtime::TypedPackedFunc pass_func = [=](Function f, Module m, PassContext pc) { - return Downcast(relay::alter_op_layout::AlterOpLayout(f)); + return Downcast(relay::alter_op_layout::AlterOpLayout(f, alter_map_attr_name)); }; return CreateFunctionPass(pass_func, 3, "AlterOpLayout", {ir::StringImm::make("InferType")}); diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index c2c3cb5df48dd..484fea59d6f95 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -49,7 +49,8 @@ def convert_to_list(x): x = [x] return x -def run_tvm_graph(tflite_model_buf, input_data, input_node, num_output=1, target='llvm', +def run_tvm_graph(tflite_model_buf, input_data, input_node, num_output=1, + target='llvm -mcpu=cascadelake', out_names=None): """ Generic function to compile on relay and execute on tvm """ try: @@ -72,6 +73,8 @@ def run_tvm_graph(tflite_model_buf, input_data, input_node, num_output=1, target mod, params = relay.frontend.from_tflite(tflite_model, shape_dict=shape_dict, dtype_dict=dtype_dict) + mod = relay.transform.ConvertLayout('NCHW')(mod) + import tvm with relay.build_config(opt_level=3): graph, lib, params = relay.build(mod, target, params=params) @@ -1101,59 +1104,59 @@ def test_forward_ssd_mobilenet_v1(): # Main # ---- if __name__ == '__main__': - # BatchToSpaceND - test_forward_batch_to_space_nd() + # # BatchToSpaceND + # test_forward_batch_to_space_nd() - # SpaceToBatchND - test_forward_space_to_batch_nd() + # # SpaceToBatchND + # test_forward_space_to_batch_nd() - # Split - test_forward_split() + # # Split + # test_forward_split() - # Transpose - test_forward_transpose() + # # Transpose + # test_forward_transpose() - # Cast - test_forward_cast() + # # Cast + # test_forward_cast() - # Tile - test_forward_tile() + # # Tile + # test_forward_tile() - # Transforms - test_forward_concatenation() - test_forward_pad() - test_forward_pack() - test_forward_reshape() - test_all_resize() - test_forward_squeeze() + # # Transforms + # test_forward_concatenation() + # test_forward_pad() + # test_forward_pack() + # test_forward_reshape() + # test_all_resize() + # test_forward_squeeze() - # NN - test_forward_convolution() - test_forward_logistic() - test_forward_pooling() - test_forward_softmax() - test_forward_tanh() - test_forward_relu() - test_forward_prelu() - test_forward_fully_connected() + # # NN + # test_forward_convolution() + # test_forward_logistic() + # test_forward_pooling() + # test_forward_softmax() + # test_forward_tanh() + # test_forward_relu() + # test_forward_prelu() + # test_forward_fully_connected() - # Elemwise - test_all_elemwise() + # # Elemwise + # test_all_elemwise() - # Zeros Like - test_forward_zeros_like() + # # Zeros Like + # test_forward_zeros_like() - # Reduce - test_all_reduce() + # # Reduce + # test_all_reduce() - # End to End + # # End to End test_forward_mobilenet_v1() - test_forward_mobilenet_v2() + # test_forward_mobilenet_v2() test_forward_inception_v3_net() - test_forward_inception_v4_net() - test_forward_ssd_mobilenet_v1() + # test_forward_inception_v4_net() + # test_forward_ssd_mobilenet_v1() - # End to End quantized - test_forward_qnn_inception_v1_net() - test_forward_qnn_mobilenet_v1_net() - test_forward_qnn_mobilenet_v2_net() + # # End to End quantized + # test_forward_qnn_inception_v1_net() + # test_forward_qnn_mobilenet_v1_net() + # test_forward_qnn_mobilenet_v2_net()