Skip to content

Commit

Permalink
Implementation of BF16 based gemv
Browse files Browse the repository at this point in the history
1. Add a new API -- sbgemv to support bfloat16 based gemv
2. Implement a generic kernel for sbgemv
3. Implement an avx512-bf16 based kernel for sbgemv

Signed-off-by: Chen, Guobing <[email protected]>
  • Loading branch information
Guobing-Chen committed Oct 28, 2020
1 parent 67f39ad commit a7b1f9b
Show file tree
Hide file tree
Showing 24 changed files with 5,111 additions and 16 deletions.
1 change: 1 addition & 0 deletions cblas.h
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,7 @@ void cblas_sbf16tos(OPENBLAS_CONST blasint n, OPENBLAS_CONST bfloat16 *in, OPE
void cblas_dbf16tod(OPENBLAS_CONST blasint n, OPENBLAS_CONST bfloat16 *in, OPENBLAS_CONST blasint incin, double *out, OPENBLAS_CONST blasint incout);
/* dot production of BFLOAT16 input arrays, and output as float */
float cblas_sbdot(OPENBLAS_CONST blasint n, OPENBLAS_CONST bfloat16 *x, OPENBLAS_CONST blasint incx, OPENBLAS_CONST bfloat16 *y, OPENBLAS_CONST blasint incy);
void cblas_sbgemv(OPENBLAS_CONST enum CBLAS_ORDER order, OPENBLAS_CONST enum CBLAS_TRANSPOSE trans, OPENBLAS_CONST blasint m, OPENBLAS_CONST blasint n, OPENBLAS_CONST float alpha, OPENBLAS_CONST bfloat16 *a, OPENBLAS_CONST blasint lda, OPENBLAS_CONST bfloat16 *x, OPENBLAS_CONST blasint incx, OPENBLAS_CONST float beta, float *y, OPENBLAS_CONST blasint incy);

#ifdef __cplusplus
}
Expand Down
4 changes: 2 additions & 2 deletions cmake/kernel.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,8 @@ macro(SetDefaultL2)
set(XHEMV_V_KERNEL ../generic/zhemv_k.c)
set(XHEMV_M_KERNEL ../generic/zhemv_k.c)
if (BUILD_BFLOAT16)
set(SBGEMVNKERNEL ../arm/gemv_n.c)
set(SBGEMVTKERNEL ../arm/gemv_t.c)
set(SBGEMVNKERNEL ../x86_64/sbgemv_n.c)
set(SBGEMVTKERNEL ../x86_64/sbgemv_t.c)
set(SHGERKERNEL ../generic/ger.c)
endif ()
endmacro ()
Expand Down
2 changes: 2 additions & 0 deletions common_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,8 @@ void BLASFUNC(xgeru)(blasint *, blasint *, xdouble *, xdouble *, blasint *,
void BLASFUNC(xgerc)(blasint *, blasint *, xdouble *, xdouble *, blasint *,
xdouble *, blasint *, xdouble *, blasint *);

void BLASFUNC(sbgemv)(char *, blasint *, blasint *, float *, bfloat16 *, blasint *,
bfloat16 *, blasint *, float *, float *, blasint *);
void BLASFUNC(sgemv)(char *, blasint *, blasint *, float *, float *, blasint *,
float *, blasint *, float *, float *, blasint *);
void BLASFUNC(dgemv)(char *, blasint *, blasint *, double *, double *, blasint *,
Expand Down
4 changes: 4 additions & 0 deletions common_level2.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@
extern "C" {
#endif

int sbgemv_n(BLASLONG, BLASLONG, float, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float, float *, BLASLONG);
int sbgemv_t(BLASLONG, BLASLONG, float, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float, float *, BLASLONG);
int sbgemv_thread_n(BLASLONG, BLASLONG, float, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float, float *, BLASLONG, int);
int sbgemv_thread_t(BLASLONG, BLASLONG, float, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float, float *, BLASLONG, int);
int sger_k (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG, float *);
int dger_k (BLASLONG, BLASLONG, BLASLONG, double, double *, BLASLONG, double *, BLASLONG, double *, BLASLONG, double *);
int qger_k (BLASLONG, BLASLONG, BLASLONG, xdouble, xdouble *, BLASLONG, xdouble *, BLASLONG, xdouble *, BLASLONG, xdouble *);
Expand Down
10 changes: 6 additions & 4 deletions common_macro.h
Original file line number Diff line number Diff line change
Expand Up @@ -646,10 +646,12 @@

#elif defined(BFLOAT16)

#define D_TO_BF16_K SBDTOBF16_K
#define D_BF16_TO_K DBF16TOD_K
#define S_TO_BF16_K SBSTOBF16_K
#define S_BF16_TO_K SBF16TOS_K
#define D_TO_BF16_K SBDTOBF16_K
#define D_BF16_TO_K DBF16TOD_K
#define S_TO_BF16_K SBSTOBF16_K
#define S_BF16_TO_K SBF16TOS_K
#define SBGEMV_N SBGEMV_N_K
#define SBGEMV_T SBGEMV_T_K

#define AMAX_K SAMAX_K
#define AMIN_K SAMIN_K
Expand Down
4 changes: 2 additions & 2 deletions common_param.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ BLASLONG (*isbmin_k) (BLASLONG, float *, BLASLONG);
int (*sbscal_k) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG);
int (*sbswap_k) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG);

int (*sbgemv_n) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG, float *);
int (*sbgemv_t) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG, float *);
int (*sbgemv_n) (BLASLONG, BLASLONG, float, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float, float *, BLASLONG);
int (*sbgemv_t) (BLASLONG, BLASLONG, float, bfloat16 *, BLASLONG, bfloat16 *, BLASLONG, float, float *, BLASLONG);
int (*sbger_k) (BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG, float *);

