Skip to content

Commit

Permalink
[TOPI] bitserial_conv2d move to autotvm template and updates (apache#…
Browse files Browse the repository at this point in the history
  • Loading branch information
Meghan Cowan authored and wweic committed Apr 7, 2019
1 parent af17a02 commit ce7094f
Show file tree
Hide file tree
Showing 7 changed files with 609 additions and 572 deletions.
2 changes: 1 addition & 1 deletion python/tvm/autotvm/task/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def args_to_workload(x, topi_compute_func=None):
workload = tuple([args_to_workload(a) for a in x])
elif isinstance(x, (str, int, float, np.int, np.float)):
workload = x
elif isinstance(x, (expr.StringImm, expr.IntImm, expr.FloatImm)):
elif isinstance(x, (expr.StringImm, expr.UIntImm, expr.IntImm, expr.FloatImm)):
workload = x.value
elif x is None:
workload = 0
Expand Down
22 changes: 22 additions & 0 deletions python/tvm/autotvm/task/topi_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ def __init__(self):
topi.nn.group_conv2d_nchw: "topi_nn_group_conv2d_nchw",
topi.nn.conv2d_transpose_nchw: "topi_nn_conv2d_transpose_nchw",
topi.nn.dense: "topi_nn_dense",
topi.nn.bitserial_conv2d_nchw: "topi_nn_bitserial_conv2d_nchw",
topi.nn.bitserial_conv2d_nhwc: "topi_nn_bitserial_conv2d_nhwc",
topi.nn.deformable_conv2d_nchw: "topi_nn_deformable_conv2d_nchw",
}

Expand All @@ -79,6 +81,8 @@ def __init__(self):
topi.nn.group_conv2d_nchw: [topi.generic.schedule_group_conv2d_nchw],
topi.nn.conv2d_transpose_nchw: [topi.generic.schedule_conv2d_transpose_nchw],
topi.nn.dense: [topi.generic.schedule_dense],
topi.nn.bitserial_conv2d_nchw: [topi.generic.schedule_bitserial_conv2d_nchw],
topi.nn.bitserial_conv2d_nhwc: [topi.generic.schedule_bitserial_conv2d_nhwc],
topi.nn.deformable_conv2d_nchw: [topi.generic.schedule_deformable_conv2d_nchw],
}

Expand Down Expand Up @@ -174,6 +178,24 @@ def _topi_nn_dense(*args, **kwargs):
return s, [data, weight, bias, C]
return s, [data, weight, C]

@register("topi_nn_bitserial_conv2d_nhwc")
def _topi_bitserial_conv2d_nhwc(*args, **kwargs):
args = deserialize_args(args)
C = topi.nn.bitserial_conv2d_nhwc(*args, **kwargs)
s = topi.generic.nn.schedule_bitserial_conv2d_nhwc([C])
data = args[0]
kernel = args[1]
return s, [data, kernel, C]

@register("topi_nn_bitserial_conv2d_nchw")
def _topi_bitserial_conv2d_nchw(*args, **kwargs):
args = deserialize_args(args)
C = topi.nn.bitserial_conv2d_nchw(*args, **kwargs)
s = topi.generic.nn.schedule_bitserial_conv2d_nchw([C])
data = args[0]
kernel = args[1]
return s, [data, kernel, C]

@register("topi_nn_deformable_conv2d_nchw")
def _topi_nn_deformable_conv2d_nchw(*args, **kwargs):
assert not kwargs, "Do not support kwargs in template function call"
Expand Down
Loading

0 comments on commit ce7094f

Please sign in to comment.