Skip to content

Commit

Permalink
Address review comments
Browse files Browse the repository at this point in the history
Change-Id: Ia1755a0af7b6d159072d9f0c93c932c481101e48
  • Loading branch information
Giuseppe Rossini committed Jun 16, 2020
1 parent e323d88 commit c151f90
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 1 deletion.
1 change: 0 additions & 1 deletion python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,6 @@ def compute_mirror_pad(attrs, inputs, out_dtype):
reg.register_pattern("nn.contrib_conv2d_gemm_without_weight_transform",
OpPattern.OUT_ELEMWISE_FUSABLE)


@reg.register_compute("nn.contrib_conv2d_gemm_weight_transform")
def compute_contrib_conv2d_gemm_weight_transform(attrs, inputs, out_dtype):
"""Compute definition of contrib_conv2d_gemm_weight_transform"""
Expand Down
5 changes: 5 additions & 0 deletions topi/python/topi/arm_cpu/conv2d_alter_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
assert (data.dtype == 'int8' and kernel.dtype == 'int8' or
data.dtype == 'uint8' and kernel.dtype == 'uint8')

assert data_layout == "NHWC" and kernel_layout == "HWIO"

data_expr, kernel_expr = inputs

data_int16 = relay.cast(data_expr, dtype='int16')
Expand Down Expand Up @@ -242,6 +244,9 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
if topi_tmpl == "conv2d_NHWC_quantized.arm_cpu":
assert (data.dtype == 'int8' and kernel.dtype == 'int8' or
data.dtype == 'uint8' and kernel.dtype == 'uint8')

assert data_layout == "NHWC" and kernel_layout == "HWIO"

CO, IC, KH, KW = get_const_tuple(kernel.shape)

K = KH * KW * IC
Expand Down

0 comments on commit c151f90

Please sign in to comment.