int (*sbsymv_L) (BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG, float *);
Expand Down
4 changes: 4 additions & 0 deletions common_sb.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#define SBDTOBF16_K sbdtobf16_k
#define SBF16TOS_K sbf16tos_k
#define DBF16TOD_K dbf16tod_k
#define SBGEMV_N_K sbgemv_n
#define SBGEMV_T_K sbgemv_t

#define SBGEMM_ONCOPY sbgemm_oncopy
#define SBGEMM_OTCOPY sbgemm_otcopy
Expand All @@ -29,6 +31,8 @@
#define SBDTOBF16_K gotoblas -> sbdtobf16_k
#define SBF16TOS_K gotoblas -> sbf16tos_k
#define DBF16TOD_K gotoblas -> dbf16tod_k
#define SBGEMV_N_K gotoblas -> sbgemv_n
#define SBGEMV_T_K gotoblas -> sbgemv_t

#define SBGEMM_ONCOPY gotoblas -> sbgemm_oncopy
#define SBGEMM_OTCOPY gotoblas -> sbgemm_otcopy
Expand Down
16 changes: 15 additions & 1 deletion driver/level2/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,13 @@ XBLASOBJS += \
xtbmv_thread_RUU.$(SUFFIX) xtbmv_thread_RUN.$(SUFFIX) \
xtbmv_thread_RLU.$(SUFFIX) xtbmv_thread_RLN.$(SUFFIX) \
xtbmv_thread_CUU.$(SUFFIX) xtbmv_thread_CUN.$(SUFFIX) \
xtbmv_thread_CLU.$(SUFFIX) xtbmv_thread_CLN.$(SUFFIX) \
xtbmv_thread_CLU.$(SUFFIX) xtbmv_thread_CLN.$(SUFFIX)

ifeq ($(BUILD_BFLOAT16),1)
SBBLASOBJS += \
sbgemv_thread_n$(TSUFFIX).$(SUFFIX) \
sbgemv_thread_t$(TSUFFIX).$(SUFFIX)
endif

endif

Expand Down Expand Up @@ -3693,4 +3699,12 @@ xtrsv_CUU.$(SUFFIX) xtrsv_CUU.$(PSUFFIX) : ztrsv_L.c ../../param.h
xtrsv_CUN.$(SUFFIX) xtrsv_CUN.$(PSUFFIX) : ztrsv_L.c ../../param.h
$(CC) -c $(CFLAGS) -DXDOUBLE -DCOMPLEX -DTRANSA=4 -UUNIT $< -o $(@F)

