Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce ov element type visitor #18189

Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 118 additions & 0 deletions src/core/shape_inference/include/element_visitor.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <functional>

#include "openvino/core/except.hpp"
#include "openvino/core/type/element_type.hpp"

namespace ov {
namespace element {

/**
* @brief Defines supported elements for applying the element visitor action.
*
* @tparam List of supported ov::element types.
*
* The apply function check input element type, if is on the list apply function of visitor for specific type
praasz marked this conversation as resolved.
Show resolved Hide resolved
* if not found apply default action.
praasz marked this conversation as resolved.
Show resolved Hide resolved
*/
template <Type_t...>
struct IfTypeOf;

/**
* @brief Applies visitor action for not supported ov::element type.
*/
template <>
struct IfTypeOf<> {
/**
* @brief Applies visitor default action for not supported element type using Visitor non-template visit function.
*
* @tparam Visitor Visitor class implementing visit function.
* @tparam Args Types of visit parameters.
*
* @param et Input element type.
* @param args Visitor arguments.
* @return Value of result type returned by Visitor.
*/
template <class Visitor, class... Args>
static auto apply(Type_t et, Args&&... args) -> typename Visitor::result_type {
return Visitor::visit();
}
};

/**
* @brief Applies visitor action for supported element type defined by template parameters.
*
* @tparam ET Current ov::element type used for check with input.
* @tparam Others Others supported ov::element.
*/
template <Type_t ET, Type_t... Others>
struct IfTypeOf<ET, Others...> {
/**
* @brief Applies visitor action for element type using Visitor visit function for ET.
*
* @tparam Visitor Visitor class implementing visit function.
* @tparam Args Types of visit parameters.
*
* @param et Input element type.
* @param args Visitor arguments.
* @return Value of result type returned by Visitor.
*/
template <class Visitor, class... Args>
static auto apply(Type_t et, Args&&... args) -> typename Visitor::result_type {
return (et == ET) ? Visitor::template visit<ET>(std::forward<Args>(args)...)
: IfTypeOf<Others...>::template apply<Visitor>(et, std::forward<Args>(args)...);
}
};

/**
* @brief Helper visitor which define no action for not supported type.
praasz marked this conversation as resolved.
Show resolved Hide resolved
*
* @tparam R Type of return value.
* @tparam value Default value returned.
*/
template <class R, R... value>
struct NoAction {
static_assert(sizeof...(value) < 2, "There should no more then one result value.");
t-jankowski marked this conversation as resolved.
Show resolved Hide resolved

using result_type = R;

static constexpr R visit() {
return {value...};
}
};

/**
* @brief Helper visitor which define no action for not supported type if result is void type.
t-jankowski marked this conversation as resolved.
Show resolved Hide resolved
*/
template <>
struct NoAction<void> {
using result_type = void;

static void visit() {}
};

/**
* @brief Helper visitor which throws ov::Exception for not supported element type.
*
* @tparam R Type of return type (used to be compatible with others call operator in Visitor).
praasz marked this conversation as resolved.
Show resolved Hide resolved
*/
template <class R>
struct NotSupported {
using result_type = R;

[[noreturn]] static R visit() {
throw_not_supported();
}

private:
[[noreturn]] static void throw_not_supported() {
OPENVINO_THROW("Element not supported");
}
};
} // namespace element
} // namespace ov
110 changes: 22 additions & 88 deletions src/core/shape_inference/include/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <type_traits>

#include "bound_evaluation_util.hpp"
#include "element_visitor.hpp"
#include "shape_infer_type_utils.hpp"
#include "tensor_data_accessor.hpp"

Expand Down Expand Up @@ -50,6 +51,20 @@ void eltwise_shape_infer(const OpType* op, const std::vector<T>& input_shapes, s
}

