Skip to content

Commit

Permalink
Moved op utils to ov namespace
Browse files Browse the repository at this point in the history
  • Loading branch information
ilyachur committed Aug 27, 2021
1 parent b7c7803 commit d41fc90
Show file tree
Hide file tree
Showing 69 changed files with 2,754 additions and 2,180 deletions.
8 changes: 2 additions & 6 deletions ngraph/core/include/ngraph/op/op.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,10 @@
#include <string>

#include "ngraph/node.hpp"
#include "openvino/op/op.hpp"

namespace ngraph {
namespace op {
/// Root of all actual ops
class NGRAPH_API Op : public Node {
protected:
Op() : Node() {}
Op(const OutputVector& arguments);
};
using ov::op::Op;
} // namespace op
} // namespace ngraph
91 changes: 9 additions & 82 deletions ngraph/core/include/ngraph/op/util/activation_functions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,98 +9,25 @@

#include "ngraph/except.hpp"
#include "ngraph/node.hpp"

#ifdef _WIN32
# pragma warning(push)

# pragma warning(disable : 4100)
#endif

// Prevents the compiler from complaining about or optimizing away variables
// that appear unused on Linux
#if (defined(__GNUC__) && !defined(__clang__))
# undef NG_ATTRIBUTE_UNUSED
# define NG_ATTRIBUTE_UNUSED __attribute__((__unused__))
#else
# define NG_ATTRIBUTE_UNUSED
#endif

#define UNUSED_PARAMETER NG_ATTRIBUTE_UNUSED = 0
#include "openvino/op/util/activation_functions.hpp"

namespace ngraph {
namespace op {
namespace util {
namespace error {
struct UnknownActivationFunction : ngraph_error {
UnknownActivationFunction(const std::string& func_name)
: ngraph_error{"Unknown activation function: " + func_name} {}
};
using ov::op::util::error::UnknownActivationFunction;
} // namespace error

namespace detail {
std::shared_ptr<Node> sigmoid(const std::shared_ptr<Node>& arg,
float alpha UNUSED_PARAMETER,
float beta UNUSED_PARAMETER);
std::shared_ptr<Node> tanh(const std::shared_ptr<Node>& arg, float alpha UNUSED_PARAMETER, float beta UNUSED_PARAMETER);
std::shared_ptr<Node> relu(const std::shared_ptr<Node>& arg, float alpha UNUSED_PARAMETER, float beta UNUSED_PARAMETER);
std::shared_ptr<Node> hardsigmoid(const std::shared_ptr<Node>& arg, float alpha, float beta);
using ov::op::util::detail::hardsigmoid;
using ov::op::util::detail::relu;
using ov::op::util::detail::sigmoid;
using ov::op::util::detail::tanh;
} // namespace detail

using ActivationFunctionType = std::shared_ptr<Node> (*)(const std::shared_ptr<Node>&, float, float);

///
/// \brief Class representing activation function used in RNN cells.
///
class NGRAPH_API ActivationFunction {
public:
ActivationFunction(ActivationFunctionType f, float alpha, float beta);
ActivationFunction(ActivationFunctionType f, float alpha);
ActivationFunction(ActivationFunctionType f);
ActivationFunction() = default;

///
/// \brief Calls stored activation function with provided node argument.
///
std::shared_ptr<Node> operator()(const std::shared_ptr<Node>& arg) const;

void set_alpha(float alpha) {
m_alpha = alpha;
}
void set_beta(float beta) {
m_beta = beta;
}

private:
/// \brief Activation function wrapper.
ActivationFunctionType m_function;
/// \brief Activation function alpha parameter (may be unused).
float m_alpha;
/// \brief Activation function beta parameter (may be unused).
float m_beta;
};

/// \brief Gets the activation function by name.
///
/// \param[in] func_name The function name
///
/// \throws UnknownActivationFunction When provided func_name is unknown.
///
/// \return The activation function object.
///
ActivationFunction get_activation_func_by_name(const std::string& func_name);
using ov::op::util::ActivationFunction;
using ov::op::util::ActivationFunctionType;
using ov::op::util::get_activation_func_by_name;
} // namespace util

} // namespace op

} // namespace ngraph

#ifdef _WIN32
# pragma warning(pop)
#endif

#ifdef UNUSED_PARAMETER
# undef UNUSED_PARAMETER
#endif
#ifdef NG_ATTRIBUTE_UNUSED
# undef NG_ATTRIBUTE_UNUSED
#endif
31 changes: 2 additions & 29 deletions ngraph/core/include/ngraph/op/util/arithmetic_reduction.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,39 +6,12 @@

#include "ngraph/op/op.hpp"
#include "ngraph/op/util/reduction_base.hpp"
#include "openvino/op/util/arithmetic_reduction.hpp"

namespace ngraph {
namespace op {
namespace util {
/// \brief Abstract base class for arithmetic reduction operations, i.e., operations
/// where chosen axes of the input tensors are eliminated (reduced out) by
/// repeated application of a particular binary arithmetic operation.
class NGRAPH_API ArithmeticReduction : public ReductionBase {
protected:
/// \brief Constructs an arithmetic reduction operation.
ArithmeticReduction();

/// \brief Constructs an arithmetic reduction operation.
///
/// \param arg Output that produces the first input tensor.
/// \param reduction_axes The axis positions (0-based) to be eliminated.
ArithmeticReduction(const Output<Node>& arg, const Output<Node>& reduction_axes);

public:
NGRAPH_RTTI_DECLARATION;
void validate_and_infer_types() override;

/// \return true if reduction axes are constant else false.
bool reduction_axes_constant() const;

/// \return The axis positions (0-based) to be eliminated through reduction.
/// \throws CheckFailure if the reduction axes are not constant. (Use
/// reduction_axes_constant to check.)
const AxisSet get_reduction_axes() const;

/// \brief Change the reduction axes
void set_reduction_axes(const AxisSet& reduction_axes);
};
using ov::op::util::ArithmeticReduction;
} // namespace util
} // namespace op
} // namespace ngraph
Original file line number Diff line number Diff line change
Expand Up @@ -6,37 +6,12 @@

#include "ngraph/op/op.hpp"
#include "ngraph/op/util/arithmetic_reduction.hpp"
#include "openvino/op/util/arithmetic_reductions_keep_dims.hpp"

namespace ngraph {
namespace op {
namespace util {
class NGRAPH_API ArithmeticReductionKeepDims : public util::ArithmeticReduction {
protected:
ArithmeticReductionKeepDims() = default;

/// \param arg The tensor to be summed.
/// \param reduction_axes The axis positions (0-based) to be eliminated.
/// \param keep_dims If set to 1 it holds axes that are used for reduction.
ArithmeticReductionKeepDims(const Output<Node>& arg, const Output<Node>& reduction_axes, bool keep_dims = false);

bool visit_attributes(AttributeVisitor& visitor) override;

public:
NGRAPH_RTTI_DECLARATION;
void validate_and_infer_types() override;

/// \return If set to 1 it holds axes that are used for reduction.
/// For each such axis, output dimension is equal to 1.
bool get_keep_dims() const {
return m_keep_dims;
}
void set_keep_dims(bool keep_dims) {
m_keep_dims = keep_dims;
}

private:
bool m_keep_dims = false;
};
using ov::op::util::ArithmeticReductionKeepDims;
} // namespace util
} // namespace op
} // namespace ngraph
Loading

0 comments on commit d41fc90

Please sign in to comment.