From 4e579c5de2cd32eedf97b275df19979a9eb52b48 Mon Sep 17 00:00:00 2001 From: Xu Jun Date: Mon, 4 Nov 2024 16:45:51 +0800 Subject: [PATCH] support kblock for x32 GIO packing --- cmake/gen/avx_microkernels.cmake | 2 + gen/avx_microkernels.bzl | 2 + scripts/generate-x32-packw.sh | 7 +- .../gen/x32-packw-x8-gemm-gio-avx-prfm.c | 5 +- .../gen/x32-packw-x8-gemm-gio-avx-u8-prfm.c | 129 ++++++++++++++++++ .../gen/x32-packw-x8-gemm-gio-avx-u8.c | 128 +++++++++++++++++ src/x32-packw/gen/x32-packw-x8-gemm-gio-avx.c | 5 +- src/x32-packw/gio-avx.c.in | 18 ++- src/x32-packw/x32-packw.h | 2 + 9 files changed, 289 insertions(+), 9 deletions(-) create mode 100644 src/x32-packw/gen/x32-packw-x8-gemm-gio-avx-u8-prfm.c create mode 100644 src/x32-packw/gen/x32-packw-x8-gemm-gio-avx-u8.c diff --git a/cmake/gen/avx_microkernels.cmake b/cmake/gen/avx_microkernels.cmake index 9b0e8ce527a..17e8a18c240 100644 --- a/cmake/gen/avx_microkernels.cmake +++ b/cmake/gen/avx_microkernels.cmake @@ -472,6 +472,8 @@ SET(NON_PROD_AVX_MICROKERNEL_SRCS src/x8-lut/gen/x8-lut-avx-u32.c src/x8-lut/gen/x8-lut-avx-u48.c src/x32-packw/gen/x32-packw-x8-gemm-gio-avx-prfm.c + src/x32-packw/gen/x32-packw-x8-gemm-gio-avx-u8-prfm.c + src/x32-packw/gen/x32-packw-x8-gemm-gio-avx-u8.c src/x32-packw/gen/x32-packw-x8-gemm-gio-avx.c src/x32-packw/gen/x32-packw-x8-gemm-goi-avx-u4-prfm.c src/x32-packw/gen/x32-packw-x8-gemm-goi-avx-u4.c diff --git a/gen/avx_microkernels.bzl b/gen/avx_microkernels.bzl index 3f204a24b79..727ab00ffa9 100644 --- a/gen/avx_microkernels.bzl +++ b/gen/avx_microkernels.bzl @@ -469,6 +469,8 @@ NON_PROD_AVX_MICROKERNEL_SRCS = [ "src/x8-lut/gen/x8-lut-avx-u32.c", "src/x8-lut/gen/x8-lut-avx-u48.c", "src/x32-packw/gen/x32-packw-x8-gemm-gio-avx-prfm.c", + "src/x32-packw/gen/x32-packw-x8-gemm-gio-avx-u8-prfm.c", + "src/x32-packw/gen/x32-packw-x8-gemm-gio-avx-u8.c", "src/x32-packw/gen/x32-packw-x8-gemm-gio-avx.c", "src/x32-packw/gen/x32-packw-x8-gemm-goi-avx-u4-prfm.c", "src/x32-packw/gen/x32-packw-x8-gemm-goi-avx-u4.c", diff --git a/scripts/generate-x32-packw.sh b/scripts/generate-x32-packw.sh index b122b5571f2..32d25560d30 100755 --- a/scripts/generate-x32-packw.sh +++ b/scripts/generate-x32-packw.sh @@ -81,8 +81,11 @@ tools/xngen src/x32-packw/s4-avx.c.in -D NR=8 -D SR=4 -D PREFETCH=1 -D KBLOCK=4 tools/xngen src/x32-packw/s4-avx.c.in -D NR=16 -D SR=4 -D PREFETCH=1 -D KBLOCK=4 -o src/x32-packw/gen/x32-packw-x16s4-gemm-goi-avx-u4-prfm.c & ### GUI NR multiple of 8 -tools/xngen src/x32-packw/gio-avx.c.in -D NR=8 -D PREFETCH=0 -o src/x32-packw/gen/x32-packw-x8-gemm-gio-avx.c & -tools/xngen src/x32-packw/gio-avx.c.in -D NR=8 -D PREFETCH=1 -o src/x32-packw/gen/x32-packw-x8-gemm-gio-avx-prfm.c & +tools/xngen src/x32-packw/gio-avx.c.in -D NR=8 -D PREFETCH=0 -D KBLOCK=0 -o src/x32-packw/gen/x32-packw-x8-gemm-gio-avx.c & +tools/xngen src/x32-packw/gio-avx.c.in -D NR=8 -D PREFETCH=1 -D KBLOCK=0 -o src/x32-packw/gen/x32-packw-x8-gemm-gio-avx-prfm.c & + +tools/xngen src/x32-packw/gio-avx.c.in -D NR=8 -D PREFETCH=0 -D KBLOCK=8 -o src/x32-packw/gen/x32-packw-x8-gemm-gio-avx-u8.c & +tools/xngen src/x32-packw/gio-avx.c.in -D NR=8 -D PREFETCH=1 -D KBLOCK=8 -o src/x32-packw/gen/x32-packw-x8-gemm-gio-avx-u8-prfm.c & ################################### x86 AVX512 ################################## ### NR multiple of 16 diff --git a/src/x32-packw/gen/x32-packw-x8-gemm-gio-avx-prfm.c b/src/x32-packw/gen/x32-packw-x8-gemm-gio-avx-prfm.c index 8cf38ee5534..ab8c5898f04 100644 --- a/src/x32-packw/gen/x32-packw-x8-gemm-gio-avx-prfm.c +++ b/src/x32-packw/gen/x32-packw-x8-gemm-gio-avx-prfm.c @@ -67,8 +67,9 @@ void xnn_x32_packw_gemm_gio_ukernel_x8__avx_prfm( packed_w += 8; // KC main loop - // todo: KBLOCK rows at a time - for (size_t k = kc; k > 0; --k) { + size_t k = kc; + + for (; k > 0; --k) { const __m256 v0 = _mm256_loadu_ps(w + 0); _mm256_store_ps(packed_w + 0, v0); w += k_stride; diff --git a/src/x32-packw/gen/x32-packw-x8-gemm-gio-avx-u8-prfm.c b/src/x32-packw/gen/x32-packw-x8-gemm-gio-avx-u8-prfm.c new file mode 100644 index 00000000000..456a8eea4f7 --- /dev/null +++ b/src/x32-packw/gen/x32-packw-x8-gemm-gio-avx-u8-prfm.c @@ -0,0 +1,129 @@ +// Auto-generated file. Do not edit! +// Template: src/x32-packw/gio-avx.c.in +// Generator: tools/xngen +// +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include +#include +#include + +#include + +#include "xnnpack/intrinsics-polyfill.h" +#include "xnnpack/packw.h" +#include "xnnpack/prefetch.h" + + +void xnn_x32_packw_gemm_gio_ukernel_x8__avx_u8_prfm( + size_t g, + size_t nc, + size_t kc, + size_t nr, + size_t kr, + size_t sr, + size_t k_stride, + const uint32_t* weights, + const uint32_t* bias, + const void* scale, + uint32_t* packed_weights, + size_t extra_bytes, + const void* params) +{ + assert(g != 0); + assert(nc != 0); + assert(kc != 0); + assert(nr == 8); // This kernel is for NR=8 + assert(kr == 1); + assert(sr == 1); + assert(k_stride != 0); + assert(weights != NULL); + assert(packed_weights != NULL); + + static const int32_t mask_table[16] = { + -1, -1, -1, -1, -1, -1, -1, -1, + 0, 0, 0, 0, 0, 0, 0, 0}; + + const float* b = (const float*) bias; + float* packed_w = (float*) packed_weights; + do { + // NC main loop multiple of 8 + const float* w = (const float*) weights; + size_t n = nc; + + for (; n >= 8; n -= 8) { + if XNN_LIKELY(b != NULL) { + const __m256 vb0 = _mm256_loadu_ps(b + 0); + _mm256_store_ps(packed_w + 0, vb0); + b += 8; + } else { + const __m256 vzero = _mm256_setzero_ps(); + _mm256_store_ps(packed_w + 0, vzero); + } + packed_w += 8; + + // KC main loop + size_t k = kc; + for (; k >= 8; k -= 8) { + const __m256 v0_0 = _mm256_loadu_ps(w + 0 + 0 * k_stride); + const __m256 v0_1 = _mm256_loadu_ps(w + 0 + 1 * k_stride); + const __m256 v0_2 = _mm256_loadu_ps(w + 0 + 2 * k_stride); + const __m256 v0_3 = _mm256_loadu_ps(w + 0 + 3 * k_stride); + const __m256 v0_4 = _mm256_loadu_ps(w + 0 + 4 * k_stride); + const __m256 v0_5 = _mm256_loadu_ps(w + 0 + 5 * k_stride); + const __m256 v0_6 = _mm256_loadu_ps(w + 0 + 6 * k_stride); + const __m256 v0_7 = _mm256_loadu_ps(w + 0 + 7 * k_stride); + _mm256_store_ps(packed_w + 0 + 0 * 8, v0_0); + _mm256_store_ps(packed_w + 0 + 1 * 8, v0_1); + _mm256_store_ps(packed_w + 0 + 2 * 8, v0_2); + _mm256_store_ps(packed_w + 0 + 3 * 8, v0_3); + _mm256_store_ps(packed_w + 0 + 4 * 8, v0_4); + _mm256_store_ps(packed_w + 0 + 5 * 8, v0_5); + _mm256_store_ps(packed_w + 0 + 6 * 8, v0_6); + _mm256_store_ps(packed_w + 0 + 7 * 8, v0_7); + w += k_stride * 8; + packed_w += 8 * 8; + } + + for (; k > 0; --k) { + const __m256 v0 = _mm256_loadu_ps(w + 0); + _mm256_store_ps(packed_w + 0, v0); + w += k_stride; + packed_w += 8; + } + w = w - kc * k_stride + 8; // Advance to next column of 8 floats + } + + // NC remainder (1..7) + if XNN_UNLIKELY(n != 0) { + assert(n >= 1); + assert(n <= 7); + if XNN_LIKELY(b != NULL) { + size_t nb = n; + do { + *packed_w++ = *b++; + } while (--nb != 0); + packed_w += (8 - n); + } else { + const __m256 vzero = _mm256_setzero_ps(); + _mm256_store_ps(packed_w, vzero); + packed_w += 8; + } + + const __m256i vmask0 = _mm256_loadu_si256((const __m256i*) &mask_table[8 -n]); + + // KC main loop + for (size_t k = kc; k > 0; --k) { + const __m256 v0 = _mm256_maskload_ps(w + 0, vmask0); + _mm256_maskstore_ps(packed_w + 0, vmask0, v0); + w += k_stride; + packed_w += 8; + } + } + weights += nc * kc; + } while (--g != 0); +} diff --git a/src/x32-packw/gen/x32-packw-x8-gemm-gio-avx-u8.c b/src/x32-packw/gen/x32-packw-x8-gemm-gio-avx-u8.c new file mode 100644 index 00000000000..a0c08baa110 --- /dev/null +++ b/src/x32-packw/gen/x32-packw-x8-gemm-gio-avx-u8.c @@ -0,0 +1,128 @@ +// Auto-generated file. Do not edit! +// Template: src/x32-packw/gio-avx.c.in +// Generator: tools/xngen +// +// Copyright 2024 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + + +#include +#include +#include + +#include + +#include "xnnpack/intrinsics-polyfill.h" +#include "xnnpack/packw.h" + + +void xnn_x32_packw_gemm_gio_ukernel_x8__avx_u8( + size_t g, + size_t nc, + size_t kc, + size_t nr, + size_t kr, + size_t sr, + size_t k_stride, + const uint32_t* weights, + const uint32_t* bias, + const void* scale, + uint32_t* packed_weights, + size_t extra_bytes, + const void* params) +{ + assert(g != 0); + assert(nc != 0); + assert(kc != 0); + assert(nr == 8); // This kernel is for NR=8 + assert(kr == 1); + assert(sr == 1); + assert(k_stride != 0); + assert(weights != NULL); + assert(packed_weights != NULL); + + static const int32_t mask_table[16] = { + -1, -1, -1, -1, -1, -1, -1, -1, + 0, 0, 0, 0, 0, 0, 0, 0}; + + const float* b = (const float*) bias; + float* packed_w = (float*) packed_weights; + do { + // NC main loop multiple of 8 + const float* w = (const float*) weights; + size_t n = nc; + + for (; n >= 8; n -= 8) { + if XNN_LIKELY(b != NULL) { + const __m256 vb0 = _mm256_loadu_ps(b + 0); + _mm256_store_ps(packed_w + 0, vb0); + b += 8; + } else { + const __m256 vzero = _mm256_setzero_ps(); + _mm256_store_ps(packed_w + 0, vzero); + } + packed_w += 8; + + // KC main loop + size_t k = kc; + for (; k >= 8; k -= 8) { + const __m256 v0_0 = _mm256_loadu_ps(w + 0 + 0 * k_stride); + const __m256 v0_1 = _mm256_loadu_ps(w + 0 + 1 * k_stride); + const __m256 v0_2 = _mm256_loadu_ps(w + 0 + 2 * k_stride); + const __m256 v0_3 = _mm256_loadu_ps(w + 0 + 3 * k_stride); + const __m256 v0_4 = _mm256_loadu_ps(w + 0 + 4 * k_stride); + const __m256 v0_5 = _mm256_loadu_ps(w + 0 + 5 * k_stride); + const __m256 v0_6 = _mm256_loadu_ps(w + 0 + 6 * k_stride); + const __m256 v0_7 = _mm256_loadu_ps(w + 0 + 7 * k_stride); + _mm256_store_ps(packed_w + 0 + 0 * 8, v0_0); + _mm256_store_ps(packed_w + 0 + 1 * 8, v0_1); + _mm256_store_ps(packed_w + 0 + 2 * 8, v0_2); + _mm256_store_ps(packed_w + 0 + 3 * 8, v0_3); + _mm256_store_ps(packed_w + 0 + 4 * 8, v0_4); + _mm256_store_ps(packed_w + 0 + 5 * 8, v0_5); + _mm256_store_ps(packed_w + 0 + 6 * 8, v0_6); + _mm256_store_ps(packed_w + 0 + 7 * 8, v0_7); + w += k_stride * 8; + packed_w += 8 * 8; + } + + for (; k > 0; --k) { + const __m256 v0 = _mm256_loadu_ps(w + 0); + _mm256_store_ps(packed_w + 0, v0); + w += k_stride; + packed_w += 8; + } + w = w - kc * k_stride + 8; // Advance to next column of 8 floats + } + + // NC remainder (1..7) + if XNN_UNLIKELY(n != 0) { + assert(n >= 1); + assert(n <= 7); + if XNN_LIKELY(b != NULL) { + size_t nb = n; + do { + *packed_w++ = *b++; + } while (--nb != 0); + packed_w += (8 - n); + } else { + const __m256 vzero = _mm256_setzero_ps(); + _mm256_store_ps(packed_w, vzero); + packed_w += 8; + } + + const __m256i vmask0 = _mm256_loadu_si256((const __m256i*) &mask_table[8 -n]); + + // KC main loop + for (size_t k = kc; k > 0; --k) { + const __m256 v0 = _mm256_maskload_ps(w + 0, vmask0); + _mm256_maskstore_ps(packed_w + 0, vmask0, v0); + w += k_stride; + packed_w += 8; + } + } + weights += nc * kc; + } while (--g != 0); +} diff --git a/src/x32-packw/gen/x32-packw-x8-gemm-gio-avx.c b/src/x32-packw/gen/x32-packw-x8-gemm-gio-avx.c index c83f10c7663..6e87958201c 100644 --- a/src/x32-packw/gen/x32-packw-x8-gemm-gio-avx.c +++ b/src/x32-packw/gen/x32-packw-x8-gemm-gio-avx.c @@ -66,8 +66,9 @@ void xnn_x32_packw_gemm_gio_ukernel_x8__avx( packed_w += 8; // KC main loop - // todo: KBLOCK rows at a time - for (size_t k = kc; k > 0; --k) { + size_t k = kc; + + for (; k > 0; --k) { const __m256 v0 = _mm256_loadu_ps(w + 0); _mm256_store_ps(packed_w + 0, v0); w += k_stride; diff --git a/src/x32-packw/gio-avx.c.in b/src/x32-packw/gio-avx.c.in index f4952971e8b..53930efbde1 100644 --- a/src/x32-packw/gio-avx.c.in +++ b/src/x32-packw/gio-avx.c.in @@ -19,7 +19,7 @@ $if PREFETCH: #include "xnnpack/prefetch.h" -void xnn_x32_packw_gemm_gio_ukernel_x${NR}__avx${"_prfm" if PREFETCH else ""}( +void xnn_x32_packw_gemm_gio_ukernel_x${NR}__avx${"_u" if KBLOCK else ""}${KBLOCK if KBLOCK else ""}${"_prfm" if PREFETCH else ""}( size_t g, size_t nc, size_t kc, @@ -71,8 +71,20 @@ void xnn_x32_packw_gemm_gio_ukernel_x${NR}__avx${"_prfm" if PREFETCH else ""}( packed_w += ${NR}; // KC main loop - // todo: KBLOCK rows at a time - for (size_t k = kc; k > 0; --k) { + size_t k = kc; + $if KBLOCK: + for (; k >= ${KBLOCK}; k -= ${KBLOCK}) { + $for N in range(0,NR,8): + $for K in range(0,KBLOCK): + const __m256 v${N}_${K} = _mm256_loadu_ps(w + ${N} + ${K} * k_stride); + $for N in range(0,NR,8): + $for K in range(0,KBLOCK): + _mm256_store_ps(packed_w + ${N} + ${K} * ${NR}, v${N}_${K}); + w += k_stride * ${KBLOCK}; + packed_w += ${NR} * ${KBLOCK}; + } + + for (; k > 0; --k) { $for N in range(0,NR,8): const __m256 v${N} = _mm256_loadu_ps(w + ${N}); $for N in range(0,NR,8): diff --git a/src/x32-packw/x32-packw.h b/src/x32-packw/x32-packw.h index cb82ac2a2e0..34941b2fc7d 100644 --- a/src/x32-packw/x32-packw.h +++ b/src/x32-packw/x32-packw.h @@ -68,6 +68,8 @@ XNN_UKERNEL(xnn_arch_x86_avx, xnn_x32_packw_gemm_goi_ukernel_x16s4__avx_u4_prfm, XNN_GIO_UKERNEL(xnn_arch_x86_avx, xnn_x32_packw_gemm_gio_ukernel_x8__avx, 8, 1, 1, 4, 1) XNN_GIO_UKERNEL(xnn_arch_x86_avx, xnn_x32_packw_gemm_gio_ukernel_x8__avx_prfm, 8, 1, 1, 4, 1) +XNN_GIO_UKERNEL(xnn_arch_x86_avx, xnn_x32_packw_gemm_gio_ukernel_x8__avx_u8, 8, 1, 1, 4, 1) +XNN_GIO_UKERNEL(xnn_arch_x86_avx, xnn_x32_packw_gemm_gio_ukernel_x8__avx_u8_prfm, 8, 1, 1, 4, 1) #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 #if XNN_ENABLE_AVX512F && (XNN_ARCH_X86_64 || XNN_ARCH_X86)