Skip to content

Commit

Permalink
[PyOV] Extend Python API with Squeeze-15 (#27281)
Browse files Browse the repository at this point in the history
### Details:
- This PR includes commits from
#26995

### Tickets:
 - CVS-154024

---------

Signed-off-by: p-wysocki <[email protected]>
Co-authored-by: Michal Barnas <[email protected]>
Co-authored-by: Roman Kazantsev <[email protected]>
Co-authored-by: Michal Lukaszewski <[email protected]>
  • Loading branch information
4 people authored Oct 31, 2024
1 parent c902a01 commit a0b73e0
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@
from openvino.runtime.opset1.ops import split
from openvino.runtime.opset1.ops import sqrt
from openvino.runtime.opset1.ops import squared_difference
from openvino.runtime.opset1.ops import squeeze
from openvino.runtime.opset15.ops import squeeze
from openvino.runtime.opset15.ops import stft
from openvino.runtime.opset1.ops import strided_slice
from openvino.runtime.opset1.ops import subtract
Expand Down
39 changes: 39 additions & 0 deletions src/bindings/python/src/openvino/runtime/opset15/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,3 +348,42 @@ def search_sorted(
inputs = as_nodes(sorted_sequence, values, name=name)
attributes = {"right_mode": right_mode}
return _get_node_factory_opset15().create("SearchSorted", inputs, attributes)


@nameable_op
def squeeze(
data: NodeInput,
axes: Optional[NodeInput] = None,
allow_axis_skip: bool = False,
name: Optional[str] = None,
) -> Node:
"""Perform squeeze operation on input tensor.
:param data: The node with data tensor.
:param axes: Optional list of integers, indicating the dimensions to squeeze.
Negative indices are supported. One of: input node or array.
:param allow_axis_skip: If true, shape inference results in a dynamic rank, when
selected axis has value 1 in its dynamic range. Used only if axes input
is given. Defaults to false.
:param name: Optional new name for output node.
:return: The new node performing a squeeze operation on input tensor.
Remove single-dimensional entries from the shape of a tensor.
Takes an optional parameter `axes` with a list of axes to squeeze.
If `axes` is not provided, all the single dimensions will be removed from the shape.
For example:
Inputs: tensor with shape [1, 2, 1, 3, 1, 1], axes=[2, 4]
Result: tensor with shape [1, 2, 3, 1]
"""
if axes is None:
inputs = as_nodes(data, name=name)
else:
inputs = as_nodes(data, axes, name=name)
return _get_node_factory_opset15().create(
"Squeeze",
inputs,
{"allow_axis_skip": allow_axis_skip}
)
11 changes: 0 additions & 11 deletions src/bindings/python/tests/test_graph/test_ops_fused.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,17 +110,6 @@ def test_clamp_operator():
assert list(model.get_output_shape(0)) == [2, 2]


def test_squeeze_operator():
data_shape = [1, 2, 1, 3, 1, 1]
parameter_data = ov.parameter(data_shape, name="Data", dtype=np.float32)
axes = [2, 4]
model = ov.squeeze(parameter_data, axes)

assert model.get_type_name() == "Squeeze"
assert model.get_output_size() == 1
assert list(model.get_output_shape(0)) == [1, 2, 3, 1]


def test_squared_difference_operator():
x1_shape = [1, 2, 3, 4]
x2_shape = [2, 3, 4]
Expand Down
51 changes: 51 additions & 0 deletions src/bindings/python/tests/test_graph/test_squeeze.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2018-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import openvino.runtime.opset1 as ov_opset1
import openvino.runtime.opset15 as ov_opset15
import numpy as np
import pytest


def test_squeeze_v1_operator():
data_shape = [1, 2, 1, 3, 1, 1]
parameter_data = ov_opset1.parameter(data_shape, name="Data", dtype=np.float32)
axes = [2, 4]
model = ov_opset1.squeeze(parameter_data, axes)

assert model.get_type_name() == "Squeeze"
assert model.get_output_size() == 1
assert list(model.get_output_shape(0)) == [1, 2, 3, 1]


@pytest.mark.parametrize(("input_shape", "axes", "allow_axis_skip", "expected_shape"), [
((1, 2, 1, 3, 1, 1), [1, 2, 4], True, [1, 2, 3, 1]),
((1, 2, 1, 3, 1, 1), [1, 2, 4], False, [1, 2, 3, 1]),
((2, -1, 3), [1], False, [2, 3])
])
def test_squeeze_v15_operator(input_shape, axes, allow_axis_skip, expected_shape):
parameter_data = ov_opset15.parameter(input_shape, name="Data", dtype=np.float32)
model = ov_opset15.squeeze(parameter_data, axes, allow_axis_skip, name="Squeeze")

assert model.get_type_name() == "Squeeze"
assert model.get_output_size() == 1
assert list(model.get_output_shape(0)) == expected_shape


def test_squeeze_v15_dynamic_rank_output():
parameter_data = ov_opset15.parameter((2, -1, 3), name="Data", dtype=np.float32)
model = ov_opset15.squeeze(parameter_data, [1], True, name="Squeeze")

assert model.get_type_name() == "Squeeze"
assert model.get_output_size() == 1
assert model.get_output_partial_shape(0).to_string() == "[...]"


def test_squeeze_v15_axes_not_given():
parameter_data = ov_opset15.parameter((1, 3, 1, 1, 3, 5), name="Data", dtype=np.float32)
model = ov_opset15.squeeze(data=parameter_data, name="Squeeze")

assert model.get_type_name() == "Squeeze"
assert model.get_output_size() == 1
assert list(model.get_output_shape(0)) == [3, 3, 5]

0 comments on commit a0b73e0

Please sign in to comment.