Skip to content

Commit

Permalink
Merge pull request neo-ai#7 from trevor-m/subgraph-size
Browse files Browse the repository at this point in the history
[TIDL] Add reduce subgraph size pass
  • Loading branch information
jianzhong-xu authored Jun 24, 2020
2 parents 8404726 + 332b27c commit c249527
Show file tree
Hide file tree
Showing 4 changed files with 588 additions and 18 deletions.
275 changes: 268 additions & 7 deletions python/tvm/relay/backend/contrib/tidl.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,9 @@ def tensor_quant_flatten(input_tensor, data_layout):
return output, scale, sign

class VarReplacer(ExprMutator):
"""
Replaces vars in expr according to var_map.
"""
def __init__(self, var_map):
ExprMutator.__init__(self)
self.var_map = var_map
Expand All @@ -247,10 +250,7 @@ def visit_call(self, call):
return super().visit_call(call)

for func in mod.get_global_vars():
name = func.name_hint
if not mod[name].attrs or mod[name].attrs["Compiler"] != compiler:
continue
mod[name] = Unpacker().visit(mod[name])
mod[func.name_hint] = Unpacker().visit(mod[func.name_hint])
return mod

class CalibrationGraphMutator(ExprMutator):
Expand Down Expand Up @@ -366,18 +366,42 @@ def generate_subgraph_tensors(tidl_target, mod, params, input_node, input_data):

return subgraph_tensors

class ExprReplacer(ExprMutator):
"""
Replaces call nodes in expr according to call_map
"""
def __init__(self, call_map):
ExprMutator.__init__(self)
self.call_map = call_map

def visit_call(self, call):
if call in self.call_map:
return self.call_map[call]
return super().visit_call(call)

class VarRenamer(ExprMutator):
"""
Renames vars to match the new subgraph name. Used when subgraphs are renamed starting from zero.
If subgraph was originally "tidl_34", it would have inputs named like "tidl_34_i0".
IF new_subgraph_name is "tidl_0", pass will that input to "tidl_0_i0".
"""
def __init__(self, new_subgraph_name):
ExprMutator.__init__(self)
self.new_subgraph_name = new_subgraph_name

def visit_var(self, var):
if "_".join(var.name_hint.split('_')[:2]) != self.new_subgraph_name:
# TODO: Make sure input isn't from a composite func.
# TODO: Doesn't account for tuple inputs (not possible due to PruneSubgraphsWithMoreThanOneInput)
if var.name_hint.startswith("tidl") and "_".join(var.name_hint.split('_')[:2]) != self.new_subgraph_name:
new_var_name = self.new_subgraph_name + "_" + var.name_hint.split('_')[2]
return relay.Var(new_var_name, var.checked_type)
return super().visit_var(var)

class SubgraphRemover(ExprMutator):
"""
Removes subgraphs which are in the list subgraphs_to_remove and returns them back to regular
TVM compilation in main function.
"""
def __init__(self, subgraphs_to_remove, mod, new_mod, rename_starting_from_0=True):
ExprVisitor.__init__(self)
self.subgraphs_to_remove = subgraphs_to_remove
Expand Down Expand Up @@ -417,6 +441,240 @@ def visit_call(self, call):
return subgraph_gv(*args)
return super().visit_call(call)

class SubgraphSizeCounter(ExprVisitor):
"""
Pass to count size of subgraph, both number of layers and estimated total memory usage.
Used by SubgraphReducer pass.
"""
def __init__(self):
ExprVisitor.__init__(self)
self.num_layers = 0
self.total_memory = 0

def get_total_memory_mb(self):
return self.total_memory / (1024.0 * 1024.0)

def visit_call(self, call):
super().visit_call(call)
# Don't count twice for composite op
if not isinstance(call.op, Function):
self.num_layers += 1
# Add total size of weights (16 bits per)
for arg in call.args:
if isinstance(arg, tvm.relay.expr.Constant):
self.total_memory += 2 * np.prod(list(map(int, arg.checked_type.shape)))
# Add activation size (8 bits per)
if isinstance(call.checked_type, tvm.relay.TensorType):
self.total_memory += np.prod(list(map(int, call.checked_type.shape)))

def FindCommonAncestor(expr):
"""
Find the closest common ancestor to expr0 and expr1.
Returns distance from both.
Used by SubgraphReducer pass.
"""
class CommonAncestor(ExprVisitor):
"""
Creates a map of nodes -> distance from expr
"""
def __init__(self, expr, ancestors_from_first_traversal=None):
"""
Parameters
----------
expr : tvm.relay.Expr
Output node
ancestors_from_first_traversal : Dict[tvm.relay.ir.expr, int]
CommonAncestor.ancestors_with_distance from previous traversal of a different
output of the same graph. Will be used to terminate traversal early to avoid
visiting nodes unnecessarily.
"""
ExprVisitor.__init__(self)
self.ancestors_with_distance = {expr: 0}
self.call_outputs = {expr: []}
self.ancestors_from_first_traversal = ancestors_from_first_traversal
super().visit(expr)

