Skip to content

Commit

Permalink
[blas] Fix sgemv and sgemm loop for fp16
Browse files Browse the repository at this point in the history
- Previously, we used full-fp16 variables for sgemv and sgemm loop code.
- However, such practice might cause acummulation error that exceeds our expected epsilon.
- Now, it uses inter-fp32 value to preseve accuracy and avoid precision loss.

Self evaluation:**
1. Build test:     [X]Passed [ ]Failed [ ]Skipped
2. Run test:     [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: skykongkong8 <[email protected]>
  • Loading branch information
skykongkong8 committed Oct 4, 2023
1 parent 4f55742 commit 017f73f
Showing 1 changed file with 15 additions and 19 deletions.
34 changes: 15 additions & 19 deletions nntrainer/tensor/blas_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,16 @@
} \
} while (0);

#define sgemv_loop_fp16(ci, cj, cM, cN) \
do { \
_FP16 y0; \
unsigned int i, j; \
for (ci = 0; ci != cM; ci++) { \
y0 = Y[ci * incy] * static_cast<_FP16>(beta); \
for (cj = 0; cj != cN; cj++) \
y0 += A[i + j * lda] * X[cj * incx]; \
Y[ci * incy] = y0; \
} \
#define sgemv_loop_fp16(ci, cj, cM, cN) \
do { \
float y0; \
unsigned int i, j; \
for (ci = 0; ci != cM; ci++) { \
y0 = static_cast<float>(Y[ci * incy] * static_cast<_FP16>(beta)); \
for (cj = 0; cj != cN; cj++) \
y0 += static_cast<float>(A[i + j * lda] * X[cj * incx]); \
Y[ci * incy] = static_cast<_FP16>(y0); \
} \
} while (0);

#define saxpy_loop_fp16() \
Expand All @@ -56,15 +56,15 @@
do { \
for (unsigned int m = 0; m < M; ++m) { \
for (unsigned int n = 0; n < N; ++n) { \
_FP16 c = 0; \
float c = 0; \
_FP16 c_old = C[m * ldc + n]; \
for (unsigned int k = 0; k < K; ++k) { \
_FP16 a, b; \
a = ((TransA == CblasTrans) ? A[k * lda + m] : A[m * lda + k]); \
b = ((TransB == CblasTrans) ? B[n * ldb + k] : B[k * ldb + n]); \
c += a * b; \
c += static_cast<float>(a * b); \
} \
C[m * ldc + n] = static_cast<_FP16>(alpha) * c; \
C[m * ldc + n] = alpha * c; \
if (beta != 0.0) \
C[m * ldc + n] += static_cast<_FP16>(beta) * c_old; \
} \
Expand Down Expand Up @@ -116,7 +116,6 @@ static void sgemv_FP16(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA,
}
}


static _FP16 sdot_FP16(const unsigned int N, const _FP16 *X,
const unsigned int incX, const _FP16 *Y,
const unsigned int incY) {
Expand Down Expand Up @@ -180,7 +179,6 @@ static void scopy_INT4(const unsigned int N, const uint8_t *X, const int incX,
#endif
}


static void ewvm_FP16(const unsigned int N, const _FP16 *X, const _FP16 *Y,
_FP16 *Z) {
#ifdef USE__FP16
Expand Down Expand Up @@ -247,9 +245,8 @@ static void sgemm_FP16(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA,
const unsigned int ldc) {

#ifdef USE__FP16
nntrainer::neon::sgemm_neon_fp16(A, B, C, M, N, K, alpha, beta,
TransA == CblasTrans,
TransB == CblasTrans);
nntrainer::neon::sgemm_neon_fp16(A, B, C, M, N, K, alpha, beta,
TransA == CblasTrans, TransB == CblasTrans);
#else
sgemm_loop_fp16();
#endif
Expand Down Expand Up @@ -335,7 +332,6 @@ void sgemv(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, const unsigned int M,
sgemv_FP16(order, TransA, M, N, alpha, A, lda, X, incX, beta, Y, incY);
}


unsigned int isamax(const unsigned int N, const _FP16 *X, const int incX) {
/// @todo isamax_FP16 for BLAS_NUM_THREADS
return isamax_FP16(N, X, incX);
Expand Down

0 comments on commit 017f73f

Please sign in to comment.