From a0661076458ea72f075ebeeb5ca05d55ba58ebd9 Mon Sep 17 00:00:00 2001 From: Animesh Jain Date: Wed, 13 Nov 2019 06:54:33 +0000 Subject: [PATCH] [Relay][WIP] ConvertLayout pass. --- include/tvm/relay/transform.h | 2 +- python/tvm/relay/op/__init__.py | 2 +- python/tvm/relay/op/op.py | 17 + python/tvm/relay/transform.py | 69 +++- src/relay/backend/build_module.cc | 2 + src/relay/ir/op.cc | 15 +- src/relay/pass/alter_op_layout.cc | 40 ++- tests/python/frontend/tflite/test_forward.py | 5 +- .../python/relay/test_pass_convert_layout.py | 329 ++++++++++++++++++ 9 files changed, 459 insertions(+), 22 deletions(-) create mode 100644 tests/python/relay/test_pass_convert_layout.py diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 10de08710fbe3..ce6375550f6ef 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -530,7 +530,7 @@ TVM_DLL Pass CanonicalizeOps(); * * \return The pass. */ -TVM_DLL Pass AlterOpLayout(); +TVM_DLL Pass AlterOpLayout(const std::string& alter_map_attr_name = "FTVMAlterOpLayout"); /*! * \brief Legalizes an expr with another expression. diff --git a/python/tvm/relay/op/__init__.py b/python/tvm/relay/op/__init__.py index a089cab669c92..2d9af8a741b10 100644 --- a/python/tvm/relay/op/__init__.py +++ b/python/tvm/relay/op/__init__.py @@ -18,7 +18,7 @@ """Relay core operators.""" # operator defs from .op import get, register, register_schedule, register_compute, register_gradient, \ - register_pattern, register_alter_op_layout, register_legalize, \ + register_pattern, register_alter_op_layout, register_convert_op_layout, register_legalize, \ schedule_injective, Op, OpPattern, debug # Operators diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py index fcbc3fd544479..e7060e31b3761 100644 --- a/python/tvm/relay/op/op.py +++ b/python/tvm/relay/op/op.py @@ -186,6 +186,23 @@ def register_alter_op_layout(op_name, alter_layout=None, level=10): return register(op_name, "FTVMAlterOpLayout", alter_layout, level) +def register_convert_op_layout(op_name, convert_layout=None, level=10): + """Register alter op layout function for an op + + Parameters + ---------- + op_name : str + The name of the operator + + alter_layout: function (attrs: Attrs, inputs: List[Expr]) -> new_expr: Expr + The function for changing the layout or replacing the operator + + level : int + The priority level + """ + return register(op_name, "FTVMConvertOpLayout", convert_layout, level) + + def register_legalize(op_name, legal_op=None, level=10): """Register legal transformation function for an op diff --git a/python/tvm/relay/transform.py b/python/tvm/relay/transform.py index d3509dabddf99..efeda38a3a8aa 100644 --- a/python/tvm/relay/transform.py +++ b/python/tvm/relay/transform.py @@ -430,18 +430,23 @@ def CombineParallelDense(min_num_branches=3): return _transform.CombineParallelDense(min_num_branches) -def AlterOpLayout(): +def AlterOpLayout(alter_map_attr_name="FTVMAlterOpLayout"): """Alternate the layouts of operators or replace primitive operators with other expressions. This pass can be used for computing convolution in custom layouts or other general weight pre-transformation. + Parameters + ---------- + alter_map_attr_name : str + The Op's attr name which corresponds to the alter rule function. + Returns ------- ret : tvm.relay.Pass The registered pass that alters the layout of operators. """ - return _transform.AlterOpLayout() + return _transform.AlterOpLayout(alter_map_attr_name) def Legalize(legalize_map_attr_name="FTVMLegalize"): @@ -979,3 +984,63 @@ def visit_var(self, var): else: return var return ChangeBatchMutator().visit(func) + +@function_pass(opt_level=1) +class ConvertLayout: + """ + Given a dest layout, this pass transforms the expr such that most of the ops input data layout + is changed to the dest layout. In ideal situation, there are only 2 layout transforms, one at + the start and one at the end. + + This pass is not a part of relay.build and is expected to be called between framework-relay + parser and relay.build call. This is very helpful for hardware backends that support/prefer only + type of data layout. + + RFC - https://discuss.tvm.ai/t/layout-conversion-pass/4009 + + This pass uses most of the AlterOpLayout and InferCorrectLayout infrastructure. We can just + overwrite conv2d alter op layout. Most of the other operators try to adapt to their input layout + using the InferCorrectLayout infrastructure. + + Parameters + ---------- + dst_layout: str + The desired layout for the transformed expr. + + Returns + ------- + pass: FunctionPass + The pass. + """ + + def __init__(self, dst_layout): + assert dst_layout == 'NCHW', \ + 'Currently, only NCHW layout conversion is supported.' + + # Register convert layout function for conv2d op. + from tvm.relay.op import register_convert_op_layout + @register_convert_op_layout("nn.conv2d") + def alter_conv2d(attrs, inputs, tinfos): + data_layout = attrs['data_layout'] + kernel_layout = attrs['kernel_layout'] + data, weight = inputs + if dst_layout == 'NCHW': + new_attrs = dict(attrs) + new_attrs['data_layout'] = dst_layout + new_attrs['kernel_layout'] = 'OIHW' + + if data_layout == 'NHWC' and kernel_layout == 'HWIO': + # Convert (NHWC, HWIO) to (NCHW, OIHW) + return relay.nn.conv2d(data, weight, **new_attrs) + elif data_layout == 'NHWC' and kernel_layout == 'HWOI': + # Convert (NHWC, HWOI) to (NCHW, OIHW). Depthwise conv2d. + return relay.nn.conv2d(data, weight, **new_attrs) + return None + return None + + def transform_function(self, func, mod, ctx): + cur_mod = relay.Module.from_expr(func) + cur_mod = CanonicalizeOps()(cur_mod) + cur_mod = AlterOpLayout("FTVMConvertOpLayout")(cur_mod) + cur_mod = FoldConstant()(cur_mod) + return cur_mod['main'] diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 9254c7e3e7b9a..6c5d7e0cfffbb 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -345,6 +345,8 @@ class RelayBuildModule : public runtime::ModuleNode { } else { relay_module = seq(relay_module); } + auto func1 = relay_module->Lookup("main"); + LOG(INFO) << AsText(func1, false); // Handle heterogeneous compilation. transform::PassContext pass_ctx = PassContext::Current(); diff --git a/src/relay/ir/op.cc b/src/relay/ir/op.cc index c4557ac16ad5d..4cbd7ea21bd0e 100644 --- a/src/relay/ir/op.cc +++ b/src/relay/ir/op.cc @@ -109,13 +109,14 @@ void OpRegistry::UpdateAttr(const std::string& key, if (op_map->data_.size() <= index) { op_map->data_.resize(index + 1, std::make_pair(TVMRetValue(), 0)); } - std::pair& p = op_map->data_[index]; - CHECK(p.second != plevel) - << "Attribute " << key << " of operator " << this->name - << " is already registered with same plevel=" << plevel; - if (p.second < plevel) { - op_map->data_[index] = std::make_pair(value, plevel); - } + op_map->data_[index] = std::make_pair(value, plevel); + // std::pair& p = op_map->data_[index]; + // CHECK(p.second != plevel) + // << "Attribute " << key << " of operator " << this->name + // << " is already registered with same plevel=" << plevel; + // if (p.second < plevel) { + // op_map->data_[index] = std::make_pair(value, plevel); + // } } // Frontend APIs diff --git a/src/relay/pass/alter_op_layout.cc b/src/relay/pass/alter_op_layout.cc index bbfb97c56dc2f..9936ec36495aa 100644 --- a/src/relay/pass/alter_op_layout.cc +++ b/src/relay/pass/alter_op_layout.cc @@ -185,8 +185,9 @@ std::tuple, Array, bool> CallInfer( // Call registered FTVMAlterOpLayout of an op // Returns the altered expression Call CallAlter(const Call& ref_call, - const std::vector& new_args) { - static auto falter_layout = Op::GetAttr("FTVMAlterOpLayout"); + const std::vector& new_args, + const std::string& alter_map_attr_name) { + auto falter_layout = Op::GetAttr(alter_map_attr_name); Op op = Downcast(ref_call->op); Expr new_e; @@ -213,9 +214,10 @@ Call CallAlter(const Call& ref_call, return GetRef(new_call); } -Expr AlterOpLayoutRewrite(const Call &ref_call, - const Array &new_args, - const NodeRef& ctx) { +Expr Rewriter(const Call &ref_call, + const Array &new_args, + const NodeRef& ctx, + const std::string& alter_map_attr_name) { std::vector inputs; std::vector normal_new_args; Array > input_shapes; @@ -289,7 +291,7 @@ Expr AlterOpLayoutRewrite(const Call &ref_call, } // new_op = alter(op) - Call new_call = CallAlter(ref_call, normal_new_args); + Call new_call = CallAlter(ref_call, normal_new_args, alter_map_attr_name); // new_in2, new_out = op.infer(new_in) if (new_call->op->IsInstance()) { @@ -353,26 +355,44 @@ Expr AlterOpLayoutRewrite(const Call &ref_call, } } +Expr AlterOpLayoutRewrite(const Call &ref_call, + const Array &new_args, + const NodeRef& ctx) { + return Rewriter(ref_call, new_args, ctx, "FTVMAlterOpLayout"); +} + +Expr ConvertOpLayoutRewrite(const Call &ref_call, + const Array &new_args, + const NodeRef& ctx) { + return Rewriter(ref_call, new_args, ctx, "FTVMConvertOpLayout"); +} + // Limiations: // 1. the altered op should have the same number of arguments as the previous one // 2. do not support nested tuple arguments -Expr AlterOpLayout(const Expr& expr) { +Expr AlterOpLayout(const Expr& expr, const std::string& alter_map_attr_name) { TransformMemorizer transformMemorizer(make_node()); auto fcontext = [&](const Call& call) -> NodeRef{ return transformMemorizer; }; - return ForwardRewrite(expr, AlterOpLayoutRewrite, fcontext); + if (alter_map_attr_name == "FTVMAlterOpLayout") { + return ForwardRewrite(expr, AlterOpLayoutRewrite, fcontext); + } else if (alter_map_attr_name == "FTVMConvertOpLayout") { + return ForwardRewrite(expr, ConvertOpLayoutRewrite, fcontext); + } + LOG(FATAL) << "AlterOpLayout supports only 2 attr names - FTVMAlterOpLayout/FTVMConvertOpLayout"; + return Expr(); } } // namespace alter_op_layout namespace transform { -Pass AlterOpLayout() { +Pass AlterOpLayout(const std::string& alter_map_attr_name) { runtime::TypedPackedFunc pass_func = [=](Function f, Module m, PassContext pc) { - return Downcast(relay::alter_op_layout::AlterOpLayout(f)); + return Downcast(relay::alter_op_layout::AlterOpLayout(f, alter_map_attr_name)); }; return CreateFunctionPass(pass_func, 3, "AlterOpLayout", {ir::StringImm::make("InferType")}); diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 83a0730d74f87..1943c7edce3bf 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -72,6 +72,9 @@ def run_tvm_graph(tflite_model_buf, input_data, input_node, num_output=1, target mod, params = relay.frontend.from_tflite(tflite_model, shape_dict=shape_dict, dtype_dict=dtype_dict) + # Convert to NCHW layout. + mod = relay.transform.ConvertLayout('NCHW')(mod) + with relay.build_config(opt_level=3): graph, lib, params = relay.build(mod, target, params=params) @@ -151,7 +154,7 @@ def compare_tflite_with_tvm(in_data, in_name, input_tensors, tflite_model_buffer = converter.convert() tflite_output = run_tflite_graph(tflite_model_buffer, in_data) - for device in ["llvm"]: + for device in ["llvm", "cuda"]: ctx = tvm.context(device, 0) if not ctx.exist: print("Skip because %s is not enabled" % device) diff --git a/tests/python/relay/test_pass_convert_layout.py b/tests/python/relay/test_pass_convert_layout.py new file mode 100644 index 0000000000000..3ac918f69a003 --- /dev/null +++ b/tests/python/relay/test_pass_convert_layout.py @@ -0,0 +1,329 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test alter op layout pass""" +import tvm + +from tvm import relay +from tvm.relay.op import register_alter_op_layout +from tvm.relay import transform, analysis + + +def run_opt_pass(expr, passes): + passes = passes if isinstance(passes, list) else [passes] + mod = relay.Module.from_expr(expr) + seq = transform.Sequential(passes) + with transform.PassContext(opt_level=3): + mod = seq(mod) + entry = mod["main"] + return entry if isinstance(expr, relay.Function) else entry.body + + +def test_no_convert_layout(): + def before(): + x = relay.var("x", shape=(1, 64, 56, 56)) + weight = relay.var('weight', shape=(64, 64, 3, 3)) + y = relay.nn.conv2d(x, weight, + channels=64, + kernel_size=(3, 3), + padding=(1, 1)) + y = relay.nn.relu(y) + y = relay.Function([x, weight], y) + return y + + def expected(): + return before() + + a = before() + a = run_opt_pass(a, transform.ConvertLayout('NCHW')) + b = run_opt_pass(expected(), transform.InferType()) + + assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + + +def test_conv_convert_layout(): + def before(): + x = relay.var("x", shape=(1, 56, 56, 64)) + weight = relay.var('weight', shape=(3, 3, 64, 64)) + y = relay.nn.conv2d(x, weight, + channels=64, + kernel_size=(3, 3), + padding=(1, 1), + data_layout='NHWC', + kernel_layout='HWIO') + y = relay.nn.relu(y) + y = relay.Function([x, weight], y) + return y + + def expected(): + x = relay.var("x", shape=(1, 56, 56, 64)) + weight = relay.var('weight', shape=(3, 3, 64, 64)) + x = relay.layout_transform(x, 'NHWC', 'NCHW') + weight = relay.layout_transform(weight, 'HWIO', 'OIHW') + y = relay.nn.conv2d(x, weight, + channels=64, + kernel_size=(3, 3), + padding=(1, 1)) + y = relay.nn.relu(y) + y = relay.layout_transform(y, 'NCHW', 'NHWC') + y = relay.Function(relay.analysis.free_vars(y), y) + return y + + a = before() + a = run_opt_pass(a, transform.ConvertLayout('NCHW')) + b = run_opt_pass(expected(), transform.InferType()) + + assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + + +def test_conv_bias_pool_convert_layout(): + def before(): + x = relay.var("x", shape=(1, 56, 56, 64)) + bias = relay.var("bias", shape=(64,)) + weight = relay.var("weight", shape=(3, 3, 64, 64)) + y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1), + data_layout='NHWC', kernel_layout='HWIO') + y = relay.nn.bias_add(y, bias, axis=3) + # a useless tuple, which will be eliminated + y = relay.Tuple([y])[0] + y = relay.nn.relu(y) + y = relay.nn.max_pool2d(y, pool_size=(2, 2), layout='NHWC') + y = relay.cast(y, 'int32') + y = relay.nn.batch_flatten(y) + y = relay.Function(analysis.free_vars(y), y) + return y + + def expected(): + x = relay.var("x", shape=(1, 56, 56, 64)) + bias = relay.var("bias", shape=(64,)) + weight = relay.var("weight", shape=(3, 3, 64, 64)) + x = relay.layout_transform(x, 'NHWC', 'NCHW') + weight = relay.layout_transform(weight, 'HWIO', 'OIHW') + y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1)) + + bias = relay.expand_dims(bias, axis=0, num_newaxis=3) + bias = relay.layout_transform(bias, 'NHWC', 'NCHW') + y = relay.add(y, bias) + # a useless tuple, which will be eliminated + y = relay.Tuple([y])[0] + y = relay.nn.relu(y) + y = relay.nn.max_pool2d(y, pool_size=(2, 2)) + y = relay.cast(y, 'int32') + y = relay.layout_transform(y, 'NCHW', 'NHWC') + y = relay.nn.batch_flatten(y) + y = relay.Function(analysis.free_vars(y), y) + return y + + a = before() + a = run_opt_pass(a, transform.ConvertLayout('NCHW')) + b = run_opt_pass(expected(), transform.InferType()) + + assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + + +def test_conv_concat_convert_layout(): + def before(): + x = relay.var("x", shape=(1, 56, 56, 64)) + weight1 = relay.var('weight1', shape=(3, 3, 64, 64)) + weight2 = relay.var('weight2', shape=(3, 3, 64, 64)) + y = relay.nn.conv2d(x, weight1, + channels=64, + kernel_size=(3, 3), + padding=(1, 1), + data_layout='NHWC', + kernel_layout='HWIO') + y1 = relay.nn.conv2d(y, weight2, + channels=64, + kernel_size=(3, 3), + padding=(1, 1), + data_layout='NHWC', + kernel_layout='HWIO') + ret = relay.concatenate([y, y1], axis=3) + y = relay.Function(analysis.free_vars(ret), ret) + return y + + def expected(): + x = relay.var("x", shape=(1, 56, 56, 64)) + weight1 = relay.var('weight1', shape=(3, 3, 64, 64)) + weight2 = relay.var('weight2', shape=(3, 3, 64, 64)) + weight1 = relay.layout_transform(weight1, 'HWIO', 'OIHW') + weight2 = relay.layout_transform(weight2, 'HWIO', 'OIHW') + y = relay.layout_transform(x, "NHWC", "NCHW") + y = relay.nn.conv2d(y, weight1, + channels=64, + kernel_size=(3, 3), + padding=(1, 1)) + y1 = relay.nn.conv2d(y, weight2, + channels=64, + kernel_size=(3, 3), + padding=(1, 1)) + ret = relay.concatenate([y, y1], axis=1) + ret = relay.layout_transform(ret, "NCHW", "NHWC") + y = relay.Function(analysis.free_vars(ret), ret) + return y + + a = before() + a = run_opt_pass(a, transform.ConvertLayout('NCHW')) + b = run_opt_pass(expected(), transform.InferType()) + + assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + + +def test_dual_path_convert_layout(): + def before(): + x = relay.var("x", shape=(1, 56, 56, 64)) + weight1 = relay.var('weight1', shape=(3, 3, 64, 32)) + weight2 = relay.var('weight2', shape=(3, 3, 32, 32)) + y = relay.nn.conv2d(x, weight1, + channels=32, + kernel_size=(3, 3), + padding=(1, 1), + data_layout='NHWC', + kernel_layout='HWIO') + y = relay.nn.relu(y) + y1 = relay.nn.conv2d(y, weight2, + channels=32, + kernel_size=(3, 3), + padding=(1, 1), + data_layout='NHWC', + kernel_layout='HWIO') + y1 = relay.nn.relu(y1) + y2 = relay.nn.batch_flatten(y) + ret = relay.Tuple([y1, y2]) + y = relay.Function(analysis.free_vars(ret), ret) + return y + + def expected(): + x = relay.var("x", shape=(1, 56, 56, 64)) + weight1 = relay.var('weight1', shape=(3, 3, 64, 32)) + weight2 = relay.var('weight2', shape=(3, 3, 32, 32)) + weight1 = relay.layout_transform(weight1, 'HWIO', 'OIHW') + weight2 = relay.layout_transform(weight2, 'HWIO', 'OIHW') + y = relay.layout_transform(x, "NHWC", "NCHW") + y = relay.nn.conv2d(y, weight1, + channels=32, + kernel_size=(3, 3), + padding=(1, 1)) + y = relay.nn.relu(y) + y1 = relay.nn.conv2d(y, weight2, + channels=32, + kernel_size=(3, 3), + padding=(1, 1)) + y1 = relay.nn.relu(y1) + y1 = relay.layout_transform(y1, "NCHW", "NHWC") + y2 = relay.layout_transform(y, "NCHW", "NHWC") + y2 = relay.nn.batch_flatten(y2) + ret = relay.Tuple([y1, y2]) + y = relay.Function(analysis.free_vars(ret), ret) + return y + + a = before() + a = run_opt_pass(a, transform.ConvertLayout('NCHW')) + b = run_opt_pass(expected(), transform.InferType()) + + assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + + +def test_resnet_convert_layout(): + def before(): + x = relay.var("x", shape=(1, 56, 56, 64)) + weight1 = relay.var('weight1', shape=(3, 3, 64, 32)) + weight2 = relay.var('weight2', shape=(1, 1, 64, 32)) + y = relay.nn.conv2d(x, weight1, + channels=32, + kernel_size=(3, 3), + padding=(1, 1), + data_layout='NHWC', + kernel_layout='HWIO') + y = relay.nn.relu(y) + y2 = relay.nn.conv2d(x, weight2, + channels=32, + kernel_size=(1, 1), + data_layout='NHWC', + kernel_layout='HWIO') + y2 = relay.nn.relu(y2) + y = y + y2 + y = relay.nn.global_max_pool2d(y, layout='NHWC') + return relay.Function(analysis.free_vars(y), y) + + def expected(): + x = relay.var("x", shape=(1,56, 56, 64)) + weight1 = relay.var('weight1', shape=(3, 3, 64, 32)) + weight2 = relay.var('weight2', shape=(1, 1, 64, 32)) + weight1 = relay.layout_transform(weight1, 'HWIO', 'OIHW') + weight2 = relay.layout_transform(weight2, 'HWIO', 'OIHW') + x = relay.layout_transform(x, "NHWC", "NCHW") + y = relay.nn.conv2d(x, weight1, + channels=32, + kernel_size=(3, 3), + padding=(1, 1)) + y = relay.nn.relu(y) + y2 = relay.nn.conv2d(x, weight2, + channels=32, + kernel_size=(1, 1)) + y2 = relay.nn.relu(y2) + y = y + y2 + y = relay.nn.global_max_pool2d(y) + y = relay.layout_transform(y, "NCHW", "NHWC") + return relay.Function(analysis.free_vars(y), y) + + a = before() + a = run_opt_pass(a, transform.ConvertLayout('NCHW')) + b = run_opt_pass(expected(), transform.InferType()) + + assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + + +def test_scalar_convert_layout(): + def before(): + x = relay.var("x", shape=(1, 56, 56, 64)) + weight = relay.var("weight", shape=(3, 3, 64, 64)) + y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1), + data_layout='NHWC', kernel_layout='HWIO') + y = relay.add(y, relay.const(1, "float32")) + y = relay.Function(analysis.free_vars(y), y) + return y + + def expected(): + x = relay.var("x", shape=(1, 56, 56, 64)) + w = relay.var("weight", shape=(3, 3, 64, 64)) + x = relay.layout_transform(x, 'NHWC', 'NCHW') + w = relay.layout_transform(w, 'HWIO', 'OIHW') + y = relay.nn.conv2d(x, w, + channels=64, + kernel_size=(3, 3), + padding=(1, 1)) + y = relay.add(y, relay.const(1.0, "float32")) + + y = relay.layout_transform(y, "NCHW", "NHWC") + y = relay.Function(analysis.free_vars(y), y) + return y + + a = before() + a = run_opt_pass(a, transform.ConvertLayout('NCHW')) + b = run_opt_pass(expected(), transform.InferType()) + + assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + + +if __name__ == "__main__": + test_no_convert_layout() + test_conv_convert_layout() + test_conv_bias_pool_convert_layout() + test_conv_concat_convert_layout() + test_dual_path_convert_layout() + test_resnet_convert_layout() + test_scalar_convert_layout()