def _update(self, expr, expr_inputs):
for arg in expr_inputs:
if arg in self.call_outputs and expr not in self.call_outputs[arg]:
self.call_outputs[arg].append(expr)
else:
self.call_outputs[arg] = [expr]

if expr in self.call_outputs and len(self.call_outputs[expr]) > 0:
self.ancestors_with_distance[expr] = \
max([self.ancestors_with_distance[output] for output in self.call_outputs[expr]]) + 1
else:
# Op did not have any outputs that we have already visited.
self.ancestors_with_distance[expr] = 0

def _terminate_early(self, node):
# Second traversal (from fields[1] can stop when it reaches any node already visited
# by first traversal).
return self.ancestors_from_first_traversal and \
node in self.ancestors_from_first_traversal

def visit_tuple_getitem(self, tuplegetitem):
self._update(tuplegetitem, [tuplegetitem.tuple_value])
if not self._terminate_early(tuplegetitem):
super().visit_tuple_getitem(tuplegetitem)

def visit_tuple(self, tup):
self._update(tup, tup.fields)
if not self._terminate_early(tup):
super().visit_tuple(tup)

def visit_call(self, call):
self._update(call, call.args)
if not self._terminate_early(call):
# Don't visit function body
# We don't care what's inside composite functions, we will just remove the whole func.
for arg in call.args:
super().visit(arg)

assert len(expr.fields) == 2, "Only subgraphs with 1 or 2 outputs are supported by ReduceSubgraphSize"
common0 = CommonAncestor(expr.fields[0])
common1 = CommonAncestor(expr.fields[1], common0.ancestors_with_distance)
# Find common
first_common_ancestor = None
distance_to_0 = 999999
distance_to_1 = 999999
for node in common0.ancestors_with_distance:
if node in common1.ancestors_with_distance:
if common0.ancestors_with_distance[node] <= distance_to_0 and common1.ancestors_with_distance[node] <= distance_to_1:
first_common_ancestor = node
distance_to_0 = common0.ancestors_with_distance[node]
distance_to_1 = common1.ancestors_with_distance[node]
assert first_common_ancestor is not None
return first_common_ancestor, distance_to_0, distance_to_1

class SubgraphReducer(ExprMutator):
"""
Removes a single op from end of subgraphs which exceed max_num_layers or max_total_memory_mb.
If an op is removed, reduced will be set to True.
"""
def __init__(self, mod, new_mod, max_num_layers=256, max_total_memory_mb=512, compiler="tidl"):
ExprVisitor.__init__(self)
self.mod = mod
self.new_mod = new_mod
self.max_num_layers = max_num_layers
self.max_total_memory_mb = max_total_memory_mb
self.compiler = compiler
self.reduced = False

def visit_call(self, call):
if isinstance(call.op, GlobalVar):
name = call.op.name_hint
if not self.mod[name].attrs or self.mod[name].attrs["Compiler"] != self.compiler:
return super().visit_call(call)
# Compute size of subgraph to see if we need to reduce it.
counter = SubgraphSizeCounter()
counter.visit(self.mod[name])
if counter.num_layers > self.max_num_layers or counter.get_total_memory_mb() > self.max_total_memory_mb:
# Mark that we have reduced the subgraph size.
self.reduced = True
# "Inline" the last op only back into new main function.
original_func = self.mod[name]
# Get last_op
last_op = original_func.body
last_op_args = []
if isinstance(last_op, tvm.relay.expr.Tuple):
# Subgraph has multiple outputs!
ancestor, dist0, dist1 = FindCommonAncestor(last_op)

def get_args(field, exclude):
"""Gather args from field, excluding exclude node"""
args = []
assert isinstance(field, tvm.relay.expr.Call)
for arg in field.args:
if arg != exclude:
args.append(arg)
return args

