Skip to content

Commit

Permalink
[ONNX] Added translation for string constants/inputs (openvinotoolkit…
Browse files Browse the repository at this point in the history
…#24189)

### Details:
 - Added support for accepting string as inputs and constants

### Tickets:
 - 139685
  • Loading branch information
gkrivor authored and allnes committed Jun 26, 2024
1 parent 28db18f commit 1dce452
Show file tree
Hide file tree
Showing 8 changed files with 185 additions and 9 deletions.
14 changes: 14 additions & 0 deletions src/frontends/onnx/frontend/src/core/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,20 @@ std::vector<char> Tensor::get_data() const {
ONNX_INVALID_DATA_TYPE(m_tensor_proto->data_type(), "BOOL, raw data");
}

template <>
std::vector<std::string> Tensor::get_data() const {
if (has_external_data()) {
FRONT_END_THROW("External strings are not supported");
}
if (m_tensor_proto->has_raw_data()) {
FRONT_END_THROW("Loading strings from raw data isn't supported");
}
if (m_tensor_proto->data_type() == TensorProto_DataType::TensorProto_DataType_STRING) {
return detail::__get_data<std::string>(m_tensor_proto->string_data());
}
ONNX_INVALID_DATA_TYPE(m_tensor_proto->data_type(), "STRING");
}

} // namespace onnx
} // namespace frontend
} // namespace ov
23 changes: 16 additions & 7 deletions src/frontends/onnx/frontend/src/core/tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,12 +157,14 @@ class Tensor {
return ov::element::u64;
case TensorProto_DataType::TensorProto_DataType_BFLOAT16:
return ov::element::bf16;
case TensorProto_DataType::TensorProto_DataType_STRING:
return ov::element::string;
case TensorProto_DataType::TensorProto_DataType_UNDEFINED:
FRONT_END_THROW("Data type is Undefined");
default:
ONNX_UNSUPPORTED_DATA_TYPE(
m_tensor_proto->data_type(),
"BOOL, BFLOAT16, FLOAT, FLOAT16, DOUBLE, INT8, INT16, INT32, INT64, UINT8, UINT16, UINT32, UINT64");
ONNX_UNSUPPORTED_DATA_TYPE(m_tensor_proto->data_type(),
"BOOL, BFLOAT16, FLOAT, FLOAT16, DOUBLE, INT8, INT16, INT32, INT64, UINT8, "
"UINT16, UINT32, UINT64, STRING");
}
}

Expand Down Expand Up @@ -201,10 +203,12 @@ class Tensor {
return make_ov_constant<uint64_t>(ov::element::u64);
case TensorProto_DataType::TensorProto_DataType_BFLOAT16:
return make_ov_constant<ov::bfloat16>(ov::element::bf16);
case TensorProto_DataType::TensorProto_DataType_STRING:
return make_ov_constant<std::string>(ov::element::string);
default:
ONNX_UNSUPPORTED_DATA_TYPE(
m_tensor_proto->data_type(),
"BOOL, BFLOAT16, FLOAT, FLOAT16, DOUBLE, INT8, INT16, INT32, INT64, UINT8, UINT16, UINT32, UINT64");
ONNX_UNSUPPORTED_DATA_TYPE(m_tensor_proto->data_type(),
"BOOL, BFLOAT16, FLOAT, FLOAT16, DOUBLE, INT8, INT16, INT32, INT64, UINT8, "
"UINT16, UINT32, UINT64, STRING");
}
}

Expand Down Expand Up @@ -320,8 +324,10 @@ class Tensor {
return m_tensor_proto->uint64_data_size();
case TensorProto_DataType::TensorProto_DataType_DOUBLE:
return m_tensor_proto->double_data_size();
case TensorProto_DataType::TensorProto_DataType_STRING:
return m_tensor_proto->string_data_size();
}
ONNX_INVALID_DATA_TYPE(m_tensor_proto->data_type(), "FLOAT, INT32, INT64, UINT64, DOUBLE");
ONNX_INVALID_DATA_TYPE(m_tensor_proto->data_type(), "FLOAT, INT32, INT64, UINT64, DOUBLE, STRING");
}

const TensorProto* m_tensor_proto;
Expand Down Expand Up @@ -373,6 +379,9 @@ std::vector<uint64_t> Tensor::get_data() const;
template <>
std::vector<char> Tensor::get_data() const;

template <>
std::vector<std::string> Tensor::get_data() const;

} // namespace onnx
} // namespace frontend
} // namespace ov
7 changes: 6 additions & 1 deletion src/frontends/onnx/frontend/src/utils/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include <onnx/onnx_pb.h> // onnx types

#include "core/tensor.hpp"
#include "onnx_framework_node.hpp"
#include "openvino/frontend/exception.hpp"
#include "openvino/op/add.hpp"
Expand Down Expand Up @@ -60,8 +61,12 @@ const ov::element::Type& get_ov_element_type(int64_t onnx_type) {
return ov::element::dynamic;
case TensorProto_DataType::TensorProto_DataType_BFLOAT16:
return ov::element::bf16;
case TensorProto_DataType::TensorProto_DataType_STRING:
return ov::element::string;
}
OPENVINO_THROW("unsupported element type");
ONNX_UNSUPPORTED_DATA_TYPE(onnx_type,
"BOOL, BFLOAT16, FLOAT, FLOAT16, DOUBLE, INT8, INT16, INT32, INT64, UINT8, UINT16, "
"UINT32, UINT64, STRING, UNDEFINED");
}

