Skip to content

Commit

Permalink
[VTA] Support network which have no unique operator as start/stop nam…
Browse files Browse the repository at this point in the history
…e for graph pack. (apache#4703)

* [VTA] Support network which have no unique operator as start/stop name
for graph pack.

[Issue]
  Current vta use 'start' and 'stop' name to define the pack start point
  and end point, but this method not work for these network which have
  no 2 unique operator as  start point and stop point.

[Solution]
  In this solution we give 2 addtional parameters start_name_indx and
  stop_name_indx to make vta pack logic work with the said network,
  for exampl for following networks which have no unique operator,

  %0 = nn.add
  %1 = nn.conv2d
  %2 = nn.batch_norm
  %3 = nn.leaky_relu
  %4 = nn.add
  %5 = nn.conv2d
  %6 = nn.batch_norm
  %7 = nn.leaky_relu
  %8 = nn.add

  with this solution we can use following parameter format to make
  vta work on it.

  relay_prog = graph_pack(
                //....
                start_name="nn.add",
                stop_name="nn.add",
                start_name_idx=0,
                stop_name_idx=4)

  to apply on new network, by printing the network we can get index information like following.

  print(mod.astext(show_meta_data=False))
  relay_prog = graph_pack(mod
                          ...
                          start_name="nn.add",
                          stop_name="nn.add",
                          start_name_idx=0,
                          stop_name_idx=4)

* address review comments and fix index count bug

issue:
when do print(mod), the output not only the Call is also have other type
like Var, need add logic to count all except meta.

solution:
add related logic

* address review comments.

* address review comments

* add more detail comments.
  • Loading branch information
huajsj authored and zhiics committed Mar 2, 2020
1 parent c37198f commit 9f29fba
Showing 1 changed file with 47 additions and 12 deletions.
59 changes: 47 additions & 12 deletions vta/python/vta/top/graphpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,15 @@ def _get_shape(node):
"""
return _to_shape(node.checked_type.shape)

def _operator_idx_inc(expr, count_meta, operator_current_idx):
"""Increase operator index
"""
if isinstance(expr, relay.expr.Constant):
operator_current_idx = operator_current_idx + 1 if count_meta else operator_current_idx
else:
operator_current_idx = operator_current_idx + 1
return operator_current_idx

class ExprPack(ExprMutator):
"""Visitor to perform graph packing on an AST.
"""
Expand Down Expand Up @@ -246,32 +255,40 @@ def visit_call(self, call):

class BT(Exception):
pass
def get_subgraph(expr, start_name, stop_name):
def get_subgraph(expr, start_name, stop_name, start_name_idx, stop_name_idx, count_meta):
""" We assume stop_name only appears once for simplicity.
This constraint will be lifted in the future.
bitpack_start and bitpack_end are both inclusive.
"""
bitpack_start = op.op.get('annotation.bitpack_start')
bitpack_end = op.op.get('annotation.bitpack_end')
anf = run_opt_pass(expr, transform.ToANormalForm())
def _recursion(anf, start_found, stop_found):
operator_current_idx = 0
def _recursion(anf, start_found, stop_found, operator_current_idx):
""" Helper to obtain the subgraph.
"""
if isinstance(anf, relay.expr.Function):
return relay.expr.Function(anf.params,
_recursion(anf.body, start_found, stop_found),
_recursion(anf.body, start_found, stop_found,
operator_current_idx),
anf.ret_type, anf.type_params, anf.attrs)
elif isinstance(anf, relay.expr.Let):
value = anf.value
if isinstance(value, relay.expr.Call):
if isinstance(value.op, relay.op.Op):
if value.op.name == start_name and not start_found:
value = relay.expr.Call(bitpack_start, [value])
start_found = True
if operator_current_idx == start_name_idx or start_name_idx is None:
value = relay.expr.Call(bitpack_start, [value])
start_found = True
elif value.op.name == stop_name:
raise BT()
if operator_current_idx == stop_name_idx or stop_name_idx is None:
raise BT()

operator_current_idx = _operator_idx_inc(value, count_meta, operator_current_idx)

try:
return relay.expr.Let(anf.var, value, _recursion(anf.body, start_found, stop_found))
return relay.expr.Let(anf.var, value, _recursion(anf.body, start_found, stop_found,
operator_current_idx))
except BT:
assert start_found
assert not stop_found
Expand All @@ -283,15 +300,18 @@ def _recursion(anf, start_found, stop_found):
assert start_found
assert stop_found
return anf
annotated = _recursion(anf, False, False)
annotated = _recursion(anf, False, False, operator_current_idx)
return run_opt_pass(annotated, transform.ToGraphNormalForm())

def graph_pack(expr,
bfactor,
cfactor,
weight_bits,
start_name="nn.max_pool2d",
stop_name="nn.global_avg_pool2d"):
stop_name="nn.global_avg_pool2d",
start_name_idx=None,
stop_name_idx=None,
count_meta=False):
"""Pack the graph into batch&channel packed format.
Parameters
Expand All @@ -309,18 +329,33 @@ def graph_pack(expr,
The bit-width of the weights.
start_name: str, optional
Start packing from certain known node.
Start packing from certain known node when start_name_idx is None.
stop_name: str, optional
Stop packing from certain known node.
Stop packing from certain known node when stop_name_idx is None.
start_name_idx: int, optional
When start_name_idx not None, start packing only when node name equal start_name
and node idx equals start_name_idx.
stop_name_idx: int, optional
When stop_name_idx not None, stop packing only when node name equal stop_name
and node index equals stop_name_idx.
count_meta:boolean, optional
When count_meta is False, the operator increase logic would not count the meta that have
the type 'relay.expr.Constant', start_name_idx and stop_name_idx follow the index from
'expr.astext(show_meta_data=False)'. When count_meta is True, the operator increase
logic would count the meta.
Returns
-------
expr : Expr
The transformed expression.
"""
assert isinstance(expr, relay.Function)
expr = get_subgraph(expr, start_name, stop_name)
assert ((start_name != stop_name) or (start_name_idx < stop_name_idx))
expr = get_subgraph(expr, start_name, stop_name, start_name_idx, stop_name_idx, count_meta)
expr = run_opt_pass(expr, transform.InferType())
packer = ExprPack(
bfactor, cfactor,
Expand Down

0 comments on commit 9f29fba

Please sign in to comment.