Skip to content

Commit

Permalink
[Bugfix][Relay][Strategy] Enable compile time transformation of weigh…
Browse files Browse the repository at this point in the history
…ts matrix for arm_cpu NHWC quantized conv2d

Fixed arm_cpu strategy bug which was causing tensorization errors when using the `AlterOpLayout` pass for the quantized NHWC conv2d schedules, as discovered in apache#10724. Therefore, we can now also enable the usage of `AlterOpLayout` for these schedules in order to transform the weight matrix at compile time, instead of runtime as before.
I also modified the padding in `Conv2DGemmWeightTransformRel` and `interleave_transpose_weights` to reflect the changes made in apache#13669 and updated the AlterOpLayout tests accordingly.
  • Loading branch information
Anndrey24 committed Aug 17, 2023
1 parent f45ed30 commit 0741bdd
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 19 deletions.
9 changes: 7 additions & 2 deletions python/tvm/relay/op/strategy/arm_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,18 +468,23 @@ def conv2d_gemm_without_weight_transform_strategy_arm_cpu(attrs, inputs, out_typ
layout = attrs.data_layout
data = inputs[0]
strategy = _op.OpStrategy()
is_aarch64 = target.features.is_aarch64
has_asimd = target.features.has_asimd
has_dot_prod = target.features.has_dotprod

interleaved_compute = topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved_without_transform
native_compute = topi.arm_cpu.compute_conv2d_NHWC_quantized_native_without_transform
if layout == "NHWC" and data.dtype in ["int8", "uint8"]:
strategy.add_implementation(
if has_dot_prod:
strategy.add_implementation(
wrap_compute_conv2d_gemm(native_compute),
wrap_topi_schedule(
topi.arm_cpu.schedule_conv2d_NHWC_quantized_native_without_transform
),
name="conv2d_NHWC_quantized_native_without_transform.arm_cpu",
)
strategy.add_implementation(
if is_aarch64 and has_asimd:
strategy.add_implementation(
wrap_compute_conv2d_gemm(interleaved_compute),
wrap_topi_schedule(
topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved_without_transform
Expand Down
23 changes: 9 additions & 14 deletions python/tvm/topi/arm_cpu/conv2d_alter_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,15 @@ def interleave_transpose_weights(inputs, data, kernel, interleave_A):

if N % tile_rows_B != 0:
pad_N = tile_rows_B - (N % tile_rows_B)
if K % tile_cols_B != 0:
pad_K = tile_cols_B - (K % tile_cols_B)

# Tensorize will later make use of 4 tiles at once across the columns so make sure we pad such
# that the columns is multiple of 4
column_multiplier = 4
tile_cols_multiplied = tile_cols_B * column_multiplier
K_misalignment = K % tile_cols_multiplied

if K_misalignment != 0:
pad_K = tile_cols_multiplied - K_misalignment

N_padded = N + pad_N
K_padded = K + pad_K
Expand Down Expand Up @@ -434,12 +441,6 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
return relay.nn.contrib_conv2d_nchwc(*inputs, **new_attrs)

if topi_tmpl == "conv2d_NHWC_quantized_interleaved.arm_cpu":
# TODO(masahi): This schedule can easily result in a tensorization error
# if used in the fallback mode
if cfg.is_fallback: # if is fallback, clear query cache and return None
autotvm.task.clear_fallback_cache(target, workload)
return None

assert data_layout == "NHWC" and kernel_layout == "HWIO"
KH, KW, _, OC = get_const_tuple(kernel.shape)
new_workload_name = "conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu"
Expand All @@ -456,12 +457,6 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
inputs[0], new_kernel_expr, **new_attrs
)
if topi_tmpl == "conv2d_NHWC_quantized_native.arm_cpu":
# TODO(masahi): This schedule can easily result in a tensorization error
# if used in the fallback mode
if cfg.is_fallback: # if is fallback, clear query cache and return None
autotvm.task.clear_fallback_cache(target, workload)
return None

assert data_layout == "NHWC" and kernel_layout == "HWIO"
KH, KW, _, OC = get_const_tuple(kernel.shape)
new_workload_name = "conv2d_NHWC_quantized_native_without_transform.arm_cpu"
Expand Down
4 changes: 2 additions & 2 deletions src/relay/op/nn/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1510,10 +1510,10 @@ bool Conv2DGemmWeightTransformRel(const Array<Type>& types, int num_inputs, cons
const auto K = weight->shape[0] * weight->shape[1] * weight->shape[2];
const auto N = weight->shape[3];

auto K_mod_k = indexmod(K, k);
auto K_mod_k = indexmod(K, k*4);
auto N_mod_n = indexmod(N, n);

auto pad_K = tvm::if_then_else(K_mod_k != 0, k - K_mod_k, tir::make_zero(DataType::Int(32)));
auto pad_K = tvm::if_then_else(K_mod_k != 0, k*4 - K_mod_k, tir::make_zero(DataType::Int(32)));
auto pad_N = tvm::if_then_else(N_mod_n != 0, n - N_mod_n, tir::make_zero(DataType::Int(32)));

const auto N_padded = N + pad_N;
Expand Down
2 changes: 1 addition & 1 deletion tests/python/relay/test_pass_alter_op_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -1230,7 +1230,7 @@ def test_alter_layout_nhwc_int8_aarch64():
"""Check that AlterOplayout does not alter NHWC data layout."""
from tvm import autotvm

expected_workload_shape = (20, 42, 4, 16)
expected_workload_shape = (20, 44, 4, 16)

# We use Int8Fallback to disable the fallback flag
# and to test the new workload produced during the pass
Expand Down

0 comments on commit 0741bdd

Please sign in to comment.