std::shared_ptr<ov::Node> get_monotonic_range_along_node_rank(const ov::Output<ov::Node>& value,
Expand Down
3 changes: 2 additions & 1 deletion src/frontends/onnx/onnx_common/src/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ const std::map<ov::element::Type_t, TensorProto_DataType> OV_2_ONNX_TYPES = {
{ov::element::Type_t::u16, TensorProto_DataType::TensorProto_DataType_UINT16},
{ov::element::Type_t::u32, TensorProto_DataType::TensorProto_DataType_UINT32},
{ov::element::Type_t::u64, TensorProto_DataType::TensorProto_DataType_UINT64},
{ov::element::Type_t::boolean, TensorProto_DataType::TensorProto_DataType_BOOL}};
{ov::element::Type_t::boolean, TensorProto_DataType::TensorProto_DataType_BOOL},
{ov::element::Type_t::string, TensorProto_DataType::TensorProto_DataType_STRING}};

ov::element::Type_t onnx_to_ov_data_type(const TensorProto_DataType& onnx_type) {
const auto result = std::find_if(OV_2_ONNX_TYPES.begin(),
Expand Down
59 changes: 59 additions & 0 deletions src/frontends/onnx/tests/models/string_constant.prototxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
ir_version: 9
producer_name: "OpenVINO ONNX Frontend"
graph {
name: "test"
node {
output: "Y"
op_type: "Constant"
attribute {
name: "value"
t {
dims: 2
data_type: 8
string_data: "string1"
string_data: "string2"
name: "str_const"
}
type: TENSOR
}
}
node {
input: "Y"
output: "O"
op_type: "Shape"
}
node {
input: "Y"
output: "V"
op_type: "Identity"
}
output {
name: "O"
type {
tensor_type {
elem_type: 7
shape {
dim {
dim_value: 1
}
}
}
}
}
output {
name: "V"
type {
tensor_type {
elem_type: 8
shape {
dim {
dim_value: 2
}
}
}
}
}
}
opset_import {
version: 20
}
57 changes: 57 additions & 0 deletions src/frontends/onnx/tests/models/string_input.prototxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
ir_version: 9
producer_name: "OpenVINO ONNX Frontend"
graph {
name: "test"
node {
input: "I"
output: "O"
op_type: "Shape"
}
node {
input: "I"
output: "V"
op_type: "Identity"
}
input {
name: "I"
type {
tensor_type {
elem_type: 8
shape {
dim {
dim_value: 2
}
}
}
}
}
output {
name: "O"
type {
tensor_type {
elem_type: 7
shape {
dim {
dim_value: 1
}
}
}
}
}
output {
name: "V"
type {
tensor_type {
elem_type: 8
shape {
dim {
dim_value: 2
}
}
}
}
}
}
opset_import {
version: 20
}
19 changes: 19 additions & 0 deletions src/frontends/onnx/tests/onnx_import.in.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6736,6 +6736,25 @@ OPENVINO_TEST(${BACKEND_NAME}, onnx_model_reduce_min_20_boolean) {
test_case.run();
}

OPENVINO_TEST(${BACKEND_NAME}, onnx_string_input) {
const auto model = convert_model("string_input.onnx");
auto test_case = test::TestCase(model);
test_case.add_input<std::string>({"strinpt1", "strinpt2"});
test_case.add_expected_output<int64_t>({2});
test_case.add_expected_output<std::string>({"strinpt1", "strinpt2"});

test_case.run();
}

OPENVINO_TEST(${BACKEND_NAME}, onnx_string_constant) {
const auto model = convert_model("string_constant.onnx");
auto test_case = test::TestCase(model);
test_case.add_expected_output<int64_t>({2});
test_case.add_expected_output<std::string>({"string1", "string2"});

test_case.run();
}

OPENVINO_TEST(${BACKEND_NAME}, onnx_model_multinomial_7) {
auto model = convert_model("multinomial.onnx");

Expand Down
12 changes: 12 additions & 0 deletions src/tests/test_utils/common_test_utils/src/test_case.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,18 @@ std::pair<testing::AssertionResult, size_t> TestCase::compare_results(size_t tol
case element::Type_t::boolean:
res = compare_values<char>(exp_result, result_tensor, tolerance_bits);
break;
case element::Type_t::string: {
res = ::testing::AssertionSuccess();
std::string* exp_strings = exp_result.data<std::string>();
std::string* res_strings = result_tensor.data<std::string>();
for (size_t i = 0; i < exp_result.get_size(); ++i) {
if (exp_strings[i] != res_strings[i]) {
res = ::testing::AssertionFailure() << "Wrong string value at index " << i << ", expected \""
<< exp_strings[i] << "\" got \"" << res_strings[i] << "\"";
break;
}
}
} break;
default:
res = testing::AssertionFailure() << "Unsupported data type encountered in 'res' method";
}
Expand Down

0 comments on commit 1dce452

Please sign in to comment.