Skip to content

Commit

Permalink
Add initial version of u1 type support.
Browse files Browse the repository at this point in the history
  • Loading branch information
jdanieck committed May 14, 2021
1 parent c1b1e2e commit daebf97
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 22 deletions.
53 changes: 53 additions & 0 deletions ngraph/core/reference/include/ngraph/runtime/reference/convert.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include <cstddef>

#include "ngraph/type/element_type.hpp"
#include "ngraph/type/float16.hpp"

namespace ngraph
Expand All @@ -14,6 +15,58 @@ namespace ngraph
{
namespace reference
{
namespace detail
{
inline void set_bit(uint8_t* buf, size_t idx)
{
const int byte_idx = idx / 8;
const int bit_idx = 7 - (idx % 8);
buf[byte_idx] |= (1 << bit_idx);
}

inline uint8_t get_bit(const uint8_t* buf, size_t idx)
{
const int byte_idx = idx / 8;
const int bit_idx = 7 - (idx % 8);
return buf[byte_idx] & (1 << bit_idx);
}
} // namespace detail

template <typename TI, typename TO>
void convert(const TI* arg,
TO* out,
size_t count,
element::Type_t src_type,
element::Type_t dst_type)
{
std::memset(out, 0, count * sizeof(TO));

if (dst_type == element::u1)
{
for (size_t i = 0; i < count; ++i)
{
if (arg[i])
{
detail::set_bit(reinterpret_cast<uint8_t*>(out), i);
}
}
}
else if (src_type == element::u1)
{
for (size_t i = 0; i < count; ++i)
{
if (detail::get_bit(reinterpret_cast<const uint8_t*>(arg), i))
{
out[i] = static_cast<TO>(1);
}
}
}
else
{
NGRAPH_CHECK(false, "Unimplemented");
}
}

template <typename TI, typename TO>
typename std::enable_if<!std::is_same<TO, char>::value>::type
convert(const TI* arg, TO* out, size_t count)
Expand Down
32 changes: 24 additions & 8 deletions ngraph/core/src/op/convert.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,13 @@ void op::Convert::validate_and_infer_types()
const element::Type destination_et = m_destination_type;

NODE_VALIDATION_CHECK(this,
data_et != element::u1 && data_et != element::u4 &&
data_et != element::i4,
data_et != element::u4 && data_et != element::i4,
"Input element type '",
data_et,
"' is not supported.");

NODE_VALIDATION_CHECK(this,
destination_et != element::u1 && destination_et != element::u4 &&
destination_et != element::i4,
destination_et != element::u4 && destination_et != element::i4,
"Destination element type '",
destination_et,
"' is not supported.");
Expand Down Expand Up @@ -68,10 +66,26 @@ namespace convert
{
out->set_shape(arg->get_shape());
size_t element_count = shape_size(out->get_shape());
return (INPUT_ET == arg->get_element_type()) && OUTPUT_ET == out->get_element_type() &&
(runtime::reference::convert(
arg->get_data_ptr<INPUT_ET>(), out->get_data_ptr<OUTPUT_ET>(), element_count),
true);

if ((INPUT_ET != arg->get_element_type()) || OUTPUT_ET != out->get_element_type())
{
return false;
}
const std::unordered_set<element::Type_t> lp_types{element::u1};
if ((lp_types.count(INPUT_ET) || (lp_types.count(OUTPUT_ET))))
{
runtime::reference::convert(arg->get_data_ptr<INPUT_ET>(),
out->get_data_ptr<OUTPUT_ET>(),
element_count,
INPUT_ET,
OUTPUT_ET);
}
else
{
runtime::reference::convert(
arg->get_data_ptr<INPUT_ET>(), out->get_data_ptr<OUTPUT_ET>(), element_count);
}
return true;
}

#define TYPE_OUT_CASE(a, ...) \
Expand All @@ -93,6 +107,7 @@ namespace convert
TYPE_OUT_CASE(i16, arg, out);
TYPE_OUT_CASE(i32, arg, out);
TYPE_OUT_CASE(i64, arg, out);
TYPE_OUT_CASE(u1, arg, out);
TYPE_OUT_CASE(u8, arg, out);
TYPE_OUT_CASE(u16, arg, out);
TYPE_OUT_CASE(u32, arg, out);
Expand All @@ -112,6 +127,7 @@ namespace convert
bool rc = true;
switch (arg->get_element_type())
{
NGRAPH_TYPE_CASE(evaluate_convert, u1, arg, out);
NGRAPH_TYPE_CASE(evaluate_convert, u8, arg, out);
NGRAPH_TYPE_CASE(evaluate_convert, i8, arg, out);
NGRAPH_TYPE_CASE(evaluate_convert, i32, arg, out);
Expand Down
56 changes: 42 additions & 14 deletions ngraph/test/backend/convert.in.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,23 @@ static string s_manifest = "${MANIFEST}";
using TestEngine = test::ENGINE_CLASS_NAME(${BACKEND_NAME});
namespace
{
std::shared_ptr<Function> CreateFunction(const Shape& input_shape,
const element::Type& input_type,
const element::Type& expected_output_type)
{
const auto in = make_shared<op::Parameter>(input_type, input_shape);
const auto convert = make_shared<op::Convert>(in, expected_output_type);
return make_shared<Function>(NodeVector{convert}, ParameterVector{in});
}

template <typename T_IN, typename T_OUT>
void ConvertTest(const std::vector<T_IN>& input,
const Shape& input_shape,
const ngraph::element::Type& input_type,
const std::vector<T_OUT>& expected_output,
const ngraph::element::Type& expected_output_type)
{
const auto in = make_shared<op::Parameter>(input_type, input_shape);
const auto convert = make_shared<op::Convert>(in, expected_output_type);
const auto f = make_shared<Function>(NodeVector{convert}, ParameterVector{in});

const auto f = CreateFunction(input_shape, input_type, expected_output_type);
auto test_case = test::TestCase<TestEngine>(f);
test_case.add_input(input);
test_case.add_expected_output(expected_output);
Expand Down Expand Up @@ -180,17 +186,28 @@ NGRAPH_TEST(${BACKEND_NAME}, convert_i64_to_f32)
ConvertTest(input, input_shape, input_type, expected_output, expected_output_type);
}

NGRAPH_TEST(${BACKEND_NAME}, convert_u1_to_f32_is_not_supported_yet)
NGRAPH_TEST(${BACKEND_NAME}, convert_u1_to_f32)
{
const std::vector<uint8_t> input{0x00};
const std::vector<uint8_t> input{0xA0};
const Shape input_shape{2, 2};
const element::Type input_type = ngraph::element::u1;

const std::vector<float> expected_output{0.0f, 0.0f, 0.0f, 0.0f};
const std::vector<float> expected_output{1.0f, 0.0f, 1.0f, 0.0f};
const element::Type expected_output_type = ngraph::element::f32;

ASSERT_THROW(ConvertTest(input, input_shape, input_type, expected_output, expected_output_type),
ngraph::NodeValidationFailure);
{
const auto f = CreateFunction(input_shape, input_type, expected_output_type);
auto backend = runtime::Backend::create("${BACKEND_NAME}");
auto input_tesnor = backend->create_tensor(input_type, input_shape);
copy_data(input_tesnor, input);
auto output = backend->create_tensor(expected_output_type, input_shape);
auto handle = backend->compile(f);
handle->call_with_validate({output}, {input_tesnor});

std::vector<float> result(expected_output.size());
output->read(result.data(), result.size() * sizeof(float));
EXPECT_TRUE(test::all_close_f(expected_output, result));
}
}

NGRAPH_TEST(${BACKEND_NAME}, convert_u4_to_f32_is_not_supported_yet)
Expand Down Expand Up @@ -361,17 +378,28 @@ NGRAPH_TEST(${BACKEND_NAME}, convert_u8_to_i64)
}

// destination: u1
NGRAPH_TEST(${BACKEND_NAME}, convert_u8_to_u1_is_not_supported_yet)
NGRAPH_TEST(${BACKEND_NAME}, convert_u8_to_u1)
{
const std::vector<uint8_t> input{0, 0, 0, 0};
const std::vector<uint8_t> input{1, 0, 1, 0};
const Shape input_shape{4};
const element::Type input_type = ngraph::element::u8;

const std::vector<uint8_t> expected_output{0x00};
const std::vector<uint8_t> expected_output{0xA0};
const element::Type expected_output_type = ngraph::element::u1;

ASSERT_THROW(ConvertTest(input, input_shape, input_type, expected_output, expected_output_type),
ngraph::NodeValidationFailure);
{
const auto f = CreateFunction(input_shape, input_type, expected_output_type);
auto backend = runtime::Backend::create("${BACKEND_NAME}");
auto input_tesnor = backend->create_tensor(input_type, input_shape);
copy_data(input_tesnor, input);
auto output = backend->create_tensor(expected_output_type, input_shape);
auto handle = backend->compile(f);
handle->call_with_validate({output}, {input_tesnor});

std::vector<uint8_t> result(expected_output.size());
output->read(result.data(), result.size() * sizeof(uint8_t));
EXPECT_TRUE(test::all_close(expected_output, result));
}
}

// destination: u4
Expand Down

0 comments on commit daebf97

Please sign in to comment.