ifeq ($(BUILD_BFLOAT16),1)
sbgemv_thread_n.$(SUFFIX) sbgemv_thread_n.$(PSUFFIX) : sbgemv_thread.c ../../common.h
$(CC) -c $(CFLAGS) -UCOMPLEX -UDOUBLE -UTRANSA -UCONJ -UXCONJ $< -o $(@F)
sbgemv_thread_t.$(SUFFIX) sbgemv_thread_t.$(PSUFFIX) : sbgemv_thread.c ../../common.h
$(CC) -c $(CFLAGS) -UCOMPLEX -UDOUBLE -DTRANSA -UCONJ -UXCONJ $< -o $(@F)
endif


include ../../Makefile.tail
149 changes: 149 additions & 0 deletions driver/level2/sbgemv_thread.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
/*********************************************************************/
/* Copyright 2009, 2010 The University of Texas at Austin. */
/* All rights reserved. */
/* */
/* Redistribution and use in source and binary forms, with or */
/* without modification, are permitted provided that the following */
/* conditions are met: */
/* */
/* 1. Redistributions of source code must retain the above */
/* copyright notice, this list of conditions and the following */
/* disclaimer. */
/* */
/* 2. Redistributions in binary form must reproduce the above */
/* copyright notice, this list of conditions and the following */
/* disclaimer in the documentation and/or other materials */
/* provided with the distribution. */
/* */
/* THIS SOFTWARE IS PROVIDED BY THE UNIVERSITY OF TEXAS AT */
/* AUSTIN ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, */
/* INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF */
/* MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE */
/* DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY OF TEXAS AT */
/* AUSTIN OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, */
/* INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES */
/* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE */
/* GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR */
/* BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF */
/* LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT */
/* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT */
/* OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE */
/* POSSIBILITY OF SUCH DAMAGE. */
/* */
/* The views and conclusions contained in the software and */
/* documentation are those of the authors and should not be */
/* interpreted as representing official policies, either expressed */
/* or implied, of The University of Texas at Austin. */
/*********************************************************************/

#include <stdio.h>
#include <stdlib.h>
#include "common.h"

#ifndef TRANSA
#define SBGEMV SBGEMV_N
#else
#define SBGEMV SBGEMV_T
#endif

static int sbgemv_kernel(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *dummy1, FLOAT *dummy2, BLASLONG dummy3){

bfloat16 *a, *x;
float *y;
BLASLONG lda, incx, incy;
BLASLONG m_from, m_to, n_from, n_to;

a = (bfloat16 *)args->a;
x = (bfloat16 *)args->b;
y = (float *)args->c;

lda = args->lda;
incx = args->ldb;
incy = args->ldc;

#ifndef TRANSA // N
m_from = *(range_m + 0);
m_to = *(range_m + 1);
n_from = 0;
n_to = args -> n;
a += m_from;
y += m_from * incy;
#else // T
m_from = 0;
m_to = args->m;
n_from = *(range_n + 0);
n_to = *(range_n + 1);
a += n_from * lda;
y += n_from * incy;
#endif

SBGEMV(m_to - m_from, n_to - n_from, *((FLOAT *)(args->alpha)), a, lda, x, incx, *((FLOAT *)(args->beta)), y, incy);

return 0;
}

int CNAME(BLASLONG m, BLASLONG n, float alpha, bfloat16 *a, BLASLONG lda, bfloat16 *x, BLASLONG incx, float beta, float *y, BLASLONG incy, int threads)
{
blas_arg_t args;
blas_queue_t queue[MAX_CPU_NUMBER];
BLASLONG range[MAX_CPU_NUMBER + 1];

#ifndef TRANSA
BLASLONG width_for_split = m;
#else
BLASLONG width_for_split = n;
#endif

BLASLONG BLOCK_WIDTH = width_for_split/threads;

int mode = BLAS_BFLOAT16 | BLAS_REAL;

args.m = m;
args.n = n;
args.a = (void *)a;
args.b = (void *)x;
args.c = (void *)y;
args.lda = lda;
args.ldb = incx;
args.ldc = incy;
args.alpha = (void *)&alpha;
args.beta = (void *)&beta;

range[0] = 0;

int thread_idx;

for (thread_idx=0; thread_idx<threads; thread_idx++) {
if (thread_idx != threads-1) {
range[thread_idx + 1] = range[thread_idx] + BLOCK_WIDTH;
} else {
range[thread_idx + 1] = range[thread_idx] + width_for_split;
}

queue[thread_idx].mode = mode;
queue[thread_idx].routine = sbgemv_kernel;
queue[thread_idx].args = &args;
#ifndef TRANSA
queue[thread_idx].range_m = &range[thread_idx];
queue[thread_idx].range_n = NULL;
#else
queue[thread_idx].range_m = NULL;
queue[thread_idx].range_n = &range[thread_idx];
#endif
queue[thread_idx].sa = NULL;
queue[thread_idx].sb = NULL;
queue[thread_idx].next = &queue[thread_idx + 1];

width_for_split -= BLOCK_WIDTH;
}

if (thread_idx) {
queue[0].sa = NULL;
queue[0].sb = NULL;
queue[thread_idx - 1].next = NULL;

exec_blas(thread_idx, queue);
}

return 0;
}
1 change: 0 additions & 1 deletion driver/others/blas_server_omp.c
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,6 @@ fprintf(stderr,"UNHANDLED COMPLEX\n");
/* Other types in future */
}
}
if (!sb) fprintf(stderr,"SB not declared!!!\n");
queue->sb=sb;
}
}
Expand Down
4 changes: 2 additions & 2 deletions exports/gensymbol
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
zgeadd, dzsum);

