Skip to content

Commit

Permalink
Revise swish (#5983)
Browse files Browse the repository at this point in the history
* Update Swish OP description.

Signed-off-by: Luwei Zhou <[email protected]>

* Use RTTI to declare/define NGraph Swish OP.
Add input element type check when constructing Swish OP.

Signed-off-by: Luwei Zhou <[email protected]>

* Add Swish into activation serialization test list.

Signed-off-by: Luwei Zhou <[email protected]>

* Add Swish into IE CPU plugin activation single layer test suit.

Signed-off-by: Luwei Zhou <[email protected]>

* Add Swish NGraph backend and visitor API tests.

Signed-off-by: Luwei Zhou <[email protected]>

* Add Swish unsupported parameter data type test cases.

Signed-off-by: Luwei Zhou <[email protected]>

* Update the Swish OP visistor API to use typed test.

Signed-off-by: Luwei Zhou <[email protected]>
  • Loading branch information
luweizhou2016 authored Jun 23, 2021
1 parent f72e248 commit c483cdc
Show file tree
Hide file tree
Showing 9 changed files with 115 additions and 5 deletions.
3 changes: 2 additions & 1 deletion docs/ops/activation/Swish_4.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
**Detailed description**

*Swish* operation is introduced in this [article](https://arxiv.org/abs/1710.05941).
It performs element-wise activation function on a given input tensor, based on the following mathematical formula:

*Swish* is a smooth, non-monotonic function. The non-monotonicity property of *Swish* distinguishes itself from most common activation functions. It performs element-wise activation function on a given input tensor, based on the following mathematical formula:

\f[
Swish(x) = x\cdot \sigma(\beta x) = x \left(1 + e^{-(\beta x)}\right)^{-1}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ const std::map<ActivationTypes, std::vector<std::vector<float>>> activationTypes
{Ceiling, {}},
{Mish, {}},
{HSwish, {}},
{Swish, {{0.3f}}},
{SoftPlus, {}},
{HSigmoid, {}},
{RoundHalfToEven, {}},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ const std::map<ActivationTypes, std::vector<std::vector<float>>> activationTypes
{RoundHalfAwayFromZero, {}},
{Erf, {}},
{GeluErf, {}},
{GeluTanh, {}}
{GeluTanh, {}},
{Swish, {{0.4f}}}
};

// List of operations that should be tested also with integer precision
Expand Down
3 changes: 1 addition & 2 deletions ngraph/core/include/ngraph/op/swish.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ namespace ngraph
class NGRAPH_API Swish : public ngraph::op::Op
{
public:
static constexpr NodeTypeInfo type_info{"Swish", 4};
const NodeTypeInfo& get_type_info() const override { return type_info; }
NGRAPH_RTTI_DECLARATION;
Swish() = default;

/// \brief Constructs an Swish operation.
Expand Down
8 changes: 7 additions & 1 deletion ngraph/core/src/op/swish.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
using namespace std;
using namespace ngraph;

constexpr NodeTypeInfo op::v4::Swish::type_info;
NGRAPH_RTTI_DEFINITION(op::v4::Swish, "Swish", 4);

op::v4::Swish::Swish(const Output<Node>& arg)
: Op({arg})
Expand Down Expand Up @@ -43,6 +43,12 @@ void op::v4::Swish::validate_and_infer_types()
"Swish must have 1 or 2 inputs, but it has: ",
inputs_count);

NODE_VALIDATION_CHECK(this,
get_input_element_type(0).is_real(),
"Swish input tensor must be floating point type(",
get_input_element_type(0),
").");

if (inputs_count == 2)
{
NODE_VALIDATION_CHECK(this,
Expand Down
2 changes: 2 additions & 0 deletions ngraph/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ set(SRC
visitors/op/squeeze.cpp
visitors/op/sqrt.cpp
visitors/op/strided_slice.cpp
visitors/op/swish.cpp
visitors/op/tanh.cpp
visitors/op/topk.cpp
visitors/op/transpose.cpp
Expand Down Expand Up @@ -476,6 +477,7 @@ set(MULTI_TEST_SRC
backend/squared_difference.in.cpp
backend/squeeze.in.cpp
backend/subtract.in.cpp
backend/swish.in.cpp
backend/tan.in.cpp
backend/tanh.in.cpp
backend/tile.in.cpp
Expand Down
75 changes: 75 additions & 0 deletions ngraph/test/backend/swish.in.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "util/engine/test_engines.hpp"
#include "util/test_case.hpp"
#include "util/test_control.hpp"

using namespace std;
using namespace ngraph;

static string s_manifest = "${MANIFEST}";
using TestEngine = test::ENGINE_CLASS_NAME(${BACKEND_NAME});

NGRAPH_TEST(${BACKEND_NAME}, swish_2D_with_beta0_6)
{
Shape in_shape{2, 4};
element::Type et = element::f32;
auto beta = 0.6f;

auto args0 = make_shared<op::Parameter>(et, in_shape);
auto args1 = make_shared<op::Parameter>(et, Shape{});
auto swish = make_shared<op::v4::Swish>(args0, args1);
auto f = make_shared<Function>(swish, ParameterVector{args0, args1});

vector<vector<float>> in_vec{vector<float>{0.4, -5.7, -6, 3, -0.9, 23, 5, 3.3} , vector<float>{beta}};
vector<float> out_vec{in_vec[0]};
std::transform(out_vec.begin(), out_vec.end(), out_vec.begin(), [&beta](float x) -> float { return (x / (1.0f + std::exp(x * beta * -1.0f)));});

auto test_case = test::TestCase<TestEngine>(f);
test_case.add_multiple_inputs<float>(in_vec);
test_case.add_expected_output<float>(in_shape, out_vec);
test_case.run();
}

NGRAPH_TEST(${BACKEND_NAME}, swish_2D_without_beta)
{
Shape in_shape{2, 3};
element::Type et = element::f32;

auto args0 = make_shared<op::Parameter>(et, in_shape);
auto swish = make_shared<op::v4::Swish>(args0);
auto f = make_shared<Function>(swish, ParameterVector{args0});

vector<float> in_vec{1, 8, -8, 17, -0.5, -1};
vector<float> out_vec{in_vec};
std::transform(out_vec.begin(), out_vec.end(), out_vec.begin(), [](float x) -> float { return (x / (1.0f + std::exp(x * -1.0f)));});

auto test_case = test::TestCase<TestEngine>(f);
test_case.add_input<float>(in_vec);
test_case.add_expected_output<float>(in_shape, out_vec);
test_case.run();
}

NGRAPH_TEST(${BACKEND_NAME}, swish_4D_with_beta0_33)
{
Shape in_shape{2, 2, 1, 2};
element::Type et = element::f32;
auto beta = 0.33f;

auto args0 = make_shared<op::Parameter>(et, in_shape);
auto args1 = make_shared<op::Parameter>(et, Shape{});
auto swish = make_shared<op::v4::Swish>(args0, args1);
auto f = make_shared<Function>(swish, ParameterVector{args0, args1});

vector<vector<float>> in_vec{vector<float>{0.1, 0.6, 20, -7, -5.3, 3.5, -9, 11} , vector<float>{beta}};
vector<float> out_vec{in_vec[0]};
std::transform(out_vec.begin(), out_vec.end(), out_vec.begin(), [&beta](float x) -> float { return (x / (1.0f + std::exp(x * beta * -1.0f)));});

auto test_case = test::TestCase<TestEngine>(f);
test_case.add_multiple_inputs<float>(in_vec);
test_case.add_expected_output<float>(in_shape, out_vec);
test_case.run();
}
14 changes: 14 additions & 0 deletions ngraph/test/type_prop/swish.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,17 @@ TEST(type_prop, swish_2_inputs)
ASSERT_TRUE(swish_func->get_output_partial_shape(0).same_scheme(data->get_output_shape(0)));
ASSERT_TRUE(swish_func->get_output_partial_shape(0).rank().is_static());
}

TEST(type_prop, swish_incompatible_type_boolean)
{
auto data = make_shared<op::Parameter>(element::boolean, Shape{1, 3, 6});
auto beta = make_shared<op::Parameter>(element::f32, Shape{});
ASSERT_THROW(make_shared<op::v4::Swish>(data, beta);, ngraph::NodeValidationFailure);
}

TEST(type_prop, swish_incompatible_types_u32)
{
auto data = make_shared<op::Parameter>(element::f32, Shape{1, 3, 6});
auto beta = make_shared<op::Parameter>(element::u32, Shape{});
ASSERT_THROW(make_shared<op::v4::Swish>(data, beta);, ngraph::NodeValidationFailure);
}
11 changes: 11 additions & 0 deletions ngraph/test/visitors/op/swish.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "unary_ops.hpp"

using Type = ::testing::Types<UnaryOperatorType<ngraph::op::v4::Swish, element::f32>>;

INSTANTIATE_TYPED_TEST_CASE_P(visitor_without_atrribute,
UnaryOperatorVisitor,
Type,
UnaryOperatorTypeName);

0 comments on commit c483cdc

Please sign in to comment.