Skip to content

Commit

Permalink
Fix issue with np.expand_dims for numpy<18.x (openvinotoolkit#3436)
Browse files Browse the repository at this point in the history
* Fix trouble with np.expand_dims for numpy<18.x

* Delete function expand_dims

* Added additional line
  • Loading branch information
evolosen authored and mryzhov committed Dec 11, 2020
1 parent c5587ea commit 677ab2d
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions model-optimizer/mo/utils/broadcasting.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,10 @@ def explicit_broadcasting(input_value: np.array, target_shape: np.array, axes_ma
:return: broadcasted value
"""
res_shape, normalized_axes_mapping = explicit_shape_broadcasting(input_value.shape, target_shape, axes_mapping)
#TODO: Function 'expand_dims' should be replaced with 'numpy.expand_dims' if numpy version will be >=18.x in requirements.
expand_dim_axis = set(np.arange(len(target_shape))) - set(normalized_axes_mapping)

input_expanded = np.expand_dims(input_value.copy(), axis=list(expand_dim_axis))
input_expanded = input_value.copy()

for axis in sorted(list(expand_dim_axis)):
input_expanded = np.expand_dims(input_expanded, axis)
return np.broadcast_to(input_expanded, res_shape)

0 comments on commit 677ab2d

Please sign in to comment.