diff --git a/src/layer/x86/gelu_x86.cpp b/src/layer/x86/gelu_x86.cpp index bfcec9e3d566..352d330b8777 100644 --- a/src/layer/x86/gelu_x86.cpp +++ b/src/layer/x86/gelu_x86.cpp @@ -1,6 +1,6 @@ // Tencent is pleased to support the open source community by making ncnn available. // -// Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved. +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. // // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except // in compliance with the License. You may obtain a copy of the License at @@ -35,8 +35,22 @@ GELU_x86::GELU_x86() #endif // __SSE2__ } +int GELU_x86::create_pipeline(const Option& /*opt*/) +{ + if (!fast_gelu) + { + support_packing = false; + } + return 0; +} + int GELU_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const { + if (!fast_gelu) + { + return GELU::forward_inplace(bottom_top_blob, opt); + } + int w = bottom_top_blob.w; int h = bottom_top_blob.h; int elempack = bottom_top_blob.elempack; @@ -51,21 +65,12 @@ int GELU_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const int i = 0; #if __SSE2__ - __m128 _half128 = _mm_set1_ps(0.5f); - __m128 _one128 = _mm_set1_ps(1.f); - __m128 _fast1c128 = _mm_set1_ps(0.79788452f); - __m128 _fast2c128 = _mm_set1_ps(0.044715f); #if __AVX__ - __m256 _half256 = _mm256_set1_ps(0.5f); - __m256 _one256 = _mm256_set1_ps(1.f); - __m256 _fast1c256 = _mm256_set1_ps(0.79788452f); - __m256 _fast2c256 = _mm256_set1_ps(0.044715f); #if __AVX512F__ __m512 _half512 = _mm512_set1_ps(0.5f); __m512 _one512 = _mm512_set1_ps(1.f); __m512 _fast1c512 = _mm512_set1_ps(0.79788452f); __m512 _fast2c512 = _mm512_set1_ps(0.044715f); - for (; i + 15 < size; i += 16) { __m512 _pLoad = _mm512_loadu_ps(ptr); @@ -86,6 +91,10 @@ int GELU_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const ptr += 16; } #endif // __AVX512F__ + __m256 _half256 = _mm256_set1_ps(0.5f); + __m256 _one256 = _mm256_set1_ps(1.f); + __m256 _fast1c256 = _mm256_set1_ps(0.79788452f); + __m256 _fast2c256 = _mm256_set1_ps(0.044715f); for (; i + 7 < size; i += 8) { __m256 _pLoad = _mm256_loadu_ps(ptr); @@ -106,6 +115,10 @@ int GELU_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const ptr += 8; } #endif // __AVX__ + __m128 _half128 = _mm_set1_ps(0.5f); + __m128 _one128 = _mm_set1_ps(1.f); + __m128 _fast1c128 = _mm_set1_ps(0.79788452f); + __m128 _fast2c128 = _mm_set1_ps(0.044715f); for (; i + 3 < size; i += 4) { __m128 _pLoad = _mm_loadu_ps(ptr); diff --git a/src/layer/x86/gelu_x86.h b/src/layer/x86/gelu_x86.h index 8c75aeb88c4a..75d821bfd45d 100644 --- a/src/layer/x86/gelu_x86.h +++ b/src/layer/x86/gelu_x86.h @@ -1,6 +1,6 @@ // Tencent is pleased to support the open source community by making ncnn available. // -// Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved. +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. // // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except // in compliance with the License. You may obtain a copy of the License at @@ -24,6 +24,7 @@ class GELU_x86 : virtual public GELU public: GELU_x86(); + virtual int create_pipeline(const Option& opt); virtual int forward_inplace(Mat& bottom_top_blob, const Option& opt) const; };