diff --git a/python/tvm/relay/backend/contrib/tidl_reduce_subgraph_size.py b/python/tvm/relay/backend/contrib/tidl_reduce_subgraph_size.py old mode 100644 new mode 100755 index a4cebab3d478..f8519effb59f --- a/python/tvm/relay/backend/contrib/tidl_reduce_subgraph_size.py +++ b/python/tvm/relay/backend/contrib/tidl_reduce_subgraph_size.py @@ -195,70 +195,13 @@ def visit_call(self, call): last_op_args = [] if isinstance(last_op, tvm.relay.expr.Tuple): # Subgraph has multiple outputs! - ancestor, distances = find_common_ancestor(last_op) - - def get_field(field): - """Get field as it is, unless it is a TupleGetItem which we will remove.""" - if isinstance(field, tvm.relay.expr.Call): - # Handle concat - if isinstance(field.args[0], tvm.relay.expr.Tuple): - args = [] - for f in field.args[0].fields: - args.append(f) - return args - return [field] - if isinstance(field, tvm.relay.expr.TupleGetItem): - args = [] - for arg in field.tuple_value.args: - args.append(arg) - return args - if isinstance(field, tvm.relay.expr.Tuple): - args = [] - for arg in field.fields: - args.append(arg) - return args - raise ValueError("New output of subgraph must be Call node.") - - def get_args(field): - """Gather args from field, excluding exclude node""" - args = [] - if isinstance(field, tvm.relay.expr.Call): - for arg in field.args: - # Handle concat - if isinstance(arg, tvm.relay.expr.Tuple): - for f in arg.fields: - args.append(f) - else: - args.append(arg) - elif isinstance(field, tvm.relay.expr.TupleGetItem): - for arg in field.tuple_value.args: - args.append(arg) - elif isinstance(field, tvm.relay.expr.Tuple): - for arg in field.fields: - args.append(arg) - else: - raise ValueError("New output of subgraph must be Call node.") - return args - - # All nodes come from same parent. - if all([dist == 0 for dist in distances]): - last_op_args = ancestor.args - else: - # Remove node with longest path - index_to_remove = np.argmax(distances) - # field[index_to_remove] is further from LCA, remove it - # by replacing it with its args. - last_op_args = [] - for i in range(0, len(last_op.fields)): - if i == index_to_remove: - last_op_args += get_args(last_op.fields[i]) - else: - last_op_args += get_field(last_op.fields[i]) - - # Remove duplicates. - seen = set() - seen_add = seen.add - last_op_args = [x for x in last_op_args if not (x in seen or seen_add(x))] + ancestor, _ = find_common_ancestor(last_op) + # Removing only op furthest from LCA greatly increase time taken for this pass. + # Instead, always delete all the way up to the lower common ancestor. This may + # cause more ops to be removed than is required, but it is much faster. + # TODO(trevmorr): Consider rewriting in C++ to improve speed. + # last_op_args = self._remove_op_furthest_from_lca(last_op, ancestor, distances) + last_op_args = ancestor.args elif isinstance(last_op, tvm.relay.expr.Call): last_op_args = last_op.args elif isinstance(last_op, tvm.relay.expr.TupleGetItem): @@ -268,8 +211,7 @@ def get_args(field): else: raise ValueError("Last op is not Call, Tuple, or TupleGetItem") # Gather new outputs of the subgraph - from removed op's inputs - # This map will map Expr to index in new_outputs tuple - #print('last_op_args', last_op_args) + # This map will map Expr to index in new_outputs tuplea new_outputs = [] last_op_input_to_new_output_map = {} if len(last_op_args) > 1: @@ -318,6 +260,73 @@ def get_args(field): return subgraph_gv(*args) return super().visit_call(call) + def _remove_op_furthest_from_lca(self, last_op, ancestor, distances): + """For subgraph with multiple outputs, pick output with logest path to least common + ancestor. Returns list of new outputs. + """ + def get_field(field): + """Get field as it is, unless it is a TupleGetItem which we will remove.""" + if isinstance(field, tvm.relay.expr.Call): + # Handle concat + if isinstance(field.args[0], tvm.relay.expr.Tuple): + args = [] + for f in field.args[0].fields: + args.append(f) + return args + return [field] + if isinstance(field, tvm.relay.expr.TupleGetItem): + args = [] + for arg in field.tuple_value.args: + args.append(arg) + return args + if isinstance(field, tvm.relay.expr.Tuple): + args = [] + for arg in field.fields: + args.append(arg) + return args + raise ValueError("New output of subgraph must be Call node.") + + def get_args(field): + """Gather args from field, excluding exclude node""" + args = [] + if isinstance(field, tvm.relay.expr.Call): + for arg in field.args: + # Handle concat + if isinstance(arg, tvm.relay.expr.Tuple): + for f in arg.fields: + args.append(f) + else: + args.append(arg) + elif isinstance(field, tvm.relay.expr.TupleGetItem): + for arg in field.tuple_value.args: + args.append(arg) + elif isinstance(field, tvm.relay.expr.Tuple): + for arg in field.fields: + args.append(arg) + else: + raise ValueError("New output of subgraph must be Call node.") + return args + + # All nodes come from same parent. + if all([dist == 0 for dist in distances]): + return ancestor.args + # Remove node with longest path + index_to_remove = np.argmax(distances) + # field[index_to_remove] is further from LCA, remove it + # by replacing it with its args. + last_op_args = [] + for i in range(0, len(last_op.fields)): + if i == index_to_remove: + last_op_args += get_args(last_op.fields[i]) + else: + last_op_args += get_field(last_op.fields[i]) + + # Remove duplicates. + seen = set() + seen_add = seen.add + last_op_args = [x for x in last_op_args if not (x in seen or seen_add(x))] + return last_op_args + def reduce_subgraph_size(mod, max_num_layers=256, max_total_memory_mb=512): """ Reduces size of subgraph to fit limitations.