From c4950bcada31dae3a2a02d4a7e244217722e6d35 Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Thu, 11 Jun 2020 22:47:22 +0000 Subject: [PATCH 1/4] [TIDL] Add reduce subgraph size pass --- python/tvm/relay/backend/contrib/tidl.py | 256 +++++++++++++++++- tests/python/relay/test_tidl_gluon.py | 114 ++++++++ .../relay/test_tidl_reduce_subgraph_size.py | 178 ++++++++++++ 3 files changed, 541 insertions(+), 7 deletions(-) create mode 100644 tests/python/relay/test_tidl_gluon.py create mode 100644 tests/python/relay/test_tidl_reduce_subgraph_size.py diff --git a/python/tvm/relay/backend/contrib/tidl.py b/python/tvm/relay/backend/contrib/tidl.py index 5808cae83b4b..05175c24b959 100755 --- a/python/tvm/relay/backend/contrib/tidl.py +++ b/python/tvm/relay/backend/contrib/tidl.py @@ -222,6 +222,9 @@ def tensor_quant_flatten(input_tensor, data_layout): return output, scale, sign class VarReplacer(ExprMutator): + """ + Replaces vars in expr according to var_map. + """ def __init__(self, var_map): ExprMutator.__init__(self) self.var_map = var_map @@ -247,10 +250,7 @@ def visit_call(self, call): return super().visit_call(call) for func in mod.get_global_vars(): - name = func.name_hint - if not mod[name].attrs or mod[name].attrs["Compiler"] != compiler: - continue - mod[name] = Unpacker().visit(mod[name]) + mod[func.name_hint] = Unpacker().visit(mod[func.name_hint]) return mod class CalibrationGraphMutator(ExprMutator): @@ -366,18 +366,42 @@ def generate_subgraph_tensors(tidl_target, mod, params, input_node, input_data): return subgraph_tensors +class ExprReplacer(ExprMutator): + """ + Replaces call nodes in expr according to call_map + """ + def __init__(self, call_map): + ExprMutator.__init__(self) + self.call_map = call_map + + def visit_call(self, call): + if call in self.call_map: + return self.call_map[call] + return super().visit_call(call) + class VarRenamer(ExprMutator): + """ + Renames vars to match the new subgraph name. Used when subgraphs are renamed starting from zero. + If subgraph was originally "tidl_34", it would have inputs named like "tidl_34_i0". + IF new_subgraph_name is "tidl_0", pass will that input to "tidl_0_i0". + """ def __init__(self, new_subgraph_name): ExprMutator.__init__(self) self.new_subgraph_name = new_subgraph_name def visit_var(self, var): - if "_".join(var.name_hint.split('_')[:2]) != self.new_subgraph_name: + # TODO: Make sure input isn't from a composite func. + # TODO: Doesn't account for tuple inputs (not possible due to PruneSubgraphsWithMoreThanOneInput) + if var.name_hint.startswith("tidl") and "_".join(var.name_hint.split('_')[:2]) != self.new_subgraph_name: new_var_name = self.new_subgraph_name + "_" + var.name_hint.split('_')[2] return relay.Var(new_var_name, var.checked_type) return super().visit_var(var) class SubgraphRemover(ExprMutator): + """ + Removes subgraphs which are in the list subgraphs_to_remove and returns them back to regular + TVM compilation in main function. + """ def __init__(self, subgraphs_to_remove, mod, new_mod, rename_starting_from_0=True): ExprVisitor.__init__(self) self.subgraphs_to_remove = subgraphs_to_remove @@ -417,6 +441,223 @@ def visit_call(self, call): return subgraph_gv(*args) return super().visit_call(call) +class SubgraphSizeCounter(ExprVisitor): + """ + Pass to count size of subgraph, both number of layers and estimated total memory usage. + Used by SubgraphReducer pass. + """ + def __init__(self): + ExprVisitor.__init__(self) + self.num_layers = 0 + self.total_memory = 0 + + def get_total_memory_mb(self): + return self.total_memory / (1024.0 * 1024.0) + + def visit_call(self, call): + # Don't visit composite function body. + for arg in call.args: + super().visit(arg) + self.num_layers += 1 + # Add total size of weights (16 bits per) + for arg in call.args: + if isinstance(arg, tvm.relay.expr.Constant): + self.total_memory += 2 * np.prod(list(map(int, arg.checked_type.shape))) + # Add activation size (8 bits per) + if isinstance(call.checked_type, tvm.relay.TensorType): + self.total_memory += np.prod(list(map(int, call.checked_type.shape))) + +def FindCommonAncestor(expr0, expr1): + """ + Find the closest common ancestor to expr0 and expr1. + Returns distance from both. + Used by SubgraphReducer pass. + """ + class CommonAncestor(ExprVisitor): + """ + Creates a map of node -> distance from expr + """ + def __init__(self): + ExprVisitor.__init__(self) + self.ancestors_with_distance = {} + self.call_outputs = {} + + def Find(self, expr): + self.ancestors_with_distance[expr] = 0 + self.call_outputs[expr] = [] + super().visit(expr) + + def _update(self, expr, expr_inputs): + for arg in expr_inputs: + if arg in self.call_outputs and expr not in self.call_outputs[arg]: + self.call_outputs[arg].append(expr) + else: + self.call_outputs[arg] = [expr] + + if expr in self.call_outputs and len(self.call_outputs[expr]) > 0: + self.ancestors_with_distance[expr] = max([self.ancestors_with_distance[output] for output in self.call_outputs[expr]]) + 1 + else: + # Op did not have any outputs that we have already visited. + self.ancestors_with_distance[expr] = 0 + + def visit_tuple_getitem(self, tuplegetitem): + self._update(tuplegetitem, [tuplegetitem.tuple_value]) + super().visit_tuple_getitem(tuplegetitem) + + def visit_tuple(self, tup): + self._update(tup, tup.fields) + super().visit_tuple(tup) + + def visit_call(self, call): + self._update(call, call.args) + # Don't visit function body + # We don't care what's inside composite functions, we will just remove the whole func. + for arg in call.args: + super().visit(arg) + + common0 = CommonAncestor() + common0.Find(expr0) + common1 = CommonAncestor() + common1.Find(expr1) + # Find common + first_common_ancestor = None + distance_to_0 = 999999 + distance_to_1 = 999999 + for node in common0.ancestors_with_distance: + if node in common1.ancestors_with_distance: + if common0.ancestors_with_distance[node] <= distance_to_0 and common1.ancestors_with_distance[node] <= distance_to_1: + first_common_ancestor = node + distance_to_0 = common0.ancestors_with_distance[node] + distance_to_1 = common1.ancestors_with_distance[node] + return first_common_ancestor, distance_to_0, distance_to_1 + +class SubgraphReducer(ExprMutator): + """ + Removes a single op from end of subgraphs which exceed max_num_layers or max_total_memory_mb. + If an op is removed, reduced will be set to True. + """ + def __init__(self, mod, new_mod, max_num_layers=256, max_total_memory_mb=512, compiler="tidl"): + ExprVisitor.__init__(self) + self.mod = mod + self.new_mod = new_mod + self.max_num_layers = max_num_layers + self.max_total_memory_mb = max_total_memory_mb + self.compiler = compiler + self.reduced = False + + def visit_call(self, call): + if isinstance(call.op, GlobalVar): + name = call.op.name_hint + if not self.mod[name].attrs or self.mod[name].attrs["Compiler"] != self.compiler: + return super().visit_call(call) + # Compute size of subgraph to see if we need to reduce it. + counter = SubgraphSizeCounter() + counter.visit(self.mod[name]) + if counter.num_layers > self.max_num_layers or counter.get_total_memory_mb() > self.max_total_memory_mb: + # Mark that we have reduced the subgraph size. + self.reduced = True + # "Inline" the last op only back into new main function. + original_func = self.mod[name] + # Get last_op + last_op = original_func.body + last_op_args = [] + if isinstance(last_op, tvm.relay.expr.Tuple): + # Subgraph has multiple outputs! + assert len(last_op.fields) == 2 + ancestor, dist0, dist1 = FindCommonAncestor(last_op.fields[0], last_op.fields[1]) + if dist0 == 0 and dist1 == 0: + last_op_args = ancestor.args + elif dist0 > dist1: + # field[0] is further from LCA. + # Remove it by replacing it with its args. + # Keep field[1] + last_op_args = [] + for arg in last_op.fields[0].args: + if arg != last_op.fields[1]: + last_op_args.append(arg) + last_op_args.append(last_op.fields[1]) + elif dist1 >= dist0: + # field[1] is further from LCA. + # Remove it by replacing it with its args. + # Keep field[0] + last_op_args = [last_op.fields[0]] + for arg in last_op.fields[1].args: + if arg != last_op_args[0]: + last_op_args.append(arg) + elif isinstance(last_op, tvm.relay.expr.Call): + last_op_args = last_op.args + else: + raise ValueError("Input to last op is not call or tuple") + # Gather new outputs of the subgraph - from removed op's inputs + # This map will map Expr to index in new_outputs tuple + #print('last_op_args', last_op_args) + new_outputs = [] + last_op_input_to_new_output_map = {} + if len(last_op_args) > 1: + for arg in last_op_args: + # Skip weights + if not isinstance(arg, tvm.relay.expr.Constant): + new_outputs.append(arg) + last_op_input_to_new_output_map[arg] = len(new_outputs) - 1 + if len(new_outputs) > 1: + new_outputs_expr = relay.Tuple(new_outputs) + elif len(new_outputs) == 1: + new_outputs_expr = new_outputs[0] + else: + raise ValueError("No ops left in subgraph after reducing size") + else: + new_outputs = [last_op_args[0]] + new_outputs_expr = new_outputs[0] + subgraph_gv = relay.GlobalVar(name) + + # construct new func without last_op + new_func = relay.Function(original_func.params, new_outputs_expr) + new_func = new_func.with_attr("Primitive", tvm.tir.IntImm("int32", 1)) + new_func = new_func.with_attr("Inline", tvm.tir.IntImm("int32", 1)) + new_func = new_func.with_attr("Compiler", self.compiler) + new_func = new_func.with_attr("global_symbol", name) + self.new_mod[subgraph_gv] = new_func + args = [] + for arg in call.args: + args.append(super().visit(arg)) + new_expr = subgraph_gv(*args) + if len(new_outputs) > 1: + call_map = {arg: relay.TupleGetItem(new_expr, index) for arg, index in last_op_input_to_new_output_map.items()} + else: + call_map = {new_outputs[0]: new_expr} + new_expr = ExprReplacer(call_map).visit(last_op) + + return new_expr + elif name != "main": + # Transfer subgraph to new mod without modifying + args = [] + for arg in call.args: + args.append(super().visit(arg)) + subgraph_gv = relay.GlobalVar(name) + self.new_mod[subgraph_gv] = self.mod[name] + return subgraph_gv(*args) + return super().visit_call(call) + +def ReduceSubgraphSize(mod, compiler="tidl", max_num_layers=256, max_total_memory_mb=512): + """ + Reduces size of subgraph to fit limitations. + """ + # Counter just in case to avoid infinite loop. + sanity_counter = 10000 + # SubgraphReducer removes one op if the subgraph is above the limits. + # Repeated call SubgraphReducer until no subgraphs are reduced. + while sanity_counter > 0: + new_mod = tvm.IRModule() + reducer = SubgraphReducer(mod, new_mod, max_num_layers, max_total_memory_mb) + new_mod['main'] = reducer.visit(mod["main"]) + # If no subgraphs where reduced in size, we are done. + if not reducer.reduced: + return new_mod + mod = new_mod + # Avoid infinite loop. + sanity_counter -= 1 + return mod + def PruneSubgraphsWithMoreThanOneInput(mod, compiler="tidl"): subgraph_names_to_remove = [] # Remove subgraphs with more than 1 input or tuple inputs. @@ -1282,7 +1523,7 @@ class TIDLCompiler: Folder to hold TIDL artifacts """ - def __init__(self, platform, version, **kwargs): + def __init__(self, platform, version, max_num_layers=256, max_total_memory_mb=512, **kwargs): if platform == "AM57" and version >= (6,3): for key in ('num_tidl_subgraphs', 'data_layout', 'artifacts_folder', 'tidl_tools_path'): if key in kwargs: @@ -1337,8 +1578,9 @@ def enable(self, mod_orig, params, input): mod = transform.AnnotateTarget(self.tidl_target)(mod) mod = transform.MergeCompilerRegions()(mod) mod = transform.PartitionGraph()(mod) - mod = UnpackComposites(mod, compiler=self.tidl_target) mod = PruneSubgraphsWithMoreThanOneInput(mod, compiler=self.tidl_target) + mod = ReduceSubgraphSize(mod, max_num_layers=max_num_layers, max_total_memory_mb=max_total_memory_mb, compiler=self.tidl_target) + mod = UnpackComposites(mod, compiler=self.tidl_target) mod = PruneSubgraphs(mod, compiler=self.tidl_target, num_subgraphs_to_keep=self.num_tidl_subgraphs) #============= Generate subgraph boundary tensors ============== diff --git a/tests/python/relay/test_tidl_gluon.py b/tests/python/relay/test_tidl_gluon.py new file mode 100644 index 000000000000..4520d5d1420b --- /dev/null +++ b/tests/python/relay/test_tidl_gluon.py @@ -0,0 +1,114 @@ +import os +import sys +import numpy as np +from matplotlib import pyplot as plt + +import tvm +import tvm.relay.testing +from tvm import relay +from tvm.contrib import cc +from tvm.contrib import graph_runtime +from tvm.contrib.download import download_testdata +from tvm.relay.backend.contrib import tidl +import mxnet as mx +from mxnet import image +from mxnet.gluon.model_zoo.vision import get_model +import gluoncv + +def model_compile(model_name, mod_orig, params, input_data, num_tidl_subgraphs=4, + data_layout="NCHW", input_node="data"): + artifacts_folder = "./artifacts_" + model_name + if os.path.isdir(artifacts_folder): + filelist = [ f for f in os.listdir(artifacts_folder)] + for file in filelist: + os.remove(os.path.join(artifacts_folder, file)) + else: + os.mkdir(artifacts_folder) + + mod = tidl.EnableTIDL(mod_orig, params, num_tidl_subgraphs, + data_layout, input_node, input_data, + artifacts_folder, os.path.join(os.getenv("TIDL_TOOLS_PATH"), "eve_test_dl_algo_ref.out")) + # We expect somethign to be offloaded to TIDL. + assert mod is not None + + target = "llvm -target=armv7l-linux-gnueabihf" + graph, lib, params = relay.build_module.build(mod, target=target, params=params) + path_lib = os.path.join(artifacts_folder, "deploy_lib.so") + path_graph = os.path.join(artifacts_folder, "deploy_graph.json") + path_params = os.path.join(artifacts_folder, "deploy_param.params") + cc_path = os.path.join(os.getenv("ARM_GCC_PATH"), "arm-linux-gnueabihf-g++") + lib.export_library(path_lib, cc=cc_path) + with open(path_graph, "w") as fo: + fo.write(graph) + with open(path_params, "wb") as fo: + fo.write(relay.save_param_dict(params)) + + +def test_tidl_gluon(): + x = np.random.normal(0, 1, (1, 3, 224, 224)) #np.load(os.path.join(os.getenv("TIDL_TOOLS_PATH"), 'dog.npy')) + input_data = x/np.amax(np.abs(x)) + + def test_model(model, input_shape, dtype, use_tidl=True, num_iteration=1): + block = get_model(model, pretrained=True) + mod, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype) + model_compile(model, mod, params, input_node='data', input_data=input_data) + + latency = {} + models = [ + 'alexnet', + 'resnet18_v1', + 'resnet34_v1', + 'resnet50_v1', + 'resnet101_v1', + 'resnet152_v1', + 'resnet18_v2', + 'resnet34_v2', + 'resnet50_v2', + 'resnet101_v2', + 'resnet152_v2', + 'squeezenet1.0', + 'mobilenet0.25', + 'mobilenet0.5', + 'mobilenet0.75', + 'mobilenet1.0', + 'mobilenetv2_0.25', + 'mobilenetv2_0.5', + 'mobilenetv2_0.75', + 'mobilenetv2_1.0', + 'vgg11', + 'vgg16', + 'densenet121', + 'densenet169', + 'densenet201', + ] + + dtype = 'float32' + input_shape = (1, 3, 224, 224) + for model in models: + print('testing model:', model), + test_model(model, input_shape, dtype, use_tidl=True) + +def test_tidl_gluoncv(): + def test_model(model, input_shape, dtype, use_tidl=True, num_iteration=1): + block = gluoncv.model_zoo.get_model(model, pretrained=True) + mod, params = relay.frontend.from_mxnet(block, shape={'data': input_shape}, dtype=dtype) + + input_data = np.random.normal(0, 1, input_shape) + input_data = input_data/np.amax(np.abs(input_data)) + model_compile(model, mod, params, input_node='data', input_data=input_data) + + latency = {} + models = [ + ('deeplab_resnet50_ade', (1, 3, 480, 480)), + ('deeplab_resnet101_ade', (1, 3, 480, 480)), + ('yolo3_mobilenet1.0_coco', (1, 3, 224, 224)), + ] + + dtype = 'float32' + for model, input_shape in models: + print('testing model:', model) + test_model(model, input_shape, dtype, use_tidl=True) + +if __name__ == "__main__": + test_tidl_gluon() + test_tidl_gluoncv() diff --git a/tests/python/relay/test_tidl_reduce_subgraph_size.py b/tests/python/relay/test_tidl_reduce_subgraph_size.py new file mode 100644 index 000000000000..e9b3d42d9287 --- /dev/null +++ b/tests/python/relay/test_tidl_reduce_subgraph_size.py @@ -0,0 +1,178 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import numpy as np +import tvm +from tvm import relay +from tvm.relay.build_module import bind_params_by_name +from tvm.relay.backend.contrib.tidl import ReduceSubgraphSize +from test_pass_partition_graph import set_func_attr + +def test_reduce_subgraph_size_single_output(): + def create_graph(): + ishape = (1, 3, 12, 12) + x = relay.var('tidl_i0', shape=ishape, dtype='float32') + y = relay.nn.relu(x) + out = relay.nn.relu(y) + func = relay.Function([x], out) + func = set_func_attr(func, "tidl", "tidl_0") + gv = relay.GlobalVar("tidl_0") + + mod = tvm.IRModule() + mod[gv] = func + x_main = relay.var('x', shape=ishape, dtype='float32') + main_f = relay.Function([x_main], gv(x_main)) + mod['main'] = main_f + return mod + + def expected(): + ishape = (1, 3, 12, 12) + x = relay.var('tidl_i0', shape=ishape, dtype='float32') + out = relay.nn.relu(x) + func = relay.Function([x], out) + func = set_func_attr(func, "tidl", "tidl_0") + gv = relay.GlobalVar("tidl_0") + + mod = tvm.IRModule() + mod[gv] = func + x_main = relay.var('x', shape=ishape, dtype='float32') + call = gv(x_main) + out = relay.nn.relu(call) + main_f = relay.Function([x_main], out) + mod['main'] = main_f + return mod + + ref_mod = expected() + reduced = ReduceSubgraphSize(create_graph(), max_num_layers=1, compiler="tidl") + assert tvm.ir.structural_equal(reduced, ref_mod, map_free_vars=True) + +def test_reduce_subgraph_size_multiple_output(): + def create_graph(): + ishape = (1, 32, 14, 14) + w1shape = (32, 1, 3, 3) + dtype = "float32" + data0 = relay.var("tidl_0_i0", shape=(ishape), dtype=dtype) + input0 = relay.var("tidl_0_i1", shape=(w1shape), dtype=dtype) + input1 = relay.var("tidl_0_i2", shape=(w1shape), dtype=dtype) + params = {"tidl_0_i1": np.ones(w1shape, dtype="float32"), "tidl_0_i2": np.ones(w1shape, dtype="float32")} + depthwise_conv2d_1 = relay.nn.conv2d(data0, + input0, + kernel_size=(3, 3), + padding=(1, 1), + groups=32) + depthwise_conv2d_2 = relay.nn.conv2d(depthwise_conv2d_1, + input1, + kernel_size=(3, 3), + padding=(1, 1), + groups=32) + out = relay.add(depthwise_conv2d_1, depthwise_conv2d_2) + func = relay.Function([data0, input0, input1], out) + func = set_func_attr(func, "tidl", "tidl_0") + func = bind_params_by_name(func, params) + gv = relay.GlobalVar("tidl_0") + + mod = tvm.IRModule() + mod[gv] = func + x_main = relay.var('x', shape=ishape, dtype='float32') + main_f = relay.Function([x_main], gv(x_main)) + mod['main'] = main_f #bind_params_by_name(main_f, params) + return mod + + def expected_1(): + ishape = (1, 32, 14, 14) + w1shape = (32, 1, 3, 3) + dtype = "float32" + data0 = relay.var("tidl_0_i0", shape=(ishape), dtype=dtype) + input0 = relay.var("tidl_0_i1", shape=(w1shape), dtype=dtype) + input1 = relay.var("tidl_0_i2", shape=(w1shape), dtype=dtype) + params = {"tidl_0_i1": np.ones(w1shape, dtype="float32"), "tidl_0_i2": np.ones(w1shape, dtype="float32")} + depthwise_conv2d_1 = relay.nn.conv2d(data0, + input0, + kernel_size=(3, 3), + padding=(1, 1), + groups=32) + depthwise_conv2d_2 = relay.nn.conv2d(depthwise_conv2d_1, + input1, + kernel_size=(3, 3), + padding=(1, 1), + groups=32) + out = relay.Tuple([depthwise_conv2d_1, depthwise_conv2d_2]) + func = relay.Function([data0, input0, input1], out) + func = set_func_attr(func, "tidl", "tidl_0") + func = bind_params_by_name(func, params) + gv = relay.GlobalVar("tidl_0") + + mod = tvm.IRModule() + mod[gv] = func + x_main = relay.var('x', shape=ishape, dtype='float32') + call = gv(x_main) + get_output_0 = relay.TupleGetItem(call, 0) + get_output_1 = relay.TupleGetItem(call, 1) + out = relay.add(get_output_0, get_output_1) + main_f = relay.Function([x_main], out) + mod['main'] = bind_params_by_name(main_f, params) + return mod + + def expected_2(): + ishape = (1, 32, 14, 14) + w1shape = (32, 1, 3, 3) + dtype = "float32" + data0 = relay.var("tidl_0_i0", shape=(ishape), dtype=dtype) + input0 = relay.var("tidl_0_i1", shape=(w1shape), dtype=dtype) + input1 = relay.var("tidl_0_i2", shape=(w1shape), dtype=dtype) + params = {"tidl_0_i1": np.ones(w1shape, dtype="float32"), "tidl_0_i2": np.ones(w1shape, dtype="float32")} + depthwise_conv2d_1 = relay.nn.conv2d(data0, + input0, + kernel_size=(3, 3), + padding=(1, 1), + groups=32) + out = depthwise_conv2d_1 + func = relay.Function([data0, input0, input1], out) + func = set_func_attr(func, "tidl", "tidl_0") + func = bind_params_by_name(func, params) + gv = relay.GlobalVar("tidl_0") + + mod = tvm.IRModule() + mod[gv] = func + x_main = relay.var('x', shape=ishape, dtype='float32') + call = gv(x_main) + depthwise_conv2d_2 = relay.nn.conv2d(call, + input1, + kernel_size=(3, 3), + padding=(1, 1), + groups=32) + tup = relay.Tuple([call, depthwise_conv2d_2]) + get_output_0 = relay.TupleGetItem(tup, 0) + get_output_1 = relay.TupleGetItem(tup, 1) + out = relay.add(get_output_0, get_output_1) + main_f = relay.Function([x_main, input1], out) + mod['main'] = bind_params_by_name(main_f, params) + return mod + + # Will remove add. + ref_mod = expected_1() + reduced = ReduceSubgraphSize(create_graph(), max_num_layers=2, compiler="tidl") + assert tvm.ir.structural_equal(reduced, ref_mod, map_free_vars=True) + + # Will remove 2nd conv2d. + ref_mod = expected_2() + reduced = ReduceSubgraphSize(create_graph(), max_num_layers=1, compiler="tidl") + assert tvm.ir.structural_equal(reduced, ref_mod, map_free_vars=True) + +if __name__ == '__main__': + test_reduce_subgraph_size_single_output() + test_reduce_subgraph_size_multiple_output() From fe8b4536a2e008e8db607ce72c399eb0974835e6 Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Fri, 19 Jun 2020 20:20:19 +0000 Subject: [PATCH 2/4] Fix passing max num layers and max total mb --- python/tvm/relay/backend/contrib/tidl.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/tvm/relay/backend/contrib/tidl.py b/python/tvm/relay/backend/contrib/tidl.py index 05175c24b959..a2962dad8dc5 100755 --- a/python/tvm/relay/backend/contrib/tidl.py +++ b/python/tvm/relay/backend/contrib/tidl.py @@ -1529,6 +1529,8 @@ def __init__(self, platform, version, max_num_layers=256, max_total_memory_mb=51 if key in kwargs: setattr(self, key, kwargs[key]) self.tidl_target = "tidl" + self.max_num_layers = max_num_layers + self.max_total_memory_mb = max_total_memory_mb else: sys.exit("Unsupported TIDL platform or version!") From f432e214e59228675a3bd7f6d10bf03c73dcb2e9 Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Tue, 23 Jun 2020 21:43:54 +0000 Subject: [PATCH 3/4] Improve --- python/tvm/relay/backend/contrib/tidl.py | 119 ++++++++++-------- python/tvm/relay/op/contrib/tidl.py | 38 ++++-- .../relay/test_tidl_reduce_subgraph_size.py | 3 +- 3 files changed, 97 insertions(+), 63 deletions(-) diff --git a/python/tvm/relay/backend/contrib/tidl.py b/python/tvm/relay/backend/contrib/tidl.py index a2962dad8dc5..1fc71e605cb2 100755 --- a/python/tvm/relay/backend/contrib/tidl.py +++ b/python/tvm/relay/backend/contrib/tidl.py @@ -455,19 +455,19 @@ def get_total_memory_mb(self): return self.total_memory / (1024.0 * 1024.0) def visit_call(self, call): - # Don't visit composite function body. - for arg in call.args: - super().visit(arg) - self.num_layers += 1 - # Add total size of weights (16 bits per) - for arg in call.args: - if isinstance(arg, tvm.relay.expr.Constant): - self.total_memory += 2 * np.prod(list(map(int, arg.checked_type.shape))) - # Add activation size (8 bits per) - if isinstance(call.checked_type, tvm.relay.TensorType): - self.total_memory += np.prod(list(map(int, call.checked_type.shape))) - -def FindCommonAncestor(expr0, expr1): + super().visit_call(call) + # Don't count twice for composite op + if not isinstance(call.op, Function): + self.num_layers += 1 + # Add total size of weights (16 bits per) + for arg in call.args: + if isinstance(arg, tvm.relay.expr.Constant): + self.total_memory += 2 * np.prod(list(map(int, arg.checked_type.shape))) + # Add activation size (8 bits per) + if isinstance(call.checked_type, tvm.relay.TensorType): + self.total_memory += np.prod(list(map(int, call.checked_type.shape))) + +def FindCommonAncestor(expr): """ Find the closest common ancestor to expr0 and expr1. Returns distance from both. @@ -475,16 +475,23 @@ def FindCommonAncestor(expr0, expr1): """ class CommonAncestor(ExprVisitor): """ - Creates a map of node -> distance from expr + Creates a map of nodes -> distance from expr """ - def __init__(self): + def __init__(self, expr, ancestors_from_first_traversal=None): + """ + Parameters + ---------- + expr : tvm.relay.Expr + Output node + ancestors_from_first_traversal : Dict[tvm.relay.ir.expr, int] + CommonAncestor.ancestors_with_distance from previous traversal of a different + output of the same graph. Will be used to terminate traversal early to avoid + visiting nodes unnecessarily. + """ ExprVisitor.__init__(self) - self.ancestors_with_distance = {} - self.call_outputs = {} - - def Find(self, expr): - self.ancestors_with_distance[expr] = 0 - self.call_outputs[expr] = [] + self.ancestors_with_distance = {expr: 0} + self.call_outputs = {expr: []} + self.ancestors_from_first_traversal = ancestors_from_first_traversal super().visit(expr) def _update(self, expr, expr_inputs): @@ -495,30 +502,39 @@ def _update(self, expr, expr_inputs): self.call_outputs[arg] = [expr] if expr in self.call_outputs and len(self.call_outputs[expr]) > 0: - self.ancestors_with_distance[expr] = max([self.ancestors_with_distance[output] for output in self.call_outputs[expr]]) + 1 + self.ancestors_with_distance[expr] = \ + max([self.ancestors_with_distance[output] for output in self.call_outputs[expr]]) + 1 else: # Op did not have any outputs that we have already visited. self.ancestors_with_distance[expr] = 0 + def _terminate_early(self, node): + # Second traversal (from fields[1] can stop when it reaches any node already visited + # by first traversal). + return self.ancestors_from_first_traversal and \ + node in self.ancestors_from_first_traversal + def visit_tuple_getitem(self, tuplegetitem): self._update(tuplegetitem, [tuplegetitem.tuple_value]) - super().visit_tuple_getitem(tuplegetitem) + if not self._terminate_early(tuplegetitem): + super().visit_tuple_getitem(tuplegetitem) def visit_tuple(self, tup): self._update(tup, tup.fields) - super().visit_tuple(tup) + if not self._terminate_early(tup): + super().visit_tuple(tup) def visit_call(self, call): self._update(call, call.args) - # Don't visit function body - # We don't care what's inside composite functions, we will just remove the whole func. - for arg in call.args: - super().visit(arg) + if not self._terminate_early(call): + # Don't visit function body + # We don't care what's inside composite functions, we will just remove the whole func. + for arg in call.args: + super().visit(arg) - common0 = CommonAncestor() - common0.Find(expr0) - common1 = CommonAncestor() - common1.Find(expr1) + assert len(expr.fields) == 2, "Only subgraphs with 1 or 2 outputs are supported by ReduceSubgraphSize" + common0 = CommonAncestor(expr.fields[0]) + common1 = CommonAncestor(expr.fields[1], common0.ancestors_with_distance) # Find common first_common_ancestor = None distance_to_0 = 999999 @@ -529,6 +545,7 @@ def visit_call(self, call): first_common_ancestor = node distance_to_0 = common0.ancestors_with_distance[node] distance_to_1 = common1.ancestors_with_distance[node] + assert first_common_ancestor is not None return first_common_ancestor, distance_to_0, distance_to_1 class SubgraphReducer(ExprMutator): @@ -563,27 +580,26 @@ def visit_call(self, call): last_op_args = [] if isinstance(last_op, tvm.relay.expr.Tuple): # Subgraph has multiple outputs! - assert len(last_op.fields) == 2 - ancestor, dist0, dist1 = FindCommonAncestor(last_op.fields[0], last_op.fields[1]) - if dist0 == 0 and dist1 == 0: + ancestor, dist0, dist1 = FindCommonAncestor(last_op) + + def get_args(field, exclude): + """Gather args from field, excluding exclude node""" + args = [] + assert isinstance(field, tvm.relay.expr.Call) + for arg in field.args: + if arg != exclude: + args.append(arg) + return args + + # If all fields in tuple are not CallNodes, we will just remove all up to common ancestor. + if (dist0 == 0 and dist1 == 0): last_op_args = ancestor.args elif dist0 > dist1: - # field[0] is further from LCA. - # Remove it by replacing it with its args. - # Keep field[1] - last_op_args = [] - for arg in last_op.fields[0].args: - if arg != last_op.fields[1]: - last_op_args.append(arg) - last_op_args.append(last_op.fields[1]) + # field[0] is further from LCA, remove it by replacing it with its args. + last_op_args = get_args(last_op.fields[0], exclude=last_op.fields[1]) + [last_op.fields[1]] elif dist1 >= dist0: - # field[1] is further from LCA. - # Remove it by replacing it with its args. - # Keep field[0] - last_op_args = [last_op.fields[0]] - for arg in last_op.fields[1].args: - if arg != last_op_args[0]: - last_op_args.append(arg) + # field[1] is further from LCA, Remove it by replacing it with its args. + last_op_args = [last_op.fields[0]] + get_args(last_op.fields[1], exclude=last_op.fields[0]) elif isinstance(last_op, tvm.relay.expr.Call): last_op_args = last_op.args else: @@ -649,6 +665,7 @@ def ReduceSubgraphSize(mod, compiler="tidl", max_num_layers=256, max_total_memor while sanity_counter > 0: new_mod = tvm.IRModule() reducer = SubgraphReducer(mod, new_mod, max_num_layers, max_total_memory_mb) + # TODO(trevmorr): Models with Preclude not supported (multiple functions other than main). new_mod['main'] = reducer.visit(mod["main"]) # If no subgraphs where reduced in size, we are done. if not reducer.reduced: @@ -1523,7 +1540,7 @@ class TIDLCompiler: Folder to hold TIDL artifacts """ - def __init__(self, platform, version, max_num_layers=256, max_total_memory_mb=512, **kwargs): + def __init__(self, platform, version, max_num_layers=225, max_total_memory_mb=128, **kwargs): if platform == "AM57" and version >= (6,3): for key in ('num_tidl_subgraphs', 'data_layout', 'artifacts_folder', 'tidl_tools_path'): if key in kwargs: diff --git a/python/tvm/relay/op/contrib/tidl.py b/python/tvm/relay/op/contrib/tidl.py index 7712a1e765c4..d05f4d573662 100755 --- a/python/tvm/relay/op/contrib/tidl.py +++ b/python/tvm/relay/op/contrib/tidl.py @@ -135,6 +135,20 @@ def _dense_bias_pattern(): bias_out = is_op('nn.bias_add')(dense_out, is_constant()) return bias_out + def _bn_tuple_get_item(): + bn_out = is_op('nn.batch_norm')(wildcard(), is_constant(), is_constant(), is_constant(), is_constant()) + tuple_get_item_node = is_tuple_get_item(bn_out, 0) + return tuple_get_item_node + + def _bn_tuple_get_item_checker(extract): + bn_op = extract.tuple_value + data1 = infer_type(bn_op.args[1]) + if data1.checked_type.dtype != 'float32': + return False + elif bn_op.attrs.axis != 1 and bn_op.attrs.axis != 3: + return False + return True + pattern_table = [ ('tidl.squeeze_reshape', _squeeze_reshape_pattern()), ('tidl.transpose_reshape', _transpose_reshape_pattern()), @@ -153,6 +167,7 @@ def _dense_bias_pattern(): ('tidl.dense_relu', _dense_relu_pattern()), ('tidl.dense_bias_relu', _dense_bias_relu_pattern()), ('tidl.dense_bias', _dense_bias_pattern()), + ('tidl.bn_tuple_get_item', _bn_tuple_get_item(), _bn_tuple_get_item_checker), ] return relay.transform.MergeComposite(pattern_table)(mod) @@ -206,7 +221,12 @@ def _conv2d_bias_relu_whitelist_fn(attrs, args): @tvm.ir.register_op_attr("tidl.bn_relu", "target.tidl") def _bn_relu_whitelist_fn(attrs, args): bn_op = args[0] - return _batch_norm_whitelist_fn(bn_op.attrs, bn_op.args) + data1 = infer_type(bn_op.args[1]) + if data1.checked_type.dtype != 'float32': + return False + elif bn_op.attrs.axis != 1 and bn_op.attrs.axis != 3: + return False + return True @tvm.ir.register_op_attr("tidl.add_relu", "target.tidl") def _add_relu_whitelist_fn(attrs, args): @@ -240,6 +260,10 @@ def _conv2d_pad_whitelist_fn(attrs, args): conv2d_supported = _conv2d_whitelist_fn(conv2d_op.attrs, conv2d_op.args) return (pad_supported and conv2d_supported) +@tvm.ir.register_op_attr("tidl.bn_tuple_get_item", "target.tidl") +def _bn_tuple_get_item_whitelist_fn(attrs, args): + return True + @tvm.ir.register_op_attr("add", "target.tidl") def _add_whitelist_fn(attrs, args): supported = True @@ -273,16 +297,8 @@ def _batch_flatten_fn(attrs, args): @tvm.ir.register_op_attr("nn.batch_norm", "target.tidl") def _batch_norm_whitelist_fn(attrs, args): - #These are the relay arguments... look up the operator to get the actual name... - data1 = infer_type(args[1]) - supported = True - - if data1.checked_type.dtype != 'float32': - supported = False - elif attrs.axis != 1 and attrs.axis != 3: - supported = False - - return supported + # Standalone batch_norm is supported only as a pattern (bn_tuple_get_item). + return False @tvm.ir.register_op_attr("nn.bias_add", "target.tidl") def _bias_add_whitelist_fn(attrs, args): diff --git a/tests/python/relay/test_tidl_reduce_subgraph_size.py b/tests/python/relay/test_tidl_reduce_subgraph_size.py index e9b3d42d9287..bf38763d53ef 100644 --- a/tests/python/relay/test_tidl_reduce_subgraph_size.py +++ b/tests/python/relay/test_tidl_reduce_subgraph_size.py @@ -171,8 +171,9 @@ def expected_2(): # Will remove 2nd conv2d. ref_mod = expected_2() reduced = ReduceSubgraphSize(create_graph(), max_num_layers=1, compiler="tidl") + print('reduced', reduced) assert tvm.ir.structural_equal(reduced, ref_mod, map_free_vars=True) if __name__ == '__main__': - test_reduce_subgraph_size_single_output() + #test_reduce_subgraph_size_single_output() test_reduce_subgraph_size_multiple_output() From 332b27c52c7ab697fdb156dbeaa9d7b51e000a16 Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Wed, 24 Jun 2020 19:38:54 +0000 Subject: [PATCH 4/4] Fix typo --- python/tvm/relay/backend/contrib/tidl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/backend/contrib/tidl.py b/python/tvm/relay/backend/contrib/tidl.py index 1fc71e605cb2..eb27edc4d522 100755 --- a/python/tvm/relay/backend/contrib/tidl.py +++ b/python/tvm/relay/backend/contrib/tidl.py @@ -1598,7 +1598,7 @@ def enable(self, mod_orig, params, input): mod = transform.MergeCompilerRegions()(mod) mod = transform.PartitionGraph()(mod) mod = PruneSubgraphsWithMoreThanOneInput(mod, compiler=self.tidl_target) - mod = ReduceSubgraphSize(mod, max_num_layers=max_num_layers, max_total_memory_mb=max_total_memory_mb, compiler=self.tidl_target) + mod = ReduceSubgraphSize(mod, max_num_layers=self.max_num_layers, max_total_memory_mb=self.max_total_memory_mb, compiler=self.tidl_target) mod = UnpackComposites(mod, compiler=self.tidl_target) mod = PruneSubgraphs(mod, compiler=self.tidl_target, num_subgraphs_to_keep=self.num_tidl_subgraphs)