Skip to content

Commit

Permalink
Low precision types support in Convert operation (openvinotoolkit#5640)
Browse files Browse the repository at this point in the history
* Add initial version of u1 type support.

* Turn off u8_to_u1 test in IE.CPU.

* Fix compilation issue.

* Replace std::memset with std::fill.

* Add u4 type support.

* Add i4 support.

* LP types support generalized.

* Remove std::copy optimization.

* Fix backend test for LP types.

* Fixed arm plugin compilation.

* Add LP types to Serialization SLT.

* Add Convert to summarize.py report.
  • Loading branch information
jdanieck authored and yekruglov committed Jun 7, 2021
1 parent 0ebbf41 commit b7ecfb5
Show file tree
Hide file tree
Showing 6 changed files with 335 additions and 62 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@ namespace {
const std::vector<std::vector<size_t>> inShape = {{1, 2, 3, 4}};

const std::vector<InferenceEngine::Precision> precisions = {
InferenceEngine::Precision::BOOL, InferenceEngine::Precision::U8,
InferenceEngine::Precision::I8, InferenceEngine::Precision::U16,
InferenceEngine::Precision::I16, InferenceEngine::Precision::U32,
InferenceEngine::Precision::I32, InferenceEngine::Precision::U64,
InferenceEngine::Precision::I64, InferenceEngine::Precision::BF16,
InferenceEngine::Precision::FP16, InferenceEngine::Precision::FP32,
InferenceEngine::Precision::FP64};
InferenceEngine::Precision::BOOL, InferenceEngine::Precision::BIN,
InferenceEngine::Precision::U4, InferenceEngine::Precision::U8,
InferenceEngine::Precision::I4, InferenceEngine::Precision::I8,
InferenceEngine::Precision::U16, InferenceEngine::Precision::I16,
InferenceEngine::Precision::U32, InferenceEngine::Precision::I32,
InferenceEngine::Precision::U64, InferenceEngine::Precision::I64,
InferenceEngine::Precision::BF16, InferenceEngine::Precision::FP16,
InferenceEngine::Precision::FP32, InferenceEngine::Precision::FP64};

TEST_P(ConvertLayerTest, Serialize) {
Serialize();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
'CTCGreedyDecoder-1',
'CTCGreedyDecoderSeqLen-6',
'Concat-1',
'Convert-1',
'ConvertLike-1',
'Convolution-1',
'Constant-1',
Expand Down
113 changes: 113 additions & 0 deletions ngraph/core/reference/include/ngraph/runtime/reference/convert.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include <cstddef>

#include "ngraph/type/element_type.hpp"
#include "ngraph/type/float16.hpp"

namespace ngraph
Expand All @@ -14,6 +15,118 @@ namespace ngraph
{
namespace reference
{
namespace detail
{
inline void set_u1(uint8_t* buf, size_t idx, uint8_t val)
{
const int byte_idx = idx / 8;
const int bit_idx = 7 - (idx % 8);
if (val)
{
buf[byte_idx] |= (1 << bit_idx);
}
else
{
buf[byte_idx] &= ~(1 << bit_idx);
}
}

inline uint8_t get_u1(const uint8_t* buf, size_t idx)
{
const int byte_idx = idx / 8;
const int bit_idx = 7 - (idx % 8);
return (buf[byte_idx] & (1 << bit_idx)) ? 1 : 0;
}

inline void set_u4(uint8_t* buf, size_t idx, uint8_t val)
{
const int byte_idx = idx / 2;
const int bit_shift = 4 * (++idx % 2);
buf[byte_idx] &= ~(0xF << bit_shift); // half byte zeroed
buf[byte_idx] |= (val << bit_shift); // set 1's
}

inline uint8_t get_u4(const uint8_t* buf, size_t idx)
{
const int byte_idx = idx / 2;
const int bit_shift = 4 * (++idx % 2);
return (buf[byte_idx] >> bit_shift) & 0xF;
}

inline void set_i4(uint8_t* buf, size_t idx, int8_t val)
{
const int byte_idx = idx / 2;
const int bit_shift = 4 * (++idx % 2);
buf[byte_idx] &= ~(0xF << bit_shift); // half byte zeroed
buf[byte_idx] |= (val << bit_shift); // set 1's
}

inline int8_t get_i4(const uint8_t* buf, size_t idx)
{
const int byte_idx = idx / 2;
const int bit_shift = 4 * (++idx % 2);
uint8_t val = (buf[byte_idx] >> bit_shift) & 0xF;
if (val & 0x08)
{ // negative number
val |= 0xF0;
}
return val;
}
template <typename T>
T get_value(const uint8_t* buf, size_t idx, element::Type type)
{
if (type == element::u1)
{
return detail::get_u1(buf, idx);
}

if (type == element::u4)
{
return detail::get_u4(buf, idx);
}

if (type == element::i4)
{
return detail::get_i4(buf, idx);
}

return static_cast<T>(buf[idx]);
}

template <typename TI, typename TO>
void lp_convert(const TI* arg,
TO* out,
size_t count,
element::Type_t src_type,
element::Type_t dst_type)
{
const uint8_t* input = reinterpret_cast<const uint8_t*>(arg);
uint8_t* output = reinterpret_cast<uint8_t*>(out);
for (size_t i = 0; i < count; ++i)
{
if (dst_type == element::u1)
{
detail::set_u1(
output, i, detail::get_value<uint8_t>(input, i, src_type));
}
else if (dst_type == element::u4)
{
detail::set_u4(
output, i, detail::get_value<uint8_t>(input, i, src_type));
}
else if (dst_type == element::i4)
{
detail::set_i4(
output, i, detail::get_value<int8_t>(input, i, src_type));
}
else
{
out[i] = detail::get_value<TO>(input, i, src_type);
}
}
}
} // namespace detail

template <typename TI, typename TO>
typename std::enable_if<!std::is_same<TO, char>::value>::type
convert(const TI* arg, TO* out, size_t count)
Expand Down
47 changes: 27 additions & 20 deletions ngraph/core/src/op/convert.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,6 @@ op::Convert::Convert(const Output<Node>& arg, const element::Type& destination_t
void op::Convert::validate_and_infer_types()
{
NGRAPH_OP_SCOPE(v0_Convert_validate_and_infer_types);
const element::Type data_et = get_input_element_type(0);
const element::Type destination_et = m_destination_type;

NODE_VALIDATION_CHECK(this,
data_et != element::u1 && data_et != element::u4 &&
data_et != element::i4,
"Input element type '",
data_et,
"' is not supported.");

NODE_VALIDATION_CHECK(this,
destination_et != element::u1 && destination_et != element::u4 &&
destination_et != element::i4,
"Destination element type '",
destination_et,
"' is not supported.");

set_output_type(0, m_destination_type, get_input_partial_shape(0));
}
Expand All @@ -68,10 +52,27 @@ namespace convert
{
out->set_shape(arg->get_shape());
size_t element_count = shape_size(out->get_shape());
return (INPUT_ET == arg->get_element_type()) && OUTPUT_ET == out->get_element_type() &&
(runtime::reference::convert(
arg->get_data_ptr<INPUT_ET>(), out->get_data_ptr<OUTPUT_ET>(), element_count),
true);

if ((INPUT_ET != arg->get_element_type()) || OUTPUT_ET != out->get_element_type())
{
return false;
}
if (((INPUT_ET == element::u1) || (OUTPUT_ET == element::u1)) ||
((INPUT_ET == element::u4) || (OUTPUT_ET == element::u4)) ||
((INPUT_ET == element::i4) || (OUTPUT_ET == element::i4)))
{
runtime::reference::detail::lp_convert(arg->get_data_ptr<INPUT_ET>(),
out->get_data_ptr<OUTPUT_ET>(),
element_count,
INPUT_ET,
OUTPUT_ET);
}
else
{
runtime::reference::convert(
arg->get_data_ptr<INPUT_ET>(), out->get_data_ptr<OUTPUT_ET>(), element_count);
}
return true;
}

#define TYPE_OUT_CASE(a, ...) \
Expand All @@ -89,10 +90,13 @@ namespace convert

switch (out->get_element_type())
{
TYPE_OUT_CASE(i4, arg, out);
TYPE_OUT_CASE(i8, arg, out);
TYPE_OUT_CASE(i16, arg, out);
TYPE_OUT_CASE(i32, arg, out);
TYPE_OUT_CASE(i64, arg, out);
TYPE_OUT_CASE(u1, arg, out);
TYPE_OUT_CASE(u4, arg, out);
TYPE_OUT_CASE(u8, arg, out);
TYPE_OUT_CASE(u16, arg, out);
TYPE_OUT_CASE(u32, arg, out);
Expand All @@ -112,7 +116,10 @@ namespace convert
bool rc = true;
switch (arg->get_element_type())
{
NGRAPH_TYPE_CASE(evaluate_convert, u1, arg, out);
NGRAPH_TYPE_CASE(evaluate_convert, u4, arg, out);
NGRAPH_TYPE_CASE(evaluate_convert, u8, arg, out);
NGRAPH_TYPE_CASE(evaluate_convert, i4, arg, out);
NGRAPH_TYPE_CASE(evaluate_convert, i8, arg, out);
NGRAPH_TYPE_CASE(evaluate_convert, i32, arg, out);
NGRAPH_TYPE_CASE(evaluate_convert, i16, arg, out);
Expand Down
Loading

0 comments on commit b7ecfb5

Please sign in to comment.