Skip to content

Commit

Permalink
blas_neon: fix compiler errors in aarch64/Linux
Browse files Browse the repository at this point in the history
With stricter compilers, fp16 codes are not compilable.
To enable testing in non-android, fix type mismatches.

Signed-off-by: MyungJoo Ham <[email protected]>
  • Loading branch information
myungjoo committed Jan 29, 2024
1 parent 4a7f3c2 commit 859a336
Showing 1 changed file with 15 additions and 12 deletions.
27 changes: 15 additions & 12 deletions nntrainer/tensor/blas_neon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
*/

#include <blas_neon.h>
#include <memory>
#include <nntrainer_error.h>
namespace nntrainer::neon {

Expand Down Expand Up @@ -339,7 +340,7 @@ void scopy_neon_int4_to_fp32(const unsigned int N, const uint8_t *X, float *Y) {

// processing remaining batch of 8
for (; (N - idx) >= 8; idx += 8) {
int8x8_t batch = vld1_u8(&X[idx]);
uint8x8_t batch = vld1_u8(&X[idx]);

unsigned int i = 0;
for (; i < 8; ++i) {
Expand Down Expand Up @@ -444,7 +445,7 @@ void sgemv_neon_fp16(const __fp16 *A, const __fp16 *X, __fp16 *Y, uint32_t rows,
uint32_t cols, float alpha, float beta) {
const int batch = 0;
const __fp16 *__restrict x;
float Y32[rows];
float *Y32 = new float[rows];

unsigned int idx = 0;

Expand Down Expand Up @@ -688,20 +689,21 @@ void sgemv_neon_fp16(const __fp16 *A, const __fp16 *X, __fp16 *Y, uint32_t rows,

float32x4_t y0 = vmulq_f32(wvec0_3, x0_3);

for (int k = 0; k < cols - idx; ++k) {
for (unsigned int k = 0; k < cols - idx; ++k) {
Y32[j] += y0[k];
}
}
}

scopy_neon_fp32_to_fp16(rows, Y32, Y);
delete[] Y32;
return;
}

void sgemv_transpose_neon_fp16(const __fp16 *A, const __fp16 *X, __fp16 *Y,
uint32_t rows, uint32_t cols, float alpha,
float beta) {
float Y32[cols];
float *Y32 = new float[cols];
const int batch = 20;
unsigned int idx = 0;
for (; cols - idx >= 8; idx += 8) {
Expand Down Expand Up @@ -1092,7 +1094,7 @@ void sgemv_transpose_neon_fp16(const __fp16 *A, const __fp16 *X, __fp16 *Y,
if (cols != idx) {
float y0_3[4];
float wvec0_3[4];
for (int j = 0; j < cols - idx; ++j) {
for (unsigned int j = 0; j < cols - idx; ++j) {
y0_3[j] = Y32[idx + j];
wvec0_3[j] = A[i * cols + idx + j];
}
Expand All @@ -1111,6 +1113,7 @@ void sgemv_transpose_neon_fp16(const __fp16 *A, const __fp16 *X, __fp16 *Y,
}
}
scopy_neon_fp32_to_fp16(cols, Y32, Y);
delete[] Y32;
return;
}

Expand Down Expand Up @@ -1350,7 +1353,7 @@ void scopy_neon_int4_to_fp16(const unsigned int N, const uint8_t *X,

// processing remaining batch of 8
for (; (N - idx) >= 8; idx += 8) {
int8x8_t batch = vld1_u8(&X[idx]);
uint8x8_t batch = vld1_u8(&X[idx]);

for (int i = 0; i < 8; ++i) {
low0 = batch[i] >> 4;
Expand Down Expand Up @@ -1448,7 +1451,7 @@ void scopy_neon_int8_to_fp32(const unsigned int N, const uint8_t *X, float *Y) {
}

void scopy_neon_fp16_to_fp32(const unsigned int N, const __fp16 *X, float *Y) {
int idx = 0;
unsigned int idx = 0;

for (; N - idx >= 8; idx += 8) {
float32x4_t y1 = vcvt_f32_f16(vld1_f16(&X[idx]));
Expand All @@ -1470,7 +1473,7 @@ void scopy_neon_fp16_to_fp32(const unsigned int N, const __fp16 *X, float *Y) {
}

void scopy_neon_fp32_to_fp16(const unsigned int N, const float *X, __fp16 *Y) {
int idx = 0;
unsigned int idx = 0;

for (; N - idx >= 8; idx += 8) {
float32x4_t x1 = vld1q_f32(&X[idx]);
Expand Down Expand Up @@ -1556,14 +1559,14 @@ void sgemm_neon_fp16(const __fp16 *A, const __fp16 *B, __fp16 *C, uint32_t M,
// performing beta*C
unsigned int idx = 0;
unsigned int size = M * N;
for (; idx < (size - idx) >= 8; idx += 8) {
for (; idx < (size - idx) && (size - idx) >= 8; idx += 8) {
float16x8_t c = vmulq_n_f16(vld1q_f16(&C[idx]), static_cast<__fp16>(beta));

vst1q_f32(&C32[idx], vcvt_f32_f16(vget_low_f16(c)));
vst1q_f32(&C32[idx + 4], vcvt_f32_f16(vget_high_f16(c)));
}
// remaining 4
for (; idx < (size - idx) >= 4; idx += 4) {
for (; idx < (size - idx) && (size - idx) >= 4; idx += 4) {
float16x4_t c = vmul_n_f16(vld1_f16(&C[idx]), static_cast<__fp16>(beta));

vst1q_f32(&C32[idx], vcvt_f32_f16(c));
Expand Down Expand Up @@ -2025,7 +2028,7 @@ void sgemm_neon_fp16_transAB(const __fp16 *A, const __fp16 *B, float *C,
void elementwise_vector_multiplication_neon_fp16(const unsigned int N,
const __fp16 *X,
const __fp16 *Y, __fp16 *Z) {
int i = 0;
unsigned int i = 0;
for (; N - i >= 8; i += 8) {
float16x8_t x0_7 = vld1q_f16(&X[i]);
float16x8_t y0_7 = vld1q_f16(&Y[i]);
Expand All @@ -2042,7 +2045,7 @@ void elementwise_vector_multiplication_neon_fp16(const unsigned int N,
void elementwise_vector_addition_neon_fp16(const unsigned int N,
const __fp16 *X, const __fp16 *Y,
__fp16 *Z) {
int i = 0;
unsigned int i = 0;
for (; N - i >= 8; i += 8) {
float16x8_t x0_7 = vld1q_f16(&X[i]);
float16x8_t y0_7 = vld1q_f16(&Y[i]);
Expand Down

0 comments on commit 859a336

Please sign in to comment.