diff --git a/ngraph/core/reference/include/ngraph/runtime/reference/fake_quantize.hpp b/ngraph/core/reference/include/ngraph/runtime/reference/fake_quantize.hpp index 1cdb1e3543d246..b0dad60fe65f19 100644 --- a/ngraph/core/reference/include/ngraph/runtime/reference/fake_quantize.hpp +++ b/ngraph/core/reference/include/ngraph/runtime/reference/fake_quantize.hpp @@ -195,106 +195,72 @@ namespace ngraph } } // namespace fake_quantize_details - namespace v1 + + template + void + fake_quantize(const T* const arg, + const T* const in_low, + const T* const in_high, + const T* const out_low, + const T* const out_high, + T* const out, + const Shape& arg_shape, + const Shape& in_low_shape, + const Shape& in_high_shape, + const Shape& out_low_shape, + const Shape& out_high_shape, + size_t levels, + const op::AutoBroadcastSpec& broadcast = op::AutoBroadcastType::NUMPY) { - template - void fake_quantize(const T* const arg, - const T* const in_low, - const T* const in_high, - const T* const out_low, - const T* const out_high, - T* const out, - const Shape& arg_shape, - const Shape& in_low_shape, - const Shape& in_high_shape, - const Shape& out_low_shape, - const Shape& out_high_shape, - size_t levels, - const op::AutoBroadcastSpec& broadcast) - { - using namespace fake_quantize_details; + using namespace fake_quantize_details; - const FeRound round_mode(FE_TONEAREST); + const FeRound round_mode(FE_TONEAREST); - if (shape_size(in_low_shape) == 1 && shape_size(in_high_shape) == 1 && - shape_size(out_low_shape) == 1 && shape_size(out_high_shape) == 1) + if (shape_size(in_low_shape) == 1 && shape_size(in_high_shape) == 1 && + shape_size(out_low_shape) == 1 && shape_size(out_high_shape) == 1) + { + const size_t arg_size = shape_size(arg_shape); + const auto q = [=](const T& a) { + return quantize(a, *in_low, *in_high, *out_low, *out_high, levels); + }; + for (size_t i = 0; i < arg_size; ++i) { - const size_t arg_size = shape_size(arg_shape); - const auto q = [=](const T& a) { - return quantize(a, *in_low, *in_high, *out_low, *out_high, levels); - }; - for (size_t i = 0; i < arg_size; ++i) - { - out[i] = q(arg[i]); - } + out[i] = q(arg[i]); } - else - { - NGRAPH_CHECK(in_low_shape.size() <= arg_shape.size() && - in_high_shape.size() <= arg_shape.size() && - out_low_shape.size() <= arg_shape.size() && - out_high_shape.size() <= arg_shape.size(), - "Tensors with inout\\output ranges should have rank less or " - "equal to data tensor rank equal to ", - arg_shape.size()); + } + else + { + NGRAPH_CHECK(in_low_shape.size() <= arg_shape.size() && + in_high_shape.size() <= arg_shape.size() && + out_low_shape.size() <= arg_shape.size() && + out_high_shape.size() <= arg_shape.size(), + "Tensors with inout\\output ranges should have rank less or " + "equal to data tensor rank equal to ", + arg_shape.size()); - const QuantizationBound in_low_bound( - in_low, in_low_shape, arg_shape, broadcast); - const QuantizationBound in_high_bound( - in_high, in_high_shape, arg_shape, broadcast); - const QuantizationBound out_low_bound( - out_low, out_low_shape, arg_shape, broadcast); - const QuantizationBound out_high_bound( - out_high, out_high_shape, arg_shape, broadcast); + const QuantizationBound in_low_bound( + in_low, in_low_shape, arg_shape, broadcast); + const QuantizationBound in_high_bound( + in_high, in_high_shape, arg_shape, broadcast); + const QuantizationBound out_low_bound( + out_low, out_low_shape, arg_shape, broadcast); + const QuantizationBound out_high_bound( + out_high, out_high_shape, arg_shape, broadcast); - std::vector current_dim(arg_shape.size(), 0); - const auto arg_shape_size = shape_size(arg_shape); - for (size_t index = 0; index < arg_shape_size; ++index) - { - const T in_low_val = in_low_bound.get_value(current_dim, index); - const T in_high_val = in_high_bound.get_value(current_dim, index); - const T out_low_val = out_low_bound.get_value(current_dim, index); - const T out_high_val = out_high_bound.get_value(current_dim, index); + std::vector current_dim(arg_shape.size(), 0); + const auto arg_shape_size = shape_size(arg_shape); + for (size_t index = 0; index < arg_shape_size; ++index) + { + const T in_low_val = in_low_bound.get_value(current_dim, index); + const T in_high_val = in_high_bound.get_value(current_dim, index); + const T out_low_val = out_low_bound.get_value(current_dim, index); + const T out_high_val = out_high_bound.get_value(current_dim, index); - out[index] = quantize(arg[index], - in_low_val, - in_high_val, - out_low_val, - out_high_val, - levels); - increment_current_dim(current_dim, arg_shape); - } + out[index] = quantize( + arg[index], in_low_val, in_high_val, out_low_val, out_high_val, levels); + increment_current_dim(current_dim, arg_shape); } } - } // namespace v1 - - template - void fake_quantize(const T* const arg, - const T* const in_low, - const T* const in_high, - const T* const out_low, - const T* const out_high, - T* const out, - const Shape& arg_shape, - const Shape& in_low_shape, - const Shape& in_high_shape, - const Shape& out_low_shape, - const Shape& out_high_shape, - size_t levels) - { - v1::fake_quantize(arg, - in_low, - in_high, - out_low, - out_high, - out, - arg_shape, - in_low_shape, - in_high_shape, - out_low_shape, - out_high_shape, - levels, - op::AutoBroadcastType::NUMPY); } } // namespace reference } // namespace runtime diff --git a/ngraph/test/runtime/interpreter/evaluates_map.cpp b/ngraph/test/runtime/interpreter/evaluates_map.cpp index 360adf4008acb7..4414c70eb1bafe 100644 --- a/ngraph/test/runtime/interpreter/evaluates_map.cpp +++ b/ngraph/test/runtime/interpreter/evaluates_map.cpp @@ -1173,7 +1173,7 @@ namespace info.selected_outputs_shape, selected_indices.data(), info.selected_indices_shape, - valid_outputs.data()); + valid_outputs.data()); void* pscores = nullptr; void* pselected_num = nullptr; @@ -2383,19 +2383,19 @@ namespace const HostTensorVector& inputs) { using T = typename element_type_traits::value_type; - runtime::reference::v1::fake_quantize(inputs[0]->get_data_ptr(), - inputs[1]->get_data_ptr(), - inputs[2]->get_data_ptr(), - inputs[3]->get_data_ptr(), - inputs[4]->get_data_ptr(), - outputs[0]->get_data_ptr(), - op->get_input_shape(0), - op->get_input_shape(1), - op->get_input_shape(2), - op->get_input_shape(3), - op->get_input_shape(4), - op->get_levels(), - op->get_auto_broadcast()); + runtime::reference::fake_quantize(inputs[0]->get_data_ptr(), + inputs[1]->get_data_ptr(), + inputs[2]->get_data_ptr(), + inputs[3]->get_data_ptr(), + inputs[4]->get_data_ptr(), + outputs[0]->get_data_ptr(), + op->get_input_shape(0), + op->get_input_shape(1), + op->get_input_shape(2), + op->get_input_shape(3), + op->get_input_shape(4), + op->get_levels(), + op->get_auto_broadcast()); return true; }