Skip to content

Commit

Permalink
more stricter armv7 fp16 and armv84 bf16 compiler check, fix #4147 fix
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui authored Oct 10, 2022
1 parent cef95f6 commit 3e2b3fa
Show file tree
Hide file tree
Showing 6 changed files with 153 additions and 86 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ if((IOS AND CMAKE_OSX_ARCHITECTURES MATCHES "arm")
check_cxx_source_compiles("#include <arm_neon.h>\nint main() { float32x4_t _s; float16x8_t _a, _b; _s = vfmlalq_low_f16(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_ARM82_FP16FML)

set(CMAKE_REQUIRED_FLAGS "-march=armv8.4-a+bf16")
check_cxx_source_compiles("#include <arm_neon.h>\nint main() { float32x4_t _s; bfloat16x8_t _a, _b; _s = vbfmmlaq_f32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_ARM84_BF16)
check_cxx_source_compiles("#include <arm_neon.h>\nint main() { float32x4_t _s; bfloat16x8_t _a, _b; _s = vcvt_f32_bf16(vcvt_bf16_f32(vbfmmlaq_f32(_s, _a, _b))); return 0; }" NCNN_COMPILER_SUPPORT_ARM84_BF16)

set(CMAKE_REQUIRED_FLAGS "-march=armv8.4-a+i8mm")
check_cxx_source_compiles("#include <arm_neon.h>\nint main() { int32x4_t _s; int8x16_t _a, _b; _s = vmmlaq_s32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_ARM84_I8MM)
Expand Down
2 changes: 1 addition & 1 deletion src/layer/arm/cast_bf16.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ static void cast_fp32_to_bf16_neon(const Mat& bottom_blob, Mat& top_blob, const

static void cast_bf16_to_fp32_neon(const Mat& bottom_blob, Mat& top_blob, const Option& opt)
{
#if NCNN_ARM84BF16 && __aarch64__ && !__ARM_FEATURE_BF16_VECTOR_ARITHMETIC
#if NCNN_RUNTIME_CPU && NCNN_ARM84BF16 && __aarch64__ && !__ARM_FEATURE_BF16_VECTOR_ARITHMETIC
if (ncnn::cpu_support_arm_bf16())
{
cast_bf16_to_fp32_neon_bf16(bottom_blob, top_blob, opt);
Expand Down
171 changes: 119 additions & 52 deletions src/layer/arm/cast_fp16.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,12 @@ static void cast_fp32_to_fp16_neon(const Mat& bottom_blob, Mat& top_blob, const
{
#if __aarch64__
asm volatile(
"prfm pldl1keep, [%0, #512] \n"
"prfm pldl1keep, [%0, #512] \n"
"ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%0], #64 \n"
"fcvtn v0.4h, v0.4s \n"
"fcvtn v1.4h, v1.4s \n"
"fcvtn v2.4h, v2.4s \n"
"fcvtn v3.4h, v3.4s \n"
"fcvtn v0.4h, v0.4s \n"
"fcvtn v1.4h, v1.4s \n"
"fcvtn v2.4h, v2.4s \n"
"fcvtn v3.4h, v3.4s \n"
"st1 {v0.4h, v1.4h, v2.4h, v3.4h}, [%1], #32 \n"
: "=r"(ptr), // %0
"=r"(outptr) // %1
Expand All @@ -61,12 +61,12 @@ static void cast_fp32_to_fp16_neon(const Mat& bottom_blob, Mat& top_blob, const
: "memory", "v0", "v1", "v2", "v3");
#else // __aarch64__
asm volatile(
"pld [%0, #512] \n"
"vldm %0!, {d0-d7} \n"
"vcvt.f16.f32 d0, q0 \n"
"vcvt.f16.f32 d1, q1 \n"
"vcvt.f16.f32 d2, q2 \n"
"vcvt.f16.f32 d3, q3 \n"
"pld [%0, #512] \n"
"vldm %0!, {d0-d7} \n"
"vcvt.f16.f32 d0, q0 \n"
"vcvt.f16.f32 d1, q1 \n"
"vcvt.f16.f32 d2, q2 \n"
"vcvt.f16.f32 d3, q3 \n"
"vst1.u16 {d0-d3}, [%1 :128]! \n"
: "=r"(ptr), // %0
"=r"(outptr) // %1
Expand All @@ -77,24 +77,61 @@ static void cast_fp32_to_fp16_neon(const Mat& bottom_blob, Mat& top_blob, const
}
for (; i + 7 < size; i += 8)
{
float32x4_t _p0_fp32 = vld1q_f32(ptr);
float32x4_t _p1_fp32 = vld1q_f32(ptr + 4);
float16x4_t _p0_fp16 = vcvt_f16_f32(_p0_fp32);
float16x4_t _p1_fp16 = vcvt_f16_f32(_p1_fp32);
uint16x8_t _p_fp16 = vcombine_u16(vreinterpret_u16_f16(_p0_fp16), vreinterpret_u16_f16(_p1_fp16));
vst1q_u16(outptr, _p_fp16);
ptr += 8;
outptr += 8;
// This is originally implemented with neon fp16 intrinsics.
// In the new version of gcc, __ARM_FP16_FORMAT_IEEE or __ARM_FP16_FORMAT_ALTERNATIVE needs to be defined to use the float16x4_t type.
// That leads to compiler error when compiled with -mfpu=neon-vfpv4 but without -mfp16-format=ieee flag.
// We could add more macro conditions to differentiate between old and new versions, but that's pretty ugly!
// Just use all inline assembly here ~
// --- nihui
#if __aarch64__
asm volatile(
"ld1 {v0.4s, v1.4s}, [%0], #32 \n"
"fcvtn v0.4h, v0.4s \n"
"fcvtn v1.4h, v1.4s \n"
"st1 {v0.4h, v1.4h}, [%1], #16 \n"
: "=r"(ptr), // %0
"=r"(outptr) // %1
: "0"(ptr),
"1"(outptr)
: "memory", "v0", "v1");
#else // __aarch64__
asm volatile(
"vld1.f32 {d0-d3}, [%0]! \n"
"vcvt.f16.f32 d0, q0 \n"
"vcvt.f16.f32 d1, q1 \n"
"vst1.u16 {d0-d1}, [%1]! \n"
: "=r"(ptr), // %0
"=r"(outptr) // %1
: "0"(ptr),
"1"(outptr)
: "memory", "q0", "q1");
#endif // __aarch64__
}
for (; i + 3 < size; i += 4)
{
float32x4_t _p_fp32 = vld1q_f32(ptr);
float16x4_t _p_fp16 = vcvt_f16_f32(_p_fp32);
vst1_u16(outptr, vreinterpret_u16_f16(_p_fp16));
ptr += 4;
outptr += 4;
#if __aarch64__
asm volatile(
"ld1 {v0.4s}, [%0], #16 \n"
"fcvtn v0.4h, v0.4s \n"
"st1 {v0.4h}, [%1], #8 \n"
: "=r"(ptr), // %0
"=r"(outptr) // %1
: "0"(ptr),
"1"(outptr)
: "memory", "v0");
#else // __aarch64__
asm volatile(
"vld1.f32 {d0-d1}, [%0]! \n"
"vcvt.f16.f32 d0, q0 \n"
"vst1.u16 {d0}, [%1]! \n"
: "=r"(ptr), // %0
"=r"(outptr) // %1
: "0"(ptr),
"1"(outptr)
: "memory", "q0");
#endif // __aarch64__
}
#endif
#endif // (__ARM_FP & 2)
for (; i < size; i++)
{
*outptr++ = float32_to_float16(*ptr++);
Expand All @@ -104,7 +141,7 @@ static void cast_fp32_to_fp16_neon(const Mat& bottom_blob, Mat& top_blob, const

static void cast_fp16_to_fp32_neon(const Mat& bottom_blob, Mat& top_blob, const Option& opt)
{
#if NCNN_VFPV4 && __ARM_NEON && !(__ARM_FP & 2)
#if NCNN_RUNTIME_CPU && NCNN_VFPV4 && __ARM_NEON && !(__ARM_FP & 2)
if (ncnn::cpu_support_arm_vfpv4())
{
cast_fp16_to_fp32_neon_vfpv4(bottom_blob, top_blob, opt);
Expand Down Expand Up @@ -132,12 +169,12 @@ static void cast_fp16_to_fp32_neon(const Mat& bottom_blob, Mat& top_blob, const
{
#if __aarch64__
asm volatile(
"prfm pldl1keep, [%0, #256] \n"
"prfm pldl1keep, [%0, #256] \n"
"ld1 {v0.4h, v1.4h, v2.4h, v3.4h}, [%0], #32 \n"
"fcvtl v0.4s, v0.4h \n"
"fcvtl v1.4s, v1.4h \n"
"fcvtl v2.4s, v2.4h \n"
"fcvtl v3.4s, v3.4h \n"
"fcvtl v0.4s, v0.4h \n"
"fcvtl v1.4s, v1.4h \n"
"fcvtl v2.4s, v2.4h \n"
"fcvtl v3.4s, v3.4h \n"
"st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%1], #64 \n"
: "=r"(ptr), // %0
"=r"(outptr) // %1
Expand All @@ -146,13 +183,13 @@ static void cast_fp16_to_fp32_neon(const Mat& bottom_blob, Mat& top_blob, const
: "memory", "v0", "v1", "v2", "v3");
#else // __aarch64__
asm volatile(
"pld [%0, #256] \n"
"pld [%0, #256] \n"
"vld1.u16 {d4-d7}, [%0 :128]! \n"
"vcvt.f32.f16 q0, d4 \n"
"vcvt.f32.f16 q1, d5 \n"
"vcvt.f32.f16 q2, d6 \n"
"vcvt.f32.f16 q3, d7 \n"
"vstm %1!, {d0-d7} \n"
"vcvt.f32.f16 q0, d4 \n"
"vcvt.f32.f16 q1, d5 \n"
"vcvt.f32.f16 q2, d6 \n"
"vcvt.f32.f16 q3, d7 \n"
"vstm %1!, {d0-d7} \n"
: "=r"(ptr), // %0
"=r"(outptr) // %1
: "0"(ptr),
Expand All @@ -162,25 +199,55 @@ static void cast_fp16_to_fp32_neon(const Mat& bottom_blob, Mat& top_blob, const
}
for (; i + 7 < size; i += 8)
{
uint16x8_t _p_fp16 = vld1q_u16(ptr);
float16x4_t _p0_fp16 = vreinterpret_f16_u16(vget_low_u16(_p_fp16));
float16x4_t _p1_fp16 = vreinterpret_f16_u16(vget_high_u16(_p_fp16));
float32x4_t _p0_fp32 = vcvt_f32_f16(_p0_fp16);
float32x4_t _p1_fp32 = vcvt_f32_f16(_p1_fp16);
vst1q_f32(outptr, _p0_fp32);
vst1q_f32(outptr + 4, _p1_fp32);
ptr += 8;
outptr += 8;
#if __aarch64__
asm volatile(
"ld1 {v0.4h, v1.4h}, [%0], #16 \n"
"fcvtl v0.4s, v0.4h \n"
"fcvtl v1.4s, v1.4h \n"
"st1 {v0.4s, v1.4s}, [%1], #32 \n"
: "=r"(ptr), // %0
"=r"(outptr) // %1
: "0"(ptr),
"1"(outptr)
: "memory", "v0", "v1");
#else // __aarch64__
asm volatile(
"vld1.u16 {d4-d5}, [%0]! \n"
"vcvt.f32.f16 q0, d4 \n"
"vcvt.f32.f16 q1, d5 \n"
"vst1.f32 {d0-d3}, [%1]! \n"
: "=r"(ptr), // %0
"=r"(outptr) // %1
: "0"(ptr),
"1"(outptr)
: "memory", "q0", "q1", "q2");
#endif // __aarch64__
}
for (; i + 3 < size; i += 4)
{
float16x4_t _p_fp16 = vreinterpret_f16_u16(vld1_u16(ptr));
float32x4_t _p_fp32 = vcvt_f32_f16(_p_fp16);
vst1q_f32(outptr, _p_fp32);
ptr += 4;
outptr += 4;
#if __aarch64__
asm volatile(
"ld1 {v0.4h}, [%0], #8 \n"
"fcvtl v0.4s, v0.4h \n"
"st1 {v0.4s}, [%1], #16 \n"
: "=r"(ptr), // %0
"=r"(outptr) // %1
: "0"(ptr),
"1"(outptr)
: "memory", "v0");
#else // __aarch64__
asm volatile(
"vld1.u16 {d2}, [%0]! \n"
"vcvt.f32.f16 q0, d2 \n"
"vst1.f32 {d0-d1}, [%1]! \n"
: "=r"(ptr), // %0
"=r"(outptr) // %1
: "0"(ptr),
"1"(outptr)
: "memory", "q0", "q1");
#endif // __aarch64__
}
#endif
#endif // (__ARM_FP & 2)
for (; i < size; i++)
{
*outptr++ = float16_to_float32(*ptr++);
Expand Down
28 changes: 14 additions & 14 deletions src/layer/arm/innerproduct_fp16s.h
Original file line number Diff line number Diff line change
Expand Up @@ -253,10 +253,10 @@ static void innerproduct_pack4_fp16s_neon(const Mat& bottom_blob, Mat& top_blob,
float32x4_t _val = vld1q_f32(sptr);
uint16x8_t _w01 = vld1q_u16(kptr);
uint16x8_t _w23 = vld1q_u16(kptr + 8);
float32x4_t _w0 = vcvt_f32_f16(vreinterpret_f16_u16(vget_low_u16(_w01)));
float32x4_t _w1 = vcvt_f32_f16(vreinterpret_f16_u16(vget_high_u16(_w01)));
float32x4_t _w2 = vcvt_f32_f16(vreinterpret_f16_u16(vget_low_u16(_w23)));
float32x4_t _w3 = vcvt_f32_f16(vreinterpret_f16_u16(vget_high_u16(_w23)));
float32x4_t _w0 = vcvt_f32_f16((float16x4_t)(vget_low_u16(_w01)));
float32x4_t _w1 = vcvt_f32_f16((float16x4_t)(vget_high_u16(_w01)));
float32x4_t _w2 = vcvt_f32_f16((float16x4_t)(vget_low_u16(_w23)));
float32x4_t _w3 = vcvt_f32_f16((float16x4_t)(vget_high_u16(_w23)));
#endif

#if __aarch64__
Expand All @@ -281,7 +281,7 @@ static void innerproduct_pack4_fp16s_neon(const Mat& bottom_blob, Mat& top_blob,
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
float32x4_t _w = vcvt_f32_f16(vld1_f16(kptr));
#else
float32x4_t _w = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(kptr)));
float32x4_t _w = vcvt_f32_f16((float16x4_t)(vld1_u16(kptr)));
#endif
_sum0 = vfmaq_f32(_sum0, _val, _w);

Expand Down Expand Up @@ -410,10 +410,10 @@ static void innerproduct_fp16s_neon(const Mat& bottom_blob, Mat& top_blob, const
float32x4_t _w3 = vcvt_f32_f16(vld1_f16(kptr3));
#else
float32x4_t _val = vld1q_f32(sptr);
float32x4_t _w0 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(kptr0)));
float32x4_t _w1 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(kptr1)));
float32x4_t _w2 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(kptr2)));
float32x4_t _w3 = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(kptr3)));
float32x4_t _w0 = vcvt_f32_f16((float16x4_t)(vld1_u16(kptr0)));
float32x4_t _w1 = vcvt_f32_f16((float16x4_t)(vld1_u16(kptr1)));
float32x4_t _w2 = vcvt_f32_f16((float16x4_t)(vld1_u16(kptr2)));
float32x4_t _w3 = vcvt_f32_f16((float16x4_t)(vld1_u16(kptr3)));
#endif

_sum0 = vfmaq_f32(_sum0, _val, _w0);
Expand Down Expand Up @@ -507,7 +507,7 @@ static void innerproduct_fp16s_neon(const Mat& bottom_blob, Mat& top_blob, const
float32x4_t _w = vcvt_f32_f16(vld1_f16(kptr));
#else
float32x4_t _val = vld1q_f32(sptr);
float32x4_t _w = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(kptr)));
float32x4_t _w = vcvt_f32_f16((float16x4_t)(vld1_u16(kptr)));
#endif
_sum = vfmaq_f32(_sum, _val, _w);

Expand Down Expand Up @@ -713,10 +713,10 @@ static void innerproduct_transform_kernel_fp16s_neon(const Mat& weight_data, Mat
{
// transpose 4x4
uint16x4x4_t _p;
_p.val[0] = vreinterpret_u16_f16(vcvt_f16_f32(vld1q_f32(k0)));
_p.val[1] = vreinterpret_u16_f16(vcvt_f16_f32(vld1q_f32(k1)));
_p.val[2] = vreinterpret_u16_f16(vcvt_f16_f32(vld1q_f32(k2)));
_p.val[3] = vreinterpret_u16_f16(vcvt_f16_f32(vld1q_f32(k3)));
_p.val[0] = (uint16x4_t)(vcvt_f16_f32(vld1q_f32(k0)));
_p.val[1] = (uint16x4_t)(vcvt_f16_f32(vld1q_f32(k1)));
_p.val[2] = (uint16x4_t)(vcvt_f16_f32(vld1q_f32(k2)));
_p.val[3] = (uint16x4_t)(vcvt_f16_f32(vld1q_f32(k3)));
vst4_u16(g0, _p);

k0 += 4;
Expand Down
20 changes: 10 additions & 10 deletions src/layer/arm/innerproduct_gemm_fp16s.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ static void innerproduct_gemm_fp16s_neon(const Mat& bottom_blob, Mat& top_blob,
float32x4_t _w = vcvt_f32_f16(vld1_f16(kptr));
#else
float32x4_t _val = vld1q_f32(m);
float32x4_t _w = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(kptr)));
float32x4_t _w = vcvt_f32_f16((float16x4_t)(vld1_u16(kptr)));
#endif

#if __aarch64__
Expand Down Expand Up @@ -214,10 +214,10 @@ static void innerproduct_gemm_fp16s_neon(const Mat& bottom_blob, Mat& top_blob,
float32x4_t _val = vld1q_f32(m);
uint16x8_t _w01 = vld1q_u16(kptr);
uint16x8_t _w23 = vld1q_u16(kptr + 8);
float32x4_t _w0 = vcvt_f32_f16(vreinterpret_f16_u16(vget_low_u16(_w01)));
float32x4_t _w1 = vcvt_f32_f16(vreinterpret_f16_u16(vget_high_u16(_w01)));
float32x4_t _w2 = vcvt_f32_f16(vreinterpret_f16_u16(vget_low_u16(_w23)));
float32x4_t _w3 = vcvt_f32_f16(vreinterpret_f16_u16(vget_high_u16(_w23)));
float32x4_t _w0 = vcvt_f32_f16((float16x4_t)(vget_low_u16(_w01)));
float32x4_t _w1 = vcvt_f32_f16((float16x4_t)(vget_high_u16(_w01)));
float32x4_t _w2 = vcvt_f32_f16((float16x4_t)(vget_low_u16(_w23)));
float32x4_t _w3 = vcvt_f32_f16((float16x4_t)(vget_high_u16(_w23)));
#endif

#if __aarch64__
Expand All @@ -242,7 +242,7 @@ static void innerproduct_gemm_fp16s_neon(const Mat& bottom_blob, Mat& top_blob,
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
float32x4_t _w = vcvt_f32_f16(vld1_f16(kptr));
#else
float32x4_t _w = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(kptr)));
float32x4_t _w = vcvt_f32_f16((float16x4_t)(vld1_u16(kptr)));
#endif
_sum0 = vfmaq_f32(_sum0, _val, _w);

Expand Down Expand Up @@ -317,7 +317,7 @@ static void innerproduct_gemm_fp16s_neon(const Mat& bottom_blob, Mat& top_blob,
float32x4_t _val1 = vld1q_f32(m + 4);
float32x4_t _val2 = vld1q_f32(m + 8);
float32x4_t _val3 = vld1q_f32(m + 12);
float32x4_t _w = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(kptr)));
float32x4_t _w = vcvt_f32_f16((float16x4_t)(vld1_u16(kptr)));
#endif

#if __aarch64__
Expand Down Expand Up @@ -414,8 +414,8 @@ static void innerproduct_gemm_fp16s_neon(const Mat& bottom_blob, Mat& top_blob,
float32x4_t _val0 = vld1q_f32(m);
float32x4_t _val1 = vld1q_f32(m + 4);
uint16x8_t _w01 = vld1q_u16(kptr);
float32x4_t _w0 = vcvt_f32_f16(vreinterpret_f16_u16(vget_low_u16(_w01)));
float32x4_t _w1 = vcvt_f32_f16(vreinterpret_f16_u16(vget_high_u16(_w01)));
float32x4_t _w0 = vcvt_f32_f16((float16x4_t)(vget_low_u16(_w01)));
float32x4_t _w1 = vcvt_f32_f16((float16x4_t)(vget_high_u16(_w01)));
#endif

_sum0 = vfmaq_f32(_sum0, _val0, _w0);
Expand All @@ -433,7 +433,7 @@ static void innerproduct_gemm_fp16s_neon(const Mat& bottom_blob, Mat& top_blob,
float32x4_t _w = vcvt_f32_f16(vld1_f16(kptr));
#else
float32x4_t _val = vld1q_f32(m);
float32x4_t _w = vcvt_f32_f16(vreinterpret_f16_u16(vld1_u16(kptr)));
float32x4_t _w = vcvt_f32_f16((float16x4_t)(vld1_u16(kptr)));
#endif

_sum0 = vfmaq_f32(_sum0, _val, _w);
Expand Down
Loading

0 comments on commit 3e2b3fa

Please sign in to comment.