Skip to content

Commit

Permalink
Support string dtype for OV backend (#2602)
Browse files Browse the repository at this point in the history
### Changes

- Support string dtype for OV backend

### Reason for changes

NotImplementedError: NNCF is not yet supported OpenVINO data type:
string.

### Related tickets

136751

### Tests

-  test_convert_to_nncf_dtype_supported_types()
- test_convert_to_nncf_dtype_unsupported_types()
  • Loading branch information
andrey-churkin authored Mar 26, 2024
1 parent c79111b commit f2f3bb7
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 9 deletions.
20 changes: 11 additions & 9 deletions nncf/openvino/graph/nncf_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,31 +35,35 @@ class GraphConverter:
"""

@staticmethod
def convert_to_nncf_dtype(ov_dtype: str) -> Dtype:
def convert_to_nncf_dtype(ov_type: ov.Type) -> Dtype:
"""
Converts the primitive types from the OpenVINO domain to the NNCF domain.
:param ov_dtype: OpenVINO primitive typename.
:return: NNCF primitive type.
"""
type_name = ov_type.get_type_name()
conversion_map = {
"f16": "float",
"f32": "float",
"f64": "float",
"i4": "int",
"i8": "int",
"i16": "int",
"i32": "int",
"i64": "int",
"u1": "int",
"u4": "int",
"u8": "int",
"u16": "int",
"u32": "int",
"u64": "int",
"boolean": "int",
"string": "int",
}
if ov_dtype not in conversion_map:
raise NotImplementedError(f"NNCF is not yet supported OpenVINO data type: {ov_dtype}.")
return Dtype(conversion_map[ov_dtype])
if type_name not in conversion_map:
raise NotImplementedError(f"NNCF is not yet supported OpenVINO data type: {type_name}.")
return Dtype(conversion_map[type_name])

@staticmethod
def _filter_weight_input_ports(inputs: List[ov.Input], metatype: Type[OperatorMetatype]) -> List[ov.Input]:
Expand Down Expand Up @@ -96,8 +100,7 @@ def _add_edges_to_nncf_graph(model: ov.Model, graph: NNCFGraph) -> None:
for out_node, inputs in node_vs_target_inputs.items():
tensor_shape = list(out.partial_shape.get_max_shape())
output_node_id = graph.get_node_by_name(out_node.get_friendly_name()).node_id
ov_dtype = out.get_element_type().get_type_name()
nncf_dtype = GraphConverter.convert_to_nncf_dtype(ov_dtype)
nncf_dtype = GraphConverter.convert_to_nncf_dtype(out.get_element_type())

parallel_inputs = None
if len(inputs) > 1:
Expand All @@ -109,7 +112,7 @@ def _add_edges_to_nncf_graph(model: ov.Model, graph: NNCFGraph) -> None:
tensor_shape=tensor_shape,
input_port_id=inputs[0].get_index(),
output_port_id=output_port_id,
dtype=Dtype(nncf_dtype),
dtype=nncf_dtype,
parallel_input_port_ids=parallel_inputs,
)

Expand Down Expand Up @@ -189,8 +192,7 @@ def create_nncf_graph(model: ov.Model) -> NNCFGraph:
if const_node is None:
continue

ov_dtype = const_node.get_element_type().get_type_name()
if GraphConverter.convert_to_nncf_dtype(ov_dtype) == Dtype.INTEGER:
if GraphConverter.convert_to_nncf_dtype(const_node.get_element_type()) == Dtype.INTEGER:
continue

const_attrs[const_port_id] = {
Expand Down
42 changes: 42 additions & 0 deletions tests/openvino/native/test_nncf_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,45 @@ def _get_default_nncf_graph_edge(from_node, to_node, input_port_id, output_port_
)
assert set(nncf_graph.get_input_edges(mm_node)) == ref_input_edges
assert set(nncf_graph.get_output_edges(input_node)) == ref_output_edges


@pytest.mark.parametrize(
"ov_type,expected_nncf_dtype",
[
(ov.Type.f16, Dtype.FLOAT),
(ov.Type.f32, Dtype.FLOAT),
(ov.Type.f64, Dtype.FLOAT),
(ov.Type.i4, Dtype.INTEGER),
(ov.Type.i8, Dtype.INTEGER),
(ov.Type.i16, Dtype.INTEGER),
(ov.Type.i32, Dtype.INTEGER),
(ov.Type.i64, Dtype.INTEGER),
(ov.Type.u1, Dtype.INTEGER),
(ov.Type.u4, Dtype.INTEGER),
(ov.Type.u8, Dtype.INTEGER),
(ov.Type.u16, Dtype.INTEGER),
(ov.Type.u32, Dtype.INTEGER),
(ov.Type.u64, Dtype.INTEGER),
(ov.Type.boolean, Dtype.INTEGER),
(ov.Type.string, Dtype.INTEGER),
],
)
def test_convert_to_nncf_dtype_supported_types(ov_type: ov.Type, expected_nncf_dtype: Dtype):
actual_nncf_dtype = GraphConverter.convert_to_nncf_dtype(ov_type)
assert actual_nncf_dtype == expected_nncf_dtype


@pytest.mark.parametrize(
"ov_type",
[
ov.Type.bf16,
ov.Type.nf4,
ov.Type.undefined,
# TODO(andrey-churkin): Add in OV 2024.0
# ov.Type.f8e4m3,
# ov.Type.f8e5m2,
],
)
def test_convert_to_nncf_dtype_unsupported_types(ov_type: ov.Type):
with pytest.raises(NotImplementedError):
_ = GraphConverter.convert_to_nncf_dtype(ov_type)

0 comments on commit f2f3bb7

Please sign in to comment.