# If all fields in tuple are not CallNodes, we will just remove all up to common ancestor.
if (dist0 == 0 and dist1 == 0):
last_op_args = ancestor.args
elif dist0 > dist1:
# field[0] is further from LCA, remove it by replacing it with its args.
last_op_args = get_args(last_op.fields[0], exclude=last_op.fields[1]) + [last_op.fields[1]]
elif dist1 >= dist0:
# field[1] is further from LCA, Remove it by replacing it with its args.
last_op_args = [last_op.fields[0]] + get_args(last_op.fields[1], exclude=last_op.fields[0])
elif isinstance(last_op, tvm.relay.expr.Call):
last_op_args = last_op.args
else:
raise ValueError("Input to last op is not call or tuple")
# Gather new outputs of the subgraph - from removed op's inputs
# This map will map Expr to index in new_outputs tuple
#print('last_op_args', last_op_args)
new_outputs = []
last_op_input_to_new_output_map = {}
if len(last_op_args) > 1:
for arg in last_op_args:
# Skip weights
if not isinstance(arg, tvm.relay.expr.Constant):
new_outputs.append(arg)
last_op_input_to_new_output_map[arg] = len(new_outputs) - 1
if len(new_outputs) > 1:
new_outputs_expr = relay.Tuple(new_outputs)
elif len(new_outputs) == 1:
new_outputs_expr = new_outputs[0]
else:
raise ValueError("No ops left in subgraph after reducing size")
else:
new_outputs = [last_op_args[0]]
new_outputs_expr = new_outputs[0]
subgraph_gv = relay.GlobalVar(name)

# construct new func without last_op
new_func = relay.Function(original_func.params, new_outputs_expr)
new_func = new_func.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
new_func = new_func.with_attr("Inline", tvm.tir.IntImm("int32", 1))
new_func = new_func.with_attr("Compiler", self.compiler)
new_func = new_func.with_attr("global_symbol", name)
self.new_mod[subgraph_gv] = new_func
args = []
for arg in call.args:
args.append(super().visit(arg))
new_expr = subgraph_gv(*args)
if len(new_outputs) > 1:
call_map = {arg: relay.TupleGetItem(new_expr, index) for arg, index in last_op_input_to_new_output_map.items()}
else:
call_map = {new_outputs[0]: new_expr}
new_expr = ExprReplacer(call_map).visit(last_op)

return new_expr
elif name != "main":
# Transfer subgraph to new mod without modifying
args = []
for arg in call.args:
args.append(super().visit(arg))
subgraph_gv = relay.GlobalVar(name)
self.new_mod[subgraph_gv] = self.mod[name]
return subgraph_gv(*args)
return super().visit_call(call)

def ReduceSubgraphSize(mod, compiler="tidl", max_num_layers=256, max_total_memory_mb=512):
"""
Reduces size of subgraph to fit limitations.
"""
# Counter just in case to avoid infinite loop.
sanity_counter = 10000
# SubgraphReducer removes one op if the subgraph is above the limits.
# Repeated call SubgraphReducer until no subgraphs are reduced.
while sanity_counter > 0:
new_mod = tvm.IRModule()
reducer = SubgraphReducer(mod, new_mod, max_num_layers, max_total_memory_mb)
# TODO(trevmorr): Models with Preclude not supported (multiple functions other than main).
new_mod['main'] = reducer.visit(mod["main"])
# If no subgraphs where reduced in size, we are done.
if not reducer.reduced:
return new_mod
mod = new_mod
# Avoid infinite loop.
sanity_counter -= 1
return mod

def PruneSubgraphsWithMoreThanOneInput(mod, compiler="tidl"):
subgraph_names_to_remove = []
# Remove subgraphs with more than 1 input or tuple inputs.
Expand Down Expand Up @@ -1282,12 +1540,14 @@ class TIDLCompiler:
Folder to hold TIDL artifacts
"""

def __init__(self, platform, version, **kwargs):
def __init__(self, platform, version, max_num_layers=225, max_total_memory_mb=128, **kwargs):
if platform == "AM57" and version >= (6,3):
for key in ('num_tidl_subgraphs', 'data_layout', 'artifacts_folder', 'tidl_tools_path'):
if key in kwargs:
setattr(self, key, kwargs[key])
self.tidl_target = "tidl"
self.max_num_layers = max_num_layers
self.max_total_memory_mb = max_total_memory_mb
else:
sys.exit("Unsupported TIDL platform or version!")

Expand Down Expand Up @@ -1337,8 +1597,9 @@ def enable(self, mod_orig, params, input):
mod = transform.AnnotateTarget(self.tidl_target)(mod)
mod = transform.MergeCompilerRegions()(mod)
mod = transform.PartitionGraph()(mod)
mod = UnpackComposites(mod, compiler=self.tidl_target)
mod = PruneSubgraphsWithMoreThanOneInput(mod, compiler=self.tidl_target)
mod = ReduceSubgraphSize(mod, max_num_layers=self.max_num_layers, max_total_memory_mb=self.max_total_memory_mb, compiler=self.tidl_target)
mod = UnpackComposites(mod, compiler=self.tidl_target)
mod = PruneSubgraphs(mod, compiler=self.tidl_target, num_subgraphs_to_keep=self.num_tidl_subgraphs)

#============= Generate subgraph boundary tensors ==============
Expand Down
Loading

0 comments on commit c249527

Please sign in to comment.