Skip to content

Commit

Permalink
[core] Migrate HSwish operator to new API (openvinotoolkit#20854)
Browse files Browse the repository at this point in the history
* Drop ngraph remains

* Use ov::Tensor

instaed of ngraph::HostTensor
  • Loading branch information
t-jankowski authored Nov 6, 2023
1 parent 64c21fd commit 5cd9659
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 54 deletions.
6 changes: 1 addition & 5 deletions src/core/include/openvino/op/hswish.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,8 @@ class OPENVINO_API HSwish : public util::UnaryElementwiseArithmetic {
/// \param data Input tensor
HSwish(const Output<Node>& arg);

bool visit_attributes(AttributeVisitor& visitor) override;

std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
OPENVINO_SUPPRESS_DEPRECATED_START
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
OPENVINO_SUPPRESS_DEPRECATED_END
bool evaluate(TensorVector& outputs, const TensorVector& inputs) const override;
bool has_evaluate() const override;
};
} // namespace v4
Expand Down
84 changes: 35 additions & 49 deletions src/core/src/op/hswish.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,78 +2,64 @@
// SPDX-License-Identifier: Apache-2.0
//

#include "ngraph/op/hswish.hpp"

#include <ngraph/validation_util.hpp>
#include "openvino/op/hswish.hpp"

#include "element_visitor.hpp"
#include "itt.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/runtime/host_tensor.hpp"
#include "openvino/reference/hswish.hpp"

using namespace std;
using namespace ngraph;

op::v4::HSwish::HSwish(const Output<Node>& arg) : UnaryElementwiseArithmetic(arg) {
namespace ov {
namespace op {
namespace v4 {
HSwish::HSwish(const Output<Node>& arg) : UnaryElementwiseArithmetic(arg) {
constructor_validate_and_infer_types();
}

bool op::v4::HSwish::visit_attributes(AttributeVisitor& visitor) {
OV_OP_SCOPE(v4_HSwish_visit_attributes);
return true;
}

shared_ptr<Node> op::v4::HSwish::clone_with_new_inputs(const OutputVector& new_args) const {
std::shared_ptr<Node> HSwish::clone_with_new_inputs(const OutputVector& new_args) const {
OV_OP_SCOPE(v4_HSwish_clone_with_new_inputs);
return make_shared<op::v4::HSwish>(new_args.at(0));
return std::make_shared<HSwish>(new_args.at(0));
}

OPENVINO_SUPPRESS_DEPRECATED_START
namespace hswish {
namespace {
template <element::Type_t ET>
inline bool evaluate(const HostTensorPtr& arg, const HostTensorPtr& out, const size_t count) {
using T = typename element_type_traits<ET>::value_type;

ov::reference::hswish<T>(arg->get_data_ptr<ET>(), out->get_data_ptr<ET>(), count);
return true;
}
struct Evaluate : element::NoAction<bool> {
using element::NoAction<bool>::visit;

bool evaluate_hswish(const HostTensorPtr& arg, const HostTensorPtr& out) {
bool rc = true;
size_t count = shape_size(arg->get_shape());
out->set_unary(arg);

switch (arg->get_element_type()) {
OPENVINO_TYPE_CASE(evaluate_hswish, bf16, arg, out, count);
OPENVINO_TYPE_CASE(evaluate_hswish, f16, arg, out, count);
OPENVINO_TYPE_CASE(evaluate_hswish, f32, arg, out, count);
default:
rc = false;
break;
template <element::Type_t ET, class T = fundamental_type_for<ET>>
static result_type visit(const Tensor& in, Tensor& out, const size_t count) {
ov::reference::hswish(in.data<const T>(), out.data<T>(), count);
return true;
}
return rc;
}
};
} // namespace
} // namespace hswish

bool op::v4::HSwish::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const {
bool HSwish::evaluate(TensorVector& outputs, const TensorVector& inputs) const {
OV_OP_SCOPE(v4_HSwish_evaluate);
OPENVINO_SUPPRESS_DEPRECATED_START
OPENVINO_ASSERT(validate_host_tensor_vector(outputs, 1) && validate_host_tensor_vector(inputs, 1));
OPENVINO_SUPPRESS_DEPRECATED_END
return hswish::evaluate_hswish(inputs[0], outputs[0]);
OPENVINO_ASSERT(inputs.size() == 1);
OPENVINO_ASSERT(outputs.size() == 1);

const auto& input_shape = inputs[0].get_shape();
const auto count = shape_size(input_shape);
outputs[0].set_shape(input_shape);
using namespace ov::element;
return IfTypeOf<bf16, f16, f32>::apply<hswish::Evaluate>(inputs[0].get_element_type(),
inputs[0],
outputs[0],
count);
}

bool op::v4::HSwish::has_evaluate() const {
bool HSwish::has_evaluate() const {
OV_OP_SCOPE(v4_HSwish_has_evaluate);
switch (get_input_element_type(0)) {
case ngraph::element::bf16:
case ngraph::element::f16:
case ngraph::element::f32:
case element::bf16:
case element::f16:
case element::f32:
return true;
default:
break;
return false;
}
return false;
}
} // namespace v4
} // namespace op
} // namespace ov

0 comments on commit 5cd9659

Please sign in to comment.