From 76227e2948b1f847abef003de5f6d49ea0dd3171 Mon Sep 17 00:00:00 2001 From: Chip Kerchner Date: Fri, 6 Sep 2024 14:03:31 -0500 Subject: [PATCH 01/24] Initial commit for vectorized BF16 GEMV. Added GEMM_GEMV_FORWARD_BF16 to enable using BF16 GEMV for one dimension matrices. Updated unit test to support inc_x != 1 or inc_y for GEMV. --- Makefile.system | 6 +- cmake/system.cmake | 3 + interface/gemm.c | 2 +- kernel/power/KERNEL.POWER10 | 2 + kernel/power/KERNEL.POWER8 | 2 + kernel/power/KERNEL.POWER9 | 2 + kernel/power/sbgemv_common.c | 285 ++++++++++++++++++++++++++++++ kernel/power/sbgemv_n.c | 189 ++++++++++++++++++++ kernel/power/sbgemv_n_power10.c | 33 ++++ kernel/power/sbgemv_n_vsx.c | 303 ++++++++++++++++++++++++++++++++ kernel/power/sbgemv_t.c | 117 ++++++++++++ kernel/power/sbgemv_t_power10.c | 32 ++++ kernel/power/sbgemv_t_vsx.c | 286 ++++++++++++++++++++++++++++++ test/compare_sgemm_sbgemm.c | 31 ++-- 14 files changed, 1277 insertions(+), 16 deletions(-) create mode 100644 kernel/power/sbgemv_common.c create mode 100644 kernel/power/sbgemv_n.c create mode 100644 kernel/power/sbgemv_n_power10.c create mode 100644 kernel/power/sbgemv_n_vsx.c create mode 100644 kernel/power/sbgemv_t.c create mode 100644 kernel/power/sbgemv_t_power10.c create mode 100644 kernel/power/sbgemv_t_vsx.c diff --git a/Makefile.system b/Makefile.system index b065f9a981..8c030842a4 100644 --- a/Makefile.system +++ b/Makefile.system @@ -282,15 +282,19 @@ GEMM_GEMV_FORWARD = 1 endif ifeq ($(ARCH), power) GEMM_GEMV_FORWARD = 1 +GEMM_GEMV_FORWARD_BF16 = 1 endif ifeq ($(SMALL_MATRIX_OPT), 1) CCOMMON_OPT += -DSMALL_MATRIX_OPT endif -ifeq ($(GEMM_GEMV_FORWARD), 1) ifneq ($(ONLY_CBLAS), 1) +ifeq ($(GEMM_GEMV_FORWARD), 1) CCOMMON_OPT += -DGEMM_GEMV_FORWARD endif +ifeq ($(GEMM_GEMV_FORWARD_BF16), 1) +CCOMMON_OPT += -DGEMM_GEMV_FORWARD_BF16 +endif endif # This operation is expensive, so execution should be once. diff --git a/cmake/system.cmake b/cmake/system.cmake index a0b73ddae0..fb2d350abb 100644 --- a/cmake/system.cmake +++ b/cmake/system.cmake @@ -398,6 +398,9 @@ endif () if (GEMM_GEMV_FORWARD AND NOT ONLY_CBLAS) set(CCOMMON_OPT "${CCOMMON_OPT} -DGEMM_GEMV_FORWARD") endif () +if (GEMM_GEMV_FORWARD_BF16 AND NOT ONLY_CBLAS) + set(CCOMMON_OPT "${CCOMMON_OPT} -DGEMM_GEMV_FORWARD_BF16") +endif () if (SMALL_MATRIX_OPT) set(CCOMMON_OPT "${CCOMMON_OPT} -DSMALL_MATRIX_OPT") endif () diff --git a/interface/gemm.c b/interface/gemm.c index 64b8b620cf..7cd0884fad 100644 --- a/interface/gemm.c +++ b/interface/gemm.c @@ -498,7 +498,7 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS args.m, args.n, args.k, args.lda, args.ldb, args.ldc); #endif -#if defined(GEMM_GEMV_FORWARD) && !defined(GEMM3M) && !defined(COMPLEX) && !defined(BFLOAT16) +#if defined(GEMM_GEMV_FORWARD) && !defined(GEMM3M) && !defined(COMPLEX) && (!defined(BFLOAT16) || defined(GEMM_GEMV_FORWARD_BF16)) // Check if we can convert GEMM -> GEMV if (args.k != 0) { if (args.n == 1) { diff --git a/kernel/power/KERNEL.POWER10 b/kernel/power/KERNEL.POWER10 index c84cd91d2a..956b401fb2 100644 --- a/kernel/power/KERNEL.POWER10 +++ b/kernel/power/KERNEL.POWER10 @@ -228,11 +228,13 @@ ZSWAPKERNEL = zswap.c # SGEMVNKERNEL = sgemv_n.c +SBGEMVNKERNEL = sbgemv_n_power10.c DGEMVNKERNEL = dgemv_n_power10.c CGEMVNKERNEL = cgemv_n.c ZGEMVNKERNEL = zgemv_n_power10.c # SGEMVTKERNEL = sgemv_t.c +SBGEMVTKERNEL = sbgemv_t_power10.c DGEMVTKERNEL = dgemv_t_power10.c CGEMVTKERNEL = cgemv_t.c ZGEMVTKERNEL = zgemv_t_4.c diff --git a/kernel/power/KERNEL.POWER8 b/kernel/power/KERNEL.POWER8 index 700a68e447..001401d532 100644 --- a/kernel/power/KERNEL.POWER8 +++ b/kernel/power/KERNEL.POWER8 @@ -257,11 +257,13 @@ ZSWAPKERNEL = zswap.c # SGEMVNKERNEL = sgemv_n.c +SBGEMVNKERNEL = sbgemv_n_vsx.c DGEMVNKERNEL = dgemv_n.c CGEMVNKERNEL = cgemv_n.c ZGEMVNKERNEL = zgemv_n_4.c # SGEMVTKERNEL = sgemv_t.c +SBGEMVTKERNEL = sbgemv_t_vsx.c DGEMVTKERNEL = dgemv_t.c CGEMVTKERNEL = cgemv_t.c ZGEMVTKERNEL = zgemv_t_4.c diff --git a/kernel/power/KERNEL.POWER9 b/kernel/power/KERNEL.POWER9 index 7d007d1a2b..a18c31a2e9 100644 --- a/kernel/power/KERNEL.POWER9 +++ b/kernel/power/KERNEL.POWER9 @@ -181,11 +181,13 @@ ZSWAPKERNEL = zswap.c # SGEMVNKERNEL = sgemv_n.c +SBGEMVNKERNEL = sbgemv_n_vsx.c DGEMVNKERNEL = dgemv_n.c CGEMVNKERNEL = cgemv_n.c ZGEMVNKERNEL = zgemv_n_4.c # SGEMVTKERNEL = sgemv_t.c +SBGEMVTKERNEL = sbgemv_t_vsx.c DGEMVTKERNEL = dgemv_t.c CGEMVTKERNEL = cgemv_t.c ZGEMVTKERNEL = zgemv_t_4.c diff --git a/kernel/power/sbgemv_common.c b/kernel/power/sbgemv_common.c new file mode 100644 index 0000000000..2aadcca6ff --- /dev/null +++ b/kernel/power/sbgemv_common.c @@ -0,0 +1,285 @@ +/*************************************************************************** +Copyright (c) 2024, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#ifndef SBGEMV_COMMON_C +#define SBGEMV_COMMON_C +#include "common.h" + +#include + +#define FORCEINLINE inline __attribute__((always_inline)) + +#ifdef __clang__ +#define uint16_t unsigned short +#define uint32_t unsigned int +#define uint64_t unsigned long long +#endif + +#ifdef _ARCH_PWR10 +#ifdef __has_builtin +#if !__has_builtin(__builtin_vsx_assemble_pair) +#define __builtin_vsx_assemble_pair __builtin_mma_assemble_pair +#endif +#if !__has_builtin(__builtin_vsx_disassemble_pair) +#define __builtin_vsx_disassemble_pair __builtin_mma_disassemble_pair +#endif +#endif + +#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ +#define __builtin_vsx_assemble_pair2(vp0, v0, v1) __builtin_vsx_assemble_pair(vp0, v1, v0) +#else +#define __builtin_vsx_assemble_pair2(vp0, v0, v1) __builtin_vsx_assemble_pair(vp0, v0, v1) +#endif + +#define USE_VECTOR_PAIRS +#endif + +typedef __vector IFLOAT vec_bf16; +typedef __vector FLOAT vec_f32; +typedef __vector unsigned char vec_uc8; + +#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ +#define BF16_HI(data, zero) (vec_f32)vec_mergeh(data, zero) +#define BF16_LO(data, zero) (vec_f32)vec_mergel(data, zero) +#else +#define BF16_HI(data, zero) (vec_f32)vec_mergeh(zero, data) +#define BF16_LO(data, zero) (vec_f32)vec_mergel(zero, data) +#endif + +FORCEINLINE vec_uc8 vec_load_vec(void *src) +{ + return vec_xl(0, (unsigned char *)(src)); +} + +FORCEINLINE void vec_load_pair(vec_f32 *dst, vec_f32 *src) +{ +#ifdef USE_VECTOR_PAIRS + __vector_pair vy0p; + vy0p = *(__vector_pair *)(src); + __builtin_vsx_disassemble_pair((void *)(dst), &vy0p); +#else + dst[0] = src[0]; + dst[1] = src[1]; +#endif +} + +FORCEINLINE void vec_store_pair(vec_f32 *dst, vec_f32 *src) +{ +#ifdef USE_VECTOR_PAIRS + __vector_pair vy0p; + __builtin_vsx_assemble_pair2(&vy0p, (vec_uc8)src[1], (vec_uc8)src[0]); + *(__vector_pair *)(dst) = vy0p; +#else + dst[0] = src[0]; + dst[1] = src[1]; +#endif +} + +FORCEINLINE vec_bf16 vec_loadN(void *src, BLASLONG n) +{ + IFLOAT *src2 = (IFLOAT *)(src); +#ifdef _ARCH_PWR9 + return vec_xl_len(src2, n * sizeof(IFLOAT)); +#else + __attribute__((aligned(16))) IFLOAT data[sizeof(vec_bf16) / sizeof(IFLOAT)]; + memset(data, 0, sizeof(vec_bf16)); + if (n & 4) { + memcpy(data, src2, sizeof(uint64_t)); + } + if (n & 2) { + BLASLONG n4 = n & 4; + memcpy(data + n4, src2 + n4, sizeof(uint32_t)); + } + if (n & 1) { + BLASLONG n6 = n & 6; + data[n6] = src2[n6]; + } + return (vec_bf16)vec_load_vec(data); +#endif +} + +FORCEINLINE vec_f32 vec_loadNHi(void *src, BLASLONG n, vec_bf16 zero) +{ + vec_bf16 data = vec_loadN(src, n); + return BF16_HI(data, zero); +} + +FORCEINLINE vec_f32 vec_loadN_f32(void *src, BLASLONG n) +{ +#ifndef _ARCH_PWR9 + if (n & 4) { + return (vec_f32)vec_load_vec(src); + } +#endif + return (vec_f32)vec_loadN(src, n * (sizeof(FLOAT) / sizeof(IFLOAT))); +} + +FORCEINLINE void vec_storeN_f32(vec_f32 data, void *dst, BLASLONG n) +{ + FLOAT *dst2 = (FLOAT *)(dst); +#ifdef _ARCH_PWR9 + vec_xst_len(data, dst2, n * sizeof(FLOAT)); +#else + if (n & 4) { + vec_xst(data, 0, dst2); + return; + } + __attribute__((aligned(16))) FLOAT data2[sizeof(vec_f32) / sizeof(FLOAT)]; + vec_xst(data, 0, data2); + if (n & 2) { + memcpy(dst2, data2, sizeof(uint64_t)); + } + if (n & 1) { + BLASLONG n2 = n & 2; + dst2[n2] = data2[n2]; + } +#endif +} + +FORCEINLINE vec_f32 vec_mult(vec_f32 *inp, vec_bf16 in0, vec_bf16 zero) +{ + vec_f32 v_in00 = BF16_HI(in0, zero); + vec_f32 v_in01 = BF16_LO(in0, zero); + + return (inp[0] * v_in00) + (inp[1] * v_in01); +} + +FORCEINLINE vec_f32 vec_load_mult(vec_bf16 *in, vec_f32 *inp, vec_bf16 zero) +{ + vec_bf16 in0 = (vec_bf16)vec_load_vec(in); + + return vec_mult(inp, in0, zero); +} + +FORCEINLINE void vec_load_vec2(vec_bf16 *in, BLASLONG i, vec_f32 *v_x0, vec_bf16 zero) +{ + vec_bf16 inp = (vec_bf16)vec_load_vec(&in[i]); + + v_x0[0] = BF16_HI(inp, zero); + v_x0[1] = BF16_LO(inp, zero); +} + +FORCEINLINE void vec_mult2(vec_f32 v_x0, vec_bf16 in0, vec_bf16 zero, vec_f32 *vy0) +{ + vec_f32 v_in00 = BF16_HI(in0, zero); + vec_f32 v_in01 = BF16_LO(in0, zero); + + vy0[0] += (v_x0 * v_in00); + vy0[1] += (v_x0 * v_in01); +} + +FORCEINLINE void vec_load_mult2(vec_f32 v_x0, vec_bf16 *in, vec_bf16 zero, vec_f32 *vy0) +{ + vec_bf16 in0 = (vec_bf16)vec_load_vec(in); + + vec_mult2(v_x0, in0, zero, vy0); +} + +FORCEINLINE vec_f32 vec_loadN_mult(vec_bf16 *in, vec_f32 *inp, BLASLONG n, vec_bf16 zero) +{ + vec_bf16 in0 = vec_loadN(in, n); + + return vec_mult(inp, in0, zero); +} + +FORCEINLINE void vec_loadN_vec2(vec_bf16 *in, BLASLONG i, vec_f32 *v_x0, BLASLONG n, vec_bf16 zero) +{ + vec_bf16 inp = vec_loadN(&in[i], n); + + v_x0[0] = BF16_HI(inp, zero); + v_x0[1] = BF16_LO(inp, zero); +} + +FORCEINLINE void vec_loadN_mult2(vec_f32 v_x0, vec_bf16 *in, BLASLONG n, vec_bf16 zero, vec_f32 *vy0) +{ + vec_bf16 in0 = vec_loadN(in, n); + + vec_mult2(v_x0, in0, zero, vy0); +} + +FORCEINLINE vec_f32 vec_loadNHi_mult(vec_bf16 *in, vec_f32 v_inp0, BLASLONG n, vec_bf16 zero) +{ + vec_f32 v_in00 = vec_loadNHi(in, n, zero); + + return (v_inp0 * v_in00); +} + +FORCEINLINE vec_f32 vec_loadNHi_multi2(vec_f32 v_x0, vec_bf16 *in, BLASLONG n, vec_bf16 zero) +{ + vec_f32 v_in00 = vec_loadNHi(in, n, zero); + + return (v_x0 * v_in00); +} + +FORCEINLINE vec_f32 vec_loadNHi_vec(vec_bf16 *in, BLASLONG i, BLASLONG n, vec_bf16 zero) +{ + return vec_loadNHi(&in[i], n, zero); +} + +FORCEINLINE void copy_x(BLASLONG n, IFLOAT *src, IFLOAT *dest, BLASLONG inc_src) +{ + for (BLASLONG i = 0; i < n; i++) { + *dest++ = *src; + src += inc_src; + } +} + +FORCEINLINE void copy_y_beta(BLASLONG n, FLOAT *src, FLOAT *dest, BLASLONG inc_src, FLOAT beta) +{ + if (beta == 0) { + memset(dest, 0, sizeof(FLOAT) * n); + } else { + for (BLASLONG i = 0; i < n; i++) { + *dest++ = *src * beta; + src += inc_src; + } + } +} + +FORCEINLINE void copy_y(BLASLONG n, FLOAT *src, FLOAT *dest, BLASLONG inc_src, FLOAT beta) +{ + if (beta == 0) { + for (BLASLONG i = 0; i < n; i++) { + *dest = *src++; + dest += inc_src; + } + } else { + for (BLASLONG i = 0; i < n; i++) { + *dest = *src++ + (beta * *dest); + dest += inc_src; + } + } +} + +FORCEINLINE void add_y(BLASLONG n, FLOAT *src, FLOAT *dest, BLASLONG inc_dest) +{ + for (BLASLONG i = 0; i < n; i++) { + *dest += *src++; + dest += inc_dest; + } +} +#endif diff --git a/kernel/power/sbgemv_n.c b/kernel/power/sbgemv_n.c new file mode 100644 index 0000000000..854ad93ee2 --- /dev/null +++ b/kernel/power/sbgemv_n.c @@ -0,0 +1,189 @@ +/*************************************************************************** +Copyright (c) 2024, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#ifndef SBGEMV_N_COMMON_C +#define SBGEMV_N_COMMON_C +static void BF16GEMV_N_beta(BLASLONG n, FLOAT *output_vector, FLOAT *input_vector, FLOAT beta) +{ + if (beta == 0) { + memset(output_vector, 0, sizeof(FLOAT) * n); + } else { + vec_f32 b = { beta, beta, beta, beta }; + + vec_f32 *in = (vec_f32 *)input_vector; + vec_f32 *out = (vec_f32 *)output_vector; + + BLASLONG n8 = n / 8; + BLASLONG i = 0; + vec_f32 v_inp0[2]; + + for (; i + 4 <= n8; i += 4) { + vec_f32 v_inp1[2], v_inp2[2], v_inp3[2]; + vec_load_pair(v_inp0, &in[(i * 2) + 0]); + vec_load_pair(v_inp1, &in[(i * 2) + 2]); + vec_load_pair(v_inp2, &in[(i * 2) + 4]); + vec_load_pair(v_inp3, &in[(i * 2) + 6]); + v_inp0[0] *= b; + v_inp0[1] *= b; + v_inp1[0] *= b; + v_inp1[1] *= b; + v_inp2[0] *= b; + v_inp2[1] *= b; + v_inp3[0] *= b; + v_inp3[1] *= b; + vec_store_pair(&out[(i * 2) + 0], v_inp0); + vec_store_pair(&out[(i * 2) + 2], v_inp1); + vec_store_pair(&out[(i * 2) + 4], v_inp2); + vec_store_pair(&out[(i * 2) + 6], v_inp3); + } + + for (; i < n8; i++) { + vec_load_pair(v_inp0, &in[(i * 2) + 0]); + v_inp0[0] *= b; + v_inp0[1] *= b; + vec_store_pair(&out[(i * 2) + 0], v_inp0); + } + + n &= 7; + if (n > 4) { + BLASLONG n3 = n & 3; + v_inp0[0] = in[(i * 2) + 0]; + v_inp0[1] = vec_loadN_f32(&in[(i * 2) + 1], n3); + v_inp0[0] *= b; + v_inp0[1] *= b; + out[(i * 2) + 0] = v_inp0[0]; + vec_storeN_f32(v_inp0[1], &out[(i * 2) + 1], n3); + } else if (n) { + v_inp0[0] = vec_loadN_f32(&in[(i * 2) + 0], n); + v_inp0[0] *= b; + vec_storeN_f32(v_inp0[0], &out[(i * 2) + 0], n); + } + } +} + +int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT *x, BLASLONG inc_x, FLOAT beta, FLOAT *y, BLASLONG inc_y) +{ + IFLOAT *x_ptr, *ap[4]; + IFLOAT xbuffer[8] __attribute__((aligned(16))); + FLOAT *y_ptr, *ybuffer; + FLOAT buffer[NBMAX] __attribute__((aligned(16))); + + if ((m < 1) || (n < 1)) return 0; + + ybuffer = buffer; + y_ptr = y; + + BLASLONG lda4 = lda << 2; + BLASLONG lda8 = lda << 3; + BLASLONG NB = NBMAX; + BLASLONG m2 = (m & (NBMAX - 1)); + + while (NB == NBMAX) { + m -= NB; + if (m < 0) { + if (m2 == 0) break; + NB = m2; + } + + if (inc_y != 1) { + copy_y_beta(NB, y_ptr, ybuffer, inc_y, beta); + } else { + ybuffer = y_ptr; + BF16GEMV_N_beta(NB, ybuffer, ybuffer, beta); + } + + x_ptr = x; + + ap[0] = a; + ap[1] = a + lda; + ap[2] = ap[1] + lda; + ap[3] = ap[2] + lda; + + if (inc_x == 1) { + for (BLASLONG j = 0; j + 8 <= n; j += 8) { + BF16GEMV_N_8(NB, ap, x_ptr, ybuffer, lda4, alpha); + ap[0] += lda8; + ap[1] += lda8; + ap[2] += lda8; + ap[3] += lda8; + x_ptr += 8; + } + if (n & 4) { + BF16GEMV_N_4(NB, ap, x_ptr, ybuffer, alpha); + ap[0] += lda4; + ap[1] += lda4; + x_ptr += 4; + } + if (n & 2) { + BF16GEMV_N_2(NB, ap, x_ptr, ybuffer, alpha); + ap[0] += (lda * 2); + x_ptr += 2; + } + if (n & 1) { + BF16GEMV_N_1(NB, ap, x_ptr, ybuffer, alpha); + } + } else { + for (BLASLONG j = 0; j + 8 <= n; j += 8) { + copy_x(8, x_ptr, xbuffer, inc_x); + BF16GEMV_N_8(NB, ap, xbuffer, ybuffer, lda4, alpha); + ap[0] += lda8; + ap[1] += lda8; + ap[2] += lda8; + ap[3] += lda8; + x_ptr += 8 * inc_x; + } + if (n & 4) { + copy_x(4, x_ptr, xbuffer, inc_x); + BF16GEMV_N_4(NB, ap, xbuffer, ybuffer, alpha); + ap[0] += lda4; + ap[1] += lda4; + x_ptr += 4 * inc_x; + } + if (n & 2) { + copy_x(2, x_ptr, xbuffer, inc_x); + BF16GEMV_N_2(NB, ap, xbuffer, ybuffer, alpha); + ap[0] += (lda * 2); + x_ptr += 2 * inc_x; + } + if (n & 1) { + copy_x(1, x_ptr, xbuffer, inc_x); + BF16GEMV_N_1(NB, ap, xbuffer, ybuffer, alpha); + } + } + + a += NB; + if (inc_y != 1) { + add_y(NB, ybuffer, y_ptr, inc_y); + y_ptr += (NB * inc_y); + } else { + y_ptr += NB; + } + } + + return 0; +} +#endif diff --git a/kernel/power/sbgemv_n_power10.c b/kernel/power/sbgemv_n_power10.c new file mode 100644 index 0000000000..fc83b38c37 --- /dev/null +++ b/kernel/power/sbgemv_n_power10.c @@ -0,0 +1,33 @@ +/*************************************************************************** +Copyright (c) 2024, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +//#include "sbgemv_common.c" + +#include "sbgemv_n_vsx.c" + +//#include "sbgemv_n.c" + diff --git a/kernel/power/sbgemv_n_vsx.c b/kernel/power/sbgemv_n_vsx.c new file mode 100644 index 0000000000..ddbf908b3f --- /dev/null +++ b/kernel/power/sbgemv_n_vsx.c @@ -0,0 +1,303 @@ +/*************************************************************************** +Copyright (c) 2024, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#ifndef SBGEMV_N_VSX +#define SBGEMV_N_VSX + +#include "sbgemv_common.c" + +#define NBMAX 4096 + +static void BF16GEMV_N_VSX_1(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOAT alpha) +{ + IFLOAT *a0; + vec_bf16 zero = { 0, 0, 0, 0, 0, 0, 0, 0 }; + vec_f32 v_alpha = { alpha, alpha, alpha, alpha }; + + a0 = ap[0]; + + vec_bf16 *va0 = (vec_bf16 *)a0; + + vec_bf16 *x_bf = (vec_bf16 *)(xo); + vec_f32 x_0 = vec_loadNHi(x_bf, 1, zero); + x_0 *= v_alpha; + + vec_f32 v_x0 = vec_splat(x_0, 0); + + vec_f32 *v_y = (vec_f32 *)y; + BLASLONG n8 = n / 8; + BLASLONG i = 0; + vec_f32 vy0[2]; + + for (; i < n8; i++) { + vec_load_pair(vy0, &v_y[(i * 2) + 0]); + + vec_load_mult2(v_x0, &va0[i], zero, vy0); + + vec_store_pair(&v_y[(i * 2) + 0], vy0); + } + + n &= 7; + if (n > 4) { + BLASLONG n3 = n & 3; + vy0[0] = v_y[(i * 2) + 0]; + vy0[1] = vec_loadN_f32(&v_y[(i * 2) + 1], n3); + + vec_loadN_mult2(v_x0, &va0[i], n, zero, vy0); + + v_y[(i * 2) + 0] = vy0[0]; + vec_storeN_f32(vy0[1], &v_y[(i * 2) + 1], n3); + } else if (n) { + vec_f32 vy0 = vec_loadN_f32(&v_y[(i * 2) + 0], n); + + vy0 += vec_loadNHi_multi2(v_x0, &va0[i], n, zero); + + vec_storeN_f32(vy0, &v_y[(i * 2) + 0], n); + } +} + +static void BF16GEMV_N_VSX_2(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOAT alpha) +{ + IFLOAT *a0, *a1; + vec_bf16 zero = { 0, 0, 0, 0, 0, 0, 0, 0 }; + vec_f32 v_alpha = { alpha, alpha, alpha, alpha }; + + a0 = ap[0]; + a1 = ap[1]; + + vec_bf16 *va0 = (vec_bf16 *)a0; + vec_bf16 *va1 = (vec_bf16 *)a1; + + vec_bf16 *x_bf = (vec_bf16 *)(xo); + vec_f32 x_0 = vec_loadNHi(x_bf, 2, zero); + x_0 *= v_alpha; + + vec_f32 v_x0 = vec_splat(x_0, 0); + vec_f32 v_x1 = vec_splat(x_0, 1); + + vec_f32 *v_y = (vec_f32 *)y; + BLASLONG n8 = n / 8; + BLASLONG i = 0; + vec_f32 vy0[2]; + + for (; i < n8; i++) { + vec_load_pair(vy0, &v_y[(i * 2) + 0]); + + vec_load_mult2(v_x0, &va0[i], zero, vy0); + vec_load_mult2(v_x1, &va1[i], zero, vy0); + + vec_store_pair(&v_y[(i * 2) + 0], vy0); + } + + n &= 7; + if (n > 4) { + BLASLONG n3 = n & 3; + vy0[0] = v_y[(i * 2) + 0]; + vy0[1] = vec_loadN_f32(&v_y[(i * 2) + 1], n3); + + vec_loadN_mult2(v_x0, &va0[i], n, zero, vy0); + vec_loadN_mult2(v_x1, &va1[i], n, zero, vy0); + + v_y[(i * 2) + 0] = vy0[0]; + vec_storeN_f32(vy0[1], &v_y[(i * 2) + 1], n3); + } else if (n) { + vec_f32 vy0 = vec_loadN_f32(&v_y[(i * 2) + 0], n); + + vy0 += vec_loadNHi_multi2(v_x0, &va0[i], n, zero); + vy0 += vec_loadNHi_multi2(v_x1, &va1[i], n, zero); + + vec_storeN_f32(vy0, &v_y[(i * 2) + 0], n); + } +} + +static void BF16GEMV_N_VSX_4(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOAT alpha) +{ + IFLOAT *a0, *a1, *a2, *a3; + vec_bf16 zero = { 0, 0, 0, 0, 0, 0, 0, 0 }; + vec_f32 v_alpha = { alpha, alpha, alpha, alpha }; + + a0 = ap[0]; + a1 = ap[1]; + a2 = ap[2]; + a3 = ap[3]; + + vec_bf16 *va0 = (vec_bf16 *)a0; + vec_bf16 *va1 = (vec_bf16 *)a1; + vec_bf16 *va2 = (vec_bf16 *)a2; + vec_bf16 *va3 = (vec_bf16 *)a3; + + vec_bf16 *x_bf = (vec_bf16 *)(xo); + vec_f32 x_0 = vec_loadNHi(x_bf, 4, zero); + x_0 *= v_alpha; + + vec_f32 v_x0 = vec_splat(x_0, 0); + vec_f32 v_x1 = vec_splat(x_0, 1); + vec_f32 v_x2 = vec_splat(x_0, 2); + vec_f32 v_x3 = vec_splat(x_0, 3); + + vec_f32 *v_y = (vec_f32 *)y; + BLASLONG n8 = n / 8; + BLASLONG i = 0; + vec_f32 vy0[2]; + + for (; i < n8; i++) { + vec_load_pair(vy0, &v_y[(i * 2) + 0]); + + vec_load_mult2(v_x0, &va0[i], zero, vy0); + vec_load_mult2(v_x1, &va1[i], zero, vy0); + vec_load_mult2(v_x2, &va2[i], zero, vy0); + vec_load_mult2(v_x3, &va3[i], zero, vy0); + + vec_store_pair(&v_y[(i * 2) + 0], vy0); + } + + n &= 7; + if (n > 4) { + BLASLONG n3 = n & 3; + vy0[0] = v_y[(i * 2) + 0]; + vy0[1] = vec_loadN_f32(&v_y[(i * 2) + 1], n3); + + vec_loadN_mult2(v_x0, &va0[i], n, zero, vy0); + vec_loadN_mult2(v_x1, &va1[i], n, zero, vy0); + vec_loadN_mult2(v_x2, &va2[i], n, zero, vy0); + vec_loadN_mult2(v_x3, &va3[i], n, zero, vy0); + + v_y[(i * 2) + 0] = vy0[0]; + vec_storeN_f32(vy0[1], &v_y[(i * 2) + 1], n3); + } else if (n) { + vec_f32 vy0 = vec_loadN_f32(&v_y[(i * 2) + 0], n); + + vy0 += vec_loadNHi_multi2(v_x0, &va0[i], n, zero); + vy0 += vec_loadNHi_multi2(v_x1, &va1[i], n, zero); + vy0 += vec_loadNHi_multi2(v_x2, &va2[i], n, zero); + vy0 += vec_loadNHi_multi2(v_x3, &va3[i], n, zero); + + vec_storeN_f32(vy0, &v_y[(i * 2) + 0], n); + } +} + +static void BF16GEMV_N_VSX_8(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, BLASLONG lda4, FLOAT alpha) +{ + IFLOAT *a0, *a1, *a2, *a3, *b0, *b1, *b2, *b3; + vec_bf16 zero = { 0, 0, 0, 0, 0, 0, 0, 0 }; + vec_f32 v_alpha = { alpha, alpha, alpha, alpha }; + + a0 = ap[0]; + a1 = ap[1]; + a2 = ap[2]; + a3 = ap[3]; + b0 = a0 + lda4; + b1 = a1 + lda4; + b2 = a2 + lda4; + b3 = a3 + lda4; + + vec_bf16 *va0 = (vec_bf16 *)a0; + vec_bf16 *va1 = (vec_bf16 *)a1; + vec_bf16 *va2 = (vec_bf16 *)a2; + vec_bf16 *va3 = (vec_bf16 *)a3; + vec_bf16 *vb0 = (vec_bf16 *)b0; + vec_bf16 *vb1 = (vec_bf16 *)b1; + vec_bf16 *vb2 = (vec_bf16 *)b2; + vec_bf16 *vb3 = (vec_bf16 *)b3; + + vec_bf16 *x_bf = (vec_bf16 *)(xo); + vec_bf16 x_in = (vec_bf16)vec_load_vec(x_bf); + vec_f32 x_0 = BF16_HI(x_in, zero); + vec_f32 x_1 = BF16_LO(x_in, zero); + x_0 *= v_alpha; + x_1 *= v_alpha; + + vec_f32 v_x0 = vec_splat(x_0, 0); + vec_f32 v_x1 = vec_splat(x_0, 1); + vec_f32 v_x2 = vec_splat(x_0, 2); + vec_f32 v_x3 = vec_splat(x_0, 3); + vec_f32 v_x4 = vec_splat(x_1, 0); + vec_f32 v_x5 = vec_splat(x_1, 1); + vec_f32 v_x6 = vec_splat(x_1, 2); + vec_f32 v_x7 = vec_splat(x_1, 3); + + vec_f32 *v_y = (vec_f32 *)y; + BLASLONG n8 = n / 8; + BLASLONG i = 0; + vec_f32 vy0[2]; + + for (; i < n8; i++) { + vec_load_pair(vy0, &v_y[(i * 2) + 0]); + + vec_load_mult2(v_x0, &va0[i], zero, vy0); + vec_load_mult2(v_x1, &va1[i], zero, vy0); + vec_load_mult2(v_x2, &va2[i], zero, vy0); + vec_load_mult2(v_x3, &va3[i], zero, vy0); + vec_load_mult2(v_x4, &vb0[i], zero, vy0); + vec_load_mult2(v_x5, &vb1[i], zero, vy0); + vec_load_mult2(v_x6, &vb2[i], zero, vy0); + vec_load_mult2(v_x7, &vb3[i], zero, vy0); + + vec_store_pair(&v_y[(i * 2) + 0], vy0); + } + + n &= 7; + if (n > 4) { + BLASLONG n3 = n & 3; + vy0[0] = v_y[(i * 2) + 0]; + vy0[1] = vec_loadN_f32(&v_y[(i * 2) + 1], n3); + + vec_loadN_mult2(v_x0, &va0[i], n, zero, vy0); + vec_loadN_mult2(v_x1, &va1[i], n, zero, vy0); + vec_loadN_mult2(v_x2, &va2[i], n, zero, vy0); + vec_loadN_mult2(v_x3, &va3[i], n, zero, vy0); + vec_loadN_mult2(v_x4, &vb0[i], n, zero, vy0); + vec_loadN_mult2(v_x5, &vb1[i], n, zero, vy0); + vec_loadN_mult2(v_x6, &vb2[i], n, zero, vy0); + vec_loadN_mult2(v_x7, &vb3[i], n, zero, vy0); + + v_y[(i * 2) + 0] = vy0[0]; + vec_storeN_f32(vy0[1], &v_y[(i * 2) + 1], n3); + } else + if (n) { + vec_f32 vy0 = vec_loadN_f32(&v_y[(i * 2) + 0], n); + + vy0 += vec_loadNHi_multi2(v_x0, &va0[i], n, zero); + vy0 += vec_loadNHi_multi2(v_x1, &va1[i], n, zero); + vy0 += vec_loadNHi_multi2(v_x2, &va2[i], n, zero); + vy0 += vec_loadNHi_multi2(v_x3, &va3[i], n, zero); + vy0 += vec_loadNHi_multi2(v_x4, &vb0[i], n, zero); + vy0 += vec_loadNHi_multi2(v_x5, &vb1[i], n, zero); + vy0 += vec_loadNHi_multi2(v_x6, &vb2[i], n, zero); + vy0 += vec_loadNHi_multi2(v_x7, &vb3[i], n, zero); + + vec_storeN_f32(vy0, &v_y[(i * 2) + 0], n); + } +} + +#define BF16GEMV_N_8 BF16GEMV_N_VSX_8 +#define BF16GEMV_N_4 BF16GEMV_N_VSX_4 +#define BF16GEMV_N_2 BF16GEMV_N_VSX_2 +#define BF16GEMV_N_1 BF16GEMV_N_VSX_1 + +#include "sbgemv_n.c" +#endif diff --git a/kernel/power/sbgemv_t.c b/kernel/power/sbgemv_t.c new file mode 100644 index 0000000000..f0c79fe77a --- /dev/null +++ b/kernel/power/sbgemv_t.c @@ -0,0 +1,117 @@ +/*************************************************************************** +Copyright (c) 2024, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#ifndef SBGEMV_T_COMMON_C +#define SBGEMV_T_COMMON_C +int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT *x, BLASLONG inc_x, FLOAT beta, FLOAT *y, BLASLONG inc_y) +{ + IFLOAT *xbuffer, *a_ptr; + IFLOAT buffer[NBMAX] __attribute__((aligned(16))); + FLOAT ybuffer[8] __attribute__((aligned(16))); + FLOAT *y_ptr; + + if ((m < 1) || (n < 1)) return 0; + + xbuffer = buffer; + + BLASLONG lda4 = lda << 2; + BLASLONG lda8 = lda << 3; + BLASLONG NB = NBMAX; + BLASLONG m2 = (m & (NBMAX - 1)); + + while (NB == NBMAX) { + m -= NB; + if (m < 0) { + if (m2 == 0) break; + NB = m2; + } + + a_ptr = a; + y_ptr = y; + + if (inc_x != 1) { + copy_x(NB, x, xbuffer, inc_x); + } else { + xbuffer = x; + } + + if (inc_y == 1) { + for (BLASLONG j = 0; j + 8 <= n; j += 8) { + BF16GEMV_T_8(NB, lda, a_ptr, xbuffer, y_ptr, alpha, beta); + y_ptr += 8; + a_ptr += lda8; + } + if (n & 4) { + BF16GEMV_T_4(NB, lda, a_ptr, xbuffer, y_ptr, alpha, beta); + y_ptr += 4; + a_ptr += lda4; + } + if (n & 2) { + BF16GEMV_T_2(NB, lda, a_ptr, xbuffer, y_ptr, alpha, beta); + y_ptr += 2; + a_ptr += (lda * 2); + } + if (n & 1) { + BF16GEMV_T_1(NB, lda, a_ptr, xbuffer, y_ptr, alpha, beta); + } + } else { + for (BLASLONG j = 0; j + 8 <= n; j += 8) { + memset(ybuffer, 0, sizeof(FLOAT) * 8); + BF16GEMV_T_8(NB, lda, a_ptr, xbuffer, ybuffer, alpha, beta); + copy_y(8, ybuffer, y_ptr, inc_y, beta); + y_ptr += 8 * inc_y; + a_ptr += lda8; + } + if (n & 4) { + memset(ybuffer, 0, sizeof(FLOAT) * 4); + BF16GEMV_T_4(NB, lda, a_ptr, xbuffer, ybuffer, alpha, beta); + copy_y(4, ybuffer, y_ptr, inc_y, beta); + y_ptr += 4 * inc_y; + a_ptr += lda4; + } + if (n & 2) { + memset(ybuffer, 0, sizeof(FLOAT) * 4); + BF16GEMV_T_2(NB, lda, a_ptr, xbuffer, ybuffer, alpha, beta); + copy_y(2, ybuffer, y_ptr, inc_y, beta); + y_ptr += 2 * inc_y; + a_ptr += (lda * 2); + } + if (n & 1) { + memset(ybuffer, 0, sizeof(FLOAT) * 4); + BF16GEMV_T_1(NB, lda, a_ptr, xbuffer, ybuffer, alpha, beta); + copy_y(1, ybuffer, y_ptr, inc_y, beta); + } + } + + a += NB; + x += NB * inc_x; + } + + return 0; +} +#endif + diff --git a/kernel/power/sbgemv_t_power10.c b/kernel/power/sbgemv_t_power10.c new file mode 100644 index 0000000000..08bc4237c7 --- /dev/null +++ b/kernel/power/sbgemv_t_power10.c @@ -0,0 +1,32 @@ +/*************************************************************************** +Copyright (c) 2024, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +//#include "sbgemv_common.c" + +#include "sbgemv_t_vsx.c" + +//#include "sbgemv_t.c" diff --git a/kernel/power/sbgemv_t_vsx.c b/kernel/power/sbgemv_t_vsx.c new file mode 100644 index 0000000000..7da894109b --- /dev/null +++ b/kernel/power/sbgemv_t_vsx.c @@ -0,0 +1,286 @@ +/*************************************************************************** +Copyright (c) 2024, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#ifndef SBGEMV_T_VSX +#define SBGEMV_T_VSX + +#include "sbgemv_common.c" + +#define NBMAX 4096 + +static void BF16GEMV_T_VSX_1(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha, FLOAT beta) +{ + IFLOAT *a0; + vec_bf16 *va0, *v_x; + vec_f32 temp0 = { 0, 0, 0, 0 }; + vec_bf16 zero = { 0, 0, 0, 0, 0, 0, 0, 0 }; + vec_f32 inp[2]; + + a0 = ap; + va0 = (vec_bf16 *)a0; + v_x = (vec_bf16 *)x; + BLASLONG n8 = n / 8; + BLASLONG i = 0; + + for (; i < n8; i++) { + vec_load_vec2(v_x, i, inp, zero); + + temp0 += vec_load_mult(&va0[i], inp, zero); + } + + n &= 7; + if (n > 4) { + vec_loadN_vec2(v_x, i, inp, n, zero); + + temp0 += vec_loadN_mult(&va0[i], inp, n, zero); + } else if (n) { + vec_f32 v_inp0 = vec_loadNHi_vec(v_x, i, n, zero); + + temp0 += vec_loadNHi_mult(&va0[i], v_inp0, n, zero); + } + + y[0] = (alpha * (temp0[0] + temp0[1] + temp0[2] + temp0[3])) + (beta * y[0]); +} + +static void BF16GEMV_T_VSX_2(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha, FLOAT beta) +{ + IFLOAT *a0, *a1; + vec_bf16 *va0, *va1, *v_x; + vec_f32 temp0 = { 0, 0, 0, 0 }; + vec_f32 temp1 = { 0, 0, 0, 0 }; + vec_bf16 zero = { 0, 0, 0, 0, 0, 0, 0, 0 }; + vec_f32 inp[2]; + + a0 = ap; + a1 = ap + lda; + va0 = (vec_bf16 *)a0; + va1 = (vec_bf16 *)a1; + v_x = (vec_bf16 *)x; + BLASLONG n8 = n / 8; + BLASLONG i = 0; + + for (; i < n8; i++) { + vec_load_vec2(v_x, i, inp, zero); + + temp0 += vec_load_mult(&va0[i], inp, zero); + temp1 += vec_load_mult(&va1[i], inp, zero); + } + + n &= 7; + if (n > 4) { + vec_loadN_vec2(v_x, i, inp, n, zero); + + temp0 += vec_loadN_mult(&va0[i], inp, n, zero); + temp1 += vec_loadN_mult(&va1[i], inp, n, zero); + } else if (n) { + vec_f32 v_inp0 = vec_loadNHi_vec(v_x, i, n, zero); + + temp0 += vec_loadNHi_mult(&va0[i], v_inp0, n, zero); + temp1 += vec_loadNHi_mult(&va1[i], v_inp0, n, zero); + } + + y[0] = (alpha * (temp0[0] + temp0[1] + temp0[2] + temp0[3])) + (beta * y[0]); + y[1] = (alpha * (temp1[0] + temp1[1] + temp1[2] + temp1[3])) + (beta * y[1]); +} + +static void BF16GEMV_T_VSX_4(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha, FLOAT beta) +{ + IFLOAT *a0, *a1, *a2, *a3; + vec_bf16 *va0, *va1, *va2, *va3, *v_x; + vec_f32 temp0 = { 0, 0, 0, 0 }; + vec_f32 temp1 = { 0, 0, 0, 0 }; + vec_f32 temp2 = { 0, 0, 0, 0 }; + vec_f32 temp3 = { 0, 0, 0, 0 }; + vec_bf16 zero = { 0, 0, 0, 0, 0, 0, 0, 0 }; + vec_f32 inp[2]; + + a0 = ap; + a1 = ap + lda; + a2 = a1 + lda; + a3 = a2 + lda; + va0 = (vec_bf16 *)a0; + va1 = (vec_bf16 *)a1; + va2 = (vec_bf16 *)a2; + va3 = (vec_bf16 *)a3; + v_x = (vec_bf16 *)x; + BLASLONG n8 = n / 8; + BLASLONG i = 0; + + for (; i < n8; i++) { + vec_load_vec2(v_x, i, inp, zero); + + temp0 += vec_load_mult(&va0[i], inp, zero); + temp1 += vec_load_mult(&va1[i], inp, zero); + temp2 += vec_load_mult(&va2[i], inp, zero); + temp3 += vec_load_mult(&va3[i], inp, zero); + } + + n &= 7; + if (n > 4) { + vec_loadN_vec2(v_x, i, inp, n, zero); + + temp0 += vec_loadN_mult(&va0[i], inp, n, zero); + temp1 += vec_loadN_mult(&va1[i], inp, n, zero); + temp2 += vec_loadN_mult(&va2[i], inp, n, zero); + temp3 += vec_loadN_mult(&va3[i], inp, n, zero); + } else if (n) { + vec_f32 v_inp0 = vec_loadNHi_vec(v_x, i, n, zero); + + temp0 += vec_loadNHi_mult(&va0[i], v_inp0, n, zero); + temp1 += vec_loadNHi_mult(&va1[i], v_inp0, n, zero); + temp2 += vec_loadNHi_mult(&va2[i], v_inp0, n, zero); + temp3 += vec_loadNHi_mult(&va3[i], v_inp0, n, zero); + } + + vec_f32 t0, t1, t2, t3; + vec_f32 a = { alpha, alpha, alpha, alpha }; + vec_f32 b = { beta, beta, beta, beta }; + vec_f32 *v_y = (vec_f32 *) y; + + t0 = vec_mergeh(temp0, temp2); + t1 = vec_mergel(temp0, temp2); + t2 = vec_mergeh(temp1, temp3); + t3 = vec_mergel(temp1, temp3); + temp0 = vec_mergeh(t0, t2); + temp1 = vec_mergel(t0, t2); + temp2 = vec_mergeh(t1, t3); + temp3 = vec_mergel(t1, t3); + temp0 += temp1 + temp2 + temp3; + + v_y[0] = (a * temp0) + (b * v_y[0]); +} + +static void BF16GEMV_T_VSX_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha, FLOAT beta) +{ + IFLOAT *a0, *a1, *a2, *a3, *a4, *a5, *a6, *a7; + vec_bf16 *va0, *va1, *va2, *va3, *va4, *va5, *va6, *va7, *v_x; + vec_f32 temp0 = { 0, 0, 0, 0 }; + vec_f32 temp1 = { 0, 0, 0, 0 }; + vec_f32 temp2 = { 0, 0, 0, 0 }; + vec_f32 temp3 = { 0, 0, 0, 0 }; + vec_f32 temp4 = { 0, 0, 0, 0 }; + vec_f32 temp5 = { 0, 0, 0, 0 }; + vec_f32 temp6 = { 0, 0, 0, 0 }; + vec_f32 temp7 = { 0, 0, 0, 0 }; + vec_bf16 zero = { 0, 0, 0, 0, 0, 0, 0, 0 }; + vec_f32 inp[2]; + + a0 = ap; + a1 = ap + lda; + a2 = a1 + lda; + a3 = a2 + lda; + a4 = a3 + lda; + a5 = a4 + lda; + a6 = a5 + lda; + a7 = a6 + lda; + va0 = (vec_bf16 *)a0; + va1 = (vec_bf16 *)a1; + va2 = (vec_bf16 *)a2; + va3 = (vec_bf16 *)a3; + va4 = (vec_bf16 *)a4; + va5 = (vec_bf16 *)a5; + va6 = (vec_bf16 *)a6; + va7 = (vec_bf16 *)a7; + v_x = (vec_bf16 *)x; + BLASLONG n8 = n / 8; + BLASLONG i = 0; + + for (; i < n8; i++) { + vec_load_vec2(v_x, i, inp, zero); + + temp0 += vec_load_mult(&va0[i], inp, zero); + temp1 += vec_load_mult(&va1[i], inp, zero); + temp2 += vec_load_mult(&va2[i], inp, zero); + temp3 += vec_load_mult(&va3[i], inp, zero); + temp4 += vec_load_mult(&va4[i], inp, zero); + temp5 += vec_load_mult(&va5[i], inp, zero); + temp6 += vec_load_mult(&va6[i], inp, zero); + temp7 += vec_load_mult(&va7[i], inp, zero); + } + + n &= 7; + if (n > 4) { + vec_loadN_vec2(v_x, i, inp, n, zero); + + temp0 += vec_loadN_mult(&va0[i], inp, n, zero); + temp1 += vec_loadN_mult(&va1[i], inp, n, zero); + temp2 += vec_loadN_mult(&va2[i], inp, n, zero); + temp3 += vec_loadN_mult(&va3[i], inp, n, zero); + temp4 += vec_loadN_mult(&va4[i], inp, n, zero); + temp5 += vec_loadN_mult(&va5[i], inp, n, zero); + temp6 += vec_loadN_mult(&va6[i], inp, n, zero); + temp7 += vec_loadN_mult(&va7[i], inp, n, zero); + } else if (n) { + vec_f32 v_inp0 = vec_loadNHi_vec(v_x, i, n, zero); + + temp0 += vec_loadNHi_mult(&va0[i], v_inp0, n, zero); + temp1 += vec_loadNHi_mult(&va1[i], v_inp0, n, zero); + temp2 += vec_loadNHi_mult(&va2[i], v_inp0, n, zero); + temp3 += vec_loadNHi_mult(&va3[i], v_inp0, n, zero); + temp4 += vec_loadNHi_mult(&va4[i], v_inp0, n, zero); + temp5 += vec_loadNHi_mult(&va5[i], v_inp0, n, zero); + temp6 += vec_loadNHi_mult(&va6[i], v_inp0, n, zero); + temp7 += vec_loadNHi_mult(&va7[i], v_inp0, n, zero); + } + + vec_f32 t0, t1, t2, t3; + vec_f32 a = { alpha, alpha, alpha, alpha }; + vec_f32 b = { beta, beta, beta, beta }; + vec_f32 *v_y = (vec_f32 *) y; + + t0 = vec_mergeh(temp0, temp2); + t1 = vec_mergel(temp0, temp2); + t2 = vec_mergeh(temp1, temp3); + t3 = vec_mergel(temp1, temp3); + temp0 = vec_mergeh(t0, t2); + temp1 = vec_mergel(t0, t2); + temp2 = vec_mergeh(t1, t3); + temp3 = vec_mergel(t1, t3); + temp0 += temp1 + temp2 + temp3; + + t0 = vec_mergeh(temp4, temp6); + t1 = vec_mergel(temp4, temp6); + t2 = vec_mergeh(temp5, temp7); + t3 = vec_mergel(temp5, temp7); + temp4 = vec_mergeh(t0, t2); + temp5 = vec_mergel(t0, t2); + temp6 = vec_mergeh(t1, t3); + temp7 = vec_mergel(t1, t3); + temp4 += temp5 + temp6 + temp7; + + v_y[0] = (a * temp0) + (b * v_y[0]); + v_y[1] = (a * temp4) + (b * v_y[1]); +} + +#define BF16GEMV_T_8 BF16GEMV_T_VSX_8 +#define BF16GEMV_T_4 BF16GEMV_T_VSX_4 +#define BF16GEMV_T_2 BF16GEMV_T_VSX_2 +#define BF16GEMV_T_1 BF16GEMV_T_VSX_1 + +#include "sbgemv_t.c" +#endif + diff --git a/test/compare_sgemm_sbgemm.c b/test/compare_sgemm_sbgemm.c index b8aaee8be3..a86c73d1c5 100644 --- a/test/compare_sgemm_sbgemm.c +++ b/test/compare_sgemm_sbgemm.c @@ -202,16 +202,18 @@ main (int argc, char *argv[]) return ret; } + for (l = 0; l < 2; l++) { // l = 1 to test inc_x & inc_y not equal to one. for (x = 1; x <= loop; x++) { - k = (x == 0) ? 0 : 1; + m = l + 1; + k = (x == 0) ? 0 : m; float *A = (float *)malloc_safe(x * x * sizeof(FLOAT)); - float *B = (float *)malloc_safe(x * sizeof(FLOAT)); - float *C = (float *)malloc_safe(x * sizeof(FLOAT)); + float *B = (float *)malloc_safe(x * sizeof(FLOAT) * m); + float *C = (float *)malloc_safe(x * sizeof(FLOAT) * m); bfloat16_bits *AA = (bfloat16_bits *)malloc_safe(x * x * sizeof(bfloat16_bits)); - bfloat16_bits *BB = (bfloat16_bits *)malloc_safe(x * sizeof(bfloat16_bits)); + bfloat16_bits *BB = (bfloat16_bits *)malloc_safe(x * sizeof(bfloat16_bits) * m); float *DD = (float *)malloc_safe(x * sizeof(FLOAT)); - float *CC = (float *)malloc_safe(x * sizeof(FLOAT)); + float *CC = (float *)malloc_safe(x * sizeof(FLOAT) * m); if ((A == NULL) || (B == NULL) || (C == NULL) || (AA == NULL) || (BB == NULL) || (DD == NULL) || (CC == NULL)) return 1; @@ -226,9 +228,9 @@ main (int argc, char *argv[]) sbstobf16_(&one, &A[j*x+i], &one, &atmp, &one); AA[j * x + i].v = atmp; } - B[j] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; - sbstobf16_(&one, &B[j], &one, &btmp, &one); - BB[j].v = btmp; + B[j*m] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; + sbstobf16_(&one, &B[j*m], &one, &btmp, &one); + BB[j*m].v = btmp; } for (y = 0; y < 2; y++) { @@ -238,9 +240,9 @@ main (int argc, char *argv[]) transA = 'T'; } - memset(CC, 0, x * sizeof(FLOAT)); + memset(CC, 0, x * m * sizeof(FLOAT)); memset(DD, 0, x * sizeof(FLOAT)); - memset(C, 0, x * sizeof(FLOAT)); + memset(C, 0, x * m * sizeof(FLOAT)); SGEMV (&transA, &x, &x, &alpha, A, &x, B, &k, &beta, C, &k); SBGEMV (&transA, &x, &x, &alpha, (bfloat16*) AA, &x, (bfloat16*) BB, &k, &beta, CC, &k); @@ -248,15 +250,15 @@ main (int argc, char *argv[]) for (j = 0; j < x; j++) for (i = 0; i < x; i++) if (transA == 'N') { - DD[i] += float16to32 (AA[j * x + i]) * float16to32 (BB[j]); + DD[i] += float16to32 (AA[j * x + i]) * float16to32 (BB[j*m]); } else if (transA == 'T') { - DD[j] += float16to32 (AA[j * x + i]) * float16to32 (BB[i]); + DD[j] += float16to32 (AA[j * x + i]) * float16to32 (BB[i*m]); } for (j = 0; j < x; j++) { - if (fabs (CC[j] - C[j]) > 1.0) + if (fabs (CC[j*m] - C[j*m]) > 1.0) ret++; - if (fabs (CC[j] - DD[j]) > 1.0) + if (fabs (CC[j*m] - DD[j]) > 1.0) ret++; } } @@ -268,6 +270,7 @@ main (int argc, char *argv[]) free(DD); free(CC); } + } if (ret != 0) fprintf (stderr, "FATAL ERROR SBGEMV - Return code: %d\n", ret); From 8541b25e1d755f7e05594547184bd88bda23a5af Mon Sep 17 00:00:00 2001 From: Chip Kerchner Date: Fri, 6 Sep 2024 14:48:48 -0500 Subject: [PATCH 02/24] Special case beta is one. --- kernel/power/sbgemv_common.c | 10 ++++++++++ kernel/power/sbgemv_n.c | 2 ++ 2 files changed, 12 insertions(+) diff --git a/kernel/power/sbgemv_common.c b/kernel/power/sbgemv_common.c index 2aadcca6ff..b11ab59de8 100644 --- a/kernel/power/sbgemv_common.c +++ b/kernel/power/sbgemv_common.c @@ -252,6 +252,11 @@ FORCEINLINE void copy_y_beta(BLASLONG n, FLOAT *src, FLOAT *dest, BLASLONG inc_s { if (beta == 0) { memset(dest, 0, sizeof(FLOAT) * n); + } else if (beta == 1) { + for (BLASLONG i = 0; i < n; i++) { + *dest++ = *src; + src += inc_src; + } } else { for (BLASLONG i = 0; i < n; i++) { *dest++ = *src * beta; @@ -267,6 +272,11 @@ FORCEINLINE void copy_y(BLASLONG n, FLOAT *src, FLOAT *dest, BLASLONG inc_src, F *dest = *src++; dest += inc_src; } + } else if (beta == 1) { + for (BLASLONG i = 0; i < n; i++) { + *dest += *src++; + dest += inc_src; + } } else { for (BLASLONG i = 0; i < n; i++) { *dest = *src++ + (beta * *dest); diff --git a/kernel/power/sbgemv_n.c b/kernel/power/sbgemv_n.c index 854ad93ee2..db64915e05 100644 --- a/kernel/power/sbgemv_n.c +++ b/kernel/power/sbgemv_n.c @@ -31,6 +31,8 @@ static void BF16GEMV_N_beta(BLASLONG n, FLOAT *output_vector, FLOAT *input_vecto { if (beta == 0) { memset(output_vector, 0, sizeof(FLOAT) * n); + } else if ((output_vector != input_vector) && (beta == 1)) { + memcpy(output_vector, input_vector, sizeof(FLOAT) * n); } else { vec_f32 b = { beta, beta, beta, beta }; From 39fd29f1de36763c77c8bfe5acb8a6337046f748 Mon Sep 17 00:00:00 2001 From: Chip Kerchner Date: Sun, 8 Sep 2024 18:28:31 -0500 Subject: [PATCH 03/24] Minor improvement and turn off BF16 GEMV forwarding by default. --- Makefile.system | 1 - kernel/power/sbgemv_n.c | 6 ++++-- test/compare_sgemm_sbgemm.c | 29 ++++++++++++++--------------- 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/Makefile.system b/Makefile.system index 8c030842a4..2c5ca96906 100644 --- a/Makefile.system +++ b/Makefile.system @@ -282,7 +282,6 @@ GEMM_GEMV_FORWARD = 1 endif ifeq ($(ARCH), power) GEMM_GEMV_FORWARD = 1 -GEMM_GEMV_FORWARD_BF16 = 1 endif ifeq ($(SMALL_MATRIX_OPT), 1) diff --git a/kernel/power/sbgemv_n.c b/kernel/power/sbgemv_n.c index db64915e05..fa7df858f8 100644 --- a/kernel/power/sbgemv_n.c +++ b/kernel/power/sbgemv_n.c @@ -31,8 +31,10 @@ static void BF16GEMV_N_beta(BLASLONG n, FLOAT *output_vector, FLOAT *input_vecto { if (beta == 0) { memset(output_vector, 0, sizeof(FLOAT) * n); - } else if ((output_vector != input_vector) && (beta == 1)) { - memcpy(output_vector, input_vector, sizeof(FLOAT) * n); + } else if (beta == 1) { + if (output_vector != input_vector) { + memcpy(output_vector, input_vector, sizeof(FLOAT) * n); + } } else { vec_f32 b = { beta, beta, beta, beta }; diff --git a/test/compare_sgemm_sbgemm.c b/test/compare_sgemm_sbgemm.c index a86c73d1c5..05d9b33aba 100644 --- a/test/compare_sgemm_sbgemm.c +++ b/test/compare_sgemm_sbgemm.c @@ -205,15 +205,14 @@ main (int argc, char *argv[]) for (l = 0; l < 2; l++) { // l = 1 to test inc_x & inc_y not equal to one. for (x = 1; x <= loop; x++) { - m = l + 1; - k = (x == 0) ? 0 : m; + k = (x == 0) ? 0 : l + 1; float *A = (float *)malloc_safe(x * x * sizeof(FLOAT)); - float *B = (float *)malloc_safe(x * sizeof(FLOAT) * m); - float *C = (float *)malloc_safe(x * sizeof(FLOAT) * m); + float *B = (float *)malloc_safe(x * sizeof(FLOAT) << l); + float *C = (float *)malloc_safe(x * sizeof(FLOAT) << l); bfloat16_bits *AA = (bfloat16_bits *)malloc_safe(x * x * sizeof(bfloat16_bits)); - bfloat16_bits *BB = (bfloat16_bits *)malloc_safe(x * sizeof(bfloat16_bits) * m); + bfloat16_bits *BB = (bfloat16_bits *)malloc_safe(x * sizeof(bfloat16_bits) << l); float *DD = (float *)malloc_safe(x * sizeof(FLOAT)); - float *CC = (float *)malloc_safe(x * sizeof(FLOAT) * m); + float *CC = (float *)malloc_safe(x * sizeof(FLOAT) << l); if ((A == NULL) || (B == NULL) || (C == NULL) || (AA == NULL) || (BB == NULL) || (DD == NULL) || (CC == NULL)) return 1; @@ -228,9 +227,9 @@ main (int argc, char *argv[]) sbstobf16_(&one, &A[j*x+i], &one, &atmp, &one); AA[j * x + i].v = atmp; } - B[j*m] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; - sbstobf16_(&one, &B[j*m], &one, &btmp, &one); - BB[j*m].v = btmp; + B[j << l] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; + sbstobf16_(&one, &B[j << l], &one, &btmp, &one); + BB[j << l].v = btmp; } for (y = 0; y < 2; y++) { @@ -240,9 +239,9 @@ main (int argc, char *argv[]) transA = 'T'; } - memset(CC, 0, x * m * sizeof(FLOAT)); + memset(CC, 0, x * sizeof(FLOAT) << l); memset(DD, 0, x * sizeof(FLOAT)); - memset(C, 0, x * m * sizeof(FLOAT)); + memset(C, 0, x * sizeof(FLOAT) << l); SGEMV (&transA, &x, &x, &alpha, A, &x, B, &k, &beta, C, &k); SBGEMV (&transA, &x, &x, &alpha, (bfloat16*) AA, &x, (bfloat16*) BB, &k, &beta, CC, &k); @@ -250,15 +249,15 @@ main (int argc, char *argv[]) for (j = 0; j < x; j++) for (i = 0; i < x; i++) if (transA == 'N') { - DD[i] += float16to32 (AA[j * x + i]) * float16to32 (BB[j*m]); + DD[i] += float16to32 (AA[j * x + i]) * float16to32 (BB[j << l]); } else if (transA == 'T') { - DD[j] += float16to32 (AA[j * x + i]) * float16to32 (BB[i*m]); + DD[j] += float16to32 (AA[j * x + i]) * float16to32 (BB[i << l]); } for (j = 0; j < x; j++) { - if (fabs (CC[j*m] - C[j*m]) > 1.0) + if (fabs (CC[j << l] - C[j << l]) > 1.0) ret++; - if (fabs (CC[j*m] - DD[j]) > 1.0) + if (fabs (CC[j << l] - DD[j]) > 1.0) ret++; } } From 2f142ee857e2c04118401a83f62bcba365a8f537 Mon Sep 17 00:00:00 2001 From: Chip Kerchner Date: Mon, 9 Sep 2024 14:41:55 -0500 Subject: [PATCH 04/24] More common code. --- kernel/power/sbgemv_common.c | 12 ++++++++++++ kernel/power/sbgemv_n.c | 6 ++---- kernel/power/sbgemv_n_vsx.c | 24 ++++++++---------------- 3 files changed, 22 insertions(+), 20 deletions(-) diff --git a/kernel/power/sbgemv_common.c b/kernel/power/sbgemv_common.c index b11ab59de8..1893eba516 100644 --- a/kernel/power/sbgemv_common.c +++ b/kernel/power/sbgemv_common.c @@ -138,6 +138,12 @@ FORCEINLINE vec_f32 vec_loadN_f32(void *src, BLASLONG n) return (vec_f32)vec_loadN(src, n * (sizeof(FLOAT) / sizeof(IFLOAT))); } +FORCEINLINE void vec_loadN2_f32(vec_f32 *data, vec_f32 *src, BLASLONG n) +{ + data[0] = src[0]; + data[1] = vec_loadN_f32(&src[1], n); +} + FORCEINLINE void vec_storeN_f32(vec_f32 data, void *dst, BLASLONG n) { FLOAT *dst2 = (FLOAT *)(dst); @@ -160,6 +166,12 @@ FORCEINLINE void vec_storeN_f32(vec_f32 data, void *dst, BLASLONG n) #endif } +FORCEINLINE void vec_storeN2_f32(vec_f32 *data, vec_f32 *dst, BLASLONG n) +{ + dst[0] = data[0]; + vec_storeN_f32(data[1], &dst[1], n); +} + FORCEINLINE vec_f32 vec_mult(vec_f32 *inp, vec_bf16 in0, vec_bf16 zero) { vec_f32 v_in00 = BF16_HI(in0, zero); diff --git a/kernel/power/sbgemv_n.c b/kernel/power/sbgemv_n.c index fa7df858f8..05c02a0068 100644 --- a/kernel/power/sbgemv_n.c +++ b/kernel/power/sbgemv_n.c @@ -75,12 +75,10 @@ static void BF16GEMV_N_beta(BLASLONG n, FLOAT *output_vector, FLOAT *input_vecto n &= 7; if (n > 4) { BLASLONG n3 = n & 3; - v_inp0[0] = in[(i * 2) + 0]; - v_inp0[1] = vec_loadN_f32(&in[(i * 2) + 1], n3); + vec_loadN2_f32(v_inp0, &in[(i * 2) + 0], n3); v_inp0[0] *= b; v_inp0[1] *= b; - out[(i * 2) + 0] = v_inp0[0]; - vec_storeN_f32(v_inp0[1], &out[(i * 2) + 1], n3); + vec_storeN2_f32(v_inp0, &out[(i * 2) + 0], n3); } else if (n) { v_inp0[0] = vec_loadN_f32(&in[(i * 2) + 0], n); v_inp0[0] *= b; diff --git a/kernel/power/sbgemv_n_vsx.c b/kernel/power/sbgemv_n_vsx.c index ddbf908b3f..45570950ea 100644 --- a/kernel/power/sbgemv_n_vsx.c +++ b/kernel/power/sbgemv_n_vsx.c @@ -64,13 +64,11 @@ static void BF16GEMV_N_VSX_1(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA n &= 7; if (n > 4) { BLASLONG n3 = n & 3; - vy0[0] = v_y[(i * 2) + 0]; - vy0[1] = vec_loadN_f32(&v_y[(i * 2) + 1], n3); + vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n3); vec_loadN_mult2(v_x0, &va0[i], n, zero, vy0); - v_y[(i * 2) + 0] = vy0[0]; - vec_storeN_f32(vy0[1], &v_y[(i * 2) + 1], n3); + vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n3); } else if (n) { vec_f32 vy0 = vec_loadN_f32(&v_y[(i * 2) + 0], n); @@ -116,14 +114,12 @@ static void BF16GEMV_N_VSX_2(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA n &= 7; if (n > 4) { BLASLONG n3 = n & 3; - vy0[0] = v_y[(i * 2) + 0]; - vy0[1] = vec_loadN_f32(&v_y[(i * 2) + 1], n3); + vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n3); vec_loadN_mult2(v_x0, &va0[i], n, zero, vy0); vec_loadN_mult2(v_x1, &va1[i], n, zero, vy0); - v_y[(i * 2) + 0] = vy0[0]; - vec_storeN_f32(vy0[1], &v_y[(i * 2) + 1], n3); + vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n3); } else if (n) { vec_f32 vy0 = vec_loadN_f32(&v_y[(i * 2) + 0], n); @@ -178,16 +174,14 @@ static void BF16GEMV_N_VSX_4(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA n &= 7; if (n > 4) { BLASLONG n3 = n & 3; - vy0[0] = v_y[(i * 2) + 0]; - vy0[1] = vec_loadN_f32(&v_y[(i * 2) + 1], n3); + vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n3); vec_loadN_mult2(v_x0, &va0[i], n, zero, vy0); vec_loadN_mult2(v_x1, &va1[i], n, zero, vy0); vec_loadN_mult2(v_x2, &va2[i], n, zero, vy0); vec_loadN_mult2(v_x3, &va3[i], n, zero, vy0); - v_y[(i * 2) + 0] = vy0[0]; - vec_storeN_f32(vy0[1], &v_y[(i * 2) + 1], n3); + vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n3); } else if (n) { vec_f32 vy0 = vec_loadN_f32(&v_y[(i * 2) + 0], n); @@ -263,8 +257,7 @@ static void BF16GEMV_N_VSX_8(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, BLAS n &= 7; if (n > 4) { BLASLONG n3 = n & 3; - vy0[0] = v_y[(i * 2) + 0]; - vy0[1] = vec_loadN_f32(&v_y[(i * 2) + 1], n3); + vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n3); vec_loadN_mult2(v_x0, &va0[i], n, zero, vy0); vec_loadN_mult2(v_x1, &va1[i], n, zero, vy0); @@ -275,8 +268,7 @@ static void BF16GEMV_N_VSX_8(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, BLAS vec_loadN_mult2(v_x6, &vb2[i], n, zero, vy0); vec_loadN_mult2(v_x7, &vb3[i], n, zero, vy0); - v_y[(i * 2) + 0] = vy0[0]; - vec_storeN_f32(vy0[1], &v_y[(i * 2) + 1], n3); + vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n3); } else if (n) { vec_f32 vy0 = vec_loadN_f32(&v_y[(i * 2) + 0], n); From 72216d28c256087363435642b6ec3d497902033d Mon Sep 17 00:00:00 2001 From: Chip Kerchner Date: Wed, 11 Sep 2024 08:47:32 -0500 Subject: [PATCH 05/24] Fix bug with inc_y adding results twice. --- kernel/power/sbgemv_common.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kernel/power/sbgemv_common.c b/kernel/power/sbgemv_common.c index 1893eba516..07f75d3183 100644 --- a/kernel/power/sbgemv_common.c +++ b/kernel/power/sbgemv_common.c @@ -300,7 +300,7 @@ FORCEINLINE void copy_y(BLASLONG n, FLOAT *src, FLOAT *dest, BLASLONG inc_src, F FORCEINLINE void add_y(BLASLONG n, FLOAT *src, FLOAT *dest, BLASLONG inc_dest) { for (BLASLONG i = 0; i < n; i++) { - *dest += *src++; + *dest = *src++; dest += inc_dest; } } From 7947970f9d5d88a9399c691a0911689c592f5d37 Mon Sep 17 00:00:00 2001 From: Chip Kerchner Date: Fri, 13 Sep 2024 06:22:13 -0500 Subject: [PATCH 06/24] Move common code. --- kernel/power/gemm_common.c | 148 +++++++++++++++++++++++++++++++++++ kernel/power/sbgemv_common.c | 133 +------------------------------ kernel/power/sbgemv_n.c | 2 +- kernel/power/sbgemv_n_vsx.c | 3 +- 4 files changed, 152 insertions(+), 134 deletions(-) create mode 100644 kernel/power/gemm_common.c diff --git a/kernel/power/gemm_common.c b/kernel/power/gemm_common.c new file mode 100644 index 0000000000..c33faffe0e --- /dev/null +++ b/kernel/power/gemm_common.c @@ -0,0 +1,148 @@ +#ifndef GEMM_COMMON_C +#define GEMM_COMMON_C +#include "common.h" + +#include + +#define FORCEINLINE inline __attribute__((always_inline)) + +#ifdef __clang__ +#define uint16_t unsigned short +#define uint32_t unsigned int +#define uint64_t unsigned long long +#endif + +#ifdef _ARCH_PWR10 +#ifdef __has_builtin +#if !__has_builtin(__builtin_vsx_assemble_pair) +#define __builtin_vsx_assemble_pair __builtin_mma_assemble_pair +#endif +#if !__has_builtin(__builtin_vsx_disassemble_pair) +#define __builtin_vsx_disassemble_pair __builtin_mma_disassemble_pair +#endif +#endif + +#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ +#define __builtin_vsx_assemble_pair2(vp0, v0, v1) __builtin_vsx_assemble_pair(vp0, v1, v0) +#else +#define __builtin_vsx_assemble_pair2(vp0, v0, v1) __builtin_vsx_assemble_pair(vp0, v0, v1) +#endif + +#define USE_VECTOR_PAIRS +#endif + +typedef __vector IFLOAT vec_bf16; +typedef __vector FLOAT vec_f32; +typedef __vector unsigned char vec_uc8; + +FORCEINLINE vec_uc8 vec_load_vec(void *src) +{ + return vec_xl(0, (unsigned char *)(src)); +} + +FORCEINLINE void vec_load_pair(vec_f32 *dst, vec_f32 *src) +{ +#ifdef USE_VECTOR_PAIRS + __vector_pair vy0p; + vy0p = *(__vector_pair *)(src); + __builtin_vsx_disassemble_pair((void *)(dst), &vy0p); +#else + dst[0] = src[0]; + dst[1] = src[1]; +#endif +} + +FORCEINLINE void vec_store_pair(vec_f32 *dst, vec_f32 *src) +{ +#ifdef USE_VECTOR_PAIRS + __vector_pair vy0p; + __builtin_vsx_assemble_pair2(&vy0p, (vec_uc8)src[1], (vec_uc8)src[0]); + *(__vector_pair *)(dst) = vy0p; +#else + dst[0] = src[0]; + dst[1] = src[1]; +#endif +} + +FORCEINLINE vec_bf16 vec_loadN(void *src, BLASLONG n) +{ + IFLOAT *src2 = (IFLOAT *)(src); +#ifdef _ARCH_PWR9 + return vec_xl_len(src2, n * sizeof(IFLOAT)); +#else + __attribute__((aligned(16))) IFLOAT data[sizeof(vec_bf16) / sizeof(IFLOAT)]; + memset(data, 0, sizeof(vec_bf16)); + if (n & 4) { + memcpy(data, src2, sizeof(uint64_t)); + } + if (n & 2) { + BLASLONG n4 = n & 4; + memcpy(data + n4, src2 + n4, sizeof(uint32_t)); + } + if (n & 1) { + BLASLONG n6 = n & 6; + data[n6] = src2[n6]; + } + return (vec_bf16)vec_load_vec(data); +#endif +} + +FORCEINLINE vec_f32 vec_loadN_f32(void *src, BLASLONG n) +{ +#ifndef _ARCH_PWR9 + if (n & 4) { + return (vec_f32)vec_load_vec(src); + } +#endif + return (vec_f32)vec_loadN(src, n * (sizeof(FLOAT) / sizeof(IFLOAT))); +} + +FORCEINLINE void vec_loadN2_f32(vec_f32 *data, vec_f32 *src, BLASLONG n) +{ + data[0] = src[0]; + data[1] = vec_loadN_f32(&src[1], n); +} + +FORCEINLINE void vec_storeN(vec_bf16 data, void *dst, BLASLONG n) +{ + IFLOAT *dst2 = (IFLOAT *)(dst); +#ifdef _ARCH_PWR9 + vec_xst_len(data, dst2, n * sizeof(IFLOAT)); +#else + if (n & 8) { + vec_xst(data, 0, dst2); + return; + } + __attribute__((aligned(16))) IFLOAT data2[sizeof(vec_f32) / sizeof(IFLOAT)]; + vec_xst(data, 0, data2); + if (n & 4) { + memcpy(dst2, data2, sizeof(uint64_t)); + } + if (n & 2) { + BLASLONG n4 = n & 4; + memcpy(dst2 + n4, data2 + n4, sizeof(uint32_t)); + } + if (n & 1) { + BLASLONG n6 = n & 6; + dst2[n6] = data2[n6]; + } +#endif +} + +FORCEINLINE void vec_storeN_f32(vec_f32 data, void *dst, BLASLONG n) +{ +#ifndef _ARCH_PWR9 + if (n & 4) { + vec_xst(data, 0, (FLOAT *)dst); + return; + } +#endif + return vec_storeN((vec_bf16)data, dst, n * (sizeof(FLOAT) / sizeof(IFLOAT))); +} + +FORCEINLINE void vec_storeN2_f32(vec_f32 *data, vec_f32 *dst, BLASLONG n) +{ + dst[0] = data[0]; + vec_storeN_f32(data[1], &dst[1], n); +} +#endif diff --git a/kernel/power/sbgemv_common.c b/kernel/power/sbgemv_common.c index 07f75d3183..46dee74c3e 100644 --- a/kernel/power/sbgemv_common.c +++ b/kernel/power/sbgemv_common.c @@ -27,40 +27,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #ifndef SBGEMV_COMMON_C #define SBGEMV_COMMON_C -#include "common.h" - -#include - -#define FORCEINLINE inline __attribute__((always_inline)) - -#ifdef __clang__ -#define uint16_t unsigned short -#define uint32_t unsigned int -#define uint64_t unsigned long long -#endif - -#ifdef _ARCH_PWR10 -#ifdef __has_builtin -#if !__has_builtin(__builtin_vsx_assemble_pair) -#define __builtin_vsx_assemble_pair __builtin_mma_assemble_pair -#endif -#if !__has_builtin(__builtin_vsx_disassemble_pair) -#define __builtin_vsx_disassemble_pair __builtin_mma_disassemble_pair -#endif -#endif - -#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ -#define __builtin_vsx_assemble_pair2(vp0, v0, v1) __builtin_vsx_assemble_pair(vp0, v1, v0) -#else -#define __builtin_vsx_assemble_pair2(vp0, v0, v1) __builtin_vsx_assemble_pair(vp0, v0, v1) -#endif - -#define USE_VECTOR_PAIRS -#endif - -typedef __vector IFLOAT vec_bf16; -typedef __vector FLOAT vec_f32; -typedef __vector unsigned char vec_uc8; +#include "gemm_common.c" #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ #define BF16_HI(data, zero) (vec_f32)vec_mergeh(data, zero) @@ -70,108 +37,12 @@ typedef __vector unsigned char vec_uc8; #define BF16_LO(data, zero) (vec_f32)vec_mergel(zero, data) #endif -FORCEINLINE vec_uc8 vec_load_vec(void *src) -{ - return vec_xl(0, (unsigned char *)(src)); -} - -FORCEINLINE void vec_load_pair(vec_f32 *dst, vec_f32 *src) -{ -#ifdef USE_VECTOR_PAIRS - __vector_pair vy0p; - vy0p = *(__vector_pair *)(src); - __builtin_vsx_disassemble_pair((void *)(dst), &vy0p); -#else - dst[0] = src[0]; - dst[1] = src[1]; -#endif -} - -FORCEINLINE void vec_store_pair(vec_f32 *dst, vec_f32 *src) -{ -#ifdef USE_VECTOR_PAIRS - __vector_pair vy0p; - __builtin_vsx_assemble_pair2(&vy0p, (vec_uc8)src[1], (vec_uc8)src[0]); - *(__vector_pair *)(dst) = vy0p; -#else - dst[0] = src[0]; - dst[1] = src[1]; -#endif -} - -FORCEINLINE vec_bf16 vec_loadN(void *src, BLASLONG n) -{ - IFLOAT *src2 = (IFLOAT *)(src); -#ifdef _ARCH_PWR9 - return vec_xl_len(src2, n * sizeof(IFLOAT)); -#else - __attribute__((aligned(16))) IFLOAT data[sizeof(vec_bf16) / sizeof(IFLOAT)]; - memset(data, 0, sizeof(vec_bf16)); - if (n & 4) { - memcpy(data, src2, sizeof(uint64_t)); - } - if (n & 2) { - BLASLONG n4 = n & 4; - memcpy(data + n4, src2 + n4, sizeof(uint32_t)); - } - if (n & 1) { - BLASLONG n6 = n & 6; - data[n6] = src2[n6]; - } - return (vec_bf16)vec_load_vec(data); -#endif -} - FORCEINLINE vec_f32 vec_loadNHi(void *src, BLASLONG n, vec_bf16 zero) { vec_bf16 data = vec_loadN(src, n); return BF16_HI(data, zero); } -FORCEINLINE vec_f32 vec_loadN_f32(void *src, BLASLONG n) -{ -#ifndef _ARCH_PWR9 - if (n & 4) { - return (vec_f32)vec_load_vec(src); - } -#endif - return (vec_f32)vec_loadN(src, n * (sizeof(FLOAT) / sizeof(IFLOAT))); -} - -FORCEINLINE void vec_loadN2_f32(vec_f32 *data, vec_f32 *src, BLASLONG n) -{ - data[0] = src[0]; - data[1] = vec_loadN_f32(&src[1], n); -} - -FORCEINLINE void vec_storeN_f32(vec_f32 data, void *dst, BLASLONG n) -{ - FLOAT *dst2 = (FLOAT *)(dst); -#ifdef _ARCH_PWR9 - vec_xst_len(data, dst2, n * sizeof(FLOAT)); -#else - if (n & 4) { - vec_xst(data, 0, dst2); - return; - } - __attribute__((aligned(16))) FLOAT data2[sizeof(vec_f32) / sizeof(FLOAT)]; - vec_xst(data, 0, data2); - if (n & 2) { - memcpy(dst2, data2, sizeof(uint64_t)); - } - if (n & 1) { - BLASLONG n2 = n & 2; - dst2[n2] = data2[n2]; - } -#endif -} - -FORCEINLINE void vec_storeN2_f32(vec_f32 *data, vec_f32 *dst, BLASLONG n) -{ - dst[0] = data[0]; - vec_storeN_f32(data[1], &dst[1], n); -} - FORCEINLINE vec_f32 vec_mult(vec_f32 *inp, vec_bf16 in0, vec_bf16 zero) { vec_f32 v_in00 = BF16_HI(in0, zero); @@ -297,7 +168,7 @@ FORCEINLINE void copy_y(BLASLONG n, FLOAT *src, FLOAT *dest, BLASLONG inc_src, F } } -FORCEINLINE void add_y(BLASLONG n, FLOAT *src, FLOAT *dest, BLASLONG inc_dest) +FORCEINLINE void move_y(BLASLONG n, FLOAT *src, FLOAT *dest, BLASLONG inc_dest) { for (BLASLONG i = 0; i < n; i++) { *dest = *src++; diff --git a/kernel/power/sbgemv_n.c b/kernel/power/sbgemv_n.c index 05c02a0068..c7559a47c4 100644 --- a/kernel/power/sbgemv_n.c +++ b/kernel/power/sbgemv_n.c @@ -179,7 +179,7 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT * a += NB; if (inc_y != 1) { - add_y(NB, ybuffer, y_ptr, inc_y); + move_y(NB, ybuffer, y_ptr, inc_y); y_ptr += (NB * inc_y); } else { y_ptr += NB; diff --git a/kernel/power/sbgemv_n_vsx.c b/kernel/power/sbgemv_n_vsx.c index 45570950ea..cab2316d4e 100644 --- a/kernel/power/sbgemv_n_vsx.c +++ b/kernel/power/sbgemv_n_vsx.c @@ -269,8 +269,7 @@ static void BF16GEMV_N_VSX_8(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, BLAS vec_loadN_mult2(v_x7, &vb3[i], n, zero, vy0); vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n3); - } else - if (n) { + } else if (n) { vec_f32 vy0 = vec_loadN_f32(&v_y[(i * 2) + 0], n); vy0 += vec_loadNHi_multi2(v_x0, &va0[i], n, zero); From 89a12fa08352da73b03225cf1743dead75669043 Mon Sep 17 00:00:00 2001 From: Chip Kerchner Date: Mon, 23 Sep 2024 06:32:14 -0500 Subject: [PATCH 07/24] MMA BF16 GEMV code. --- kernel/power/gemm_common.c | 2 + kernel/power/sbgemv_common.c | 2 +- kernel/power/sbgemv_common_power10.c | 265 +++++++++++++++++++++ kernel/power/sbgemv_n.c | 22 ++ kernel/power/sbgemv_n_power10.c | 306 ++++++++++++++++++++++++- kernel/power/sbgemv_n_vsx.c | 71 +++--- kernel/power/sbgemv_t.c | 15 ++ kernel/power/sbgemv_t_power10.c | 330 ++++++++++++++++++++++++++- kernel/power/sbgemv_t_vsx.c | 67 +++--- 9 files changed, 1011 insertions(+), 69 deletions(-) create mode 100644 kernel/power/sbgemv_common_power10.c diff --git a/kernel/power/gemm_common.c b/kernel/power/gemm_common.c index c33faffe0e..0611ebc2a9 100644 --- a/kernel/power/gemm_common.c +++ b/kernel/power/gemm_common.c @@ -4,6 +4,8 @@ #include +#define NBMAX 4096 + #define FORCEINLINE inline __attribute__((always_inline)) #ifdef __clang__ diff --git a/kernel/power/sbgemv_common.c b/kernel/power/sbgemv_common.c index 46dee74c3e..ab50f430af 100644 --- a/kernel/power/sbgemv_common.c +++ b/kernel/power/sbgemv_common.c @@ -111,7 +111,7 @@ FORCEINLINE vec_f32 vec_loadNHi_mult(vec_bf16 *in, vec_f32 v_inp0, BLASLONG n, v return (v_inp0 * v_in00); } -FORCEINLINE vec_f32 vec_loadNHi_multi2(vec_f32 v_x0, vec_bf16 *in, BLASLONG n, vec_bf16 zero) +FORCEINLINE vec_f32 vec_loadNHi_mult2(vec_f32 v_x0, vec_bf16 *in, BLASLONG n, vec_bf16 zero) { vec_f32 v_in00 = vec_loadNHi(in, n, zero); diff --git a/kernel/power/sbgemv_common_power10.c b/kernel/power/sbgemv_common_power10.c new file mode 100644 index 0000000000..da088014b0 --- /dev/null +++ b/kernel/power/sbgemv_common_power10.c @@ -0,0 +1,265 @@ +/*************************************************************************** +Copyright (c) 2024, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#ifndef SBGEMV_COMMON_MMA_C +#define SBGEMV_COMMON_MMA_C +#include "sbgemv_common.c" + +FORCEINLINE void vec_load_mult_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 inp) +{ + vec_bf16 in0 = (vec_bf16)vec_load_vec(in); + + __builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0, (vec_uc8)inp); +} + +FORCEINLINE void vec_load_mult2_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 *inp) +{ + vec_bf16 in0[2]; + + vec_load_pair((vec_f32 *)in0, (vec_f32 *)in); + + __builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0[0], (vec_uc8)inp[0]); + __builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0[1], (vec_uc8)inp[1]); +} + +FORCEINLINE void vec_loadN_mult_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 inp, BLASLONG n) +{ + vec_bf16 in0 = vec_loadN(in, n); + + __builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0, (vec_uc8)inp); +} + +FORCEINLINE void vec_mult1_mma(__vector_quad *out, vec_bf16 in0, vec_bf16 inp) +{ + vec_bf16 in00 = vec_mergeh(in0, in0); + + __builtin_mma_xvbf16ger2(out, (vec_uc8)inp, (vec_uc8)in00); +} + +FORCEINLINE void vec_mult2_mma(__vector_quad *out, vec_bf16 in0, vec_bf16 inp) +{ + vec_bf16 in01 = vec_mergel(in0, in0); + + vec_mult1_mma(&out[0], in0, inp); + + __builtin_mma_xvbf16ger2(&out[1], (vec_uc8)inp, (vec_uc8)in01); +} + +FORCEINLINE void vec_mult4_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 inp) +{ + vec_mult2_mma(out + 0, in0[0], inp); + vec_mult2_mma(out + 2, in0[1], inp); +} + +FORCEINLINE void vec_loadN_mult11_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 inp, BLASLONG n) +{ + vec_bf16 in0 = vec_loadN(in, n); + + vec_mult1_mma(out, in0, inp); +} + +FORCEINLINE void vec_loadN_mult12_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 inp, BLASLONG n) +{ + vec_bf16 in0 = vec_loadN(in, n); + + vec_mult2_mma(out, in0, inp); +} + +FORCEINLINE void vec_load_mult12_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 inp) +{ + vec_bf16 in0 = (vec_bf16)vec_load_vec(in); + + vec_mult2_mma(out, in0, inp); +} + +FORCEINLINE void vec_load_mult18_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 inp) +{ + vec_bf16 in0[4]; + + vec_load_pair((vec_f32 *)(in0 + 0), (vec_f32 *)(in + 0)); + vec_load_pair((vec_f32 *)(in0 + 2), (vec_f32 *)(in + 2)); + + vec_mult4_mma(&out[0], in0 + 0, inp); + vec_mult4_mma(&out[4], in0 + 2, inp); +} + +FORCEINLINE void vec_reduce1_mma(__vector_quad *out, vec_f32 *temp, vec_f32 v_alpha, vec_f32 *vy0) +{ + __builtin_mma_disassemble_acc((void*)temp, &out[0]); + + vy0[0] += (temp[0] * v_alpha); +} + +FORCEINLINE void vec_reduce2_mma(__vector_quad *out, vec_f32 *temp, vec_f32 v_alpha, vec_f32 *vy0) +{ + vec_reduce1_mma(&out[0], &temp[0], v_alpha, &vy0[0]); + vec_reduce1_mma(&out[1], &temp[4], v_alpha, &vy0[1]); +} + +FORCEINLINE void vec_reduce8_mma(__vector_quad *out, vec_f32 *temp, vec_f32 v_alpha, vec_f32 *vy0) +{ + vec_reduce2_mma(&out[0], &temp[0], v_alpha, vy0 + 0); + vec_reduce2_mma(&out[2], &temp[8], v_alpha, vy0 + 2); + vec_reduce2_mma(&out[4], &temp[16], v_alpha, vy0 + 4); + vec_reduce2_mma(&out[6], &temp[24], v_alpha, vy0 + 6); +} + +FORCEINLINE void vec_mult11a_mma(__vector_quad *out, vec_bf16 in0, vec_bf16 in1, vec_bf16 inp) +{ + vec_bf16 in00 = vec_mergeh(in0, in1); + + __builtin_mma_xvbf16ger2(out, (vec_uc8)inp, (vec_uc8)in00); +} + +FORCEINLINE void vec_mult2a_mma(__vector_quad *out, vec_bf16 in0, vec_bf16 in1, vec_bf16 inp) +{ + vec_bf16 in01 = vec_mergel(in0, in1); + + vec_mult11a_mma(&out[0], in0, in1, inp); + + __builtin_mma_xvbf16ger2(&out[1], (vec_uc8)inp, (vec_uc8)in01); +} + +FORCEINLINE void vec_mult4a_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 inp) +{ + vec_mult2a_mma(out + 0, in0[0], in1[0], inp); + vec_mult2a_mma(out + 2, in0[1], in1[1], inp); +} + +FORCEINLINE void vec_loadN_mult11a_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 inp, BLASLONG n) +{ + vec_bf16 in0 = vec_loadN(ina, n); + vec_bf16 in1 = vec_loadN(inb, n); + + vec_mult11a_mma(out, in0, in1, inp); +} + +FORCEINLINE void vec_load_mult22a_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 inp) +{ + vec_bf16 in0 = (vec_bf16)vec_load_vec(ina); + vec_bf16 in1 = (vec_bf16)vec_load_vec(inb); + + vec_mult2a_mma(out, in0, in1, inp); +} + +FORCEINLINE void vec_load_mult28a_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 inp) +{ + vec_bf16 in0[4], in1[4]; + + vec_load_pair((vec_f32 *)(in0 + 0), (vec_f32 *)(ina + 0)); + vec_load_pair((vec_f32 *)(in1 + 0), (vec_f32 *)(inb + 0)); + vec_load_pair((vec_f32 *)(in0 + 2), (vec_f32 *)(ina + 2)); + vec_load_pair((vec_f32 *)(in1 + 2), (vec_f32 *)(inb + 2)); + + vec_mult4a_mma(&out[0], in0 + 0, in1 + 0, inp); + vec_mult4a_mma(&out[4], in0 + 2, in1 + 2, inp); +} + +FORCEINLINE void vec_loadN_mult22a_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 inp, BLASLONG n) +{ + vec_bf16 in0 = vec_loadN(ina, n); + vec_bf16 in1 = vec_loadN(inb, n); + + vec_mult2a_mma(out, in0, in1, inp); +} + +FORCEINLINE void vec_mult11b_mma(__vector_quad *out, vec_bf16 in0, vec_bf16 in1, vec_bf16 inp) +{ + vec_bf16 in00 = vec_mergeh(in0, in1); + + __builtin_mma_xvbf16ger2pp(out, (vec_uc8)inp, (vec_uc8)in00); +} + +FORCEINLINE void vec_mult2b_mma(__vector_quad *out, vec_bf16 in0, vec_bf16 in1, vec_bf16 inp) +{ + vec_bf16 in01 = vec_mergel(in0, in1); + + vec_mult11b_mma(&out[0], in0, in1, inp); + + __builtin_mma_xvbf16ger2pp(&out[1], (vec_uc8)inp, (vec_uc8)in01); +} + +FORCEINLINE void vec_mult4b_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 inp) +{ + vec_mult2b_mma(out + 0, in0[0], in1[0], inp); + vec_mult2b_mma(out + 2, in0[1], in1[1], inp); +} + +FORCEINLINE void vec_loadN_mult11b_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 inp, BLASLONG n) +{ + vec_bf16 in0 = vec_loadN(ina, n); + vec_bf16 in1 = vec_loadN(inb, n); + + vec_mult11b_mma(out, in0, in1, inp); +} + +FORCEINLINE void vec_load_mult22b_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 inp) +{ + vec_bf16 in0 = (vec_bf16)vec_load_vec(ina); + vec_bf16 in1 = (vec_bf16)vec_load_vec(inb); + + vec_mult2b_mma(out, in0, in1, inp); +} + +FORCEINLINE void vec_load_mult28b_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 inp) +{ + vec_bf16 in0[4], in1[4]; + + vec_load_pair((vec_f32 *)(in0 + 0), (vec_f32 *)(ina + 0)); + vec_load_pair((vec_f32 *)(in1 + 0), (vec_f32 *)(inb + 0)); + vec_load_pair((vec_f32 *)(in0 + 2), (vec_f32 *)(ina + 2)); + vec_load_pair((vec_f32 *)(in1 + 2), (vec_f32 *)(inb + 2)); + + vec_mult4b_mma(&out[0], in0 + 0, in1 + 0, inp); + vec_mult4b_mma(&out[4], in0 + 2, in1 + 2, inp); +} + +FORCEINLINE void vec_loadN_mult22b_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 inp, BLASLONG n) +{ + vec_bf16 in0 = vec_loadN(ina, n); + vec_bf16 in1 = vec_loadN(inb, n); + + vec_mult2b_mma(out, in0, in1, inp); +} + +FORCEINLINE void vec_load4_pair(vec_f32 *vy0, vec_f32 *v_y) +{ + vec_load_pair(vy0 + 0, v_y + 0); + vec_load_pair(vy0 + 2, v_y + 2); + vec_load_pair(vy0 + 4, v_y + 4); + vec_load_pair(vy0 + 6, v_y + 6); +} + +FORCEINLINE void vec_store4_pair(vec_f32 *v_y, vec_f32 *vy0) +{ + vec_store_pair(v_y + 0, vy0 + 0); + vec_store_pair(v_y + 2, vy0 + 2); + vec_store_pair(v_y + 4, vy0 + 4); + vec_store_pair(v_y + 6, vy0 + 6); +} + +#endif diff --git a/kernel/power/sbgemv_n.c b/kernel/power/sbgemv_n.c index c7559a47c4..4768be31fa 100644 --- a/kernel/power/sbgemv_n.c +++ b/kernel/power/sbgemv_n.c @@ -87,6 +87,10 @@ static void BF16GEMV_N_beta(BLASLONG n, FLOAT *output_vector, FLOAT *input_vecto } } +#if (defined(_ARCH_PWR10) && (defined(USE_BFGEMV_8_N_MMA) || (!defined(USE_BFGEMV_N_MMA) && defined(USE_BFGEMV_8_N_VSX)))) || (!defined(_ARCH_PWR10) && defined(USE_BFGEMV_8_N_VSX)) +#define USE_N_8 +#endif + int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT *x, BLASLONG inc_x, FLOAT beta, FLOAT *y, BLASLONG inc_y) { IFLOAT *x_ptr, *ap[4]; @@ -100,7 +104,9 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT * y_ptr = y; BLASLONG lda4 = lda << 2; +#ifdef USE_N_8 BLASLONG lda8 = lda << 3; +#endif BLASLONG NB = NBMAX; BLASLONG m2 = (m & (NBMAX - 1)); @@ -126,6 +132,7 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT * ap[3] = ap[2] + lda; if (inc_x == 1) { +#ifdef USE_N_8 for (BLASLONG j = 0; j + 8 <= n; j += 8) { BF16GEMV_N_8(NB, ap, x_ptr, ybuffer, lda4, alpha); ap[0] += lda8; @@ -135,9 +142,16 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT * x_ptr += 8; } if (n & 4) { +#else + for (BLASLONG j = 0; j + 4 <= n; j += 4) { +#endif BF16GEMV_N_4(NB, ap, x_ptr, ybuffer, alpha); ap[0] += lda4; ap[1] += lda4; +#ifndef USE_N_8 + ap[2] += lda4; + ap[3] += lda4; +#endif x_ptr += 4; } if (n & 2) { @@ -149,6 +163,7 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT * BF16GEMV_N_1(NB, ap, x_ptr, ybuffer, alpha); } } else { +#ifdef USE_N_8 for (BLASLONG j = 0; j + 8 <= n; j += 8) { copy_x(8, x_ptr, xbuffer, inc_x); BF16GEMV_N_8(NB, ap, xbuffer, ybuffer, lda4, alpha); @@ -159,10 +174,17 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT * x_ptr += 8 * inc_x; } if (n & 4) { +#else + for (BLASLONG j = 0; j + 4 <= n; j += 4) { +#endif copy_x(4, x_ptr, xbuffer, inc_x); BF16GEMV_N_4(NB, ap, xbuffer, ybuffer, alpha); ap[0] += lda4; ap[1] += lda4; +#ifndef USE_N_8 + ap[2] += lda4; + ap[3] += lda4; +#endif x_ptr += 4 * inc_x; } if (n & 2) { diff --git a/kernel/power/sbgemv_n_power10.c b/kernel/power/sbgemv_n_power10.c index fc83b38c37..7b2beb0c7b 100644 --- a/kernel/power/sbgemv_n_power10.c +++ b/kernel/power/sbgemv_n_power10.c @@ -25,9 +25,309 @@ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. *****************************************************************************/ -//#include "sbgemv_common.c" +#ifndef SBGEMV_N_MMA_C +#define SBGEMV_N_MMA_C -#include "sbgemv_n_vsx.c" +#define USE_BFGEMV_N_MMA + +#ifdef USE_BFGEMV_N_MMA +#include "sbgemv_common_power10.c" + +#ifndef BF16GEMV_N_X +#define BF16GEMV_N_X +#define BF16GEMV_N_8 BF16GEMV_N_MMA_8 +#define BF16GEMV_N_4 BF16GEMV_N_MMA_4 +#define BF16GEMV_N_2 BF16GEMV_N_MMA_2 +#define BF16GEMV_N_1 BF16GEMV_N_MMA_1 +#endif + +#define USE_BFGEMV_8_N_MMA + +static void BF16GEMV_N_MMA_1(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOAT alpha) +{ + IFLOAT *a0; + __vector_quad temp[2*4]; + vec_f32 temp0[8*4], vy0[2*4]; + vec_f32 v_alpha = { alpha, alpha, alpha, alpha }; + + a0 = ap[0]; + + vec_bf16 *va0 = (vec_bf16 *)a0; + + vec_bf16 *x_bf = (vec_bf16 *)(xo); + vec_bf16 v_x0 = vec_loadN(x_bf, 1); + + vec_f32 *v_y = (vec_f32 *)y; + BLASLONG n8 = n / 8; + BLASLONG i = 0; + + for (; i + 4 <= n8; i += 4) { + vec_load4_pair(vy0, &v_y[(i * 2) + 0]); + + vec_load_mult18_mma(&temp[0], &va0[i + 0], v_x0); + + vec_reduce8_mma(&temp[0], temp0, v_alpha, vy0); + + vec_store4_pair(&v_y[(i * 2) + 0], vy0); + } + + for (; i < n8; i++) { + vec_load_pair(vy0, &v_y[(i * 2) + 0]); + + vec_load_mult12_mma(&temp[0], &va0[i], v_x0); + + vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0); + + vec_store_pair(&v_y[(i * 2) + 0], vy0); + } + + n &= 7; + if (n > 4) { + BLASLONG n3 = n & 3; + vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n3); + + vec_loadN_mult12_mma(&temp[0], &va0[i], v_x0, n); + + vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0); + + vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n3); + } else if (n) { + vy0[0] = vec_loadN_f32(&v_y[(i * 2) + 0], n); + + vec_loadN_mult11_mma(&temp[0], &va0[i], v_x0, n); + + vec_reduce1_mma(&temp[0], temp0, v_alpha, vy0); + + vec_storeN_f32(vy0[0], &v_y[(i * 2) + 0], n); + } +} + +static void BF16GEMV_N_MMA_2(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOAT alpha) +{ + IFLOAT *a0, *a1; + __vector_quad temp[2*4]; + vec_f32 temp0[8*4], vy0[2*4]; + vec_f32 v_alpha = { alpha, alpha, alpha, alpha }; + + a0 = ap[0]; + a1 = ap[1]; + + vec_bf16 *va0 = (vec_bf16 *)a0; + vec_bf16 *va1 = (vec_bf16 *)a1; + + vec_bf16 *x_bf = (vec_bf16 *)(xo); + vec_bf16 v_x0 = vec_loadN(x_bf, 2); + + vec_f32 *v_y = (vec_f32 *)y; + BLASLONG n8 = n / 8; + BLASLONG i = 0; + + for (; i + 4 <= n8; i += 4) { + vec_load4_pair(vy0, &v_y[(i * 2) + 0]); + + vec_load_mult28a_mma(&temp[0], &va0[i + 0], &va1[i + 0], v_x0); + + vec_reduce8_mma(&temp[0], temp0, v_alpha, vy0); + + vec_store4_pair(&v_y[(i * 2) + 0], vy0); + } + + for (; i < n8; i++) { + vec_load_pair(vy0, &v_y[(i * 2) + 0]); + + vec_load_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x0); + + vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0); + + vec_store_pair(&v_y[(i * 2) + 0], vy0); + } + + n &= 7; + if (n > 4) { + BLASLONG n3 = n & 3; + vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n3); + + vec_loadN_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x0, n); + + vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0); + + vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n3); + } else if (n) { + vy0[0] = vec_loadN_f32(&v_y[(i * 2) + 0], n); + + vec_loadN_mult11a_mma(&temp[0], &va0[i], &va1[i], v_x0, n); + + vec_reduce1_mma(&temp[0], temp0, v_alpha, vy0); + + vec_storeN_f32(vy0[0], &v_y[(i * 2) + 0], n); + } +} -//#include "sbgemv_n.c" +static void BF16GEMV_N_MMA_4(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOAT alpha) +{ + IFLOAT *a0, *a1, *a2, *a3; + __vector_quad temp[2*4]; + vec_f32 temp0[8*4], vy0[2*4]; + vec_f32 v_alpha = { alpha, alpha, alpha, alpha }; + + a0 = ap[0]; + a1 = ap[1]; + a2 = ap[2]; + a3 = ap[3]; + + vec_bf16 *va0 = (vec_bf16 *)a0; + vec_bf16 *va1 = (vec_bf16 *)a1; + vec_bf16 *va2 = (vec_bf16 *)a2; + vec_bf16 *va3 = (vec_bf16 *)a3; + + vec_bf16 *x_bf = (vec_bf16 *)(xo); + vec_bf16 v_x00 = vec_loadN(x_bf, 4); + + vec_bf16 v_x01 = (vec_bf16)vec_splat((vec_f32)v_x00, 1); + + vec_f32 *v_y = (vec_f32 *)y; + BLASLONG n8 = n / 8; + BLASLONG i = 0; + + for (; i + 4 <= n8; i += 4) { + vec_load4_pair(vy0, &v_y[(i * 2) + 0]); + + vec_load_mult28a_mma(&temp[0], &va0[i + 0], &va1[i + 0], v_x00); + vec_load_mult28b_mma(&temp[0], &va2[i + 0], &va3[i + 0], v_x01); + + vec_reduce8_mma(&temp[0], temp0, v_alpha, vy0); + + vec_store4_pair(&v_y[(i * 2) + 0], vy0); + } + + for (; i < n8; i++) { + vec_load_pair(vy0, &v_y[(i * 2) + 0]); + + vec_load_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x00); + vec_load_mult22b_mma(&temp[0], &va2[i], &va3[i], v_x01); + + vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0); + + vec_store_pair(&v_y[(i * 2) + 0], vy0); + } + + n &= 7; + if (n > 4) { + BLASLONG n3 = n & 3; + vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n3); + + vec_loadN_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x00, n); + vec_loadN_mult22b_mma(&temp[0], &va2[i], &va3[i], v_x01, n); + + vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0); + + vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n3); + } else if (n) { + vy0[0] = vec_loadN_f32(&v_y[(i * 2) + 0], n); + + vec_loadN_mult11a_mma(&temp[0], &va0[i], &va1[i], v_x00, n); + vec_loadN_mult11b_mma(&temp[0], &va2[i], &va3[i], v_x01, n); + + vec_reduce1_mma(&temp[0], temp0, v_alpha, vy0); + + vec_storeN_f32(vy0[0], &v_y[(i * 2) + 0], n); + } +} + +#ifdef USE_BFGEMV_8_N_MMA +static void BF16GEMV_N_MMA_8(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, BLASLONG lda4, FLOAT alpha) +{ + IFLOAT *a0, *a1, *a2, *a3, *b0, *b1, *b2, *b3; + __vector_quad temp[2*4]; + vec_f32 temp0[8*4], vy0[2*4]; + vec_f32 v_alpha = { alpha, alpha, alpha, alpha }; + + a0 = ap[0]; + a1 = ap[1]; + a2 = ap[2]; + a3 = ap[3]; + b0 = a0 + lda4; + b1 = a1 + lda4; + b2 = a2 + lda4; + b3 = a3 + lda4; + + vec_bf16 *va0 = (vec_bf16 *)a0; + vec_bf16 *va1 = (vec_bf16 *)a1; + vec_bf16 *va2 = (vec_bf16 *)a2; + vec_bf16 *va3 = (vec_bf16 *)a3; + vec_bf16 *vb0 = (vec_bf16 *)b0; + vec_bf16 *vb1 = (vec_bf16 *)b1; + vec_bf16 *vb2 = (vec_bf16 *)b2; + vec_bf16 *vb3 = (vec_bf16 *)b3; + + vec_bf16 *x_bf = (vec_bf16 *)(xo); + vec_bf16 v_x00 = (vec_bf16)vec_load_vec(x_bf); + + vec_bf16 v_x01 = (vec_bf16)vec_splat((vec_f32)v_x00, 1); + vec_bf16 v_x02 = (vec_bf16)vec_splat((vec_f32)v_x00, 2); + vec_bf16 v_x03 = (vec_bf16)vec_splat((vec_f32)v_x00, 3); + + vec_f32 *v_y = (vec_f32 *)y; + BLASLONG n8 = n / 8; + BLASLONG i = 0; + + for (; i + 4 <= n8; i += 4) { + vec_load4_pair(vy0, &v_y[(i * 2) + 0]); + + vec_load_mult28a_mma(&temp[0], &va0[i + 0], &va1[i + 0], v_x00); + vec_load_mult28b_mma(&temp[0], &va2[i + 0], &va3[i + 0], v_x01); + vec_load_mult28b_mma(&temp[0], &vb0[i + 0], &vb1[i + 0], v_x02); + vec_load_mult28b_mma(&temp[0], &vb2[i + 0], &vb3[i + 0], v_x03); + + vec_reduce8_mma(&temp[0], temp0, v_alpha, vy0); + + vec_store4_pair(&v_y[(i * 2) + 0], vy0); + } + + for (; i < n8; i++) { + vec_load_pair(vy0, &v_y[(i * 2) + 0]); + + vec_load_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x00); + vec_load_mult22b_mma(&temp[0], &va2[i], &va3[i], v_x01); + vec_load_mult22b_mma(&temp[0], &vb0[i], &vb1[i], v_x02); + vec_load_mult22b_mma(&temp[0], &vb2[i], &vb3[i], v_x03); + + vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0); + + vec_store_pair(&v_y[(i * 2) + 0], vy0); + } + + n &= 7; + if (n > 4) { + BLASLONG n3 = n & 3; + vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n3); + + vec_loadN_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x00, n); + vec_loadN_mult22b_mma(&temp[0], &va2[i], &va3[i], v_x01, n); + vec_loadN_mult22b_mma(&temp[0], &vb0[i], &vb1[i], v_x02, n); + vec_loadN_mult22b_mma(&temp[0], &vb2[i], &vb3[i], v_x03, n); + + vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0); + + vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n3); + } else if (n) { + vy0[0] = vec_loadN_f32(&v_y[(i * 2) + 0], n); + + vec_loadN_mult11a_mma(&temp[0], &va0[i], &va1[i], v_x00, n); + vec_loadN_mult11b_mma(&temp[0], &va2[i], &va3[i], v_x01, n); + vec_loadN_mult11b_mma(&temp[0], &vb0[i], &vb1[i], v_x02, n); + vec_loadN_mult11b_mma(&temp[0], &vb2[i], &vb3[i], v_x03, n); + + vec_reduce1_mma(&temp[0], temp0, v_alpha, vy0); + + vec_storeN_f32(vy0[0], &v_y[(i * 2) + 0], n); + } +} +#endif + +#include "sbgemv_n.c" +#else +#include "sbgemv_n_vsx.c" +#endif +#endif diff --git a/kernel/power/sbgemv_n_vsx.c b/kernel/power/sbgemv_n_vsx.c index cab2316d4e..e8f6dca9fc 100644 --- a/kernel/power/sbgemv_n_vsx.c +++ b/kernel/power/sbgemv_n_vsx.c @@ -25,12 +25,20 @@ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. *****************************************************************************/ -#ifndef SBGEMV_N_VSX -#define SBGEMV_N_VSX +#ifndef SBGEMV_N_VSX_C +#define SBGEMV_N_VSX_C #include "sbgemv_common.c" -#define NBMAX 4096 +#ifndef BF16GEMV_N_X +#define BF16GEMV_N_X +#define BF16GEMV_N_8 BF16GEMV_N_VSX_8 +#define BF16GEMV_N_4 BF16GEMV_N_VSX_4 +#define BF16GEMV_N_2 BF16GEMV_N_VSX_2 +#define BF16GEMV_N_1 BF16GEMV_N_VSX_1 +#endif + +#define USE_BFGEMV_8_N_VSX static void BF16GEMV_N_VSX_1(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOAT alpha) { @@ -70,11 +78,11 @@ static void BF16GEMV_N_VSX_1(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n3); } else if (n) { - vec_f32 vy0 = vec_loadN_f32(&v_y[(i * 2) + 0], n); + vy0[0] = vec_loadN_f32(&v_y[(i * 2) + 0], n); - vy0 += vec_loadNHi_multi2(v_x0, &va0[i], n, zero); + vy0[0] += vec_loadNHi_mult2(v_x0, &va0[i], n, zero); - vec_storeN_f32(vy0, &v_y[(i * 2) + 0], n); + vec_storeN_f32(vy0[0], &v_y[(i * 2) + 0], n); } } @@ -121,12 +129,12 @@ static void BF16GEMV_N_VSX_2(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n3); } else if (n) { - vec_f32 vy0 = vec_loadN_f32(&v_y[(i * 2) + 0], n); + vy0[0] = vec_loadN_f32(&v_y[(i * 2) + 0], n); - vy0 += vec_loadNHi_multi2(v_x0, &va0[i], n, zero); - vy0 += vec_loadNHi_multi2(v_x1, &va1[i], n, zero); + vy0[0] += vec_loadNHi_mult2(v_x0, &va0[i], n, zero); + vy0[0] += vec_loadNHi_mult2(v_x1, &va1[i], n, zero); - vec_storeN_f32(vy0, &v_y[(i * 2) + 0], n); + vec_storeN_f32(vy0[0], &v_y[(i * 2) + 0], n); } } @@ -183,17 +191,18 @@ static void BF16GEMV_N_VSX_4(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n3); } else if (n) { - vec_f32 vy0 = vec_loadN_f32(&v_y[(i * 2) + 0], n); + vy0[0] = vec_loadN_f32(&v_y[(i * 2) + 0], n); - vy0 += vec_loadNHi_multi2(v_x0, &va0[i], n, zero); - vy0 += vec_loadNHi_multi2(v_x1, &va1[i], n, zero); - vy0 += vec_loadNHi_multi2(v_x2, &va2[i], n, zero); - vy0 += vec_loadNHi_multi2(v_x3, &va3[i], n, zero); + vy0[0] += vec_loadNHi_mult2(v_x0, &va0[i], n, zero); + vy0[0] += vec_loadNHi_mult2(v_x1, &va1[i], n, zero); + vy0[0] += vec_loadNHi_mult2(v_x2, &va2[i], n, zero); + vy0[0] += vec_loadNHi_mult2(v_x3, &va3[i], n, zero); - vec_storeN_f32(vy0, &v_y[(i * 2) + 0], n); + vec_storeN_f32(vy0[0], &v_y[(i * 2) + 0], n); } } +#ifdef USE_BFGEMV_8_N_VSX static void BF16GEMV_N_VSX_8(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, BLASLONG lda4, FLOAT alpha) { IFLOAT *a0, *a1, *a2, *a3, *b0, *b1, *b2, *b3; @@ -270,25 +279,21 @@ static void BF16GEMV_N_VSX_8(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, BLAS vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n3); } else if (n) { - vec_f32 vy0 = vec_loadN_f32(&v_y[(i * 2) + 0], n); - - vy0 += vec_loadNHi_multi2(v_x0, &va0[i], n, zero); - vy0 += vec_loadNHi_multi2(v_x1, &va1[i], n, zero); - vy0 += vec_loadNHi_multi2(v_x2, &va2[i], n, zero); - vy0 += vec_loadNHi_multi2(v_x3, &va3[i], n, zero); - vy0 += vec_loadNHi_multi2(v_x4, &vb0[i], n, zero); - vy0 += vec_loadNHi_multi2(v_x5, &vb1[i], n, zero); - vy0 += vec_loadNHi_multi2(v_x6, &vb2[i], n, zero); - vy0 += vec_loadNHi_multi2(v_x7, &vb3[i], n, zero); - - vec_storeN_f32(vy0, &v_y[(i * 2) + 0], n); + vy0[0] = vec_loadN_f32(&v_y[(i * 2) + 0], n); + + vy0[0] += vec_loadNHi_mult2(v_x0, &va0[i], n, zero); + vy0[0] += vec_loadNHi_mult2(v_x1, &va1[i], n, zero); + vy0[0] += vec_loadNHi_mult2(v_x2, &va2[i], n, zero); + vy0[0] += vec_loadNHi_mult2(v_x3, &va3[i], n, zero); + vy0[0] += vec_loadNHi_mult2(v_x4, &vb0[i], n, zero); + vy0[0] += vec_loadNHi_mult2(v_x5, &vb1[i], n, zero); + vy0[0] += vec_loadNHi_mult2(v_x6, &vb2[i], n, zero); + vy0[0] += vec_loadNHi_mult2(v_x7, &vb3[i], n, zero); + + vec_storeN_f32(vy0[0], &v_y[(i * 2) + 0], n); } } - -#define BF16GEMV_N_8 BF16GEMV_N_VSX_8 -#define BF16GEMV_N_4 BF16GEMV_N_VSX_4 -#define BF16GEMV_N_2 BF16GEMV_N_VSX_2 -#define BF16GEMV_N_1 BF16GEMV_N_VSX_1 +#endif #include "sbgemv_n.c" #endif diff --git a/kernel/power/sbgemv_t.c b/kernel/power/sbgemv_t.c index f0c79fe77a..4cc8f060e9 100644 --- a/kernel/power/sbgemv_t.c +++ b/kernel/power/sbgemv_t.c @@ -27,6 +27,11 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #ifndef SBGEMV_T_COMMON_C #define SBGEMV_T_COMMON_C + +#if (defined(_ARCH_PWR10) && (defined(USE_BFGEMV_8_T_MMA) || (!defined(USE_BFGEMV_N_MMA) && defined(USE_BFGEMV_8_T_VSX)))) || (!defined(_ARCH_PWR10) && defined(USE_BFGEMV_8_T_VSX)) +#define USE_T_8 +#endif + int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT *x, BLASLONG inc_x, FLOAT beta, FLOAT *y, BLASLONG inc_y) { IFLOAT *xbuffer, *a_ptr; @@ -39,7 +44,9 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT * xbuffer = buffer; BLASLONG lda4 = lda << 2; +#ifdef USE_T_8 BLASLONG lda8 = lda << 3; +#endif BLASLONG NB = NBMAX; BLASLONG m2 = (m & (NBMAX - 1)); @@ -60,12 +67,16 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT * } if (inc_y == 1) { +#ifdef USE_T_8 for (BLASLONG j = 0; j + 8 <= n; j += 8) { BF16GEMV_T_8(NB, lda, a_ptr, xbuffer, y_ptr, alpha, beta); y_ptr += 8; a_ptr += lda8; } if (n & 4) { +#else + for (BLASLONG j = 0; j + 4 <= n; j += 4) { +#endif BF16GEMV_T_4(NB, lda, a_ptr, xbuffer, y_ptr, alpha, beta); y_ptr += 4; a_ptr += lda4; @@ -79,6 +90,7 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT * BF16GEMV_T_1(NB, lda, a_ptr, xbuffer, y_ptr, alpha, beta); } } else { +#ifdef USE_T_8 for (BLASLONG j = 0; j + 8 <= n; j += 8) { memset(ybuffer, 0, sizeof(FLOAT) * 8); BF16GEMV_T_8(NB, lda, a_ptr, xbuffer, ybuffer, alpha, beta); @@ -87,6 +99,9 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT * a_ptr += lda8; } if (n & 4) { +#else + for (BLASLONG j = 0; j + 4 <= n; j += 4) { +#endif memset(ybuffer, 0, sizeof(FLOAT) * 4); BF16GEMV_T_4(NB, lda, a_ptr, xbuffer, ybuffer, alpha, beta); copy_y(4, ybuffer, y_ptr, inc_y, beta); diff --git a/kernel/power/sbgemv_t_power10.c b/kernel/power/sbgemv_t_power10.c index 08bc4237c7..810287e89a 100644 --- a/kernel/power/sbgemv_t_power10.c +++ b/kernel/power/sbgemv_t_power10.c @@ -25,8 +25,334 @@ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. *****************************************************************************/ -//#include "sbgemv_common.c" +#ifndef SBGEMV_T_MMA_C +#define SBGEMV_T_MMA_C +#define USE_BFGEMV_T_MMA + +#ifdef USE_BFGEMV_T_MMA +#include "sbgemv_common_power10.c" + +#ifndef BF16GEMV_T_X +#define BF16GEMV_T_X +#define BF16GEMV_T_8 BF16GEMV_T_MMA_8 +#define BF16GEMV_T_4 BF16GEMV_T_MMA_4 +#define BF16GEMV_T_2 BF16GEMV_T_MMA_2 +#define BF16GEMV_T_1 BF16GEMV_T_MMA_1 +#endif + +#define USE_BFGEMV_8_T_MMA + +static void BF16GEMV_T_MMA_1(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha, FLOAT beta) +{ + IFLOAT *a0; + vec_bf16 *va0, *v_x; + __vector_quad temp0; + vec_f32 temp00[4]; + vec_bf16 inp[2]; + + __builtin_mma_xxsetaccz(&temp0); + + a0 = ap; + va0 = (vec_bf16 *)a0; + v_x = (vec_bf16 *)x; + BLASLONG n8 = n / 8; + BLASLONG i = 0; + + for (; i + 2 <= n8; i += 2) { + vec_load_pair((vec_f32 *)inp, (vec_f32 *)&v_x[i]); + + vec_load_mult2_mma(&temp0, &va0[i + 0], inp); + } + + if (n8 & 1) { + inp[0] = (vec_bf16)vec_load_vec(&v_x[i]); + + vec_load_mult_mma(&temp0, &va0[i], inp[0]); + + i++; + } + + n &= 7; + if (n) { + inp[0] = vec_loadN(&v_x[i], n); + + vec_loadN_mult_mma(&temp0, &va0[i], inp[0], n); + } + + __builtin_mma_disassemble_acc((void*)temp00, &temp0); + + y[0] = (alpha * (temp00[0][0] + temp00[1][1] + temp00[2][2] + temp00[3][3])) + (beta * y[0]); +} + +static void BF16GEMV_T_MMA_2(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha, FLOAT beta) +{ + IFLOAT *a0, *a1; + vec_bf16 *va0, *va1, *v_x; + __vector_quad temp0, temp1; + vec_f32 temp00[4], temp01[4]; + vec_bf16 inp[2]; + + __builtin_mma_xxsetaccz(&temp0); + __builtin_mma_xxsetaccz(&temp1); + + a0 = ap; + a1 = ap + lda; + va0 = (vec_bf16 *)a0; + va1 = (vec_bf16 *)a1; + v_x = (vec_bf16 *)x; + BLASLONG n8 = n / 8; + BLASLONG i = 0; + + for (; i + 2 <= n8; i += 2) { + vec_load_pair((vec_f32 *)inp, (vec_f32 *)&v_x[i]); + + vec_load_mult2_mma(&temp0, &va0[i + 0], inp); + vec_load_mult2_mma(&temp1, &va1[i + 0], inp); + } + + if (n8 & 1) { + inp[0] = (vec_bf16)vec_load_vec(&v_x[i]); + + vec_load_mult_mma(&temp0, &va0[i], inp[0]); + vec_load_mult_mma(&temp1, &va1[i], inp[0]); + + i++; + } + + n &= 7; + if (n) { + inp[0] = vec_loadN(&v_x[i], n); + + vec_loadN_mult_mma(&temp0, &va0[i], inp[0], n); + vec_loadN_mult_mma(&temp1, &va1[i], inp[0], n); + } + + __builtin_mma_disassemble_acc((void*)temp00, &temp0); + __builtin_mma_disassemble_acc((void*)temp01, &temp1); + + y[0] = (alpha * (temp00[0][0] + temp00[1][1] + temp00[2][2] + temp00[3][3])) + (beta * y[0]); + y[1] = (alpha * (temp01[0][0] + temp01[1][1] + temp01[2][2] + temp01[3][3])) + (beta * y[1]); +} + +static void BF16GEMV_T_MMA_4(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha, FLOAT beta) +{ + IFLOAT *a0, *a1, *a2, *a3; + vec_bf16 *va0, *va1, *va2, *va3, *v_x; + __vector_quad temp0, temp1, temp2, temp3; + vec_f32 temp00[4], temp01[4], temp02[4], temp03[4]; + vec_bf16 inp[2]; + + __builtin_mma_xxsetaccz(&temp0); + __builtin_mma_xxsetaccz(&temp1); + __builtin_mma_xxsetaccz(&temp2); + __builtin_mma_xxsetaccz(&temp3); + + a0 = ap; + a1 = ap + lda; + a2 = a1 + lda; + a3 = a2 + lda; + va0 = (vec_bf16 *)a0; + va1 = (vec_bf16 *)a1; + va2 = (vec_bf16 *)a2; + va3 = (vec_bf16 *)a3; + v_x = (vec_bf16 *)x; + BLASLONG n8 = n / 8; + BLASLONG i = 0; + + for (; i + 2 <= n8; i += 2) { + vec_load_pair((vec_f32 *)inp, (vec_f32 *)&v_x[i]); + + vec_load_mult2_mma(&temp0, &va0[i + 0], inp); + vec_load_mult2_mma(&temp1, &va1[i + 0], inp); + vec_load_mult2_mma(&temp2, &va2[i + 0], inp); + vec_load_mult2_mma(&temp3, &va3[i + 0], inp); + } + + if (n8 & 1) { + inp[0] = (vec_bf16)vec_load_vec(&v_x[i]); + + vec_load_mult_mma(&temp0, &va0[i], inp[0]); + vec_load_mult_mma(&temp1, &va1[i], inp[0]); + vec_load_mult_mma(&temp2, &va2[i], inp[0]); + vec_load_mult_mma(&temp3, &va3[i], inp[0]); + + i++; + } + + n &= 7; + if (n) { + inp[0] = vec_loadN(&v_x[i], n); + + vec_loadN_mult_mma(&temp0, &va0[i], inp[0], n); + vec_loadN_mult_mma(&temp1, &va1[i], inp[0], n); + vec_loadN_mult_mma(&temp2, &va2[i], inp[0], n); + vec_loadN_mult_mma(&temp3, &va3[i], inp[0], n); + } + + __builtin_mma_disassemble_acc((void*)temp00, &temp0); + __builtin_mma_disassemble_acc((void*)temp01, &temp1); + __builtin_mma_disassemble_acc((void*)temp02, &temp2); + __builtin_mma_disassemble_acc((void*)temp03, &temp3); + + vec_f32 t0, t1, t2, t3, t4, t5, t6, t7; + vec_f32 a = { alpha, alpha, alpha, alpha }; + vec_f32 b = { beta, beta, beta, beta }; + vec_f32 *v_y = (vec_f32 *) y; + + t0 = vec_mergeh(temp00[0], temp01[0]); + t1 = vec_mergeh(temp02[0], temp03[0]); + t2 = vec_mergeo(temp00[1], temp01[1]); + t3 = vec_mergeo(temp02[1], temp03[1]); + t4 = vec_mergel(temp00[2], temp01[2]); + t5 = vec_mergel(temp02[2], temp03[2]); + t6 = vec_mergeo(temp00[3], temp01[3]); + t7 = vec_mergeo(temp02[3], temp03[3]); + t0 = vec_xxpermdi(t0, t1, 0); + t2 = vec_xxpermdi(t2, t3, 0); + t4 = vec_xxpermdi(t4, t5, 0); + t6 = vec_xxpermdi(t6, t7, 3); + + t0 += t2 + t4 + t6; + + v_y[0] = (a * t0) + (b * v_y[0]); +} + +#ifdef USE_BFGEMV_8_T_MMA +static void BF16GEMV_T_MMA_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha, FLOAT beta) +{ + IFLOAT *a0, *a1, *a2, *a3, *a4, *a5, *a6, *a7; + vec_bf16 *va0, *va1, *va2, *va3, *va4, *va5, *va6, *va7, *v_x; + __vector_quad temp0, temp1, temp2, temp3, temp4, temp5, temp6, temp7; + vec_f32 temp00[4], temp01[4], temp02[4], temp03[4], temp04[4], temp05[4], temp06[4], temp07[4]; + vec_bf16 inp[2]; + + __builtin_mma_xxsetaccz(&temp0); + __builtin_mma_xxsetaccz(&temp1); + __builtin_mma_xxsetaccz(&temp2); + __builtin_mma_xxsetaccz(&temp3); + __builtin_mma_xxsetaccz(&temp4); + __builtin_mma_xxsetaccz(&temp5); + __builtin_mma_xxsetaccz(&temp6); + __builtin_mma_xxsetaccz(&temp7); + + a0 = ap; + a1 = ap + lda; + a2 = a1 + lda; + a3 = a2 + lda; + a4 = a3 + lda; + a5 = a4 + lda; + a6 = a5 + lda; + a7 = a6 + lda; + va0 = (vec_bf16 *)a0; + va1 = (vec_bf16 *)a1; + va2 = (vec_bf16 *)a2; + va3 = (vec_bf16 *)a3; + va4 = (vec_bf16 *)a4; + va5 = (vec_bf16 *)a5; + va6 = (vec_bf16 *)a6; + va7 = (vec_bf16 *)a7; + v_x = (vec_bf16 *)x; + BLASLONG n8 = n / 8; + BLASLONG i = 0; + + for (; i + 2 <= n8; i += 2) { + vec_load_pair((vec_f32 *)inp, (vec_f32 *)&v_x[i]); + + vec_load_mult2_mma(&temp0, &va0[i + 0], inp); + vec_load_mult2_mma(&temp1, &va1[i + 0], inp); + vec_load_mult2_mma(&temp2, &va2[i + 0], inp); + vec_load_mult2_mma(&temp3, &va3[i + 0], inp); + vec_load_mult2_mma(&temp4, &va4[i + 0], inp); + vec_load_mult2_mma(&temp5, &va5[i + 0], inp); + vec_load_mult2_mma(&temp6, &va6[i + 0], inp); + vec_load_mult2_mma(&temp7, &va7[i + 0], inp); + } + + if (n8 & 1) { + inp[0] = (vec_bf16)vec_load_vec(&v_x[i]); + + vec_load_mult_mma(&temp0, &va0[i], inp[0]); + vec_load_mult_mma(&temp1, &va1[i], inp[0]); + vec_load_mult_mma(&temp2, &va2[i], inp[0]); + vec_load_mult_mma(&temp3, &va3[i], inp[0]); + vec_load_mult_mma(&temp4, &va4[i], inp[0]); + vec_load_mult_mma(&temp5, &va5[i], inp[0]); + vec_load_mult_mma(&temp6, &va6[i], inp[0]); + vec_load_mult_mma(&temp7, &va7[i], inp[0]); + + i++; + } + + n &= 7; + if (n) { + inp[0] = vec_loadN(&v_x[i], n); + + vec_loadN_mult_mma(&temp0, &va0[i], inp[0], n); + vec_loadN_mult_mma(&temp1, &va1[i], inp[0], n); + vec_loadN_mult_mma(&temp2, &va2[i], inp[0], n); + vec_loadN_mult_mma(&temp3, &va3[i], inp[0], n); + vec_loadN_mult_mma(&temp4, &va4[i], inp[0], n); + vec_loadN_mult_mma(&temp5, &va5[i], inp[0], n); + vec_loadN_mult_mma(&temp6, &va6[i], inp[0], n); + vec_loadN_mult_mma(&temp7, &va7[i], inp[0], n); + } + + __builtin_mma_disassemble_acc((void*)temp00, &temp0); + __builtin_mma_disassemble_acc((void*)temp01, &temp1); + __builtin_mma_disassemble_acc((void*)temp02, &temp2); + __builtin_mma_disassemble_acc((void*)temp03, &temp3); + __builtin_mma_disassemble_acc((void*)temp04, &temp4); + __builtin_mma_disassemble_acc((void*)temp05, &temp5); + __builtin_mma_disassemble_acc((void*)temp06, &temp6); + __builtin_mma_disassemble_acc((void*)temp07, &temp7); + + vec_f32 t0, t1, t2, t3, t4, t5, t6, t7, t10, t11, t12, t13, t14, t15, t16, t17; + vec_f32 a = { alpha, alpha, alpha, alpha }; + vec_f32 b = { beta, beta, beta, beta }; + vec_f32 *v_y = (vec_f32 *) y; + + t0 = vec_mergeh(temp00[0], temp01[0]); + t1 = vec_mergeh(temp02[0], temp03[0]); + t2 = vec_mergeo(temp00[1], temp01[1]); + t3 = vec_mergeo(temp02[1], temp03[1]); + t4 = vec_mergel(temp00[2], temp01[2]); + t5 = vec_mergel(temp02[2], temp03[2]); + t6 = vec_mergeo(temp00[3], temp01[3]); + t7 = vec_mergeo(temp02[3], temp03[3]); + t0 = vec_xxpermdi(t0, t1, 0); + t2 = vec_xxpermdi(t2, t3, 0); + t4 = vec_xxpermdi(t4, t5, 0); + t6 = vec_xxpermdi(t6, t7, 3); + + t0 += t2 + t4 + t6; + + t10 = vec_mergeh(temp04[0], temp05[0]); + t11 = vec_mergeh(temp06[0], temp07[0]); + t12 = vec_mergeo(temp04[1], temp05[1]); + t13 = vec_mergeo(temp06[1], temp07[1]); + t14 = vec_mergel(temp04[2], temp05[2]); + t15 = vec_mergel(temp06[2], temp07[2]); + t16 = vec_mergeo(temp04[3], temp05[3]); + t17 = vec_mergeo(temp06[3], temp07[3]); + t10 = vec_xxpermdi(t10, t11, 0); + t12 = vec_xxpermdi(t12, t13, 0); + t14 = vec_xxpermdi(t14, t15, 0); + t16 = vec_xxpermdi(t16, t17, 3); + + t10 += t12 + t14 + t16; + + vec_f32 inp2[2]; + vec_load_pair(inp2, v_y); + inp2[0] = (a * t0) + (b * inp2[0]); + inp2[1] = (a * t10) + (b * inp2[1]); + vec_store_pair(v_y, inp2); +} +#endif + +#include "sbgemv_t.c" +#else #include "sbgemv_t_vsx.c" +#endif +#endif -//#include "sbgemv_t.c" diff --git a/kernel/power/sbgemv_t_vsx.c b/kernel/power/sbgemv_t_vsx.c index 7da894109b..399989bb52 100644 --- a/kernel/power/sbgemv_t_vsx.c +++ b/kernel/power/sbgemv_t_vsx.c @@ -25,12 +25,20 @@ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. *****************************************************************************/ -#ifndef SBGEMV_T_VSX -#define SBGEMV_T_VSX +#ifndef SBGEMV_T_VSX_C +#define SBGEMV_T_VSX_C #include "sbgemv_common.c" -#define NBMAX 4096 +#ifndef BF16GEMV_T_X +#define BF16GEMV_T_X +#define BF16GEMV_T_8 BF16GEMV_T_VSX_8 +#define BF16GEMV_T_4 BF16GEMV_T_VSX_4 +#define BF16GEMV_T_2 BF16GEMV_T_VSX_2 +#define BF16GEMV_T_1 BF16GEMV_T_VSX_1 +#endif + +#define USE_BFGEMV_8_T_VSX static void BF16GEMV_T_VSX_1(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha, FLOAT beta) { @@ -58,9 +66,9 @@ static void BF16GEMV_T_VSX_1(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL temp0 += vec_loadN_mult(&va0[i], inp, n, zero); } else if (n) { - vec_f32 v_inp0 = vec_loadNHi_vec(v_x, i, n, zero); + inp[0] = vec_loadNHi_vec(v_x, i, n, zero); - temp0 += vec_loadNHi_mult(&va0[i], v_inp0, n, zero); + temp0 += vec_loadNHi_mult(&va0[i], inp[0], n, zero); } y[0] = (alpha * (temp0[0] + temp0[1] + temp0[2] + temp0[3])) + (beta * y[0]); @@ -97,10 +105,10 @@ static void BF16GEMV_T_VSX_2(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL temp0 += vec_loadN_mult(&va0[i], inp, n, zero); temp1 += vec_loadN_mult(&va1[i], inp, n, zero); } else if (n) { - vec_f32 v_inp0 = vec_loadNHi_vec(v_x, i, n, zero); + inp[0] = vec_loadNHi_vec(v_x, i, n, zero); - temp0 += vec_loadNHi_mult(&va0[i], v_inp0, n, zero); - temp1 += vec_loadNHi_mult(&va1[i], v_inp0, n, zero); + temp0 += vec_loadNHi_mult(&va0[i], inp[0], n, zero); + temp1 += vec_loadNHi_mult(&va1[i], inp[0], n, zero); } y[0] = (alpha * (temp0[0] + temp0[1] + temp0[2] + temp0[3])) + (beta * y[0]); @@ -148,12 +156,12 @@ static void BF16GEMV_T_VSX_4(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL temp2 += vec_loadN_mult(&va2[i], inp, n, zero); temp3 += vec_loadN_mult(&va3[i], inp, n, zero); } else if (n) { - vec_f32 v_inp0 = vec_loadNHi_vec(v_x, i, n, zero); + inp[0] = vec_loadNHi_vec(v_x, i, n, zero); - temp0 += vec_loadNHi_mult(&va0[i], v_inp0, n, zero); - temp1 += vec_loadNHi_mult(&va1[i], v_inp0, n, zero); - temp2 += vec_loadNHi_mult(&va2[i], v_inp0, n, zero); - temp3 += vec_loadNHi_mult(&va3[i], v_inp0, n, zero); + temp0 += vec_loadNHi_mult(&va0[i], inp[0], n, zero); + temp1 += vec_loadNHi_mult(&va1[i], inp[0], n, zero); + temp2 += vec_loadNHi_mult(&va2[i], inp[0], n, zero); + temp3 += vec_loadNHi_mult(&va3[i], inp[0], n, zero); } vec_f32 t0, t1, t2, t3; @@ -174,6 +182,7 @@ static void BF16GEMV_T_VSX_4(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL v_y[0] = (a * temp0) + (b * v_y[0]); } +#ifdef USE_BFGEMV_8_T_VSX static void BF16GEMV_T_VSX_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha, FLOAT beta) { IFLOAT *a0, *a1, *a2, *a3, *a4, *a5, *a6, *a7; @@ -235,16 +244,16 @@ static void BF16GEMV_T_VSX_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL temp6 += vec_loadN_mult(&va6[i], inp, n, zero); temp7 += vec_loadN_mult(&va7[i], inp, n, zero); } else if (n) { - vec_f32 v_inp0 = vec_loadNHi_vec(v_x, i, n, zero); - - temp0 += vec_loadNHi_mult(&va0[i], v_inp0, n, zero); - temp1 += vec_loadNHi_mult(&va1[i], v_inp0, n, zero); - temp2 += vec_loadNHi_mult(&va2[i], v_inp0, n, zero); - temp3 += vec_loadNHi_mult(&va3[i], v_inp0, n, zero); - temp4 += vec_loadNHi_mult(&va4[i], v_inp0, n, zero); - temp5 += vec_loadNHi_mult(&va5[i], v_inp0, n, zero); - temp6 += vec_loadNHi_mult(&va6[i], v_inp0, n, zero); - temp7 += vec_loadNHi_mult(&va7[i], v_inp0, n, zero); + inp[0] = vec_loadNHi_vec(v_x, i, n, zero); + + temp0 += vec_loadNHi_mult(&va0[i], inp[0], n, zero); + temp1 += vec_loadNHi_mult(&va1[i], inp[0], n, zero); + temp2 += vec_loadNHi_mult(&va2[i], inp[0], n, zero); + temp3 += vec_loadNHi_mult(&va3[i], inp[0], n, zero); + temp4 += vec_loadNHi_mult(&va4[i], inp[0], n, zero); + temp5 += vec_loadNHi_mult(&va5[i], inp[0], n, zero); + temp6 += vec_loadNHi_mult(&va6[i], inp[0], n, zero); + temp7 += vec_loadNHi_mult(&va7[i], inp[0], n, zero); } vec_f32 t0, t1, t2, t3; @@ -272,14 +281,12 @@ static void BF16GEMV_T_VSX_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL temp7 = vec_mergel(t1, t3); temp4 += temp5 + temp6 + temp7; - v_y[0] = (a * temp0) + (b * v_y[0]); - v_y[1] = (a * temp4) + (b * v_y[1]); + vec_load_pair(inp, v_y); + inp[0] = (a * temp0) + (b * inp[0]); + inp[1] = (a * temp4) + (b * inp[1]); + vec_store_pair(v_y, inp); } - -#define BF16GEMV_T_8 BF16GEMV_T_VSX_8 -#define BF16GEMV_T_4 BF16GEMV_T_VSX_4 -#define BF16GEMV_T_2 BF16GEMV_T_VSX_2 -#define BF16GEMV_T_1 BF16GEMV_T_VSX_1 +#endif #include "sbgemv_t.c" #endif From c9ce37d527311145be20210fe6cef792aca7a6f5 Mon Sep 17 00:00:00 2001 From: Chip Kerchner Date: Mon, 23 Sep 2024 08:43:58 -0500 Subject: [PATCH 08/24] Force vector pairs in clang. --- kernel/power/gemm_common.c | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/kernel/power/gemm_common.c b/kernel/power/gemm_common.c index 0611ebc2a9..ed00de95b0 100644 --- a/kernel/power/gemm_common.c +++ b/kernel/power/gemm_common.c @@ -46,7 +46,11 @@ FORCEINLINE void vec_load_pair(vec_f32 *dst, vec_f32 *src) { #ifdef USE_VECTOR_PAIRS __vector_pair vy0p; +#ifdef __clang__ + vy0p = __builtin_vsx_lxvp(0L, (const __vector_pair *)(src)); +#else vy0p = *(__vector_pair *)(src); +#endif __builtin_vsx_disassemble_pair((void *)(dst), &vy0p); #else dst[0] = src[0]; @@ -59,7 +63,11 @@ FORCEINLINE void vec_store_pair(vec_f32 *dst, vec_f32 *src) #ifdef USE_VECTOR_PAIRS __vector_pair vy0p; __builtin_vsx_assemble_pair2(&vy0p, (vec_uc8)src[1], (vec_uc8)src[0]); +#ifdef __clang__ + __builtin_vsx_stxvp(vy0p, 0L, (__vector_pair *)(dst)); +#else *(__vector_pair *)(dst) = vy0p; +#endif #else dst[0] = src[0]; dst[1] = src[1]; From 05aa63e738edeb06ad2697a04f8889f6c0746067 Mon Sep 17 00:00:00 2001 From: Chip Kerchner Date: Tue, 24 Sep 2024 12:54:02 -0500 Subject: [PATCH 09/24] More MMA BF16 GEMV code. --- kernel/power/sbgemv_common_power10.c | 179 ++++++++++++++++++- kernel/power/sbgemv_n_power10.c | 246 ++++++++++++++++++++++----- 2 files changed, 372 insertions(+), 53 deletions(-) diff --git a/kernel/power/sbgemv_common_power10.c b/kernel/power/sbgemv_common_power10.c index da088014b0..2ee912b9d2 100644 --- a/kernel/power/sbgemv_common_power10.c +++ b/kernel/power/sbgemv_common_power10.c @@ -29,6 +29,10 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #define SBGEMV_COMMON_MMA_C #include "sbgemv_common.c" +#if defined(_AIX) || defined(__clang__) +#define USE_MERGE_MMA +#endif + FORCEINLINE void vec_load_mult_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 inp) { vec_bf16 in0 = (vec_bf16)vec_load_vec(in); @@ -69,11 +73,13 @@ FORCEINLINE void vec_mult2_mma(__vector_quad *out, vec_bf16 in0, vec_bf16 inp) __builtin_mma_xvbf16ger2(&out[1], (vec_uc8)inp, (vec_uc8)in01); } +#ifndef USE_MERGE_MMA FORCEINLINE void vec_mult4_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 inp) { vec_mult2_mma(out + 0, in0[0], inp); vec_mult2_mma(out + 2, in0[1], inp); } +#endif FORCEINLINE void vec_loadN_mult11_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 inp, BLASLONG n) { @@ -96,6 +102,7 @@ FORCEINLINE void vec_load_mult12_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 vec_mult2_mma(out, in0, inp); } +#ifndef USE_MERGE_MMA FORCEINLINE void vec_load_mult18_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 inp) { vec_bf16 in0[4]; @@ -106,6 +113,7 @@ FORCEINLINE void vec_load_mult18_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 vec_mult4_mma(&out[0], in0 + 0, inp); vec_mult4_mma(&out[4], in0 + 2, inp); } +#endif FORCEINLINE void vec_reduce1_mma(__vector_quad *out, vec_f32 *temp, vec_f32 v_alpha, vec_f32 *vy0) { @@ -120,6 +128,7 @@ FORCEINLINE void vec_reduce2_mma(__vector_quad *out, vec_f32 *temp, vec_f32 v_al vec_reduce1_mma(&out[1], &temp[4], v_alpha, &vy0[1]); } +#ifndef USE_MERGE_MMA FORCEINLINE void vec_reduce8_mma(__vector_quad *out, vec_f32 *temp, vec_f32 v_alpha, vec_f32 *vy0) { vec_reduce2_mma(&out[0], &temp[0], v_alpha, vy0 + 0); @@ -127,6 +136,23 @@ FORCEINLINE void vec_reduce8_mma(__vector_quad *out, vec_f32 *temp, vec_f32 v_al vec_reduce2_mma(&out[4], &temp[16], v_alpha, vy0 + 4); vec_reduce2_mma(&out[6], &temp[24], v_alpha, vy0 + 6); } +#else +FORCEINLINE void vec_reduce44_mma(__vector_quad *out, vec_f32 *temp, vec_f32 v_alpha, vec_f32 *vy0) +{ + __builtin_mma_disassemble_acc((void*)temp, &out[0]); + + vy0[0] += (temp[0] * v_alpha); + vy0[2] += (temp[1] * v_alpha); + vy0[4] += (temp[2] * v_alpha); + vy0[6] += (temp[3] * v_alpha); +} + +FORCEINLINE void vec_reduce84_mma(__vector_quad *out, vec_f32 *temp, vec_f32 v_alpha, vec_f32 *vy0) +{ + vec_reduce44_mma(&out[0], &temp[0], v_alpha, vy0 + 0); + vec_reduce44_mma(&out[1], &temp[4], v_alpha, vy0 + 1); +} +#endif FORCEINLINE void vec_mult11a_mma(__vector_quad *out, vec_bf16 in0, vec_bf16 in1, vec_bf16 inp) { @@ -166,18 +192,25 @@ FORCEINLINE void vec_load_mult22a_mma(__vector_quad *out, vec_bf16 *ina, vec_bf1 vec_mult2a_mma(out, in0, in1, inp); } -FORCEINLINE void vec_load_mult28a_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 inp) +FORCEINLINE void vec_load4_mma(vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *ina, vec_bf16 *inb) { - vec_bf16 in0[4], in1[4]; - vec_load_pair((vec_f32 *)(in0 + 0), (vec_f32 *)(ina + 0)); vec_load_pair((vec_f32 *)(in1 + 0), (vec_f32 *)(inb + 0)); vec_load_pair((vec_f32 *)(in0 + 2), (vec_f32 *)(ina + 2)); vec_load_pair((vec_f32 *)(in1 + 2), (vec_f32 *)(inb + 2)); +} + +#ifndef USE_MERGE_MMA +FORCEINLINE void vec_load_mult28a_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 inp) +{ + vec_bf16 in0[4], in1[4]; + + vec_load4_mma(in0, in1, ina, inb); vec_mult4a_mma(&out[0], in0 + 0, in1 + 0, inp); vec_mult4a_mma(&out[4], in0 + 2, in1 + 2, inp); } +#endif FORCEINLINE void vec_loadN_mult22a_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 inp, BLASLONG n) { @@ -209,6 +242,48 @@ FORCEINLINE void vec_mult4b_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1 vec_mult2b_mma(out + 2, in0[1], in1[1], inp); } +#ifdef USE_MERGE_MMA +FORCEINLINE void vec_mult1c_mma(__vector_quad *out, vec_bf16 in0, vec_bf16 inp) +{ + vec_bf16 in00 = vec_mergeh(in0, in0); + + __builtin_mma_xvbf16ger2pp(out, (vec_uc8)inp, (vec_uc8)in00); +} + +FORCEINLINE void vec_mult2c_mma(__vector_quad *out, vec_bf16 in0, vec_bf16 inp) +{ + vec_bf16 in01 = vec_mergel(in0, in0); + + vec_mult1c_mma(&out[0], in0, inp); + + __builtin_mma_xvbf16ger2pp(&out[1], (vec_uc8)inp, (vec_uc8)in01); +} + +FORCEINLINE void vec_mult44_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 *inp) +{ + vec_mult2_mma(out, in[0], inp[0]); + vec_mult2c_mma(out, in[1], inp[1]); +} + +FORCEINLINE void vec_mult44c_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 *inp) +{ + vec_mult2c_mma(out, in[0], inp[0]); + vec_mult2c_mma(out, in[1], inp[1]); +} + +FORCEINLINE void vec_mult44a_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *inp) +{ + vec_mult2a_mma(out, in0[0], in1[0], inp[0]); + vec_mult2b_mma(out, in0[1], in1[1], inp[1]); +} + +FORCEINLINE void vec_mult44b_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *inp) +{ + vec_mult2b_mma(out, in0[0], in1[0], inp[0]); + vec_mult2b_mma(out, in0[1], in1[1], inp[1]); +} +#endif + FORCEINLINE void vec_loadN_mult11b_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 inp, BLASLONG n) { vec_bf16 in0 = vec_loadN(ina, n); @@ -225,18 +300,48 @@ FORCEINLINE void vec_load_mult22b_mma(__vector_quad *out, vec_bf16 *ina, vec_bf1 vec_mult2b_mma(out, in0, in1, inp); } +#ifndef USE_MERGE_MMA FORCEINLINE void vec_load_mult28b_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 inp) { vec_bf16 in0[4], in1[4]; - vec_load_pair((vec_f32 *)(in0 + 0), (vec_f32 *)(ina + 0)); - vec_load_pair((vec_f32 *)(in1 + 0), (vec_f32 *)(inb + 0)); - vec_load_pair((vec_f32 *)(in0 + 2), (vec_f32 *)(ina + 2)); - vec_load_pair((vec_f32 *)(in1 + 2), (vec_f32 *)(inb + 2)); + vec_load4_mma(in0, in1, ina, inb); vec_mult4b_mma(&out[0], in0 + 0, in1 + 0, inp); vec_mult4b_mma(&out[4], in0 + 2, in1 + 2, inp); } +#else +FORCEINLINE void vec_load_mult184_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 *inp) +{ + vec_bf16 in0[4]; + + vec_load_pair((vec_f32 *)(in0 + 0), (vec_f32 *)(in + 0)); + vec_load_pair((vec_f32 *)(in0 + 2), (vec_f32 *)(in + 2)); + + vec_mult44_mma(out, in0 + 0, inp + 0); + vec_mult44c_mma(out, in0 + 2, inp + 2); +} + +FORCEINLINE void vec_load_mult284a_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 *inp) +{ + vec_bf16 in0[4], in1[4]; + + vec_load4_mma(in0, in1, ina, inb); + + vec_mult44a_mma(out, in0 + 0, in1 + 0, inp + 0); + vec_mult44b_mma(out, in0 + 2, in1 + 2, inp + 2); +} + +FORCEINLINE void vec_load_mult284b_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 *inp) +{ + vec_bf16 in0[4], in1[4]; + + vec_load4_mma(in0, in1, ina, inb); + + vec_mult44b_mma(out, in0 + 0, in1 + 0, inp + 0); + vec_mult44b_mma(out, in0 + 2, in1 + 2, inp + 2); +} +#endif FORCEINLINE void vec_loadN_mult22b_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 inp, BLASLONG n) { @@ -262,4 +367,64 @@ FORCEINLINE void vec_store4_pair(vec_f32 *v_y, vec_f32 *vy0) vec_store_pair(v_y + 6, vy0 + 6); } +#ifdef USE_MERGE_MMA +FORCEINLINE void vec_load8_pair(vec_f32 *vy0, vec_f32 *v_y) +{ + vec_load4_pair(vy0 + 0, v_y + 0); + vec_load4_pair(vy0 + 8, v_y + 8); +} + +FORCEINLINE void vec_store8_pair(vec_f32 *v_y, vec_f32 *vy0) +{ + vec_store4_pair(v_y + 0, vy0 + 0); + vec_store4_pair(v_y + 8, vy0 + 8); +} + +#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ +#define VEC_SHIFT(data, shift) vec_sld(data, data, 16 - shift) +#else +#define VEC_SHIFT(data, shift) vec_sld(data, data, shift) +#endif + +typedef __vector unsigned int vec_ui32; + +static vec_ui32 mask_0 = { 0xffffffff, 0x00000000, 0x00000000, 0x00000000 }; +static vec_ui32 mask_1 = { 0x00000000, 0xffffffff, 0x00000000, 0x00000000 }; +static vec_ui32 mask_2 = { 0x00000000, 0x00000000, 0xffffffff, 0x00000000 }; +static vec_ui32 mask_3 = { 0x00000000, 0x00000000, 0x00000000, 0xffffffff }; + +FORCEINLINE void vec_make_mult1(vec_bf16 *v_x0) +{ + v_x0[ 0] = vec_and(v_x0[0], (vec_bf16)mask_0); + + v_x0[ 1] = VEC_SHIFT(v_x0[ 0], 4); + v_x0[ 2] = VEC_SHIFT(v_x0[ 0], 8); + v_x0[ 3] = VEC_SHIFT(v_x0[ 0], 12); +} + +FORCEINLINE void vec_make_mult2(vec_bf16 *v_x0) +{ + v_x0[ 5] = vec_and(v_x0[0], (vec_bf16)mask_1); + vec_make_mult1(v_x0); + + v_x0[ 4] = VEC_SHIFT(v_x0[ 5], 12); + v_x0[ 6] = VEC_SHIFT(v_x0[ 5], 4); + v_x0[ 7] = VEC_SHIFT(v_x0[ 5], 8); +} + +FORCEINLINE void vec_make_mult4(vec_bf16 *v_x0) +{ + v_x0[10] = vec_and(v_x0[0], (vec_bf16)mask_2); + v_x0[15] = vec_and(v_x0[0], (vec_bf16)mask_3); + vec_make_mult2(v_x0); + + v_x0[ 8] = VEC_SHIFT(v_x0[10], 8); + v_x0[ 9] = VEC_SHIFT(v_x0[10], 12); + v_x0[11] = VEC_SHIFT(v_x0[10], 4); + v_x0[12] = VEC_SHIFT(v_x0[15], 4); + v_x0[13] = VEC_SHIFT(v_x0[15], 8); + v_x0[14] = VEC_SHIFT(v_x0[15], 12); +} +#endif + #endif diff --git a/kernel/power/sbgemv_n_power10.c b/kernel/power/sbgemv_n_power10.c index 7b2beb0c7b..f2ed6bf9a8 100644 --- a/kernel/power/sbgemv_n_power10.c +++ b/kernel/power/sbgemv_n_power10.c @@ -28,7 +28,9 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #ifndef SBGEMV_N_MMA_C #define SBGEMV_N_MMA_C +#if !defined(_AIX) || defined(__clang__) #define USE_BFGEMV_N_MMA +#endif #ifdef USE_BFGEMV_N_MMA #include "sbgemv_common_power10.c" @@ -47,7 +49,7 @@ static void BF16GEMV_N_MMA_1(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA { IFLOAT *a0; __vector_quad temp[2*4]; - vec_f32 temp0[8*4], vy0[2*4]; + vec_f32 temp0[8*4]; vec_f32 v_alpha = { alpha, alpha, alpha, alpha }; a0 = ap[0]; @@ -55,26 +57,61 @@ static void BF16GEMV_N_MMA_1(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA vec_bf16 *va0 = (vec_bf16 *)a0; vec_bf16 *x_bf = (vec_bf16 *)(xo); - vec_bf16 v_x0 = vec_loadN(x_bf, 1); vec_f32 *v_y = (vec_f32 *)y; BLASLONG n8 = n / 8; BLASLONG i = 0; +#ifdef USE_MERGE_MMA + vec_bf16 v_x0[4]; + v_x0[0] = vec_loadN(x_bf, 1); + vec_f32 vy0[2*4*2]; + + vec_make_mult1(v_x0); + + for (; i + 8 <= n8; i += 8) { + vec_load8_pair(vy0, &v_y[(i * 2) + 0]); + + vec_load_mult184_mma(&temp[0], &va0[i + 0], &v_x0[ 0]); + vec_load_mult184_mma(&temp[2], &va0[i + 4], &v_x0[ 0]); + + vec_reduce84_mma(&temp[0], temp0 + 0, v_alpha, vy0 + 0); + vec_reduce84_mma(&temp[2], temp0 + 8, v_alpha, vy0 + 8); + + vec_store8_pair(&v_y[(i * 2) + 0], vy0); + } + + if (n8 & 4) { + vec_load4_pair(vy0, &v_y[(i * 2) + 0]); + + vec_load_mult184_mma(&temp[0], &va0[i + 0], &v_x0[ 0]); + + vec_reduce84_mma(&temp[0], temp0, v_alpha, vy0); + + vec_store4_pair(&v_y[(i * 2) + 0], vy0); + + i += 4; + } +#else + vec_bf16 v_x0[1]; + v_x0[0] = vec_loadN(x_bf, 1); + vec_f32 vy0[2*4]; + for (; i + 4 <= n8; i += 4) { vec_load4_pair(vy0, &v_y[(i * 2) + 0]); - vec_load_mult18_mma(&temp[0], &va0[i + 0], v_x0); + vec_load_mult18_mma(&temp[0], &va0[i + 0], v_x0[ 0]); vec_reduce8_mma(&temp[0], temp0, v_alpha, vy0); vec_store4_pair(&v_y[(i * 2) + 0], vy0); } +#endif for (; i < n8; i++) { vec_load_pair(vy0, &v_y[(i * 2) + 0]); - vec_load_mult12_mma(&temp[0], &va0[i], v_x0); + vec_load_mult12_mma(&temp[0], &va0[i], v_x0[ 0]); vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0); @@ -86,7 +123,7 @@ static void BF16GEMV_N_MMA_1(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA BLASLONG n3 = n & 3; vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n3); - vec_loadN_mult12_mma(&temp[0], &va0[i], v_x0, n); + vec_loadN_mult12_mma(&temp[0], &va0[i], v_x0[ 0], n); vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0); @@ -94,7 +131,7 @@ static void BF16GEMV_N_MMA_1(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA } else if (n) { vy0[0] = vec_loadN_f32(&v_y[(i * 2) + 0], n); - vec_loadN_mult11_mma(&temp[0], &va0[i], v_x0, n); + vec_loadN_mult11_mma(&temp[0], &va0[i], v_x0[ 0], n); vec_reduce1_mma(&temp[0], temp0, v_alpha, vy0); @@ -106,7 +143,7 @@ static void BF16GEMV_N_MMA_2(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA { IFLOAT *a0, *a1; __vector_quad temp[2*4]; - vec_f32 temp0[8*4], vy0[2*4]; + vec_f32 temp0[8*4]; vec_f32 v_alpha = { alpha, alpha, alpha, alpha }; a0 = ap[0]; @@ -116,26 +153,61 @@ static void BF16GEMV_N_MMA_2(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA vec_bf16 *va1 = (vec_bf16 *)a1; vec_bf16 *x_bf = (vec_bf16 *)(xo); - vec_bf16 v_x0 = vec_loadN(x_bf, 2); vec_f32 *v_y = (vec_f32 *)y; BLASLONG n8 = n / 8; BLASLONG i = 0; +#ifdef USE_MERGE_MMA + vec_bf16 v_x0[4]; + vec_f32 vy0[2*4*2]; + v_x0[0] = vec_loadN(x_bf, 2); + + vec_make_mult1(v_x0); + + for (; i + 8 <= n8; i += 8) { + vec_load8_pair(vy0, &v_y[(i * 2) + 0]); + + vec_load_mult284a_mma(&temp[0], &va0[i + 0], &va1[i + 0], &v_x0[ 0]); + vec_load_mult284a_mma(&temp[2], &va0[i + 4], &va1[i + 4], &v_x0[ 0]); + + vec_reduce84_mma(&temp[0], temp0 + 0, v_alpha, vy0 + 0); + vec_reduce84_mma(&temp[2], temp0 + 8, v_alpha, vy0 + 8); + + vec_store8_pair(&v_y[(i * 2) + 0], vy0); + } + + if (n8 & 4) { + vec_load4_pair(vy0, &v_y[(i * 2) + 0]); + + vec_load_mult284a_mma(&temp[0], &va0[i + 0], &va1[i + 0], &v_x0[ 0]); + + vec_reduce84_mma(&temp[0], temp0, v_alpha, vy0); + + vec_store4_pair(&v_y[(i * 2) + 0], vy0); + + i += 4; + } +#else + vec_bf16 v_x0[1]; + vec_f32 vy0[2*4]; + v_x0[0] = vec_loadN(x_bf, 2); + for (; i + 4 <= n8; i += 4) { vec_load4_pair(vy0, &v_y[(i * 2) + 0]); - vec_load_mult28a_mma(&temp[0], &va0[i + 0], &va1[i + 0], v_x0); + vec_load_mult28a_mma(&temp[0], &va0[i + 0], &va1[i + 0], v_x0[ 0]); vec_reduce8_mma(&temp[0], temp0, v_alpha, vy0); vec_store4_pair(&v_y[(i * 2) + 0], vy0); } +#endif for (; i < n8; i++) { vec_load_pair(vy0, &v_y[(i * 2) + 0]); - vec_load_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x0); + vec_load_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0]); vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0); @@ -147,7 +219,7 @@ static void BF16GEMV_N_MMA_2(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA BLASLONG n3 = n & 3; vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n3); - vec_loadN_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x0, n); + vec_loadN_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0], n); vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0); @@ -155,7 +227,7 @@ static void BF16GEMV_N_MMA_2(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA } else if (n) { vy0[0] = vec_loadN_f32(&v_y[(i * 2) + 0], n); - vec_loadN_mult11a_mma(&temp[0], &va0[i], &va1[i], v_x0, n); + vec_loadN_mult11a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0], n); vec_reduce1_mma(&temp[0], temp0, v_alpha, vy0); @@ -167,7 +239,7 @@ static void BF16GEMV_N_MMA_4(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA { IFLOAT *a0, *a1, *a2, *a3; __vector_quad temp[2*4]; - vec_f32 temp0[8*4], vy0[2*4]; + vec_f32 temp0[8*4]; vec_f32 v_alpha = { alpha, alpha, alpha, alpha }; a0 = ap[0]; @@ -181,30 +253,68 @@ static void BF16GEMV_N_MMA_4(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA vec_bf16 *va3 = (vec_bf16 *)a3; vec_bf16 *x_bf = (vec_bf16 *)(xo); - vec_bf16 v_x00 = vec_loadN(x_bf, 4); - - vec_bf16 v_x01 = (vec_bf16)vec_splat((vec_f32)v_x00, 1); vec_f32 *v_y = (vec_f32 *)y; BLASLONG n8 = n / 8; BLASLONG i = 0; +#ifdef USE_MERGE_MMA + vec_bf16 v_x0[8]; + vec_f32 vy0[2*4*2]; + v_x0[0] = vec_loadN(x_bf, 4); + + vec_make_mult2(v_x0); + + for (; i + 8 <= n8; i += 8) { + vec_load8_pair(vy0, &v_y[(i * 2) + 0]); + + vec_load_mult284a_mma(&temp[0], &va0[i + 0], &va1[i + 0], &v_x0[ 0]); + vec_load_mult284b_mma(&temp[0], &va2[i + 0], &va3[i + 0], &v_x0[ 4]); + vec_load_mult284a_mma(&temp[2], &va0[i + 4], &va1[i + 4], &v_x0[ 0]); + vec_load_mult284b_mma(&temp[2], &va2[i + 4], &va3[i + 4], &v_x0[ 4]); + + vec_reduce84_mma(&temp[0], temp0 + 0, v_alpha, vy0 + 0); + vec_reduce84_mma(&temp[2], temp0 + 8, v_alpha, vy0 + 8); + + vec_store8_pair(&v_y[(i * 2) + 0], vy0); + } + + if (n8 & 4) { + vec_load4_pair(vy0, &v_y[(i * 2) + 0]); + + vec_load_mult284a_mma(&temp[0], &va0[i + 0], &va1[i + 0], &v_x0[ 0]); + vec_load_mult284b_mma(&temp[0], &va2[i + 0], &va3[i + 0], &v_x0[ 4]); + + vec_reduce84_mma(&temp[0], temp0, v_alpha, vy0); + + vec_store4_pair(&v_y[(i * 2) + 0], vy0); + + i += 4; + } +#else + vec_bf16 v_x0[5]; + vec_f32 vy0[2*4]; + v_x0[0] = vec_loadN(x_bf, 4); + + v_x0[ 4] = (vec_bf16)vec_splat((vec_f32)v_x0[0], 1); + for (; i + 4 <= n8; i += 4) { vec_load4_pair(vy0, &v_y[(i * 2) + 0]); - vec_load_mult28a_mma(&temp[0], &va0[i + 0], &va1[i + 0], v_x00); - vec_load_mult28b_mma(&temp[0], &va2[i + 0], &va3[i + 0], v_x01); + vec_load_mult28a_mma(&temp[0], &va0[i + 0], &va1[i + 0], v_x0[ 0]); + vec_load_mult28b_mma(&temp[0], &va2[i + 0], &va3[i + 0], v_x0[ 4]); vec_reduce8_mma(&temp[0], temp0, v_alpha, vy0); vec_store4_pair(&v_y[(i * 2) + 0], vy0); } +#endif for (; i < n8; i++) { vec_load_pair(vy0, &v_y[(i * 2) + 0]); - vec_load_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x00); - vec_load_mult22b_mma(&temp[0], &va2[i], &va3[i], v_x01); + vec_load_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0]); + vec_load_mult22b_mma(&temp[0], &va2[i], &va3[i], v_x0[ 4]); vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0); @@ -216,8 +326,8 @@ static void BF16GEMV_N_MMA_4(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA BLASLONG n3 = n & 3; vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n3); - vec_loadN_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x00, n); - vec_loadN_mult22b_mma(&temp[0], &va2[i], &va3[i], v_x01, n); + vec_loadN_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0], n); + vec_loadN_mult22b_mma(&temp[0], &va2[i], &va3[i], v_x0[ 4], n); vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0); @@ -225,8 +335,8 @@ static void BF16GEMV_N_MMA_4(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA } else if (n) { vy0[0] = vec_loadN_f32(&v_y[(i * 2) + 0], n); - vec_loadN_mult11a_mma(&temp[0], &va0[i], &va1[i], v_x00, n); - vec_loadN_mult11b_mma(&temp[0], &va2[i], &va3[i], v_x01, n); + vec_loadN_mult11a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0], n); + vec_loadN_mult11b_mma(&temp[0], &va2[i], &va3[i], v_x0[ 4], n); vec_reduce1_mma(&temp[0], temp0, v_alpha, vy0); @@ -239,7 +349,7 @@ static void BF16GEMV_N_MMA_8(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, BLAS { IFLOAT *a0, *a1, *a2, *a3, *b0, *b1, *b2, *b3; __vector_quad temp[2*4]; - vec_f32 temp0[8*4], vy0[2*4]; + vec_f32 temp0[8*4]; vec_f32 v_alpha = { alpha, alpha, alpha, alpha }; a0 = ap[0]; @@ -261,36 +371,80 @@ static void BF16GEMV_N_MMA_8(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, BLAS vec_bf16 *vb3 = (vec_bf16 *)b3; vec_bf16 *x_bf = (vec_bf16 *)(xo); - vec_bf16 v_x00 = (vec_bf16)vec_load_vec(x_bf); - - vec_bf16 v_x01 = (vec_bf16)vec_splat((vec_f32)v_x00, 1); - vec_bf16 v_x02 = (vec_bf16)vec_splat((vec_f32)v_x00, 2); - vec_bf16 v_x03 = (vec_bf16)vec_splat((vec_f32)v_x00, 3); vec_f32 *v_y = (vec_f32 *)y; BLASLONG n8 = n / 8; BLASLONG i = 0; +#ifdef USE_MERGE_MMA + vec_bf16 v_x0[16]; + vec_f32 vy0[2*4*2]; + v_x0[0] = (vec_bf16)vec_load_vec(x_bf); + + vec_make_mult4(v_x0); + + for (; i + 8 <= n8; i += 8) { + vec_load8_pair(vy0, &v_y[(i * 2) + 0]); + + vec_load_mult284a_mma(&temp[0], &va0[i + 0], &va1[i + 0], &v_x0[ 0]); + vec_load_mult284b_mma(&temp[0], &va2[i + 0], &va3[i + 0], &v_x0[ 4]); + vec_load_mult284b_mma(&temp[0], &vb0[i + 0], &vb1[i + 0], &v_x0[ 8]); + vec_load_mult284b_mma(&temp[0], &vb2[i + 0], &vb3[i + 0], &v_x0[12]); + vec_load_mult284a_mma(&temp[2], &va0[i + 4], &va1[i + 4], &v_x0[ 0]); + vec_load_mult284b_mma(&temp[2], &va2[i + 4], &va3[i + 4], &v_x0[ 4]); + vec_load_mult284b_mma(&temp[2], &vb0[i + 4], &vb1[i + 4], &v_x0[ 8]); + vec_load_mult284b_mma(&temp[2], &vb2[i + 4], &vb3[i + 4], &v_x0[12]); + + vec_reduce84_mma(&temp[0], temp0 + 0, v_alpha, vy0 + 0); + vec_reduce84_mma(&temp[2], temp0 + 8, v_alpha, vy0 + 8); + + vec_store8_pair(&v_y[(i * 2) + 0], vy0); + } + + if (n8 & 4) { + vec_load4_pair(vy0, &v_y[(i * 2) + 0]); + + vec_load_mult284a_mma(&temp[0], &va0[i + 0], &va1[i + 0], &v_x0[ 0]); + vec_load_mult284b_mma(&temp[0], &va2[i + 0], &va3[i + 0], &v_x0[ 4]); + vec_load_mult284b_mma(&temp[0], &vb0[i + 0], &vb1[i + 0], &v_x0[ 8]); + vec_load_mult284b_mma(&temp[0], &vb2[i + 0], &vb3[i + 0], &v_x0[12]); + + vec_reduce84_mma(&temp[0], temp0, v_alpha, vy0); + + vec_store4_pair(&v_y[(i * 2) + 0], vy0); + + i += 4; + } +#else + vec_bf16 v_x0[13]; + vec_f32 vy0[2*4]; + v_x0[0] = (vec_bf16)vec_load_vec(x_bf); + + v_x0[ 4] = (vec_bf16)vec_splat((vec_f32)v_x0[0], 1); + v_x0[ 8] = (vec_bf16)vec_splat((vec_f32)v_x0[0], 2); + v_x0[12] = (vec_bf16)vec_splat((vec_f32)v_x0[0], 3); + for (; i + 4 <= n8; i += 4) { vec_load4_pair(vy0, &v_y[(i * 2) + 0]); - vec_load_mult28a_mma(&temp[0], &va0[i + 0], &va1[i + 0], v_x00); - vec_load_mult28b_mma(&temp[0], &va2[i + 0], &va3[i + 0], v_x01); - vec_load_mult28b_mma(&temp[0], &vb0[i + 0], &vb1[i + 0], v_x02); - vec_load_mult28b_mma(&temp[0], &vb2[i + 0], &vb3[i + 0], v_x03); + vec_load_mult28a_mma(&temp[0], &va0[i + 0], &va1[i + 0], v_x0[ 0]); + vec_load_mult28b_mma(&temp[0], &va2[i + 0], &va3[i + 0], v_x0[ 4]); + vec_load_mult28b_mma(&temp[0], &vb0[i + 0], &vb1[i + 0], v_x0[ 8]); + vec_load_mult28b_mma(&temp[0], &vb2[i + 0], &vb3[i + 0], v_x0[12]); vec_reduce8_mma(&temp[0], temp0, v_alpha, vy0); vec_store4_pair(&v_y[(i * 2) + 0], vy0); } +#endif for (; i < n8; i++) { vec_load_pair(vy0, &v_y[(i * 2) + 0]); - vec_load_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x00); - vec_load_mult22b_mma(&temp[0], &va2[i], &va3[i], v_x01); - vec_load_mult22b_mma(&temp[0], &vb0[i], &vb1[i], v_x02); - vec_load_mult22b_mma(&temp[0], &vb2[i], &vb3[i], v_x03); + vec_load_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0]); + vec_load_mult22b_mma(&temp[0], &va2[i], &va3[i], v_x0[ 4]); + vec_load_mult22b_mma(&temp[0], &vb0[i], &vb1[i], v_x0[ 8]); + vec_load_mult22b_mma(&temp[0], &vb2[i], &vb3[i], v_x0[12]); vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0); @@ -302,10 +456,10 @@ static void BF16GEMV_N_MMA_8(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, BLAS BLASLONG n3 = n & 3; vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n3); - vec_loadN_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x00, n); - vec_loadN_mult22b_mma(&temp[0], &va2[i], &va3[i], v_x01, n); - vec_loadN_mult22b_mma(&temp[0], &vb0[i], &vb1[i], v_x02, n); - vec_loadN_mult22b_mma(&temp[0], &vb2[i], &vb3[i], v_x03, n); + vec_loadN_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0], n); + vec_loadN_mult22b_mma(&temp[0], &va2[i], &va3[i], v_x0[ 4], n); + vec_loadN_mult22b_mma(&temp[0], &vb0[i], &vb1[i], v_x0[ 8], n); + vec_loadN_mult22b_mma(&temp[0], &vb2[i], &vb3[i], v_x0[12], n); vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0); @@ -313,10 +467,10 @@ static void BF16GEMV_N_MMA_8(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, BLAS } else if (n) { vy0[0] = vec_loadN_f32(&v_y[(i * 2) + 0], n); - vec_loadN_mult11a_mma(&temp[0], &va0[i], &va1[i], v_x00, n); - vec_loadN_mult11b_mma(&temp[0], &va2[i], &va3[i], v_x01, n); - vec_loadN_mult11b_mma(&temp[0], &vb0[i], &vb1[i], v_x02, n); - vec_loadN_mult11b_mma(&temp[0], &vb2[i], &vb3[i], v_x03, n); + vec_loadN_mult11a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0], n); + vec_loadN_mult11b_mma(&temp[0], &va2[i], &va3[i], v_x0[ 4], n); + vec_loadN_mult11b_mma(&temp[0], &vb0[i], &vb1[i], v_x0[ 8], n); + vec_loadN_mult11b_mma(&temp[0], &vb2[i], &vb3[i], v_x0[12], n); vec_reduce1_mma(&temp[0], temp0, v_alpha, vy0); From df19375560641fa213b3332d8fa775efa77f7756 Mon Sep 17 00:00:00 2001 From: Chip Kerchner Date: Tue, 24 Sep 2024 16:30:01 -0500 Subject: [PATCH 10/24] Almost final code for MMA. --- kernel/power/sbgemv_common_power10.c | 93 ++++++++++++++++++++-------- kernel/power/sbgemv_n_power10.c | 41 +++++------- 2 files changed, 80 insertions(+), 54 deletions(-) diff --git a/kernel/power/sbgemv_common_power10.c b/kernel/power/sbgemv_common_power10.c index 2ee912b9d2..d24a98418f 100644 --- a/kernel/power/sbgemv_common_power10.c +++ b/kernel/power/sbgemv_common_power10.c @@ -152,6 +152,14 @@ FORCEINLINE void vec_reduce84_mma(__vector_quad *out, vec_f32 *temp, vec_f32 v_a vec_reduce44_mma(&out[0], &temp[0], v_alpha, vy0 + 0); vec_reduce44_mma(&out[1], &temp[4], v_alpha, vy0 + 1); } + +FORCEINLINE void vec_reduce88_mma(__vector_quad *out, vec_f32 *temp, vec_f32 v_alpha, vec_f32 *vy0) +{ + vec_reduce44_mma(&out[0], &temp[ 0], v_alpha, vy0 + 0); + vec_reduce44_mma(&out[1], &temp[ 4], v_alpha, vy0 + 1); + vec_reduce44_mma(&out[2], &temp[ 8], v_alpha, vy0 + 8); + vec_reduce44_mma(&out[3], &temp[12], v_alpha, vy0 + 9); +} #endif FORCEINLINE void vec_mult11a_mma(__vector_quad *out, vec_bf16 in0, vec_bf16 in1, vec_bf16 inp) @@ -341,6 +349,32 @@ FORCEINLINE void vec_load_mult284b_mma(__vector_quad *out, vec_bf16 *ina, vec_bf vec_mult44b_mma(out, in0 + 0, in1 + 0, inp + 0); vec_mult44b_mma(out, in0 + 2, in1 + 2, inp + 2); } + +FORCEINLINE void vec_load_mult288a_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 *inp) +{ + vec_bf16 in0[8], in1[8]; + + vec_load4_mma(in0 + 0, in1 + 0, ina + 0, inb + 0); + vec_load4_mma(in0 + 4, in1 + 4, ina + 4, inb + 4); + + vec_mult44a_mma(out + 0, in0 + 0, in1 + 0, inp + 0); + vec_mult44a_mma(out + 2, in0 + 4, in1 + 4, inp + 0); + vec_mult44b_mma(out + 0, in0 + 2, in1 + 2, inp + 2); + vec_mult44b_mma(out + 2, in0 + 6, in1 + 6, inp + 2); +} + +FORCEINLINE void vec_load_mult288b_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 *inp) +{ + vec_bf16 in0[8], in1[8]; + + vec_load4_mma(in0 + 0, in1 + 0, ina + 0, inb + 0); + vec_load4_mma(in0 + 4, in1 + 4, ina + 4, inb + 4); + + vec_mult44b_mma(out + 0, in0 + 0, in1 + 0, inp + 0); + vec_mult44b_mma(out + 2, in0 + 4, in1 + 4, inp + 0); + vec_mult44b_mma(out + 0, in0 + 2, in1 + 2, inp + 2); + vec_mult44b_mma(out + 2, in0 + 6, in1 + 6, inp + 2); +} #endif FORCEINLINE void vec_loadN_mult22b_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 inp, BLASLONG n) @@ -381,49 +415,54 @@ FORCEINLINE void vec_store8_pair(vec_f32 *v_y, vec_f32 *vy0) } #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ -#define VEC_SHIFT(data, shift) vec_sld(data, data, 16 - shift) -#else -#define VEC_SHIFT(data, shift) vec_sld(data, data, shift) -#endif +#define VEC_SHIFT(data, shift) vec_sldw(data, data, 4 - shift) -typedef __vector unsigned int vec_ui32; +#define MASK_0 0xf000 +#define MASK_1 0x0f00 +#define MASK_2 0x00f0 +#define MASK_3 0x000f +#else +#define VEC_SHIFT(data, shift) vec_sldw(data, data, shift) -static vec_ui32 mask_0 = { 0xffffffff, 0x00000000, 0x00000000, 0x00000000 }; -static vec_ui32 mask_1 = { 0x00000000, 0xffffffff, 0x00000000, 0x00000000 }; -static vec_ui32 mask_2 = { 0x00000000, 0x00000000, 0xffffffff, 0x00000000 }; -static vec_ui32 mask_3 = { 0x00000000, 0x00000000, 0x00000000, 0xffffffff }; +#define MASK_0 0x000f +#define MASK_1 0x00f0 +#define MASK_2 0x0f00 +#define MASK_3 0xf000 +#endif -FORCEINLINE void vec_make_mult1(vec_bf16 *v_x0) +FORCEINLINE void vec_make_mult1(vec_bf16 *v_x0, const bool mask) { - v_x0[ 0] = vec_and(v_x0[0], (vec_bf16)mask_0); + if (mask) { + v_x0[ 0] = vec_and(v_x0[0], (vec_bf16)vec_genbm(MASK_0)); + } - v_x0[ 1] = VEC_SHIFT(v_x0[ 0], 4); - v_x0[ 2] = VEC_SHIFT(v_x0[ 0], 8); - v_x0[ 3] = VEC_SHIFT(v_x0[ 0], 12); + v_x0[ 1] = VEC_SHIFT(v_x0[ 0], 1); + v_x0[ 2] = VEC_SHIFT(v_x0[ 0], 2); + v_x0[ 3] = VEC_SHIFT(v_x0[ 0], 3); } FORCEINLINE void vec_make_mult2(vec_bf16 *v_x0) { - v_x0[ 5] = vec_and(v_x0[0], (vec_bf16)mask_1); - vec_make_mult1(v_x0); + v_x0[ 5] = vec_and(v_x0[0], (vec_bf16)vec_genbm(MASK_1)); + vec_make_mult1(v_x0, true); - v_x0[ 4] = VEC_SHIFT(v_x0[ 5], 12); - v_x0[ 6] = VEC_SHIFT(v_x0[ 5], 4); - v_x0[ 7] = VEC_SHIFT(v_x0[ 5], 8); + v_x0[ 4] = VEC_SHIFT(v_x0[ 5], 3); + v_x0[ 6] = VEC_SHIFT(v_x0[ 5], 1); + v_x0[ 7] = VEC_SHIFT(v_x0[ 5], 2); } FORCEINLINE void vec_make_mult4(vec_bf16 *v_x0) { - v_x0[10] = vec_and(v_x0[0], (vec_bf16)mask_2); - v_x0[15] = vec_and(v_x0[0], (vec_bf16)mask_3); + v_x0[10] = vec_and(v_x0[0], (vec_bf16)vec_genbm(MASK_2)); + v_x0[15] = vec_and(v_x0[0], (vec_bf16)vec_genbm(MASK_3)); vec_make_mult2(v_x0); - v_x0[ 8] = VEC_SHIFT(v_x0[10], 8); - v_x0[ 9] = VEC_SHIFT(v_x0[10], 12); - v_x0[11] = VEC_SHIFT(v_x0[10], 4); - v_x0[12] = VEC_SHIFT(v_x0[15], 4); - v_x0[13] = VEC_SHIFT(v_x0[15], 8); - v_x0[14] = VEC_SHIFT(v_x0[15], 12); + v_x0[ 8] = VEC_SHIFT(v_x0[10], 2); + v_x0[ 9] = VEC_SHIFT(v_x0[10], 3); + v_x0[11] = VEC_SHIFT(v_x0[10], 1); + v_x0[12] = VEC_SHIFT(v_x0[15], 1); + v_x0[13] = VEC_SHIFT(v_x0[15], 2); + v_x0[14] = VEC_SHIFT(v_x0[15], 3); } #endif diff --git a/kernel/power/sbgemv_n_power10.c b/kernel/power/sbgemv_n_power10.c index f2ed6bf9a8..e75f394e72 100644 --- a/kernel/power/sbgemv_n_power10.c +++ b/kernel/power/sbgemv_n_power10.c @@ -28,9 +28,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #ifndef SBGEMV_N_MMA_C #define SBGEMV_N_MMA_C -#if !defined(_AIX) || defined(__clang__) #define USE_BFGEMV_N_MMA -#endif #ifdef USE_BFGEMV_N_MMA #include "sbgemv_common_power10.c" @@ -67,7 +65,7 @@ static void BF16GEMV_N_MMA_1(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA v_x0[0] = vec_loadN(x_bf, 1); vec_f32 vy0[2*4*2]; - vec_make_mult1(v_x0); + vec_make_mult1(v_x0, false); for (; i + 8 <= n8; i += 8) { vec_load8_pair(vy0, &v_y[(i * 2) + 0]); @@ -75,8 +73,7 @@ static void BF16GEMV_N_MMA_1(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA vec_load_mult184_mma(&temp[0], &va0[i + 0], &v_x0[ 0]); vec_load_mult184_mma(&temp[2], &va0[i + 4], &v_x0[ 0]); - vec_reduce84_mma(&temp[0], temp0 + 0, v_alpha, vy0 + 0); - vec_reduce84_mma(&temp[2], temp0 + 8, v_alpha, vy0 + 8); + vec_reduce88_mma(&temp[0], temp0 + 0, v_alpha, vy0 + 0); vec_store8_pair(&v_y[(i * 2) + 0], vy0); } @@ -163,16 +160,14 @@ static void BF16GEMV_N_MMA_2(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA vec_f32 vy0[2*4*2]; v_x0[0] = vec_loadN(x_bf, 2); - vec_make_mult1(v_x0); + vec_make_mult1(v_x0, false); for (; i + 8 <= n8; i += 8) { vec_load8_pair(vy0, &v_y[(i * 2) + 0]); - vec_load_mult284a_mma(&temp[0], &va0[i + 0], &va1[i + 0], &v_x0[ 0]); - vec_load_mult284a_mma(&temp[2], &va0[i + 4], &va1[i + 4], &v_x0[ 0]); + vec_load_mult288a_mma(&temp[0], &va0[i + 0], &va1[i + 0], &v_x0[ 0]); - vec_reduce84_mma(&temp[0], temp0 + 0, v_alpha, vy0 + 0); - vec_reduce84_mma(&temp[2], temp0 + 8, v_alpha, vy0 + 8); + vec_reduce88_mma(&temp[0], temp0 + 0, v_alpha, vy0 + 0); vec_store8_pair(&v_y[(i * 2) + 0], vy0); } @@ -268,13 +263,10 @@ static void BF16GEMV_N_MMA_4(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA for (; i + 8 <= n8; i += 8) { vec_load8_pair(vy0, &v_y[(i * 2) + 0]); - vec_load_mult284a_mma(&temp[0], &va0[i + 0], &va1[i + 0], &v_x0[ 0]); - vec_load_mult284b_mma(&temp[0], &va2[i + 0], &va3[i + 0], &v_x0[ 4]); - vec_load_mult284a_mma(&temp[2], &va0[i + 4], &va1[i + 4], &v_x0[ 0]); - vec_load_mult284b_mma(&temp[2], &va2[i + 4], &va3[i + 4], &v_x0[ 4]); + vec_load_mult288a_mma(&temp[0], &va0[i + 0], &va1[i + 0], &v_x0[ 0]); + vec_load_mult288b_mma(&temp[0], &va2[i + 0], &va3[i + 0], &v_x0[ 4]); - vec_reduce84_mma(&temp[0], temp0 + 0, v_alpha, vy0 + 0); - vec_reduce84_mma(&temp[2], temp0 + 8, v_alpha, vy0 + 8); + vec_reduce88_mma(&temp[0], temp0 + 0, v_alpha, vy0 + 0); vec_store8_pair(&v_y[(i * 2) + 0], vy0); } @@ -386,17 +378,12 @@ static void BF16GEMV_N_MMA_8(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, BLAS for (; i + 8 <= n8; i += 8) { vec_load8_pair(vy0, &v_y[(i * 2) + 0]); - vec_load_mult284a_mma(&temp[0], &va0[i + 0], &va1[i + 0], &v_x0[ 0]); - vec_load_mult284b_mma(&temp[0], &va2[i + 0], &va3[i + 0], &v_x0[ 4]); - vec_load_mult284b_mma(&temp[0], &vb0[i + 0], &vb1[i + 0], &v_x0[ 8]); - vec_load_mult284b_mma(&temp[0], &vb2[i + 0], &vb3[i + 0], &v_x0[12]); - vec_load_mult284a_mma(&temp[2], &va0[i + 4], &va1[i + 4], &v_x0[ 0]); - vec_load_mult284b_mma(&temp[2], &va2[i + 4], &va3[i + 4], &v_x0[ 4]); - vec_load_mult284b_mma(&temp[2], &vb0[i + 4], &vb1[i + 4], &v_x0[ 8]); - vec_load_mult284b_mma(&temp[2], &vb2[i + 4], &vb3[i + 4], &v_x0[12]); - - vec_reduce84_mma(&temp[0], temp0 + 0, v_alpha, vy0 + 0); - vec_reduce84_mma(&temp[2], temp0 + 8, v_alpha, vy0 + 8); + vec_load_mult288a_mma(&temp[0], &va0[i + 0], &va1[i + 0], &v_x0[ 0]); + vec_load_mult288b_mma(&temp[0], &va2[i + 0], &va3[i + 0], &v_x0[ 4]); + vec_load_mult288b_mma(&temp[0], &vb0[i + 0], &vb1[i + 0], &v_x0[ 8]); + vec_load_mult288b_mma(&temp[0], &vb2[i + 0], &vb3[i + 0], &v_x0[12]); + + vec_reduce88_mma(&temp[0], temp0 + 0, v_alpha, vy0 + 0); vec_store8_pair(&v_y[(i * 2) + 0], vy0); } From 8ab6245771b9b8b52a937c35a447c218f41e3bbd Mon Sep 17 00:00:00 2001 From: Chip Kerchner Date: Tue, 24 Sep 2024 16:50:21 -0500 Subject: [PATCH 11/24] Small change. --- kernel/power/sbgemv_n_power10.c | 100 ++++++++++++++++---------------- 1 file changed, 50 insertions(+), 50 deletions(-) diff --git a/kernel/power/sbgemv_n_power10.c b/kernel/power/sbgemv_n_power10.c index e75f394e72..f33a246a99 100644 --- a/kernel/power/sbgemv_n_power10.c +++ b/kernel/power/sbgemv_n_power10.c @@ -68,21 +68,21 @@ static void BF16GEMV_N_MMA_1(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA vec_make_mult1(v_x0, false); for (; i + 8 <= n8; i += 8) { - vec_load8_pair(vy0, &v_y[(i * 2) + 0]); - vec_load_mult184_mma(&temp[0], &va0[i + 0], &v_x0[ 0]); vec_load_mult184_mma(&temp[2], &va0[i + 4], &v_x0[ 0]); + vec_load8_pair(vy0, &v_y[(i * 2) + 0]); + vec_reduce88_mma(&temp[0], temp0 + 0, v_alpha, vy0 + 0); vec_store8_pair(&v_y[(i * 2) + 0], vy0); } if (n8 & 4) { - vec_load4_pair(vy0, &v_y[(i * 2) + 0]); - vec_load_mult184_mma(&temp[0], &va0[i + 0], &v_x0[ 0]); + vec_load4_pair(vy0, &v_y[(i * 2) + 0]); + vec_reduce84_mma(&temp[0], temp0, v_alpha, vy0); vec_store4_pair(&v_y[(i * 2) + 0], vy0); @@ -95,10 +95,10 @@ static void BF16GEMV_N_MMA_1(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA vec_f32 vy0[2*4]; for (; i + 4 <= n8; i += 4) { - vec_load4_pair(vy0, &v_y[(i * 2) + 0]); - vec_load_mult18_mma(&temp[0], &va0[i + 0], v_x0[ 0]); + vec_load4_pair(vy0, &v_y[(i * 2) + 0]); + vec_reduce8_mma(&temp[0], temp0, v_alpha, vy0); vec_store4_pair(&v_y[(i * 2) + 0], vy0); @@ -106,10 +106,10 @@ static void BF16GEMV_N_MMA_1(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA #endif for (; i < n8; i++) { - vec_load_pair(vy0, &v_y[(i * 2) + 0]); - vec_load_mult12_mma(&temp[0], &va0[i], v_x0[ 0]); + vec_load_pair(vy0, &v_y[(i * 2) + 0]); + vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0); vec_store_pair(&v_y[(i * 2) + 0], vy0); @@ -117,19 +117,19 @@ static void BF16GEMV_N_MMA_1(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA n &= 7; if (n > 4) { + vec_loadN_mult12_mma(&temp[0], &va0[i], v_x0[ 0], n); + BLASLONG n3 = n & 3; vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n3); - vec_loadN_mult12_mma(&temp[0], &va0[i], v_x0[ 0], n); - vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0); vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n3); } else if (n) { - vy0[0] = vec_loadN_f32(&v_y[(i * 2) + 0], n); - vec_loadN_mult11_mma(&temp[0], &va0[i], v_x0[ 0], n); + vy0[0] = vec_loadN_f32(&v_y[(i * 2) + 0], n); + vec_reduce1_mma(&temp[0], temp0, v_alpha, vy0); vec_storeN_f32(vy0[0], &v_y[(i * 2) + 0], n); @@ -163,20 +163,20 @@ static void BF16GEMV_N_MMA_2(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA vec_make_mult1(v_x0, false); for (; i + 8 <= n8; i += 8) { - vec_load8_pair(vy0, &v_y[(i * 2) + 0]); - vec_load_mult288a_mma(&temp[0], &va0[i + 0], &va1[i + 0], &v_x0[ 0]); + vec_load8_pair(vy0, &v_y[(i * 2) + 0]); + vec_reduce88_mma(&temp[0], temp0 + 0, v_alpha, vy0 + 0); vec_store8_pair(&v_y[(i * 2) + 0], vy0); } if (n8 & 4) { - vec_load4_pair(vy0, &v_y[(i * 2) + 0]); - vec_load_mult284a_mma(&temp[0], &va0[i + 0], &va1[i + 0], &v_x0[ 0]); + vec_load4_pair(vy0, &v_y[(i * 2) + 0]); + vec_reduce84_mma(&temp[0], temp0, v_alpha, vy0); vec_store4_pair(&v_y[(i * 2) + 0], vy0); @@ -189,10 +189,10 @@ static void BF16GEMV_N_MMA_2(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA v_x0[0] = vec_loadN(x_bf, 2); for (; i + 4 <= n8; i += 4) { - vec_load4_pair(vy0, &v_y[(i * 2) + 0]); - vec_load_mult28a_mma(&temp[0], &va0[i + 0], &va1[i + 0], v_x0[ 0]); + vec_load4_pair(vy0, &v_y[(i * 2) + 0]); + vec_reduce8_mma(&temp[0], temp0, v_alpha, vy0); vec_store4_pair(&v_y[(i * 2) + 0], vy0); @@ -200,10 +200,10 @@ static void BF16GEMV_N_MMA_2(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA #endif for (; i < n8; i++) { - vec_load_pair(vy0, &v_y[(i * 2) + 0]); - vec_load_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0]); + vec_load_pair(vy0, &v_y[(i * 2) + 0]); + vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0); vec_store_pair(&v_y[(i * 2) + 0], vy0); @@ -211,19 +211,19 @@ static void BF16GEMV_N_MMA_2(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA n &= 7; if (n > 4) { + vec_loadN_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0], n); + BLASLONG n3 = n & 3; vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n3); - vec_loadN_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0], n); - vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0); vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n3); } else if (n) { - vy0[0] = vec_loadN_f32(&v_y[(i * 2) + 0], n); - vec_loadN_mult11a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0], n); + vy0[0] = vec_loadN_f32(&v_y[(i * 2) + 0], n); + vec_reduce1_mma(&temp[0], temp0, v_alpha, vy0); vec_storeN_f32(vy0[0], &v_y[(i * 2) + 0], n); @@ -261,22 +261,22 @@ static void BF16GEMV_N_MMA_4(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA vec_make_mult2(v_x0); for (; i + 8 <= n8; i += 8) { - vec_load8_pair(vy0, &v_y[(i * 2) + 0]); - vec_load_mult288a_mma(&temp[0], &va0[i + 0], &va1[i + 0], &v_x0[ 0]); vec_load_mult288b_mma(&temp[0], &va2[i + 0], &va3[i + 0], &v_x0[ 4]); + vec_load8_pair(vy0, &v_y[(i * 2) + 0]); + vec_reduce88_mma(&temp[0], temp0 + 0, v_alpha, vy0 + 0); vec_store8_pair(&v_y[(i * 2) + 0], vy0); } if (n8 & 4) { - vec_load4_pair(vy0, &v_y[(i * 2) + 0]); - vec_load_mult284a_mma(&temp[0], &va0[i + 0], &va1[i + 0], &v_x0[ 0]); vec_load_mult284b_mma(&temp[0], &va2[i + 0], &va3[i + 0], &v_x0[ 4]); + vec_load4_pair(vy0, &v_y[(i * 2) + 0]); + vec_reduce84_mma(&temp[0], temp0, v_alpha, vy0); vec_store4_pair(&v_y[(i * 2) + 0], vy0); @@ -291,11 +291,11 @@ static void BF16GEMV_N_MMA_4(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA v_x0[ 4] = (vec_bf16)vec_splat((vec_f32)v_x0[0], 1); for (; i + 4 <= n8; i += 4) { - vec_load4_pair(vy0, &v_y[(i * 2) + 0]); - vec_load_mult28a_mma(&temp[0], &va0[i + 0], &va1[i + 0], v_x0[ 0]); vec_load_mult28b_mma(&temp[0], &va2[i + 0], &va3[i + 0], v_x0[ 4]); + vec_load4_pair(vy0, &v_y[(i * 2) + 0]); + vec_reduce8_mma(&temp[0], temp0, v_alpha, vy0); vec_store4_pair(&v_y[(i * 2) + 0], vy0); @@ -303,11 +303,11 @@ static void BF16GEMV_N_MMA_4(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA #endif for (; i < n8; i++) { - vec_load_pair(vy0, &v_y[(i * 2) + 0]); - vec_load_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0]); vec_load_mult22b_mma(&temp[0], &va2[i], &va3[i], v_x0[ 4]); + vec_load_pair(vy0, &v_y[(i * 2) + 0]); + vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0); vec_store_pair(&v_y[(i * 2) + 0], vy0); @@ -315,21 +315,21 @@ static void BF16GEMV_N_MMA_4(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA n &= 7; if (n > 4) { - BLASLONG n3 = n & 3; - vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n3); - vec_loadN_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0], n); vec_loadN_mult22b_mma(&temp[0], &va2[i], &va3[i], v_x0[ 4], n); + BLASLONG n3 = n & 3; + vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n3); + vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0); vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n3); } else if (n) { - vy0[0] = vec_loadN_f32(&v_y[(i * 2) + 0], n); - vec_loadN_mult11a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0], n); vec_loadN_mult11b_mma(&temp[0], &va2[i], &va3[i], v_x0[ 4], n); + vy0[0] = vec_loadN_f32(&v_y[(i * 2) + 0], n); + vec_reduce1_mma(&temp[0], temp0, v_alpha, vy0); vec_storeN_f32(vy0[0], &v_y[(i * 2) + 0], n); @@ -376,26 +376,26 @@ static void BF16GEMV_N_MMA_8(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, BLAS vec_make_mult4(v_x0); for (; i + 8 <= n8; i += 8) { - vec_load8_pair(vy0, &v_y[(i * 2) + 0]); - vec_load_mult288a_mma(&temp[0], &va0[i + 0], &va1[i + 0], &v_x0[ 0]); vec_load_mult288b_mma(&temp[0], &va2[i + 0], &va3[i + 0], &v_x0[ 4]); vec_load_mult288b_mma(&temp[0], &vb0[i + 0], &vb1[i + 0], &v_x0[ 8]); vec_load_mult288b_mma(&temp[0], &vb2[i + 0], &vb3[i + 0], &v_x0[12]); + vec_load8_pair(vy0, &v_y[(i * 2) + 0]); + vec_reduce88_mma(&temp[0], temp0 + 0, v_alpha, vy0 + 0); vec_store8_pair(&v_y[(i * 2) + 0], vy0); } if (n8 & 4) { - vec_load4_pair(vy0, &v_y[(i * 2) + 0]); - vec_load_mult284a_mma(&temp[0], &va0[i + 0], &va1[i + 0], &v_x0[ 0]); vec_load_mult284b_mma(&temp[0], &va2[i + 0], &va3[i + 0], &v_x0[ 4]); vec_load_mult284b_mma(&temp[0], &vb0[i + 0], &vb1[i + 0], &v_x0[ 8]); vec_load_mult284b_mma(&temp[0], &vb2[i + 0], &vb3[i + 0], &v_x0[12]); + vec_load4_pair(vy0, &v_y[(i * 2) + 0]); + vec_reduce84_mma(&temp[0], temp0, v_alpha, vy0); vec_store4_pair(&v_y[(i * 2) + 0], vy0); @@ -412,13 +412,13 @@ static void BF16GEMV_N_MMA_8(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, BLAS v_x0[12] = (vec_bf16)vec_splat((vec_f32)v_x0[0], 3); for (; i + 4 <= n8; i += 4) { - vec_load4_pair(vy0, &v_y[(i * 2) + 0]); - vec_load_mult28a_mma(&temp[0], &va0[i + 0], &va1[i + 0], v_x0[ 0]); vec_load_mult28b_mma(&temp[0], &va2[i + 0], &va3[i + 0], v_x0[ 4]); vec_load_mult28b_mma(&temp[0], &vb0[i + 0], &vb1[i + 0], v_x0[ 8]); vec_load_mult28b_mma(&temp[0], &vb2[i + 0], &vb3[i + 0], v_x0[12]); + vec_load4_pair(vy0, &v_y[(i * 2) + 0]); + vec_reduce8_mma(&temp[0], temp0, v_alpha, vy0); vec_store4_pair(&v_y[(i * 2) + 0], vy0); @@ -426,13 +426,13 @@ static void BF16GEMV_N_MMA_8(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, BLAS #endif for (; i < n8; i++) { - vec_load_pair(vy0, &v_y[(i * 2) + 0]); - vec_load_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0]); vec_load_mult22b_mma(&temp[0], &va2[i], &va3[i], v_x0[ 4]); vec_load_mult22b_mma(&temp[0], &vb0[i], &vb1[i], v_x0[ 8]); vec_load_mult22b_mma(&temp[0], &vb2[i], &vb3[i], v_x0[12]); + vec_load_pair(vy0, &v_y[(i * 2) + 0]); + vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0); vec_store_pair(&v_y[(i * 2) + 0], vy0); @@ -440,25 +440,25 @@ static void BF16GEMV_N_MMA_8(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, BLAS n &= 7; if (n > 4) { - BLASLONG n3 = n & 3; - vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n3); - vec_loadN_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0], n); vec_loadN_mult22b_mma(&temp[0], &va2[i], &va3[i], v_x0[ 4], n); vec_loadN_mult22b_mma(&temp[0], &vb0[i], &vb1[i], v_x0[ 8], n); vec_loadN_mult22b_mma(&temp[0], &vb2[i], &vb3[i], v_x0[12], n); + BLASLONG n3 = n & 3; + vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n3); + vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0); vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n3); } else if (n) { - vy0[0] = vec_loadN_f32(&v_y[(i * 2) + 0], n); - vec_loadN_mult11a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0], n); vec_loadN_mult11b_mma(&temp[0], &va2[i], &va3[i], v_x0[ 4], n); vec_loadN_mult11b_mma(&temp[0], &vb0[i], &vb1[i], v_x0[ 8], n); vec_loadN_mult11b_mma(&temp[0], &vb2[i], &vb3[i], v_x0[12], n); + vy0[0] = vec_loadN_f32(&v_y[(i * 2) + 0], n); + vec_reduce1_mma(&temp[0], temp0, v_alpha, vy0); vec_storeN_f32(vy0[0], &v_y[(i * 2) + 0], n); From fb287d17fc1a5f53920f0b8c29ba476b258950bc Mon Sep 17 00:00:00 2001 From: Chip Kerchner Date: Wed, 25 Sep 2024 16:31:36 -0500 Subject: [PATCH 12/24] Common code. --- kernel/power/sbgemv_common_power10.c | 138 +++++++++++++++++ kernel/power/sbgemv_n_power10.c | 24 +-- kernel/power/sbgemv_t_power10.c | 223 ++++++++++++++------------- 3 files changed, 263 insertions(+), 122 deletions(-) diff --git a/kernel/power/sbgemv_common_power10.c b/kernel/power/sbgemv_common_power10.c index d24a98418f..638e2655c0 100644 --- a/kernel/power/sbgemv_common_power10.c +++ b/kernel/power/sbgemv_common_power10.c @@ -33,6 +33,12 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #define USE_MERGE_MMA #endif +FORCEINLINE void vec_load_pair2(vec_bf16 *in0, vec_bf16 *in) +{ + vec_load_pair((vec_f32 *)(in0 + 0), (vec_f32 *)(in + 0)); + vec_load_pair((vec_f32 *)(in0 + 2), (vec_f32 *)(in + 2)); +} + FORCEINLINE void vec_load_mult_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 inp) { vec_bf16 in0 = (vec_bf16)vec_load_vec(in); @@ -40,6 +46,28 @@ FORCEINLINE void vec_load_mult_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 in __builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0, (vec_uc8)inp); } +FORCEINLINE void vec_load_mult12a_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 inp) +{ + vec_bf16 in01 = (vec_bf16)vec_load_vec(in0); + vec_bf16 in11 = (vec_bf16)vec_load_vec(in1); + + __builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01, (vec_uc8)inp); + __builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11, (vec_uc8)inp); +} + +FORCEINLINE void vec_load_mult14_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *in2, vec_bf16 *in3, vec_bf16 inp) +{ + vec_bf16 in01 = (vec_bf16)vec_load_vec(in0); + vec_bf16 in11 = (vec_bf16)vec_load_vec(in1); + vec_bf16 in21 = (vec_bf16)vec_load_vec(in2); + vec_bf16 in31 = (vec_bf16)vec_load_vec(in3); + + __builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01, (vec_uc8)inp); + __builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11, (vec_uc8)inp); + __builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21, (vec_uc8)inp); + __builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31, (vec_uc8)inp); +} + FORCEINLINE void vec_load_mult2_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 *inp) { vec_bf16 in0[2]; @@ -50,6 +78,94 @@ FORCEINLINE void vec_load_mult2_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 * __builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0[1], (vec_uc8)inp[1]); } +FORCEINLINE void vec_load_mult22_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *inp) +{ + vec_bf16 in01[2], in11[2]; + + vec_load_pair((vec_f32 *)in01, (vec_f32 *)in0); + vec_load_pair((vec_f32 *)in11, (vec_f32 *)in1); + + __builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[0], (vec_uc8)inp[0]); + __builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[0], (vec_uc8)inp[0]); + __builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[1], (vec_uc8)inp[1]); + __builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[1], (vec_uc8)inp[1]); +} + +FORCEINLINE void vec_load_mult24_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *in2, vec_bf16 *in3, vec_bf16 *inp) +{ + vec_bf16 in01[2], in11[2], in21[2], in31[2]; + + vec_load_pair((vec_f32 *)in01, (vec_f32 *)in0); + vec_load_pair((vec_f32 *)in11, (vec_f32 *)in1); + vec_load_pair((vec_f32 *)in21, (vec_f32 *)in2); + vec_load_pair((vec_f32 *)in31, (vec_f32 *)in3); + + __builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[0], (vec_uc8)inp[0]); + __builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[0], (vec_uc8)inp[0]); + __builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21[0], (vec_uc8)inp[0]); + __builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31[0], (vec_uc8)inp[0]); + __builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[1], (vec_uc8)inp[1]); + __builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[1], (vec_uc8)inp[1]); + __builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21[1], (vec_uc8)inp[1]); + __builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31[1], (vec_uc8)inp[1]); +} + +FORCEINLINE void vec_load_mult4_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 *inp) +{ + vec_bf16 in0[4]; + + vec_load_pair2(in0, in); + + __builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0[0], (vec_uc8)inp[0]); + __builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0[1], (vec_uc8)inp[1]); + __builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0[2], (vec_uc8)inp[2]); + __builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0[3], (vec_uc8)inp[3]); +} + +FORCEINLINE void vec_load_mult42_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *inp) +{ + vec_bf16 in01[4], in11[4]; + + vec_load_pair2(in01, in0); + vec_load_pair2(in11, in1); + + __builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[0], (vec_uc8)inp[0]); + __builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[0], (vec_uc8)inp[0]); + __builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[1], (vec_uc8)inp[1]); + __builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[1], (vec_uc8)inp[1]); + __builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[2], (vec_uc8)inp[2]); + __builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[2], (vec_uc8)inp[2]); + __builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[3], (vec_uc8)inp[3]); + __builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[3], (vec_uc8)inp[3]); +} + +FORCEINLINE void vec_load_mult44_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *in2, vec_bf16 *in3, vec_bf16 *inp) +{ + vec_bf16 in01[4], in11[4], in21[4], in31[4]; + + vec_load_pair2(in01, in0); + vec_load_pair2(in11, in1); + vec_load_pair2(in21, in2); + vec_load_pair2(in31, in3); + + __builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[0], (vec_uc8)inp[0]); + __builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[0], (vec_uc8)inp[0]); + __builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21[0], (vec_uc8)inp[0]); + __builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31[0], (vec_uc8)inp[0]); + __builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[1], (vec_uc8)inp[1]); + __builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[1], (vec_uc8)inp[1]); + __builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21[1], (vec_uc8)inp[1]); + __builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31[1], (vec_uc8)inp[1]); + __builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[2], (vec_uc8)inp[2]); + __builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[2], (vec_uc8)inp[2]); + __builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21[2], (vec_uc8)inp[2]); + __builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31[2], (vec_uc8)inp[2]); + __builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[3], (vec_uc8)inp[3]); + __builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[3], (vec_uc8)inp[3]); + __builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21[3], (vec_uc8)inp[3]); + __builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31[3], (vec_uc8)inp[3]); +} + FORCEINLINE void vec_loadN_mult_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 inp, BLASLONG n) { vec_bf16 in0 = vec_loadN(in, n); @@ -57,6 +173,28 @@ FORCEINLINE void vec_loadN_mult_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 i __builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0, (vec_uc8)inp); } +FORCEINLINE void vec_loadN_mult12a_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 inp, BLASLONG n) +{ + vec_bf16 in01 = (vec_bf16)vec_loadN(in0, n); + vec_bf16 in11 = (vec_bf16)vec_loadN(in1, n); + + __builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01, (vec_uc8)inp); + __builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11, (vec_uc8)inp); +} + +FORCEINLINE void vec_loadN_mult14_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *in2, vec_bf16 *in3, vec_bf16 inp, BLASLONG n) +{ + vec_bf16 in01 = (vec_bf16)vec_loadN(in0, n); + vec_bf16 in11 = (vec_bf16)vec_loadN(in1, n); + vec_bf16 in21 = (vec_bf16)vec_loadN(in2, n); + vec_bf16 in31 = (vec_bf16)vec_loadN(in3, n); + + __builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01, (vec_uc8)inp); + __builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11, (vec_uc8)inp); + __builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21, (vec_uc8)inp); + __builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31, (vec_uc8)inp); +} + FORCEINLINE void vec_mult1_mma(__vector_quad *out, vec_bf16 in0, vec_bf16 inp) { vec_bf16 in00 = vec_mergeh(in0, in0); diff --git a/kernel/power/sbgemv_n_power10.c b/kernel/power/sbgemv_n_power10.c index f33a246a99..b1dcb2fcc4 100644 --- a/kernel/power/sbgemv_n_power10.c +++ b/kernel/power/sbgemv_n_power10.c @@ -119,12 +119,12 @@ static void BF16GEMV_N_MMA_1(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA if (n > 4) { vec_loadN_mult12_mma(&temp[0], &va0[i], v_x0[ 0], n); - BLASLONG n3 = n & 3; - vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n3); + n &= 3; + vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n); vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0); - vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n3); + vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n); } else if (n) { vec_loadN_mult11_mma(&temp[0], &va0[i], v_x0[ 0], n); @@ -213,12 +213,12 @@ static void BF16GEMV_N_MMA_2(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA if (n > 4) { vec_loadN_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0], n); - BLASLONG n3 = n & 3; - vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n3); + n &= 3; + vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n); vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0); - vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n3); + vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n); } else if (n) { vec_loadN_mult11a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0], n); @@ -318,12 +318,12 @@ static void BF16GEMV_N_MMA_4(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA vec_loadN_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0], n); vec_loadN_mult22b_mma(&temp[0], &va2[i], &va3[i], v_x0[ 4], n); - BLASLONG n3 = n & 3; - vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n3); + n &= 3; + vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n); vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0); - vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n3); + vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n); } else if (n) { vec_loadN_mult11a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0], n); vec_loadN_mult11b_mma(&temp[0], &va2[i], &va3[i], v_x0[ 4], n); @@ -445,12 +445,12 @@ static void BF16GEMV_N_MMA_8(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, BLAS vec_loadN_mult22b_mma(&temp[0], &vb0[i], &vb1[i], v_x0[ 8], n); vec_loadN_mult22b_mma(&temp[0], &vb2[i], &vb3[i], v_x0[12], n); - BLASLONG n3 = n & 3; - vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n3); + n &= 3; + vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n); vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0); - vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n3); + vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n); } else if (n) { vec_loadN_mult11a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0], n); vec_loadN_mult11b_mma(&temp[0], &va2[i], &va3[i], v_x0[ 4], n); diff --git a/kernel/power/sbgemv_t_power10.c b/kernel/power/sbgemv_t_power10.c index 810287e89a..9a5c54f12f 100644 --- a/kernel/power/sbgemv_t_power10.c +++ b/kernel/power/sbgemv_t_power10.c @@ -49,7 +49,7 @@ static void BF16GEMV_T_MMA_1(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL vec_bf16 *va0, *v_x; __vector_quad temp0; vec_f32 temp00[4]; - vec_bf16 inp[2]; + vec_bf16 inp[4]; __builtin_mma_xxsetaccz(&temp0); @@ -59,10 +59,18 @@ static void BF16GEMV_T_MMA_1(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL BLASLONG n8 = n / 8; BLASLONG i = 0; - for (; i + 2 <= n8; i += 2) { + for (; i + 4 <= n8; i += 4) { + vec_load_pair2(inp, &v_x[i]); + + vec_load_mult4_mma(&temp0, &va0[i + 0], inp); + } + + if (n8 & 2) { vec_load_pair((vec_f32 *)inp, (vec_f32 *)&v_x[i]); vec_load_mult2_mma(&temp0, &va0[i + 0], inp); + + i += 2; } if (n8 & 1) { @@ -89,12 +97,12 @@ static void BF16GEMV_T_MMA_2(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL { IFLOAT *a0, *a1; vec_bf16 *va0, *va1, *v_x; - __vector_quad temp0, temp1; - vec_f32 temp00[4], temp01[4]; - vec_bf16 inp[2]; + __vector_quad temp0[2]; + vec_f32 temp00[4*2]; + vec_bf16 inp[4]; - __builtin_mma_xxsetaccz(&temp0); - __builtin_mma_xxsetaccz(&temp1); + __builtin_mma_xxsetaccz(&temp0[0]); + __builtin_mma_xxsetaccz(&temp0[1]); a0 = ap; a1 = ap + lda; @@ -104,18 +112,24 @@ static void BF16GEMV_T_MMA_2(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL BLASLONG n8 = n / 8; BLASLONG i = 0; - for (; i + 2 <= n8; i += 2) { + for (; i + 4 <= n8; i += 4) { + vec_load_pair2(inp, &v_x[i]); + + vec_load_mult42_mma(&temp0[0], &va0[i + 0], &va1[i + 0], inp); + } + + if (n8 & 2) { vec_load_pair((vec_f32 *)inp, (vec_f32 *)&v_x[i]); - vec_load_mult2_mma(&temp0, &va0[i + 0], inp); - vec_load_mult2_mma(&temp1, &va1[i + 0], inp); + vec_load_mult22_mma(&temp0[0], &va0[i + 0], &va1[i + 0], inp); + + i += 2; } if (n8 & 1) { inp[0] = (vec_bf16)vec_load_vec(&v_x[i]); - vec_load_mult_mma(&temp0, &va0[i], inp[0]); - vec_load_mult_mma(&temp1, &va1[i], inp[0]); + vec_load_mult12a_mma(&temp0[0], &va0[i], &va1[i], inp[0]); i++; } @@ -124,29 +138,28 @@ static void BF16GEMV_T_MMA_2(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL if (n) { inp[0] = vec_loadN(&v_x[i], n); - vec_loadN_mult_mma(&temp0, &va0[i], inp[0], n); - vec_loadN_mult_mma(&temp1, &va1[i], inp[0], n); + vec_loadN_mult12a_mma(&temp0[0], &va0[i], &va1[i], inp[0], n); } - __builtin_mma_disassemble_acc((void*)temp00, &temp0); - __builtin_mma_disassemble_acc((void*)temp01, &temp1); + __builtin_mma_disassemble_acc((void*)(temp00 + 0), &temp0[0]); + __builtin_mma_disassemble_acc((void*)(temp00 + 4), &temp0[1]); y[0] = (alpha * (temp00[0][0] + temp00[1][1] + temp00[2][2] + temp00[3][3])) + (beta * y[0]); - y[1] = (alpha * (temp01[0][0] + temp01[1][1] + temp01[2][2] + temp01[3][3])) + (beta * y[1]); + y[1] = (alpha * (temp00[4][0] + temp00[5][1] + temp00[6][2] + temp00[7][3])) + (beta * y[1]); } static void BF16GEMV_T_MMA_4(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha, FLOAT beta) { IFLOAT *a0, *a1, *a2, *a3; vec_bf16 *va0, *va1, *va2, *va3, *v_x; - __vector_quad temp0, temp1, temp2, temp3; - vec_f32 temp00[4], temp01[4], temp02[4], temp03[4]; - vec_bf16 inp[2]; + __vector_quad temp0[4]; + vec_f32 temp00[4*4]; + vec_bf16 inp[4]; - __builtin_mma_xxsetaccz(&temp0); - __builtin_mma_xxsetaccz(&temp1); - __builtin_mma_xxsetaccz(&temp2); - __builtin_mma_xxsetaccz(&temp3); + __builtin_mma_xxsetaccz(&temp0[0]); + __builtin_mma_xxsetaccz(&temp0[1]); + __builtin_mma_xxsetaccz(&temp0[2]); + __builtin_mma_xxsetaccz(&temp0[3]); a0 = ap; a1 = ap + lda; @@ -160,22 +173,24 @@ static void BF16GEMV_T_MMA_4(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL BLASLONG n8 = n / 8; BLASLONG i = 0; - for (; i + 2 <= n8; i += 2) { + for (; i + 4 <= n8; i += 4) { + vec_load_pair2(inp, &v_x[i]); + + vec_load_mult44_mma(&temp0[0], &va0[i + 0], &va1[i + 0], &va2[i + 0], &va3[i + 0], inp); + } + + if (n8 & 2) { vec_load_pair((vec_f32 *)inp, (vec_f32 *)&v_x[i]); - vec_load_mult2_mma(&temp0, &va0[i + 0], inp); - vec_load_mult2_mma(&temp1, &va1[i + 0], inp); - vec_load_mult2_mma(&temp2, &va2[i + 0], inp); - vec_load_mult2_mma(&temp3, &va3[i + 0], inp); + vec_load_mult24_mma(&temp0[0], &va0[i + 0], &va1[i + 0], &va2[i + 0], &va3[i + 0], inp); + + i += 2; } if (n8 & 1) { inp[0] = (vec_bf16)vec_load_vec(&v_x[i]); - vec_load_mult_mma(&temp0, &va0[i], inp[0]); - vec_load_mult_mma(&temp1, &va1[i], inp[0]); - vec_load_mult_mma(&temp2, &va2[i], inp[0]); - vec_load_mult_mma(&temp3, &va3[i], inp[0]); + vec_load_mult14_mma(&temp0[0], &va0[i], &va1[i], &va2[i], &va3[i], inp[0]); i++; } @@ -184,30 +199,27 @@ static void BF16GEMV_T_MMA_4(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL if (n) { inp[0] = vec_loadN(&v_x[i], n); - vec_loadN_mult_mma(&temp0, &va0[i], inp[0], n); - vec_loadN_mult_mma(&temp1, &va1[i], inp[0], n); - vec_loadN_mult_mma(&temp2, &va2[i], inp[0], n); - vec_loadN_mult_mma(&temp3, &va3[i], inp[0], n); + vec_loadN_mult14_mma(&temp0[0], &va0[i], &va1[i], &va2[i], &va3[i], inp[0], n); } - __builtin_mma_disassemble_acc((void*)temp00, &temp0); - __builtin_mma_disassemble_acc((void*)temp01, &temp1); - __builtin_mma_disassemble_acc((void*)temp02, &temp2); - __builtin_mma_disassemble_acc((void*)temp03, &temp3); + __builtin_mma_disassemble_acc((void*)(temp00 + 0), &temp0[0]); + __builtin_mma_disassemble_acc((void*)(temp00 + 4), &temp0[1]); + __builtin_mma_disassemble_acc((void*)(temp00 + 8), &temp0[2]); + __builtin_mma_disassemble_acc((void*)(temp00 + 12), &temp0[3]); vec_f32 t0, t1, t2, t3, t4, t5, t6, t7; vec_f32 a = { alpha, alpha, alpha, alpha }; vec_f32 b = { beta, beta, beta, beta }; vec_f32 *v_y = (vec_f32 *) y; - t0 = vec_mergeh(temp00[0], temp01[0]); - t1 = vec_mergeh(temp02[0], temp03[0]); - t2 = vec_mergeo(temp00[1], temp01[1]); - t3 = vec_mergeo(temp02[1], temp03[1]); - t4 = vec_mergel(temp00[2], temp01[2]); - t5 = vec_mergel(temp02[2], temp03[2]); - t6 = vec_mergeo(temp00[3], temp01[3]); - t7 = vec_mergeo(temp02[3], temp03[3]); + t0 = vec_mergeh(temp00[ 0], temp00[ 4]); + t1 = vec_mergeh(temp00[ 8], temp00[12]); + t2 = vec_mergeo(temp00[ 1], temp00[ 5]); + t3 = vec_mergeo(temp00[ 9], temp00[13]); + t4 = vec_mergel(temp00[ 2], temp00[ 6]); + t5 = vec_mergel(temp00[10], temp00[14]); + t6 = vec_mergeo(temp00[ 3], temp00[ 7]); + t7 = vec_mergeo(temp00[11], temp00[15]); t0 = vec_xxpermdi(t0, t1, 0); t2 = vec_xxpermdi(t2, t3, 0); t4 = vec_xxpermdi(t4, t5, 0); @@ -223,18 +235,18 @@ static void BF16GEMV_T_MMA_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL { IFLOAT *a0, *a1, *a2, *a3, *a4, *a5, *a6, *a7; vec_bf16 *va0, *va1, *va2, *va3, *va4, *va5, *va6, *va7, *v_x; - __vector_quad temp0, temp1, temp2, temp3, temp4, temp5, temp6, temp7; - vec_f32 temp00[4], temp01[4], temp02[4], temp03[4], temp04[4], temp05[4], temp06[4], temp07[4]; - vec_bf16 inp[2]; - - __builtin_mma_xxsetaccz(&temp0); - __builtin_mma_xxsetaccz(&temp1); - __builtin_mma_xxsetaccz(&temp2); - __builtin_mma_xxsetaccz(&temp3); - __builtin_mma_xxsetaccz(&temp4); - __builtin_mma_xxsetaccz(&temp5); - __builtin_mma_xxsetaccz(&temp6); - __builtin_mma_xxsetaccz(&temp7); + __vector_quad temp0[8]; + vec_f32 temp00[4*8]; + vec_bf16 inp[4]; + + __builtin_mma_xxsetaccz(&temp0[0]); + __builtin_mma_xxsetaccz(&temp0[1]); + __builtin_mma_xxsetaccz(&temp0[2]); + __builtin_mma_xxsetaccz(&temp0[3]); + __builtin_mma_xxsetaccz(&temp0[4]); + __builtin_mma_xxsetaccz(&temp0[5]); + __builtin_mma_xxsetaccz(&temp0[6]); + __builtin_mma_xxsetaccz(&temp0[7]); a0 = ap; a1 = ap + lda; @@ -256,30 +268,27 @@ static void BF16GEMV_T_MMA_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL BLASLONG n8 = n / 8; BLASLONG i = 0; - for (; i + 2 <= n8; i += 2) { + for (; i + 4 <= n8; i += 4) { + vec_load_pair2(inp, &v_x[i]); + + vec_load_mult44_mma(&temp0[0], &va0[i + 0], &va1[i + 0], &va2[i + 0], &va3[i + 0], inp); + vec_load_mult44_mma(&temp0[4], &va4[i + 0], &va5[i + 0], &va6[i + 0], &va7[i + 0], inp); + } + + if (n8 & 2) { vec_load_pair((vec_f32 *)inp, (vec_f32 *)&v_x[i]); - vec_load_mult2_mma(&temp0, &va0[i + 0], inp); - vec_load_mult2_mma(&temp1, &va1[i + 0], inp); - vec_load_mult2_mma(&temp2, &va2[i + 0], inp); - vec_load_mult2_mma(&temp3, &va3[i + 0], inp); - vec_load_mult2_mma(&temp4, &va4[i + 0], inp); - vec_load_mult2_mma(&temp5, &va5[i + 0], inp); - vec_load_mult2_mma(&temp6, &va6[i + 0], inp); - vec_load_mult2_mma(&temp7, &va7[i + 0], inp); + vec_load_mult24_mma(&temp0[0], &va0[i + 0], &va1[i + 0], &va2[i + 0], &va3[i + 0], inp); + vec_load_mult24_mma(&temp0[4], &va4[i + 0], &va5[i + 0], &va6[i + 0], &va7[i + 0], inp); + + i += 2; } if (n8 & 1) { inp[0] = (vec_bf16)vec_load_vec(&v_x[i]); - vec_load_mult_mma(&temp0, &va0[i], inp[0]); - vec_load_mult_mma(&temp1, &va1[i], inp[0]); - vec_load_mult_mma(&temp2, &va2[i], inp[0]); - vec_load_mult_mma(&temp3, &va3[i], inp[0]); - vec_load_mult_mma(&temp4, &va4[i], inp[0]); - vec_load_mult_mma(&temp5, &va5[i], inp[0]); - vec_load_mult_mma(&temp6, &va6[i], inp[0]); - vec_load_mult_mma(&temp7, &va7[i], inp[0]); + vec_load_mult14_mma(&temp0[0], &va0[i], &va1[i], &va2[i], &va3[i], inp[0]); + vec_load_mult14_mma(&temp0[4], &va4[i], &va5[i], &va6[i], &va7[i], inp[0]); i++; } @@ -288,38 +297,32 @@ static void BF16GEMV_T_MMA_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL if (n) { inp[0] = vec_loadN(&v_x[i], n); - vec_loadN_mult_mma(&temp0, &va0[i], inp[0], n); - vec_loadN_mult_mma(&temp1, &va1[i], inp[0], n); - vec_loadN_mult_mma(&temp2, &va2[i], inp[0], n); - vec_loadN_mult_mma(&temp3, &va3[i], inp[0], n); - vec_loadN_mult_mma(&temp4, &va4[i], inp[0], n); - vec_loadN_mult_mma(&temp5, &va5[i], inp[0], n); - vec_loadN_mult_mma(&temp6, &va6[i], inp[0], n); - vec_loadN_mult_mma(&temp7, &va7[i], inp[0], n); + vec_loadN_mult14_mma(&temp0[0], &va0[i], &va1[i], &va2[i], &va3[i], inp[0], n); + vec_loadN_mult14_mma(&temp0[4], &va4[i], &va5[i], &va6[i], &va7[i], inp[0], n); } - __builtin_mma_disassemble_acc((void*)temp00, &temp0); - __builtin_mma_disassemble_acc((void*)temp01, &temp1); - __builtin_mma_disassemble_acc((void*)temp02, &temp2); - __builtin_mma_disassemble_acc((void*)temp03, &temp3); - __builtin_mma_disassemble_acc((void*)temp04, &temp4); - __builtin_mma_disassemble_acc((void*)temp05, &temp5); - __builtin_mma_disassemble_acc((void*)temp06, &temp6); - __builtin_mma_disassemble_acc((void*)temp07, &temp7); + __builtin_mma_disassemble_acc((void*)(temp00 + 0), &temp0[0]); + __builtin_mma_disassemble_acc((void*)(temp00 + 4), &temp0[1]); + __builtin_mma_disassemble_acc((void*)(temp00 + 8), &temp0[2]); + __builtin_mma_disassemble_acc((void*)(temp00 + 12), &temp0[3]); + __builtin_mma_disassemble_acc((void*)(temp00 + 16), &temp0[4]); + __builtin_mma_disassemble_acc((void*)(temp00 + 20), &temp0[5]); + __builtin_mma_disassemble_acc((void*)(temp00 + 24), &temp0[6]); + __builtin_mma_disassemble_acc((void*)(temp00 + 28), &temp0[7]); vec_f32 t0, t1, t2, t3, t4, t5, t6, t7, t10, t11, t12, t13, t14, t15, t16, t17; vec_f32 a = { alpha, alpha, alpha, alpha }; vec_f32 b = { beta, beta, beta, beta }; vec_f32 *v_y = (vec_f32 *) y; - t0 = vec_mergeh(temp00[0], temp01[0]); - t1 = vec_mergeh(temp02[0], temp03[0]); - t2 = vec_mergeo(temp00[1], temp01[1]); - t3 = vec_mergeo(temp02[1], temp03[1]); - t4 = vec_mergel(temp00[2], temp01[2]); - t5 = vec_mergel(temp02[2], temp03[2]); - t6 = vec_mergeo(temp00[3], temp01[3]); - t7 = vec_mergeo(temp02[3], temp03[3]); + t0 = vec_mergeh(temp00[ 0], temp00[ 4]); + t1 = vec_mergeh(temp00[ 8], temp00[12]); + t2 = vec_mergeo(temp00[ 1], temp00[ 5]); + t3 = vec_mergeo(temp00[ 9], temp00[13]); + t4 = vec_mergel(temp00[ 2], temp00[ 6]); + t5 = vec_mergel(temp00[10], temp00[14]); + t6 = vec_mergeo(temp00[ 3], temp00[ 7]); + t7 = vec_mergeo(temp00[11], temp00[15]); t0 = vec_xxpermdi(t0, t1, 0); t2 = vec_xxpermdi(t2, t3, 0); t4 = vec_xxpermdi(t4, t5, 0); @@ -327,14 +330,14 @@ static void BF16GEMV_T_MMA_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL t0 += t2 + t4 + t6; - t10 = vec_mergeh(temp04[0], temp05[0]); - t11 = vec_mergeh(temp06[0], temp07[0]); - t12 = vec_mergeo(temp04[1], temp05[1]); - t13 = vec_mergeo(temp06[1], temp07[1]); - t14 = vec_mergel(temp04[2], temp05[2]); - t15 = vec_mergel(temp06[2], temp07[2]); - t16 = vec_mergeo(temp04[3], temp05[3]); - t17 = vec_mergeo(temp06[3], temp07[3]); + t10 = vec_mergeh(temp00[16], temp00[20]); + t11 = vec_mergeh(temp00[24], temp00[28]); + t12 = vec_mergeo(temp00[17], temp00[21]); + t13 = vec_mergeo(temp00[25], temp00[29]); + t14 = vec_mergel(temp00[18], temp00[22]); + t15 = vec_mergel(temp00[26], temp00[30]); + t16 = vec_mergeo(temp00[19], temp00[23]); + t17 = vec_mergeo(temp00[27], temp00[31]); t10 = vec_xxpermdi(t10, t11, 0); t12 = vec_xxpermdi(t12, t13, 0); t14 = vec_xxpermdi(t14, t15, 0); From eb6f3a05efb1a441c8920a2c4a7fa2e0fe7f6507 Mon Sep 17 00:00:00 2001 From: Chip Kerchner Date: Thu, 26 Sep 2024 09:28:56 -0500 Subject: [PATCH 13/24] Common MMA code. --- kernel/power/sbgemv_common_power10.c | 94 ++++++++++++---------------- 1 file changed, 40 insertions(+), 54 deletions(-) diff --git a/kernel/power/sbgemv_common_power10.c b/kernel/power/sbgemv_common_power10.c index 638e2655c0..0510088b23 100644 --- a/kernel/power/sbgemv_common_power10.c +++ b/kernel/power/sbgemv_common_power10.c @@ -48,22 +48,20 @@ FORCEINLINE void vec_load_mult_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 in FORCEINLINE void vec_load_mult12a_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 inp) { - vec_bf16 in01 = (vec_bf16)vec_load_vec(in0); vec_bf16 in11 = (vec_bf16)vec_load_vec(in1); - __builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01, (vec_uc8)inp); + vec_load_mult_mma(out, in0, inp); + __builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11, (vec_uc8)inp); } FORCEINLINE void vec_load_mult14_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *in2, vec_bf16 *in3, vec_bf16 inp) { - vec_bf16 in01 = (vec_bf16)vec_load_vec(in0); - vec_bf16 in11 = (vec_bf16)vec_load_vec(in1); vec_bf16 in21 = (vec_bf16)vec_load_vec(in2); vec_bf16 in31 = (vec_bf16)vec_load_vec(in3); - __builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01, (vec_uc8)inp); - __builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11, (vec_uc8)inp); + vec_load_mult12a_mma(out, in0, in1, inp); + __builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21, (vec_uc8)inp); __builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31, (vec_uc8)inp); } @@ -78,6 +76,12 @@ FORCEINLINE void vec_load_mult2_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 * __builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0[1], (vec_uc8)inp[1]); } +FORCEINLINE void vec_mult2d_mma(__vector_quad *out, vec_bf16 *in01, vec_bf16 *in11, vec_bf16 *inp) +{ + __builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[0], (vec_uc8)inp[0]); + __builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[0], (vec_uc8)inp[0]); +} + FORCEINLINE void vec_load_mult22_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *inp) { vec_bf16 in01[2], in11[2]; @@ -85,10 +89,8 @@ FORCEINLINE void vec_load_mult22_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 vec_load_pair((vec_f32 *)in01, (vec_f32 *)in0); vec_load_pair((vec_f32 *)in11, (vec_f32 *)in1); - __builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[0], (vec_uc8)inp[0]); - __builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[0], (vec_uc8)inp[0]); - __builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[1], (vec_uc8)inp[1]); - __builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[1], (vec_uc8)inp[1]); + vec_mult2d_mma(out, in01 + 0, in11 + 0, inp + 0); + vec_mult2d_mma(out, in01 + 1, in11 + 1, inp + 1); } FORCEINLINE void vec_load_mult24_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *in2, vec_bf16 *in3, vec_bf16 *inp) @@ -100,26 +102,22 @@ FORCEINLINE void vec_load_mult24_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 vec_load_pair((vec_f32 *)in21, (vec_f32 *)in2); vec_load_pair((vec_f32 *)in31, (vec_f32 *)in3); - __builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[0], (vec_uc8)inp[0]); - __builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[0], (vec_uc8)inp[0]); - __builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21[0], (vec_uc8)inp[0]); - __builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31[0], (vec_uc8)inp[0]); - __builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[1], (vec_uc8)inp[1]); - __builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[1], (vec_uc8)inp[1]); - __builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21[1], (vec_uc8)inp[1]); - __builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31[1], (vec_uc8)inp[1]); + vec_mult2d_mma(out + 0, in01 + 0, in11 + 0, inp + 0); + vec_mult2d_mma(out + 2, in21 + 0, in31 + 0, inp + 0); + vec_mult2d_mma(out + 0, in01 + 1, in11 + 1, inp + 1); + vec_mult2d_mma(out + 2, in21 + 1, in31 + 1, inp + 1); } FORCEINLINE void vec_load_mult4_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 *inp) { - vec_bf16 in0[4]; + vec_bf16 in0[2]; - vec_load_pair2(in0, in); + vec_load_pair((vec_f32 *)(in0 + 0), (vec_f32 *)(in + 2)); - __builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0[0], (vec_uc8)inp[0]); - __builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0[1], (vec_uc8)inp[1]); - __builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0[2], (vec_uc8)inp[2]); - __builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0[3], (vec_uc8)inp[3]); + vec_load_mult2_mma(out, in + 0, inp + 0); + + __builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0[0], (vec_uc8)inp[2]); + __builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0[1], (vec_uc8)inp[3]); } FORCEINLINE void vec_load_mult42_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *inp) @@ -129,14 +127,16 @@ FORCEINLINE void vec_load_mult42_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 vec_load_pair2(in01, in0); vec_load_pair2(in11, in1); - __builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[0], (vec_uc8)inp[0]); - __builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[0], (vec_uc8)inp[0]); - __builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[1], (vec_uc8)inp[1]); - __builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[1], (vec_uc8)inp[1]); - __builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[2], (vec_uc8)inp[2]); - __builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[2], (vec_uc8)inp[2]); - __builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[3], (vec_uc8)inp[3]); - __builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[3], (vec_uc8)inp[3]); + vec_mult2d_mma(out, in01 + 0, in11 + 0, inp + 0); + vec_mult2d_mma(out, in01 + 1, in11 + 1, inp + 1); + vec_mult2d_mma(out, in01 + 2, in11 + 2, inp + 2); + vec_mult2d_mma(out, in01 + 3, in11 + 3, inp + 3); +} + +FORCEINLINE void vec_mult4d_mma(__vector_quad *out, vec_bf16 *in01, vec_bf16 *in11, vec_bf16 *in21, vec_bf16 *in31, vec_bf16 *inp) +{ + vec_mult2d_mma(out + 0, in01, in11, inp); + vec_mult2d_mma(out + 2, in21, in31, inp); } FORCEINLINE void vec_load_mult44_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *in2, vec_bf16 *in3, vec_bf16 *inp) @@ -148,22 +148,10 @@ FORCEINLINE void vec_load_mult44_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 vec_load_pair2(in21, in2); vec_load_pair2(in31, in3); - __builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[0], (vec_uc8)inp[0]); - __builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[0], (vec_uc8)inp[0]); - __builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21[0], (vec_uc8)inp[0]); - __builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31[0], (vec_uc8)inp[0]); - __builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[1], (vec_uc8)inp[1]); - __builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[1], (vec_uc8)inp[1]); - __builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21[1], (vec_uc8)inp[1]); - __builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31[1], (vec_uc8)inp[1]); - __builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[2], (vec_uc8)inp[2]); - __builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[2], (vec_uc8)inp[2]); - __builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21[2], (vec_uc8)inp[2]); - __builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31[2], (vec_uc8)inp[2]); - __builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[3], (vec_uc8)inp[3]); - __builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[3], (vec_uc8)inp[3]); - __builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21[3], (vec_uc8)inp[3]); - __builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31[3], (vec_uc8)inp[3]); + vec_mult4d_mma(out, in01 + 0, in11 + 0, in21 + 0, in31 + 0, inp + 0); + vec_mult4d_mma(out, in01 + 1, in11 + 1, in21 + 1, in31 + 1, inp + 1); + vec_mult4d_mma(out, in01 + 2, in11 + 2, in21 + 2, in31 + 2, inp + 2); + vec_mult4d_mma(out, in01 + 3, in11 + 3, in21 + 3, in31 + 3, inp + 3); } FORCEINLINE void vec_loadN_mult_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 inp, BLASLONG n) @@ -175,22 +163,20 @@ FORCEINLINE void vec_loadN_mult_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 i FORCEINLINE void vec_loadN_mult12a_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 inp, BLASLONG n) { - vec_bf16 in01 = (vec_bf16)vec_loadN(in0, n); vec_bf16 in11 = (vec_bf16)vec_loadN(in1, n); - __builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01, (vec_uc8)inp); + vec_loadN_mult_mma(out, in0, inp, n); + __builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11, (vec_uc8)inp); } FORCEINLINE void vec_loadN_mult14_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *in2, vec_bf16 *in3, vec_bf16 inp, BLASLONG n) { - vec_bf16 in01 = (vec_bf16)vec_loadN(in0, n); - vec_bf16 in11 = (vec_bf16)vec_loadN(in1, n); vec_bf16 in21 = (vec_bf16)vec_loadN(in2, n); vec_bf16 in31 = (vec_bf16)vec_loadN(in3, n); - __builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01, (vec_uc8)inp); - __builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11, (vec_uc8)inp); + vec_loadN_mult12a_mma(out, in0, in1, inp, n); + __builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21, (vec_uc8)inp); __builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31, (vec_uc8)inp); } From d7c0d87cd1b961300a1d32a3a7ac74d030ad1faf Mon Sep 17 00:00:00 2001 From: Chip Kerchner Date: Thu, 26 Sep 2024 15:21:29 -0500 Subject: [PATCH 14/24] Small changes. --- kernel/power/sbgemv_common_power10.c | 36 +++++++++++++++++++++++ kernel/power/sbgemv_t_power10.c | 43 +++++++--------------------- kernel/power/sbgemv_t_vsx.c | 9 +++--- 3 files changed, 52 insertions(+), 36 deletions(-) diff --git a/kernel/power/sbgemv_common_power10.c b/kernel/power/sbgemv_common_power10.c index 0510088b23..b0e611cb68 100644 --- a/kernel/power/sbgemv_common_power10.c +++ b/kernel/power/sbgemv_common_power10.c @@ -525,6 +525,42 @@ FORCEINLINE void vec_store4_pair(vec_f32 *v_y, vec_f32 *vy0) vec_store_pair(v_y + 6, vy0 + 6); } +FORCEINLINE void vec_setzero_2(__vector_quad *temp0) +{ + __builtin_mma_xxsetaccz(&temp0[0]); + __builtin_mma_xxsetaccz(&temp0[1]); +} + +FORCEINLINE void vec_setzero_4(__vector_quad *temp0) +{ + vec_setzero_2(temp0 + 0); + vec_setzero_2(temp0 + 2); +} + +FORCEINLINE void vec_setzero_8(__vector_quad *temp0) +{ + vec_setzero_4(temp0 + 0); + vec_setzero_4(temp0 + 4); +} + +FORCEINLINE void vec_reduce_2(vec_f32 *temp00, __vector_quad *temp0) +{ + __builtin_mma_disassemble_acc((void*)(temp00 + 0), &temp0[0]); + __builtin_mma_disassemble_acc((void*)(temp00 + 4), &temp0[1]); +} + +FORCEINLINE void vec_reduce_4(vec_f32 *temp00, __vector_quad *temp0) +{ + vec_reduce_2(temp00 + 0, temp0 + 0); + vec_reduce_2(temp00 + 8, temp0 + 2); +} + +FORCEINLINE void vec_reduce_8(vec_f32 *temp00, __vector_quad *temp0) +{ + vec_reduce_4(temp00 + 0, temp0 + 0); + vec_reduce_4(temp00 + 16, temp0 + 4); +} + #ifdef USE_MERGE_MMA FORCEINLINE void vec_load8_pair(vec_f32 *vy0, vec_f32 *v_y) { diff --git a/kernel/power/sbgemv_t_power10.c b/kernel/power/sbgemv_t_power10.c index 9a5c54f12f..d2f6087f05 100644 --- a/kernel/power/sbgemv_t_power10.c +++ b/kernel/power/sbgemv_t_power10.c @@ -101,8 +101,7 @@ static void BF16GEMV_T_MMA_2(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL vec_f32 temp00[4*2]; vec_bf16 inp[4]; - __builtin_mma_xxsetaccz(&temp0[0]); - __builtin_mma_xxsetaccz(&temp0[1]); + vec_setzero_2(&temp0[0]); a0 = ap; a1 = ap + lda; @@ -141,8 +140,7 @@ static void BF16GEMV_T_MMA_2(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL vec_loadN_mult12a_mma(&temp0[0], &va0[i], &va1[i], inp[0], n); } - __builtin_mma_disassemble_acc((void*)(temp00 + 0), &temp0[0]); - __builtin_mma_disassemble_acc((void*)(temp00 + 4), &temp0[1]); + vec_reduce_2(temp00, &temp0[0]); y[0] = (alpha * (temp00[0][0] + temp00[1][1] + temp00[2][2] + temp00[3][3])) + (beta * y[0]); y[1] = (alpha * (temp00[4][0] + temp00[5][1] + temp00[6][2] + temp00[7][3])) + (beta * y[1]); @@ -156,10 +154,7 @@ static void BF16GEMV_T_MMA_4(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL vec_f32 temp00[4*4]; vec_bf16 inp[4]; - __builtin_mma_xxsetaccz(&temp0[0]); - __builtin_mma_xxsetaccz(&temp0[1]); - __builtin_mma_xxsetaccz(&temp0[2]); - __builtin_mma_xxsetaccz(&temp0[3]); + vec_setzero_4(&temp0[0]); a0 = ap; a1 = ap + lda; @@ -202,10 +197,7 @@ static void BF16GEMV_T_MMA_4(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL vec_loadN_mult14_mma(&temp0[0], &va0[i], &va1[i], &va2[i], &va3[i], inp[0], n); } - __builtin_mma_disassemble_acc((void*)(temp00 + 0), &temp0[0]); - __builtin_mma_disassemble_acc((void*)(temp00 + 4), &temp0[1]); - __builtin_mma_disassemble_acc((void*)(temp00 + 8), &temp0[2]); - __builtin_mma_disassemble_acc((void*)(temp00 + 12), &temp0[3]); + vec_reduce_4(temp00, &temp0[0]); vec_f32 t0, t1, t2, t3, t4, t5, t6, t7; vec_f32 a = { alpha, alpha, alpha, alpha }; @@ -239,23 +231,17 @@ static void BF16GEMV_T_MMA_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL vec_f32 temp00[4*8]; vec_bf16 inp[4]; - __builtin_mma_xxsetaccz(&temp0[0]); - __builtin_mma_xxsetaccz(&temp0[1]); - __builtin_mma_xxsetaccz(&temp0[2]); - __builtin_mma_xxsetaccz(&temp0[3]); - __builtin_mma_xxsetaccz(&temp0[4]); - __builtin_mma_xxsetaccz(&temp0[5]); - __builtin_mma_xxsetaccz(&temp0[6]); - __builtin_mma_xxsetaccz(&temp0[7]); + vec_setzero_8(&temp0[0]); + BLASLONG lda4 = lda << 2; a0 = ap; a1 = ap + lda; a2 = a1 + lda; a3 = a2 + lda; - a4 = a3 + lda; - a5 = a4 + lda; - a6 = a5 + lda; - a7 = a6 + lda; + a4 = a0 + lda4; + a5 = a1 + lda4; + a6 = a2 + lda4; + a7 = a3 + lda4; va0 = (vec_bf16 *)a0; va1 = (vec_bf16 *)a1; va2 = (vec_bf16 *)a2; @@ -301,14 +287,7 @@ static void BF16GEMV_T_MMA_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL vec_loadN_mult14_mma(&temp0[4], &va4[i], &va5[i], &va6[i], &va7[i], inp[0], n); } - __builtin_mma_disassemble_acc((void*)(temp00 + 0), &temp0[0]); - __builtin_mma_disassemble_acc((void*)(temp00 + 4), &temp0[1]); - __builtin_mma_disassemble_acc((void*)(temp00 + 8), &temp0[2]); - __builtin_mma_disassemble_acc((void*)(temp00 + 12), &temp0[3]); - __builtin_mma_disassemble_acc((void*)(temp00 + 16), &temp0[4]); - __builtin_mma_disassemble_acc((void*)(temp00 + 20), &temp0[5]); - __builtin_mma_disassemble_acc((void*)(temp00 + 24), &temp0[6]); - __builtin_mma_disassemble_acc((void*)(temp00 + 28), &temp0[7]); + vec_reduce_8(temp00, &temp0[0]); vec_f32 t0, t1, t2, t3, t4, t5, t6, t7, t10, t11, t12, t13, t14, t15, t16, t17; vec_f32 a = { alpha, alpha, alpha, alpha }; diff --git a/kernel/power/sbgemv_t_vsx.c b/kernel/power/sbgemv_t_vsx.c index 399989bb52..0750405031 100644 --- a/kernel/power/sbgemv_t_vsx.c +++ b/kernel/power/sbgemv_t_vsx.c @@ -198,14 +198,15 @@ static void BF16GEMV_T_VSX_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL vec_bf16 zero = { 0, 0, 0, 0, 0, 0, 0, 0 }; vec_f32 inp[2]; + BLASLONG lda4 = lda << 2; a0 = ap; a1 = ap + lda; a2 = a1 + lda; a3 = a2 + lda; - a4 = a3 + lda; - a5 = a4 + lda; - a6 = a5 + lda; - a7 = a6 + lda; + a4 = a0 + lda4; + a5 = a1 + lda4; + a6 = a2 + lda4; + a7 = a3 + lda4; va0 = (vec_bf16 *)a0; va1 = (vec_bf16 *)a1; va2 = (vec_bf16 *)a2; From c8788208c8bb135bb9b5b8af2476f296987b7cf5 Mon Sep 17 00:00:00 2001 From: Chip Kerchner Date: Fri, 27 Sep 2024 13:27:03 -0500 Subject: [PATCH 15/24] Fixing block issue with transpose version. --- kernel/power/sbgemv_n.c | 4 +--- kernel/power/sbgemv_t.c | 1 + 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/kernel/power/sbgemv_n.c b/kernel/power/sbgemv_n.c index 4768be31fa..eab0b4e33b 100644 --- a/kernel/power/sbgemv_n.c +++ b/kernel/power/sbgemv_n.c @@ -202,10 +202,8 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT * a += NB; if (inc_y != 1) { move_y(NB, ybuffer, y_ptr, inc_y); - y_ptr += (NB * inc_y); - } else { - y_ptr += NB; } + y_ptr += (NB * inc_y); } return 0; diff --git a/kernel/power/sbgemv_t.c b/kernel/power/sbgemv_t.c index 4cc8f060e9..c6fdb6b1ae 100644 --- a/kernel/power/sbgemv_t.c +++ b/kernel/power/sbgemv_t.c @@ -124,6 +124,7 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT * a += NB; x += NB * inc_x; + beta = (FLOAT)1; } return 0; From 32095b0cbbfbf2a9db382931cacbc400ae975603 Mon Sep 17 00:00:00 2001 From: Chip Kerchner Date: Tue, 1 Oct 2024 09:32:42 -0500 Subject: [PATCH 16/24] Remove parameter. --- kernel/power/sbgemv_common.c | 8 ++++---- kernel/power/sbgemv_t_vsx.c | 34 +++++++++++++++++----------------- 2 files changed, 21 insertions(+), 21 deletions(-) diff --git a/kernel/power/sbgemv_common.c b/kernel/power/sbgemv_common.c index ab50f430af..c9438b7e6d 100644 --- a/kernel/power/sbgemv_common.c +++ b/kernel/power/sbgemv_common.c @@ -58,9 +58,9 @@ FORCEINLINE vec_f32 vec_load_mult(vec_bf16 *in, vec_f32 *inp, vec_bf16 zero) return vec_mult(inp, in0, zero); } -FORCEINLINE void vec_load_vec2(vec_bf16 *in, BLASLONG i, vec_f32 *v_x0, vec_bf16 zero) +FORCEINLINE void vec_load_vec2(vec_bf16 *in, vec_f32 *v_x0, vec_bf16 zero) { - vec_bf16 inp = (vec_bf16)vec_load_vec(&in[i]); + vec_bf16 inp = (vec_bf16)vec_load_vec(in); v_x0[0] = BF16_HI(inp, zero); v_x0[1] = BF16_LO(inp, zero); @@ -89,9 +89,9 @@ FORCEINLINE vec_f32 vec_loadN_mult(vec_bf16 *in, vec_f32 *inp, BLASLONG n, vec_b return vec_mult(inp, in0, zero); } -FORCEINLINE void vec_loadN_vec2(vec_bf16 *in, BLASLONG i, vec_f32 *v_x0, BLASLONG n, vec_bf16 zero) +FORCEINLINE void vec_loadN_vec2(vec_bf16 *in, vec_f32 *v_x0, BLASLONG n, vec_bf16 zero) { - vec_bf16 inp = vec_loadN(&in[i], n); + vec_bf16 inp = vec_loadN(in, n); v_x0[0] = BF16_HI(inp, zero); v_x0[1] = BF16_LO(inp, zero); diff --git a/kernel/power/sbgemv_t_vsx.c b/kernel/power/sbgemv_t_vsx.c index 0750405031..272dccef76 100644 --- a/kernel/power/sbgemv_t_vsx.c +++ b/kernel/power/sbgemv_t_vsx.c @@ -55,14 +55,14 @@ static void BF16GEMV_T_VSX_1(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL BLASLONG i = 0; for (; i < n8; i++) { - vec_load_vec2(v_x, i, inp, zero); + vec_load_vec2(&v_x[i], inp, zero); temp0 += vec_load_mult(&va0[i], inp, zero); } n &= 7; if (n > 4) { - vec_loadN_vec2(v_x, i, inp, n, zero); + vec_loadN_vec2(&v_x[i], inp, n, zero); temp0 += vec_loadN_mult(&va0[i], inp, n, zero); } else if (n) { @@ -92,7 +92,7 @@ static void BF16GEMV_T_VSX_2(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL BLASLONG i = 0; for (; i < n8; i++) { - vec_load_vec2(v_x, i, inp, zero); + vec_load_vec2(&v_x[i], inp, zero); temp0 += vec_load_mult(&va0[i], inp, zero); temp1 += vec_load_mult(&va1[i], inp, zero); @@ -100,7 +100,7 @@ static void BF16GEMV_T_VSX_2(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL n &= 7; if (n > 4) { - vec_loadN_vec2(v_x, i, inp, n, zero); + vec_loadN_vec2(&v_x[i], inp, n, zero); temp0 += vec_loadN_mult(&va0[i], inp, n, zero); temp1 += vec_loadN_mult(&va1[i], inp, n, zero); @@ -139,7 +139,7 @@ static void BF16GEMV_T_VSX_4(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL BLASLONG i = 0; for (; i < n8; i++) { - vec_load_vec2(v_x, i, inp, zero); + vec_load_vec2(&v_x[i], inp, zero); temp0 += vec_load_mult(&va0[i], inp, zero); temp1 += vec_load_mult(&va1[i], inp, zero); @@ -149,7 +149,7 @@ static void BF16GEMV_T_VSX_4(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL n &= 7; if (n > 4) { - vec_loadN_vec2(v_x, i, inp, n, zero); + vec_loadN_vec2(&v_x[i], inp, n, zero); temp0 += vec_loadN_mult(&va0[i], inp, n, zero); temp1 += vec_loadN_mult(&va1[i], inp, n, zero); @@ -220,7 +220,7 @@ static void BF16GEMV_T_VSX_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL BLASLONG i = 0; for (; i < n8; i++) { - vec_load_vec2(v_x, i, inp, zero); + vec_load_vec2(&v_x[i], inp, zero); temp0 += vec_load_mult(&va0[i], inp, zero); temp1 += vec_load_mult(&va1[i], inp, zero); @@ -234,7 +234,7 @@ static void BF16GEMV_T_VSX_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL n &= 7; if (n > 4) { - vec_loadN_vec2(v_x, i, inp, n, zero); + vec_loadN_vec2(&v_x[i], inp, n, zero); temp0 += vec_loadN_mult(&va0[i], inp, n, zero); temp1 += vec_loadN_mult(&va1[i], inp, n, zero); @@ -257,7 +257,7 @@ static void BF16GEMV_T_VSX_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL temp7 += vec_loadNHi_mult(&va7[i], inp[0], n, zero); } - vec_f32 t0, t1, t2, t3; + vec_f32 t0, t1, t2, t3, t10, t11, t12, t13; vec_f32 a = { alpha, alpha, alpha, alpha }; vec_f32 b = { beta, beta, beta, beta }; vec_f32 *v_y = (vec_f32 *) y; @@ -272,14 +272,14 @@ static void BF16GEMV_T_VSX_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL temp3 = vec_mergel(t1, t3); temp0 += temp1 + temp2 + temp3; - t0 = vec_mergeh(temp4, temp6); - t1 = vec_mergel(temp4, temp6); - t2 = vec_mergeh(temp5, temp7); - t3 = vec_mergel(temp5, temp7); - temp4 = vec_mergeh(t0, t2); - temp5 = vec_mergel(t0, t2); - temp6 = vec_mergeh(t1, t3); - temp7 = vec_mergel(t1, t3); + t10 = vec_mergeh(temp4, temp6); + t11 = vec_mergel(temp4, temp6); + t12 = vec_mergeh(temp5, temp7); + t13 = vec_mergel(temp5, temp7); + temp4 = vec_mergeh(t10, t12); + temp5 = vec_mergel(t10, t12); + temp6 = vec_mergeh(t11, t13); + temp7 = vec_mergel(t11, t13); temp4 += temp5 + temp6 + temp7; vec_load_pair(inp, v_y); From e238a68c03db1fe808b4919ac2200089009b1382 Mon Sep 17 00:00:00 2001 From: Chip Kerchner Date: Tue, 1 Oct 2024 11:06:23 -0500 Subject: [PATCH 17/24] Remove duplicate. --- kernel/power/sbgemv_common.c | 7 ------- kernel/power/sbgemv_n_vsx.c | 30 +++++++++++++++--------------- 2 files changed, 15 insertions(+), 22 deletions(-) diff --git a/kernel/power/sbgemv_common.c b/kernel/power/sbgemv_common.c index c9438b7e6d..ad040b3711 100644 --- a/kernel/power/sbgemv_common.c +++ b/kernel/power/sbgemv_common.c @@ -111,13 +111,6 @@ FORCEINLINE vec_f32 vec_loadNHi_mult(vec_bf16 *in, vec_f32 v_inp0, BLASLONG n, v return (v_inp0 * v_in00); } -FORCEINLINE vec_f32 vec_loadNHi_mult2(vec_f32 v_x0, vec_bf16 *in, BLASLONG n, vec_bf16 zero) -{ - vec_f32 v_in00 = vec_loadNHi(in, n, zero); - - return (v_x0 * v_in00); -} - FORCEINLINE vec_f32 vec_loadNHi_vec(vec_bf16 *in, BLASLONG i, BLASLONG n, vec_bf16 zero) { return vec_loadNHi(&in[i], n, zero); diff --git a/kernel/power/sbgemv_n_vsx.c b/kernel/power/sbgemv_n_vsx.c index e8f6dca9fc..390a87359d 100644 --- a/kernel/power/sbgemv_n_vsx.c +++ b/kernel/power/sbgemv_n_vsx.c @@ -80,7 +80,7 @@ static void BF16GEMV_N_VSX_1(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA } else if (n) { vy0[0] = vec_loadN_f32(&v_y[(i * 2) + 0], n); - vy0[0] += vec_loadNHi_mult2(v_x0, &va0[i], n, zero); + vy0[0] += vec_loadNHi_mult(&va0[i], v_x0, n, zero); vec_storeN_f32(vy0[0], &v_y[(i * 2) + 0], n); } @@ -131,8 +131,8 @@ static void BF16GEMV_N_VSX_2(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA } else if (n) { vy0[0] = vec_loadN_f32(&v_y[(i * 2) + 0], n); - vy0[0] += vec_loadNHi_mult2(v_x0, &va0[i], n, zero); - vy0[0] += vec_loadNHi_mult2(v_x1, &va1[i], n, zero); + vy0[0] += vec_loadNHi_mult(&va0[i], v_x0, n, zero); + vy0[0] += vec_loadNHi_mult(&va1[i], v_x1, n, zero); vec_storeN_f32(vy0[0], &v_y[(i * 2) + 0], n); } @@ -193,10 +193,10 @@ static void BF16GEMV_N_VSX_4(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOA } else if (n) { vy0[0] = vec_loadN_f32(&v_y[(i * 2) + 0], n); - vy0[0] += vec_loadNHi_mult2(v_x0, &va0[i], n, zero); - vy0[0] += vec_loadNHi_mult2(v_x1, &va1[i], n, zero); - vy0[0] += vec_loadNHi_mult2(v_x2, &va2[i], n, zero); - vy0[0] += vec_loadNHi_mult2(v_x3, &va3[i], n, zero); + vy0[0] += vec_loadNHi_mult(&va0[i], v_x0, n, zero); + vy0[0] += vec_loadNHi_mult(&va1[i], v_x1, n, zero); + vy0[0] += vec_loadNHi_mult(&va2[i], v_x2, n, zero); + vy0[0] += vec_loadNHi_mult(&va3[i], v_x3, n, zero); vec_storeN_f32(vy0[0], &v_y[(i * 2) + 0], n); } @@ -281,14 +281,14 @@ static void BF16GEMV_N_VSX_8(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, BLAS } else if (n) { vy0[0] = vec_loadN_f32(&v_y[(i * 2) + 0], n); - vy0[0] += vec_loadNHi_mult2(v_x0, &va0[i], n, zero); - vy0[0] += vec_loadNHi_mult2(v_x1, &va1[i], n, zero); - vy0[0] += vec_loadNHi_mult2(v_x2, &va2[i], n, zero); - vy0[0] += vec_loadNHi_mult2(v_x3, &va3[i], n, zero); - vy0[0] += vec_loadNHi_mult2(v_x4, &vb0[i], n, zero); - vy0[0] += vec_loadNHi_mult2(v_x5, &vb1[i], n, zero); - vy0[0] += vec_loadNHi_mult2(v_x6, &vb2[i], n, zero); - vy0[0] += vec_loadNHi_mult2(v_x7, &vb3[i], n, zero); + vy0[0] += vec_loadNHi_mult(&va0[i], v_x0, n, zero); + vy0[0] += vec_loadNHi_mult(&va1[i], v_x1, n, zero); + vy0[0] += vec_loadNHi_mult(&va2[i], v_x2, n, zero); + vy0[0] += vec_loadNHi_mult(&va3[i], v_x3, n, zero); + vy0[0] += vec_loadNHi_mult(&vb0[i], v_x4, n, zero); + vy0[0] += vec_loadNHi_mult(&vb1[i], v_x5, n, zero); + vy0[0] += vec_loadNHi_mult(&vb2[i], v_x6, n, zero); + vy0[0] += vec_loadNHi_mult(&vb3[i], v_x7, n, zero); vec_storeN_f32(vy0[0], &v_y[(i * 2) + 0], n); } From 7cc00f68c999750b8e5da8ffc6faf76cbe4deb58 Mon Sep 17 00:00:00 2001 From: Chip Kerchner Date: Tue, 1 Oct 2024 11:23:32 -0500 Subject: [PATCH 18/24] Remove more duplicate. --- kernel/power/sbgemv_common.c | 5 ----- kernel/power/sbgemv_t_vsx.c | 8 ++++---- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/kernel/power/sbgemv_common.c b/kernel/power/sbgemv_common.c index ad040b3711..156eadce75 100644 --- a/kernel/power/sbgemv_common.c +++ b/kernel/power/sbgemv_common.c @@ -111,11 +111,6 @@ FORCEINLINE vec_f32 vec_loadNHi_mult(vec_bf16 *in, vec_f32 v_inp0, BLASLONG n, v return (v_inp0 * v_in00); } -FORCEINLINE vec_f32 vec_loadNHi_vec(vec_bf16 *in, BLASLONG i, BLASLONG n, vec_bf16 zero) -{ - return vec_loadNHi(&in[i], n, zero); -} - FORCEINLINE void copy_x(BLASLONG n, IFLOAT *src, IFLOAT *dest, BLASLONG inc_src) { for (BLASLONG i = 0; i < n; i++) { diff --git a/kernel/power/sbgemv_t_vsx.c b/kernel/power/sbgemv_t_vsx.c index 272dccef76..9d5e6d9976 100644 --- a/kernel/power/sbgemv_t_vsx.c +++ b/kernel/power/sbgemv_t_vsx.c @@ -66,7 +66,7 @@ static void BF16GEMV_T_VSX_1(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL temp0 += vec_loadN_mult(&va0[i], inp, n, zero); } else if (n) { - inp[0] = vec_loadNHi_vec(v_x, i, n, zero); + inp[0] = vec_loadNHi(&v_x[i], n, zero); temp0 += vec_loadNHi_mult(&va0[i], inp[0], n, zero); } @@ -105,7 +105,7 @@ static void BF16GEMV_T_VSX_2(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL temp0 += vec_loadN_mult(&va0[i], inp, n, zero); temp1 += vec_loadN_mult(&va1[i], inp, n, zero); } else if (n) { - inp[0] = vec_loadNHi_vec(v_x, i, n, zero); + inp[0] = vec_loadNHi(&v_x[i], n, zero); temp0 += vec_loadNHi_mult(&va0[i], inp[0], n, zero); temp1 += vec_loadNHi_mult(&va1[i], inp[0], n, zero); @@ -156,7 +156,7 @@ static void BF16GEMV_T_VSX_4(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL temp2 += vec_loadN_mult(&va2[i], inp, n, zero); temp3 += vec_loadN_mult(&va3[i], inp, n, zero); } else if (n) { - inp[0] = vec_loadNHi_vec(v_x, i, n, zero); + inp[0] = vec_loadNHi(&v_x[i], n, zero); temp0 += vec_loadNHi_mult(&va0[i], inp[0], n, zero); temp1 += vec_loadNHi_mult(&va1[i], inp[0], n, zero); @@ -245,7 +245,7 @@ static void BF16GEMV_T_VSX_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL temp6 += vec_loadN_mult(&va6[i], inp, n, zero); temp7 += vec_loadN_mult(&va7[i], inp, n, zero); } else if (n) { - inp[0] = vec_loadNHi_vec(v_x, i, n, zero); + inp[0] = vec_loadNHi(&v_x[i], n, zero); temp0 += vec_loadNHi_mult(&va0[i], inp[0], n, zero); temp1 += vec_loadNHi_mult(&va1[i], inp[0], n, zero); From 7ec3c16d822b18499c7276d3e7fe16bdf186f0a3 Mon Sep 17 00:00:00 2001 From: Chip Kerchner Date: Thu, 3 Oct 2024 13:27:33 -0500 Subject: [PATCH 19/24] Remove beta from optimized functions. --- Makefile.system | 1 + kernel/power/sbgemv_common.c | 65 ++++++++++++++++++++++++++++++++- kernel/power/sbgemv_n.c | 59 ------------------------------ kernel/power/sbgemv_t.c | 28 ++++++++------ kernel/power/sbgemv_t_power10.c | 22 +++++------ kernel/power/sbgemv_t_vsx.c | 22 +++++------ 6 files changed, 101 insertions(+), 96 deletions(-) diff --git a/Makefile.system b/Makefile.system index 2c5ca96906..8c030842a4 100644 --- a/Makefile.system +++ b/Makefile.system @@ -282,6 +282,7 @@ GEMM_GEMV_FORWARD = 1 endif ifeq ($(ARCH), power) GEMM_GEMV_FORWARD = 1 +GEMM_GEMV_FORWARD_BF16 = 1 endif ifeq ($(SMALL_MATRIX_OPT), 1) diff --git a/kernel/power/sbgemv_common.c b/kernel/power/sbgemv_common.c index 156eadce75..47de837cc5 100644 --- a/kernel/power/sbgemv_common.c +++ b/kernel/power/sbgemv_common.c @@ -122,7 +122,10 @@ FORCEINLINE void copy_x(BLASLONG n, IFLOAT *src, IFLOAT *dest, BLASLONG inc_src) FORCEINLINE void copy_y_beta(BLASLONG n, FLOAT *src, FLOAT *dest, BLASLONG inc_src, FLOAT beta) { if (beta == 0) { - memset(dest, 0, sizeof(FLOAT) * n); + for (BLASLONG i = 0; i < n; i++) { + *dest++ = (FLOAT)0; + src += inc_src; + } } else if (beta == 1) { for (BLASLONG i = 0; i < n; i++) { *dest++ = *src; @@ -163,4 +166,64 @@ FORCEINLINE void move_y(BLASLONG n, FLOAT *src, FLOAT *dest, BLASLONG inc_dest) dest += inc_dest; } } + +static void BF16GEMV_N_beta(BLASLONG n, FLOAT *output_vector, FLOAT *input_vector, FLOAT beta) +{ + if (beta == 0) { + memset(output_vector, 0, sizeof(FLOAT) * n); + } else if (beta == 1) { + if (output_vector != input_vector) { + memcpy(output_vector, input_vector, sizeof(FLOAT) * n); + } + } else { + vec_f32 b = { beta, beta, beta, beta }; + + vec_f32 *in = (vec_f32 *)input_vector; + vec_f32 *out = (vec_f32 *)output_vector; + + BLASLONG n8 = n / 8; + BLASLONG i = 0; + vec_f32 v_inp0[2]; + + for (; i + 4 <= n8; i += 4) { + vec_f32 v_inp1[2], v_inp2[2], v_inp3[2]; + vec_load_pair(v_inp0, &in[(i * 2) + 0]); + vec_load_pair(v_inp1, &in[(i * 2) + 2]); + vec_load_pair(v_inp2, &in[(i * 2) + 4]); + vec_load_pair(v_inp3, &in[(i * 2) + 6]); + v_inp0[0] *= b; + v_inp0[1] *= b; + v_inp1[0] *= b; + v_inp1[1] *= b; + v_inp2[0] *= b; + v_inp2[1] *= b; + v_inp3[0] *= b; + v_inp3[1] *= b; + vec_store_pair(&out[(i * 2) + 0], v_inp0); + vec_store_pair(&out[(i * 2) + 2], v_inp1); + vec_store_pair(&out[(i * 2) + 4], v_inp2); + vec_store_pair(&out[(i * 2) + 6], v_inp3); + } + + for (; i < n8; i++) { + vec_load_pair(v_inp0, &in[(i * 2) + 0]); + v_inp0[0] *= b; + v_inp0[1] *= b; + vec_store_pair(&out[(i * 2) + 0], v_inp0); + } + + n &= 7; + if (n > 4) { + BLASLONG n3 = n & 3; + vec_loadN2_f32(v_inp0, &in[(i * 2) + 0], n3); + v_inp0[0] *= b; + v_inp0[1] *= b; + vec_storeN2_f32(v_inp0, &out[(i * 2) + 0], n3); + } else if (n) { + v_inp0[0] = vec_loadN_f32(&in[(i * 2) + 0], n); + v_inp0[0] *= b; + vec_storeN_f32(v_inp0[0], &out[(i * 2) + 0], n); + } + } +} #endif diff --git a/kernel/power/sbgemv_n.c b/kernel/power/sbgemv_n.c index eab0b4e33b..e6f7f587e6 100644 --- a/kernel/power/sbgemv_n.c +++ b/kernel/power/sbgemv_n.c @@ -27,65 +27,6 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #ifndef SBGEMV_N_COMMON_C #define SBGEMV_N_COMMON_C -static void BF16GEMV_N_beta(BLASLONG n, FLOAT *output_vector, FLOAT *input_vector, FLOAT beta) -{ - if (beta == 0) { - memset(output_vector, 0, sizeof(FLOAT) * n); - } else if (beta == 1) { - if (output_vector != input_vector) { - memcpy(output_vector, input_vector, sizeof(FLOAT) * n); - } - } else { - vec_f32 b = { beta, beta, beta, beta }; - - vec_f32 *in = (vec_f32 *)input_vector; - vec_f32 *out = (vec_f32 *)output_vector; - - BLASLONG n8 = n / 8; - BLASLONG i = 0; - vec_f32 v_inp0[2]; - - for (; i + 4 <= n8; i += 4) { - vec_f32 v_inp1[2], v_inp2[2], v_inp3[2]; - vec_load_pair(v_inp0, &in[(i * 2) + 0]); - vec_load_pair(v_inp1, &in[(i * 2) + 2]); - vec_load_pair(v_inp2, &in[(i * 2) + 4]); - vec_load_pair(v_inp3, &in[(i * 2) + 6]); - v_inp0[0] *= b; - v_inp0[1] *= b; - v_inp1[0] *= b; - v_inp1[1] *= b; - v_inp2[0] *= b; - v_inp2[1] *= b; - v_inp3[0] *= b; - v_inp3[1] *= b; - vec_store_pair(&out[(i * 2) + 0], v_inp0); - vec_store_pair(&out[(i * 2) + 2], v_inp1); - vec_store_pair(&out[(i * 2) + 4], v_inp2); - vec_store_pair(&out[(i * 2) + 6], v_inp3); - } - - for (; i < n8; i++) { - vec_load_pair(v_inp0, &in[(i * 2) + 0]); - v_inp0[0] *= b; - v_inp0[1] *= b; - vec_store_pair(&out[(i * 2) + 0], v_inp0); - } - - n &= 7; - if (n > 4) { - BLASLONG n3 = n & 3; - vec_loadN2_f32(v_inp0, &in[(i * 2) + 0], n3); - v_inp0[0] *= b; - v_inp0[1] *= b; - vec_storeN2_f32(v_inp0, &out[(i * 2) + 0], n3); - } else if (n) { - v_inp0[0] = vec_loadN_f32(&in[(i * 2) + 0], n); - v_inp0[0] *= b; - vec_storeN_f32(v_inp0[0], &out[(i * 2) + 0], n); - } - } -} #if (defined(_ARCH_PWR10) && (defined(USE_BFGEMV_8_N_MMA) || (!defined(USE_BFGEMV_N_MMA) && defined(USE_BFGEMV_8_N_VSX)))) || (!defined(_ARCH_PWR10) && defined(USE_BFGEMV_8_N_VSX)) #define USE_N_8 diff --git a/kernel/power/sbgemv_t.c b/kernel/power/sbgemv_t.c index c6fdb6b1ae..594b1fc57b 100644 --- a/kernel/power/sbgemv_t.c +++ b/kernel/power/sbgemv_t.c @@ -41,6 +41,10 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT * if ((m < 1) || (n < 1)) return 0; + if (inc_y == 1) { + BF16GEMV_N_beta(n, y, y, beta); + } + xbuffer = buffer; BLASLONG lda4 = lda << 2; @@ -58,18 +62,21 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT * } a_ptr = a; + a += NB; y_ptr = y; if (inc_x != 1) { copy_x(NB, x, xbuffer, inc_x); + x += NB * inc_x; } else { xbuffer = x; + x += NB; } if (inc_y == 1) { #ifdef USE_T_8 for (BLASLONG j = 0; j + 8 <= n; j += 8) { - BF16GEMV_T_8(NB, lda, a_ptr, xbuffer, y_ptr, alpha, beta); + BF16GEMV_T_8(NB, lda, a_ptr, xbuffer, y_ptr, alpha); y_ptr += 8; a_ptr += lda8; } @@ -77,23 +84,23 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT * #else for (BLASLONG j = 0; j + 4 <= n; j += 4) { #endif - BF16GEMV_T_4(NB, lda, a_ptr, xbuffer, y_ptr, alpha, beta); + BF16GEMV_T_4(NB, lda, a_ptr, xbuffer, y_ptr, alpha); y_ptr += 4; a_ptr += lda4; } if (n & 2) { - BF16GEMV_T_2(NB, lda, a_ptr, xbuffer, y_ptr, alpha, beta); + BF16GEMV_T_2(NB, lda, a_ptr, xbuffer, y_ptr, alpha); y_ptr += 2; a_ptr += (lda * 2); } if (n & 1) { - BF16GEMV_T_1(NB, lda, a_ptr, xbuffer, y_ptr, alpha, beta); + BF16GEMV_T_1(NB, lda, a_ptr, xbuffer, y_ptr, alpha); } } else { #ifdef USE_T_8 for (BLASLONG j = 0; j + 8 <= n; j += 8) { memset(ybuffer, 0, sizeof(FLOAT) * 8); - BF16GEMV_T_8(NB, lda, a_ptr, xbuffer, ybuffer, alpha, beta); + BF16GEMV_T_8(NB, lda, a_ptr, xbuffer, ybuffer, alpha); copy_y(8, ybuffer, y_ptr, inc_y, beta); y_ptr += 8 * inc_y; a_ptr += lda8; @@ -103,28 +110,25 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT * for (BLASLONG j = 0; j + 4 <= n; j += 4) { #endif memset(ybuffer, 0, sizeof(FLOAT) * 4); - BF16GEMV_T_4(NB, lda, a_ptr, xbuffer, ybuffer, alpha, beta); + BF16GEMV_T_4(NB, lda, a_ptr, xbuffer, ybuffer, alpha); copy_y(4, ybuffer, y_ptr, inc_y, beta); y_ptr += 4 * inc_y; a_ptr += lda4; } if (n & 2) { memset(ybuffer, 0, sizeof(FLOAT) * 4); - BF16GEMV_T_2(NB, lda, a_ptr, xbuffer, ybuffer, alpha, beta); + BF16GEMV_T_2(NB, lda, a_ptr, xbuffer, ybuffer, alpha); copy_y(2, ybuffer, y_ptr, inc_y, beta); y_ptr += 2 * inc_y; a_ptr += (lda * 2); } if (n & 1) { memset(ybuffer, 0, sizeof(FLOAT) * 4); - BF16GEMV_T_1(NB, lda, a_ptr, xbuffer, ybuffer, alpha, beta); + BF16GEMV_T_1(NB, lda, a_ptr, xbuffer, ybuffer, alpha); copy_y(1, ybuffer, y_ptr, inc_y, beta); } + beta = (FLOAT)1; } - - a += NB; - x += NB * inc_x; - beta = (FLOAT)1; } return 0; diff --git a/kernel/power/sbgemv_t_power10.c b/kernel/power/sbgemv_t_power10.c index d2f6087f05..40c166354b 100644 --- a/kernel/power/sbgemv_t_power10.c +++ b/kernel/power/sbgemv_t_power10.c @@ -43,7 +43,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #define USE_BFGEMV_8_T_MMA -static void BF16GEMV_T_MMA_1(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha, FLOAT beta) +static void BF16GEMV_T_MMA_1(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha) { IFLOAT *a0; vec_bf16 *va0, *v_x; @@ -90,10 +90,10 @@ static void BF16GEMV_T_MMA_1(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL __builtin_mma_disassemble_acc((void*)temp00, &temp0); - y[0] = (alpha * (temp00[0][0] + temp00[1][1] + temp00[2][2] + temp00[3][3])) + (beta * y[0]); + y[0] += (alpha * (temp00[0][0] + temp00[1][1] + temp00[2][2] + temp00[3][3])); } -static void BF16GEMV_T_MMA_2(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha, FLOAT beta) +static void BF16GEMV_T_MMA_2(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha) { IFLOAT *a0, *a1; vec_bf16 *va0, *va1, *v_x; @@ -142,11 +142,11 @@ static void BF16GEMV_T_MMA_2(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL vec_reduce_2(temp00, &temp0[0]); - y[0] = (alpha * (temp00[0][0] + temp00[1][1] + temp00[2][2] + temp00[3][3])) + (beta * y[0]); - y[1] = (alpha * (temp00[4][0] + temp00[5][1] + temp00[6][2] + temp00[7][3])) + (beta * y[1]); + y[0] += (alpha * (temp00[0][0] + temp00[1][1] + temp00[2][2] + temp00[3][3])); + y[1] += (alpha * (temp00[4][0] + temp00[5][1] + temp00[6][2] + temp00[7][3])); } -static void BF16GEMV_T_MMA_4(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha, FLOAT beta) +static void BF16GEMV_T_MMA_4(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha) { IFLOAT *a0, *a1, *a2, *a3; vec_bf16 *va0, *va1, *va2, *va3, *v_x; @@ -201,7 +201,6 @@ static void BF16GEMV_T_MMA_4(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL vec_f32 t0, t1, t2, t3, t4, t5, t6, t7; vec_f32 a = { alpha, alpha, alpha, alpha }; - vec_f32 b = { beta, beta, beta, beta }; vec_f32 *v_y = (vec_f32 *) y; t0 = vec_mergeh(temp00[ 0], temp00[ 4]); @@ -219,11 +218,11 @@ static void BF16GEMV_T_MMA_4(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL t0 += t2 + t4 + t6; - v_y[0] = (a * t0) + (b * v_y[0]); + v_y[0] += (a * t0); } #ifdef USE_BFGEMV_8_T_MMA -static void BF16GEMV_T_MMA_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha, FLOAT beta) +static void BF16GEMV_T_MMA_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha) { IFLOAT *a0, *a1, *a2, *a3, *a4, *a5, *a6, *a7; vec_bf16 *va0, *va1, *va2, *va3, *va4, *va5, *va6, *va7, *v_x; @@ -291,7 +290,6 @@ static void BF16GEMV_T_MMA_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL vec_f32 t0, t1, t2, t3, t4, t5, t6, t7, t10, t11, t12, t13, t14, t15, t16, t17; vec_f32 a = { alpha, alpha, alpha, alpha }; - vec_f32 b = { beta, beta, beta, beta }; vec_f32 *v_y = (vec_f32 *) y; t0 = vec_mergeh(temp00[ 0], temp00[ 4]); @@ -326,8 +324,8 @@ static void BF16GEMV_T_MMA_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL vec_f32 inp2[2]; vec_load_pair(inp2, v_y); - inp2[0] = (a * t0) + (b * inp2[0]); - inp2[1] = (a * t10) + (b * inp2[1]); + inp2[0] += (a * t0); + inp2[1] += (a * t10); vec_store_pair(v_y, inp2); } #endif diff --git a/kernel/power/sbgemv_t_vsx.c b/kernel/power/sbgemv_t_vsx.c index 9d5e6d9976..e72d2f31e0 100644 --- a/kernel/power/sbgemv_t_vsx.c +++ b/kernel/power/sbgemv_t_vsx.c @@ -40,7 +40,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #define USE_BFGEMV_8_T_VSX -static void BF16GEMV_T_VSX_1(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha, FLOAT beta) +static void BF16GEMV_T_VSX_1(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha) { IFLOAT *a0; vec_bf16 *va0, *v_x; @@ -71,10 +71,10 @@ static void BF16GEMV_T_VSX_1(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL temp0 += vec_loadNHi_mult(&va0[i], inp[0], n, zero); } - y[0] = (alpha * (temp0[0] + temp0[1] + temp0[2] + temp0[3])) + (beta * y[0]); + y[0] += (alpha * (temp0[0] + temp0[1] + temp0[2] + temp0[3])); } -static void BF16GEMV_T_VSX_2(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha, FLOAT beta) +static void BF16GEMV_T_VSX_2(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha) { IFLOAT *a0, *a1; vec_bf16 *va0, *va1, *v_x; @@ -111,11 +111,11 @@ static void BF16GEMV_T_VSX_2(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL temp1 += vec_loadNHi_mult(&va1[i], inp[0], n, zero); } - y[0] = (alpha * (temp0[0] + temp0[1] + temp0[2] + temp0[3])) + (beta * y[0]); - y[1] = (alpha * (temp1[0] + temp1[1] + temp1[2] + temp1[3])) + (beta * y[1]); + y[0] += (alpha * (temp0[0] + temp0[1] + temp0[2] + temp0[3])); + y[1] += (alpha * (temp1[0] + temp1[1] + temp1[2] + temp1[3])); } -static void BF16GEMV_T_VSX_4(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha, FLOAT beta) +static void BF16GEMV_T_VSX_4(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha) { IFLOAT *a0, *a1, *a2, *a3; vec_bf16 *va0, *va1, *va2, *va3, *v_x; @@ -166,7 +166,6 @@ static void BF16GEMV_T_VSX_4(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL vec_f32 t0, t1, t2, t3; vec_f32 a = { alpha, alpha, alpha, alpha }; - vec_f32 b = { beta, beta, beta, beta }; vec_f32 *v_y = (vec_f32 *) y; t0 = vec_mergeh(temp0, temp2); @@ -179,11 +178,11 @@ static void BF16GEMV_T_VSX_4(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL temp3 = vec_mergel(t1, t3); temp0 += temp1 + temp2 + temp3; - v_y[0] = (a * temp0) + (b * v_y[0]); + v_y[0] += (a * temp0); } #ifdef USE_BFGEMV_8_T_VSX -static void BF16GEMV_T_VSX_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha, FLOAT beta) +static void BF16GEMV_T_VSX_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha) { IFLOAT *a0, *a1, *a2, *a3, *a4, *a5, *a6, *a7; vec_bf16 *va0, *va1, *va2, *va3, *va4, *va5, *va6, *va7, *v_x; @@ -259,7 +258,6 @@ static void BF16GEMV_T_VSX_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL vec_f32 t0, t1, t2, t3, t10, t11, t12, t13; vec_f32 a = { alpha, alpha, alpha, alpha }; - vec_f32 b = { beta, beta, beta, beta }; vec_f32 *v_y = (vec_f32 *) y; t0 = vec_mergeh(temp0, temp2); @@ -283,8 +281,8 @@ static void BF16GEMV_T_VSX_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL temp4 += temp5 + temp6 + temp7; vec_load_pair(inp, v_y); - inp[0] = (a * temp0) + (b * inp[0]); - inp[1] = (a * temp4) + (b * inp[1]); + inp[0] += (a * temp0); + inp[1] += (a * temp4); vec_store_pair(v_y, inp); } #endif From 915a6d6e44b838e7618e56021bf8dee6163b6ff0 Mon Sep 17 00:00:00 2001 From: Chip Kerchner Date: Thu, 3 Oct 2024 14:08:21 -0500 Subject: [PATCH 20/24] Add casting. --- kernel/power/sbgemv_common.c | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/kernel/power/sbgemv_common.c b/kernel/power/sbgemv_common.c index 47de837cc5..8ad7f92e73 100644 --- a/kernel/power/sbgemv_common.c +++ b/kernel/power/sbgemv_common.c @@ -121,12 +121,12 @@ FORCEINLINE void copy_x(BLASLONG n, IFLOAT *src, IFLOAT *dest, BLASLONG inc_src) FORCEINLINE void copy_y_beta(BLASLONG n, FLOAT *src, FLOAT *dest, BLASLONG inc_src, FLOAT beta) { - if (beta == 0) { + if (beta == (FLOAT)0) { for (BLASLONG i = 0; i < n; i++) { *dest++ = (FLOAT)0; src += inc_src; } - } else if (beta == 1) { + } else if (beta == (FLOAT)1) { for (BLASLONG i = 0; i < n; i++) { *dest++ = *src; src += inc_src; @@ -141,12 +141,12 @@ FORCEINLINE void copy_y_beta(BLASLONG n, FLOAT *src, FLOAT *dest, BLASLONG inc_s FORCEINLINE void copy_y(BLASLONG n, FLOAT *src, FLOAT *dest, BLASLONG inc_src, FLOAT beta) { - if (beta == 0) { + if (beta == (FLOAT)0) { for (BLASLONG i = 0; i < n; i++) { *dest = *src++; dest += inc_src; } - } else if (beta == 1) { + } else if (beta == (FLOAT)1) { for (BLASLONG i = 0; i < n; i++) { *dest += *src++; dest += inc_src; @@ -169,9 +169,9 @@ FORCEINLINE void move_y(BLASLONG n, FLOAT *src, FLOAT *dest, BLASLONG inc_dest) static void BF16GEMV_N_beta(BLASLONG n, FLOAT *output_vector, FLOAT *input_vector, FLOAT beta) { - if (beta == 0) { + if (beta == (FLOAT)0) { memset(output_vector, 0, sizeof(FLOAT) * n); - } else if (beta == 1) { + } else if (beta == (FLOAT)1) { if (output_vector != input_vector) { memcpy(output_vector, input_vector, sizeof(FLOAT) * n); } From d6bb8dcfd1139037ec7538a64b9e143a05216740 Mon Sep 17 00:00:00 2001 From: Chip Kerchner Date: Sun, 6 Oct 2024 14:13:43 -0500 Subject: [PATCH 21/24] Common code. --- kernel/power/sbgemv_common.c | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/kernel/power/sbgemv_common.c b/kernel/power/sbgemv_common.c index 8ad7f92e73..830481fef3 100644 --- a/kernel/power/sbgemv_common.c +++ b/kernel/power/sbgemv_common.c @@ -122,10 +122,7 @@ FORCEINLINE void copy_x(BLASLONG n, IFLOAT *src, IFLOAT *dest, BLASLONG inc_src) FORCEINLINE void copy_y_beta(BLASLONG n, FLOAT *src, FLOAT *dest, BLASLONG inc_src, FLOAT beta) { if (beta == (FLOAT)0) { - for (BLASLONG i = 0; i < n; i++) { - *dest++ = (FLOAT)0; - src += inc_src; - } + memset(dest, 0, n * sizeof(FLOAT)); } else if (beta == (FLOAT)1) { for (BLASLONG i = 0; i < n; i++) { *dest++ = *src; @@ -139,13 +136,18 @@ FORCEINLINE void copy_y_beta(BLASLONG n, FLOAT *src, FLOAT *dest, BLASLONG inc_s } } +FORCEINLINE void move_y(BLASLONG n, FLOAT *src, FLOAT *dest, BLASLONG inc_dest) +{ + for (BLASLONG i = 0; i < n; i++) { + *dest = *src++; + dest += inc_dest; + } +} + FORCEINLINE void copy_y(BLASLONG n, FLOAT *src, FLOAT *dest, BLASLONG inc_src, FLOAT beta) { if (beta == (FLOAT)0) { - for (BLASLONG i = 0; i < n; i++) { - *dest = *src++; - dest += inc_src; - } + move_y(n, src, dest, inc_src); } else if (beta == (FLOAT)1) { for (BLASLONG i = 0; i < n; i++) { *dest += *src++; @@ -159,14 +161,6 @@ FORCEINLINE void copy_y(BLASLONG n, FLOAT *src, FLOAT *dest, BLASLONG inc_src, F } } -FORCEINLINE void move_y(BLASLONG n, FLOAT *src, FLOAT *dest, BLASLONG inc_dest) -{ - for (BLASLONG i = 0; i < n; i++) { - *dest = *src++; - dest += inc_dest; - } -} - static void BF16GEMV_N_beta(BLASLONG n, FLOAT *output_vector, FLOAT *input_vector, FLOAT beta) { if (beta == (FLOAT)0) { From f8e113f27b3a10911fe6b382148aeb846d4ade08 Mon Sep 17 00:00:00 2001 From: Chip Kerchner Date: Sun, 13 Oct 2024 10:55:03 -0500 Subject: [PATCH 22/24] Replace types with include file. --- kernel/power/gemm_common.c | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/kernel/power/gemm_common.c b/kernel/power/gemm_common.c index ed00de95b0..88fa52de53 100644 --- a/kernel/power/gemm_common.c +++ b/kernel/power/gemm_common.c @@ -3,17 +3,12 @@ #include "common.h" #include +#include #define NBMAX 4096 #define FORCEINLINE inline __attribute__((always_inline)) -#ifdef __clang__ -#define uint16_t unsigned short -#define uint32_t unsigned int -#define uint64_t unsigned long long -#endif - #ifdef _ARCH_PWR10 #ifdef __has_builtin #if !__has_builtin(__builtin_vsx_assemble_pair) From 36bd3eeddfe2b21353789da39e67bc9523e22d5a Mon Sep 17 00:00:00 2001 From: Chip Kerchner Date: Sun, 13 Oct 2024 13:46:11 -0500 Subject: [PATCH 23/24] Vectorize BF16 GEMV (VSX & MMA). Use GEMM_GEMV_FORWARD_BF16 (for Power). --- Makefile.system | 6 +- cmake/system.cmake | 3 + interface/gemm.c | 2 +- kernel/power/KERNEL.POWER10 | 2 + kernel/power/KERNEL.POWER8 | 2 + kernel/power/KERNEL.POWER9 | 2 + kernel/power/gemm_common.c | 153 +++++++ kernel/power/sbgemv_common.c | 223 ++++++++++ kernel/power/sbgemv_common_power10.c | 629 +++++++++++++++++++++++++++ kernel/power/sbgemv_n.c | 152 +++++++ kernel/power/sbgemv_n_power10.c | 474 ++++++++++++++++++++ kernel/power/sbgemv_n_vsx.c | 299 +++++++++++++ kernel/power/sbgemv_t.c | 137 ++++++ kernel/power/sbgemv_t_power10.c | 338 ++++++++++++++ kernel/power/sbgemv_t_vsx.c | 292 +++++++++++++ test/compare_sgemm_sbgemm.c | 30 +- 16 files changed, 2728 insertions(+), 16 deletions(-) create mode 100644 kernel/power/gemm_common.c create mode 100644 kernel/power/sbgemv_common.c create mode 100644 kernel/power/sbgemv_common_power10.c create mode 100644 kernel/power/sbgemv_n.c create mode 100644 kernel/power/sbgemv_n_power10.c create mode 100644 kernel/power/sbgemv_n_vsx.c create mode 100644 kernel/power/sbgemv_t.c create mode 100644 kernel/power/sbgemv_t_power10.c create mode 100644 kernel/power/sbgemv_t_vsx.c diff --git a/Makefile.system b/Makefile.system index 7bae728552..8351b8efb2 100644 --- a/Makefile.system +++ b/Makefile.system @@ -282,15 +282,19 @@ GEMM_GEMV_FORWARD = 1 endif ifeq ($(ARCH), power) GEMM_GEMV_FORWARD = 1 +GEMM_GEMV_FORWARD_BF16 = 1 endif ifeq ($(SMALL_MATRIX_OPT), 1) CCOMMON_OPT += -DSMALL_MATRIX_OPT endif -ifeq ($(GEMM_GEMV_FORWARD), 1) ifneq ($(ONLY_CBLAS), 1) +ifeq ($(GEMM_GEMV_FORWARD), 1) CCOMMON_OPT += -DGEMM_GEMV_FORWARD endif +ifeq ($(GEMM_GEMV_FORWARD_BF16), 1) +CCOMMON_OPT += -DGEMM_GEMV_FORWARD_BF16 +endif endif # This operation is expensive, so execution should be once. diff --git a/cmake/system.cmake b/cmake/system.cmake index d49d53449a..6b891ca0ef 100644 --- a/cmake/system.cmake +++ b/cmake/system.cmake @@ -398,6 +398,9 @@ endif () if (GEMM_GEMV_FORWARD AND NOT ONLY_CBLAS) set(CCOMMON_OPT "${CCOMMON_OPT} -DGEMM_GEMV_FORWARD") endif () +if (GEMM_GEMV_FORWARD_BF16 AND NOT ONLY_CBLAS) + set(CCOMMON_OPT "${CCOMMON_OPT} -DGEMM_GEMV_FORWARD_BF16") +endif () if (SMALL_MATRIX_OPT) set(CCOMMON_OPT "${CCOMMON_OPT} -DSMALL_MATRIX_OPT") endif () diff --git a/interface/gemm.c b/interface/gemm.c index c030947b6f..5742d36c4b 100644 --- a/interface/gemm.c +++ b/interface/gemm.c @@ -498,7 +498,7 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS args.m, args.n, args.k, args.lda, args.ldb, args.ldc); #endif -#if defined(GEMM_GEMV_FORWARD) && !defined(GEMM3M) && !defined(COMPLEX) && !defined(BFLOAT16) +#if defined(GEMM_GEMV_FORWARD) && !defined(GEMM3M) && !defined(COMPLEX) && (!defined(BFLOAT16) || defined(GEMM_GEMV_FORWARD_BF16)) // Check if we can convert GEMM -> GEMV if (args.k != 0) { if (args.n == 1) { diff --git a/kernel/power/KERNEL.POWER10 b/kernel/power/KERNEL.POWER10 index 4d17944ae7..c009e33cf4 100644 --- a/kernel/power/KERNEL.POWER10 +++ b/kernel/power/KERNEL.POWER10 @@ -236,11 +236,13 @@ ZSWAPKERNEL = zswap.c # SGEMVNKERNEL = sgemv_n.c +SBGEMVNKERNEL = sbgemv_n_power10.c DGEMVNKERNEL = dgemv_n_power10.c CGEMVNKERNEL = cgemv_n.c ZGEMVNKERNEL = zgemv_n_power10.c # SGEMVTKERNEL = sgemv_t.c +SBGEMVTKERNEL = sbgemv_t_power10.c DGEMVTKERNEL = dgemv_t_power10.c CGEMVTKERNEL = cgemv_t.c ZGEMVTKERNEL = zgemv_t_4.c diff --git a/kernel/power/KERNEL.POWER8 b/kernel/power/KERNEL.POWER8 index 700a68e447..001401d532 100644 --- a/kernel/power/KERNEL.POWER8 +++ b/kernel/power/KERNEL.POWER8 @@ -257,11 +257,13 @@ ZSWAPKERNEL = zswap.c # SGEMVNKERNEL = sgemv_n.c +SBGEMVNKERNEL = sbgemv_n_vsx.c DGEMVNKERNEL = dgemv_n.c CGEMVNKERNEL = cgemv_n.c ZGEMVNKERNEL = zgemv_n_4.c # SGEMVTKERNEL = sgemv_t.c +SBGEMVTKERNEL = sbgemv_t_vsx.c DGEMVTKERNEL = dgemv_t.c CGEMVTKERNEL = cgemv_t.c ZGEMVTKERNEL = zgemv_t_4.c diff --git a/kernel/power/KERNEL.POWER9 b/kernel/power/KERNEL.POWER9 index 7d007d1a2b..a18c31a2e9 100644 --- a/kernel/power/KERNEL.POWER9 +++ b/kernel/power/KERNEL.POWER9 @@ -181,11 +181,13 @@ ZSWAPKERNEL = zswap.c # SGEMVNKERNEL = sgemv_n.c +SBGEMVNKERNEL = sbgemv_n_vsx.c DGEMVNKERNEL = dgemv_n.c CGEMVNKERNEL = cgemv_n.c ZGEMVNKERNEL = zgemv_n_4.c # SGEMVTKERNEL = sgemv_t.c +SBGEMVTKERNEL = sbgemv_t_vsx.c DGEMVTKERNEL = dgemv_t.c CGEMVTKERNEL = cgemv_t.c ZGEMVTKERNEL = zgemv_t_4.c diff --git a/kernel/power/gemm_common.c b/kernel/power/gemm_common.c new file mode 100644 index 0000000000..88fa52de53 --- /dev/null +++ b/kernel/power/gemm_common.c @@ -0,0 +1,153 @@ +#ifndef GEMM_COMMON_C +#define GEMM_COMMON_C +#include "common.h" + +#include +#include + +#define NBMAX 4096 + +#define FORCEINLINE inline __attribute__((always_inline)) + +#ifdef _ARCH_PWR10 +#ifdef __has_builtin +#if !__has_builtin(__builtin_vsx_assemble_pair) +#define __builtin_vsx_assemble_pair __builtin_mma_assemble_pair +#endif +#if !__has_builtin(__builtin_vsx_disassemble_pair) +#define __builtin_vsx_disassemble_pair __builtin_mma_disassemble_pair +#endif +#endif + +#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ +#define __builtin_vsx_assemble_pair2(vp0, v0, v1) __builtin_vsx_assemble_pair(vp0, v1, v0) +#else +#define __builtin_vsx_assemble_pair2(vp0, v0, v1) __builtin_vsx_assemble_pair(vp0, v0, v1) +#endif + +#define USE_VECTOR_PAIRS +#endif + +typedef __vector IFLOAT vec_bf16; +typedef __vector FLOAT vec_f32; +typedef __vector unsigned char vec_uc8; + +FORCEINLINE vec_uc8 vec_load_vec(void *src) +{ + return vec_xl(0, (unsigned char *)(src)); +} + +FORCEINLINE void vec_load_pair(vec_f32 *dst, vec_f32 *src) +{ +#ifdef USE_VECTOR_PAIRS + __vector_pair vy0p; +#ifdef __clang__ + vy0p = __builtin_vsx_lxvp(0L, (const __vector_pair *)(src)); +#else + vy0p = *(__vector_pair *)(src); +#endif + __builtin_vsx_disassemble_pair((void *)(dst), &vy0p); +#else + dst[0] = src[0]; + dst[1] = src[1]; +#endif +} + +FORCEINLINE void vec_store_pair(vec_f32 *dst, vec_f32 *src) +{ +#ifdef USE_VECTOR_PAIRS + __vector_pair vy0p; + __builtin_vsx_assemble_pair2(&vy0p, (vec_uc8)src[1], (vec_uc8)src[0]); +#ifdef __clang__ + __builtin_vsx_stxvp(vy0p, 0L, (__vector_pair *)(dst)); +#else + *(__vector_pair *)(dst) = vy0p; +#endif +#else + dst[0] = src[0]; + dst[1] = src[1]; +#endif +} + +FORCEINLINE vec_bf16 vec_loadN(void *src, BLASLONG n) +{ + IFLOAT *src2 = (IFLOAT *)(src); +#ifdef _ARCH_PWR9 + return vec_xl_len(src2, n * sizeof(IFLOAT)); +#else + __attribute__((aligned(16))) IFLOAT data[sizeof(vec_bf16) / sizeof(IFLOAT)]; + memset(data, 0, sizeof(vec_bf16)); + if (n & 4) { + memcpy(data, src2, sizeof(uint64_t)); + } + if (n & 2) { + BLASLONG n4 = n & 4; + memcpy(data + n4, src2 + n4, sizeof(uint32_t)); + } + if (n & 1) { + BLASLONG n6 = n & 6; + data[n6] = src2[n6]; + } + return (vec_bf16)vec_load_vec(data); +#endif +} + +FORCEINLINE vec_f32 vec_loadN_f32(void *src, BLASLONG n) +{ +#ifndef _ARCH_PWR9 + if (n & 4) { + return (vec_f32)vec_load_vec(src); + } +#endif + return (vec_f32)vec_loadN(src, n * (sizeof(FLOAT) / sizeof(IFLOAT))); +} + +FORCEINLINE void vec_loadN2_f32(vec_f32 *data, vec_f32 *src, BLASLONG n) +{ + data[0] = src[0]; + data[1] = vec_loadN_f32(&src[1], n); +} + +FORCEINLINE void vec_storeN(vec_bf16 data, void *dst, BLASLONG n) +{ + IFLOAT *dst2 = (IFLOAT *)(dst); +#ifdef _ARCH_PWR9 + vec_xst_len(data, dst2, n * sizeof(IFLOAT)); +#else + if (n & 8) { + vec_xst(data, 0, dst2); + return; + } + __attribute__((aligned(16))) IFLOAT data2[sizeof(vec_f32) / sizeof(IFLOAT)]; + vec_xst(data, 0, data2); + if (n & 4) { + memcpy(dst2, data2, sizeof(uint64_t)); + } + if (n & 2) { + BLASLONG n4 = n & 4; + memcpy(dst2 + n4, data2 + n4, sizeof(uint32_t)); + } + if (n & 1) { + BLASLONG n6 = n & 6; + dst2[n6] = data2[n6]; + } +#endif +} + +FORCEINLINE void vec_storeN_f32(vec_f32 data, void *dst, BLASLONG n) +{ +#ifndef _ARCH_PWR9 + if (n & 4) { + vec_xst(data, 0, (FLOAT *)dst); + return; + } +#endif + return vec_storeN((vec_bf16)data, dst, n * (sizeof(FLOAT) / sizeof(IFLOAT))); +} + +FORCEINLINE void vec_storeN2_f32(vec_f32 *data, vec_f32 *dst, BLASLONG n) +{ + dst[0] = data[0]; + vec_storeN_f32(data[1], &dst[1], n); +} +#endif diff --git a/kernel/power/sbgemv_common.c b/kernel/power/sbgemv_common.c new file mode 100644 index 0000000000..830481fef3 --- /dev/null +++ b/kernel/power/sbgemv_common.c @@ -0,0 +1,223 @@ +/*************************************************************************** +Copyright (c) 2024, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#ifndef SBGEMV_COMMON_C +#define SBGEMV_COMMON_C +#include "gemm_common.c" + +#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ +#define BF16_HI(data, zero) (vec_f32)vec_mergeh(data, zero) +#define BF16_LO(data, zero) (vec_f32)vec_mergel(data, zero) +#else +#define BF16_HI(data, zero) (vec_f32)vec_mergeh(zero, data) +#define BF16_LO(data, zero) (vec_f32)vec_mergel(zero, data) +#endif + +FORCEINLINE vec_f32 vec_loadNHi(void *src, BLASLONG n, vec_bf16 zero) +{ + vec_bf16 data = vec_loadN(src, n); + return BF16_HI(data, zero); +} + +FORCEINLINE vec_f32 vec_mult(vec_f32 *inp, vec_bf16 in0, vec_bf16 zero) +{ + vec_f32 v_in00 = BF16_HI(in0, zero); + vec_f32 v_in01 = BF16_LO(in0, zero); + + return (inp[0] * v_in00) + (inp[1] * v_in01); +} + +FORCEINLINE vec_f32 vec_load_mult(vec_bf16 *in, vec_f32 *inp, vec_bf16 zero) +{ + vec_bf16 in0 = (vec_bf16)vec_load_vec(in); + + return vec_mult(inp, in0, zero); +} + +FORCEINLINE void vec_load_vec2(vec_bf16 *in, vec_f32 *v_x0, vec_bf16 zero) +{ + vec_bf16 inp = (vec_bf16)vec_load_vec(in); + + v_x0[0] = BF16_HI(inp, zero); + v_x0[1] = BF16_LO(inp, zero); +} + +FORCEINLINE void vec_mult2(vec_f32 v_x0, vec_bf16 in0, vec_bf16 zero, vec_f32 *vy0) +{ + vec_f32 v_in00 = BF16_HI(in0, zero); + vec_f32 v_in01 = BF16_LO(in0, zero); + + vy0[0] += (v_x0 * v_in00); + vy0[1] += (v_x0 * v_in01); +} + +FORCEINLINE void vec_load_mult2(vec_f32 v_x0, vec_bf16 *in, vec_bf16 zero, vec_f32 *vy0) +{ + vec_bf16 in0 = (vec_bf16)vec_load_vec(in); + + vec_mult2(v_x0, in0, zero, vy0); +} + +FORCEINLINE vec_f32 vec_loadN_mult(vec_bf16 *in, vec_f32 *inp, BLASLONG n, vec_bf16 zero) +{ + vec_bf16 in0 = vec_loadN(in, n); + + return vec_mult(inp, in0, zero); +} + +FORCEINLINE void vec_loadN_vec2(vec_bf16 *in, vec_f32 *v_x0, BLASLONG n, vec_bf16 zero) +{ + vec_bf16 inp = vec_loadN(in, n); + + v_x0[0] = BF16_HI(inp, zero); + v_x0[1] = BF16_LO(inp, zero); +} + +FORCEINLINE void vec_loadN_mult2(vec_f32 v_x0, vec_bf16 *in, BLASLONG n, vec_bf16 zero, vec_f32 *vy0) +{ + vec_bf16 in0 = vec_loadN(in, n); + + vec_mult2(v_x0, in0, zero, vy0); +} + +FORCEINLINE vec_f32 vec_loadNHi_mult(vec_bf16 *in, vec_f32 v_inp0, BLASLONG n, vec_bf16 zero) +{ + vec_f32 v_in00 = vec_loadNHi(in, n, zero); + + return (v_inp0 * v_in00); +} + +FORCEINLINE void copy_x(BLASLONG n, IFLOAT *src, IFLOAT *dest, BLASLONG inc_src) +{ + for (BLASLONG i = 0; i < n; i++) { + *dest++ = *src; + src += inc_src; + } +} + +FORCEINLINE void copy_y_beta(BLASLONG n, FLOAT *src, FLOAT *dest, BLASLONG inc_src, FLOAT beta) +{ + if (beta == (FLOAT)0) { + memset(dest, 0, n * sizeof(FLOAT)); + } else if (beta == (FLOAT)1) { + for (BLASLONG i = 0; i < n; i++) { + *dest++ = *src; + src += inc_src; + } + } else { + for (BLASLONG i = 0; i < n; i++) { + *dest++ = *src * beta; + src += inc_src; + } + } +} + +FORCEINLINE void move_y(BLASLONG n, FLOAT *src, FLOAT *dest, BLASLONG inc_dest) +{ + for (BLASLONG i = 0; i < n; i++) { + *dest = *src++; + dest += inc_dest; + } +} + +FORCEINLINE void copy_y(BLASLONG n, FLOAT *src, FLOAT *dest, BLASLONG inc_src, FLOAT beta) +{ + if (beta == (FLOAT)0) { + move_y(n, src, dest, inc_src); + } else if (beta == (FLOAT)1) { + for (BLASLONG i = 0; i < n; i++) { + *dest += *src++; + dest += inc_src; + } + } else { + for (BLASLONG i = 0; i < n; i++) { + *dest = *src++ + (beta * *dest); + dest += inc_src; + } + } +} + +static void BF16GEMV_N_beta(BLASLONG n, FLOAT *output_vector, FLOAT *input_vector, FLOAT beta) +{ + if (beta == (FLOAT)0) { + memset(output_vector, 0, sizeof(FLOAT) * n); + } else if (beta == (FLOAT)1) { + if (output_vector != input_vector) { + memcpy(output_vector, input_vector, sizeof(FLOAT) * n); + } + } else { + vec_f32 b = { beta, beta, beta, beta }; + + vec_f32 *in = (vec_f32 *)input_vector; + vec_f32 *out = (vec_f32 *)output_vector; + + BLASLONG n8 = n / 8; + BLASLONG i = 0; + vec_f32 v_inp0[2]; + + for (; i + 4 <= n8; i += 4) { + vec_f32 v_inp1[2], v_inp2[2], v_inp3[2]; + vec_load_pair(v_inp0, &in[(i * 2) + 0]); + vec_load_pair(v_inp1, &in[(i * 2) + 2]); + vec_load_pair(v_inp2, &in[(i * 2) + 4]); + vec_load_pair(v_inp3, &in[(i * 2) + 6]); + v_inp0[0] *= b; + v_inp0[1] *= b; + v_inp1[0] *= b; + v_inp1[1] *= b; + v_inp2[0] *= b; + v_inp2[1] *= b; + v_inp3[0] *= b; + v_inp3[1] *= b; + vec_store_pair(&out[(i * 2) + 0], v_inp0); + vec_store_pair(&out[(i * 2) + 2], v_inp1); + vec_store_pair(&out[(i * 2) + 4], v_inp2); + vec_store_pair(&out[(i * 2) + 6], v_inp3); + } + + for (; i < n8; i++) { + vec_load_pair(v_inp0, &in[(i * 2) + 0]); + v_inp0[0] *= b; + v_inp0[1] *= b; + vec_store_pair(&out[(i * 2) + 0], v_inp0); + } + + n &= 7; + if (n > 4) { + BLASLONG n3 = n & 3; + vec_loadN2_f32(v_inp0, &in[(i * 2) + 0], n3); + v_inp0[0] *= b; + v_inp0[1] *= b; + vec_storeN2_f32(v_inp0, &out[(i * 2) + 0], n3); + } else if (n) { + v_inp0[0] = vec_loadN_f32(&in[(i * 2) + 0], n); + v_inp0[0] *= b; + vec_storeN_f32(v_inp0[0], &out[(i * 2) + 0], n); + } + } +} +#endif diff --git a/kernel/power/sbgemv_common_power10.c b/kernel/power/sbgemv_common_power10.c new file mode 100644 index 0000000000..b0e611cb68 --- /dev/null +++ b/kernel/power/sbgemv_common_power10.c @@ -0,0 +1,629 @@ +/*************************************************************************** +Copyright (c) 2024, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#ifndef SBGEMV_COMMON_MMA_C +#define SBGEMV_COMMON_MMA_C +#include "sbgemv_common.c" + +#if defined(_AIX) || defined(__clang__) +#define USE_MERGE_MMA +#endif + +FORCEINLINE void vec_load_pair2(vec_bf16 *in0, vec_bf16 *in) +{ + vec_load_pair((vec_f32 *)(in0 + 0), (vec_f32 *)(in + 0)); + vec_load_pair((vec_f32 *)(in0 + 2), (vec_f32 *)(in + 2)); +} + +FORCEINLINE void vec_load_mult_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 inp) +{ + vec_bf16 in0 = (vec_bf16)vec_load_vec(in); + + __builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0, (vec_uc8)inp); +} + +FORCEINLINE void vec_load_mult12a_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 inp) +{ + vec_bf16 in11 = (vec_bf16)vec_load_vec(in1); + + vec_load_mult_mma(out, in0, inp); + + __builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11, (vec_uc8)inp); +} + +FORCEINLINE void vec_load_mult14_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *in2, vec_bf16 *in3, vec_bf16 inp) +{ + vec_bf16 in21 = (vec_bf16)vec_load_vec(in2); + vec_bf16 in31 = (vec_bf16)vec_load_vec(in3); + + vec_load_mult12a_mma(out, in0, in1, inp); + + __builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21, (vec_uc8)inp); + __builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31, (vec_uc8)inp); +} + +FORCEINLINE void vec_load_mult2_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 *inp) +{ + vec_bf16 in0[2]; + + vec_load_pair((vec_f32 *)in0, (vec_f32 *)in); + + __builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0[0], (vec_uc8)inp[0]); + __builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0[1], (vec_uc8)inp[1]); +} + +FORCEINLINE void vec_mult2d_mma(__vector_quad *out, vec_bf16 *in01, vec_bf16 *in11, vec_bf16 *inp) +{ + __builtin_mma_xvbf16ger2pp(out + 0, (vec_uc8)in01[0], (vec_uc8)inp[0]); + __builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11[0], (vec_uc8)inp[0]); +} + +FORCEINLINE void vec_load_mult22_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *inp) +{ + vec_bf16 in01[2], in11[2]; + + vec_load_pair((vec_f32 *)in01, (vec_f32 *)in0); + vec_load_pair((vec_f32 *)in11, (vec_f32 *)in1); + + vec_mult2d_mma(out, in01 + 0, in11 + 0, inp + 0); + vec_mult2d_mma(out, in01 + 1, in11 + 1, inp + 1); +} + +FORCEINLINE void vec_load_mult24_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *in2, vec_bf16 *in3, vec_bf16 *inp) +{ + vec_bf16 in01[2], in11[2], in21[2], in31[2]; + + vec_load_pair((vec_f32 *)in01, (vec_f32 *)in0); + vec_load_pair((vec_f32 *)in11, (vec_f32 *)in1); + vec_load_pair((vec_f32 *)in21, (vec_f32 *)in2); + vec_load_pair((vec_f32 *)in31, (vec_f32 *)in3); + + vec_mult2d_mma(out + 0, in01 + 0, in11 + 0, inp + 0); + vec_mult2d_mma(out + 2, in21 + 0, in31 + 0, inp + 0); + vec_mult2d_mma(out + 0, in01 + 1, in11 + 1, inp + 1); + vec_mult2d_mma(out + 2, in21 + 1, in31 + 1, inp + 1); +} + +FORCEINLINE void vec_load_mult4_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 *inp) +{ + vec_bf16 in0[2]; + + vec_load_pair((vec_f32 *)(in0 + 0), (vec_f32 *)(in + 2)); + + vec_load_mult2_mma(out, in + 0, inp + 0); + + __builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0[0], (vec_uc8)inp[2]); + __builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0[1], (vec_uc8)inp[3]); +} + +FORCEINLINE void vec_load_mult42_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *inp) +{ + vec_bf16 in01[4], in11[4]; + + vec_load_pair2(in01, in0); + vec_load_pair2(in11, in1); + + vec_mult2d_mma(out, in01 + 0, in11 + 0, inp + 0); + vec_mult2d_mma(out, in01 + 1, in11 + 1, inp + 1); + vec_mult2d_mma(out, in01 + 2, in11 + 2, inp + 2); + vec_mult2d_mma(out, in01 + 3, in11 + 3, inp + 3); +} + +FORCEINLINE void vec_mult4d_mma(__vector_quad *out, vec_bf16 *in01, vec_bf16 *in11, vec_bf16 *in21, vec_bf16 *in31, vec_bf16 *inp) +{ + vec_mult2d_mma(out + 0, in01, in11, inp); + vec_mult2d_mma(out + 2, in21, in31, inp); +} + +FORCEINLINE void vec_load_mult44_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *in2, vec_bf16 *in3, vec_bf16 *inp) +{ + vec_bf16 in01[4], in11[4], in21[4], in31[4]; + + vec_load_pair2(in01, in0); + vec_load_pair2(in11, in1); + vec_load_pair2(in21, in2); + vec_load_pair2(in31, in3); + + vec_mult4d_mma(out, in01 + 0, in11 + 0, in21 + 0, in31 + 0, inp + 0); + vec_mult4d_mma(out, in01 + 1, in11 + 1, in21 + 1, in31 + 1, inp + 1); + vec_mult4d_mma(out, in01 + 2, in11 + 2, in21 + 2, in31 + 2, inp + 2); + vec_mult4d_mma(out, in01 + 3, in11 + 3, in21 + 3, in31 + 3, inp + 3); +} + +FORCEINLINE void vec_loadN_mult_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 inp, BLASLONG n) +{ + vec_bf16 in0 = vec_loadN(in, n); + + __builtin_mma_xvbf16ger2pp(out, (vec_uc8)in0, (vec_uc8)inp); +} + +FORCEINLINE void vec_loadN_mult12a_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 inp, BLASLONG n) +{ + vec_bf16 in11 = (vec_bf16)vec_loadN(in1, n); + + vec_loadN_mult_mma(out, in0, inp, n); + + __builtin_mma_xvbf16ger2pp(out + 1, (vec_uc8)in11, (vec_uc8)inp); +} + +FORCEINLINE void vec_loadN_mult14_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *in2, vec_bf16 *in3, vec_bf16 inp, BLASLONG n) +{ + vec_bf16 in21 = (vec_bf16)vec_loadN(in2, n); + vec_bf16 in31 = (vec_bf16)vec_loadN(in3, n); + + vec_loadN_mult12a_mma(out, in0, in1, inp, n); + + __builtin_mma_xvbf16ger2pp(out + 2, (vec_uc8)in21, (vec_uc8)inp); + __builtin_mma_xvbf16ger2pp(out + 3, (vec_uc8)in31, (vec_uc8)inp); +} + +FORCEINLINE void vec_mult1_mma(__vector_quad *out, vec_bf16 in0, vec_bf16 inp) +{ + vec_bf16 in00 = vec_mergeh(in0, in0); + + __builtin_mma_xvbf16ger2(out, (vec_uc8)inp, (vec_uc8)in00); +} + +FORCEINLINE void vec_mult2_mma(__vector_quad *out, vec_bf16 in0, vec_bf16 inp) +{ + vec_bf16 in01 = vec_mergel(in0, in0); + + vec_mult1_mma(&out[0], in0, inp); + + __builtin_mma_xvbf16ger2(&out[1], (vec_uc8)inp, (vec_uc8)in01); +} + +#ifndef USE_MERGE_MMA +FORCEINLINE void vec_mult4_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 inp) +{ + vec_mult2_mma(out + 0, in0[0], inp); + vec_mult2_mma(out + 2, in0[1], inp); +} +#endif + +FORCEINLINE void vec_loadN_mult11_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 inp, BLASLONG n) +{ + vec_bf16 in0 = vec_loadN(in, n); + + vec_mult1_mma(out, in0, inp); +} + +FORCEINLINE void vec_loadN_mult12_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 inp, BLASLONG n) +{ + vec_bf16 in0 = vec_loadN(in, n); + + vec_mult2_mma(out, in0, inp); +} + +FORCEINLINE void vec_load_mult12_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 inp) +{ + vec_bf16 in0 = (vec_bf16)vec_load_vec(in); + + vec_mult2_mma(out, in0, inp); +} + +#ifndef USE_MERGE_MMA +FORCEINLINE void vec_load_mult18_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 inp) +{ + vec_bf16 in0[4]; + + vec_load_pair((vec_f32 *)(in0 + 0), (vec_f32 *)(in + 0)); + vec_load_pair((vec_f32 *)(in0 + 2), (vec_f32 *)(in + 2)); + + vec_mult4_mma(&out[0], in0 + 0, inp); + vec_mult4_mma(&out[4], in0 + 2, inp); +} +#endif + +FORCEINLINE void vec_reduce1_mma(__vector_quad *out, vec_f32 *temp, vec_f32 v_alpha, vec_f32 *vy0) +{ + __builtin_mma_disassemble_acc((void*)temp, &out[0]); + + vy0[0] += (temp[0] * v_alpha); +} + +FORCEINLINE void vec_reduce2_mma(__vector_quad *out, vec_f32 *temp, vec_f32 v_alpha, vec_f32 *vy0) +{ + vec_reduce1_mma(&out[0], &temp[0], v_alpha, &vy0[0]); + vec_reduce1_mma(&out[1], &temp[4], v_alpha, &vy0[1]); +} + +#ifndef USE_MERGE_MMA +FORCEINLINE void vec_reduce8_mma(__vector_quad *out, vec_f32 *temp, vec_f32 v_alpha, vec_f32 *vy0) +{ + vec_reduce2_mma(&out[0], &temp[0], v_alpha, vy0 + 0); + vec_reduce2_mma(&out[2], &temp[8], v_alpha, vy0 + 2); + vec_reduce2_mma(&out[4], &temp[16], v_alpha, vy0 + 4); + vec_reduce2_mma(&out[6], &temp[24], v_alpha, vy0 + 6); +} +#else +FORCEINLINE void vec_reduce44_mma(__vector_quad *out, vec_f32 *temp, vec_f32 v_alpha, vec_f32 *vy0) +{ + __builtin_mma_disassemble_acc((void*)temp, &out[0]); + + vy0[0] += (temp[0] * v_alpha); + vy0[2] += (temp[1] * v_alpha); + vy0[4] += (temp[2] * v_alpha); + vy0[6] += (temp[3] * v_alpha); +} + +FORCEINLINE void vec_reduce84_mma(__vector_quad *out, vec_f32 *temp, vec_f32 v_alpha, vec_f32 *vy0) +{ + vec_reduce44_mma(&out[0], &temp[0], v_alpha, vy0 + 0); + vec_reduce44_mma(&out[1], &temp[4], v_alpha, vy0 + 1); +} + +FORCEINLINE void vec_reduce88_mma(__vector_quad *out, vec_f32 *temp, vec_f32 v_alpha, vec_f32 *vy0) +{ + vec_reduce44_mma(&out[0], &temp[ 0], v_alpha, vy0 + 0); + vec_reduce44_mma(&out[1], &temp[ 4], v_alpha, vy0 + 1); + vec_reduce44_mma(&out[2], &temp[ 8], v_alpha, vy0 + 8); + vec_reduce44_mma(&out[3], &temp[12], v_alpha, vy0 + 9); +} +#endif + +FORCEINLINE void vec_mult11a_mma(__vector_quad *out, vec_bf16 in0, vec_bf16 in1, vec_bf16 inp) +{ + vec_bf16 in00 = vec_mergeh(in0, in1); + + __builtin_mma_xvbf16ger2(out, (vec_uc8)inp, (vec_uc8)in00); +} + +FORCEINLINE void vec_mult2a_mma(__vector_quad *out, vec_bf16 in0, vec_bf16 in1, vec_bf16 inp) +{ + vec_bf16 in01 = vec_mergel(in0, in1); + + vec_mult11a_mma(&out[0], in0, in1, inp); + + __builtin_mma_xvbf16ger2(&out[1], (vec_uc8)inp, (vec_uc8)in01); +} + +FORCEINLINE void vec_mult4a_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 inp) +{ + vec_mult2a_mma(out + 0, in0[0], in1[0], inp); + vec_mult2a_mma(out + 2, in0[1], in1[1], inp); +} + +FORCEINLINE void vec_loadN_mult11a_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 inp, BLASLONG n) +{ + vec_bf16 in0 = vec_loadN(ina, n); + vec_bf16 in1 = vec_loadN(inb, n); + + vec_mult11a_mma(out, in0, in1, inp); +} + +FORCEINLINE void vec_load_mult22a_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 inp) +{ + vec_bf16 in0 = (vec_bf16)vec_load_vec(ina); + vec_bf16 in1 = (vec_bf16)vec_load_vec(inb); + + vec_mult2a_mma(out, in0, in1, inp); +} + +FORCEINLINE void vec_load4_mma(vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *ina, vec_bf16 *inb) +{ + vec_load_pair((vec_f32 *)(in0 + 0), (vec_f32 *)(ina + 0)); + vec_load_pair((vec_f32 *)(in1 + 0), (vec_f32 *)(inb + 0)); + vec_load_pair((vec_f32 *)(in0 + 2), (vec_f32 *)(ina + 2)); + vec_load_pair((vec_f32 *)(in1 + 2), (vec_f32 *)(inb + 2)); +} + +#ifndef USE_MERGE_MMA +FORCEINLINE void vec_load_mult28a_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 inp) +{ + vec_bf16 in0[4], in1[4]; + + vec_load4_mma(in0, in1, ina, inb); + + vec_mult4a_mma(&out[0], in0 + 0, in1 + 0, inp); + vec_mult4a_mma(&out[4], in0 + 2, in1 + 2, inp); +} +#endif + +FORCEINLINE void vec_loadN_mult22a_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 inp, BLASLONG n) +{ + vec_bf16 in0 = vec_loadN(ina, n); + vec_bf16 in1 = vec_loadN(inb, n); + + vec_mult2a_mma(out, in0, in1, inp); +} + +FORCEINLINE void vec_mult11b_mma(__vector_quad *out, vec_bf16 in0, vec_bf16 in1, vec_bf16 inp) +{ + vec_bf16 in00 = vec_mergeh(in0, in1); + + __builtin_mma_xvbf16ger2pp(out, (vec_uc8)inp, (vec_uc8)in00); +} + +FORCEINLINE void vec_mult2b_mma(__vector_quad *out, vec_bf16 in0, vec_bf16 in1, vec_bf16 inp) +{ + vec_bf16 in01 = vec_mergel(in0, in1); + + vec_mult11b_mma(&out[0], in0, in1, inp); + + __builtin_mma_xvbf16ger2pp(&out[1], (vec_uc8)inp, (vec_uc8)in01); +} + +FORCEINLINE void vec_mult4b_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 inp) +{ + vec_mult2b_mma(out + 0, in0[0], in1[0], inp); + vec_mult2b_mma(out + 2, in0[1], in1[1], inp); +} + +#ifdef USE_MERGE_MMA +FORCEINLINE void vec_mult1c_mma(__vector_quad *out, vec_bf16 in0, vec_bf16 inp) +{ + vec_bf16 in00 = vec_mergeh(in0, in0); + + __builtin_mma_xvbf16ger2pp(out, (vec_uc8)inp, (vec_uc8)in00); +} + +FORCEINLINE void vec_mult2c_mma(__vector_quad *out, vec_bf16 in0, vec_bf16 inp) +{ + vec_bf16 in01 = vec_mergel(in0, in0); + + vec_mult1c_mma(&out[0], in0, inp); + + __builtin_mma_xvbf16ger2pp(&out[1], (vec_uc8)inp, (vec_uc8)in01); +} + +FORCEINLINE void vec_mult44_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 *inp) +{ + vec_mult2_mma(out, in[0], inp[0]); + vec_mult2c_mma(out, in[1], inp[1]); +} + +FORCEINLINE void vec_mult44c_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 *inp) +{ + vec_mult2c_mma(out, in[0], inp[0]); + vec_mult2c_mma(out, in[1], inp[1]); +} + +FORCEINLINE void vec_mult44a_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *inp) +{ + vec_mult2a_mma(out, in0[0], in1[0], inp[0]); + vec_mult2b_mma(out, in0[1], in1[1], inp[1]); +} + +FORCEINLINE void vec_mult44b_mma(__vector_quad *out, vec_bf16 *in0, vec_bf16 *in1, vec_bf16 *inp) +{ + vec_mult2b_mma(out, in0[0], in1[0], inp[0]); + vec_mult2b_mma(out, in0[1], in1[1], inp[1]); +} +#endif + +FORCEINLINE void vec_loadN_mult11b_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 inp, BLASLONG n) +{ + vec_bf16 in0 = vec_loadN(ina, n); + vec_bf16 in1 = vec_loadN(inb, n); + + vec_mult11b_mma(out, in0, in1, inp); +} + +FORCEINLINE void vec_load_mult22b_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 inp) +{ + vec_bf16 in0 = (vec_bf16)vec_load_vec(ina); + vec_bf16 in1 = (vec_bf16)vec_load_vec(inb); + + vec_mult2b_mma(out, in0, in1, inp); +} + +#ifndef USE_MERGE_MMA +FORCEINLINE void vec_load_mult28b_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 inp) +{ + vec_bf16 in0[4], in1[4]; + + vec_load4_mma(in0, in1, ina, inb); + + vec_mult4b_mma(&out[0], in0 + 0, in1 + 0, inp); + vec_mult4b_mma(&out[4], in0 + 2, in1 + 2, inp); +} +#else +FORCEINLINE void vec_load_mult184_mma(__vector_quad *out, vec_bf16 *in, vec_bf16 *inp) +{ + vec_bf16 in0[4]; + + vec_load_pair((vec_f32 *)(in0 + 0), (vec_f32 *)(in + 0)); + vec_load_pair((vec_f32 *)(in0 + 2), (vec_f32 *)(in + 2)); + + vec_mult44_mma(out, in0 + 0, inp + 0); + vec_mult44c_mma(out, in0 + 2, inp + 2); +} + +FORCEINLINE void vec_load_mult284a_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 *inp) +{ + vec_bf16 in0[4], in1[4]; + + vec_load4_mma(in0, in1, ina, inb); + + vec_mult44a_mma(out, in0 + 0, in1 + 0, inp + 0); + vec_mult44b_mma(out, in0 + 2, in1 + 2, inp + 2); +} + +FORCEINLINE void vec_load_mult284b_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 *inp) +{ + vec_bf16 in0[4], in1[4]; + + vec_load4_mma(in0, in1, ina, inb); + + vec_mult44b_mma(out, in0 + 0, in1 + 0, inp + 0); + vec_mult44b_mma(out, in0 + 2, in1 + 2, inp + 2); +} + +FORCEINLINE void vec_load_mult288a_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 *inp) +{ + vec_bf16 in0[8], in1[8]; + + vec_load4_mma(in0 + 0, in1 + 0, ina + 0, inb + 0); + vec_load4_mma(in0 + 4, in1 + 4, ina + 4, inb + 4); + + vec_mult44a_mma(out + 0, in0 + 0, in1 + 0, inp + 0); + vec_mult44a_mma(out + 2, in0 + 4, in1 + 4, inp + 0); + vec_mult44b_mma(out + 0, in0 + 2, in1 + 2, inp + 2); + vec_mult44b_mma(out + 2, in0 + 6, in1 + 6, inp + 2); +} + +FORCEINLINE void vec_load_mult288b_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 *inp) +{ + vec_bf16 in0[8], in1[8]; + + vec_load4_mma(in0 + 0, in1 + 0, ina + 0, inb + 0); + vec_load4_mma(in0 + 4, in1 + 4, ina + 4, inb + 4); + + vec_mult44b_mma(out + 0, in0 + 0, in1 + 0, inp + 0); + vec_mult44b_mma(out + 2, in0 + 4, in1 + 4, inp + 0); + vec_mult44b_mma(out + 0, in0 + 2, in1 + 2, inp + 2); + vec_mult44b_mma(out + 2, in0 + 6, in1 + 6, inp + 2); +} +#endif + +FORCEINLINE void vec_loadN_mult22b_mma(__vector_quad *out, vec_bf16 *ina, vec_bf16 *inb, vec_bf16 inp, BLASLONG n) +{ + vec_bf16 in0 = vec_loadN(ina, n); + vec_bf16 in1 = vec_loadN(inb, n); + + vec_mult2b_mma(out, in0, in1, inp); +} + +FORCEINLINE void vec_load4_pair(vec_f32 *vy0, vec_f32 *v_y) +{ + vec_load_pair(vy0 + 0, v_y + 0); + vec_load_pair(vy0 + 2, v_y + 2); + vec_load_pair(vy0 + 4, v_y + 4); + vec_load_pair(vy0 + 6, v_y + 6); +} + +FORCEINLINE void vec_store4_pair(vec_f32 *v_y, vec_f32 *vy0) +{ + vec_store_pair(v_y + 0, vy0 + 0); + vec_store_pair(v_y + 2, vy0 + 2); + vec_store_pair(v_y + 4, vy0 + 4); + vec_store_pair(v_y + 6, vy0 + 6); +} + +FORCEINLINE void vec_setzero_2(__vector_quad *temp0) +{ + __builtin_mma_xxsetaccz(&temp0[0]); + __builtin_mma_xxsetaccz(&temp0[1]); +} + +FORCEINLINE void vec_setzero_4(__vector_quad *temp0) +{ + vec_setzero_2(temp0 + 0); + vec_setzero_2(temp0 + 2); +} + +FORCEINLINE void vec_setzero_8(__vector_quad *temp0) +{ + vec_setzero_4(temp0 + 0); + vec_setzero_4(temp0 + 4); +} + +FORCEINLINE void vec_reduce_2(vec_f32 *temp00, __vector_quad *temp0) +{ + __builtin_mma_disassemble_acc((void*)(temp00 + 0), &temp0[0]); + __builtin_mma_disassemble_acc((void*)(temp00 + 4), &temp0[1]); +} + +FORCEINLINE void vec_reduce_4(vec_f32 *temp00, __vector_quad *temp0) +{ + vec_reduce_2(temp00 + 0, temp0 + 0); + vec_reduce_2(temp00 + 8, temp0 + 2); +} + +FORCEINLINE void vec_reduce_8(vec_f32 *temp00, __vector_quad *temp0) +{ + vec_reduce_4(temp00 + 0, temp0 + 0); + vec_reduce_4(temp00 + 16, temp0 + 4); +} + +#ifdef USE_MERGE_MMA +FORCEINLINE void vec_load8_pair(vec_f32 *vy0, vec_f32 *v_y) +{ + vec_load4_pair(vy0 + 0, v_y + 0); + vec_load4_pair(vy0 + 8, v_y + 8); +} + +FORCEINLINE void vec_store8_pair(vec_f32 *v_y, vec_f32 *vy0) +{ + vec_store4_pair(v_y + 0, vy0 + 0); + vec_store4_pair(v_y + 8, vy0 + 8); +} + +#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ +#define VEC_SHIFT(data, shift) vec_sldw(data, data, 4 - shift) + +#define MASK_0 0xf000 +#define MASK_1 0x0f00 +#define MASK_2 0x00f0 +#define MASK_3 0x000f +#else +#define VEC_SHIFT(data, shift) vec_sldw(data, data, shift) + +#define MASK_0 0x000f +#define MASK_1 0x00f0 +#define MASK_2 0x0f00 +#define MASK_3 0xf000 +#endif + +FORCEINLINE void vec_make_mult1(vec_bf16 *v_x0, const bool mask) +{ + if (mask) { + v_x0[ 0] = vec_and(v_x0[0], (vec_bf16)vec_genbm(MASK_0)); + } + + v_x0[ 1] = VEC_SHIFT(v_x0[ 0], 1); + v_x0[ 2] = VEC_SHIFT(v_x0[ 0], 2); + v_x0[ 3] = VEC_SHIFT(v_x0[ 0], 3); +} + +FORCEINLINE void vec_make_mult2(vec_bf16 *v_x0) +{ + v_x0[ 5] = vec_and(v_x0[0], (vec_bf16)vec_genbm(MASK_1)); + vec_make_mult1(v_x0, true); + + v_x0[ 4] = VEC_SHIFT(v_x0[ 5], 3); + v_x0[ 6] = VEC_SHIFT(v_x0[ 5], 1); + v_x0[ 7] = VEC_SHIFT(v_x0[ 5], 2); +} + +FORCEINLINE void vec_make_mult4(vec_bf16 *v_x0) +{ + v_x0[10] = vec_and(v_x0[0], (vec_bf16)vec_genbm(MASK_2)); + v_x0[15] = vec_and(v_x0[0], (vec_bf16)vec_genbm(MASK_3)); + vec_make_mult2(v_x0); + + v_x0[ 8] = VEC_SHIFT(v_x0[10], 2); + v_x0[ 9] = VEC_SHIFT(v_x0[10], 3); + v_x0[11] = VEC_SHIFT(v_x0[10], 1); + v_x0[12] = VEC_SHIFT(v_x0[15], 1); + v_x0[13] = VEC_SHIFT(v_x0[15], 2); + v_x0[14] = VEC_SHIFT(v_x0[15], 3); +} +#endif + +#endif diff --git a/kernel/power/sbgemv_n.c b/kernel/power/sbgemv_n.c new file mode 100644 index 0000000000..e6f7f587e6 --- /dev/null +++ b/kernel/power/sbgemv_n.c @@ -0,0 +1,152 @@ +/*************************************************************************** +Copyright (c) 2024, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#ifndef SBGEMV_N_COMMON_C +#define SBGEMV_N_COMMON_C + +#if (defined(_ARCH_PWR10) && (defined(USE_BFGEMV_8_N_MMA) || (!defined(USE_BFGEMV_N_MMA) && defined(USE_BFGEMV_8_N_VSX)))) || (!defined(_ARCH_PWR10) && defined(USE_BFGEMV_8_N_VSX)) +#define USE_N_8 +#endif + +int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT *x, BLASLONG inc_x, FLOAT beta, FLOAT *y, BLASLONG inc_y) +{ + IFLOAT *x_ptr, *ap[4]; + IFLOAT xbuffer[8] __attribute__((aligned(16))); + FLOAT *y_ptr, *ybuffer; + FLOAT buffer[NBMAX] __attribute__((aligned(16))); + + if ((m < 1) || (n < 1)) return 0; + + ybuffer = buffer; + y_ptr = y; + + BLASLONG lda4 = lda << 2; +#ifdef USE_N_8 + BLASLONG lda8 = lda << 3; +#endif + BLASLONG NB = NBMAX; + BLASLONG m2 = (m & (NBMAX - 1)); + + while (NB == NBMAX) { + m -= NB; + if (m < 0) { + if (m2 == 0) break; + NB = m2; + } + + if (inc_y != 1) { + copy_y_beta(NB, y_ptr, ybuffer, inc_y, beta); + } else { + ybuffer = y_ptr; + BF16GEMV_N_beta(NB, ybuffer, ybuffer, beta); + } + + x_ptr = x; + + ap[0] = a; + ap[1] = a + lda; + ap[2] = ap[1] + lda; + ap[3] = ap[2] + lda; + + if (inc_x == 1) { +#ifdef USE_N_8 + for (BLASLONG j = 0; j + 8 <= n; j += 8) { + BF16GEMV_N_8(NB, ap, x_ptr, ybuffer, lda4, alpha); + ap[0] += lda8; + ap[1] += lda8; + ap[2] += lda8; + ap[3] += lda8; + x_ptr += 8; + } + if (n & 4) { +#else + for (BLASLONG j = 0; j + 4 <= n; j += 4) { +#endif + BF16GEMV_N_4(NB, ap, x_ptr, ybuffer, alpha); + ap[0] += lda4; + ap[1] += lda4; +#ifndef USE_N_8 + ap[2] += lda4; + ap[3] += lda4; +#endif + x_ptr += 4; + } + if (n & 2) { + BF16GEMV_N_2(NB, ap, x_ptr, ybuffer, alpha); + ap[0] += (lda * 2); + x_ptr += 2; + } + if (n & 1) { + BF16GEMV_N_1(NB, ap, x_ptr, ybuffer, alpha); + } + } else { +#ifdef USE_N_8 + for (BLASLONG j = 0; j + 8 <= n; j += 8) { + copy_x(8, x_ptr, xbuffer, inc_x); + BF16GEMV_N_8(NB, ap, xbuffer, ybuffer, lda4, alpha); + ap[0] += lda8; + ap[1] += lda8; + ap[2] += lda8; + ap[3] += lda8; + x_ptr += 8 * inc_x; + } + if (n & 4) { +#else + for (BLASLONG j = 0; j + 4 <= n; j += 4) { +#endif + copy_x(4, x_ptr, xbuffer, inc_x); + BF16GEMV_N_4(NB, ap, xbuffer, ybuffer, alpha); + ap[0] += lda4; + ap[1] += lda4; +#ifndef USE_N_8 + ap[2] += lda4; + ap[3] += lda4; +#endif + x_ptr += 4 * inc_x; + } + if (n & 2) { + copy_x(2, x_ptr, xbuffer, inc_x); + BF16GEMV_N_2(NB, ap, xbuffer, ybuffer, alpha); + ap[0] += (lda * 2); + x_ptr += 2 * inc_x; + } + if (n & 1) { + copy_x(1, x_ptr, xbuffer, inc_x); + BF16GEMV_N_1(NB, ap, xbuffer, ybuffer, alpha); + } + } + + a += NB; + if (inc_y != 1) { + move_y(NB, ybuffer, y_ptr, inc_y); + } + y_ptr += (NB * inc_y); + } + + return 0; +} +#endif diff --git a/kernel/power/sbgemv_n_power10.c b/kernel/power/sbgemv_n_power10.c new file mode 100644 index 0000000000..b1dcb2fcc4 --- /dev/null +++ b/kernel/power/sbgemv_n_power10.c @@ -0,0 +1,474 @@ +/*************************************************************************** +Copyright (c) 2024, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#ifndef SBGEMV_N_MMA_C +#define SBGEMV_N_MMA_C + +#define USE_BFGEMV_N_MMA + +#ifdef USE_BFGEMV_N_MMA +#include "sbgemv_common_power10.c" + +#ifndef BF16GEMV_N_X +#define BF16GEMV_N_X +#define BF16GEMV_N_8 BF16GEMV_N_MMA_8 +#define BF16GEMV_N_4 BF16GEMV_N_MMA_4 +#define BF16GEMV_N_2 BF16GEMV_N_MMA_2 +#define BF16GEMV_N_1 BF16GEMV_N_MMA_1 +#endif + +#define USE_BFGEMV_8_N_MMA + +static void BF16GEMV_N_MMA_1(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOAT alpha) +{ + IFLOAT *a0; + __vector_quad temp[2*4]; + vec_f32 temp0[8*4]; + vec_f32 v_alpha = { alpha, alpha, alpha, alpha }; + + a0 = ap[0]; + + vec_bf16 *va0 = (vec_bf16 *)a0; + + vec_bf16 *x_bf = (vec_bf16 *)(xo); + + vec_f32 *v_y = (vec_f32 *)y; + BLASLONG n8 = n / 8; + BLASLONG i = 0; + +#ifdef USE_MERGE_MMA + vec_bf16 v_x0[4]; + v_x0[0] = vec_loadN(x_bf, 1); + vec_f32 vy0[2*4*2]; + + vec_make_mult1(v_x0, false); + + for (; i + 8 <= n8; i += 8) { + vec_load_mult184_mma(&temp[0], &va0[i + 0], &v_x0[ 0]); + vec_load_mult184_mma(&temp[2], &va0[i + 4], &v_x0[ 0]); + + vec_load8_pair(vy0, &v_y[(i * 2) + 0]); + + vec_reduce88_mma(&temp[0], temp0 + 0, v_alpha, vy0 + 0); + + vec_store8_pair(&v_y[(i * 2) + 0], vy0); + } + + if (n8 & 4) { + vec_load_mult184_mma(&temp[0], &va0[i + 0], &v_x0[ 0]); + + vec_load4_pair(vy0, &v_y[(i * 2) + 0]); + + vec_reduce84_mma(&temp[0], temp0, v_alpha, vy0); + + vec_store4_pair(&v_y[(i * 2) + 0], vy0); + + i += 4; + } +#else + vec_bf16 v_x0[1]; + v_x0[0] = vec_loadN(x_bf, 1); + vec_f32 vy0[2*4]; + + for (; i + 4 <= n8; i += 4) { + vec_load_mult18_mma(&temp[0], &va0[i + 0], v_x0[ 0]); + + vec_load4_pair(vy0, &v_y[(i * 2) + 0]); + + vec_reduce8_mma(&temp[0], temp0, v_alpha, vy0); + + vec_store4_pair(&v_y[(i * 2) + 0], vy0); + } +#endif + + for (; i < n8; i++) { + vec_load_mult12_mma(&temp[0], &va0[i], v_x0[ 0]); + + vec_load_pair(vy0, &v_y[(i * 2) + 0]); + + vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0); + + vec_store_pair(&v_y[(i * 2) + 0], vy0); + } + + n &= 7; + if (n > 4) { + vec_loadN_mult12_mma(&temp[0], &va0[i], v_x0[ 0], n); + + n &= 3; + vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n); + + vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0); + + vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n); + } else if (n) { + vec_loadN_mult11_mma(&temp[0], &va0[i], v_x0[ 0], n); + + vy0[0] = vec_loadN_f32(&v_y[(i * 2) + 0], n); + + vec_reduce1_mma(&temp[0], temp0, v_alpha, vy0); + + vec_storeN_f32(vy0[0], &v_y[(i * 2) + 0], n); + } +} + +static void BF16GEMV_N_MMA_2(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOAT alpha) +{ + IFLOAT *a0, *a1; + __vector_quad temp[2*4]; + vec_f32 temp0[8*4]; + vec_f32 v_alpha = { alpha, alpha, alpha, alpha }; + + a0 = ap[0]; + a1 = ap[1]; + + vec_bf16 *va0 = (vec_bf16 *)a0; + vec_bf16 *va1 = (vec_bf16 *)a1; + + vec_bf16 *x_bf = (vec_bf16 *)(xo); + + vec_f32 *v_y = (vec_f32 *)y; + BLASLONG n8 = n / 8; + BLASLONG i = 0; + +#ifdef USE_MERGE_MMA + vec_bf16 v_x0[4]; + vec_f32 vy0[2*4*2]; + v_x0[0] = vec_loadN(x_bf, 2); + + vec_make_mult1(v_x0, false); + + for (; i + 8 <= n8; i += 8) { + vec_load_mult288a_mma(&temp[0], &va0[i + 0], &va1[i + 0], &v_x0[ 0]); + + vec_load8_pair(vy0, &v_y[(i * 2) + 0]); + + vec_reduce88_mma(&temp[0], temp0 + 0, v_alpha, vy0 + 0); + + vec_store8_pair(&v_y[(i * 2) + 0], vy0); + } + + if (n8 & 4) { + vec_load_mult284a_mma(&temp[0], &va0[i + 0], &va1[i + 0], &v_x0[ 0]); + + vec_load4_pair(vy0, &v_y[(i * 2) + 0]); + + vec_reduce84_mma(&temp[0], temp0, v_alpha, vy0); + + vec_store4_pair(&v_y[(i * 2) + 0], vy0); + + i += 4; + } +#else + vec_bf16 v_x0[1]; + vec_f32 vy0[2*4]; + v_x0[0] = vec_loadN(x_bf, 2); + + for (; i + 4 <= n8; i += 4) { + vec_load_mult28a_mma(&temp[0], &va0[i + 0], &va1[i + 0], v_x0[ 0]); + + vec_load4_pair(vy0, &v_y[(i * 2) + 0]); + + vec_reduce8_mma(&temp[0], temp0, v_alpha, vy0); + + vec_store4_pair(&v_y[(i * 2) + 0], vy0); + } +#endif + + for (; i < n8; i++) { + vec_load_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0]); + + vec_load_pair(vy0, &v_y[(i * 2) + 0]); + + vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0); + + vec_store_pair(&v_y[(i * 2) + 0], vy0); + } + + n &= 7; + if (n > 4) { + vec_loadN_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0], n); + + n &= 3; + vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n); + + vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0); + + vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n); + } else if (n) { + vec_loadN_mult11a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0], n); + + vy0[0] = vec_loadN_f32(&v_y[(i * 2) + 0], n); + + vec_reduce1_mma(&temp[0], temp0, v_alpha, vy0); + + vec_storeN_f32(vy0[0], &v_y[(i * 2) + 0], n); + } +} + +static void BF16GEMV_N_MMA_4(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOAT alpha) +{ + IFLOAT *a0, *a1, *a2, *a3; + __vector_quad temp[2*4]; + vec_f32 temp0[8*4]; + vec_f32 v_alpha = { alpha, alpha, alpha, alpha }; + + a0 = ap[0]; + a1 = ap[1]; + a2 = ap[2]; + a3 = ap[3]; + + vec_bf16 *va0 = (vec_bf16 *)a0; + vec_bf16 *va1 = (vec_bf16 *)a1; + vec_bf16 *va2 = (vec_bf16 *)a2; + vec_bf16 *va3 = (vec_bf16 *)a3; + + vec_bf16 *x_bf = (vec_bf16 *)(xo); + + vec_f32 *v_y = (vec_f32 *)y; + BLASLONG n8 = n / 8; + BLASLONG i = 0; + +#ifdef USE_MERGE_MMA + vec_bf16 v_x0[8]; + vec_f32 vy0[2*4*2]; + v_x0[0] = vec_loadN(x_bf, 4); + + vec_make_mult2(v_x0); + + for (; i + 8 <= n8; i += 8) { + vec_load_mult288a_mma(&temp[0], &va0[i + 0], &va1[i + 0], &v_x0[ 0]); + vec_load_mult288b_mma(&temp[0], &va2[i + 0], &va3[i + 0], &v_x0[ 4]); + + vec_load8_pair(vy0, &v_y[(i * 2) + 0]); + + vec_reduce88_mma(&temp[0], temp0 + 0, v_alpha, vy0 + 0); + + vec_store8_pair(&v_y[(i * 2) + 0], vy0); + } + + if (n8 & 4) { + vec_load_mult284a_mma(&temp[0], &va0[i + 0], &va1[i + 0], &v_x0[ 0]); + vec_load_mult284b_mma(&temp[0], &va2[i + 0], &va3[i + 0], &v_x0[ 4]); + + vec_load4_pair(vy0, &v_y[(i * 2) + 0]); + + vec_reduce84_mma(&temp[0], temp0, v_alpha, vy0); + + vec_store4_pair(&v_y[(i * 2) + 0], vy0); + + i += 4; + } +#else + vec_bf16 v_x0[5]; + vec_f32 vy0[2*4]; + v_x0[0] = vec_loadN(x_bf, 4); + + v_x0[ 4] = (vec_bf16)vec_splat((vec_f32)v_x0[0], 1); + + for (; i + 4 <= n8; i += 4) { + vec_load_mult28a_mma(&temp[0], &va0[i + 0], &va1[i + 0], v_x0[ 0]); + vec_load_mult28b_mma(&temp[0], &va2[i + 0], &va3[i + 0], v_x0[ 4]); + + vec_load4_pair(vy0, &v_y[(i * 2) + 0]); + + vec_reduce8_mma(&temp[0], temp0, v_alpha, vy0); + + vec_store4_pair(&v_y[(i * 2) + 0], vy0); + } +#endif + + for (; i < n8; i++) { + vec_load_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0]); + vec_load_mult22b_mma(&temp[0], &va2[i], &va3[i], v_x0[ 4]); + + vec_load_pair(vy0, &v_y[(i * 2) + 0]); + + vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0); + + vec_store_pair(&v_y[(i * 2) + 0], vy0); + } + + n &= 7; + if (n > 4) { + vec_loadN_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0], n); + vec_loadN_mult22b_mma(&temp[0], &va2[i], &va3[i], v_x0[ 4], n); + + n &= 3; + vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n); + + vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0); + + vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n); + } else if (n) { + vec_loadN_mult11a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0], n); + vec_loadN_mult11b_mma(&temp[0], &va2[i], &va3[i], v_x0[ 4], n); + + vy0[0] = vec_loadN_f32(&v_y[(i * 2) + 0], n); + + vec_reduce1_mma(&temp[0], temp0, v_alpha, vy0); + + vec_storeN_f32(vy0[0], &v_y[(i * 2) + 0], n); + } +} + +#ifdef USE_BFGEMV_8_N_MMA +static void BF16GEMV_N_MMA_8(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, BLASLONG lda4, FLOAT alpha) +{ + IFLOAT *a0, *a1, *a2, *a3, *b0, *b1, *b2, *b3; + __vector_quad temp[2*4]; + vec_f32 temp0[8*4]; + vec_f32 v_alpha = { alpha, alpha, alpha, alpha }; + + a0 = ap[0]; + a1 = ap[1]; + a2 = ap[2]; + a3 = ap[3]; + b0 = a0 + lda4; + b1 = a1 + lda4; + b2 = a2 + lda4; + b3 = a3 + lda4; + + vec_bf16 *va0 = (vec_bf16 *)a0; + vec_bf16 *va1 = (vec_bf16 *)a1; + vec_bf16 *va2 = (vec_bf16 *)a2; + vec_bf16 *va3 = (vec_bf16 *)a3; + vec_bf16 *vb0 = (vec_bf16 *)b0; + vec_bf16 *vb1 = (vec_bf16 *)b1; + vec_bf16 *vb2 = (vec_bf16 *)b2; + vec_bf16 *vb3 = (vec_bf16 *)b3; + + vec_bf16 *x_bf = (vec_bf16 *)(xo); + + vec_f32 *v_y = (vec_f32 *)y; + BLASLONG n8 = n / 8; + BLASLONG i = 0; + +#ifdef USE_MERGE_MMA + vec_bf16 v_x0[16]; + vec_f32 vy0[2*4*2]; + v_x0[0] = (vec_bf16)vec_load_vec(x_bf); + + vec_make_mult4(v_x0); + + for (; i + 8 <= n8; i += 8) { + vec_load_mult288a_mma(&temp[0], &va0[i + 0], &va1[i + 0], &v_x0[ 0]); + vec_load_mult288b_mma(&temp[0], &va2[i + 0], &va3[i + 0], &v_x0[ 4]); + vec_load_mult288b_mma(&temp[0], &vb0[i + 0], &vb1[i + 0], &v_x0[ 8]); + vec_load_mult288b_mma(&temp[0], &vb2[i + 0], &vb3[i + 0], &v_x0[12]); + + vec_load8_pair(vy0, &v_y[(i * 2) + 0]); + + vec_reduce88_mma(&temp[0], temp0 + 0, v_alpha, vy0 + 0); + + vec_store8_pair(&v_y[(i * 2) + 0], vy0); + } + + if (n8 & 4) { + vec_load_mult284a_mma(&temp[0], &va0[i + 0], &va1[i + 0], &v_x0[ 0]); + vec_load_mult284b_mma(&temp[0], &va2[i + 0], &va3[i + 0], &v_x0[ 4]); + vec_load_mult284b_mma(&temp[0], &vb0[i + 0], &vb1[i + 0], &v_x0[ 8]); + vec_load_mult284b_mma(&temp[0], &vb2[i + 0], &vb3[i + 0], &v_x0[12]); + + vec_load4_pair(vy0, &v_y[(i * 2) + 0]); + + vec_reduce84_mma(&temp[0], temp0, v_alpha, vy0); + + vec_store4_pair(&v_y[(i * 2) + 0], vy0); + + i += 4; + } +#else + vec_bf16 v_x0[13]; + vec_f32 vy0[2*4]; + v_x0[0] = (vec_bf16)vec_load_vec(x_bf); + + v_x0[ 4] = (vec_bf16)vec_splat((vec_f32)v_x0[0], 1); + v_x0[ 8] = (vec_bf16)vec_splat((vec_f32)v_x0[0], 2); + v_x0[12] = (vec_bf16)vec_splat((vec_f32)v_x0[0], 3); + + for (; i + 4 <= n8; i += 4) { + vec_load_mult28a_mma(&temp[0], &va0[i + 0], &va1[i + 0], v_x0[ 0]); + vec_load_mult28b_mma(&temp[0], &va2[i + 0], &va3[i + 0], v_x0[ 4]); + vec_load_mult28b_mma(&temp[0], &vb0[i + 0], &vb1[i + 0], v_x0[ 8]); + vec_load_mult28b_mma(&temp[0], &vb2[i + 0], &vb3[i + 0], v_x0[12]); + + vec_load4_pair(vy0, &v_y[(i * 2) + 0]); + + vec_reduce8_mma(&temp[0], temp0, v_alpha, vy0); + + vec_store4_pair(&v_y[(i * 2) + 0], vy0); + } +#endif + + for (; i < n8; i++) { + vec_load_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0]); + vec_load_mult22b_mma(&temp[0], &va2[i], &va3[i], v_x0[ 4]); + vec_load_mult22b_mma(&temp[0], &vb0[i], &vb1[i], v_x0[ 8]); + vec_load_mult22b_mma(&temp[0], &vb2[i], &vb3[i], v_x0[12]); + + vec_load_pair(vy0, &v_y[(i * 2) + 0]); + + vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0); + + vec_store_pair(&v_y[(i * 2) + 0], vy0); + } + + n &= 7; + if (n > 4) { + vec_loadN_mult22a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0], n); + vec_loadN_mult22b_mma(&temp[0], &va2[i], &va3[i], v_x0[ 4], n); + vec_loadN_mult22b_mma(&temp[0], &vb0[i], &vb1[i], v_x0[ 8], n); + vec_loadN_mult22b_mma(&temp[0], &vb2[i], &vb3[i], v_x0[12], n); + + n &= 3; + vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n); + + vec_reduce2_mma(&temp[0], temp0, v_alpha, vy0); + + vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n); + } else if (n) { + vec_loadN_mult11a_mma(&temp[0], &va0[i], &va1[i], v_x0[ 0], n); + vec_loadN_mult11b_mma(&temp[0], &va2[i], &va3[i], v_x0[ 4], n); + vec_loadN_mult11b_mma(&temp[0], &vb0[i], &vb1[i], v_x0[ 8], n); + vec_loadN_mult11b_mma(&temp[0], &vb2[i], &vb3[i], v_x0[12], n); + + vy0[0] = vec_loadN_f32(&v_y[(i * 2) + 0], n); + + vec_reduce1_mma(&temp[0], temp0, v_alpha, vy0); + + vec_storeN_f32(vy0[0], &v_y[(i * 2) + 0], n); + } +} +#endif + +#include "sbgemv_n.c" +#else +#include "sbgemv_n_vsx.c" +#endif +#endif + diff --git a/kernel/power/sbgemv_n_vsx.c b/kernel/power/sbgemv_n_vsx.c new file mode 100644 index 0000000000..390a87359d --- /dev/null +++ b/kernel/power/sbgemv_n_vsx.c @@ -0,0 +1,299 @@ +/*************************************************************************** +Copyright (c) 2024, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#ifndef SBGEMV_N_VSX_C +#define SBGEMV_N_VSX_C + +#include "sbgemv_common.c" + +#ifndef BF16GEMV_N_X +#define BF16GEMV_N_X +#define BF16GEMV_N_8 BF16GEMV_N_VSX_8 +#define BF16GEMV_N_4 BF16GEMV_N_VSX_4 +#define BF16GEMV_N_2 BF16GEMV_N_VSX_2 +#define BF16GEMV_N_1 BF16GEMV_N_VSX_1 +#endif + +#define USE_BFGEMV_8_N_VSX + +static void BF16GEMV_N_VSX_1(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOAT alpha) +{ + IFLOAT *a0; + vec_bf16 zero = { 0, 0, 0, 0, 0, 0, 0, 0 }; + vec_f32 v_alpha = { alpha, alpha, alpha, alpha }; + + a0 = ap[0]; + + vec_bf16 *va0 = (vec_bf16 *)a0; + + vec_bf16 *x_bf = (vec_bf16 *)(xo); + vec_f32 x_0 = vec_loadNHi(x_bf, 1, zero); + x_0 *= v_alpha; + + vec_f32 v_x0 = vec_splat(x_0, 0); + + vec_f32 *v_y = (vec_f32 *)y; + BLASLONG n8 = n / 8; + BLASLONG i = 0; + vec_f32 vy0[2]; + + for (; i < n8; i++) { + vec_load_pair(vy0, &v_y[(i * 2) + 0]); + + vec_load_mult2(v_x0, &va0[i], zero, vy0); + + vec_store_pair(&v_y[(i * 2) + 0], vy0); + } + + n &= 7; + if (n > 4) { + BLASLONG n3 = n & 3; + vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n3); + + vec_loadN_mult2(v_x0, &va0[i], n, zero, vy0); + + vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n3); + } else if (n) { + vy0[0] = vec_loadN_f32(&v_y[(i * 2) + 0], n); + + vy0[0] += vec_loadNHi_mult(&va0[i], v_x0, n, zero); + + vec_storeN_f32(vy0[0], &v_y[(i * 2) + 0], n); + } +} + +static void BF16GEMV_N_VSX_2(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOAT alpha) +{ + IFLOAT *a0, *a1; + vec_bf16 zero = { 0, 0, 0, 0, 0, 0, 0, 0 }; + vec_f32 v_alpha = { alpha, alpha, alpha, alpha }; + + a0 = ap[0]; + a1 = ap[1]; + + vec_bf16 *va0 = (vec_bf16 *)a0; + vec_bf16 *va1 = (vec_bf16 *)a1; + + vec_bf16 *x_bf = (vec_bf16 *)(xo); + vec_f32 x_0 = vec_loadNHi(x_bf, 2, zero); + x_0 *= v_alpha; + + vec_f32 v_x0 = vec_splat(x_0, 0); + vec_f32 v_x1 = vec_splat(x_0, 1); + + vec_f32 *v_y = (vec_f32 *)y; + BLASLONG n8 = n / 8; + BLASLONG i = 0; + vec_f32 vy0[2]; + + for (; i < n8; i++) { + vec_load_pair(vy0, &v_y[(i * 2) + 0]); + + vec_load_mult2(v_x0, &va0[i], zero, vy0); + vec_load_mult2(v_x1, &va1[i], zero, vy0); + + vec_store_pair(&v_y[(i * 2) + 0], vy0); + } + + n &= 7; + if (n > 4) { + BLASLONG n3 = n & 3; + vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n3); + + vec_loadN_mult2(v_x0, &va0[i], n, zero, vy0); + vec_loadN_mult2(v_x1, &va1[i], n, zero, vy0); + + vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n3); + } else if (n) { + vy0[0] = vec_loadN_f32(&v_y[(i * 2) + 0], n); + + vy0[0] += vec_loadNHi_mult(&va0[i], v_x0, n, zero); + vy0[0] += vec_loadNHi_mult(&va1[i], v_x1, n, zero); + + vec_storeN_f32(vy0[0], &v_y[(i * 2) + 0], n); + } +} + +static void BF16GEMV_N_VSX_4(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, FLOAT alpha) +{ + IFLOAT *a0, *a1, *a2, *a3; + vec_bf16 zero = { 0, 0, 0, 0, 0, 0, 0, 0 }; + vec_f32 v_alpha = { alpha, alpha, alpha, alpha }; + + a0 = ap[0]; + a1 = ap[1]; + a2 = ap[2]; + a3 = ap[3]; + + vec_bf16 *va0 = (vec_bf16 *)a0; + vec_bf16 *va1 = (vec_bf16 *)a1; + vec_bf16 *va2 = (vec_bf16 *)a2; + vec_bf16 *va3 = (vec_bf16 *)a3; + + vec_bf16 *x_bf = (vec_bf16 *)(xo); + vec_f32 x_0 = vec_loadNHi(x_bf, 4, zero); + x_0 *= v_alpha; + + vec_f32 v_x0 = vec_splat(x_0, 0); + vec_f32 v_x1 = vec_splat(x_0, 1); + vec_f32 v_x2 = vec_splat(x_0, 2); + vec_f32 v_x3 = vec_splat(x_0, 3); + + vec_f32 *v_y = (vec_f32 *)y; + BLASLONG n8 = n / 8; + BLASLONG i = 0; + vec_f32 vy0[2]; + + for (; i < n8; i++) { + vec_load_pair(vy0, &v_y[(i * 2) + 0]); + + vec_load_mult2(v_x0, &va0[i], zero, vy0); + vec_load_mult2(v_x1, &va1[i], zero, vy0); + vec_load_mult2(v_x2, &va2[i], zero, vy0); + vec_load_mult2(v_x3, &va3[i], zero, vy0); + + vec_store_pair(&v_y[(i * 2) + 0], vy0); + } + + n &= 7; + if (n > 4) { + BLASLONG n3 = n & 3; + vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n3); + + vec_loadN_mult2(v_x0, &va0[i], n, zero, vy0); + vec_loadN_mult2(v_x1, &va1[i], n, zero, vy0); + vec_loadN_mult2(v_x2, &va2[i], n, zero, vy0); + vec_loadN_mult2(v_x3, &va3[i], n, zero, vy0); + + vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n3); + } else if (n) { + vy0[0] = vec_loadN_f32(&v_y[(i * 2) + 0], n); + + vy0[0] += vec_loadNHi_mult(&va0[i], v_x0, n, zero); + vy0[0] += vec_loadNHi_mult(&va1[i], v_x1, n, zero); + vy0[0] += vec_loadNHi_mult(&va2[i], v_x2, n, zero); + vy0[0] += vec_loadNHi_mult(&va3[i], v_x3, n, zero); + + vec_storeN_f32(vy0[0], &v_y[(i * 2) + 0], n); + } +} + +#ifdef USE_BFGEMV_8_N_VSX +static void BF16GEMV_N_VSX_8(BLASLONG n, IFLOAT **ap, IFLOAT *xo, FLOAT *y, BLASLONG lda4, FLOAT alpha) +{ + IFLOAT *a0, *a1, *a2, *a3, *b0, *b1, *b2, *b3; + vec_bf16 zero = { 0, 0, 0, 0, 0, 0, 0, 0 }; + vec_f32 v_alpha = { alpha, alpha, alpha, alpha }; + + a0 = ap[0]; + a1 = ap[1]; + a2 = ap[2]; + a3 = ap[3]; + b0 = a0 + lda4; + b1 = a1 + lda4; + b2 = a2 + lda4; + b3 = a3 + lda4; + + vec_bf16 *va0 = (vec_bf16 *)a0; + vec_bf16 *va1 = (vec_bf16 *)a1; + vec_bf16 *va2 = (vec_bf16 *)a2; + vec_bf16 *va3 = (vec_bf16 *)a3; + vec_bf16 *vb0 = (vec_bf16 *)b0; + vec_bf16 *vb1 = (vec_bf16 *)b1; + vec_bf16 *vb2 = (vec_bf16 *)b2; + vec_bf16 *vb3 = (vec_bf16 *)b3; + + vec_bf16 *x_bf = (vec_bf16 *)(xo); + vec_bf16 x_in = (vec_bf16)vec_load_vec(x_bf); + vec_f32 x_0 = BF16_HI(x_in, zero); + vec_f32 x_1 = BF16_LO(x_in, zero); + x_0 *= v_alpha; + x_1 *= v_alpha; + + vec_f32 v_x0 = vec_splat(x_0, 0); + vec_f32 v_x1 = vec_splat(x_0, 1); + vec_f32 v_x2 = vec_splat(x_0, 2); + vec_f32 v_x3 = vec_splat(x_0, 3); + vec_f32 v_x4 = vec_splat(x_1, 0); + vec_f32 v_x5 = vec_splat(x_1, 1); + vec_f32 v_x6 = vec_splat(x_1, 2); + vec_f32 v_x7 = vec_splat(x_1, 3); + + vec_f32 *v_y = (vec_f32 *)y; + BLASLONG n8 = n / 8; + BLASLONG i = 0; + vec_f32 vy0[2]; + + for (; i < n8; i++) { + vec_load_pair(vy0, &v_y[(i * 2) + 0]); + + vec_load_mult2(v_x0, &va0[i], zero, vy0); + vec_load_mult2(v_x1, &va1[i], zero, vy0); + vec_load_mult2(v_x2, &va2[i], zero, vy0); + vec_load_mult2(v_x3, &va3[i], zero, vy0); + vec_load_mult2(v_x4, &vb0[i], zero, vy0); + vec_load_mult2(v_x5, &vb1[i], zero, vy0); + vec_load_mult2(v_x6, &vb2[i], zero, vy0); + vec_load_mult2(v_x7, &vb3[i], zero, vy0); + + vec_store_pair(&v_y[(i * 2) + 0], vy0); + } + + n &= 7; + if (n > 4) { + BLASLONG n3 = n & 3; + vec_loadN2_f32(vy0, &v_y[(i * 2) + 0], n3); + + vec_loadN_mult2(v_x0, &va0[i], n, zero, vy0); + vec_loadN_mult2(v_x1, &va1[i], n, zero, vy0); + vec_loadN_mult2(v_x2, &va2[i], n, zero, vy0); + vec_loadN_mult2(v_x3, &va3[i], n, zero, vy0); + vec_loadN_mult2(v_x4, &vb0[i], n, zero, vy0); + vec_loadN_mult2(v_x5, &vb1[i], n, zero, vy0); + vec_loadN_mult2(v_x6, &vb2[i], n, zero, vy0); + vec_loadN_mult2(v_x7, &vb3[i], n, zero, vy0); + + vec_storeN2_f32(vy0, &v_y[(i * 2) + 0], n3); + } else if (n) { + vy0[0] = vec_loadN_f32(&v_y[(i * 2) + 0], n); + + vy0[0] += vec_loadNHi_mult(&va0[i], v_x0, n, zero); + vy0[0] += vec_loadNHi_mult(&va1[i], v_x1, n, zero); + vy0[0] += vec_loadNHi_mult(&va2[i], v_x2, n, zero); + vy0[0] += vec_loadNHi_mult(&va3[i], v_x3, n, zero); + vy0[0] += vec_loadNHi_mult(&vb0[i], v_x4, n, zero); + vy0[0] += vec_loadNHi_mult(&vb1[i], v_x5, n, zero); + vy0[0] += vec_loadNHi_mult(&vb2[i], v_x6, n, zero); + vy0[0] += vec_loadNHi_mult(&vb3[i], v_x7, n, zero); + + vec_storeN_f32(vy0[0], &v_y[(i * 2) + 0], n); + } +} +#endif + +#include "sbgemv_n.c" +#endif diff --git a/kernel/power/sbgemv_t.c b/kernel/power/sbgemv_t.c new file mode 100644 index 0000000000..594b1fc57b --- /dev/null +++ b/kernel/power/sbgemv_t.c @@ -0,0 +1,137 @@ +/*************************************************************************** +Copyright (c) 2024, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#ifndef SBGEMV_T_COMMON_C +#define SBGEMV_T_COMMON_C + +#if (defined(_ARCH_PWR10) && (defined(USE_BFGEMV_8_T_MMA) || (!defined(USE_BFGEMV_N_MMA) && defined(USE_BFGEMV_8_T_VSX)))) || (!defined(_ARCH_PWR10) && defined(USE_BFGEMV_8_T_VSX)) +#define USE_T_8 +#endif + +int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT *x, BLASLONG inc_x, FLOAT beta, FLOAT *y, BLASLONG inc_y) +{ + IFLOAT *xbuffer, *a_ptr; + IFLOAT buffer[NBMAX] __attribute__((aligned(16))); + FLOAT ybuffer[8] __attribute__((aligned(16))); + FLOAT *y_ptr; + + if ((m < 1) || (n < 1)) return 0; + + if (inc_y == 1) { + BF16GEMV_N_beta(n, y, y, beta); + } + + xbuffer = buffer; + + BLASLONG lda4 = lda << 2; +#ifdef USE_T_8 + BLASLONG lda8 = lda << 3; +#endif + BLASLONG NB = NBMAX; + BLASLONG m2 = (m & (NBMAX - 1)); + + while (NB == NBMAX) { + m -= NB; + if (m < 0) { + if (m2 == 0) break; + NB = m2; + } + + a_ptr = a; + a += NB; + y_ptr = y; + + if (inc_x != 1) { + copy_x(NB, x, xbuffer, inc_x); + x += NB * inc_x; + } else { + xbuffer = x; + x += NB; + } + + if (inc_y == 1) { +#ifdef USE_T_8 + for (BLASLONG j = 0; j + 8 <= n; j += 8) { + BF16GEMV_T_8(NB, lda, a_ptr, xbuffer, y_ptr, alpha); + y_ptr += 8; + a_ptr += lda8; + } + if (n & 4) { +#else + for (BLASLONG j = 0; j + 4 <= n; j += 4) { +#endif + BF16GEMV_T_4(NB, lda, a_ptr, xbuffer, y_ptr, alpha); + y_ptr += 4; + a_ptr += lda4; + } + if (n & 2) { + BF16GEMV_T_2(NB, lda, a_ptr, xbuffer, y_ptr, alpha); + y_ptr += 2; + a_ptr += (lda * 2); + } + if (n & 1) { + BF16GEMV_T_1(NB, lda, a_ptr, xbuffer, y_ptr, alpha); + } + } else { +#ifdef USE_T_8 + for (BLASLONG j = 0; j + 8 <= n; j += 8) { + memset(ybuffer, 0, sizeof(FLOAT) * 8); + BF16GEMV_T_8(NB, lda, a_ptr, xbuffer, ybuffer, alpha); + copy_y(8, ybuffer, y_ptr, inc_y, beta); + y_ptr += 8 * inc_y; + a_ptr += lda8; + } + if (n & 4) { +#else + for (BLASLONG j = 0; j + 4 <= n; j += 4) { +#endif + memset(ybuffer, 0, sizeof(FLOAT) * 4); + BF16GEMV_T_4(NB, lda, a_ptr, xbuffer, ybuffer, alpha); + copy_y(4, ybuffer, y_ptr, inc_y, beta); + y_ptr += 4 * inc_y; + a_ptr += lda4; + } + if (n & 2) { + memset(ybuffer, 0, sizeof(FLOAT) * 4); + BF16GEMV_T_2(NB, lda, a_ptr, xbuffer, ybuffer, alpha); + copy_y(2, ybuffer, y_ptr, inc_y, beta); + y_ptr += 2 * inc_y; + a_ptr += (lda * 2); + } + if (n & 1) { + memset(ybuffer, 0, sizeof(FLOAT) * 4); + BF16GEMV_T_1(NB, lda, a_ptr, xbuffer, ybuffer, alpha); + copy_y(1, ybuffer, y_ptr, inc_y, beta); + } + beta = (FLOAT)1; + } + } + + return 0; +} +#endif + diff --git a/kernel/power/sbgemv_t_power10.c b/kernel/power/sbgemv_t_power10.c new file mode 100644 index 0000000000..40c166354b --- /dev/null +++ b/kernel/power/sbgemv_t_power10.c @@ -0,0 +1,338 @@ +/*************************************************************************** +Copyright (c) 2024, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#ifndef SBGEMV_T_MMA_C +#define SBGEMV_T_MMA_C + +#define USE_BFGEMV_T_MMA + +#ifdef USE_BFGEMV_T_MMA +#include "sbgemv_common_power10.c" + +#ifndef BF16GEMV_T_X +#define BF16GEMV_T_X +#define BF16GEMV_T_8 BF16GEMV_T_MMA_8 +#define BF16GEMV_T_4 BF16GEMV_T_MMA_4 +#define BF16GEMV_T_2 BF16GEMV_T_MMA_2 +#define BF16GEMV_T_1 BF16GEMV_T_MMA_1 +#endif + +#define USE_BFGEMV_8_T_MMA + +static void BF16GEMV_T_MMA_1(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha) +{ + IFLOAT *a0; + vec_bf16 *va0, *v_x; + __vector_quad temp0; + vec_f32 temp00[4]; + vec_bf16 inp[4]; + + __builtin_mma_xxsetaccz(&temp0); + + a0 = ap; + va0 = (vec_bf16 *)a0; + v_x = (vec_bf16 *)x; + BLASLONG n8 = n / 8; + BLASLONG i = 0; + + for (; i + 4 <= n8; i += 4) { + vec_load_pair2(inp, &v_x[i]); + + vec_load_mult4_mma(&temp0, &va0[i + 0], inp); + } + + if (n8 & 2) { + vec_load_pair((vec_f32 *)inp, (vec_f32 *)&v_x[i]); + + vec_load_mult2_mma(&temp0, &va0[i + 0], inp); + + i += 2; + } + + if (n8 & 1) { + inp[0] = (vec_bf16)vec_load_vec(&v_x[i]); + + vec_load_mult_mma(&temp0, &va0[i], inp[0]); + + i++; + } + + n &= 7; + if (n) { + inp[0] = vec_loadN(&v_x[i], n); + + vec_loadN_mult_mma(&temp0, &va0[i], inp[0], n); + } + + __builtin_mma_disassemble_acc((void*)temp00, &temp0); + + y[0] += (alpha * (temp00[0][0] + temp00[1][1] + temp00[2][2] + temp00[3][3])); +} + +static void BF16GEMV_T_MMA_2(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha) +{ + IFLOAT *a0, *a1; + vec_bf16 *va0, *va1, *v_x; + __vector_quad temp0[2]; + vec_f32 temp00[4*2]; + vec_bf16 inp[4]; + + vec_setzero_2(&temp0[0]); + + a0 = ap; + a1 = ap + lda; + va0 = (vec_bf16 *)a0; + va1 = (vec_bf16 *)a1; + v_x = (vec_bf16 *)x; + BLASLONG n8 = n / 8; + BLASLONG i = 0; + + for (; i + 4 <= n8; i += 4) { + vec_load_pair2(inp, &v_x[i]); + + vec_load_mult42_mma(&temp0[0], &va0[i + 0], &va1[i + 0], inp); + } + + if (n8 & 2) { + vec_load_pair((vec_f32 *)inp, (vec_f32 *)&v_x[i]); + + vec_load_mult22_mma(&temp0[0], &va0[i + 0], &va1[i + 0], inp); + + i += 2; + } + + if (n8 & 1) { + inp[0] = (vec_bf16)vec_load_vec(&v_x[i]); + + vec_load_mult12a_mma(&temp0[0], &va0[i], &va1[i], inp[0]); + + i++; + } + + n &= 7; + if (n) { + inp[0] = vec_loadN(&v_x[i], n); + + vec_loadN_mult12a_mma(&temp0[0], &va0[i], &va1[i], inp[0], n); + } + + vec_reduce_2(temp00, &temp0[0]); + + y[0] += (alpha * (temp00[0][0] + temp00[1][1] + temp00[2][2] + temp00[3][3])); + y[1] += (alpha * (temp00[4][0] + temp00[5][1] + temp00[6][2] + temp00[7][3])); +} + +static void BF16GEMV_T_MMA_4(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha) +{ + IFLOAT *a0, *a1, *a2, *a3; + vec_bf16 *va0, *va1, *va2, *va3, *v_x; + __vector_quad temp0[4]; + vec_f32 temp00[4*4]; + vec_bf16 inp[4]; + + vec_setzero_4(&temp0[0]); + + a0 = ap; + a1 = ap + lda; + a2 = a1 + lda; + a3 = a2 + lda; + va0 = (vec_bf16 *)a0; + va1 = (vec_bf16 *)a1; + va2 = (vec_bf16 *)a2; + va3 = (vec_bf16 *)a3; + v_x = (vec_bf16 *)x; + BLASLONG n8 = n / 8; + BLASLONG i = 0; + + for (; i + 4 <= n8; i += 4) { + vec_load_pair2(inp, &v_x[i]); + + vec_load_mult44_mma(&temp0[0], &va0[i + 0], &va1[i + 0], &va2[i + 0], &va3[i + 0], inp); + } + + if (n8 & 2) { + vec_load_pair((vec_f32 *)inp, (vec_f32 *)&v_x[i]); + + vec_load_mult24_mma(&temp0[0], &va0[i + 0], &va1[i + 0], &va2[i + 0], &va3[i + 0], inp); + + i += 2; + } + + if (n8 & 1) { + inp[0] = (vec_bf16)vec_load_vec(&v_x[i]); + + vec_load_mult14_mma(&temp0[0], &va0[i], &va1[i], &va2[i], &va3[i], inp[0]); + + i++; + } + + n &= 7; + if (n) { + inp[0] = vec_loadN(&v_x[i], n); + + vec_loadN_mult14_mma(&temp0[0], &va0[i], &va1[i], &va2[i], &va3[i], inp[0], n); + } + + vec_reduce_4(temp00, &temp0[0]); + + vec_f32 t0, t1, t2, t3, t4, t5, t6, t7; + vec_f32 a = { alpha, alpha, alpha, alpha }; + vec_f32 *v_y = (vec_f32 *) y; + + t0 = vec_mergeh(temp00[ 0], temp00[ 4]); + t1 = vec_mergeh(temp00[ 8], temp00[12]); + t2 = vec_mergeo(temp00[ 1], temp00[ 5]); + t3 = vec_mergeo(temp00[ 9], temp00[13]); + t4 = vec_mergel(temp00[ 2], temp00[ 6]); + t5 = vec_mergel(temp00[10], temp00[14]); + t6 = vec_mergeo(temp00[ 3], temp00[ 7]); + t7 = vec_mergeo(temp00[11], temp00[15]); + t0 = vec_xxpermdi(t0, t1, 0); + t2 = vec_xxpermdi(t2, t3, 0); + t4 = vec_xxpermdi(t4, t5, 0); + t6 = vec_xxpermdi(t6, t7, 3); + + t0 += t2 + t4 + t6; + + v_y[0] += (a * t0); +} + +#ifdef USE_BFGEMV_8_T_MMA +static void BF16GEMV_T_MMA_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha) +{ + IFLOAT *a0, *a1, *a2, *a3, *a4, *a5, *a6, *a7; + vec_bf16 *va0, *va1, *va2, *va3, *va4, *va5, *va6, *va7, *v_x; + __vector_quad temp0[8]; + vec_f32 temp00[4*8]; + vec_bf16 inp[4]; + + vec_setzero_8(&temp0[0]); + + BLASLONG lda4 = lda << 2; + a0 = ap; + a1 = ap + lda; + a2 = a1 + lda; + a3 = a2 + lda; + a4 = a0 + lda4; + a5 = a1 + lda4; + a6 = a2 + lda4; + a7 = a3 + lda4; + va0 = (vec_bf16 *)a0; + va1 = (vec_bf16 *)a1; + va2 = (vec_bf16 *)a2; + va3 = (vec_bf16 *)a3; + va4 = (vec_bf16 *)a4; + va5 = (vec_bf16 *)a5; + va6 = (vec_bf16 *)a6; + va7 = (vec_bf16 *)a7; + v_x = (vec_bf16 *)x; + BLASLONG n8 = n / 8; + BLASLONG i = 0; + + for (; i + 4 <= n8; i += 4) { + vec_load_pair2(inp, &v_x[i]); + + vec_load_mult44_mma(&temp0[0], &va0[i + 0], &va1[i + 0], &va2[i + 0], &va3[i + 0], inp); + vec_load_mult44_mma(&temp0[4], &va4[i + 0], &va5[i + 0], &va6[i + 0], &va7[i + 0], inp); + } + + if (n8 & 2) { + vec_load_pair((vec_f32 *)inp, (vec_f32 *)&v_x[i]); + + vec_load_mult24_mma(&temp0[0], &va0[i + 0], &va1[i + 0], &va2[i + 0], &va3[i + 0], inp); + vec_load_mult24_mma(&temp0[4], &va4[i + 0], &va5[i + 0], &va6[i + 0], &va7[i + 0], inp); + + i += 2; + } + + if (n8 & 1) { + inp[0] = (vec_bf16)vec_load_vec(&v_x[i]); + + vec_load_mult14_mma(&temp0[0], &va0[i], &va1[i], &va2[i], &va3[i], inp[0]); + vec_load_mult14_mma(&temp0[4], &va4[i], &va5[i], &va6[i], &va7[i], inp[0]); + + i++; + } + + n &= 7; + if (n) { + inp[0] = vec_loadN(&v_x[i], n); + + vec_loadN_mult14_mma(&temp0[0], &va0[i], &va1[i], &va2[i], &va3[i], inp[0], n); + vec_loadN_mult14_mma(&temp0[4], &va4[i], &va5[i], &va6[i], &va7[i], inp[0], n); + } + + vec_reduce_8(temp00, &temp0[0]); + + vec_f32 t0, t1, t2, t3, t4, t5, t6, t7, t10, t11, t12, t13, t14, t15, t16, t17; + vec_f32 a = { alpha, alpha, alpha, alpha }; + vec_f32 *v_y = (vec_f32 *) y; + + t0 = vec_mergeh(temp00[ 0], temp00[ 4]); + t1 = vec_mergeh(temp00[ 8], temp00[12]); + t2 = vec_mergeo(temp00[ 1], temp00[ 5]); + t3 = vec_mergeo(temp00[ 9], temp00[13]); + t4 = vec_mergel(temp00[ 2], temp00[ 6]); + t5 = vec_mergel(temp00[10], temp00[14]); + t6 = vec_mergeo(temp00[ 3], temp00[ 7]); + t7 = vec_mergeo(temp00[11], temp00[15]); + t0 = vec_xxpermdi(t0, t1, 0); + t2 = vec_xxpermdi(t2, t3, 0); + t4 = vec_xxpermdi(t4, t5, 0); + t6 = vec_xxpermdi(t6, t7, 3); + + t0 += t2 + t4 + t6; + + t10 = vec_mergeh(temp00[16], temp00[20]); + t11 = vec_mergeh(temp00[24], temp00[28]); + t12 = vec_mergeo(temp00[17], temp00[21]); + t13 = vec_mergeo(temp00[25], temp00[29]); + t14 = vec_mergel(temp00[18], temp00[22]); + t15 = vec_mergel(temp00[26], temp00[30]); + t16 = vec_mergeo(temp00[19], temp00[23]); + t17 = vec_mergeo(temp00[27], temp00[31]); + t10 = vec_xxpermdi(t10, t11, 0); + t12 = vec_xxpermdi(t12, t13, 0); + t14 = vec_xxpermdi(t14, t15, 0); + t16 = vec_xxpermdi(t16, t17, 3); + + t10 += t12 + t14 + t16; + + vec_f32 inp2[2]; + vec_load_pair(inp2, v_y); + inp2[0] += (a * t0); + inp2[1] += (a * t10); + vec_store_pair(v_y, inp2); +} +#endif + +#include "sbgemv_t.c" +#else +#include "sbgemv_t_vsx.c" +#endif +#endif + diff --git a/kernel/power/sbgemv_t_vsx.c b/kernel/power/sbgemv_t_vsx.c new file mode 100644 index 0000000000..e72d2f31e0 --- /dev/null +++ b/kernel/power/sbgemv_t_vsx.c @@ -0,0 +1,292 @@ +/*************************************************************************** +Copyright (c) 2024, The OpenBLAS Project +All rights reserved. +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in +the documentation and/or other materials provided with the +distribution. +3. Neither the name of the OpenBLAS project nor the names of +its contributors may be used to endorse or promote products +derived from this software without specific prior written permission. +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*****************************************************************************/ + +#ifndef SBGEMV_T_VSX_C +#define SBGEMV_T_VSX_C + +#include "sbgemv_common.c" + +#ifndef BF16GEMV_T_X +#define BF16GEMV_T_X +#define BF16GEMV_T_8 BF16GEMV_T_VSX_8 +#define BF16GEMV_T_4 BF16GEMV_T_VSX_4 +#define BF16GEMV_T_2 BF16GEMV_T_VSX_2 +#define BF16GEMV_T_1 BF16GEMV_T_VSX_1 +#endif + +#define USE_BFGEMV_8_T_VSX + +static void BF16GEMV_T_VSX_1(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha) +{ + IFLOAT *a0; + vec_bf16 *va0, *v_x; + vec_f32 temp0 = { 0, 0, 0, 0 }; + vec_bf16 zero = { 0, 0, 0, 0, 0, 0, 0, 0 }; + vec_f32 inp[2]; + + a0 = ap; + va0 = (vec_bf16 *)a0; + v_x = (vec_bf16 *)x; + BLASLONG n8 = n / 8; + BLASLONG i = 0; + + for (; i < n8; i++) { + vec_load_vec2(&v_x[i], inp, zero); + + temp0 += vec_load_mult(&va0[i], inp, zero); + } + + n &= 7; + if (n > 4) { + vec_loadN_vec2(&v_x[i], inp, n, zero); + + temp0 += vec_loadN_mult(&va0[i], inp, n, zero); + } else if (n) { + inp[0] = vec_loadNHi(&v_x[i], n, zero); + + temp0 += vec_loadNHi_mult(&va0[i], inp[0], n, zero); + } + + y[0] += (alpha * (temp0[0] + temp0[1] + temp0[2] + temp0[3])); +} + +static void BF16GEMV_T_VSX_2(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha) +{ + IFLOAT *a0, *a1; + vec_bf16 *va0, *va1, *v_x; + vec_f32 temp0 = { 0, 0, 0, 0 }; + vec_f32 temp1 = { 0, 0, 0, 0 }; + vec_bf16 zero = { 0, 0, 0, 0, 0, 0, 0, 0 }; + vec_f32 inp[2]; + + a0 = ap; + a1 = ap + lda; + va0 = (vec_bf16 *)a0; + va1 = (vec_bf16 *)a1; + v_x = (vec_bf16 *)x; + BLASLONG n8 = n / 8; + BLASLONG i = 0; + + for (; i < n8; i++) { + vec_load_vec2(&v_x[i], inp, zero); + + temp0 += vec_load_mult(&va0[i], inp, zero); + temp1 += vec_load_mult(&va1[i], inp, zero); + } + + n &= 7; + if (n > 4) { + vec_loadN_vec2(&v_x[i], inp, n, zero); + + temp0 += vec_loadN_mult(&va0[i], inp, n, zero); + temp1 += vec_loadN_mult(&va1[i], inp, n, zero); + } else if (n) { + inp[0] = vec_loadNHi(&v_x[i], n, zero); + + temp0 += vec_loadNHi_mult(&va0[i], inp[0], n, zero); + temp1 += vec_loadNHi_mult(&va1[i], inp[0], n, zero); + } + + y[0] += (alpha * (temp0[0] + temp0[1] + temp0[2] + temp0[3])); + y[1] += (alpha * (temp1[0] + temp1[1] + temp1[2] + temp1[3])); +} + +static void BF16GEMV_T_VSX_4(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha) +{ + IFLOAT *a0, *a1, *a2, *a3; + vec_bf16 *va0, *va1, *va2, *va3, *v_x; + vec_f32 temp0 = { 0, 0, 0, 0 }; + vec_f32 temp1 = { 0, 0, 0, 0 }; + vec_f32 temp2 = { 0, 0, 0, 0 }; + vec_f32 temp3 = { 0, 0, 0, 0 }; + vec_bf16 zero = { 0, 0, 0, 0, 0, 0, 0, 0 }; + vec_f32 inp[2]; + + a0 = ap; + a1 = ap + lda; + a2 = a1 + lda; + a3 = a2 + lda; + va0 = (vec_bf16 *)a0; + va1 = (vec_bf16 *)a1; + va2 = (vec_bf16 *)a2; + va3 = (vec_bf16 *)a3; + v_x = (vec_bf16 *)x; + BLASLONG n8 = n / 8; + BLASLONG i = 0; + + for (; i < n8; i++) { + vec_load_vec2(&v_x[i], inp, zero); + + temp0 += vec_load_mult(&va0[i], inp, zero); + temp1 += vec_load_mult(&va1[i], inp, zero); + temp2 += vec_load_mult(&va2[i], inp, zero); + temp3 += vec_load_mult(&va3[i], inp, zero); + } + + n &= 7; + if (n > 4) { + vec_loadN_vec2(&v_x[i], inp, n, zero); + + temp0 += vec_loadN_mult(&va0[i], inp, n, zero); + temp1 += vec_loadN_mult(&va1[i], inp, n, zero); + temp2 += vec_loadN_mult(&va2[i], inp, n, zero); + temp3 += vec_loadN_mult(&va3[i], inp, n, zero); + } else if (n) { + inp[0] = vec_loadNHi(&v_x[i], n, zero); + + temp0 += vec_loadNHi_mult(&va0[i], inp[0], n, zero); + temp1 += vec_loadNHi_mult(&va1[i], inp[0], n, zero); + temp2 += vec_loadNHi_mult(&va2[i], inp[0], n, zero); + temp3 += vec_loadNHi_mult(&va3[i], inp[0], n, zero); + } + + vec_f32 t0, t1, t2, t3; + vec_f32 a = { alpha, alpha, alpha, alpha }; + vec_f32 *v_y = (vec_f32 *) y; + + t0 = vec_mergeh(temp0, temp2); + t1 = vec_mergel(temp0, temp2); + t2 = vec_mergeh(temp1, temp3); + t3 = vec_mergel(temp1, temp3); + temp0 = vec_mergeh(t0, t2); + temp1 = vec_mergel(t0, t2); + temp2 = vec_mergeh(t1, t3); + temp3 = vec_mergel(t1, t3); + temp0 += temp1 + temp2 + temp3; + + v_y[0] += (a * temp0); +} + +#ifdef USE_BFGEMV_8_T_VSX +static void BF16GEMV_T_VSX_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FLOAT *y, FLOAT alpha) +{ + IFLOAT *a0, *a1, *a2, *a3, *a4, *a5, *a6, *a7; + vec_bf16 *va0, *va1, *va2, *va3, *va4, *va5, *va6, *va7, *v_x; + vec_f32 temp0 = { 0, 0, 0, 0 }; + vec_f32 temp1 = { 0, 0, 0, 0 }; + vec_f32 temp2 = { 0, 0, 0, 0 }; + vec_f32 temp3 = { 0, 0, 0, 0 }; + vec_f32 temp4 = { 0, 0, 0, 0 }; + vec_f32 temp5 = { 0, 0, 0, 0 }; + vec_f32 temp6 = { 0, 0, 0, 0 }; + vec_f32 temp7 = { 0, 0, 0, 0 }; + vec_bf16 zero = { 0, 0, 0, 0, 0, 0, 0, 0 }; + vec_f32 inp[2]; + + BLASLONG lda4 = lda << 2; + a0 = ap; + a1 = ap + lda; + a2 = a1 + lda; + a3 = a2 + lda; + a4 = a0 + lda4; + a5 = a1 + lda4; + a6 = a2 + lda4; + a7 = a3 + lda4; + va0 = (vec_bf16 *)a0; + va1 = (vec_bf16 *)a1; + va2 = (vec_bf16 *)a2; + va3 = (vec_bf16 *)a3; + va4 = (vec_bf16 *)a4; + va5 = (vec_bf16 *)a5; + va6 = (vec_bf16 *)a6; + va7 = (vec_bf16 *)a7; + v_x = (vec_bf16 *)x; + BLASLONG n8 = n / 8; + BLASLONG i = 0; + + for (; i < n8; i++) { + vec_load_vec2(&v_x[i], inp, zero); + + temp0 += vec_load_mult(&va0[i], inp, zero); + temp1 += vec_load_mult(&va1[i], inp, zero); + temp2 += vec_load_mult(&va2[i], inp, zero); + temp3 += vec_load_mult(&va3[i], inp, zero); + temp4 += vec_load_mult(&va4[i], inp, zero); + temp5 += vec_load_mult(&va5[i], inp, zero); + temp6 += vec_load_mult(&va6[i], inp, zero); + temp7 += vec_load_mult(&va7[i], inp, zero); + } + + n &= 7; + if (n > 4) { + vec_loadN_vec2(&v_x[i], inp, n, zero); + + temp0 += vec_loadN_mult(&va0[i], inp, n, zero); + temp1 += vec_loadN_mult(&va1[i], inp, n, zero); + temp2 += vec_loadN_mult(&va2[i], inp, n, zero); + temp3 += vec_loadN_mult(&va3[i], inp, n, zero); + temp4 += vec_loadN_mult(&va4[i], inp, n, zero); + temp5 += vec_loadN_mult(&va5[i], inp, n, zero); + temp6 += vec_loadN_mult(&va6[i], inp, n, zero); + temp7 += vec_loadN_mult(&va7[i], inp, n, zero); + } else if (n) { + inp[0] = vec_loadNHi(&v_x[i], n, zero); + + temp0 += vec_loadNHi_mult(&va0[i], inp[0], n, zero); + temp1 += vec_loadNHi_mult(&va1[i], inp[0], n, zero); + temp2 += vec_loadNHi_mult(&va2[i], inp[0], n, zero); + temp3 += vec_loadNHi_mult(&va3[i], inp[0], n, zero); + temp4 += vec_loadNHi_mult(&va4[i], inp[0], n, zero); + temp5 += vec_loadNHi_mult(&va5[i], inp[0], n, zero); + temp6 += vec_loadNHi_mult(&va6[i], inp[0], n, zero); + temp7 += vec_loadNHi_mult(&va7[i], inp[0], n, zero); + } + + vec_f32 t0, t1, t2, t3, t10, t11, t12, t13; + vec_f32 a = { alpha, alpha, alpha, alpha }; + vec_f32 *v_y = (vec_f32 *) y; + + t0 = vec_mergeh(temp0, temp2); + t1 = vec_mergel(temp0, temp2); + t2 = vec_mergeh(temp1, temp3); + t3 = vec_mergel(temp1, temp3); + temp0 = vec_mergeh(t0, t2); + temp1 = vec_mergel(t0, t2); + temp2 = vec_mergeh(t1, t3); + temp3 = vec_mergel(t1, t3); + temp0 += temp1 + temp2 + temp3; + + t10 = vec_mergeh(temp4, temp6); + t11 = vec_mergel(temp4, temp6); + t12 = vec_mergeh(temp5, temp7); + t13 = vec_mergel(temp5, temp7); + temp4 = vec_mergeh(t10, t12); + temp5 = vec_mergel(t10, t12); + temp6 = vec_mergeh(t11, t13); + temp7 = vec_mergel(t11, t13); + temp4 += temp5 + temp6 + temp7; + + vec_load_pair(inp, v_y); + inp[0] += (a * temp0); + inp[1] += (a * temp4); + vec_store_pair(v_y, inp); +} +#endif + +#include "sbgemv_t.c" +#endif + diff --git a/test/compare_sgemm_sbgemm.c b/test/compare_sgemm_sbgemm.c index b8aaee8be3..05d9b33aba 100644 --- a/test/compare_sgemm_sbgemm.c +++ b/test/compare_sgemm_sbgemm.c @@ -202,16 +202,17 @@ main (int argc, char *argv[]) return ret; } + for (l = 0; l < 2; l++) { // l = 1 to test inc_x & inc_y not equal to one. for (x = 1; x <= loop; x++) { - k = (x == 0) ? 0 : 1; + k = (x == 0) ? 0 : l + 1; float *A = (float *)malloc_safe(x * x * sizeof(FLOAT)); - float *B = (float *)malloc_safe(x * sizeof(FLOAT)); - float *C = (float *)malloc_safe(x * sizeof(FLOAT)); + float *B = (float *)malloc_safe(x * sizeof(FLOAT) << l); + float *C = (float *)malloc_safe(x * sizeof(FLOAT) << l); bfloat16_bits *AA = (bfloat16_bits *)malloc_safe(x * x * sizeof(bfloat16_bits)); - bfloat16_bits *BB = (bfloat16_bits *)malloc_safe(x * sizeof(bfloat16_bits)); + bfloat16_bits *BB = (bfloat16_bits *)malloc_safe(x * sizeof(bfloat16_bits) << l); float *DD = (float *)malloc_safe(x * sizeof(FLOAT)); - float *CC = (float *)malloc_safe(x * sizeof(FLOAT)); + float *CC = (float *)malloc_safe(x * sizeof(FLOAT) << l); if ((A == NULL) || (B == NULL) || (C == NULL) || (AA == NULL) || (BB == NULL) || (DD == NULL) || (CC == NULL)) return 1; @@ -226,9 +227,9 @@ main (int argc, char *argv[]) sbstobf16_(&one, &A[j*x+i], &one, &atmp, &one); AA[j * x + i].v = atmp; } - B[j] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; - sbstobf16_(&one, &B[j], &one, &btmp, &one); - BB[j].v = btmp; + B[j << l] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5; + sbstobf16_(&one, &B[j << l], &one, &btmp, &one); + BB[j << l].v = btmp; } for (y = 0; y < 2; y++) { @@ -238,9 +239,9 @@ main (int argc, char *argv[]) transA = 'T'; } - memset(CC, 0, x * sizeof(FLOAT)); + memset(CC, 0, x * sizeof(FLOAT) << l); memset(DD, 0, x * sizeof(FLOAT)); - memset(C, 0, x * sizeof(FLOAT)); + memset(C, 0, x * sizeof(FLOAT) << l); SGEMV (&transA, &x, &x, &alpha, A, &x, B, &k, &beta, C, &k); SBGEMV (&transA, &x, &x, &alpha, (bfloat16*) AA, &x, (bfloat16*) BB, &k, &beta, CC, &k); @@ -248,15 +249,15 @@ main (int argc, char *argv[]) for (j = 0; j < x; j++) for (i = 0; i < x; i++) if (transA == 'N') { - DD[i] += float16to32 (AA[j * x + i]) * float16to32 (BB[j]); + DD[i] += float16to32 (AA[j * x + i]) * float16to32 (BB[j << l]); } else if (transA == 'T') { - DD[j] += float16to32 (AA[j * x + i]) * float16to32 (BB[i]); + DD[j] += float16to32 (AA[j * x + i]) * float16to32 (BB[i << l]); } for (j = 0; j < x; j++) { - if (fabs (CC[j] - C[j]) > 1.0) + if (fabs (CC[j << l] - C[j << l]) > 1.0) ret++; - if (fabs (CC[j] - DD[j]) > 1.0) + if (fabs (CC[j << l] - DD[j]) > 1.0) ret++; } } @@ -268,6 +269,7 @@ main (int argc, char *argv[]) free(DD); free(CC); } + } if (ret != 0) fprintf (stderr, "FATAL ERROR SBGEMV - Return code: %d\n", ret); From ab71a1edf24e309f18013b97c6473b92fbfb9608 Mon Sep 17 00:00:00 2001 From: Chip Kerchner Date: Thu, 17 Oct 2024 08:25:02 -0500 Subject: [PATCH 24/24] Better VSX. --- kernel/power/sbgemv_t_vsx.c | 70 +++++++++++++++++++++++++++---------- 1 file changed, 51 insertions(+), 19 deletions(-) diff --git a/kernel/power/sbgemv_t_vsx.c b/kernel/power/sbgemv_t_vsx.c index e72d2f31e0..ecee23a0cf 100644 --- a/kernel/power/sbgemv_t_vsx.c +++ b/kernel/power/sbgemv_t_vsx.c @@ -195,7 +195,7 @@ static void BF16GEMV_T_VSX_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL vec_f32 temp6 = { 0, 0, 0, 0 }; vec_f32 temp7 = { 0, 0, 0, 0 }; vec_bf16 zero = { 0, 0, 0, 0, 0, 0, 0, 0 }; - vec_f32 inp[2]; + vec_f32 inp[2], inp0[2], inp1[2], inp2[2], inp3[2], inp4[2], inp5[2], inp6[2], inp7[2]; BLASLONG lda4 = lda << 2; a0 = ap; @@ -220,29 +220,61 @@ static void BF16GEMV_T_VSX_8(BLASLONG n, BLASLONG lda, IFLOAT *ap, IFLOAT *x, FL for (; i < n8; i++) { vec_load_vec2(&v_x[i], inp, zero); - - temp0 += vec_load_mult(&va0[i], inp, zero); - temp1 += vec_load_mult(&va1[i], inp, zero); - temp2 += vec_load_mult(&va2[i], inp, zero); - temp3 += vec_load_mult(&va3[i], inp, zero); - temp4 += vec_load_mult(&va4[i], inp, zero); - temp5 += vec_load_mult(&va5[i], inp, zero); - temp6 += vec_load_mult(&va6[i], inp, zero); - temp7 += vec_load_mult(&va7[i], inp, zero); + vec_load_vec2(&va0[i], inp0, zero); + vec_load_vec2(&va1[i], inp1, zero); + vec_load_vec2(&va2[i], inp2, zero); + vec_load_vec2(&va3[i], inp3, zero); + vec_load_vec2(&va4[i], inp4, zero); + vec_load_vec2(&va5[i], inp5, zero); + vec_load_vec2(&va6[i], inp6, zero); + vec_load_vec2(&va7[i], inp7, zero); + + temp0 += (inp[0] * inp0[0]); + temp1 += (inp[0] * inp1[0]); + temp2 += (inp[0] * inp2[0]); + temp3 += (inp[0] * inp3[0]); + temp4 += (inp[0] * inp4[0]); + temp5 += (inp[0] * inp5[0]); + temp6 += (inp[0] * inp6[0]); + temp7 += (inp[0] * inp7[0]); + temp0 += (inp[1] * inp0[1]); + temp1 += (inp[1] * inp1[1]); + temp2 += (inp[1] * inp2[1]); + temp3 += (inp[1] * inp3[1]); + temp4 += (inp[1] * inp4[1]); + temp5 += (inp[1] * inp5[1]); + temp6 += (inp[1] * inp6[1]); + temp7 += (inp[1] * inp7[1]); } n &= 7; if (n > 4) { vec_loadN_vec2(&v_x[i], inp, n, zero); - - temp0 += vec_loadN_mult(&va0[i], inp, n, zero); - temp1 += vec_loadN_mult(&va1[i], inp, n, zero); - temp2 += vec_loadN_mult(&va2[i], inp, n, zero); - temp3 += vec_loadN_mult(&va3[i], inp, n, zero); - temp4 += vec_loadN_mult(&va4[i], inp, n, zero); - temp5 += vec_loadN_mult(&va5[i], inp, n, zero); - temp6 += vec_loadN_mult(&va6[i], inp, n, zero); - temp7 += vec_loadN_mult(&va7[i], inp, n, zero); + vec_loadN_vec2(&va0[i], inp0, n, zero); + vec_loadN_vec2(&va1[i], inp1, n, zero); + vec_loadN_vec2(&va2[i], inp2, n, zero); + vec_loadN_vec2(&va3[i], inp3, n, zero); + vec_loadN_vec2(&va4[i], inp4, n, zero); + vec_loadN_vec2(&va5[i], inp5, n, zero); + vec_loadN_vec2(&va6[i], inp6, n, zero); + vec_loadN_vec2(&va7[i], inp7, n, zero); + + temp0 += (inp[0] * inp0[0]); + temp1 += (inp[0] * inp1[0]); + temp2 += (inp[0] * inp2[0]); + temp3 += (inp[0] * inp3[0]); + temp4 += (inp[0] * inp4[0]); + temp5 += (inp[0] * inp5[0]); + temp6 += (inp[0] * inp6[0]); + temp7 += (inp[0] * inp7[0]); + temp0 += (inp[1] * inp0[1]); + temp1 += (inp[1] * inp1[1]); + temp2 += (inp[1] * inp2[1]); + temp3 += (inp[1] * inp3[1]); + temp4 += (inp[1] * inp4[1]); + temp5 += (inp[1] * inp5[1]); + temp6 += (inp[1] * inp6[1]); + temp7 += (inp[1] * inp7[1]); } else if (n) { inp[0] = vec_loadNHi(&v_x[i], n, zero);