@blasobjs = (lsame, xerbla);
@bfblasobjs = (sbgemm, sbdot, sbstobf16, sbdtobf16, sbf16tos, dbf16tod);
@bfblasobjs = (sbgemm, sbgemv, sbdot, sbstobf16, sbdtobf16, sbf16tos, dbf16tod);
@cblasobjsc = (
cblas_caxpy, cblas_ccopy, cblas_cdotc, cblas_cdotu, cblas_cgbmv, cblas_cgemm, cblas_cgemv,
cblas_cgerc, cblas_cgeru, cblas_chbmv, cblas_chemm, cblas_chemv, cblas_cher2, cblas_cher2k,
Expand Down Expand Up @@ -94,7 +94,7 @@

@cblasobjs = ( cblas_xerbla );

@bfcblasobjs = (cblas_sbgemm, cblas_sbdot, cblas_sbstobf16, cblas_sbdtobf16, cblas_sbf16tos, cblas_dbf16tod);
@bfcblasobjs = (cblas_sbgemm, cblas_sbgemv, cblas_sbdot, cblas_sbstobf16, cblas_sbdtobf16, cblas_sbf16tos, cblas_dbf16tod);

@exblasobjs = (
qamax,qamin,qasum,qaxpy,qcabs1,qcopy,qdot,qgbmv,qgemm,
Expand Down
17 changes: 15 additions & 2 deletions interface/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ SBLAS3OBJS = \

ifeq ($(BUILD_BFLOAT16),1)
SBBLAS1OBJS = sbdot.$(SUFFIX)
SBBLAS2OBJS = sbgemv.$(SUFFIX)
SBBLAS3OBJS = sbgemm.$(SUFFIX)
SBEXTOBJS = sbstobf16.$(SUFFIX) sbdtobf16.$(SUFFIX) sbf16tos.$(SUFFIX) dbf16tod.$(SUFFIX)
endif
Expand Down Expand Up @@ -284,6 +285,7 @@ CSBLAS3OBJS = \

ifeq ($(BUILD_BFLOAT16),1)
CSBBLAS1OBJS = cblas_sbdot.$(SUFFIX)
CSBBLAS2OBJS = cblas_sbgemv.$(SUFFIX)
CSBBLAS3OBJS = cblas_sbgemm.$(SUFFIX)
CSBEXTOBJS = cblas_sbstobf16.$(SUFFIX) cblas_sbdtobf16.$(SUFFIX) cblas_sbf16tos.$(SUFFIX) cblas_dbf16tod.$(SUFFIX)
endif
Expand Down Expand Up @@ -382,6 +384,7 @@ SBLAS1OBJS += $(CSBLAS1OBJS)
SBLAS2OBJS += $(CSBLAS2OBJS)
SBLAS3OBJS += $(CSBLAS3OBJS)
SBBLAS1OBJS += $(CSBBLAS1OBJS)
SBBLAS2OBJS += $(CSBBLAS2OBJS)
SBBLAS3OBJS += $(CSBBLAS3OBJS)
DBLAS1OBJS += $(CDBLAS1OBJS)
DBLAS2OBJS += $(CDBLAS2OBJS)
Expand All @@ -399,7 +402,7 @@ CBAUXOBJS += $(CXERBLAOBJ)
endif

SBLASOBJS = $(SBLAS1OBJS) $(SBLAS2OBJS) $(SBLAS3OBJS)
SBBLASOBJS = $(SBBLAS1OBJS) $(SBBLAS3OBJS)
SBBLASOBJS = $(SBBLAS1OBJS) $(SBBLAS2OBJS) $(SBBLAS3OBJS)
DBLASOBJS = $(DBLAS1OBJS) $(DBLAS2OBJS) $(DBLAS3OBJS)
QBLASOBJS = $(QBLAS1OBJS) $(QBLAS2OBJS) $(QBLAS3OBJS)
CBLASOBJS = $(CBLAS1OBJS) $(CBLAS2OBJS) $(CBLAS3OBJS)
Expand Down Expand Up @@ -538,7 +541,7 @@ clean ::
level1 : $(SBEXTOBJS) $(SBBLAS1OBJS) $(SBLAS1OBJS) $(DBLAS1OBJS) $(QBLAS1OBJS) $(CBLAS1OBJS) $(ZBLAS1OBJS) $(XBLAS1OBJS)
$(AR) $(ARFLAGS) -ru $(TOPDIR)/$(LIBNAME) $^

level2 : $(SBLAS2OBJS) $(DBLAS2OBJS) $(QBLAS2OBJS) $(CBLAS2OBJS) $(ZBLAS2OBJS) $(XBLAS2OBJS)
level2 : $(SBBLAS2OBJS) $(SBLAS2OBJS) $(DBLAS2OBJS) $(QBLAS2OBJS) $(CBLAS2OBJS) $(ZBLAS2OBJS) $(XBLAS2OBJS)
$(AR) $(ARFLAGS) -ru $(TOPDIR)/$(LIBNAME) $^

level3 : $(SBBLAS3OBJS) $(SBLAS3OBJS) $(DBLAS3OBJS) $(QBLAS3OBJS) $(CBLAS3OBJS) $(ZBLAS3OBJS) $(XBLAS3OBJS)
Expand Down Expand Up @@ -929,6 +932,11 @@ xgeru.$(SUFFIX) xgeru.$(PSUFFIX) : zger.c
xgerc.$(SUFFIX) xgerc.$(PSUFFIX) : zger.c
$(CC) -c $(CFLAGS) -DCONJ $< -o $(@F)

ifeq ($(BUILD_BFLOAT16),1)
sbgemv.$(SUFFIX) sbgemv.$(PSUFFIX) : sbgemv.c
$(CC) $(CFLAGS) -c $< -o $(@F)
endif

ifndef USE_NETLIB_GEMV
sgemv.$(SUFFIX) sgemv.$(PSUFFIX): gemv.c
$(CC) -c $(CFLAGS) -o $(@F) $<
Expand Down Expand Up @@ -1656,6 +1664,11 @@ cblas_csscal.$(SUFFIX) cblas_csscal.$(PSUFFIX) : zscal.c
cblas_zdscal.$(SUFFIX) cblas_zdscal.$(PSUFFIX) : zscal.c
$(CC) $(CFLAGS) -DCBLAS -c -DSSCAL $< -o $(@F)

ifeq ($(BUILD_BFLOAT16),1)
cblas_sbgemv.$(SUFFIX) cblas_sbgemv.$(PSUFFIX) : sbgemv.c
$(CC) -DCBLAS -c $(CFLAGS) $< -o $(@F)
endif

cblas_sgemv.$(SUFFIX) cblas_sgemv.$(PSUFFIX): gemv.c
$(CC) -DCBLAS -c $(CFLAGS) -o $(@F) $<

Expand Down
1 change: 0 additions & 1 deletion interface/gemv.c
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,6 @@ void CNAME(enum CBLAS_ORDER order,
}

#endif
//printf("m=%d, n=%d, trans=%d, incx=%d, incy=%d, alpha=%f, beta=%f\n", m, n, trans, incx, incy, alpha, beta);
if ((m==0) || (n==0)) return;

lenx = n;
Expand Down
Loading

0 comments on commit a7b1f9b

Please sign in to comment.