Skip to content

Commit

Permalink
fix unit-tests
Browse files Browse the repository at this point in the history
  • Loading branch information
pavel-esir committed Sep 18, 2023
1 parent 372009b commit 5a18b11
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 5 deletions.
17 changes: 14 additions & 3 deletions src/bindings/python/src/openvino/frontend/tensorflow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@


import logging as log
import numpy as np
import sys
from distutils.version import LooseVersion
from typing import List, Dict, Union

import numpy as np
from openvino.runtime import PartialShape, Dimension


Expand Down Expand Up @@ -63,7 +65,7 @@ def get_imported_module_version(imported_module):
for attr in version_attrs:
installed_version = getattr(imported_module, attr, None)
if isinstance(installed_version, str):
return installed_version
return installed_version
else:
installed_version = None

Expand Down Expand Up @@ -191,10 +193,19 @@ def tf_function(args):
return model(*args)
input_needs_packing = True

def are_shapes_defined(shape: Union[List, Dict]):
if shape is None:
return False

if isinstance(shape, list):
return np.all([shape is not None for shape in input_shapes])
elif isinstance(shape, dict):
return np.all([shape is not None for name, shape in input_shapes.items()])

if example_input is not None:
concrete_func = get_concrete_func(tf_function, example_input, input_needs_packing,
"Could not trace the TF model with the following error: {}")
elif np.all([shape is not None for name, shape in input_shapes.items()]):
elif are_shapes_defined(input_shapes):
inp = create_example_input_by_user_shapes(input_shapes, input_types)
concrete_func = get_concrete_func(tf_function, inp, input_needs_packing,
"Could not trace the TF model with the following error: {}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,14 +144,14 @@ def test_conversion_pbtxt_model_with_inference(self, inputs, expected, dtype):
# new frontend
(
"model_add_with_undefined_constant.pbtxt",
"x[2,3]",
("x", [2, 3]),
{"x": np.array([[12, 13, 10], [11, 14, 16]], dtype=np.float32)},
np.array([[12, 13, 10], [11, 14, 16]], dtype=np.float32),
np.float32
),
(
"model_mul_with_undefined_constant.pbtxt",
"x[2]",
("x", [2]),
{"x": np.array([11, -12], dtype=np.int32)},
np.array([0, 0], dtype=np.int32),
np.int32
Expand Down

0 comments on commit 5a18b11

Please sign in to comment.