From d0c94d447ba438966c09a70a454f3ecf22fa5f55 Mon Sep 17 00:00:00 2001 From: Andrei Hutu Date: Wed, 23 Aug 2023 09:12:27 +0100 Subject: [PATCH] [Bugfix][Relay][Strategy] Enable compile time transformation of weights matrix for arm_cpu NHWC quantized conv2d (#15584) Fixed arm_cpu strategy bug which was causing tensorization errors when using the `AlterOpLayout` pass for the quantized NHWC conv2d schedules, as discovered in #10724. Therefore, we can now also enable the usage of `AlterOpLayout` for these schedules in order to transform the weight matrix at compile time, instead of runtime as before. I also modified the padding in `Conv2DGemmWeightTransformRel` and `interleave_transpose_weights` to reflect the changes made in #13669 and updated the AlterOpLayout tests accordingly. --- python/tvm/relay/op/strategy/arm_cpu.py | 33 +++++++++++-------- python/tvm/topi/arm_cpu/conv2d_alter_op.py | 23 +++++-------- src/relay/op/nn/convolution.cc | 4 +-- .../strategy/test_select_implementation.py | 20 +++++++++-- .../python/relay/test_pass_alter_op_layout.py | 2 +- 5 files changed, 48 insertions(+), 34 deletions(-) diff --git a/python/tvm/relay/op/strategy/arm_cpu.py b/python/tvm/relay/op/strategy/arm_cpu.py index f81354466664..b64c541863f7 100644 --- a/python/tvm/relay/op/strategy/arm_cpu.py +++ b/python/tvm/relay/op/strategy/arm_cpu.py @@ -468,24 +468,29 @@ def conv2d_gemm_without_weight_transform_strategy_arm_cpu(attrs, inputs, out_typ layout = attrs.data_layout data = inputs[0] strategy = _op.OpStrategy() + is_aarch64 = target.features.is_aarch64 + has_asimd = target.features.has_asimd + has_dot_prod = target.features.has_dotprod interleaved_compute = topi.arm_cpu.compute_conv2d_NHWC_quantized_interleaved_without_transform native_compute = topi.arm_cpu.compute_conv2d_NHWC_quantized_native_without_transform if layout == "NHWC" and data.dtype in ["int8", "uint8"]: - strategy.add_implementation( - wrap_compute_conv2d_gemm(native_compute), - wrap_topi_schedule( - topi.arm_cpu.schedule_conv2d_NHWC_quantized_native_without_transform - ), - name="conv2d_NHWC_quantized_native_without_transform.arm_cpu", - ) - strategy.add_implementation( - wrap_compute_conv2d_gemm(interleaved_compute), - wrap_topi_schedule( - topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved_without_transform - ), - name="conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu", - ) + if has_dot_prod: + strategy.add_implementation( + wrap_compute_conv2d_gemm(native_compute), + wrap_topi_schedule( + topi.arm_cpu.schedule_conv2d_NHWC_quantized_native_without_transform + ), + name="conv2d_NHWC_quantized_native_without_transform.arm_cpu", + ) + if is_aarch64 and has_asimd: + strategy.add_implementation( + wrap_compute_conv2d_gemm(interleaved_compute), + wrap_topi_schedule( + topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved_without_transform + ), + name="conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu", + ) else: raise RuntimeError( f"Unsupported conv2d_NHWC_quantized_without_transform layout {layout}" diff --git a/python/tvm/topi/arm_cpu/conv2d_alter_op.py b/python/tvm/topi/arm_cpu/conv2d_alter_op.py index b0fdb99cbe33..8ddb591397e4 100644 --- a/python/tvm/topi/arm_cpu/conv2d_alter_op.py +++ b/python/tvm/topi/arm_cpu/conv2d_alter_op.py @@ -77,8 +77,15 @@ def interleave_transpose_weights(inputs, data, kernel, interleave_A): if N % tile_rows_B != 0: pad_N = tile_rows_B - (N % tile_rows_B) - if K % tile_cols_B != 0: - pad_K = tile_cols_B - (K % tile_cols_B) + + # Tensorize will later make use of 4 tiles at once across the columns so make sure we pad such + # that the columns is multiple of 4 + column_multiplier = 4 + tile_cols_multiplied = tile_cols_B * column_multiplier + K_misalignment = K % tile_cols_multiplied + + if K_misalignment != 0: + pad_K = tile_cols_multiplied - K_misalignment N_padded = N + pad_N K_padded = K + pad_K @@ -434,12 +441,6 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): return relay.nn.contrib_conv2d_nchwc(*inputs, **new_attrs) if topi_tmpl == "conv2d_NHWC_quantized_interleaved.arm_cpu": - # TODO(masahi): This schedule can easily result in a tensorization error - # if used in the fallback mode - if cfg.is_fallback: # if is fallback, clear query cache and return None - autotvm.task.clear_fallback_cache(target, workload) - return None - assert data_layout == "NHWC" and kernel_layout == "HWIO" KH, KW, _, OC = get_const_tuple(kernel.shape) new_workload_name = "conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu" @@ -456,12 +457,6 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): inputs[0], new_kernel_expr, **new_attrs ) if topi_tmpl == "conv2d_NHWC_quantized_native.arm_cpu": - # TODO(masahi): This schedule can easily result in a tensorization error - # if used in the fallback mode - if cfg.is_fallback: # if is fallback, clear query cache and return None - autotvm.task.clear_fallback_cache(target, workload) - return None - assert data_layout == "NHWC" and kernel_layout == "HWIO" KH, KW, _, OC = get_const_tuple(kernel.shape) new_workload_name = "conv2d_NHWC_quantized_native_without_transform.arm_cpu" diff --git a/src/relay/op/nn/convolution.cc b/src/relay/op/nn/convolution.cc index e44d03833e52..13c7f74c7ecd 100644 --- a/src/relay/op/nn/convolution.cc +++ b/src/relay/op/nn/convolution.cc @@ -1510,10 +1510,10 @@ bool Conv2DGemmWeightTransformRel(const Array& types, int num_inputs, cons const auto K = weight->shape[0] * weight->shape[1] * weight->shape[2]; const auto N = weight->shape[3]; - auto K_mod_k = indexmod(K, k); + auto K_mod_k = indexmod(K, k * 4); auto N_mod_n = indexmod(N, n); - auto pad_K = tvm::if_then_else(K_mod_k != 0, k - K_mod_k, tir::make_zero(DataType::Int(32))); + auto pad_K = tvm::if_then_else(K_mod_k != 0, k * 4 - K_mod_k, tir::make_zero(DataType::Int(32))); auto pad_N = tvm::if_then_else(N_mod_n != 0, n - N_mod_n, tir::make_zero(DataType::Int(32))); const auto N_padded = N + pad_N; diff --git a/tests/python/relay/strategy/test_select_implementation.py b/tests/python/relay/strategy/test_select_implementation.py index 2bf1548d41d8..906ef2d161b0 100644 --- a/tests/python/relay/strategy/test_select_implementation.py +++ b/tests/python/relay/strategy/test_select_implementation.py @@ -24,7 +24,7 @@ import tvm from tvm import relay from tvm import te -from tvm.relay.testing import run_infer_type +from tvm.relay.testing import run_infer_type, run_opt_pass import tvm.testing from tvm import topi @@ -63,12 +63,24 @@ def test_concatenate(target, expected_implementation): ("llvm -device=arm_cpu", "conv2d_nhwc_spatial_pack.arm_cpu"), ( "llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+neon", - "conv2d_NHWC_quantized_interleaved.arm_cpu", + "conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu", ), ( "llvm -device=arm_cpu -mtriple=armv8l-linux-gnu -mattr=+neon", "conv2d_nhwc_spatial_pack.arm_cpu", ), + ( + "llvm -device=arm_cpu -mtriple=aarch64-linux-gnu", + "conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu", + ), + ( + "llvm --device=arm_cpu --mtriple=aarch64-linux-gnu -mattr=+v8.2a,+dotprod", + "conv2d_NHWC_quantized_native_without_transform.arm_cpu", + ), + ( + "llvm --device=arm_cpu --mtriple=aarch64-linux-gnu -mattr=+v8.2a,+i8mm", + "conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu", + ), ], ) def test_int8_conv2d(target, expected_impl): @@ -89,16 +101,18 @@ def test_int8_conv2d(target, expected_impl): channels=channels, data_layout=data_layout, kernel_layout=kernel_layout, + out_dtype=dtype, ) - out = run_infer_type(out) with target: + out = run_opt_pass(out, relay.transform.AlterOpLayout()) impl, _ = relay.backend.te_compiler.select_implementation( out.op, out.attrs, [te.placeholder(data_shape, dtype), te.placeholder(weight_shape, dtype)], out.checked_type, target, + use_autotvm=False, ) assert impl.name == expected_impl diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py index 4caab0ea095b..829c1d6ae43f 100644 --- a/tests/python/relay/test_pass_alter_op_layout.py +++ b/tests/python/relay/test_pass_alter_op_layout.py @@ -1230,7 +1230,7 @@ def test_alter_layout_nhwc_int8_aarch64(): """Check that AlterOplayout does not alter NHWC data layout.""" from tvm import autotvm - expected_workload_shape = (20, 42, 4, 16) + expected_workload_shape = (20, 44, 4, 16) # We use Int8Fallback to disable the fallback flag # and to test the new workload produced during the pass