Skip to content

Commit

Permalink
Add the create_pipeline when non-fastgelu and improve the format
Browse files Browse the repository at this point in the history
  • Loading branch information
LRY89757 committed Sep 17, 2022
1 parent 778bce7 commit 6e9cf57
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 11 deletions.
33 changes: 23 additions & 10 deletions src/layer/x86/gelu_x86.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
Expand Down Expand Up @@ -35,8 +35,22 @@ GELU_x86::GELU_x86()
#endif // __SSE2__
}

int GELU_x86::create_pipeline(const Option& /*opt*/)
{
if (!fast_gelu)
{
support_packing = false;
}
return 0;
}

int GELU_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
{
if (!fast_gelu)
{
return GELU::forward_inplace(bottom_top_blob, opt);
}

int w = bottom_top_blob.w;
int h = bottom_top_blob.h;
int elempack = bottom_top_blob.elempack;
Expand All @@ -51,21 +65,12 @@ int GELU_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
int i = 0;

#if __SSE2__
__m128 _half128 = _mm_set1_ps(0.5f);
__m128 _one128 = _mm_set1_ps(1.f);
__m128 _fast1c128 = _mm_set1_ps(0.79788452f);
__m128 _fast2c128 = _mm_set1_ps(0.044715f);
#if __AVX__
__m256 _half256 = _mm256_set1_ps(0.5f);
__m256 _one256 = _mm256_set1_ps(1.f);
__m256 _fast1c256 = _mm256_set1_ps(0.79788452f);
__m256 _fast2c256 = _mm256_set1_ps(0.044715f);
#if __AVX512F__
__m512 _half512 = _mm512_set1_ps(0.5f);
__m512 _one512 = _mm512_set1_ps(1.f);
__m512 _fast1c512 = _mm512_set1_ps(0.79788452f);
__m512 _fast2c512 = _mm512_set1_ps(0.044715f);

for (; i + 15 < size; i += 16)
{
__m512 _pLoad = _mm512_loadu_ps(ptr);
Expand All @@ -86,6 +91,10 @@ int GELU_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
ptr += 16;
}
#endif // __AVX512F__
__m256 _half256 = _mm256_set1_ps(0.5f);
__m256 _one256 = _mm256_set1_ps(1.f);
__m256 _fast1c256 = _mm256_set1_ps(0.79788452f);
__m256 _fast2c256 = _mm256_set1_ps(0.044715f);
for (; i + 7 < size; i += 8)
{
__m256 _pLoad = _mm256_loadu_ps(ptr);
Expand All @@ -106,6 +115,10 @@ int GELU_x86::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
ptr += 8;
}
#endif // __AVX__
__m128 _half128 = _mm_set1_ps(0.5f);
__m128 _one128 = _mm_set1_ps(1.f);
__m128 _fast1c128 = _mm_set1_ps(0.79788452f);
__m128 _fast2c128 = _mm_set1_ps(0.044715f);
for (; i + 3 < size; i += 4)
{
__m128 _pLoad = _mm_loadu_ps(ptr);
Expand Down
3 changes: 2 additions & 1 deletion src/layer/x86/gelu_x86.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
Expand All @@ -24,6 +24,7 @@ class GELU_x86 : virtual public GELU
public:
GELU_x86();

virtual int create_pipeline(const Option& opt);
virtual int forward_inplace(Mat& bottom_top_blob, const Option& opt) const;
};

Expand Down

0 comments on commit 6e9cf57

Please sign in to comment.