namespace ov {

struct TensorTransform : element::NotSupported<void> {
using element::NotSupported<void>::visit;

template <element::Type_t ET, class Iterator, class UnaryOperation>
static result_type visit(const void* const ptr, const size_t size, Iterator out_it, UnaryOperation&& func) {
using T = fundamental_type_for<ET>;
std::transform(static_cast<const T*>(ptr),
static_cast<const T*>(ptr) + size,
out_it,
std::forward<UnaryOperation>(func));
}
};

/**
* \brief Get the raw data as TResult object.
*
Expand All @@ -71,94 +86,13 @@ TResult get_raw_data_as(const element::Type_t et, const void* const ptr, const s
TResult out;
auto out_it = std::inserter(out, out.end());

switch (et) {
case element::Type_t::i4: {
using dtype = fundamental_type_for<element::Type_t::i4>;
std::transform(static_cast<const dtype*>(ptr),
static_cast<const dtype*>(ptr) + size,
out_it,
std::forward<UnaryOperation>(func));
} break;
case element::Type_t::i8: {
using dtype = fundamental_type_for<element::Type_t::i8>;
std::transform(static_cast<const dtype*>(ptr),
static_cast<const dtype*>(ptr) + size,
out_it,
std::forward<UnaryOperation>(func));
} break;
case element::Type_t::i16: {
using dtype = fundamental_type_for<element::Type_t::i16>;
std::transform(static_cast<const dtype*>(ptr),
static_cast<const dtype*>(ptr) + size,
out_it,
std::forward<UnaryOperation>(func));
} break;
case element::Type_t::i32: {
using dtype = fundamental_type_for<element::Type_t::i32>;
std::transform(static_cast<const dtype*>(ptr),
static_cast<const dtype*>(ptr) + size,
out_it,
std::forward<UnaryOperation>(func));
} break;
case element::Type_t::i64: {
using dtype = fundamental_type_for<element::Type_t::i64>;
std::transform(static_cast<const dtype*>(ptr),
static_cast<const dtype*>(ptr) + size,
out_it,
std::forward<UnaryOperation>(func));
} break;
case element::Type_t::u4: {
using dtype = fundamental_type_for<element::Type_t::u4>;
std::transform(static_cast<const dtype*>(ptr),
static_cast<const dtype*>(ptr) + size,
out_it,
std::forward<UnaryOperation>(func));
} break;
case element::Type_t::u8: {
using dtype = fundamental_type_for<element::Type_t::u8>;
std::transform(static_cast<const dtype*>(ptr),
static_cast<const dtype*>(ptr) + size,
out_it,
std::forward<UnaryOperation>(func));
} break;
case element::Type_t::u16: {
using dtype = fundamental_type_for<element::Type_t::u16>;
std::transform(static_cast<const dtype*>(ptr),
static_cast<const dtype*>(ptr) + size,
out_it,
std::forward<UnaryOperation>(func));
} break;
case element::Type_t::u32: {
using dtype = fundamental_type_for<element::Type_t::u32>;
std::transform(static_cast<const dtype*>(ptr),
static_cast<const dtype*>(ptr) + size,
out_it,
std::forward<UnaryOperation>(func));
} break;
case element::Type_t::u64: {
using dtype = fundamental_type_for<element::Type_t::u64>;
std::transform(static_cast<const dtype*>(ptr),
static_cast<const dtype*>(ptr) + size,
out_it,
std::forward<UnaryOperation>(func));
} break;
case element::Type_t::f16: {
using dtype = fundamental_type_for<element::Type_t::f16>;
std::transform(static_cast<const dtype*>(ptr),
static_cast<const dtype*>(ptr) + size,
out_it,
std::forward<UnaryOperation>(func));
} break;
case element::Type_t::f32: {
using dtype = fundamental_type_for<element::Type_t::f32>;
std::transform(static_cast<const dtype*>(ptr),
static_cast<const dtype*>(ptr) + size,
out_it,
std::forward<UnaryOperation>(func));
} break;
default:
OPENVINO_ASSERT(false, "Get raw data from tensor is not supported for element type: ", et);
};
using namespace ov::element;
IfTypeOf<i4, i8, i16, i32, i64, u4, u8, u16, u32, u64, f16, f32>::apply<TensorTransform>(
et,
ptr,
size,
out_it,
std::forward<UnaryOperation>(func));
return out;
}

Expand Down
74 changes: 36 additions & 38 deletions src/core/src/op/round.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include "ngraph/op/round.hpp"

#include "element_visitor.hpp"
#include "itt.hpp"
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/op/util/eval_copy.hpp"
Expand All @@ -15,50 +16,48 @@ using namespace std;
using namespace ngraph;

namespace roundop {
namespace {
// function used by TYPE_CASE
template <element::Type_t ET>
inline bool evaluate(const HostTensorPtr& arg0,
const HostTensorPtr& out,
const size_t count,
const op::v5::Round::RoundMode mode) {
using T = typename element_type_traits<ET>::value_type;
runtime::reference::round<T>(arg0->get_data_ptr<ET>(), out->get_data_ptr<ET>(), count, mode);
return true;
}

// function used by COPY_TENSOR
template <element::Type_t ET>
inline bool copy_tensor(const HostTensorPtr& arg0, const HostTensorPtr& out, const size_t count) {
runtime::reference::copy(arg0->get_data_ptr<ET>(), out->get_data_ptr<ET>(), count);
return true;
}
struct Evaluate : ov::element::NoAction<bool> {
using ov::element::NoAction<bool>::visit;

template <element::Type_t ET>
static constexpr bool is_floating() {
return (ET == element::f16) || (ET == element::f32) || (ET == element::bf16);
praasz marked this conversation as resolved.
Show resolved Hide resolved
}

template <element::Type_t ET>
static typename std::enable_if<!is_floating<ET>(), bool>::type visit(const HostTensorPtr& arg0,
const HostTensorPtr& out,
const size_t count,
const op::v5::Round::RoundMode) {
memcpy(out->get_data_ptr(), arg0->get_data_ptr(), out->get_size_in_bytes());
praasz marked this conversation as resolved.
Show resolved Hide resolved
return true;
}

template <ov::element::Type_t ET>
static typename std::enable_if<is_floating<ET>(), bool>::type visit(const HostTensorPtr& arg0,
const HostTensorPtr& out,
const size_t count,
const op::v5::Round::RoundMode mode) {
ngraph::runtime::reference::round(arg0->get_data_ptr<ET>(), out->get_data_ptr<ET>(), count, mode);
return true;
}
};

namespace {
bool evaluate_round(const HostTensorPtr& arg0,
const HostTensorPtr& out,
const size_t count,
const op::v5::Round::RoundMode mode) {
bool rc = true;
out->set_unary(arg0);

switch (arg0->get_element_type()) {
NGRAPH_COPY_TENSOR(evaluate_round, boolean, arg0, out, count);
NGRAPH_COPY_TENSOR(evaluate_round, i8, arg0, out, count);
NGRAPH_COPY_TENSOR(evaluate_round, i16, arg0, out, count);
NGRAPH_COPY_TENSOR(evaluate_round, i32, arg0, out, count);
NGRAPH_COPY_TENSOR(evaluate_round, i64, arg0, out, count);
NGRAPH_COPY_TENSOR(evaluate_round, u8, arg0, out, count);
NGRAPH_COPY_TENSOR(evaluate_round, u16, arg0, out, count);
NGRAPH_COPY_TENSOR(evaluate_round, u32, arg0, out, count);
NGRAPH_COPY_TENSOR(evaluate_round, u64, arg0, out, count);
NGRAPH_TYPE_CASE(evaluate_round, f16, arg0, out, count, mode);
NGRAPH_TYPE_CASE(evaluate_round, f32, arg0, out, count, mode);
NGRAPH_TYPE_CASE(evaluate_round, bf16, arg0, out, count, mode);
default:
rc = false;
break;
}
return rc;
using namespace ov::element;
return IfTypeOf<boolean, i8, i16, i32, i64, u8, u16, u32, u64, f16, f32, bf16>::apply<Evaluate>(
arg0->get_element_type(),
arg0,
out,
count,
mode);
}
} // namespace
} // namespace roundop
Expand Down Expand Up @@ -108,9 +107,8 @@ bool op::v5::Round::has_evaluate() const {
case ngraph::element::bf16:
return true;
default:
break;
return false;
}
return false;
}

std::ostream& ov::operator<<(std::ostream& s, const op::v5::Round::RoundMode& type) {
Expand Down
Loading