diff --git a/vta/python/vta/top/graphpack.py b/vta/python/vta/top/graphpack.py index a4c054880ac2..ba139a8b5ace 100644 --- a/vta/python/vta/top/graphpack.py +++ b/vta/python/vta/top/graphpack.py @@ -110,6 +110,15 @@ def _get_shape(node): """ return _to_shape(node.checked_type.shape) +def _operator_idx_inc(expr, count_meta, operator_current_idx): + """Increase operator index + """ + if isinstance(expr, relay.expr.Constant): + operator_current_idx = operator_current_idx + 1 if count_meta else operator_current_idx + else: + operator_current_idx = operator_current_idx + 1 + return operator_current_idx + class ExprPack(ExprMutator): """Visitor to perform graph packing on an AST. """ @@ -246,7 +255,7 @@ def visit_call(self, call): class BT(Exception): pass -def get_subgraph(expr, start_name, stop_name): +def get_subgraph(expr, start_name, stop_name, start_name_idx, stop_name_idx, count_meta): """ We assume stop_name only appears once for simplicity. This constraint will be lifted in the future. bitpack_start and bitpack_end are both inclusive. @@ -254,24 +263,32 @@ def get_subgraph(expr, start_name, stop_name): bitpack_start = op.op.get('annotation.bitpack_start') bitpack_end = op.op.get('annotation.bitpack_end') anf = run_opt_pass(expr, transform.ToANormalForm()) - def _recursion(anf, start_found, stop_found): + operator_current_idx = 0 + def _recursion(anf, start_found, stop_found, operator_current_idx): """ Helper to obtain the subgraph. """ if isinstance(anf, relay.expr.Function): return relay.expr.Function(anf.params, - _recursion(anf.body, start_found, stop_found), + _recursion(anf.body, start_found, stop_found, + operator_current_idx), anf.ret_type, anf.type_params, anf.attrs) elif isinstance(anf, relay.expr.Let): value = anf.value if isinstance(value, relay.expr.Call): if isinstance(value.op, relay.op.Op): if value.op.name == start_name and not start_found: - value = relay.expr.Call(bitpack_start, [value]) - start_found = True + if operator_current_idx == start_name_idx or start_name_idx is None: + value = relay.expr.Call(bitpack_start, [value]) + start_found = True elif value.op.name == stop_name: - raise BT() + if operator_current_idx == stop_name_idx or stop_name_idx is None: + raise BT() + + operator_current_idx = _operator_idx_inc(value, count_meta, operator_current_idx) + try: - return relay.expr.Let(anf.var, value, _recursion(anf.body, start_found, stop_found)) + return relay.expr.Let(anf.var, value, _recursion(anf.body, start_found, stop_found, + operator_current_idx)) except BT: assert start_found assert not stop_found @@ -283,7 +300,7 @@ def _recursion(anf, start_found, stop_found): assert start_found assert stop_found return anf - annotated = _recursion(anf, False, False) + annotated = _recursion(anf, False, False, operator_current_idx) return run_opt_pass(annotated, transform.ToGraphNormalForm()) def graph_pack(expr, @@ -291,7 +308,10 @@ def graph_pack(expr, cfactor, weight_bits, start_name="nn.max_pool2d", - stop_name="nn.global_avg_pool2d"): + stop_name="nn.global_avg_pool2d", + start_name_idx=None, + stop_name_idx=None, + count_meta=False): """Pack the graph into batch&channel packed format. Parameters @@ -309,10 +329,24 @@ def graph_pack(expr, The bit-width of the weights. start_name: str, optional - Start packing from certain known node. + Start packing from certain known node when start_name_idx is None. stop_name: str, optional - Stop packing from certain known node. + Stop packing from certain known node when stop_name_idx is None. + + start_name_idx: int, optional + When start_name_idx not None, start packing only when node name equal start_name + and node idx equals start_name_idx. + + stop_name_idx: int, optional + When stop_name_idx not None, stop packing only when node name equal stop_name + and node index equals stop_name_idx. + + count_meta:boolean, optional + When count_meta is False, the operator increase logic would not count the meta that have + the type 'relay.expr.Constant', start_name_idx and stop_name_idx follow the index from + 'expr.astext(show_meta_data=False)'. When count_meta is True, the operator increase + logic would count the meta. Returns ------- @@ -320,7 +354,8 @@ def graph_pack(expr, The transformed expression. """ assert isinstance(expr, relay.Function) - expr = get_subgraph(expr, start_name, stop_name) + assert ((start_name != stop_name) or (start_name_idx < stop_name_idx)) + expr = get_subgraph(expr, start_name, stop_name, start_name_idx, stop_name_idx, count_meta) expr = run_opt_pass(expr, transform.InferType()) packer = ExprPack( bfactor, cfactor,